From 46eecd110a4017ea0c86cbb1010d0ccd6a5eb2ef Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 31 Aug 2013 19:27:07 -0700 Subject: Initial work to rename package to org.apache.spark --- README.md | 2 +- assembly/pom.xml | 16 +- assembly/src/main/assembly/assembly.xml | 10 +- bagel/pom.xml | 6 +- .../main/scala/org/apache/spark/bagel/Bagel.scala | 293 +++++ bagel/src/main/scala/spark/bagel/Bagel.scala | 294 ----- bagel/src/test/scala/bagel/BagelSuite.scala | 118 -- .../scala/org/apache/spark/bagel/BagelSuite.scala | 116 ++ bin/start-master.sh | 2 +- bin/start-slave.sh | 2 +- bin/stop-master.sh | 2 +- bin/stop-slaves.sh | 4 +- core/pom.xml | 4 +- .../org/apache/spark/network/netty/FileClient.java | 89 ++ .../netty/FileClientChannelInitializer.java | 41 + .../spark/network/netty/FileClientHandler.java | 60 + .../org/apache/spark/network/netty/FileServer.java | 103 ++ .../netty/FileServerChannelInitializer.java | 42 + .../spark/network/netty/FileServerHandler.java | 82 ++ .../apache/spark/network/netty/PathResolver.java | 29 + .../main/java/spark/network/netty/FileClient.java | 89 -- .../netty/FileClientChannelInitializer.java | 41 - .../spark/network/netty/FileClientHandler.java | 60 - .../main/java/spark/network/netty/FileServer.java | 103 -- .../netty/FileServerChannelInitializer.java | 42 - .../spark/network/netty/FileServerHandler.java | 82 -- .../java/spark/network/netty/PathResolver.java | 29 - .../org/apache/spark/ui/static/bootstrap.min.css | 874 +++++++++++++ .../org/apache/spark/ui/static/sorttable.js | 495 ++++++++ .../spark/ui/static/spark-logo-77x50px-hd.png | Bin 0 -> 3536 bytes .../org/apache/spark/ui/static/spark_logo.png | Bin 0 -> 14233 bytes .../resources/org/apache/spark/ui/static/webui.css | 63 + .../resources/spark/ui/static/bootstrap.min.css | 874 ------------- .../main/resources/spark/ui/static/sorttable.js | 495 -------- .../spark/ui/static/spark-logo-77x50px-hd.png | Bin 3536 -> 0 bytes .../main/resources/spark/ui/static/spark_logo.png | Bin 14233 -> 0 bytes core/src/main/resources/spark/ui/static/webui.css | 63 - .../main/scala/org/apache/spark/Accumulators.scala | 256 ++++ .../main/scala/org/apache/spark/Aggregator.scala | 61 + .../apache/spark/BlockStoreShuffleFetcher.scala | 89 ++ .../main/scala/org/apache/spark/CacheManager.scala | 82 ++ .../scala/org/apache/spark/ClosureCleaner.scala | 231 ++++ .../main/scala/org/apache/spark/Dependency.scala | 81 ++ .../org/apache/spark/DoubleRDDFunctions.scala | 78 ++ .../org/apache/spark/FetchFailedException.scala | 44 + .../scala/org/apache/spark/HttpFileServer.scala | 62 + .../main/scala/org/apache/spark/HttpServer.scala | 88 ++ .../scala/org/apache/spark/JavaSerializer.scala | 83 ++ .../scala/org/apache/spark/KryoSerializer.scala | 156 +++ core/src/main/scala/org/apache/spark/Logging.scala | 95 ++ .../scala/org/apache/spark/MapOutputTracker.scala | 338 +++++ .../scala/org/apache/spark/PairRDDFunctions.scala | 703 +++++++++++ .../main/scala/org/apache/spark/Partition.scala | 31 + .../main/scala/org/apache/spark/Partitioner.scala | 135 ++ core/src/main/scala/org/apache/spark/RDD.scala | 957 ++++++++++++++ .../scala/org/apache/spark/RDDCheckpointData.scala | 130 ++ .../apache/spark/SequenceFileRDDFunctions.scala | 107 ++ .../org/apache/spark/SerializableWritable.scala | 42 + .../scala/org/apache/spark/ShuffleFetcher.scala | 35 + .../scala/org/apache/spark/SizeEstimator.scala | 283 +++++ .../main/scala/org/apache/spark/SparkContext.scala | 995 +++++++++++++++ .../src/main/scala/org/apache/spark/SparkEnv.scala | 240 ++++ .../scala/org/apache/spark/SparkException.scala | 24 + .../main/scala/org/apache/spark/SparkFiles.java | 42 + .../scala/org/apache/spark/SparkHadoopWriter.scala | 201 +++ .../main/scala/org/apache/spark/TaskContext.scala | 41 + .../scala/org/apache/spark/TaskEndReason.scala | 51 + .../main/scala/org/apache/spark/TaskState.scala | 51 + core/src/main/scala/org/apache/spark/Utils.scala | 780 ++++++++++++ .../org/apache/spark/api/java/JavaDoubleRDD.scala | 167 +++ .../org/apache/spark/api/java/JavaPairRDD.scala | 601 +++++++++ .../scala/org/apache/spark/api/java/JavaRDD.scala | 114 ++ .../org/apache/spark/api/java/JavaRDDLike.scala | 426 +++++++ .../apache/spark/api/java/JavaSparkContext.scala | 418 +++++++ .../java/JavaSparkContextVarargsWorkaround.java | 64 + .../org/apache/spark/api/java/JavaUtils.scala | 28 + .../org/apache/spark/api/java/StorageLevels.java | 48 + .../api/java/function/DoubleFlatMapFunction.java | 37 + .../spark/api/java/function/DoubleFunction.java | 34 + .../spark/api/java/function/FlatMapFunction.scala | 28 + .../spark/api/java/function/FlatMapFunction2.scala | 28 + .../apache/spark/api/java/function/Function.java | 39 + .../apache/spark/api/java/function/Function2.java | 38 + .../api/java/function/PairFlatMapFunction.java | 46 + .../spark/api/java/function/PairFunction.java | 45 + .../spark/api/java/function/VoidFunction.scala | 33 + .../spark/api/java/function/WrappedFunction1.scala | 32 + .../spark/api/java/function/WrappedFunction2.scala | 32 + .../spark/api/python/PythonPartitioner.scala | 50 + .../org/apache/spark/api/python/PythonRDD.scala | 344 ++++++ .../spark/api/python/PythonWorkerFactory.scala | 132 ++ .../spark/broadcast/BitTorrentBroadcast.scala | 1057 ++++++++++++++++ .../org/apache/spark/broadcast/Broadcast.scala | 70 ++ .../apache/spark/broadcast/BroadcastFactory.scala | 30 + .../org/apache/spark/broadcast/HttpBroadcast.scala | 171 +++ .../org/apache/spark/broadcast/MultiTracker.scala | 409 ++++++ .../org/apache/spark/broadcast/SourceInfo.scala | 54 + .../org/apache/spark/broadcast/TreeBroadcast.scala | 602 +++++++++ .../spark/deploy/ApplicationDescription.scala | 32 + .../scala/org/apache/spark/deploy/Command.scala | 26 + .../org/apache/spark/deploy/DeployMessage.scala | 130 ++ .../org/apache/spark/deploy/ExecutorState.scala | 28 + .../org/apache/spark/deploy/JsonProtocol.scala | 86 ++ .../apache/spark/deploy/LocalSparkCluster.scala | 69 ++ .../org/apache/spark/deploy/SparkHadoopUtil.scala | 36 + .../main/scala/org/apache/spark/deploy/WebUI.scala | 47 + .../org/apache/spark/deploy/client/Client.scala | 145 +++ .../spark/deploy/client/ClientListener.scala | 35 + .../apache/spark/deploy/client/TestClient.scala | 51 + .../apache/spark/deploy/client/TestExecutor.scala | 27 + .../spark/deploy/master/ApplicationInfo.scala | 85 ++ .../spark/deploy/master/ApplicationSource.scala | 24 + .../spark/deploy/master/ApplicationState.scala | 28 + .../apache/spark/deploy/master/ExecutorInfo.scala | 32 + .../org/apache/spark/deploy/master/Master.scala | 386 ++++++ .../spark/deploy/master/MasterArguments.scala | 89 ++ .../apache/spark/deploy/master/MasterSource.scala | 25 + .../apache/spark/deploy/master/WorkerInfo.scala | 77 ++ .../apache/spark/deploy/master/WorkerState.scala | 24 + .../spark/deploy/master/ui/ApplicationPage.scala | 118 ++ .../apache/spark/deploy/master/ui/IndexPage.scala | 141 +++ .../spark/deploy/master/ui/MasterWebUI.scala | 80 ++ .../spark/deploy/worker/ExecutorRunner.scala | 199 +++ .../org/apache/spark/deploy/worker/Worker.scala | 213 ++++ .../spark/deploy/worker/WorkerArguments.scala | 153 +++ .../apache/spark/deploy/worker/WorkerSource.scala | 34 + .../apache/spark/deploy/worker/ui/IndexPage.scala | 115 ++ .../spark/deploy/worker/ui/WorkerWebUI.scala | 190 +++ .../scala/org/apache/spark/executor/Executor.scala | 269 ++++ .../apache/spark/executor/ExecutorBackend.scala | 28 + .../apache/spark/executor/ExecutorExitCode.scala | 60 + .../org/apache/spark/executor/ExecutorSource.scala | 55 + .../spark/executor/ExecutorURLClassLoader.scala | 31 + .../spark/executor/MesosExecutorBackend.scala | 95 ++ .../spark/executor/StandaloneExecutorBackend.scala | 107 ++ .../org/apache/spark/executor/TaskMetrics.scala | 105 ++ .../org/apache/spark/io/CompressionCodec.scala | 82 ++ .../org/apache/spark/metrics/MetricsConfig.scala | 100 ++ .../org/apache/spark/metrics/MetricsSystem.scala | 163 +++ .../apache/spark/metrics/sink/ConsoleSink.scala | 59 + .../org/apache/spark/metrics/sink/CsvSink.scala | 68 + .../org/apache/spark/metrics/sink/JmxSink.scala | 35 + .../apache/spark/metrics/sink/MetricsServlet.scala | 55 + .../scala/org/apache/spark/metrics/sink/Sink.scala | 23 + .../apache/spark/metrics/source/JvmSource.scala | 32 + .../org/apache/spark/metrics/source/Source.scala | 25 + .../org/apache/spark/network/BufferMessage.scala | 111 ++ .../org/apache/spark/network/Connection.scala | 586 +++++++++ .../apache/spark/network/ConnectionManager.scala | 720 +++++++++++ .../apache/spark/network/ConnectionManagerId.scala | 38 + .../spark/network/ConnectionManagerTest.scala | 102 ++ .../scala/org/apache/spark/network/Message.scala | 93 ++ .../org/apache/spark/network/MessageChunk.scala | 42 + .../apache/spark/network/MessageChunkHeader.scala | 75 ++ .../org/apache/spark/network/ReceiverTest.scala | 37 + .../org/apache/spark/network/SenderTest.scala | 70 ++ .../apache/spark/network/netty/FileHeader.scala | 74 ++ .../apache/spark/network/netty/ShuffleCopier.scala | 118 ++ .../apache/spark/network/netty/ShuffleSender.scala | 70 ++ core/src/main/scala/org/apache/spark/package.scala | 32 + .../spark/partial/ApproximateActionListener.scala | 87 ++ .../spark/partial/ApproximateEvaluator.scala | 27 + .../org/apache/spark/partial/BoundedDouble.scala | 25 + .../org/apache/spark/partial/CountEvaluator.scala | 55 + .../spark/partial/GroupedCountEvaluator.scala | 79 ++ .../spark/partial/GroupedMeanEvaluator.scala | 82 ++ .../apache/spark/partial/GroupedSumEvaluator.scala | 89 ++ .../org/apache/spark/partial/MeanEvaluator.scala | 58 + .../org/apache/spark/partial/PartialResult.scala | 137 ++ .../org/apache/spark/partial/StudentTCacher.scala | 43 + .../org/apache/spark/partial/SumEvaluator.scala | 68 + .../main/scala/org/apache/spark/rdd/BlockRDD.scala | 51 + .../scala/org/apache/spark/rdd/CartesianRDD.scala | 90 ++ .../scala/org/apache/spark/rdd/CheckpointRDD.scala | 155 +++ .../scala/org/apache/spark/rdd/CoGroupedRDD.scala | 144 +++ .../scala/org/apache/spark/rdd/CoalescedRDD.scala | 342 +++++ .../main/scala/org/apache/spark/rdd/EmptyRDD.scala | 33 + .../scala/org/apache/spark/rdd/FilteredRDD.scala | 33 + .../scala/org/apache/spark/rdd/FlatMappedRDD.scala | 33 + .../org/apache/spark/rdd/FlatMappedValuesRDD.scala | 36 + .../scala/org/apache/spark/rdd/GlommedRDD.scala | 29 + .../scala/org/apache/spark/rdd/HadoopRDD.scala | 137 ++ .../main/scala/org/apache/spark/rdd/JdbcRDD.scala | 120 ++ .../org/apache/spark/rdd/MapPartitionsRDD.scala | 37 + .../spark/rdd/MapPartitionsWithIndexRDD.scala | 41 + .../scala/org/apache/spark/rdd/MappedRDD.scala | 30 + .../org/apache/spark/rdd/MappedValuesRDD.scala | 34 + .../scala/org/apache/spark/rdd/NewHadoopRDD.scala | 126 ++ .../org/apache/spark/rdd/OrderedRDDFunctions.scala | 51 + .../apache/spark/rdd/ParallelCollectionRDD.scala | 151 +++ .../org/apache/spark/rdd/PartitionPruningRDD.scala | 72 ++ .../main/scala/org/apache/spark/rdd/PipedRDD.scala | 125 ++ .../scala/org/apache/spark/rdd/SampledRDD.scala | 66 + .../scala/org/apache/spark/rdd/ShuffledRDD.scala | 67 + .../scala/org/apache/spark/rdd/SubtractedRDD.scala | 129 ++ .../main/scala/org/apache/spark/rdd/UnionRDD.scala | 73 ++ .../org/apache/spark/rdd/ZippedPartitionsRDD.scala | 143 +++ .../scala/org/apache/spark/rdd/ZippedRDD.scala | 85 ++ .../org/apache/spark/scheduler/ActiveJob.scala | 39 + .../org/apache/spark/scheduler/DAGScheduler.scala | 849 +++++++++++++ .../apache/spark/scheduler/DAGSchedulerEvent.scala | 63 + .../spark/scheduler/DAGSchedulerSource.scala | 30 + .../apache/spark/scheduler/InputFormatInfo.scala | 178 +++ .../org/apache/spark/scheduler/JobListener.scala | 28 + .../org/apache/spark/scheduler/JobLogger.scala | 292 +++++ .../org/apache/spark/scheduler/JobResult.scala | 26 + .../org/apache/spark/scheduler/JobWaiter.scala | 66 + .../org/apache/spark/scheduler/MapStatus.scala | 44 + .../org/apache/spark/scheduler/ResultTask.scala | 134 ++ .../apache/spark/scheduler/ShuffleMapTask.scala | 189 +++ .../org/apache/spark/scheduler/SparkListener.scala | 204 +++ .../apache/spark/scheduler/SparkListenerBus.scala | 74 ++ .../org/apache/spark/scheduler/SplitInfo.scala | 78 ++ .../scala/org/apache/spark/scheduler/Stage.scala | 112 ++ .../org/apache/spark/scheduler/StageInfo.scala | 29 + .../scala/org/apache/spark/scheduler/Task.scala | 115 ++ .../org/apache/spark/scheduler/TaskLocation.scala | 34 + .../org/apache/spark/scheduler/TaskResult.scala | 72 ++ .../org/apache/spark/scheduler/TaskScheduler.scala | 52 + .../spark/scheduler/TaskSchedulerListener.scala | 45 + .../scala/org/apache/spark/scheduler/TaskSet.scala | 35 + .../spark/scheduler/cluster/ClusterScheduler.scala | 440 +++++++ .../scheduler/cluster/ClusterTaskSetManager.scala | 712 +++++++++++ .../scheduler/cluster/ExecutorLossReason.scala | 38 + .../org/apache/spark/scheduler/cluster/Pool.scala | 121 ++ .../spark/scheduler/cluster/Schedulable.scala | 48 + .../scheduler/cluster/SchedulableBuilder.scala | 137 ++ .../spark/scheduler/cluster/SchedulerBackend.scala | 37 + .../scheduler/cluster/SchedulingAlgorithm.scala | 81 ++ .../spark/scheduler/cluster/SchedulingMode.scala | 29 + .../cluster/SparkDeploySchedulerBackend.scala | 91 ++ .../cluster/StandaloneClusterMessage.scala | 63 + .../cluster/StandaloneSchedulerBackend.scala | 198 +++ .../spark/scheduler/cluster/TaskDescription.scala | 37 + .../apache/spark/scheduler/cluster/TaskInfo.scala | 72 ++ .../spark/scheduler/cluster/TaskLocality.scala | 32 + .../spark/scheduler/cluster/TaskSetManager.scala | 51 + .../spark/scheduler/cluster/WorkerOffer.scala | 24 + .../spark/scheduler/local/LocalScheduler.scala | 272 ++++ .../scheduler/local/LocalTaskSetManager.scala | 194 +++ .../mesos/CoarseMesosSchedulerBackend.scala | 286 +++++ .../scheduler/mesos/MesosSchedulerBackend.scala | 342 +++++ .../org/apache/spark/serializer/Serializer.scala | 112 ++ .../spark/serializer/SerializerManager.scala | 62 + .../org/apache/spark/storage/BlockException.scala | 22 + .../apache/spark/storage/BlockFetchTracker.scala | 27 + .../spark/storage/BlockFetcherIterator.scala | 348 ++++++ .../org/apache/spark/storage/BlockManager.scala | 1046 ++++++++++++++++ .../org/apache/spark/storage/BlockManagerId.scala | 118 ++ .../apache/spark/storage/BlockManagerMaster.scala | 178 +++ .../spark/storage/BlockManagerMasterActor.scala | 404 ++++++ .../spark/storage/BlockManagerMessages.scala | 110 ++ .../spark/storage/BlockManagerSlaveActor.scala | 39 + .../apache/spark/storage/BlockManagerSource.scala | 48 + .../apache/spark/storage/BlockManagerWorker.scala | 139 +++ .../org/apache/spark/storage/BlockMessage.scala | 223 ++++ .../apache/spark/storage/BlockMessageArray.scala | 159 +++ .../apache/spark/storage/BlockObjectWriter.scala | 65 + .../org/apache/spark/storage/BlockStore.scala | 61 + .../scala/org/apache/spark/storage/DiskStore.scala | 329 +++++ .../org/apache/spark/storage/MemoryStore.scala | 257 ++++ .../scala/org/apache/spark/storage/PutResult.scala | 26 + .../apache/spark/storage/ShuffleBlockManager.scala | 67 + .../org/apache/spark/storage/StorageLevel.scala | 146 +++ .../org/apache/spark/storage/StorageUtils.scala | 115 ++ .../org/apache/spark/storage/ThreadingTest.scala | 113 ++ .../scala/org/apache/spark/ui/JettyUtils.scala | 132 ++ core/src/main/scala/org/apache/spark/ui/Page.scala | 22 + .../main/scala/org/apache/spark/ui/SparkUI.scala | 87 ++ .../main/scala/org/apache/spark/ui/UIUtils.scala | 131 ++ .../org/apache/spark/ui/UIWorkloadGenerator.scala | 105 ++ .../org/apache/spark/ui/env/EnvironmentUI.scala | 91 ++ .../org/apache/spark/ui/exec/ExecutorsUI.scala | 136 ++ .../scala/org/apache/spark/ui/jobs/IndexPage.scala | 90 ++ .../apache/spark/ui/jobs/JobProgressListener.scala | 156 +++ .../org/apache/spark/ui/jobs/JobProgressUI.scala | 60 + .../scala/org/apache/spark/ui/jobs/PoolPage.scala | 32 + .../scala/org/apache/spark/ui/jobs/PoolTable.scala | 55 + .../scala/org/apache/spark/ui/jobs/StagePage.scala | 183 +++ .../org/apache/spark/ui/jobs/StageTable.scala | 107 ++ .../apache/spark/ui/storage/BlockManagerUI.scala | 41 + .../org/apache/spark/ui/storage/IndexPage.scala | 65 + .../org/apache/spark/ui/storage/RDDPage.scala | 132 ++ .../scala/org/apache/spark/util/AkkaUtils.scala | 72 ++ .../apache/spark/util/BoundedPriorityQueue.scala | 62 + .../apache/spark/util/ByteBufferInputStream.scala | 80 ++ .../main/scala/org/apache/spark/util/Clock.scala | 29 + .../org/apache/spark/util/CompletionIterator.scala | 42 + .../scala/org/apache/spark/util/Distribution.scala | 82 ++ .../scala/org/apache/spark/util/IdGenerator.scala | 31 + .../scala/org/apache/spark/util/IntParam.scala | 31 + .../scala/org/apache/spark/util/MemoryParam.scala | 34 + .../org/apache/spark/util/MetadataCleaner.scala | 61 + .../scala/org/apache/spark/util/MutablePair.scala | 36 + .../scala/org/apache/spark/util/NextIterator.scala | 88 ++ .../spark/util/RateLimitedOutputStream.scala | 79 ++ .../org/apache/spark/util/SerializableBuffer.scala | 54 + .../scala/org/apache/spark/util/StatCounter.scala | 131 ++ .../org/apache/spark/util/TimeStampedHashMap.scala | 122 ++ .../org/apache/spark/util/TimeStampedHashSet.scala | 86 ++ .../main/scala/org/apache/spark/util/Vector.scala | 139 +++ core/src/main/scala/spark/Accumulators.scala | 256 ---- core/src/main/scala/spark/Aggregator.scala | 61 - .../scala/spark/BlockStoreShuffleFetcher.scala | 89 -- core/src/main/scala/spark/CacheManager.scala | 82 -- core/src/main/scala/spark/ClosureCleaner.scala | 231 ---- core/src/main/scala/spark/Dependency.scala | 81 -- core/src/main/scala/spark/DoubleRDDFunctions.scala | 78 -- .../main/scala/spark/FetchFailedException.scala | 44 - core/src/main/scala/spark/HttpFileServer.scala | 62 - core/src/main/scala/spark/HttpServer.scala | 88 -- core/src/main/scala/spark/JavaSerializer.scala | 83 -- core/src/main/scala/spark/KryoSerializer.scala | 156 --- core/src/main/scala/spark/Logging.scala | 95 -- core/src/main/scala/spark/MapOutputTracker.scala | 338 ----- core/src/main/scala/spark/PairRDDFunctions.scala | 703 ----------- core/src/main/scala/spark/Partition.scala | 31 - core/src/main/scala/spark/Partitioner.scala | 135 -- core/src/main/scala/spark/RDD.scala | 957 -------------- core/src/main/scala/spark/RDDCheckpointData.scala | 130 -- .../scala/spark/SequenceFileRDDFunctions.scala | 107 -- .../main/scala/spark/SerializableWritable.scala | 42 - core/src/main/scala/spark/ShuffleFetcher.scala | 35 - core/src/main/scala/spark/SizeEstimator.scala | 283 ----- core/src/main/scala/spark/SparkContext.scala | 995 --------------- core/src/main/scala/spark/SparkEnv.scala | 241 ---- core/src/main/scala/spark/SparkException.scala | 24 - core/src/main/scala/spark/SparkFiles.java | 42 - core/src/main/scala/spark/SparkHadoopWriter.scala | 201 --- core/src/main/scala/spark/TaskContext.scala | 41 - core/src/main/scala/spark/TaskEndReason.scala | 51 - core/src/main/scala/spark/TaskState.scala | 51 - core/src/main/scala/spark/Utils.scala | 780 ------------ .../main/scala/spark/api/java/JavaDoubleRDD.scala | 167 --- .../main/scala/spark/api/java/JavaPairRDD.scala | 601 --------- core/src/main/scala/spark/api/java/JavaRDD.scala | 114 -- .../main/scala/spark/api/java/JavaRDDLike.scala | 426 ------- .../scala/spark/api/java/JavaSparkContext.scala | 418 ------- .../java/JavaSparkContextVarargsWorkaround.java | 64 - core/src/main/scala/spark/api/java/JavaUtils.scala | 28 - .../main/scala/spark/api/java/StorageLevels.java | 48 - .../api/java/function/DoubleFlatMapFunction.java | 37 - .../spark/api/java/function/DoubleFunction.java | 34 - .../spark/api/java/function/FlatMapFunction.scala | 28 - .../spark/api/java/function/FlatMapFunction2.scala | 28 - .../scala/spark/api/java/function/Function.java | 39 - .../scala/spark/api/java/function/Function2.java | 38 - .../api/java/function/PairFlatMapFunction.java | 46 - .../spark/api/java/function/PairFunction.java | 45 - .../spark/api/java/function/VoidFunction.scala | 33 - .../spark/api/java/function/WrappedFunction1.scala | 32 - .../spark/api/java/function/WrappedFunction2.scala | 32 - .../scala/spark/api/python/PythonPartitioner.scala | 50 - .../main/scala/spark/api/python/PythonRDD.scala | 344 ------ .../spark/api/python/PythonWorkerFactory.scala | 132 -- .../spark/broadcast/BitTorrentBroadcast.scala | 1057 ---------------- .../src/main/scala/spark/broadcast/Broadcast.scala | 70 -- .../scala/spark/broadcast/BroadcastFactory.scala | 30 - .../main/scala/spark/broadcast/HttpBroadcast.scala | 171 --- .../main/scala/spark/broadcast/MultiTracker.scala | 409 ------ .../main/scala/spark/broadcast/SourceInfo.scala | 54 - .../main/scala/spark/broadcast/TreeBroadcast.scala | 602 --------- .../spark/deploy/ApplicationDescription.scala | 32 - core/src/main/scala/spark/deploy/Command.scala | 26 - .../main/scala/spark/deploy/DeployMessage.scala | 130 -- .../main/scala/spark/deploy/ExecutorState.scala | 28 - .../src/main/scala/spark/deploy/JsonProtocol.scala | 86 -- .../scala/spark/deploy/LocalSparkCluster.scala | 69 -- .../main/scala/spark/deploy/SparkHadoopUtil.scala | 36 - core/src/main/scala/spark/deploy/WebUI.scala | 47 - .../main/scala/spark/deploy/client/Client.scala | 145 --- .../scala/spark/deploy/client/ClientListener.scala | 35 - .../scala/spark/deploy/client/TestClient.scala | 51 - .../scala/spark/deploy/client/TestExecutor.scala | 27 - .../spark/deploy/master/ApplicationInfo.scala | 85 -- .../spark/deploy/master/ApplicationSource.scala | 24 - .../spark/deploy/master/ApplicationState.scala | 28 - .../scala/spark/deploy/master/ExecutorInfo.scala | 32 - .../main/scala/spark/deploy/master/Master.scala | 386 ------ .../spark/deploy/master/MasterArguments.scala | 89 -- .../scala/spark/deploy/master/MasterSource.scala | 25 - .../scala/spark/deploy/master/WorkerInfo.scala | 77 -- .../scala/spark/deploy/master/WorkerState.scala | 24 - .../spark/deploy/master/ui/ApplicationPage.scala | 118 -- .../scala/spark/deploy/master/ui/IndexPage.scala | 141 --- .../scala/spark/deploy/master/ui/MasterWebUI.scala | 80 -- .../scala/spark/deploy/worker/ExecutorRunner.scala | 199 --- .../main/scala/spark/deploy/worker/Worker.scala | 213 ---- .../spark/deploy/worker/WorkerArguments.scala | 153 --- .../scala/spark/deploy/worker/WorkerSource.scala | 34 - .../scala/spark/deploy/worker/ui/IndexPage.scala | 115 -- .../scala/spark/deploy/worker/ui/WorkerWebUI.scala | 190 --- core/src/main/scala/spark/executor/Executor.scala | 269 ---- .../scala/spark/executor/ExecutorBackend.scala | 28 - .../scala/spark/executor/ExecutorExitCode.scala | 60 - .../main/scala/spark/executor/ExecutorSource.scala | 55 - .../spark/executor/ExecutorURLClassLoader.scala | 31 - .../spark/executor/MesosExecutorBackend.scala | 95 -- .../spark/executor/StandaloneExecutorBackend.scala | 107 -- .../main/scala/spark/executor/TaskMetrics.scala | 105 -- .../src/main/scala/spark/io/CompressionCodec.scala | 82 -- .../main/scala/spark/metrics/MetricsConfig.scala | 100 -- .../main/scala/spark/metrics/MetricsSystem.scala | 163 --- .../scala/spark/metrics/sink/ConsoleSink.scala | 59 - .../main/scala/spark/metrics/sink/CsvSink.scala | 68 - .../main/scala/spark/metrics/sink/JmxSink.scala | 35 - .../scala/spark/metrics/sink/MetricsServlet.scala | 55 - core/src/main/scala/spark/metrics/sink/Sink.scala | 23 - .../scala/spark/metrics/source/JvmSource.scala | 32 - .../main/scala/spark/metrics/source/Source.scala | 25 - .../main/scala/spark/network/BufferMessage.scala | 111 -- core/src/main/scala/spark/network/Connection.scala | 586 --------- .../scala/spark/network/ConnectionManager.scala | 720 ----------- .../scala/spark/network/ConnectionManagerId.scala | 38 - .../spark/network/ConnectionManagerTest.scala | 102 -- core/src/main/scala/spark/network/Message.scala | 93 -- .../main/scala/spark/network/MessageChunk.scala | 42 - .../scala/spark/network/MessageChunkHeader.scala | 75 -- .../main/scala/spark/network/ReceiverTest.scala | 37 - core/src/main/scala/spark/network/SenderTest.scala | 70 -- .../scala/spark/network/netty/FileHeader.scala | 74 -- .../scala/spark/network/netty/ShuffleCopier.scala | 118 -- .../scala/spark/network/netty/ShuffleSender.scala | 70 -- core/src/main/scala/spark/package.scala | 32 - .../spark/partial/ApproximateActionListener.scala | 87 -- .../scala/spark/partial/ApproximateEvaluator.scala | 27 - .../main/scala/spark/partial/BoundedDouble.scala | 25 - .../main/scala/spark/partial/CountEvaluator.scala | 55 - .../spark/partial/GroupedCountEvaluator.scala | 79 -- .../scala/spark/partial/GroupedMeanEvaluator.scala | 82 -- .../scala/spark/partial/GroupedSumEvaluator.scala | 89 -- .../main/scala/spark/partial/MeanEvaluator.scala | 58 - .../main/scala/spark/partial/PartialResult.scala | 137 -- .../main/scala/spark/partial/StudentTCacher.scala | 43 - .../main/scala/spark/partial/SumEvaluator.scala | 68 - core/src/main/scala/spark/rdd/BlockRDD.scala | 51 - core/src/main/scala/spark/rdd/CartesianRDD.scala | 90 -- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 155 --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 144 --- core/src/main/scala/spark/rdd/CoalescedRDD.scala | 342 ----- core/src/main/scala/spark/rdd/EmptyRDD.scala | 33 - core/src/main/scala/spark/rdd/FilteredRDD.scala | 33 - core/src/main/scala/spark/rdd/FlatMappedRDD.scala | 33 - .../main/scala/spark/rdd/FlatMappedValuesRDD.scala | 36 - core/src/main/scala/spark/rdd/GlommedRDD.scala | 29 - core/src/main/scala/spark/rdd/HadoopRDD.scala | 137 -- core/src/main/scala/spark/rdd/JdbcRDD.scala | 120 -- .../main/scala/spark/rdd/MapPartitionsRDD.scala | 37 - .../spark/rdd/MapPartitionsWithIndexRDD.scala | 41 - core/src/main/scala/spark/rdd/MappedRDD.scala | 30 - .../src/main/scala/spark/rdd/MappedValuesRDD.scala | 34 - core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 126 -- .../main/scala/spark/rdd/OrderedRDDFunctions.scala | 51 - .../scala/spark/rdd/ParallelCollectionRDD.scala | 151 --- .../main/scala/spark/rdd/PartitionPruningRDD.scala | 72 -- core/src/main/scala/spark/rdd/PipedRDD.scala | 125 -- core/src/main/scala/spark/rdd/SampledRDD.scala | 66 - core/src/main/scala/spark/rdd/ShuffledRDD.scala | 67 - core/src/main/scala/spark/rdd/SubtractedRDD.scala | 129 -- core/src/main/scala/spark/rdd/UnionRDD.scala | 73 -- .../main/scala/spark/rdd/ZippedPartitionsRDD.scala | 143 --- core/src/main/scala/spark/rdd/ZippedRDD.scala | 85 -- .../src/main/scala/spark/scheduler/ActiveJob.scala | 39 - .../main/scala/spark/scheduler/DAGScheduler.scala | 849 ------------- .../scala/spark/scheduler/DAGSchedulerEvent.scala | 63 - .../scala/spark/scheduler/DAGSchedulerSource.scala | 30 - .../scala/spark/scheduler/InputFormatInfo.scala | 178 --- .../main/scala/spark/scheduler/JobListener.scala | 28 - .../src/main/scala/spark/scheduler/JobLogger.scala | 292 ----- .../src/main/scala/spark/scheduler/JobResult.scala | 26 - .../src/main/scala/spark/scheduler/JobWaiter.scala | 66 - .../src/main/scala/spark/scheduler/MapStatus.scala | 44 - .../main/scala/spark/scheduler/ResultTask.scala | 134 -- .../scala/spark/scheduler/ShuffleMapTask.scala | 189 --- .../main/scala/spark/scheduler/SparkListener.scala | 204 --- .../scala/spark/scheduler/SparkListenerBus.scala | 74 -- .../src/main/scala/spark/scheduler/SplitInfo.scala | 78 -- core/src/main/scala/spark/scheduler/Stage.scala | 112 -- .../src/main/scala/spark/scheduler/StageInfo.scala | 29 - core/src/main/scala/spark/scheduler/Task.scala | 115 -- .../main/scala/spark/scheduler/TaskLocation.scala | 34 - .../main/scala/spark/scheduler/TaskResult.scala | 72 -- .../main/scala/spark/scheduler/TaskScheduler.scala | 52 - .../spark/scheduler/TaskSchedulerListener.scala | 45 - core/src/main/scala/spark/scheduler/TaskSet.scala | 35 - .../spark/scheduler/cluster/ClusterScheduler.scala | 440 ------- .../scheduler/cluster/ClusterTaskSetManager.scala | 712 ----------- .../scheduler/cluster/ExecutorLossReason.scala | 38 - .../main/scala/spark/scheduler/cluster/Pool.scala | 121 -- .../spark/scheduler/cluster/Schedulable.scala | 48 - .../scheduler/cluster/SchedulableBuilder.scala | 137 -- .../spark/scheduler/cluster/SchedulerBackend.scala | 37 - .../scheduler/cluster/SchedulingAlgorithm.scala | 81 -- .../spark/scheduler/cluster/SchedulingMode.scala | 29 - .../cluster/SparkDeploySchedulerBackend.scala | 90 -- .../cluster/StandaloneClusterMessage.scala | 63 - .../cluster/StandaloneSchedulerBackend.scala | 198 --- .../spark/scheduler/cluster/TaskDescription.scala | 37 - .../scala/spark/scheduler/cluster/TaskInfo.scala | 72 -- .../spark/scheduler/cluster/TaskLocality.scala | 32 - .../spark/scheduler/cluster/TaskSetManager.scala | 51 - .../spark/scheduler/cluster/WorkerOffer.scala | 24 - .../spark/scheduler/local/LocalScheduler.scala | 272 ---- .../scheduler/local/LocalTaskSetManager.scala | 194 --- .../mesos/CoarseMesosSchedulerBackend.scala | 284 ----- .../scheduler/mesos/MesosSchedulerBackend.scala | 342 ----- .../main/scala/spark/serializer/Serializer.scala | 112 -- .../scala/spark/serializer/SerializerManager.scala | 62 - .../main/scala/spark/storage/BlockException.scala | 22 - .../scala/spark/storage/BlockFetchTracker.scala | 27 - .../scala/spark/storage/BlockFetcherIterator.scala | 348 ------ .../main/scala/spark/storage/BlockManager.scala | 1046 ---------------- .../main/scala/spark/storage/BlockManagerId.scala | 118 -- .../scala/spark/storage/BlockManagerMaster.scala | 178 --- .../spark/storage/BlockManagerMasterActor.scala | 404 ------ .../scala/spark/storage/BlockManagerMessages.scala | 110 -- .../spark/storage/BlockManagerSlaveActor.scala | 39 - .../scala/spark/storage/BlockManagerSource.scala | 48 - .../scala/spark/storage/BlockManagerWorker.scala | 139 --- .../main/scala/spark/storage/BlockMessage.scala | 223 ---- .../scala/spark/storage/BlockMessageArray.scala | 159 --- .../scala/spark/storage/BlockObjectWriter.scala | 65 - core/src/main/scala/spark/storage/BlockStore.scala | 61 - core/src/main/scala/spark/storage/DiskStore.scala | 329 ----- .../src/main/scala/spark/storage/MemoryStore.scala | 257 ---- core/src/main/scala/spark/storage/PutResult.scala | 26 - .../scala/spark/storage/ShuffleBlockManager.scala | 67 - .../main/scala/spark/storage/StorageLevel.scala | 146 --- .../main/scala/spark/storage/StorageUtils.scala | 115 -- .../main/scala/spark/storage/ThreadingTest.scala | 113 -- core/src/main/scala/spark/ui/JettyUtils.scala | 132 -- core/src/main/scala/spark/ui/Page.scala | 22 - core/src/main/scala/spark/ui/SparkUI.scala | 87 -- core/src/main/scala/spark/ui/UIUtils.scala | 131 -- .../main/scala/spark/ui/UIWorkloadGenerator.scala | 105 -- .../main/scala/spark/ui/env/EnvironmentUI.scala | 91 -- .../src/main/scala/spark/ui/exec/ExecutorsUI.scala | 136 -- core/src/main/scala/spark/ui/jobs/IndexPage.scala | 90 -- .../scala/spark/ui/jobs/JobProgressListener.scala | 156 --- .../main/scala/spark/ui/jobs/JobProgressUI.scala | 60 - core/src/main/scala/spark/ui/jobs/PoolPage.scala | 32 - core/src/main/scala/spark/ui/jobs/PoolTable.scala | 55 - core/src/main/scala/spark/ui/jobs/StagePage.scala | 183 --- core/src/main/scala/spark/ui/jobs/StageTable.scala | 107 -- .../scala/spark/ui/storage/BlockManagerUI.scala | 41 - .../main/scala/spark/ui/storage/IndexPage.scala | 65 - core/src/main/scala/spark/ui/storage/RDDPage.scala | 132 -- core/src/main/scala/spark/util/AkkaUtils.scala | 72 -- .../scala/spark/util/BoundedPriorityQueue.scala | 62 - .../scala/spark/util/ByteBufferInputStream.scala | 80 -- core/src/main/scala/spark/util/Clock.scala | 29 - .../main/scala/spark/util/CompletionIterator.scala | 42 - core/src/main/scala/spark/util/Distribution.scala | 82 -- core/src/main/scala/spark/util/IdGenerator.scala | 31 - core/src/main/scala/spark/util/IntParam.scala | 31 - core/src/main/scala/spark/util/MemoryParam.scala | 34 - .../main/scala/spark/util/MetadataCleaner.scala | 61 - core/src/main/scala/spark/util/MutablePair.scala | 36 - core/src/main/scala/spark/util/NextIterator.scala | 88 -- .../scala/spark/util/RateLimitedOutputStream.scala | 79 -- .../main/scala/spark/util/SerializableBuffer.scala | 54 - core/src/main/scala/spark/util/StatCounter.scala | 131 -- .../main/scala/spark/util/TimeStampedHashMap.scala | 121 -- .../main/scala/spark/util/TimeStampedHashSet.scala | 86 -- core/src/main/scala/spark/util/Vector.scala | 139 --- .../test/resources/test_metrics_config.properties | 2 +- .../test/resources/test_metrics_system.properties | 6 +- .../scala/org/apache/spark/AccumulatorSuite.scala | 143 +++ .../scala/org/apache/spark/BroadcastSuite.scala | 39 + .../scala/org/apache/spark/CheckpointSuite.scala | 392 ++++++ .../org/apache/spark/ClosureCleanerSuite.scala | 146 +++ .../scala/org/apache/spark/DistributedSuite.scala | 362 ++++++ .../test/scala/org/apache/spark/DriverSuite.scala | 54 + .../test/scala/org/apache/spark/FailureSuite.scala | 127 ++ .../scala/org/apache/spark/FileServerSuite.scala | 123 ++ .../test/scala/org/apache/spark/FileSuite.scala | 212 ++++ .../test/scala/org/apache/spark/JavaAPISuite.java | 865 +++++++++++++ .../org/apache/spark/KryoSerializerSuite.scala | 208 ++++ .../scala/org/apache/spark/LocalSparkContext.scala | 68 + .../org/apache/spark/MapOutputTrackerSuite.scala | 136 ++ .../org/apache/spark/PairRDDFunctionsSuite.scala | 299 +++++ .../apache/spark/PartitionPruningRDDSuite.scala | 28 + .../scala/org/apache/spark/PartitioningSuite.scala | 150 +++ .../scala/org/apache/spark/PipedRDDSuite.scala | 93 ++ .../src/test/scala/org/apache/spark/RDDSuite.scala | 389 ++++++ .../org/apache/spark/SharedSparkContext.scala | 42 + .../scala/org/apache/spark/ShuffleNettySuite.scala | 34 + .../test/scala/org/apache/spark/ShuffleSuite.scala | 210 ++++ .../org/apache/spark/SizeEstimatorSuite.scala | 164 +++ .../test/scala/org/apache/spark/SortingSuite.scala | 123 ++ .../org/apache/spark/SparkContextInfoSuite.scala | 60 + .../scala/org/apache/spark/ThreadingSuite.scala | 152 +++ .../scala/org/apache/spark/UnpersistSuite.scala | 47 + .../test/scala/org/apache/spark/UtilsSuite.scala | 139 +++ .../org/apache/spark/ZippedPartitionsSuite.scala | 50 + .../apache/spark/io/CompressionCodecSuite.scala | 62 + .../apache/spark/metrics/MetricsConfigSuite.scala | 89 ++ .../apache/spark/metrics/MetricsSystemSuite.scala | 54 + .../scala/org/apache/spark/rdd/JdbcRDDSuite.scala | 73 ++ .../spark/rdd/ParallelCollectionSplitSuite.scala | 212 ++++ .../apache/spark/scheduler/DAGSchedulerSuite.scala | 421 +++++++ .../apache/spark/scheduler/JobLoggerSuite.scala | 121 ++ .../spark/scheduler/SparkListenerSuite.scala | 102 ++ .../apache/spark/scheduler/TaskContextSuite.scala | 49 + .../scheduler/cluster/ClusterSchedulerSuite.scala | 266 ++++ .../cluster/ClusterTaskSetManagerSuite.scala | 273 ++++ .../apache/spark/scheduler/cluster/FakeTask.scala | 26 + .../scheduler/local/LocalSchedulerSuite.scala | 223 ++++ .../apache/spark/storage/BlockManagerSuite.scala | 666 ++++++++++ .../test/scala/org/apache/spark/ui/UISuite.scala | 47 + .../org/apache/spark/util/DistributionSuite.scala | 42 + .../scala/org/apache/spark/util/FakeClock.scala | 26 + .../org/apache/spark/util/NextIteratorSuite.scala | 85 ++ .../spark/util/RateLimitedOutputStreamSuite.scala | 40 + core/src/test/scala/spark/AccumulatorSuite.scala | 143 --- core/src/test/scala/spark/BroadcastSuite.scala | 39 - core/src/test/scala/spark/CheckpointSuite.scala | 392 ------ .../src/test/scala/spark/ClosureCleanerSuite.scala | 146 --- core/src/test/scala/spark/DistributedSuite.scala | 362 ------ core/src/test/scala/spark/DriverSuite.scala | 54 - core/src/test/scala/spark/FailureSuite.scala | 127 -- core/src/test/scala/spark/FileServerSuite.scala | 123 -- core/src/test/scala/spark/FileSuite.scala | 212 ---- core/src/test/scala/spark/JavaAPISuite.java | 865 ------------- .../src/test/scala/spark/KryoSerializerSuite.scala | 208 ---- core/src/test/scala/spark/LocalSparkContext.scala | 68 - .../test/scala/spark/MapOutputTrackerSuite.scala | 136 -- .../test/scala/spark/PairRDDFunctionsSuite.scala | 299 ----- .../scala/spark/PartitionPruningRDDSuite.scala | 28 - core/src/test/scala/spark/PartitioningSuite.scala | 150 --- core/src/test/scala/spark/PipedRDDSuite.scala | 93 -- core/src/test/scala/spark/RDDSuite.scala | 389 ------ core/src/test/scala/spark/SharedSparkContext.scala | 42 - core/src/test/scala/spark/ShuffleNettySuite.scala | 34 - core/src/test/scala/spark/ShuffleSuite.scala | 210 ---- core/src/test/scala/spark/SizeEstimatorSuite.scala | 164 --- core/src/test/scala/spark/SortingSuite.scala | 123 -- .../test/scala/spark/SparkContextInfoSuite.scala | 60 - core/src/test/scala/spark/ThreadingSuite.scala | 152 --- core/src/test/scala/spark/UnpersistSuite.scala | 47 - core/src/test/scala/spark/UtilsSuite.scala | 139 --- .../test/scala/spark/ZippedPartitionsSuite.scala | 50 - .../scala/spark/io/CompressionCodecSuite.scala | 62 - .../scala/spark/metrics/MetricsConfigSuite.scala | 89 -- .../scala/spark/metrics/MetricsSystemSuite.scala | 53 - core/src/test/scala/spark/rdd/JdbcRDDSuite.scala | 73 -- .../spark/rdd/ParallelCollectionSplitSuite.scala | 212 ---- .../scala/spark/scheduler/DAGSchedulerSuite.scala | 421 ------- .../scala/spark/scheduler/JobLoggerSuite.scala | 121 -- .../scala/spark/scheduler/SparkListenerSuite.scala | 102 -- .../scala/spark/scheduler/TaskContextSuite.scala | 49 - .../scheduler/cluster/ClusterSchedulerSuite.scala | 266 ---- .../cluster/ClusterTaskSetManagerSuite.scala | 273 ---- .../scala/spark/scheduler/cluster/FakeTask.scala | 26 - .../scheduler/local/LocalSchedulerSuite.scala | 223 ---- .../scala/spark/storage/BlockManagerSuite.scala | 665 ---------- core/src/test/scala/spark/ui/UISuite.scala | 47 - .../test/scala/spark/util/DistributionSuite.scala | 42 - core/src/test/scala/spark/util/FakeClock.scala | 26 - .../test/scala/spark/util/NextIteratorSuite.scala | 85 -- .../spark/util/RateLimitedOutputStreamSuite.scala | 40 - docs/_layouts/global.html | 2 +- examples/pom.xml | 14 +- .../java/org/apache/spark/examples/JavaHdfsLR.java | 140 +++ .../java/org/apache/spark/examples/JavaKMeans.java | 131 ++ .../org/apache/spark/examples/JavaLogQuery.java | 131 ++ .../org/apache/spark/examples/JavaPageRank.java | 115 ++ .../org/apache/spark/examples/JavaSparkPi.java | 65 + .../java/org/apache/spark/examples/JavaTC.java | 97 ++ .../org/apache/spark/examples/JavaWordCount.java | 66 + .../org/apache/spark/mllib/examples/JavaALS.java | 87 ++ .../apache/spark/mllib/examples/JavaKMeans.java | 81 ++ .../org/apache/spark/mllib/examples/JavaLR.java | 85 ++ .../streaming/examples/JavaFlumeEventCount.java | 68 + .../streaming/examples/JavaNetworkWordCount.java | 79 ++ .../spark/streaming/examples/JavaQueueStream.java | 80 ++ .../src/main/java/spark/examples/JavaHdfsLR.java | 140 --- .../src/main/java/spark/examples/JavaKMeans.java | 131 -- .../src/main/java/spark/examples/JavaLogQuery.java | 131 -- .../src/main/java/spark/examples/JavaPageRank.java | 115 -- .../src/main/java/spark/examples/JavaSparkPi.java | 65 - examples/src/main/java/spark/examples/JavaTC.java | 97 -- .../main/java/spark/examples/JavaWordCount.java | 66 - .../main/java/spark/mllib/examples/JavaALS.java | 87 -- .../main/java/spark/mllib/examples/JavaKMeans.java | 81 -- .../src/main/java/spark/mllib/examples/JavaLR.java | 85 -- .../streaming/examples/JavaFlumeEventCount.java | 68 - .../streaming/examples/JavaNetworkWordCount.java | 79 -- .../spark/streaming/examples/JavaQueueStream.java | 80 -- .../org/apache/spark/examples/BroadcastTest.scala | 50 + .../org/apache/spark/examples/CassandraTest.scala | 213 ++++ .../spark/examples/ExceptionHandlingTest.scala | 38 + .../org/apache/spark/examples/GroupByTest.scala | 57 + .../org/apache/spark/examples/HBaseTest.scala | 52 + .../scala/org/apache/spark/examples/HdfsTest.scala | 37 + .../scala/org/apache/spark/examples/LocalALS.scala | 140 +++ .../org/apache/spark/examples/LocalFileLR.scala | 55 + .../org/apache/spark/examples/LocalKMeans.scala | 99 ++ .../scala/org/apache/spark/examples/LocalLR.scala | 63 + .../scala/org/apache/spark/examples/LocalPi.scala | 34 + .../scala/org/apache/spark/examples/LogQuery.scala | 85 ++ .../apache/spark/examples/MultiBroadcastTest.scala | 53 + .../spark/examples/SimpleSkewedGroupByTest.scala | 71 ++ .../apache/spark/examples/SkewedGroupByTest.scala | 61 + .../scala/org/apache/spark/examples/SparkALS.scala | 143 +++ .../org/apache/spark/examples/SparkHdfsLR.scala | 78 ++ .../org/apache/spark/examples/SparkKMeans.scala | 91 ++ .../scala/org/apache/spark/examples/SparkLR.scala | 71 ++ .../org/apache/spark/examples/SparkPageRank.scala | 46 + .../scala/org/apache/spark/examples/SparkPi.scala | 43 + .../scala/org/apache/spark/examples/SparkTC.scala | 75 ++ .../spark/examples/bagel/PageRankUtils.scala | 123 ++ .../spark/examples/bagel/WikipediaPageRank.scala | 101 ++ .../bagel/WikipediaPageRankStandalone.scala | 223 ++++ .../spark/streaming/examples/ActorWordCount.scala | 175 +++ .../spark/streaming/examples/FlumeEventCount.scala | 61 + .../spark/streaming/examples/HdfsWordCount.scala | 54 + .../spark/streaming/examples/KafkaWordCount.scala | 98 ++ .../streaming/examples/NetworkWordCount.scala | 54 + .../spark/streaming/examples/QueueStream.scala | 57 + .../spark/streaming/examples/RawNetworkGrep.scala | 64 + .../examples/StatefulNetworkWordCount.scala | 67 + .../streaming/examples/TwitterAlgebirdCMS.scala | 110 ++ .../streaming/examples/TwitterAlgebirdHLL.scala | 88 ++ .../streaming/examples/TwitterPopularTags.scala | 70 ++ .../spark/streaming/examples/ZeroMQWordCount.scala | 91 ++ .../examples/clickstream/PageViewGenerator.scala | 102 ++ .../examples/clickstream/PageViewStream.scala | 101 ++ .../main/scala/spark/examples/BroadcastTest.scala | 50 - .../main/scala/spark/examples/CassandraTest.scala | 213 ---- .../spark/examples/ExceptionHandlingTest.scala | 38 - .../main/scala/spark/examples/GroupByTest.scala | 57 - .../src/main/scala/spark/examples/HBaseTest.scala | 52 - .../src/main/scala/spark/examples/HdfsTest.scala | 37 - .../src/main/scala/spark/examples/LocalALS.scala | 140 --- .../main/scala/spark/examples/LocalFileLR.scala | 55 - .../main/scala/spark/examples/LocalKMeans.scala | 99 -- .../src/main/scala/spark/examples/LocalLR.scala | 63 - .../src/main/scala/spark/examples/LocalPi.scala | 34 - .../src/main/scala/spark/examples/LogQuery.scala | 85 -- .../scala/spark/examples/MultiBroadcastTest.scala | 53 - .../spark/examples/SimpleSkewedGroupByTest.scala | 71 -- .../scala/spark/examples/SkewedGroupByTest.scala | 61 - .../src/main/scala/spark/examples/SparkALS.scala | 143 --- .../main/scala/spark/examples/SparkHdfsLR.scala | 78 -- .../main/scala/spark/examples/SparkKMeans.scala | 91 -- .../src/main/scala/spark/examples/SparkLR.scala | 71 -- .../main/scala/spark/examples/SparkPageRank.scala | 46 - .../src/main/scala/spark/examples/SparkPi.scala | 43 - .../src/main/scala/spark/examples/SparkTC.scala | 75 -- .../scala/spark/examples/bagel/PageRankUtils.scala | 123 -- .../spark/examples/bagel/WikipediaPageRank.scala | 101 -- .../bagel/WikipediaPageRankStandalone.scala | 223 ---- .../spark/streaming/examples/ActorWordCount.scala | 175 --- .../spark/streaming/examples/FlumeEventCount.scala | 61 - .../spark/streaming/examples/HdfsWordCount.scala | 54 - .../spark/streaming/examples/KafkaWordCount.scala | 98 -- .../streaming/examples/NetworkWordCount.scala | 54 - .../spark/streaming/examples/QueueStream.scala | 57 - .../spark/streaming/examples/RawNetworkGrep.scala | 64 - .../examples/StatefulNetworkWordCount.scala | 67 - .../streaming/examples/TwitterAlgebirdCMS.scala | 110 -- .../streaming/examples/TwitterAlgebirdHLL.scala | 88 -- .../streaming/examples/TwitterPopularTags.scala | 70 -- .../spark/streaming/examples/ZeroMQWordCount.scala | 91 -- .../examples/clickstream/PageViewGenerator.scala | 102 -- .../examples/clickstream/PageViewStream.scala | 101 -- mllib/pom.xml | 6 +- .../mllib/classification/ClassificationModel.scala | 21 + .../mllib/classification/LogisticRegression.scala | 188 +++ .../apache/spark/mllib/classification/SVM.scala | 187 +++ .../org/apache/spark/mllib/clustering/KMeans.scala | 335 +++++ .../spark/mllib/clustering/KMeansModel.scala | 44 + .../spark/mllib/clustering/LocalKMeans.scala | 105 ++ .../apache/spark/mllib/optimization/Gradient.scala | 98 ++ .../spark/mllib/optimization/GradientDescent.scala | 166 +++ .../spark/mllib/optimization/Optimizer.scala | 29 + .../apache/spark/mllib/optimization/Updater.scala | 99 ++ .../apache/spark/mllib/recommendation/ALS.scala | 453 +++++++ .../recommendation/MatrixFactorizationModel.scala | 49 + .../regression/GeneralizedLinearAlgorithm.scala | 159 +++ .../spark/mllib/regression/LabeledPoint.scala | 26 + .../org/apache/spark/mllib/regression/Lasso.scala | 210 ++++ .../spark/mllib/regression/LinearRegression.scala | 167 +++ .../spark/mllib/regression/RegressionModel.scala | 38 + .../spark/mllib/regression/RidgeRegression.scala | 213 ++++ .../apache/spark/mllib/util/DataValidators.scala | 42 + .../spark/mllib/util/KMeansDataGenerator.scala | 84 ++ .../spark/mllib/util/LinearDataGenerator.scala | 132 ++ .../util/LogisticRegressionDataGenerator.scala | 81 ++ .../apache/spark/mllib/util/MFDataGenerator.scala | 113 ++ .../org/apache/spark/mllib/util/MLUtils.scala | 122 ++ .../apache/spark/mllib/util/SVMDataGenerator.scala | 50 + .../mllib/classification/ClassificationModel.scala | 21 - .../mllib/classification/LogisticRegression.scala | 188 --- .../scala/spark/mllib/classification/SVM.scala | 187 --- .../main/scala/spark/mllib/clustering/KMeans.scala | 335 ----- .../scala/spark/mllib/clustering/KMeansModel.scala | 44 - .../scala/spark/mllib/clustering/LocalKMeans.scala | 105 -- .../scala/spark/mllib/optimization/Gradient.scala | 98 -- .../spark/mllib/optimization/GradientDescent.scala | 166 --- .../scala/spark/mllib/optimization/Optimizer.scala | 29 - .../scala/spark/mllib/optimization/Updater.scala | 99 -- .../scala/spark/mllib/recommendation/ALS.scala | 453 ------- .../recommendation/MatrixFactorizationModel.scala | 49 - .../regression/GeneralizedLinearAlgorithm.scala | 159 --- .../spark/mllib/regression/LabeledPoint.scala | 26 - .../main/scala/spark/mllib/regression/Lasso.scala | 210 ---- .../spark/mllib/regression/LinearRegression.scala | 167 --- .../spark/mllib/regression/RegressionModel.scala | 38 - .../spark/mllib/regression/RidgeRegression.scala | 213 ---- .../scala/spark/mllib/util/DataValidators.scala | 42 - .../spark/mllib/util/KMeansDataGenerator.scala | 84 -- .../spark/mllib/util/LinearDataGenerator.scala | 132 -- .../util/LogisticRegressionDataGenerator.scala | 81 -- .../scala/spark/mllib/util/MFDataGenerator.scala | 113 -- .../src/main/scala/spark/mllib/util/MLUtils.scala | 122 -- .../scala/spark/mllib/util/SVMDataGenerator.scala | 50 - .../JavaLogisticRegressionSuite.java | 98 ++ .../spark/mllib/classification/JavaSVMSuite.java | 98 ++ .../spark/mllib/clustering/JavaKMeansSuite.java | 115 ++ .../spark/mllib/recommendation/JavaALSSuite.java | 110 ++ .../spark/mllib/regression/JavaLassoSuite.java | 97 ++ .../regression/JavaLinearRegressionSuite.java | 94 ++ .../mllib/regression/JavaRidgeRegressionSuite.java | 110 ++ .../JavaLogisticRegressionSuite.java | 98 -- .../spark/mllib/classification/JavaSVMSuite.java | 98 -- .../spark/mllib/clustering/JavaKMeansSuite.java | 115 -- .../spark/mllib/recommendation/JavaALSSuite.java | 110 -- .../spark/mllib/regression/JavaLassoSuite.java | 97 -- .../regression/JavaLinearRegressionSuite.java | 94 -- .../mllib/regression/JavaRidgeRegressionSuite.java | 110 -- .../classification/LogisticRegressionSuite.scala | 150 +++ .../spark/mllib/classification/SVMSuite.scala | 169 +++ .../spark/mllib/clustering/KMeansSuite.scala | 173 +++ .../spark/mllib/recommendation/ALSSuite.scala | 125 ++ .../apache/spark/mllib/regression/LassoSuite.scala | 121 ++ .../mllib/regression/LinearRegressionSuite.scala | 72 ++ .../mllib/regression/RidgeRegressionSuite.scala | 90 ++ .../classification/LogisticRegressionSuite.scala | 150 --- .../spark/mllib/classification/SVMSuite.scala | 169 --- .../scala/spark/mllib/clustering/KMeansSuite.scala | 173 --- .../spark/mllib/recommendation/ALSSuite.scala | 125 -- .../scala/spark/mllib/regression/LassoSuite.scala | 121 -- .../mllib/regression/LinearRegressionSuite.scala | 72 -- .../mllib/regression/RidgeRegressionSuite.scala | 90 -- pom.xml | 14 +- project/SparkBuild.scala | 12 +- python/pyspark/context.py | 4 +- python/pyspark/files.py | 2 +- python/pyspark/java_gateway.py | 4 +- repl-bin/pom.xml | 12 +- repl/pom.xml | 12 +- .../apache/spark/repl/ExecutorClassLoader.scala | 124 ++ .../main/scala/org/apache/spark/repl/Main.scala | 33 + .../scala/org/apache/spark/repl/SparkHelper.scala | 5 + .../scala/org/apache/spark/repl/SparkILoop.scala | 1008 +++++++++++++++ .../scala/org/apache/spark/repl/SparkIMain.scala | 1160 +++++++++++++++++ .../org/apache/spark/repl/SparkISettings.scala | 63 + .../scala/org/apache/spark/repl/SparkImports.scala | 214 ++++ .../apache/spark/repl/SparkJLineCompletion.scala | 379 ++++++ .../org/apache/spark/repl/SparkJLineReader.scala | 79 ++ .../apache/spark/repl/SparkMemberHandlers.scala | 207 ++++ .../scala/spark/repl/ExecutorClassLoader.scala | 124 -- repl/src/main/scala/spark/repl/Main.scala | 33 - repl/src/main/scala/spark/repl/SparkHelper.scala | 5 - repl/src/main/scala/spark/repl/SparkILoop.scala | 1008 --------------- repl/src/main/scala/spark/repl/SparkIMain.scala | 1160 ----------------- .../src/main/scala/spark/repl/SparkISettings.scala | 63 - repl/src/main/scala/spark/repl/SparkImports.scala | 214 ---- .../scala/spark/repl/SparkJLineCompletion.scala | 379 ------ .../main/scala/spark/repl/SparkJLineReader.scala | 79 -- .../scala/spark/repl/SparkMemberHandlers.scala | 207 ---- .../scala/org/apache/spark/repl/ReplSuite.scala | 207 ++++ repl/src/test/scala/spark/repl/ReplSuite.scala | 207 ---- spark-executor | 2 +- spark-shell | 2 +- spark-shell.cmd | 2 +- streaming/pom.xml | 6 +- .../org/apache/spark/streaming/Checkpoint.scala | 190 +++ .../scala/org/apache/spark/streaming/DStream.scala | 702 +++++++++++ .../spark/streaming/DStreamCheckpointData.scala | 110 ++ .../org/apache/spark/streaming/DStreamGraph.scala | 167 +++ .../org/apache/spark/streaming/Duration.scala | 83 ++ .../org/apache/spark/streaming/Interval.scala | 59 + .../scala/org/apache/spark/streaming/Job.scala | 41 + .../org/apache/spark/streaming/JobManager.scala | 88 ++ .../spark/streaming/NetworkInputTracker.scala | 173 +++ .../spark/streaming/PairDStreamFunctions.scala | 534 ++++++++ .../org/apache/spark/streaming/Scheduler.scala | 131 ++ .../apache/spark/streaming/StreamingContext.scala | 563 +++++++++ .../scala/org/apache/spark/streaming/Time.scala | 72 ++ .../spark/streaming/api/java/JavaDStream.scala | 102 ++ .../spark/streaming/api/java/JavaDStreamLike.scala | 316 +++++ .../spark/streaming/api/java/JavaPairDStream.scala | 613 +++++++++ .../streaming/api/java/JavaStreamingContext.scala | 614 +++++++++ .../spark/streaming/dstream/CoGroupedDStream.scala | 57 + .../streaming/dstream/ConstantInputDStream.scala | 36 + .../spark/streaming/dstream/FileInputDStream.scala | 199 +++ .../spark/streaming/dstream/FilteredDStream.scala | 38 + .../streaming/dstream/FlatMapValuedDStream.scala | 37 + .../streaming/dstream/FlatMappedDStream.scala | 37 + .../streaming/dstream/FlumeInputDStream.scala | 154 +++ .../spark/streaming/dstream/ForEachDStream.scala | 45 + .../spark/streaming/dstream/GlommedDStream.scala | 34 + .../spark/streaming/dstream/InputDStream.scala | 70 ++ .../streaming/dstream/KafkaInputDStream.scala | 141 +++ .../streaming/dstream/MapPartitionedDStream.scala | 38 + .../spark/streaming/dstream/MapValuedDStream.scala | 38 + .../spark/streaming/dstream/MappedDStream.scala | 37 + .../streaming/dstream/NetworkInputDStream.scala | 272 ++++ .../streaming/dstream/PluggableInputDStream.scala | 30 + .../streaming/dstream/QueueInputDStream.scala | 59 + .../spark/streaming/dstream/RawInputDStream.scala | 108 ++ .../streaming/dstream/ReducedWindowedDStream.scala | 174 +++ .../spark/streaming/dstream/ShuffledDStream.scala | 44 + .../streaming/dstream/SocketInputDStream.scala | 94 ++ .../spark/streaming/dstream/StateDStream.scala | 109 ++ .../streaming/dstream/TransformedDStream.scala | 36 + .../streaming/dstream/TwitterInputDStream.scala | 99 ++ .../spark/streaming/dstream/UnionDStream.scala | 57 + .../spark/streaming/dstream/WindowedDStream.scala | 57 + .../spark/streaming/receivers/ActorReceiver.scala | 175 +++ .../spark/streaming/receivers/ZeroMQReceiver.scala | 50 + .../org/apache/spark/streaming/util/Clock.scala | 101 ++ .../spark/streaming/util/MasterFailureTest.scala | 414 +++++++ .../spark/streaming/util/RawTextHelper.scala | 115 ++ .../spark/streaming/util/RawTextSender.scala | 77 ++ .../spark/streaming/util/RecurringTimer.scala | 94 ++ .../main/scala/spark/streaming/Checkpoint.scala | 190 --- .../src/main/scala/spark/streaming/DStream.scala | 700 ----------- .../spark/streaming/DStreamCheckpointData.scala | 110 -- .../main/scala/spark/streaming/DStreamGraph.scala | 167 --- .../src/main/scala/spark/streaming/Duration.scala | 83 -- .../src/main/scala/spark/streaming/Interval.scala | 59 - streaming/src/main/scala/spark/streaming/Job.scala | 41 - .../main/scala/spark/streaming/JobManager.scala | 88 -- .../spark/streaming/NetworkInputTracker.scala | 173 --- .../spark/streaming/PairDStreamFunctions.scala | 534 -------- .../src/main/scala/spark/streaming/Scheduler.scala | 130 -- .../scala/spark/streaming/StreamingContext.scala | 563 --------- .../src/main/scala/spark/streaming/Time.scala | 72 -- .../spark/streaming/api/java/JavaDStream.scala | 102 -- .../spark/streaming/api/java/JavaDStreamLike.scala | 316 ----- .../spark/streaming/api/java/JavaPairDStream.scala | 613 --------- .../streaming/api/java/JavaStreamingContext.scala | 613 --------- .../spark/streaming/dstream/CoGroupedDStream.scala | 57 - .../streaming/dstream/ConstantInputDStream.scala | 36 - .../spark/streaming/dstream/FileInputDStream.scala | 199 --- .../spark/streaming/dstream/FilteredDStream.scala | 38 - .../streaming/dstream/FlatMapValuedDStream.scala | 37 - .../streaming/dstream/FlatMappedDStream.scala | 37 - .../streaming/dstream/FlumeInputDStream.scala | 154 --- .../spark/streaming/dstream/ForEachDStream.scala | 45 - .../spark/streaming/dstream/GlommedDStream.scala | 34 - .../spark/streaming/dstream/InputDStream.scala | 70 -- .../streaming/dstream/KafkaInputDStream.scala | 141 --- .../streaming/dstream/MapPartitionedDStream.scala | 38 - .../spark/streaming/dstream/MapValuedDStream.scala | 38 - .../spark/streaming/dstream/MappedDStream.scala | 37 - .../streaming/dstream/NetworkInputDStream.scala | 272 ---- .../streaming/dstream/PluggableInputDStream.scala | 30 - .../streaming/dstream/QueueInputDStream.scala | 59 - .../spark/streaming/dstream/RawInputDStream.scala | 108 -- .../streaming/dstream/ReducedWindowedDStream.scala | 174 --- .../spark/streaming/dstream/ShuffledDStream.scala | 44 - .../streaming/dstream/SocketInputDStream.scala | 94 -- .../spark/streaming/dstream/StateDStream.scala | 109 -- .../streaming/dstream/TransformedDStream.scala | 36 - .../streaming/dstream/TwitterInputDStream.scala | 99 -- .../spark/streaming/dstream/UnionDStream.scala | 57 - .../spark/streaming/dstream/WindowedDStream.scala | 57 - .../spark/streaming/receivers/ActorReceiver.scala | 175 --- .../spark/streaming/receivers/ZeroMQReceiver.scala | 50 - .../main/scala/spark/streaming/util/Clock.scala | 101 -- .../spark/streaming/util/MasterFailureTest.scala | 414 ------- .../scala/spark/streaming/util/RawTextHelper.scala | 115 -- .../scala/spark/streaming/util/RawTextSender.scala | 77 -- .../spark/streaming/util/RecurringTimer.scala | 94 -- .../org/apache/spark/streaming/JavaAPISuite.java | 1304 ++++++++++++++++++++ .../org/apache/spark/streaming/JavaTestUtils.scala | 85 ++ .../test/java/spark/streaming/JavaAPISuite.java | 1304 -------------------- .../test/java/spark/streaming/JavaTestUtils.scala | 84 -- .../spark/streaming/BasicOperationsSuite.scala | 322 +++++ .../apache/spark/streaming/CheckpointSuite.scala | 372 ++++++ .../org/apache/spark/streaming/FailureSuite.scala | 57 + .../apache/spark/streaming/InputStreamsSuite.scala | 349 ++++++ .../org/apache/spark/streaming/TestSuiteBase.scala | 314 +++++ .../spark/streaming/WindowOperationsSuite.scala | 340 +++++ .../spark/streaming/BasicOperationsSuite.scala | 322 ----- .../scala/spark/streaming/CheckpointSuite.scala | 372 ------ .../test/scala/spark/streaming/FailureSuite.scala | 57 - .../scala/spark/streaming/InputStreamsSuite.scala | 349 ------ .../test/scala/spark/streaming/TestSuiteBase.scala | 314 ----- .../spark/streaming/WindowOperationsSuite.scala | 340 ----- tools/pom.xml | 8 +- .../spark/tools/JavaAPICompletenessChecker.scala | 360 ++++++ .../spark/tools/JavaAPICompletenessChecker.scala | 360 ------ yarn/pom.xml | 6 +- .../spark/deploy/yarn/ApplicationMaster.scala | 371 ++++++ .../deploy/yarn/ApplicationMasterArguments.scala | 94 ++ .../org/apache/spark/deploy/yarn/Client.scala | 336 +++++ .../apache/spark/deploy/yarn/ClientArguments.scala | 116 ++ .../apache/spark/deploy/yarn/WorkerRunnable.scala | 224 ++++ .../spark/deploy/yarn/YarnAllocationHandler.scala | 564 +++++++++ .../spark/deploy/yarn/YarnSparkHadoopUtil.scala | 46 + .../scheduler/cluster/YarnClusterScheduler.scala | 52 + .../spark/deploy/yarn/ApplicationMaster.scala | 371 ------ .../deploy/yarn/ApplicationMasterArguments.scala | 94 -- yarn/src/main/scala/spark/deploy/yarn/Client.scala | 336 ----- .../scala/spark/deploy/yarn/ClientArguments.scala | 116 -- .../scala/spark/deploy/yarn/WorkerRunnable.scala | 224 ---- .../spark/deploy/yarn/YarnAllocationHandler.scala | 564 --------- .../spark/deploy/yarn/YarnSparkHadoopUtil.scala | 46 - .../scheduler/cluster/YarnClusterScheduler.scala | 52 - 1015 files changed, 70352 insertions(+), 70341 deletions(-) create mode 100644 bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala delete mode 100644 bagel/src/main/scala/spark/bagel/Bagel.scala delete mode 100644 bagel/src/test/scala/bagel/BagelSuite.scala create mode 100644 bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala create mode 100644 core/src/main/java/org/apache/spark/network/netty/FileClient.java create mode 100644 core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java create mode 100644 core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java create mode 100644 core/src/main/java/org/apache/spark/network/netty/FileServer.java create mode 100644 core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java create mode 100644 core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java create mode 100755 core/src/main/java/org/apache/spark/network/netty/PathResolver.java delete mode 100644 core/src/main/java/spark/network/netty/FileClient.java delete mode 100644 core/src/main/java/spark/network/netty/FileClientChannelInitializer.java delete mode 100644 core/src/main/java/spark/network/netty/FileClientHandler.java delete mode 100644 core/src/main/java/spark/network/netty/FileServer.java delete mode 100644 core/src/main/java/spark/network/netty/FileServerChannelInitializer.java delete mode 100644 core/src/main/java/spark/network/netty/FileServerHandler.java delete mode 100755 core/src/main/java/spark/network/netty/PathResolver.java create mode 100755 core/src/main/resources/org/apache/spark/ui/static/bootstrap.min.css create mode 100644 core/src/main/resources/org/apache/spark/ui/static/sorttable.js create mode 100644 core/src/main/resources/org/apache/spark/ui/static/spark-logo-77x50px-hd.png create mode 100644 core/src/main/resources/org/apache/spark/ui/static/spark_logo.png create mode 100644 core/src/main/resources/org/apache/spark/ui/static/webui.css delete mode 100755 core/src/main/resources/spark/ui/static/bootstrap.min.css delete mode 100644 core/src/main/resources/spark/ui/static/sorttable.js delete mode 100644 core/src/main/resources/spark/ui/static/spark-logo-77x50px-hd.png delete mode 100644 core/src/main/resources/spark/ui/static/spark_logo.png delete mode 100644 core/src/main/resources/spark/ui/static/webui.css create mode 100644 core/src/main/scala/org/apache/spark/Accumulators.scala create mode 100644 core/src/main/scala/org/apache/spark/Aggregator.scala create mode 100644 core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala create mode 100644 core/src/main/scala/org/apache/spark/CacheManager.scala create mode 100644 core/src/main/scala/org/apache/spark/ClosureCleaner.scala create mode 100644 core/src/main/scala/org/apache/spark/Dependency.scala create mode 100644 core/src/main/scala/org/apache/spark/DoubleRDDFunctions.scala create mode 100644 core/src/main/scala/org/apache/spark/FetchFailedException.scala create mode 100644 core/src/main/scala/org/apache/spark/HttpFileServer.scala create mode 100644 core/src/main/scala/org/apache/spark/HttpServer.scala create mode 100644 core/src/main/scala/org/apache/spark/JavaSerializer.scala create mode 100644 core/src/main/scala/org/apache/spark/KryoSerializer.scala create mode 100644 core/src/main/scala/org/apache/spark/Logging.scala create mode 100644 core/src/main/scala/org/apache/spark/MapOutputTracker.scala create mode 100644 core/src/main/scala/org/apache/spark/PairRDDFunctions.scala create mode 100644 core/src/main/scala/org/apache/spark/Partition.scala create mode 100644 core/src/main/scala/org/apache/spark/Partitioner.scala create mode 100644 core/src/main/scala/org/apache/spark/RDD.scala create mode 100644 core/src/main/scala/org/apache/spark/RDDCheckpointData.scala create mode 100644 core/src/main/scala/org/apache/spark/SequenceFileRDDFunctions.scala create mode 100644 core/src/main/scala/org/apache/spark/SerializableWritable.scala create mode 100644 core/src/main/scala/org/apache/spark/ShuffleFetcher.scala create mode 100644 core/src/main/scala/org/apache/spark/SizeEstimator.scala create mode 100644 core/src/main/scala/org/apache/spark/SparkContext.scala create mode 100644 core/src/main/scala/org/apache/spark/SparkEnv.scala create mode 100644 core/src/main/scala/org/apache/spark/SparkException.scala create mode 100644 core/src/main/scala/org/apache/spark/SparkFiles.java create mode 100644 core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala create mode 100644 core/src/main/scala/org/apache/spark/TaskContext.scala create mode 100644 core/src/main/scala/org/apache/spark/TaskEndReason.scala create mode 100644 core/src/main/scala/org/apache/spark/TaskState.scala create mode 100644 core/src/main/scala/org/apache/spark/Utils.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java create mode 100644 core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/StorageLevels.java create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/Function.java create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/Function2.java create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala create mode 100644 core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala create mode 100644 core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/Command.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/WebUI.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/client/Client.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/Master.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/Executor.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala create mode 100644 core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala create mode 100644 core/src/main/scala/org/apache/spark/io/CompressionCodec.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala create mode 100644 core/src/main/scala/org/apache/spark/metrics/source/Source.scala create mode 100644 core/src/main/scala/org/apache/spark/network/BufferMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/network/Connection.scala create mode 100644 core/src/main/scala/org/apache/spark/network/ConnectionManager.scala create mode 100644 core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala create mode 100644 core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala create mode 100644 core/src/main/scala/org/apache/spark/network/Message.scala create mode 100644 core/src/main/scala/org/apache/spark/network/MessageChunk.scala create mode 100644 core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala create mode 100644 core/src/main/scala/org/apache/spark/network/ReceiverTest.scala create mode 100644 core/src/main/scala/org/apache/spark/network/SenderTest.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala create mode 100644 core/src/main/scala/org/apache/spark/package.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/ApproximateEvaluator.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/PartialResult.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala create mode 100644 core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/FlatMappedValuesRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/JobListener.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/JobResult.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/Stage.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/Task.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/Pool.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/Schedulable.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulableBuilder.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingAlgorithm.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingMode.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/TaskDescription.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/TaskInfo.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/TaskLocality.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/TaskSetManager.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackend.scala create mode 100644 core/src/main/scala/org/apache/spark/serializer/Serializer.scala create mode 100644 core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockException.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManager.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/BlockStore.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/DiskStore.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/MemoryStore.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/PutResult.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/StorageLevel.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/StorageUtils.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/JettyUtils.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/Page.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/SparkUI.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/UIUtils.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala create mode 100644 core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala create mode 100644 core/src/main/scala/org/apache/spark/util/AkkaUtils.scala create mode 100644 core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala create mode 100644 core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala create mode 100644 core/src/main/scala/org/apache/spark/util/Clock.scala create mode 100644 core/src/main/scala/org/apache/spark/util/CompletionIterator.scala create mode 100644 core/src/main/scala/org/apache/spark/util/Distribution.scala create mode 100644 core/src/main/scala/org/apache/spark/util/IdGenerator.scala create mode 100644 core/src/main/scala/org/apache/spark/util/IntParam.scala create mode 100644 core/src/main/scala/org/apache/spark/util/MemoryParam.scala create mode 100644 core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala create mode 100644 core/src/main/scala/org/apache/spark/util/MutablePair.scala create mode 100644 core/src/main/scala/org/apache/spark/util/NextIterator.scala create mode 100644 core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala create mode 100644 core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala create mode 100644 core/src/main/scala/org/apache/spark/util/StatCounter.scala create mode 100644 core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala create mode 100644 core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala create mode 100644 core/src/main/scala/org/apache/spark/util/Vector.scala delete mode 100644 core/src/main/scala/spark/Accumulators.scala delete mode 100644 core/src/main/scala/spark/Aggregator.scala delete mode 100644 core/src/main/scala/spark/BlockStoreShuffleFetcher.scala delete mode 100644 core/src/main/scala/spark/CacheManager.scala delete mode 100644 core/src/main/scala/spark/ClosureCleaner.scala delete mode 100644 core/src/main/scala/spark/Dependency.scala delete mode 100644 core/src/main/scala/spark/DoubleRDDFunctions.scala delete mode 100644 core/src/main/scala/spark/FetchFailedException.scala delete mode 100644 core/src/main/scala/spark/HttpFileServer.scala delete mode 100644 core/src/main/scala/spark/HttpServer.scala delete mode 100644 core/src/main/scala/spark/JavaSerializer.scala delete mode 100644 core/src/main/scala/spark/KryoSerializer.scala delete mode 100644 core/src/main/scala/spark/Logging.scala delete mode 100644 core/src/main/scala/spark/MapOutputTracker.scala delete mode 100644 core/src/main/scala/spark/PairRDDFunctions.scala delete mode 100644 core/src/main/scala/spark/Partition.scala delete mode 100644 core/src/main/scala/spark/Partitioner.scala delete mode 100644 core/src/main/scala/spark/RDD.scala delete mode 100644 core/src/main/scala/spark/RDDCheckpointData.scala delete mode 100644 core/src/main/scala/spark/SequenceFileRDDFunctions.scala delete mode 100644 core/src/main/scala/spark/SerializableWritable.scala delete mode 100644 core/src/main/scala/spark/ShuffleFetcher.scala delete mode 100644 core/src/main/scala/spark/SizeEstimator.scala delete mode 100644 core/src/main/scala/spark/SparkContext.scala delete mode 100644 core/src/main/scala/spark/SparkEnv.scala delete mode 100644 core/src/main/scala/spark/SparkException.scala delete mode 100644 core/src/main/scala/spark/SparkFiles.java delete mode 100644 core/src/main/scala/spark/SparkHadoopWriter.scala delete mode 100644 core/src/main/scala/spark/TaskContext.scala delete mode 100644 core/src/main/scala/spark/TaskEndReason.scala delete mode 100644 core/src/main/scala/spark/TaskState.scala delete mode 100644 core/src/main/scala/spark/Utils.scala delete mode 100644 core/src/main/scala/spark/api/java/JavaDoubleRDD.scala delete mode 100644 core/src/main/scala/spark/api/java/JavaPairRDD.scala delete mode 100644 core/src/main/scala/spark/api/java/JavaRDD.scala delete mode 100644 core/src/main/scala/spark/api/java/JavaRDDLike.scala delete mode 100644 core/src/main/scala/spark/api/java/JavaSparkContext.scala delete mode 100644 core/src/main/scala/spark/api/java/JavaSparkContextVarargsWorkaround.java delete mode 100644 core/src/main/scala/spark/api/java/JavaUtils.scala delete mode 100644 core/src/main/scala/spark/api/java/StorageLevels.java delete mode 100644 core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java delete mode 100644 core/src/main/scala/spark/api/java/function/DoubleFunction.java delete mode 100644 core/src/main/scala/spark/api/java/function/FlatMapFunction.scala delete mode 100644 core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala delete mode 100644 core/src/main/scala/spark/api/java/function/Function.java delete mode 100644 core/src/main/scala/spark/api/java/function/Function2.java delete mode 100644 core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java delete mode 100644 core/src/main/scala/spark/api/java/function/PairFunction.java delete mode 100644 core/src/main/scala/spark/api/java/function/VoidFunction.scala delete mode 100644 core/src/main/scala/spark/api/java/function/WrappedFunction1.scala delete mode 100644 core/src/main/scala/spark/api/java/function/WrappedFunction2.scala delete mode 100644 core/src/main/scala/spark/api/python/PythonPartitioner.scala delete mode 100644 core/src/main/scala/spark/api/python/PythonRDD.scala delete mode 100644 core/src/main/scala/spark/api/python/PythonWorkerFactory.scala delete mode 100644 core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala delete mode 100644 core/src/main/scala/spark/broadcast/Broadcast.scala delete mode 100644 core/src/main/scala/spark/broadcast/BroadcastFactory.scala delete mode 100644 core/src/main/scala/spark/broadcast/HttpBroadcast.scala delete mode 100644 core/src/main/scala/spark/broadcast/MultiTracker.scala delete mode 100644 core/src/main/scala/spark/broadcast/SourceInfo.scala delete mode 100644 core/src/main/scala/spark/broadcast/TreeBroadcast.scala delete mode 100644 core/src/main/scala/spark/deploy/ApplicationDescription.scala delete mode 100644 core/src/main/scala/spark/deploy/Command.scala delete mode 100644 core/src/main/scala/spark/deploy/DeployMessage.scala delete mode 100644 core/src/main/scala/spark/deploy/ExecutorState.scala delete mode 100644 core/src/main/scala/spark/deploy/JsonProtocol.scala delete mode 100644 core/src/main/scala/spark/deploy/LocalSparkCluster.scala delete mode 100644 core/src/main/scala/spark/deploy/SparkHadoopUtil.scala delete mode 100644 core/src/main/scala/spark/deploy/WebUI.scala delete mode 100644 core/src/main/scala/spark/deploy/client/Client.scala delete mode 100644 core/src/main/scala/spark/deploy/client/ClientListener.scala delete mode 100644 core/src/main/scala/spark/deploy/client/TestClient.scala delete mode 100644 core/src/main/scala/spark/deploy/client/TestExecutor.scala delete mode 100644 core/src/main/scala/spark/deploy/master/ApplicationInfo.scala delete mode 100644 core/src/main/scala/spark/deploy/master/ApplicationSource.scala delete mode 100644 core/src/main/scala/spark/deploy/master/ApplicationState.scala delete mode 100644 core/src/main/scala/spark/deploy/master/ExecutorInfo.scala delete mode 100644 core/src/main/scala/spark/deploy/master/Master.scala delete mode 100644 core/src/main/scala/spark/deploy/master/MasterArguments.scala delete mode 100644 core/src/main/scala/spark/deploy/master/MasterSource.scala delete mode 100644 core/src/main/scala/spark/deploy/master/WorkerInfo.scala delete mode 100644 core/src/main/scala/spark/deploy/master/WorkerState.scala delete mode 100644 core/src/main/scala/spark/deploy/master/ui/ApplicationPage.scala delete mode 100644 core/src/main/scala/spark/deploy/master/ui/IndexPage.scala delete mode 100644 core/src/main/scala/spark/deploy/master/ui/MasterWebUI.scala delete mode 100644 core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala delete mode 100644 core/src/main/scala/spark/deploy/worker/Worker.scala delete mode 100644 core/src/main/scala/spark/deploy/worker/WorkerArguments.scala delete mode 100644 core/src/main/scala/spark/deploy/worker/WorkerSource.scala delete mode 100644 core/src/main/scala/spark/deploy/worker/ui/IndexPage.scala delete mode 100644 core/src/main/scala/spark/deploy/worker/ui/WorkerWebUI.scala delete mode 100644 core/src/main/scala/spark/executor/Executor.scala delete mode 100644 core/src/main/scala/spark/executor/ExecutorBackend.scala delete mode 100644 core/src/main/scala/spark/executor/ExecutorExitCode.scala delete mode 100644 core/src/main/scala/spark/executor/ExecutorSource.scala delete mode 100644 core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala delete mode 100644 core/src/main/scala/spark/executor/MesosExecutorBackend.scala delete mode 100644 core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala delete mode 100644 core/src/main/scala/spark/executor/TaskMetrics.scala delete mode 100644 core/src/main/scala/spark/io/CompressionCodec.scala delete mode 100644 core/src/main/scala/spark/metrics/MetricsConfig.scala delete mode 100644 core/src/main/scala/spark/metrics/MetricsSystem.scala delete mode 100644 core/src/main/scala/spark/metrics/sink/ConsoleSink.scala delete mode 100644 core/src/main/scala/spark/metrics/sink/CsvSink.scala delete mode 100644 core/src/main/scala/spark/metrics/sink/JmxSink.scala delete mode 100644 core/src/main/scala/spark/metrics/sink/MetricsServlet.scala delete mode 100644 core/src/main/scala/spark/metrics/sink/Sink.scala delete mode 100644 core/src/main/scala/spark/metrics/source/JvmSource.scala delete mode 100644 core/src/main/scala/spark/metrics/source/Source.scala delete mode 100644 core/src/main/scala/spark/network/BufferMessage.scala delete mode 100644 core/src/main/scala/spark/network/Connection.scala delete mode 100644 core/src/main/scala/spark/network/ConnectionManager.scala delete mode 100644 core/src/main/scala/spark/network/ConnectionManagerId.scala delete mode 100644 core/src/main/scala/spark/network/ConnectionManagerTest.scala delete mode 100644 core/src/main/scala/spark/network/Message.scala delete mode 100644 core/src/main/scala/spark/network/MessageChunk.scala delete mode 100644 core/src/main/scala/spark/network/MessageChunkHeader.scala delete mode 100644 core/src/main/scala/spark/network/ReceiverTest.scala delete mode 100644 core/src/main/scala/spark/network/SenderTest.scala delete mode 100644 core/src/main/scala/spark/network/netty/FileHeader.scala delete mode 100644 core/src/main/scala/spark/network/netty/ShuffleCopier.scala delete mode 100644 core/src/main/scala/spark/network/netty/ShuffleSender.scala delete mode 100644 core/src/main/scala/spark/package.scala delete mode 100644 core/src/main/scala/spark/partial/ApproximateActionListener.scala delete mode 100644 core/src/main/scala/spark/partial/ApproximateEvaluator.scala delete mode 100644 core/src/main/scala/spark/partial/BoundedDouble.scala delete mode 100644 core/src/main/scala/spark/partial/CountEvaluator.scala delete mode 100644 core/src/main/scala/spark/partial/GroupedCountEvaluator.scala delete mode 100644 core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala delete mode 100644 core/src/main/scala/spark/partial/GroupedSumEvaluator.scala delete mode 100644 core/src/main/scala/spark/partial/MeanEvaluator.scala delete mode 100644 core/src/main/scala/spark/partial/PartialResult.scala delete mode 100644 core/src/main/scala/spark/partial/StudentTCacher.scala delete mode 100644 core/src/main/scala/spark/partial/SumEvaluator.scala delete mode 100644 core/src/main/scala/spark/rdd/BlockRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/CartesianRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/CheckpointRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/CoGroupedRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/CoalescedRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/EmptyRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/FilteredRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/FlatMappedRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/FlatMappedValuesRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/GlommedRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/HadoopRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/JdbcRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/MapPartitionsRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/MapPartitionsWithIndexRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/MappedRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/MappedValuesRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/NewHadoopRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/OrderedRDDFunctions.scala delete mode 100644 core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/PartitionPruningRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/PipedRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/SampledRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/ShuffledRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/SubtractedRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/UnionRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/ZippedRDD.scala delete mode 100644 core/src/main/scala/spark/scheduler/ActiveJob.scala delete mode 100644 core/src/main/scala/spark/scheduler/DAGScheduler.scala delete mode 100644 core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala delete mode 100644 core/src/main/scala/spark/scheduler/DAGSchedulerSource.scala delete mode 100644 core/src/main/scala/spark/scheduler/InputFormatInfo.scala delete mode 100644 core/src/main/scala/spark/scheduler/JobListener.scala delete mode 100644 core/src/main/scala/spark/scheduler/JobLogger.scala delete mode 100644 core/src/main/scala/spark/scheduler/JobResult.scala delete mode 100644 core/src/main/scala/spark/scheduler/JobWaiter.scala delete mode 100644 core/src/main/scala/spark/scheduler/MapStatus.scala delete mode 100644 core/src/main/scala/spark/scheduler/ResultTask.scala delete mode 100644 core/src/main/scala/spark/scheduler/ShuffleMapTask.scala delete mode 100644 core/src/main/scala/spark/scheduler/SparkListener.scala delete mode 100644 core/src/main/scala/spark/scheduler/SparkListenerBus.scala delete mode 100644 core/src/main/scala/spark/scheduler/SplitInfo.scala delete mode 100644 core/src/main/scala/spark/scheduler/Stage.scala delete mode 100644 core/src/main/scala/spark/scheduler/StageInfo.scala delete mode 100644 core/src/main/scala/spark/scheduler/Task.scala delete mode 100644 core/src/main/scala/spark/scheduler/TaskLocation.scala delete mode 100644 core/src/main/scala/spark/scheduler/TaskResult.scala delete mode 100644 core/src/main/scala/spark/scheduler/TaskScheduler.scala delete mode 100644 core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala delete mode 100644 core/src/main/scala/spark/scheduler/TaskSet.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/Pool.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/Schedulable.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/TaskLocality.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala delete mode 100644 core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala delete mode 100644 core/src/main/scala/spark/scheduler/local/LocalScheduler.scala delete mode 100644 core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala delete mode 100644 core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala delete mode 100644 core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala delete mode 100644 core/src/main/scala/spark/serializer/Serializer.scala delete mode 100644 core/src/main/scala/spark/serializer/SerializerManager.scala delete mode 100644 core/src/main/scala/spark/storage/BlockException.scala delete mode 100644 core/src/main/scala/spark/storage/BlockFetchTracker.scala delete mode 100644 core/src/main/scala/spark/storage/BlockFetcherIterator.scala delete mode 100644 core/src/main/scala/spark/storage/BlockManager.scala delete mode 100644 core/src/main/scala/spark/storage/BlockManagerId.scala delete mode 100644 core/src/main/scala/spark/storage/BlockManagerMaster.scala delete mode 100644 core/src/main/scala/spark/storage/BlockManagerMasterActor.scala delete mode 100644 core/src/main/scala/spark/storage/BlockManagerMessages.scala delete mode 100644 core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala delete mode 100644 core/src/main/scala/spark/storage/BlockManagerSource.scala delete mode 100644 core/src/main/scala/spark/storage/BlockManagerWorker.scala delete mode 100644 core/src/main/scala/spark/storage/BlockMessage.scala delete mode 100644 core/src/main/scala/spark/storage/BlockMessageArray.scala delete mode 100644 core/src/main/scala/spark/storage/BlockObjectWriter.scala delete mode 100644 core/src/main/scala/spark/storage/BlockStore.scala delete mode 100644 core/src/main/scala/spark/storage/DiskStore.scala delete mode 100644 core/src/main/scala/spark/storage/MemoryStore.scala delete mode 100644 core/src/main/scala/spark/storage/PutResult.scala delete mode 100644 core/src/main/scala/spark/storage/ShuffleBlockManager.scala delete mode 100644 core/src/main/scala/spark/storage/StorageLevel.scala delete mode 100644 core/src/main/scala/spark/storage/StorageUtils.scala delete mode 100644 core/src/main/scala/spark/storage/ThreadingTest.scala delete mode 100644 core/src/main/scala/spark/ui/JettyUtils.scala delete mode 100644 core/src/main/scala/spark/ui/Page.scala delete mode 100644 core/src/main/scala/spark/ui/SparkUI.scala delete mode 100644 core/src/main/scala/spark/ui/UIUtils.scala delete mode 100644 core/src/main/scala/spark/ui/UIWorkloadGenerator.scala delete mode 100644 core/src/main/scala/spark/ui/env/EnvironmentUI.scala delete mode 100644 core/src/main/scala/spark/ui/exec/ExecutorsUI.scala delete mode 100644 core/src/main/scala/spark/ui/jobs/IndexPage.scala delete mode 100644 core/src/main/scala/spark/ui/jobs/JobProgressListener.scala delete mode 100644 core/src/main/scala/spark/ui/jobs/JobProgressUI.scala delete mode 100644 core/src/main/scala/spark/ui/jobs/PoolPage.scala delete mode 100644 core/src/main/scala/spark/ui/jobs/PoolTable.scala delete mode 100644 core/src/main/scala/spark/ui/jobs/StagePage.scala delete mode 100644 core/src/main/scala/spark/ui/jobs/StageTable.scala delete mode 100644 core/src/main/scala/spark/ui/storage/BlockManagerUI.scala delete mode 100644 core/src/main/scala/spark/ui/storage/IndexPage.scala delete mode 100644 core/src/main/scala/spark/ui/storage/RDDPage.scala delete mode 100644 core/src/main/scala/spark/util/AkkaUtils.scala delete mode 100644 core/src/main/scala/spark/util/BoundedPriorityQueue.scala delete mode 100644 core/src/main/scala/spark/util/ByteBufferInputStream.scala delete mode 100644 core/src/main/scala/spark/util/Clock.scala delete mode 100644 core/src/main/scala/spark/util/CompletionIterator.scala delete mode 100644 core/src/main/scala/spark/util/Distribution.scala delete mode 100644 core/src/main/scala/spark/util/IdGenerator.scala delete mode 100644 core/src/main/scala/spark/util/IntParam.scala delete mode 100644 core/src/main/scala/spark/util/MemoryParam.scala delete mode 100644 core/src/main/scala/spark/util/MetadataCleaner.scala delete mode 100644 core/src/main/scala/spark/util/MutablePair.scala delete mode 100644 core/src/main/scala/spark/util/NextIterator.scala delete mode 100644 core/src/main/scala/spark/util/RateLimitedOutputStream.scala delete mode 100644 core/src/main/scala/spark/util/SerializableBuffer.scala delete mode 100644 core/src/main/scala/spark/util/StatCounter.scala delete mode 100644 core/src/main/scala/spark/util/TimeStampedHashMap.scala delete mode 100644 core/src/main/scala/spark/util/TimeStampedHashSet.scala delete mode 100644 core/src/main/scala/spark/util/Vector.scala create mode 100644 core/src/test/scala/org/apache/spark/AccumulatorSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/BroadcastSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/CheckpointSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/ClosureCleanerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/DistributedSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/DriverSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/FailureSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/FileServerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/FileSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/JavaAPISuite.java create mode 100644 core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/LocalSparkContext.scala create mode 100644 core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/PartitioningSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/PipedRDDSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/RDDSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/SharedSparkContext.scala create mode 100644 core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala create mode 100644 core/src/test/scala/org/apache/spark/ShuffleSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/SizeEstimatorSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/SortingSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/ThreadingSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/UnpersistSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/UtilsSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/ui/UISuite.scala create mode 100644 core/src/test/scala/org/apache/spark/util/DistributionSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/util/FakeClock.scala create mode 100644 core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/util/RateLimitedOutputStreamSuite.scala delete mode 100644 core/src/test/scala/spark/AccumulatorSuite.scala delete mode 100644 core/src/test/scala/spark/BroadcastSuite.scala delete mode 100644 core/src/test/scala/spark/CheckpointSuite.scala delete mode 100644 core/src/test/scala/spark/ClosureCleanerSuite.scala delete mode 100644 core/src/test/scala/spark/DistributedSuite.scala delete mode 100644 core/src/test/scala/spark/DriverSuite.scala delete mode 100644 core/src/test/scala/spark/FailureSuite.scala delete mode 100644 core/src/test/scala/spark/FileServerSuite.scala delete mode 100644 core/src/test/scala/spark/FileSuite.scala delete mode 100644 core/src/test/scala/spark/JavaAPISuite.java delete mode 100644 core/src/test/scala/spark/KryoSerializerSuite.scala delete mode 100644 core/src/test/scala/spark/LocalSparkContext.scala delete mode 100644 core/src/test/scala/spark/MapOutputTrackerSuite.scala delete mode 100644 core/src/test/scala/spark/PairRDDFunctionsSuite.scala delete mode 100644 core/src/test/scala/spark/PartitionPruningRDDSuite.scala delete mode 100644 core/src/test/scala/spark/PartitioningSuite.scala delete mode 100644 core/src/test/scala/spark/PipedRDDSuite.scala delete mode 100644 core/src/test/scala/spark/RDDSuite.scala delete mode 100644 core/src/test/scala/spark/SharedSparkContext.scala delete mode 100644 core/src/test/scala/spark/ShuffleNettySuite.scala delete mode 100644 core/src/test/scala/spark/ShuffleSuite.scala delete mode 100644 core/src/test/scala/spark/SizeEstimatorSuite.scala delete mode 100644 core/src/test/scala/spark/SortingSuite.scala delete mode 100644 core/src/test/scala/spark/SparkContextInfoSuite.scala delete mode 100644 core/src/test/scala/spark/ThreadingSuite.scala delete mode 100644 core/src/test/scala/spark/UnpersistSuite.scala delete mode 100644 core/src/test/scala/spark/UtilsSuite.scala delete mode 100644 core/src/test/scala/spark/ZippedPartitionsSuite.scala delete mode 100644 core/src/test/scala/spark/io/CompressionCodecSuite.scala delete mode 100644 core/src/test/scala/spark/metrics/MetricsConfigSuite.scala delete mode 100644 core/src/test/scala/spark/metrics/MetricsSystemSuite.scala delete mode 100644 core/src/test/scala/spark/rdd/JdbcRDDSuite.scala delete mode 100644 core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala delete mode 100644 core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala delete mode 100644 core/src/test/scala/spark/scheduler/JobLoggerSuite.scala delete mode 100644 core/src/test/scala/spark/scheduler/SparkListenerSuite.scala delete mode 100644 core/src/test/scala/spark/scheduler/TaskContextSuite.scala delete mode 100644 core/src/test/scala/spark/scheduler/cluster/ClusterSchedulerSuite.scala delete mode 100644 core/src/test/scala/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala delete mode 100644 core/src/test/scala/spark/scheduler/cluster/FakeTask.scala delete mode 100644 core/src/test/scala/spark/scheduler/local/LocalSchedulerSuite.scala delete mode 100644 core/src/test/scala/spark/storage/BlockManagerSuite.scala delete mode 100644 core/src/test/scala/spark/ui/UISuite.scala delete mode 100644 core/src/test/scala/spark/util/DistributionSuite.scala delete mode 100644 core/src/test/scala/spark/util/FakeClock.scala delete mode 100644 core/src/test/scala/spark/util/NextIteratorSuite.scala delete mode 100644 core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala create mode 100644 examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java create mode 100644 examples/src/main/java/org/apache/spark/examples/JavaKMeans.java create mode 100644 examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java create mode 100644 examples/src/main/java/org/apache/spark/examples/JavaPageRank.java create mode 100644 examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java create mode 100644 examples/src/main/java/org/apache/spark/examples/JavaTC.java create mode 100644 examples/src/main/java/org/apache/spark/examples/JavaWordCount.java create mode 100644 examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java create mode 100644 examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java create mode 100644 examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java create mode 100644 examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java create mode 100644 examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java create mode 100644 examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java delete mode 100644 examples/src/main/java/spark/examples/JavaHdfsLR.java delete mode 100644 examples/src/main/java/spark/examples/JavaKMeans.java delete mode 100644 examples/src/main/java/spark/examples/JavaLogQuery.java delete mode 100644 examples/src/main/java/spark/examples/JavaPageRank.java delete mode 100644 examples/src/main/java/spark/examples/JavaSparkPi.java delete mode 100644 examples/src/main/java/spark/examples/JavaTC.java delete mode 100644 examples/src/main/java/spark/examples/JavaWordCount.java delete mode 100644 examples/src/main/java/spark/mllib/examples/JavaALS.java delete mode 100644 examples/src/main/java/spark/mllib/examples/JavaKMeans.java delete mode 100644 examples/src/main/java/spark/mllib/examples/JavaLR.java delete mode 100644 examples/src/main/java/spark/streaming/examples/JavaFlumeEventCount.java delete mode 100644 examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java delete mode 100644 examples/src/main/java/spark/streaming/examples/JavaQueueStream.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/LocalALS.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/LocalLR.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/LocalPi.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/LogQuery.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkALS.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkLR.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkPi.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/SparkTC.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/FlumeEventCount.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala create mode 100644 examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala delete mode 100644 examples/src/main/scala/spark/examples/BroadcastTest.scala delete mode 100644 examples/src/main/scala/spark/examples/CassandraTest.scala delete mode 100644 examples/src/main/scala/spark/examples/ExceptionHandlingTest.scala delete mode 100644 examples/src/main/scala/spark/examples/GroupByTest.scala delete mode 100644 examples/src/main/scala/spark/examples/HBaseTest.scala delete mode 100644 examples/src/main/scala/spark/examples/HdfsTest.scala delete mode 100644 examples/src/main/scala/spark/examples/LocalALS.scala delete mode 100644 examples/src/main/scala/spark/examples/LocalFileLR.scala delete mode 100644 examples/src/main/scala/spark/examples/LocalKMeans.scala delete mode 100644 examples/src/main/scala/spark/examples/LocalLR.scala delete mode 100644 examples/src/main/scala/spark/examples/LocalPi.scala delete mode 100644 examples/src/main/scala/spark/examples/LogQuery.scala delete mode 100644 examples/src/main/scala/spark/examples/MultiBroadcastTest.scala delete mode 100644 examples/src/main/scala/spark/examples/SimpleSkewedGroupByTest.scala delete mode 100644 examples/src/main/scala/spark/examples/SkewedGroupByTest.scala delete mode 100644 examples/src/main/scala/spark/examples/SparkALS.scala delete mode 100644 examples/src/main/scala/spark/examples/SparkHdfsLR.scala delete mode 100644 examples/src/main/scala/spark/examples/SparkKMeans.scala delete mode 100644 examples/src/main/scala/spark/examples/SparkLR.scala delete mode 100644 examples/src/main/scala/spark/examples/SparkPageRank.scala delete mode 100644 examples/src/main/scala/spark/examples/SparkPi.scala delete mode 100644 examples/src/main/scala/spark/examples/SparkTC.scala delete mode 100644 examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala delete mode 100644 examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala delete mode 100644 examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/ActorWordCount.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/QueueStream.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/ZeroMQWordCount.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala delete mode 100644 examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala delete mode 100644 mllib/src/main/scala/spark/mllib/classification/ClassificationModel.scala delete mode 100644 mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala delete mode 100644 mllib/src/main/scala/spark/mllib/classification/SVM.scala delete mode 100644 mllib/src/main/scala/spark/mllib/clustering/KMeans.scala delete mode 100644 mllib/src/main/scala/spark/mllib/clustering/KMeansModel.scala delete mode 100644 mllib/src/main/scala/spark/mllib/clustering/LocalKMeans.scala delete mode 100644 mllib/src/main/scala/spark/mllib/optimization/Gradient.scala delete mode 100644 mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala delete mode 100644 mllib/src/main/scala/spark/mllib/optimization/Optimizer.scala delete mode 100644 mllib/src/main/scala/spark/mllib/optimization/Updater.scala delete mode 100644 mllib/src/main/scala/spark/mllib/recommendation/ALS.scala delete mode 100644 mllib/src/main/scala/spark/mllib/recommendation/MatrixFactorizationModel.scala delete mode 100644 mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala delete mode 100644 mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala delete mode 100644 mllib/src/main/scala/spark/mllib/regression/Lasso.scala delete mode 100644 mllib/src/main/scala/spark/mllib/regression/LinearRegression.scala delete mode 100644 mllib/src/main/scala/spark/mllib/regression/RegressionModel.scala delete mode 100644 mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala delete mode 100644 mllib/src/main/scala/spark/mllib/util/DataValidators.scala delete mode 100644 mllib/src/main/scala/spark/mllib/util/KMeansDataGenerator.scala delete mode 100644 mllib/src/main/scala/spark/mllib/util/LinearDataGenerator.scala delete mode 100644 mllib/src/main/scala/spark/mllib/util/LogisticRegressionDataGenerator.scala delete mode 100644 mllib/src/main/scala/spark/mllib/util/MFDataGenerator.scala delete mode 100644 mllib/src/main/scala/spark/mllib/util/MLUtils.scala delete mode 100644 mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala create mode 100644 mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java delete mode 100644 mllib/src/test/java/spark/mllib/classification/JavaLogisticRegressionSuite.java delete mode 100644 mllib/src/test/java/spark/mllib/classification/JavaSVMSuite.java delete mode 100644 mllib/src/test/java/spark/mllib/clustering/JavaKMeansSuite.java delete mode 100644 mllib/src/test/java/spark/mllib/recommendation/JavaALSSuite.java delete mode 100644 mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java delete mode 100644 mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java delete mode 100644 mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala delete mode 100644 mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala delete mode 100644 mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala delete mode 100644 mllib/src/test/scala/spark/mllib/clustering/KMeansSuite.scala delete mode 100644 mllib/src/test/scala/spark/mllib/recommendation/ALSSuite.scala delete mode 100644 mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala delete mode 100644 mllib/src/test/scala/spark/mllib/regression/LinearRegressionSuite.scala delete mode 100644 mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/Main.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/SparkISettings.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/SparkImports.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala create mode 100644 repl/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala delete mode 100644 repl/src/main/scala/spark/repl/ExecutorClassLoader.scala delete mode 100644 repl/src/main/scala/spark/repl/Main.scala delete mode 100644 repl/src/main/scala/spark/repl/SparkHelper.scala delete mode 100644 repl/src/main/scala/spark/repl/SparkILoop.scala delete mode 100644 repl/src/main/scala/spark/repl/SparkIMain.scala delete mode 100644 repl/src/main/scala/spark/repl/SparkISettings.scala delete mode 100644 repl/src/main/scala/spark/repl/SparkImports.scala delete mode 100644 repl/src/main/scala/spark/repl/SparkJLineCompletion.scala delete mode 100644 repl/src/main/scala/spark/repl/SparkJLineReader.scala delete mode 100644 repl/src/main/scala/spark/repl/SparkMemberHandlers.scala create mode 100644 repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala delete mode 100644 repl/src/test/scala/spark/repl/ReplSuite.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/DStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Duration.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Interval.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Job.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/Time.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/CoGroupedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/TwitterInputDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/receivers/ZeroMQReceiver.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Checkpoint.scala delete mode 100644 streaming/src/main/scala/spark/streaming/DStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala delete mode 100644 streaming/src/main/scala/spark/streaming/DStreamGraph.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Duration.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Interval.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Job.scala delete mode 100644 streaming/src/main/scala/spark/streaming/JobManager.scala delete mode 100644 streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala delete mode 100644 streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Scheduler.scala delete mode 100644 streaming/src/main/scala/spark/streaming/StreamingContext.scala delete mode 100644 streaming/src/main/scala/spark/streaming/Time.scala delete mode 100644 streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala delete mode 100644 streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/PluggableInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala delete mode 100644 streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala delete mode 100644 streaming/src/main/scala/spark/streaming/receivers/ZeroMQReceiver.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/Clock.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/RawTextSender.scala delete mode 100644 streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala create mode 100644 streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java create mode 100644 streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala delete mode 100644 streaming/src/test/java/spark/streaming/JavaAPISuite.java delete mode 100644 streaming/src/test/java/spark/streaming/JavaTestUtils.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/CheckpointSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/FailureSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala delete mode 100644 streaming/src/test/scala/spark/streaming/TestSuiteBase.scala delete mode 100644 streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala create mode 100644 tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala delete mode 100644 tools/src/main/scala/spark/tools/JavaAPICompletenessChecker.scala create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala create mode 100644 yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala delete mode 100644 yarn/src/main/scala/spark/deploy/yarn/ApplicationMaster.scala delete mode 100644 yarn/src/main/scala/spark/deploy/yarn/ApplicationMasterArguments.scala delete mode 100644 yarn/src/main/scala/spark/deploy/yarn/Client.scala delete mode 100644 yarn/src/main/scala/spark/deploy/yarn/ClientArguments.scala delete mode 100644 yarn/src/main/scala/spark/deploy/yarn/WorkerRunnable.scala delete mode 100644 yarn/src/main/scala/spark/deploy/yarn/YarnAllocationHandler.scala delete mode 100644 yarn/src/main/scala/spark/deploy/yarn/YarnSparkHadoopUtil.scala delete mode 100644 yarn/src/main/scala/spark/scheduler/cluster/YarnClusterScheduler.scala diff --git a/README.md b/README.md index 2ddfe862a2..c4170650f7 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ Or, for the Python API, the Python shell (`./pyspark`). Spark also comes with several sample programs in the `examples` directory. To run one of them, use `./run-example `. For example: - ./run-example spark.examples.SparkLR local[2] + ./run-example org.apache.spark.examples.SparkLR local[2] will run the Logistic Regression example locally on 2 CPUs. diff --git a/assembly/pom.xml b/assembly/pom.xml index 74990b6361..dc63811b76 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -19,13 +19,13 @@ 4.0.0 - org.spark-project + org.apache.spark spark-parent 0.8.0-SNAPSHOT ../pom.xml - org.spark-project + org.apache.spark spark-assembly Spark Project Assembly http://spark-project.org/ @@ -40,27 +40,27 @@ - org.spark-project + org.apache.spark spark-core ${project.version} - org.spark-project + org.apache.spark spark-bagel ${project.version} - org.spark-project + org.apache.spark spark-mllib ${project.version} - org.spark-project + org.apache.spark spark-repl ${project.version} - org.spark-project + org.apache.spark spark-streaming ${project.version} @@ -121,7 +121,7 @@ hadoop2-yarn - org.spark-project + org.apache.spark spark-yarn ${project.version} diff --git a/assembly/src/main/assembly/assembly.xml b/assembly/src/main/assembly/assembly.xml index 4543b52c93..47d3fa93d0 100644 --- a/assembly/src/main/assembly/assembly.xml +++ b/assembly/src/main/assembly/assembly.xml @@ -30,9 +30,9 @@ - ${project.parent.basedir}/core/src/main/resources/spark/ui/static/ + ${project.parent.basedir}/core/src/main/resources/org/apache/spark/ui/static/ - /ui-resources/spark/ui/static + /ui-resources/org/apache/spark/ui/static **/* @@ -63,10 +63,10 @@ - org.spark-project:*:jar + org.apache.spark:*:jar - org.spark-project:spark-assembly:jar + org.apache.spark:spark-assembly:jar @@ -77,7 +77,7 @@ false org.apache.hadoop:*:jar - org.spark-project:*:jar + org.apache.spark:*:jar diff --git a/bagel/pom.xml b/bagel/pom.xml index cbcf8d1239..9340991377 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -19,13 +19,13 @@ 4.0.0 - org.spark-project + org.apache.spark spark-parent 0.8.0-SNAPSHOT ../pom.xml - org.spark-project + org.apache.spark spark-bagel jar Spark Project Bagel @@ -33,7 +33,7 @@ - org.spark-project + org.apache.spark spark-core ${project.version} diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala new file mode 100644 index 0000000000..fec8737fcd --- /dev/null +++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.bagel + +import org.apache.spark._ +import org.apache.spark.SparkContext._ + +import org.apache.spark.storage.StorageLevel + +object Bagel extends Logging { + val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK + + /** + * Runs a Bagel program. + * @param sc [[org.apache.spark.SparkContext]] to use for the program. + * @param vertices vertices of the graph represented as an RDD of (Key, Vertex) pairs. Often the Key will be + * the vertex id. + * @param messages initial set of messages represented as an RDD of (Key, Message) pairs. Often this will be an + * empty array, i.e. sc.parallelize(Array[K, Message]()). + * @param combiner [[org.apache.spark.bagel.Combiner]] combines multiple individual messages to a given vertex into one + * message before sending (which often involves network I/O). + * @param aggregator [[org.apache.spark.bagel.Aggregator]] performs a reduce across all vertices after each superstep, + * and provides the result to each vertex in the next superstep. + * @param partitioner [[org.apache.spark.Partitioner]] partitions values by key + * @param numPartitions number of partitions across which to split the graph. + * Default is the default parallelism of the SparkContext + * @param storageLevel [[org.apache.spark.storage.StorageLevel]] to use for caching of intermediate RDDs in each superstep. + * Defaults to caching in memory. + * @param compute function that takes a Vertex, optional set of (possibly combined) messages to the Vertex, + * optional Aggregator and the current superstep, + * and returns a set of (Vertex, outgoing Messages) pairs + * @tparam K key + * @tparam V vertex type + * @tparam M message type + * @tparam C combiner + * @tparam A aggregator + * @return an RDD of (K, V) pairs representing the graph after completion of the program + */ + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, + C: Manifest, A: Manifest]( + sc: SparkContext, + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + combiner: Combiner[M, C], + aggregator: Option[Aggregator[V, A]], + partitioner: Partitioner, + numPartitions: Int, + storageLevel: StorageLevel = DEFAULT_STORAGE_LEVEL + )( + compute: (V, Option[C], Option[A], Int) => (V, Array[M]) + ): RDD[(K, V)] = { + val splits = if (numPartitions != 0) numPartitions else sc.defaultParallelism + + var superstep = 0 + var verts = vertices + var msgs = messages + var noActivity = false + do { + logInfo("Starting superstep "+superstep+".") + val startTime = System.currentTimeMillis + + val aggregated = agg(verts, aggregator) + val combinedMsgs = msgs.combineByKey( + combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner) + val grouped = combinedMsgs.groupWith(verts) + val superstep_ = superstep // Create a read-only copy of superstep for capture in closure + val (processed, numMsgs, numActiveVerts) = + comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep_), storageLevel) + + val timeTaken = System.currentTimeMillis - startTime + logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) + + verts = processed.mapValues { case (vert, msgs) => vert } + msgs = processed.flatMap { + case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m)) + } + superstep += 1 + + noActivity = numMsgs == 0 && numActiveVerts == 0 + } while (!noActivity) + + verts + } + + /** Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] and the default storage level */ + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( + sc: SparkContext, + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + combiner: Combiner[M, C], + partitioner: Partitioner, + numPartitions: Int + )( + compute: (V, Option[C], Int) => (V, Array[M]) + ): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) + + /** Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] */ + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( + sc: SparkContext, + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + combiner: Combiner[M, C], + partitioner: Partitioner, + numPartitions: Int, + storageLevel: StorageLevel + )( + compute: (V, Option[C], Int) => (V, Array[M]) + ): RDD[(K, V)] = { + run[K, V, M, C, Nothing]( + sc, vertices, messages, combiner, None, partitioner, numPartitions, storageLevel)( + addAggregatorArg[K, V, M, C](compute)) + } + + /** + * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], default [[org.apache.spark.HashPartitioner]] + * and default storage level + */ + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( + sc: SparkContext, + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + combiner: Combiner[M, C], + numPartitions: Int + )( + compute: (V, Option[C], Int) => (V, Array[M]) + ): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) + + /** Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] and the default [[org.apache.spark.HashPartitioner]]*/ + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( + sc: SparkContext, + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + combiner: Combiner[M, C], + numPartitions: Int, + storageLevel: StorageLevel + )( + compute: (V, Option[C], Int) => (V, Array[M]) + ): RDD[(K, V)] = { + val part = new HashPartitioner(numPartitions) + run[K, V, M, C, Nothing]( + sc, vertices, messages, combiner, None, part, numPartitions, storageLevel)( + addAggregatorArg[K, V, M, C](compute)) + } + + /** + * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], default [[org.apache.spark.HashPartitioner]], + * [[org.apache.spark.bagel.DefaultCombiner]] and the default storage level + */ + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( + sc: SparkContext, + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + numPartitions: Int + )( + compute: (V, Option[Array[M]], Int) => (V, Array[M]) + ): RDD[(K, V)] = run(sc, vertices, messages, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) + + /** + * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], the default [[org.apache.spark.HashPartitioner]] + * and [[org.apache.spark.bagel.DefaultCombiner]] + */ + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( + sc: SparkContext, + vertices: RDD[(K, V)], + messages: RDD[(K, M)], + numPartitions: Int, + storageLevel: StorageLevel + )( + compute: (V, Option[Array[M]], Int) => (V, Array[M]) + ): RDD[(K, V)] = { + val part = new HashPartitioner(numPartitions) + run[K, V, M, Array[M], Nothing]( + sc, vertices, messages, new DefaultCombiner(), None, part, numPartitions, storageLevel)( + addAggregatorArg[K, V, M, Array[M]](compute)) + } + + /** + * Aggregates the given vertices using the given aggregator, if it + * is specified. + */ + private def agg[K, V <: Vertex, A: Manifest]( + verts: RDD[(K, V)], + aggregator: Option[Aggregator[V, A]] + ): Option[A] = aggregator match { + case Some(a) => + Some(verts.map { + case (id, vert) => a.createAggregator(vert) + }.reduce(a.mergeAggregators(_, _))) + case None => None + } + + /** + * Processes the given vertex-message RDD using the compute + * function. Returns the processed RDD, the number of messages + * created, and the number of active vertices. + */ + private def comp[K: Manifest, V <: Vertex, M <: Message[K], C]( + sc: SparkContext, + grouped: RDD[(K, (Seq[C], Seq[V]))], + compute: (V, Option[C]) => (V, Array[M]), + storageLevel: StorageLevel + ): (RDD[(K, (V, Array[M]))], Int, Int) = { + var numMsgs = sc.accumulator(0) + var numActiveVerts = sc.accumulator(0) + val processed = grouped.flatMapValues { + case (_, vs) if vs.size == 0 => None + case (c, vs) => + val (newVert, newMsgs) = + compute(vs(0), c match { + case Seq(comb) => Some(comb) + case Seq() => None + }) + + numMsgs += newMsgs.size + if (newVert.active) + numActiveVerts += 1 + + Some((newVert, newMsgs)) + }.persist(storageLevel) + + // Force evaluation of processed RDD for accurate performance measurements + processed.foreach(x => {}) + + (processed, numMsgs.value, numActiveVerts.value) + } + + /** + * Converts a compute function that doesn't take an aggregator to + * one that does, so it can be passed to Bagel.run. + */ + private def addAggregatorArg[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C]( + compute: (V, Option[C], Int) => (V, Array[M]) + ): (V, Option[C], Option[Nothing], Int) => (V, Array[M]) = { + (vert: V, msgs: Option[C], aggregated: Option[Nothing], superstep: Int) => + compute(vert, msgs, superstep) + } +} + +trait Combiner[M, C] { + def createCombiner(msg: M): C + def mergeMsg(combiner: C, msg: M): C + def mergeCombiners(a: C, b: C): C +} + +trait Aggregator[V, A] { + def createAggregator(vert: V): A + def mergeAggregators(a: A, b: A): A +} + +/** Default combiner that simply appends messages together (i.e. performs no aggregation) */ +class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable { + def createCombiner(msg: M): Array[M] = + Array(msg) + def mergeMsg(combiner: Array[M], msg: M): Array[M] = + combiner :+ msg + def mergeCombiners(a: Array[M], b: Array[M]): Array[M] = + a ++ b +} + +/** + * Represents a Bagel vertex. + * + * Subclasses may store state along with each vertex and must + * inherit from java.io.Serializable or scala.Serializable. + */ +trait Vertex { + def active: Boolean +} + +/** + * Represents a Bagel message to a target vertex. + * + * Subclasses may contain a payload to deliver to the target vertex + * and must inherit from java.io.Serializable or scala.Serializable. + */ +trait Message[K] { + def targetId: K +} diff --git a/bagel/src/main/scala/spark/bagel/Bagel.scala b/bagel/src/main/scala/spark/bagel/Bagel.scala deleted file mode 100644 index 80c8d53d2b..0000000000 --- a/bagel/src/main/scala/spark/bagel/Bagel.scala +++ /dev/null @@ -1,294 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.bagel - -import spark._ -import spark.SparkContext._ - -import scala.collection.mutable.ArrayBuffer -import storage.StorageLevel - -object Bagel extends Logging { - val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK - - /** - * Runs a Bagel program. - * @param sc [[spark.SparkContext]] to use for the program. - * @param vertices vertices of the graph represented as an RDD of (Key, Vertex) pairs. Often the Key will be - * the vertex id. - * @param messages initial set of messages represented as an RDD of (Key, Message) pairs. Often this will be an - * empty array, i.e. sc.parallelize(Array[K, Message]()). - * @param combiner [[spark.bagel.Combiner]] combines multiple individual messages to a given vertex into one - * message before sending (which often involves network I/O). - * @param aggregator [[spark.bagel.Aggregator]] performs a reduce across all vertices after each superstep, - * and provides the result to each vertex in the next superstep. - * @param partitioner [[spark.Partitioner]] partitions values by key - * @param numPartitions number of partitions across which to split the graph. - * Default is the default parallelism of the SparkContext - * @param storageLevel [[spark.storage.StorageLevel]] to use for caching of intermediate RDDs in each superstep. - * Defaults to caching in memory. - * @param compute function that takes a Vertex, optional set of (possibly combined) messages to the Vertex, - * optional Aggregator and the current superstep, - * and returns a set of (Vertex, outgoing Messages) pairs - * @tparam K key - * @tparam V vertex type - * @tparam M message type - * @tparam C combiner - * @tparam A aggregator - * @return an RDD of (K, V) pairs representing the graph after completion of the program - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, - C: Manifest, A: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - aggregator: Option[Aggregator[V, A]], - partitioner: Partitioner, - numPartitions: Int, - storageLevel: StorageLevel = DEFAULT_STORAGE_LEVEL - )( - compute: (V, Option[C], Option[A], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - val splits = if (numPartitions != 0) numPartitions else sc.defaultParallelism - - var superstep = 0 - var verts = vertices - var msgs = messages - var noActivity = false - do { - logInfo("Starting superstep "+superstep+".") - val startTime = System.currentTimeMillis - - val aggregated = agg(verts, aggregator) - val combinedMsgs = msgs.combineByKey( - combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner) - val grouped = combinedMsgs.groupWith(verts) - val superstep_ = superstep // Create a read-only copy of superstep for capture in closure - val (processed, numMsgs, numActiveVerts) = - comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep_), storageLevel) - - val timeTaken = System.currentTimeMillis - startTime - logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) - - verts = processed.mapValues { case (vert, msgs) => vert } - msgs = processed.flatMap { - case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m)) - } - superstep += 1 - - noActivity = numMsgs == 0 && numActiveVerts == 0 - } while (!noActivity) - - verts - } - - /** Runs a Bagel program with no [[spark.bagel.Aggregator]] and the default storage level */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - partitioner: Partitioner, - numPartitions: Int - )( - compute: (V, Option[C], Int) => (V, Array[M]) - ): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) - - /** Runs a Bagel program with no [[spark.bagel.Aggregator]] */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - partitioner: Partitioner, - numPartitions: Int, - storageLevel: StorageLevel - )( - compute: (V, Option[C], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - run[K, V, M, C, Nothing]( - sc, vertices, messages, combiner, None, partitioner, numPartitions, storageLevel)( - addAggregatorArg[K, V, M, C](compute)) - } - - /** - * Runs a Bagel program with no [[spark.bagel.Aggregator]], default [[spark.HashPartitioner]] - * and default storage level - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - numPartitions: Int - )( - compute: (V, Option[C], Int) => (V, Array[M]) - ): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) - - /** Runs a Bagel program with no [[spark.bagel.Aggregator]] and the default [[spark.HashPartitioner]]*/ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - numPartitions: Int, - storageLevel: StorageLevel - )( - compute: (V, Option[C], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - val part = new HashPartitioner(numPartitions) - run[K, V, M, C, Nothing]( - sc, vertices, messages, combiner, None, part, numPartitions, storageLevel)( - addAggregatorArg[K, V, M, C](compute)) - } - - /** - * Runs a Bagel program with no [[spark.bagel.Aggregator]], default [[spark.HashPartitioner]], - * [[spark.bagel.DefaultCombiner]] and the default storage level - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - numPartitions: Int - )( - compute: (V, Option[Array[M]], Int) => (V, Array[M]) - ): RDD[(K, V)] = run(sc, vertices, messages, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) - - /** - * Runs a Bagel program with no [[spark.bagel.Aggregator]], the default [[spark.HashPartitioner]] - * and [[spark.bagel.DefaultCombiner]] - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - numPartitions: Int, - storageLevel: StorageLevel - )( - compute: (V, Option[Array[M]], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - val part = new HashPartitioner(numPartitions) - run[K, V, M, Array[M], Nothing]( - sc, vertices, messages, new DefaultCombiner(), None, part, numPartitions, storageLevel)( - addAggregatorArg[K, V, M, Array[M]](compute)) - } - - /** - * Aggregates the given vertices using the given aggregator, if it - * is specified. - */ - private def agg[K, V <: Vertex, A: Manifest]( - verts: RDD[(K, V)], - aggregator: Option[Aggregator[V, A]] - ): Option[A] = aggregator match { - case Some(a) => - Some(verts.map { - case (id, vert) => a.createAggregator(vert) - }.reduce(a.mergeAggregators(_, _))) - case None => None - } - - /** - * Processes the given vertex-message RDD using the compute - * function. Returns the processed RDD, the number of messages - * created, and the number of active vertices. - */ - private def comp[K: Manifest, V <: Vertex, M <: Message[K], C]( - sc: SparkContext, - grouped: RDD[(K, (Seq[C], Seq[V]))], - compute: (V, Option[C]) => (V, Array[M]), - storageLevel: StorageLevel - ): (RDD[(K, (V, Array[M]))], Int, Int) = { - var numMsgs = sc.accumulator(0) - var numActiveVerts = sc.accumulator(0) - val processed = grouped.flatMapValues { - case (_, vs) if vs.size == 0 => None - case (c, vs) => - val (newVert, newMsgs) = - compute(vs(0), c match { - case Seq(comb) => Some(comb) - case Seq() => None - }) - - numMsgs += newMsgs.size - if (newVert.active) - numActiveVerts += 1 - - Some((newVert, newMsgs)) - }.persist(storageLevel) - - // Force evaluation of processed RDD for accurate performance measurements - processed.foreach(x => {}) - - (processed, numMsgs.value, numActiveVerts.value) - } - - /** - * Converts a compute function that doesn't take an aggregator to - * one that does, so it can be passed to Bagel.run. - */ - private def addAggregatorArg[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C]( - compute: (V, Option[C], Int) => (V, Array[M]) - ): (V, Option[C], Option[Nothing], Int) => (V, Array[M]) = { - (vert: V, msgs: Option[C], aggregated: Option[Nothing], superstep: Int) => - compute(vert, msgs, superstep) - } -} - -trait Combiner[M, C] { - def createCombiner(msg: M): C - def mergeMsg(combiner: C, msg: M): C - def mergeCombiners(a: C, b: C): C -} - -trait Aggregator[V, A] { - def createAggregator(vert: V): A - def mergeAggregators(a: A, b: A): A -} - -/** Default combiner that simply appends messages together (i.e. performs no aggregation) */ -class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable { - def createCombiner(msg: M): Array[M] = - Array(msg) - def mergeMsg(combiner: Array[M], msg: M): Array[M] = - combiner :+ msg - def mergeCombiners(a: Array[M], b: Array[M]): Array[M] = - a ++ b -} - -/** - * Represents a Bagel vertex. - * - * Subclasses may store state along with each vertex and must - * inherit from java.io.Serializable or scala.Serializable. - */ -trait Vertex { - def active: Boolean -} - -/** - * Represents a Bagel message to a target vertex. - * - * Subclasses may contain a payload to deliver to the target vertex - * and must inherit from java.io.Serializable or scala.Serializable. - */ -trait Message[K] { - def targetId: K -} diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala deleted file mode 100644 index ef2d57fbd0..0000000000 --- a/bagel/src/test/scala/bagel/BagelSuite.scala +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.bagel - -import org.scalatest.{FunSuite, Assertions, BeforeAndAfter} -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ - -import scala.collection.mutable.ArrayBuffer - -import spark._ -import storage.StorageLevel - -class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable -class TestMessage(val targetId: String) extends Message[String] with Serializable - -class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - } - - test("halting by voting") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("halting by message silence") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0)))) - val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - val msgsOut = - msgs match { - case Some(ms) if (superstep < numSupersteps - 1) => - ms - case _ => - Array[TestMessage]() - } - (new TestVertex(self.active, self.age + 1), msgsOut) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("large number of iterations") { - // This tests whether jobs with a large number of iterations finish in a reasonable time, - // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang - failAfter(10 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 50 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } - - test("using non-default persistence level") { - failAfter(10 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 50 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } -} diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala new file mode 100644 index 0000000000..7b954a4775 --- /dev/null +++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.bagel + +import org.scalatest.{BeforeAndAfter, FunSuite, Assertions} +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ + +import org.apache.spark._ +import org.apache.spark.storage.StorageLevel + +class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable +class TestMessage(val targetId: String) extends Message[String] with Serializable + +class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts { + + var sc: SparkContext = _ + + after { + if (sc != null) { + sc.stop() + sc = null + } + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + } + + test("halting by voting") { + sc = new SparkContext("local", "test") + val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0)))) + val msgs = sc.parallelize(Array[(String, TestMessage)]()) + val numSupersteps = 5 + val result = + Bagel.run(sc, verts, msgs, sc.defaultParallelism) { + (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => + (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) + } + for ((id, vert) <- result.collect) { + assert(vert.age === numSupersteps) + } + } + + test("halting by message silence") { + sc = new SparkContext("local", "test") + val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0)))) + val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) + val numSupersteps = 5 + val result = + Bagel.run(sc, verts, msgs, sc.defaultParallelism) { + (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => + val msgsOut = + msgs match { + case Some(ms) if (superstep < numSupersteps - 1) => + ms + case _ => + Array[TestMessage]() + } + (new TestVertex(self.active, self.age + 1), msgsOut) + } + for ((id, vert) <- result.collect) { + assert(vert.age === numSupersteps) + } + } + + test("large number of iterations") { + // This tests whether jobs with a large number of iterations finish in a reasonable time, + // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang + failAfter(10 seconds) { + sc = new SparkContext("local", "test") + val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) + val msgs = sc.parallelize(Array[(String, TestMessage)]()) + val numSupersteps = 50 + val result = + Bagel.run(sc, verts, msgs, sc.defaultParallelism) { + (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => + (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) + } + for ((id, vert) <- result.collect) { + assert(vert.age === numSupersteps) + } + } + } + + test("using non-default persistence level") { + failAfter(10 seconds) { + sc = new SparkContext("local", "test") + val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) + val msgs = sc.parallelize(Array[(String, TestMessage)]()) + val numSupersteps = 50 + val result = + Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) { + (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => + (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) + } + for ((id, vert) <- result.collect) { + assert(vert.age === numSupersteps) + } + } + } +} diff --git a/bin/start-master.sh b/bin/start-master.sh index 2288fb19d7..648c7ae75f 100755 --- a/bin/start-master.sh +++ b/bin/start-master.sh @@ -49,4 +49,4 @@ if [ "$SPARK_PUBLIC_DNS" = "" ]; then fi fi -"$bin"/spark-daemon.sh start spark.deploy.master.Master 1 --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT +"$bin"/spark-daemon.sh start org.apache.spark.deploy.master.Master 1 --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT diff --git a/bin/start-slave.sh b/bin/start-slave.sh index d6db16882d..4eefa20944 100755 --- a/bin/start-slave.sh +++ b/bin/start-slave.sh @@ -32,4 +32,4 @@ if [ "$SPARK_PUBLIC_DNS" = "" ]; then fi fi -"$bin"/spark-daemon.sh start spark.deploy.worker.Worker "$@" +"$bin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker "$@" diff --git a/bin/stop-master.sh b/bin/stop-master.sh index 31a610bf9d..310e33bedc 100755 --- a/bin/stop-master.sh +++ b/bin/stop-master.sh @@ -24,4 +24,4 @@ bin=`cd "$bin"; pwd` . "$bin/spark-config.sh" -"$bin"/spark-daemon.sh stop spark.deploy.master.Master 1 +"$bin"/spark-daemon.sh stop org.apache.spark.deploy.master.Master 1 diff --git a/bin/stop-slaves.sh b/bin/stop-slaves.sh index 8e056f23d4..03e416a132 100755 --- a/bin/stop-slaves.sh +++ b/bin/stop-slaves.sh @@ -29,9 +29,9 @@ if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then fi if [ "$SPARK_WORKER_INSTANCES" = "" ]; then - "$bin"/spark-daemons.sh stop spark.deploy.worker.Worker 1 + "$bin"/spark-daemons.sh stop org.apache.spark.deploy.worker.Worker 1 else for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do - "$bin"/spark-daemons.sh stop spark.deploy.worker.Worker $(( $i + 1 )) + "$bin"/spark-daemons.sh stop org.apache.spark.deploy.worker.Worker $(( $i + 1 )) done fi diff --git a/core/pom.xml b/core/pom.xml index 53696367e9..c803217f96 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -19,13 +19,13 @@ 4.0.0 - org.spark-project + org.apache.spark spark-parent 0.8.0-SNAPSHOT ../pom.xml - org.spark-project + org.apache.spark spark-core jar Spark Project Core diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClient.java b/core/src/main/java/org/apache/spark/network/netty/FileClient.java new file mode 100644 index 0000000000..20a7a3aa8c --- /dev/null +++ b/core/src/main/java/org/apache/spark/network/netty/FileClient.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelOption; +import io.netty.channel.oio.OioEventLoopGroup; +import io.netty.channel.socket.oio.OioSocketChannel; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class FileClient { + + private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); + private FileClientHandler handler = null; + private Channel channel = null; + private Bootstrap bootstrap = null; + private int connectTimeout = 60*1000; // 1 min + + public FileClient(FileClientHandler handler, int connectTimeout) { + this.handler = handler; + this.connectTimeout = connectTimeout; + } + + public void init() { + bootstrap = new Bootstrap(); + bootstrap.group(new OioEventLoopGroup()) + .channel(OioSocketChannel.class) + .option(ChannelOption.SO_KEEPALIVE, true) + .option(ChannelOption.TCP_NODELAY, true) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) + .handler(new FileClientChannelInitializer(handler)); + } + + public void connect(String host, int port) { + try { + // Start the connection attempt. + channel = bootstrap.connect(host, port).sync().channel(); + // ChannelFuture cf = channel.closeFuture(); + //cf.addListener(new ChannelCloseListener(this)); + } catch (InterruptedException e) { + close(); + } + } + + public void waitForClose() { + try { + channel.closeFuture().sync(); + } catch (InterruptedException e) { + LOG.warn("FileClient interrupted", e); + } + } + + public void sendRequest(String file) { + //assert(file == null); + //assert(channel == null); + channel.write(file + "\r\n"); + } + + public void close() { + if(channel != null) { + channel.close(); + channel = null; + } + if ( bootstrap!=null) { + bootstrap.shutdown(); + bootstrap = null; + } + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java new file mode 100644 index 0000000000..65ee15d63b --- /dev/null +++ b/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty; + +import io.netty.buffer.BufType; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.string.StringEncoder; + + +class FileClientChannelInitializer extends ChannelInitializer { + + private FileClientHandler fhandler; + + public FileClientChannelInitializer(FileClientHandler handler) { + fhandler = handler; + } + + @Override + public void initChannel(SocketChannel channel) { + // file no more than 2G + channel.pipeline() + .addLast("encoder", new StringEncoder(BufType.BYTE)) + .addLast("handler", fhandler); + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java new file mode 100644 index 0000000000..c4aa2669e0 --- /dev/null +++ b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundByteHandlerAdapter; + + +abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { + + private FileHeader currentHeader = null; + + private volatile boolean handlerCalled = false; + + public boolean isComplete() { + return handlerCalled; + } + + public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); + public abstract void handleError(String blockId); + + @Override + public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { + // Use direct buffer if possible. + return ctx.alloc().ioBuffer(); + } + + @Override + public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) { + // get header + if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) { + currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE())); + } + // get file + if(in.readableBytes() >= currentHeader.fileLen()) { + handle(ctx, in, currentHeader); + handlerCalled = true; + currentHeader = null; + ctx.close(); + } + } + +} + diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServer.java b/core/src/main/java/org/apache/spark/network/netty/FileServer.java new file mode 100644 index 0000000000..666432474d --- /dev/null +++ b/core/src/main/java/org/apache/spark/network/netty/FileServer.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty; + +import java.net.InetSocketAddress; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelOption; +import io.netty.channel.oio.OioEventLoopGroup; +import io.netty.channel.socket.oio.OioServerSocketChannel; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * Server that accept the path of a file an echo back its content. + */ +class FileServer { + + private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); + + private ServerBootstrap bootstrap = null; + private ChannelFuture channelFuture = null; + private int port = 0; + private Thread blockingThread = null; + + public FileServer(PathResolver pResolver, int port) { + InetSocketAddress addr = new InetSocketAddress(port); + + // Configure the server. + bootstrap = new ServerBootstrap(); + bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup()) + .channel(OioServerSocketChannel.class) + .option(ChannelOption.SO_BACKLOG, 100) + .option(ChannelOption.SO_RCVBUF, 1500) + .childHandler(new FileServerChannelInitializer(pResolver)); + // Start the server. + channelFuture = bootstrap.bind(addr); + try { + // Get the address we bound to. + InetSocketAddress boundAddress = + ((InetSocketAddress) channelFuture.sync().channel().localAddress()); + this.port = boundAddress.getPort(); + } catch (InterruptedException ie) { + this.port = 0; + } + } + + /** + * Start the file server asynchronously in a new thread. + */ + public void start() { + blockingThread = new Thread() { + public void run() { + try { + channelFuture.channel().closeFuture().sync(); + LOG.info("FileServer exiting"); + } catch (InterruptedException e) { + LOG.error("File server start got interrupted", e); + } + // NOTE: bootstrap is shutdown in stop() + } + }; + blockingThread.setDaemon(true); + blockingThread.start(); + } + + public int getPort() { + return port; + } + + public void stop() { + // Close the bound channel. + if (channelFuture != null) { + channelFuture.channel().close(); + channelFuture = null; + } + // Shutdown bootstrap. + if (bootstrap != null) { + bootstrap.shutdown(); + bootstrap = null; + } + // TODO: Shutdown all accepted channels as well ? + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java new file mode 100644 index 0000000000..833af1632d --- /dev/null +++ b/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty; + +import io.netty.channel.ChannelInitializer; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.DelimiterBasedFrameDecoder; +import io.netty.handler.codec.Delimiters; +import io.netty.handler.codec.string.StringDecoder; + + +class FileServerChannelInitializer extends ChannelInitializer { + + PathResolver pResolver; + + public FileServerChannelInitializer(PathResolver pResolver) { + this.pResolver = pResolver; + } + + @Override + public void initChannel(SocketChannel channel) { + channel.pipeline() + .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter())) + .addLast("strDecoder", new StringDecoder()) + .addLast("handler", new FileServerHandler(pResolver)); + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java new file mode 100644 index 0000000000..d3d57a0255 --- /dev/null +++ b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty; + +import java.io.File; +import java.io.FileInputStream; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundMessageHandlerAdapter; +import io.netty.channel.DefaultFileRegion; + + +class FileServerHandler extends ChannelInboundMessageHandlerAdapter { + + PathResolver pResolver; + + public FileServerHandler(PathResolver pResolver){ + this.pResolver = pResolver; + } + + @Override + public void messageReceived(ChannelHandlerContext ctx, String blockId) { + String path = pResolver.getAbsolutePath(blockId); + // if getFilePath returns null, close the channel + if (path == null) { + //ctx.close(); + return; + } + File file = new File(path); + if (file.exists()) { + if (!file.isFile()) { + //logger.info("Not a file : " + file.getAbsolutePath()); + ctx.write(new FileHeader(0, blockId).buffer()); + ctx.flush(); + return; + } + long length = file.length(); + if (length > Integer.MAX_VALUE || length <= 0) { + //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length); + ctx.write(new FileHeader(0, blockId).buffer()); + ctx.flush(); + return; + } + int len = new Long(length).intValue(); + //logger.info("Sending block "+blockId+" filelen = "+len); + //logger.info("header = "+ (new FileHeader(len, blockId)).buffer()); + ctx.write((new FileHeader(len, blockId)).buffer()); + try { + ctx.sendFile(new DefaultFileRegion(new FileInputStream(file) + .getChannel(), 0, file.length())); + } catch (Exception e) { + //logger.warning("Exception when sending file : " + file.getAbsolutePath()); + e.printStackTrace(); + } + } else { + //logger.warning("File not found: " + file.getAbsolutePath()); + ctx.write(new FileHeader(0, blockId).buffer()); + } + ctx.flush(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + cause.printStackTrace(); + ctx.close(); + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java b/core/src/main/java/org/apache/spark/network/netty/PathResolver.java new file mode 100755 index 0000000000..94c034cad0 --- /dev/null +++ b/core/src/main/java/org/apache/spark/network/netty/PathResolver.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty; + + +public interface PathResolver { + /** + * Get the absolute path of the file + * + * @param fileId + * @return the absolute path of file + */ + public String getAbsolutePath(String fileId); +} diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java deleted file mode 100644 index 0625a6d502..0000000000 --- a/core/src/main/java/spark/network/netty/FileClient.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty; - -import io.netty.bootstrap.Bootstrap; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelOption; -import io.netty.channel.oio.OioEventLoopGroup; -import io.netty.channel.socket.oio.OioSocketChannel; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -class FileClient { - - private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); - private FileClientHandler handler = null; - private Channel channel = null; - private Bootstrap bootstrap = null; - private int connectTimeout = 60*1000; // 1 min - - public FileClient(FileClientHandler handler, int connectTimeout) { - this.handler = handler; - this.connectTimeout = connectTimeout; - } - - public void init() { - bootstrap = new Bootstrap(); - bootstrap.group(new OioEventLoopGroup()) - .channel(OioSocketChannel.class) - .option(ChannelOption.SO_KEEPALIVE, true) - .option(ChannelOption.TCP_NODELAY, true) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) - .handler(new FileClientChannelInitializer(handler)); - } - - public void connect(String host, int port) { - try { - // Start the connection attempt. - channel = bootstrap.connect(host, port).sync().channel(); - // ChannelFuture cf = channel.closeFuture(); - //cf.addListener(new ChannelCloseListener(this)); - } catch (InterruptedException e) { - close(); - } - } - - public void waitForClose() { - try { - channel.closeFuture().sync(); - } catch (InterruptedException e) { - LOG.warn("FileClient interrupted", e); - } - } - - public void sendRequest(String file) { - //assert(file == null); - //assert(channel == null); - channel.write(file + "\r\n"); - } - - public void close() { - if(channel != null) { - channel.close(); - channel = null; - } - if ( bootstrap!=null) { - bootstrap.shutdown(); - bootstrap = null; - } - } -} diff --git a/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java deleted file mode 100644 index 05ad4b61d7..0000000000 --- a/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty; - -import io.netty.buffer.BufType; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.socket.SocketChannel; -import io.netty.handler.codec.string.StringEncoder; - - -class FileClientChannelInitializer extends ChannelInitializer { - - private FileClientHandler fhandler; - - public FileClientChannelInitializer(FileClientHandler handler) { - fhandler = handler; - } - - @Override - public void initChannel(SocketChannel channel) { - // file no more than 2G - channel.pipeline() - .addLast("encoder", new StringEncoder(BufType.BYTE)) - .addLast("handler", fhandler); - } -} diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java deleted file mode 100644 index e8cd9801f6..0000000000 --- a/core/src/main/java/spark/network/netty/FileClientHandler.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundByteHandlerAdapter; - - -abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { - - private FileHeader currentHeader = null; - - private volatile boolean handlerCalled = false; - - public boolean isComplete() { - return handlerCalled; - } - - public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); - public abstract void handleError(String blockId); - - @Override - public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { - // Use direct buffer if possible. - return ctx.alloc().ioBuffer(); - } - - @Override - public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) { - // get header - if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) { - currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE())); - } - // get file - if(in.readableBytes() >= currentHeader.fileLen()) { - handle(ctx, in, currentHeader); - handlerCalled = true; - currentHeader = null; - ctx.close(); - } - } - -} - diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java deleted file mode 100644 index 9f009a61d5..0000000000 --- a/core/src/main/java/spark/network/netty/FileServer.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty; - -import java.net.InetSocketAddress; - -import io.netty.bootstrap.ServerBootstrap; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelOption; -import io.netty.channel.oio.OioEventLoopGroup; -import io.netty.channel.socket.oio.OioServerSocketChannel; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - - -/** - * Server that accept the path of a file an echo back its content. - */ -class FileServer { - - private Logger LOG = LoggerFactory.getLogger(this.getClass().getName()); - - private ServerBootstrap bootstrap = null; - private ChannelFuture channelFuture = null; - private int port = 0; - private Thread blockingThread = null; - - public FileServer(PathResolver pResolver, int port) { - InetSocketAddress addr = new InetSocketAddress(port); - - // Configure the server. - bootstrap = new ServerBootstrap(); - bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup()) - .channel(OioServerSocketChannel.class) - .option(ChannelOption.SO_BACKLOG, 100) - .option(ChannelOption.SO_RCVBUF, 1500) - .childHandler(new FileServerChannelInitializer(pResolver)); - // Start the server. - channelFuture = bootstrap.bind(addr); - try { - // Get the address we bound to. - InetSocketAddress boundAddress = - ((InetSocketAddress) channelFuture.sync().channel().localAddress()); - this.port = boundAddress.getPort(); - } catch (InterruptedException ie) { - this.port = 0; - } - } - - /** - * Start the file server asynchronously in a new thread. - */ - public void start() { - blockingThread = new Thread() { - public void run() { - try { - channelFuture.channel().closeFuture().sync(); - LOG.info("FileServer exiting"); - } catch (InterruptedException e) { - LOG.error("File server start got interrupted", e); - } - // NOTE: bootstrap is shutdown in stop() - } - }; - blockingThread.setDaemon(true); - blockingThread.start(); - } - - public int getPort() { - return port; - } - - public void stop() { - // Close the bound channel. - if (channelFuture != null) { - channelFuture.channel().close(); - channelFuture = null; - } - // Shutdown bootstrap. - if (bootstrap != null) { - bootstrap.shutdown(); - bootstrap = null; - } - // TODO: Shutdown all accepted channels as well ? - } -} diff --git a/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java deleted file mode 100644 index 50c57a81a3..0000000000 --- a/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty; - -import io.netty.channel.ChannelInitializer; -import io.netty.channel.socket.SocketChannel; -import io.netty.handler.codec.DelimiterBasedFrameDecoder; -import io.netty.handler.codec.Delimiters; -import io.netty.handler.codec.string.StringDecoder; - - -class FileServerChannelInitializer extends ChannelInitializer { - - PathResolver pResolver; - - public FileServerChannelInitializer(PathResolver pResolver) { - this.pResolver = pResolver; - } - - @Override - public void initChannel(SocketChannel channel) { - channel.pipeline() - .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter())) - .addLast("strDecoder", new StringDecoder()) - .addLast("handler", new FileServerHandler(pResolver)); - } -} diff --git a/core/src/main/java/spark/network/netty/FileServerHandler.java b/core/src/main/java/spark/network/netty/FileServerHandler.java deleted file mode 100644 index 176ba8da49..0000000000 --- a/core/src/main/java/spark/network/netty/FileServerHandler.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty; - -import java.io.File; -import java.io.FileInputStream; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundMessageHandlerAdapter; -import io.netty.channel.DefaultFileRegion; - - -class FileServerHandler extends ChannelInboundMessageHandlerAdapter { - - PathResolver pResolver; - - public FileServerHandler(PathResolver pResolver){ - this.pResolver = pResolver; - } - - @Override - public void messageReceived(ChannelHandlerContext ctx, String blockId) { - String path = pResolver.getAbsolutePath(blockId); - // if getFilePath returns null, close the channel - if (path == null) { - //ctx.close(); - return; - } - File file = new File(path); - if (file.exists()) { - if (!file.isFile()) { - //logger.info("Not a file : " + file.getAbsolutePath()); - ctx.write(new FileHeader(0, blockId).buffer()); - ctx.flush(); - return; - } - long length = file.length(); - if (length > Integer.MAX_VALUE || length <= 0) { - //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length); - ctx.write(new FileHeader(0, blockId).buffer()); - ctx.flush(); - return; - } - int len = new Long(length).intValue(); - //logger.info("Sending block "+blockId+" filelen = "+len); - //logger.info("header = "+ (new FileHeader(len, blockId)).buffer()); - ctx.write((new FileHeader(len, blockId)).buffer()); - try { - ctx.sendFile(new DefaultFileRegion(new FileInputStream(file) - .getChannel(), 0, file.length())); - } catch (Exception e) { - //logger.warning("Exception when sending file : " + file.getAbsolutePath()); - e.printStackTrace(); - } - } else { - //logger.warning("File not found: " + file.getAbsolutePath()); - ctx.write(new FileHeader(0, blockId).buffer()); - } - ctx.flush(); - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - cause.printStackTrace(); - ctx.close(); - } -} diff --git a/core/src/main/java/spark/network/netty/PathResolver.java b/core/src/main/java/spark/network/netty/PathResolver.java deleted file mode 100755 index f446c55b19..0000000000 --- a/core/src/main/java/spark/network/netty/PathResolver.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty; - - -public interface PathResolver { - /** - * Get the absolute path of the file - * - * @param fileId - * @return the absolute path of file - */ - public String getAbsolutePath(String fileId); -} diff --git a/core/src/main/resources/org/apache/spark/ui/static/bootstrap.min.css b/core/src/main/resources/org/apache/spark/ui/static/bootstrap.min.css new file mode 100755 index 0000000000..13cef3d6f1 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/bootstrap.min.css @@ -0,0 +1,874 @@ +/*! + * Bootstrap v2.3.2 + * + * Copyright 2013 Twitter, Inc + * Licensed under the Apache License v2.0 + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Designed and built with all the love in the world @twitter by @mdo and @fat. + */ +.clearfix{*zoom:1;}.clearfix:before,.clearfix:after{display:table;content:"";line-height:0;} +.clearfix:after{clear:both;} +.hide-text{font:0/0 a;color:transparent;text-shadow:none;background-color:transparent;border:0;} +.input-block-level{display:block;width:100%;min-height:30px;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;} +article,aside,details,figcaption,figure,footer,header,hgroup,nav,section{display:block;} +audio,canvas,video{display:inline-block;*display:inline;*zoom:1;} +audio:not([controls]){display:none;} +html{font-size:100%;-webkit-text-size-adjust:100%;-ms-text-size-adjust:100%;} +a:focus{outline:thin dotted #333;outline:5px auto -webkit-focus-ring-color;outline-offset:-2px;} +a:hover,a:active{outline:0;} +sub,sup{position:relative;font-size:75%;line-height:0;vertical-align:baseline;} +sup{top:-0.5em;} +sub{bottom:-0.25em;} +img{max-width:100%;width:auto\9;height:auto;vertical-align:middle;border:0;-ms-interpolation-mode:bicubic;} +#map_canvas img,.google-maps img{max-width:none;} +button,input,select,textarea{margin:0;font-size:100%;vertical-align:middle;} +button,input{*overflow:visible;line-height:normal;} +button::-moz-focus-inner,input::-moz-focus-inner{padding:0;border:0;} +button,html input[type="button"],input[type="reset"],input[type="submit"]{-webkit-appearance:button;cursor:pointer;} +label,select,button,input[type="button"],input[type="reset"],input[type="submit"],input[type="radio"],input[type="checkbox"]{cursor:pointer;} +input[type="search"]{-webkit-box-sizing:content-box;-moz-box-sizing:content-box;box-sizing:content-box;-webkit-appearance:textfield;} +input[type="search"]::-webkit-search-decoration,input[type="search"]::-webkit-search-cancel-button{-webkit-appearance:none;} +textarea{overflow:auto;vertical-align:top;} +@media print{*{text-shadow:none !important;color:#000 !important;background:transparent !important;box-shadow:none !important;} a,a:visited{text-decoration:underline;} a[href]:after{content:" (" attr(href) ")";} abbr[title]:after{content:" (" attr(title) ")";} .ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:"";} pre,blockquote{border:1px solid #999;page-break-inside:avoid;} thead{display:table-header-group;} tr,img{page-break-inside:avoid;} img{max-width:100% !important;} @page {margin:0.5cm;}p,h2,h3{orphans:3;widows:3;} h2,h3{page-break-after:avoid;}}body{margin:0;font-family:"Helvetica Neue",Helvetica,Arial,sans-serif;font-size:14px;line-height:20px;color:#333333;background-color:#ffffff;} +a{color:#0088cc;text-decoration:none;} +a:hover,a:focus{color:#005580;text-decoration:underline;} +.img-rounded{-webkit-border-radius:6px;-moz-border-radius:6px;border-radius:6px;} +.img-polaroid{padding:4px;background-color:#fff;border:1px solid #ccc;border:1px solid rgba(0, 0, 0, 0.2);-webkit-box-shadow:0 1px 3px rgba(0, 0, 0, 0.1);-moz-box-shadow:0 1px 3px rgba(0, 0, 0, 0.1);box-shadow:0 1px 3px rgba(0, 0, 0, 0.1);} +.img-circle{-webkit-border-radius:500px;-moz-border-radius:500px;border-radius:500px;} +.row{margin-left:-20px;*zoom:1;}.row:before,.row:after{display:table;content:"";line-height:0;} +.row:after{clear:both;} +[class*="span"]{float:left;min-height:1px;margin-left:20px;} +.container,.navbar-static-top .container,.navbar-fixed-top .container,.navbar-fixed-bottom .container{width:940px;} +.span12{width:940px;} +.span11{width:860px;} +.span10{width:780px;} +.span9{width:700px;} +.span8{width:620px;} +.span7{width:540px;} +.span6{width:460px;} +.span5{width:380px;} +.span4{width:300px;} +.span3{width:220px;} +.span2{width:140px;} +.span1{width:60px;} +.offset12{margin-left:980px;} +.offset11{margin-left:900px;} +.offset10{margin-left:820px;} +.offset9{margin-left:740px;} +.offset8{margin-left:660px;} +.offset7{margin-left:580px;} +.offset6{margin-left:500px;} +.offset5{margin-left:420px;} +.offset4{margin-left:340px;} +.offset3{margin-left:260px;} +.offset2{margin-left:180px;} +.offset1{margin-left:100px;} +.row-fluid{width:100%;*zoom:1;}.row-fluid:before,.row-fluid:after{display:table;content:"";line-height:0;} +.row-fluid:after{clear:both;} +.row-fluid [class*="span"]{display:block;width:100%;min-height:30px;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;float:left;margin-left:2.127659574468085%;*margin-left:2.074468085106383%;} +.row-fluid [class*="span"]:first-child{margin-left:0;} +.row-fluid .controls-row [class*="span"]+[class*="span"]{margin-left:2.127659574468085%;} +.row-fluid .span12{width:100%;*width:99.94680851063829%;} +.row-fluid .span11{width:91.48936170212765%;*width:91.43617021276594%;} +.row-fluid .span10{width:82.97872340425532%;*width:82.92553191489361%;} +.row-fluid .span9{width:74.46808510638297%;*width:74.41489361702126%;} +.row-fluid .span8{width:65.95744680851064%;*width:65.90425531914893%;} +.row-fluid .span7{width:57.44680851063829%;*width:57.39361702127659%;} +.row-fluid .span6{width:48.93617021276595%;*width:48.88297872340425%;} +.row-fluid .span5{width:40.42553191489362%;*width:40.37234042553192%;} +.row-fluid .span4{width:31.914893617021278%;*width:31.861702127659576%;} +.row-fluid .span3{width:23.404255319148934%;*width:23.351063829787233%;} +.row-fluid .span2{width:14.893617021276595%;*width:14.840425531914894%;} +.row-fluid .span1{width:6.382978723404255%;*width:6.329787234042553%;} +.row-fluid .offset12{margin-left:104.25531914893617%;*margin-left:104.14893617021275%;} +.row-fluid .offset12:first-child{margin-left:102.12765957446808%;*margin-left:102.02127659574467%;} +.row-fluid .offset11{margin-left:95.74468085106382%;*margin-left:95.6382978723404%;} +.row-fluid .offset11:first-child{margin-left:93.61702127659574%;*margin-left:93.51063829787232%;} +.row-fluid .offset10{margin-left:87.23404255319149%;*margin-left:87.12765957446807%;} +.row-fluid .offset10:first-child{margin-left:85.1063829787234%;*margin-left:84.99999999999999%;} +.row-fluid .offset9{margin-left:78.72340425531914%;*margin-left:78.61702127659572%;} +.row-fluid .offset9:first-child{margin-left:76.59574468085106%;*margin-left:76.48936170212764%;} +.row-fluid .offset8{margin-left:70.2127659574468%;*margin-left:70.10638297872339%;} +.row-fluid .offset8:first-child{margin-left:68.08510638297872%;*margin-left:67.9787234042553%;} +.row-fluid .offset7{margin-left:61.70212765957446%;*margin-left:61.59574468085106%;} +.row-fluid .offset7:first-child{margin-left:59.574468085106375%;*margin-left:59.46808510638297%;} +.row-fluid .offset6{margin-left:53.191489361702125%;*margin-left:53.085106382978715%;} +.row-fluid .offset6:first-child{margin-left:51.063829787234035%;*margin-left:50.95744680851063%;} +.row-fluid .offset5{margin-left:44.68085106382979%;*margin-left:44.57446808510638%;} +.row-fluid .offset5:first-child{margin-left:42.5531914893617%;*margin-left:42.4468085106383%;} +.row-fluid .offset4{margin-left:36.170212765957444%;*margin-left:36.06382978723405%;} +.row-fluid .offset4:first-child{margin-left:34.04255319148936%;*margin-left:33.93617021276596%;} +.row-fluid .offset3{margin-left:27.659574468085104%;*margin-left:27.5531914893617%;} +.row-fluid .offset3:first-child{margin-left:25.53191489361702%;*margin-left:25.425531914893618%;} +.row-fluid .offset2{margin-left:19.148936170212764%;*margin-left:19.04255319148936%;} +.row-fluid .offset2:first-child{margin-left:17.02127659574468%;*margin-left:16.914893617021278%;} +.row-fluid .offset1{margin-left:10.638297872340425%;*margin-left:10.53191489361702%;} +.row-fluid .offset1:first-child{margin-left:8.51063829787234%;*margin-left:8.404255319148938%;} +[class*="span"].hide,.row-fluid [class*="span"].hide{display:none;} +[class*="span"].pull-right,.row-fluid [class*="span"].pull-right{float:right;} +.container{margin-right:auto;margin-left:auto;*zoom:1;}.container:before,.container:after{display:table;content:"";line-height:0;} +.container:after{clear:both;} +.container-fluid{padding-right:20px;padding-left:20px;*zoom:1;}.container-fluid:before,.container-fluid:after{display:table;content:"";line-height:0;} +.container-fluid:after{clear:both;} +p{margin:0 0 10px;} +.lead{margin-bottom:20px;font-size:21px;font-weight:200;line-height:30px;} +small{font-size:85%;} +strong{font-weight:bold;} +em{font-style:italic;} +cite{font-style:normal;} +.muted{color:#999999;} +a.muted:hover,a.muted:focus{color:#808080;} +.text-warning{color:#c09853;} +a.text-warning:hover,a.text-warning:focus{color:#a47e3c;} +.text-error{color:#b94a48;} +a.text-error:hover,a.text-error:focus{color:#953b39;} +.text-info{color:#3a87ad;} +a.text-info:hover,a.text-info:focus{color:#2d6987;} +.text-success{color:#468847;} +a.text-success:hover,a.text-success:focus{color:#356635;} +.text-left{text-align:left;} +.text-right{text-align:right;} +.text-center{text-align:center;} +h1,h2,h3,h4,h5,h6{margin:10px 0;font-family:inherit;font-weight:bold;line-height:20px;color:inherit;text-rendering:optimizelegibility;}h1 small,h2 small,h3 small,h4 small,h5 small,h6 small{font-weight:normal;line-height:1;color:#999999;} +h1,h2,h3{line-height:40px;} +h1{font-size:38.5px;} +h2{font-size:31.5px;} +h3{font-size:24.5px;} +h4{font-size:17.5px;} +h5{font-size:14px;} +h6{font-size:11.9px;} +h1 small{font-size:24.5px;} +h2 small{font-size:17.5px;} +h3 small{font-size:14px;} +h4 small{font-size:14px;} +.page-header{padding-bottom:9px;margin:20px 0 30px;border-bottom:1px solid #eeeeee;} +ul,ol{padding:0;margin:0 0 10px 25px;} +ul ul,ul ol,ol ol,ol ul{margin-bottom:0;} +li{line-height:20px;} +ul.unstyled,ol.unstyled{margin-left:0;list-style:none;} +ul.inline,ol.inline{margin-left:0;list-style:none;}ul.inline>li,ol.inline>li{display:inline-block;*display:inline;*zoom:1;padding-left:5px;padding-right:5px;} +dl{margin-bottom:20px;} +dt,dd{line-height:20px;} +dt{font-weight:bold;} +dd{margin-left:10px;} +.dl-horizontal{*zoom:1;}.dl-horizontal:before,.dl-horizontal:after{display:table;content:"";line-height:0;} +.dl-horizontal:after{clear:both;} +.dl-horizontal dt{float:left;width:160px;clear:left;text-align:right;overflow:hidden;text-overflow:ellipsis;white-space:nowrap;} +.dl-horizontal dd{margin-left:180px;} +hr{margin:20px 0;border:0;border-top:1px solid #eeeeee;border-bottom:1px solid #ffffff;} +abbr[title],abbr[data-original-title]{cursor:help;border-bottom:1px dotted #999999;} +abbr.initialism{font-size:90%;text-transform:uppercase;} +blockquote{padding:0 0 0 15px;margin:0 0 20px;border-left:5px solid #eeeeee;}blockquote p{margin-bottom:0;font-size:17.5px;font-weight:300;line-height:1.25;} +blockquote small{display:block;line-height:20px;color:#999999;}blockquote small:before{content:'\2014 \00A0';} +blockquote.pull-right{float:right;padding-right:15px;padding-left:0;border-right:5px solid #eeeeee;border-left:0;}blockquote.pull-right p,blockquote.pull-right small{text-align:right;} +blockquote.pull-right small:before{content:'';} +blockquote.pull-right small:after{content:'\00A0 \2014';} +q:before,q:after,blockquote:before,blockquote:after{content:"";} +address{display:block;margin-bottom:20px;font-style:normal;line-height:20px;} +code,pre{padding:0 3px 2px;font-family:Monaco,Menlo,Consolas,"Courier New",monospace;font-size:12px;color:#333333;-webkit-border-radius:3px;-moz-border-radius:3px;border-radius:3px;} +code{padding:2px 4px;color:#d14;background-color:#f7f7f9;border:1px solid #e1e1e8;white-space:nowrap;} +pre{display:block;padding:9.5px;margin:0 0 10px;font-size:13px;line-height:20px;word-break:break-all;word-wrap:break-word;white-space:pre;white-space:pre-wrap;background-color:#f5f5f5;border:1px solid #ccc;border:1px solid rgba(0, 0, 0, 0.15);-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;}pre.prettyprint{margin-bottom:20px;} +pre code{padding:0;color:inherit;white-space:pre;white-space:pre-wrap;background-color:transparent;border:0;} +.pre-scrollable{max-height:340px;overflow-y:scroll;} +.label,.badge{display:inline-block;padding:2px 4px;font-size:11.844px;font-weight:bold;line-height:14px;color:#ffffff;vertical-align:baseline;white-space:nowrap;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#999999;} +.label{-webkit-border-radius:3px;-moz-border-radius:3px;border-radius:3px;} +.badge{padding-left:9px;padding-right:9px;-webkit-border-radius:9px;-moz-border-radius:9px;border-radius:9px;} +.label:empty,.badge:empty{display:none;} +a.label:hover,a.label:focus,a.badge:hover,a.badge:focus{color:#ffffff;text-decoration:none;cursor:pointer;} +.label-important,.badge-important{background-color:#b94a48;} +.label-important[href],.badge-important[href]{background-color:#953b39;} +.label-warning,.badge-warning{background-color:#f89406;} +.label-warning[href],.badge-warning[href]{background-color:#c67605;} +.label-success,.badge-success{background-color:#468847;} +.label-success[href],.badge-success[href]{background-color:#356635;} +.label-info,.badge-info{background-color:#3a87ad;} +.label-info[href],.badge-info[href]{background-color:#2d6987;} +.label-inverse,.badge-inverse{background-color:#333333;} +.label-inverse[href],.badge-inverse[href]{background-color:#1a1a1a;} +.btn .label,.btn .badge{position:relative;top:-1px;} +.btn-mini .label,.btn-mini .badge{top:0;} +table{max-width:100%;background-color:transparent;border-collapse:collapse;border-spacing:0;} +.table{width:100%;margin-bottom:20px;}.table th,.table td{padding:8px;line-height:20px;text-align:left;vertical-align:top;border-top:1px solid #dddddd;} +.table th{font-weight:bold;} +.table thead th{vertical-align:bottom;} +.table caption+thead tr:first-child th,.table caption+thead tr:first-child td,.table colgroup+thead tr:first-child th,.table colgroup+thead tr:first-child td,.table thead:first-child tr:first-child th,.table thead:first-child tr:first-child td{border-top:0;} +.table tbody+tbody{border-top:2px solid #dddddd;} +.table .table{background-color:#ffffff;} +.table-condensed th,.table-condensed td{padding:4px 5px;} +.table-bordered{border:1px solid #dddddd;border-collapse:separate;*border-collapse:collapse;border-left:0;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;}.table-bordered th,.table-bordered td{border-left:1px solid #dddddd;} +.table-bordered caption+thead tr:first-child th,.table-bordered caption+tbody tr:first-child th,.table-bordered caption+tbody tr:first-child td,.table-bordered colgroup+thead tr:first-child th,.table-bordered colgroup+tbody tr:first-child th,.table-bordered colgroup+tbody tr:first-child td,.table-bordered thead:first-child tr:first-child th,.table-bordered tbody:first-child tr:first-child th,.table-bordered tbody:first-child tr:first-child td{border-top:0;} +.table-bordered thead:first-child tr:first-child>th:first-child,.table-bordered tbody:first-child tr:first-child>td:first-child,.table-bordered tbody:first-child tr:first-child>th:first-child{-webkit-border-top-left-radius:4px;-moz-border-radius-topleft:4px;border-top-left-radius:4px;} +.table-bordered thead:first-child tr:first-child>th:last-child,.table-bordered tbody:first-child tr:first-child>td:last-child,.table-bordered tbody:first-child tr:first-child>th:last-child{-webkit-border-top-right-radius:4px;-moz-border-radius-topright:4px;border-top-right-radius:4px;} +.table-bordered thead:last-child tr:last-child>th:first-child,.table-bordered tbody:last-child tr:last-child>td:first-child,.table-bordered tbody:last-child tr:last-child>th:first-child,.table-bordered tfoot:last-child tr:last-child>td:first-child,.table-bordered tfoot:last-child tr:last-child>th:first-child{-webkit-border-bottom-left-radius:4px;-moz-border-radius-bottomleft:4px;border-bottom-left-radius:4px;} +.table-bordered thead:last-child tr:last-child>th:last-child,.table-bordered tbody:last-child tr:last-child>td:last-child,.table-bordered tbody:last-child tr:last-child>th:last-child,.table-bordered tfoot:last-child tr:last-child>td:last-child,.table-bordered tfoot:last-child tr:last-child>th:last-child{-webkit-border-bottom-right-radius:4px;-moz-border-radius-bottomright:4px;border-bottom-right-radius:4px;} +.table-bordered tfoot+tbody:last-child tr:last-child td:first-child{-webkit-border-bottom-left-radius:0;-moz-border-radius-bottomleft:0;border-bottom-left-radius:0;} +.table-bordered tfoot+tbody:last-child tr:last-child td:last-child{-webkit-border-bottom-right-radius:0;-moz-border-radius-bottomright:0;border-bottom-right-radius:0;} +.table-bordered caption+thead tr:first-child th:first-child,.table-bordered caption+tbody tr:first-child td:first-child,.table-bordered colgroup+thead tr:first-child th:first-child,.table-bordered colgroup+tbody tr:first-child td:first-child{-webkit-border-top-left-radius:4px;-moz-border-radius-topleft:4px;border-top-left-radius:4px;} +.table-bordered caption+thead tr:first-child th:last-child,.table-bordered caption+tbody tr:first-child td:last-child,.table-bordered colgroup+thead tr:first-child th:last-child,.table-bordered colgroup+tbody tr:first-child td:last-child{-webkit-border-top-right-radius:4px;-moz-border-radius-topright:4px;border-top-right-radius:4px;} +.table-striped tbody>tr:nth-child(odd)>td,.table-striped tbody>tr:nth-child(odd)>th{background-color:#f9f9f9;} +.table-hover tbody tr:hover>td,.table-hover tbody tr:hover>th{background-color:#f5f5f5;} +table td[class*="span"],table th[class*="span"],.row-fluid table td[class*="span"],.row-fluid table th[class*="span"]{display:table-cell;float:none;margin-left:0;} +.table td.span1,.table th.span1{float:none;width:44px;margin-left:0;} +.table td.span2,.table th.span2{float:none;width:124px;margin-left:0;} +.table td.span3,.table th.span3{float:none;width:204px;margin-left:0;} +.table td.span4,.table th.span4{float:none;width:284px;margin-left:0;} +.table td.span5,.table th.span5{float:none;width:364px;margin-left:0;} +.table td.span6,.table th.span6{float:none;width:444px;margin-left:0;} +.table td.span7,.table th.span7{float:none;width:524px;margin-left:0;} +.table td.span8,.table th.span8{float:none;width:604px;margin-left:0;} +.table td.span9,.table th.span9{float:none;width:684px;margin-left:0;} +.table td.span10,.table th.span10{float:none;width:764px;margin-left:0;} +.table td.span11,.table th.span11{float:none;width:844px;margin-left:0;} +.table td.span12,.table th.span12{float:none;width:924px;margin-left:0;} +.table tbody tr.success>td{background-color:#dff0d8;} +.table tbody tr.error>td{background-color:#f2dede;} +.table tbody tr.warning>td{background-color:#fcf8e3;} +.table tbody tr.info>td{background-color:#d9edf7;} +.table-hover tbody tr.success:hover>td{background-color:#d0e9c6;} +.table-hover tbody tr.error:hover>td{background-color:#ebcccc;} +.table-hover tbody tr.warning:hover>td{background-color:#faf2cc;} +.table-hover tbody tr.info:hover>td{background-color:#c4e3f3;} +form{margin:0 0 20px;} +fieldset{padding:0;margin:0;border:0;} +legend{display:block;width:100%;padding:0;margin-bottom:20px;font-size:21px;line-height:40px;color:#333333;border:0;border-bottom:1px solid #e5e5e5;}legend small{font-size:15px;color:#999999;} +label,input,button,select,textarea{font-size:14px;font-weight:normal;line-height:20px;} +input,button,select,textarea{font-family:"Helvetica Neue",Helvetica,Arial,sans-serif;} +label{display:block;margin-bottom:5px;} +select,textarea,input[type="text"],input[type="password"],input[type="datetime"],input[type="datetime-local"],input[type="date"],input[type="month"],input[type="time"],input[type="week"],input[type="number"],input[type="email"],input[type="url"],input[type="search"],input[type="tel"],input[type="color"],.uneditable-input{display:inline-block;height:20px;padding:4px 6px;margin-bottom:10px;font-size:14px;line-height:20px;color:#555555;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;vertical-align:middle;} +input,textarea,.uneditable-input{width:206px;} +textarea{height:auto;} +textarea,input[type="text"],input[type="password"],input[type="datetime"],input[type="datetime-local"],input[type="date"],input[type="month"],input[type="time"],input[type="week"],input[type="number"],input[type="email"],input[type="url"],input[type="search"],input[type="tel"],input[type="color"],.uneditable-input{background-color:#ffffff;border:1px solid #cccccc;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);-webkit-transition:border linear .2s, box-shadow linear .2s;-moz-transition:border linear .2s, box-shadow linear .2s;-o-transition:border linear .2s, box-shadow linear .2s;transition:border linear .2s, box-shadow linear .2s;}textarea:focus,input[type="text"]:focus,input[type="password"]:focus,input[type="datetime"]:focus,input[type="datetime-local"]:focus,input[type="date"]:focus,input[type="month"]:focus,input[type="time"]:focus,input[type="week"]:focus,input[type="number"]:focus,input[type="email"]:focus,input[type="url"]:focus,input[type="search"]:focus,input[type="tel"]:focus,input[type="color"]:focus,.uneditable-input:focus{border-color:rgba(82, 168, 236, 0.8);outline:0;outline:thin dotted \9;-webkit-box-shadow:inset 0 1px 1px rgba(0,0,0,.075), 0 0 8px rgba(82,168,236,.6);-moz-box-shadow:inset 0 1px 1px rgba(0,0,0,.075), 0 0 8px rgba(82,168,236,.6);box-shadow:inset 0 1px 1px rgba(0,0,0,.075), 0 0 8px rgba(82,168,236,.6);} +input[type="radio"],input[type="checkbox"]{margin:4px 0 0;*margin-top:0;margin-top:1px \9;line-height:normal;} +input[type="file"],input[type="image"],input[type="submit"],input[type="reset"],input[type="button"],input[type="radio"],input[type="checkbox"]{width:auto;} +select,input[type="file"]{height:30px;*margin-top:4px;line-height:30px;} +select{width:220px;border:1px solid #cccccc;background-color:#ffffff;} +select[multiple],select[size]{height:auto;} +select:focus,input[type="file"]:focus,input[type="radio"]:focus,input[type="checkbox"]:focus{outline:thin dotted #333;outline:5px auto -webkit-focus-ring-color;outline-offset:-2px;} +.uneditable-input,.uneditable-textarea{color:#999999;background-color:#fcfcfc;border-color:#cccccc;-webkit-box-shadow:inset 0 1px 2px rgba(0, 0, 0, 0.025);-moz-box-shadow:inset 0 1px 2px rgba(0, 0, 0, 0.025);box-shadow:inset 0 1px 2px rgba(0, 0, 0, 0.025);cursor:not-allowed;} +.uneditable-input{overflow:hidden;white-space:nowrap;} +.uneditable-textarea{width:auto;height:auto;} +input:-moz-placeholder,textarea:-moz-placeholder{color:#999999;} +input:-ms-input-placeholder,textarea:-ms-input-placeholder{color:#999999;} +input::-webkit-input-placeholder,textarea::-webkit-input-placeholder{color:#999999;} +.radio,.checkbox{min-height:20px;padding-left:20px;} +.radio input[type="radio"],.checkbox input[type="checkbox"]{float:left;margin-left:-20px;} +.controls>.radio:first-child,.controls>.checkbox:first-child{padding-top:5px;} +.radio.inline,.checkbox.inline{display:inline-block;padding-top:5px;margin-bottom:0;vertical-align:middle;} +.radio.inline+.radio.inline,.checkbox.inline+.checkbox.inline{margin-left:10px;} +.input-mini{width:60px;} +.input-small{width:90px;} +.input-medium{width:150px;} +.input-large{width:210px;} +.input-xlarge{width:270px;} +.input-xxlarge{width:530px;} +input[class*="span"],select[class*="span"],textarea[class*="span"],.uneditable-input[class*="span"],.row-fluid input[class*="span"],.row-fluid select[class*="span"],.row-fluid textarea[class*="span"],.row-fluid .uneditable-input[class*="span"]{float:none;margin-left:0;} +.input-append input[class*="span"],.input-append .uneditable-input[class*="span"],.input-prepend input[class*="span"],.input-prepend .uneditable-input[class*="span"],.row-fluid input[class*="span"],.row-fluid select[class*="span"],.row-fluid textarea[class*="span"],.row-fluid .uneditable-input[class*="span"],.row-fluid .input-prepend [class*="span"],.row-fluid .input-append [class*="span"]{display:inline-block;} +input,textarea,.uneditable-input{margin-left:0;} +.controls-row [class*="span"]+[class*="span"]{margin-left:20px;} +input.span12,textarea.span12,.uneditable-input.span12{width:926px;} +input.span11,textarea.span11,.uneditable-input.span11{width:846px;} +input.span10,textarea.span10,.uneditable-input.span10{width:766px;} +input.span9,textarea.span9,.uneditable-input.span9{width:686px;} +input.span8,textarea.span8,.uneditable-input.span8{width:606px;} +input.span7,textarea.span7,.uneditable-input.span7{width:526px;} +input.span6,textarea.span6,.uneditable-input.span6{width:446px;} +input.span5,textarea.span5,.uneditable-input.span5{width:366px;} +input.span4,textarea.span4,.uneditable-input.span4{width:286px;} +input.span3,textarea.span3,.uneditable-input.span3{width:206px;} +input.span2,textarea.span2,.uneditable-input.span2{width:126px;} +input.span1,textarea.span1,.uneditable-input.span1{width:46px;} +.controls-row{*zoom:1;}.controls-row:before,.controls-row:after{display:table;content:"";line-height:0;} +.controls-row:after{clear:both;} +.controls-row [class*="span"],.row-fluid .controls-row [class*="span"]{float:left;} +.controls-row .checkbox[class*="span"],.controls-row .radio[class*="span"]{padding-top:5px;} +input[disabled],select[disabled],textarea[disabled],input[readonly],select[readonly],textarea[readonly]{cursor:not-allowed;background-color:#eeeeee;} +input[type="radio"][disabled],input[type="checkbox"][disabled],input[type="radio"][readonly],input[type="checkbox"][readonly]{background-color:transparent;} +.control-group.warning .control-label,.control-group.warning .help-block,.control-group.warning .help-inline{color:#c09853;} +.control-group.warning .checkbox,.control-group.warning .radio,.control-group.warning input,.control-group.warning select,.control-group.warning textarea{color:#c09853;} +.control-group.warning input,.control-group.warning select,.control-group.warning textarea{border-color:#c09853;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);}.control-group.warning input:focus,.control-group.warning select:focus,.control-group.warning textarea:focus{border-color:#a47e3c;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #dbc59e;-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #dbc59e;box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #dbc59e;} +.control-group.warning .input-prepend .add-on,.control-group.warning .input-append .add-on{color:#c09853;background-color:#fcf8e3;border-color:#c09853;} +.control-group.error .control-label,.control-group.error .help-block,.control-group.error .help-inline{color:#b94a48;} +.control-group.error .checkbox,.control-group.error .radio,.control-group.error input,.control-group.error select,.control-group.error textarea{color:#b94a48;} +.control-group.error input,.control-group.error select,.control-group.error textarea{border-color:#b94a48;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);}.control-group.error input:focus,.control-group.error select:focus,.control-group.error textarea:focus{border-color:#953b39;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #d59392;-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #d59392;box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #d59392;} +.control-group.error .input-prepend .add-on,.control-group.error .input-append .add-on{color:#b94a48;background-color:#f2dede;border-color:#b94a48;} +.control-group.success .control-label,.control-group.success .help-block,.control-group.success .help-inline{color:#468847;} +.control-group.success .checkbox,.control-group.success .radio,.control-group.success input,.control-group.success select,.control-group.success textarea{color:#468847;} +.control-group.success input,.control-group.success select,.control-group.success textarea{border-color:#468847;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);}.control-group.success input:focus,.control-group.success select:focus,.control-group.success textarea:focus{border-color:#356635;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #7aba7b;-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #7aba7b;box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #7aba7b;} +.control-group.success .input-prepend .add-on,.control-group.success .input-append .add-on{color:#468847;background-color:#dff0d8;border-color:#468847;} +.control-group.info .control-label,.control-group.info .help-block,.control-group.info .help-inline{color:#3a87ad;} +.control-group.info .checkbox,.control-group.info .radio,.control-group.info input,.control-group.info select,.control-group.info textarea{color:#3a87ad;} +.control-group.info input,.control-group.info select,.control-group.info textarea{border-color:#3a87ad;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075);}.control-group.info input:focus,.control-group.info select:focus,.control-group.info textarea:focus{border-color:#2d6987;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #7ab5d3;-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #7ab5d3;box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.075),0 0 6px #7ab5d3;} +.control-group.info .input-prepend .add-on,.control-group.info .input-append .add-on{color:#3a87ad;background-color:#d9edf7;border-color:#3a87ad;} +input:focus:invalid,textarea:focus:invalid,select:focus:invalid{color:#b94a48;border-color:#ee5f5b;}input:focus:invalid:focus,textarea:focus:invalid:focus,select:focus:invalid:focus{border-color:#e9322d;-webkit-box-shadow:0 0 6px #f8b9b7;-moz-box-shadow:0 0 6px #f8b9b7;box-shadow:0 0 6px #f8b9b7;} +.form-actions{padding:19px 20px 20px;margin-top:20px;margin-bottom:20px;background-color:#f5f5f5;border-top:1px solid #e5e5e5;*zoom:1;}.form-actions:before,.form-actions:after{display:table;content:"";line-height:0;} +.form-actions:after{clear:both;} +.help-block,.help-inline{color:#595959;} +.help-block{display:block;margin-bottom:10px;} +.help-inline{display:inline-block;*display:inline;*zoom:1;vertical-align:middle;padding-left:5px;} +.input-append,.input-prepend{display:inline-block;margin-bottom:10px;vertical-align:middle;font-size:0;white-space:nowrap;}.input-append input,.input-prepend input,.input-append select,.input-prepend select,.input-append .uneditable-input,.input-prepend .uneditable-input,.input-append .dropdown-menu,.input-prepend .dropdown-menu,.input-append .popover,.input-prepend .popover{font-size:14px;} +.input-append input,.input-prepend input,.input-append select,.input-prepend select,.input-append .uneditable-input,.input-prepend .uneditable-input{position:relative;margin-bottom:0;*margin-left:0;vertical-align:top;-webkit-border-radius:0 4px 4px 0;-moz-border-radius:0 4px 4px 0;border-radius:0 4px 4px 0;}.input-append input:focus,.input-prepend input:focus,.input-append select:focus,.input-prepend select:focus,.input-append .uneditable-input:focus,.input-prepend .uneditable-input:focus{z-index:2;} +.input-append .add-on,.input-prepend .add-on{display:inline-block;width:auto;height:20px;min-width:16px;padding:4px 5px;font-size:14px;font-weight:normal;line-height:20px;text-align:center;text-shadow:0 1px 0 #ffffff;background-color:#eeeeee;border:1px solid #ccc;} +.input-append .add-on,.input-prepend .add-on,.input-append .btn,.input-prepend .btn,.input-append .btn-group>.dropdown-toggle,.input-prepend .btn-group>.dropdown-toggle{vertical-align:top;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;} +.input-append .active,.input-prepend .active{background-color:#a9dba9;border-color:#46a546;} +.input-prepend .add-on,.input-prepend .btn{margin-right:-1px;} +.input-prepend .add-on:first-child,.input-prepend .btn:first-child{-webkit-border-radius:4px 0 0 4px;-moz-border-radius:4px 0 0 4px;border-radius:4px 0 0 4px;} +.input-append input,.input-append select,.input-append .uneditable-input{-webkit-border-radius:4px 0 0 4px;-moz-border-radius:4px 0 0 4px;border-radius:4px 0 0 4px;}.input-append input+.btn-group .btn:last-child,.input-append select+.btn-group .btn:last-child,.input-append .uneditable-input+.btn-group .btn:last-child{-webkit-border-radius:0 4px 4px 0;-moz-border-radius:0 4px 4px 0;border-radius:0 4px 4px 0;} +.input-append .add-on,.input-append .btn,.input-append .btn-group{margin-left:-1px;} +.input-append .add-on:last-child,.input-append .btn:last-child,.input-append .btn-group:last-child>.dropdown-toggle{-webkit-border-radius:0 4px 4px 0;-moz-border-radius:0 4px 4px 0;border-radius:0 4px 4px 0;} +.input-prepend.input-append input,.input-prepend.input-append select,.input-prepend.input-append .uneditable-input{-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;}.input-prepend.input-append input+.btn-group .btn,.input-prepend.input-append select+.btn-group .btn,.input-prepend.input-append .uneditable-input+.btn-group .btn{-webkit-border-radius:0 4px 4px 0;-moz-border-radius:0 4px 4px 0;border-radius:0 4px 4px 0;} +.input-prepend.input-append .add-on:first-child,.input-prepend.input-append .btn:first-child{margin-right:-1px;-webkit-border-radius:4px 0 0 4px;-moz-border-radius:4px 0 0 4px;border-radius:4px 0 0 4px;} +.input-prepend.input-append .add-on:last-child,.input-prepend.input-append .btn:last-child{margin-left:-1px;-webkit-border-radius:0 4px 4px 0;-moz-border-radius:0 4px 4px 0;border-radius:0 4px 4px 0;} +.input-prepend.input-append .btn-group:first-child{margin-left:0;} +input.search-query{padding-right:14px;padding-right:4px \9;padding-left:14px;padding-left:4px \9;margin-bottom:0;-webkit-border-radius:15px;-moz-border-radius:15px;border-radius:15px;} +.form-search .input-append .search-query,.form-search .input-prepend .search-query{-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;} +.form-search .input-append .search-query{-webkit-border-radius:14px 0 0 14px;-moz-border-radius:14px 0 0 14px;border-radius:14px 0 0 14px;} +.form-search .input-append .btn{-webkit-border-radius:0 14px 14px 0;-moz-border-radius:0 14px 14px 0;border-radius:0 14px 14px 0;} +.form-search .input-prepend .search-query{-webkit-border-radius:0 14px 14px 0;-moz-border-radius:0 14px 14px 0;border-radius:0 14px 14px 0;} +.form-search .input-prepend .btn{-webkit-border-radius:14px 0 0 14px;-moz-border-radius:14px 0 0 14px;border-radius:14px 0 0 14px;} +.form-search input,.form-inline input,.form-horizontal input,.form-search textarea,.form-inline textarea,.form-horizontal textarea,.form-search select,.form-inline select,.form-horizontal select,.form-search .help-inline,.form-inline .help-inline,.form-horizontal .help-inline,.form-search .uneditable-input,.form-inline .uneditable-input,.form-horizontal .uneditable-input,.form-search .input-prepend,.form-inline .input-prepend,.form-horizontal .input-prepend,.form-search .input-append,.form-inline .input-append,.form-horizontal .input-append{display:inline-block;*display:inline;*zoom:1;margin-bottom:0;vertical-align:middle;} +.form-search .hide,.form-inline .hide,.form-horizontal .hide{display:none;} +.form-search label,.form-inline label,.form-search .btn-group,.form-inline .btn-group{display:inline-block;} +.form-search .input-append,.form-inline .input-append,.form-search .input-prepend,.form-inline .input-prepend{margin-bottom:0;} +.form-search .radio,.form-search .checkbox,.form-inline .radio,.form-inline .checkbox{padding-left:0;margin-bottom:0;vertical-align:middle;} +.form-search .radio input[type="radio"],.form-search .checkbox input[type="checkbox"],.form-inline .radio input[type="radio"],.form-inline .checkbox input[type="checkbox"]{float:left;margin-right:3px;margin-left:0;} +.control-group{margin-bottom:10px;} +legend+.control-group{margin-top:20px;-webkit-margin-top-collapse:separate;} +.form-horizontal .control-group{margin-bottom:20px;*zoom:1;}.form-horizontal .control-group:before,.form-horizontal .control-group:after{display:table;content:"";line-height:0;} +.form-horizontal .control-group:after{clear:both;} +.form-horizontal .control-label{float:left;width:160px;padding-top:5px;text-align:right;} +.form-horizontal .controls{*display:inline-block;*padding-left:20px;margin-left:180px;*margin-left:0;}.form-horizontal .controls:first-child{*padding-left:180px;} +.form-horizontal .help-block{margin-bottom:0;} +.form-horizontal input+.help-block,.form-horizontal select+.help-block,.form-horizontal textarea+.help-block,.form-horizontal .uneditable-input+.help-block,.form-horizontal .input-prepend+.help-block,.form-horizontal .input-append+.help-block{margin-top:10px;} +.form-horizontal .form-actions{padding-left:180px;} +.btn{display:inline-block;*display:inline;*zoom:1;padding:4px 12px;margin-bottom:0;font-size:14px;line-height:20px;text-align:center;vertical-align:middle;cursor:pointer;color:#333333;text-shadow:0 1px 1px rgba(255, 255, 255, 0.75);background-color:#f5f5f5;background-image:-moz-linear-gradient(top, #ffffff, #e6e6e6);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#ffffff), to(#e6e6e6));background-image:-webkit-linear-gradient(top, #ffffff, #e6e6e6);background-image:-o-linear-gradient(top, #ffffff, #e6e6e6);background-image:linear-gradient(to bottom, #ffffff, #e6e6e6);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ffffffff', endColorstr='#ffe6e6e6', GradientType=0);border-color:#e6e6e6 #e6e6e6 #bfbfbf;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#e6e6e6;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);border:1px solid #cccccc;*border:0;border-bottom-color:#b3b3b3;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;*margin-left:.3em;-webkit-box-shadow:inset 0 1px 0 rgba(255,255,255,.2), 0 1px 2px rgba(0,0,0,.05);-moz-box-shadow:inset 0 1px 0 rgba(255,255,255,.2), 0 1px 2px rgba(0,0,0,.05);box-shadow:inset 0 1px 0 rgba(255,255,255,.2), 0 1px 2px rgba(0,0,0,.05);}.btn:hover,.btn:focus,.btn:active,.btn.active,.btn.disabled,.btn[disabled]{color:#333333;background-color:#e6e6e6;*background-color:#d9d9d9;} +.btn:active,.btn.active{background-color:#cccccc \9;} +.btn:first-child{*margin-left:0;} +.btn:hover,.btn:focus{color:#333333;text-decoration:none;background-position:0 -15px;-webkit-transition:background-position 0.1s linear;-moz-transition:background-position 0.1s linear;-o-transition:background-position 0.1s linear;transition:background-position 0.1s linear;} +.btn:focus{outline:thin dotted #333;outline:5px auto -webkit-focus-ring-color;outline-offset:-2px;} +.btn.active,.btn:active{background-image:none;outline:0;-webkit-box-shadow:inset 0 2px 4px rgba(0,0,0,.15), 0 1px 2px rgba(0,0,0,.05);-moz-box-shadow:inset 0 2px 4px rgba(0,0,0,.15), 0 1px 2px rgba(0,0,0,.05);box-shadow:inset 0 2px 4px rgba(0,0,0,.15), 0 1px 2px rgba(0,0,0,.05);} +.btn.disabled,.btn[disabled]{cursor:default;background-image:none;opacity:0.65;filter:alpha(opacity=65);-webkit-box-shadow:none;-moz-box-shadow:none;box-shadow:none;} +.btn-large{padding:11px 19px;font-size:17.5px;-webkit-border-radius:6px;-moz-border-radius:6px;border-radius:6px;} +.btn-large [class^="icon-"],.btn-large [class*=" icon-"]{margin-top:4px;} +.btn-small{padding:2px 10px;font-size:11.9px;-webkit-border-radius:3px;-moz-border-radius:3px;border-radius:3px;} +.btn-small [class^="icon-"],.btn-small [class*=" icon-"]{margin-top:0;} +.btn-mini [class^="icon-"],.btn-mini [class*=" icon-"]{margin-top:-1px;} +.btn-mini{padding:0 6px;font-size:10.5px;-webkit-border-radius:3px;-moz-border-radius:3px;border-radius:3px;} +.btn-block{display:block;width:100%;padding-left:0;padding-right:0;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;} +.btn-block+.btn-block{margin-top:5px;} +input[type="submit"].btn-block,input[type="reset"].btn-block,input[type="button"].btn-block{width:100%;} +.btn-primary.active,.btn-warning.active,.btn-danger.active,.btn-success.active,.btn-info.active,.btn-inverse.active{color:rgba(255, 255, 255, 0.75);} +.btn-primary{color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#006dcc;background-image:-moz-linear-gradient(top, #0088cc, #0044cc);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#0088cc), to(#0044cc));background-image:-webkit-linear-gradient(top, #0088cc, #0044cc);background-image:-o-linear-gradient(top, #0088cc, #0044cc);background-image:linear-gradient(to bottom, #0088cc, #0044cc);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff0088cc', endColorstr='#ff0044cc', GradientType=0);border-color:#0044cc #0044cc #002a80;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#0044cc;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);}.btn-primary:hover,.btn-primary:focus,.btn-primary:active,.btn-primary.active,.btn-primary.disabled,.btn-primary[disabled]{color:#ffffff;background-color:#0044cc;*background-color:#003bb3;} +.btn-primary:active,.btn-primary.active{background-color:#003399 \9;} +.btn-warning{color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#faa732;background-image:-moz-linear-gradient(top, #fbb450, #f89406);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#fbb450), to(#f89406));background-image:-webkit-linear-gradient(top, #fbb450, #f89406);background-image:-o-linear-gradient(top, #fbb450, #f89406);background-image:linear-gradient(to bottom, #fbb450, #f89406);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#fffbb450', endColorstr='#fff89406', GradientType=0);border-color:#f89406 #f89406 #ad6704;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#f89406;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);}.btn-warning:hover,.btn-warning:focus,.btn-warning:active,.btn-warning.active,.btn-warning.disabled,.btn-warning[disabled]{color:#ffffff;background-color:#f89406;*background-color:#df8505;} +.btn-warning:active,.btn-warning.active{background-color:#c67605 \9;} +.btn-danger{color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#da4f49;background-image:-moz-linear-gradient(top, #ee5f5b, #bd362f);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#ee5f5b), to(#bd362f));background-image:-webkit-linear-gradient(top, #ee5f5b, #bd362f);background-image:-o-linear-gradient(top, #ee5f5b, #bd362f);background-image:linear-gradient(to bottom, #ee5f5b, #bd362f);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ffee5f5b', endColorstr='#ffbd362f', GradientType=0);border-color:#bd362f #bd362f #802420;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#bd362f;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);}.btn-danger:hover,.btn-danger:focus,.btn-danger:active,.btn-danger.active,.btn-danger.disabled,.btn-danger[disabled]{color:#ffffff;background-color:#bd362f;*background-color:#a9302a;} +.btn-danger:active,.btn-danger.active{background-color:#942a25 \9;} +.btn-success{color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#5bb75b;background-image:-moz-linear-gradient(top, #62c462, #51a351);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#62c462), to(#51a351));background-image:-webkit-linear-gradient(top, #62c462, #51a351);background-image:-o-linear-gradient(top, #62c462, #51a351);background-image:linear-gradient(to bottom, #62c462, #51a351);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff62c462', endColorstr='#ff51a351', GradientType=0);border-color:#51a351 #51a351 #387038;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#51a351;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);}.btn-success:hover,.btn-success:focus,.btn-success:active,.btn-success.active,.btn-success.disabled,.btn-success[disabled]{color:#ffffff;background-color:#51a351;*background-color:#499249;} +.btn-success:active,.btn-success.active{background-color:#408140 \9;} +.btn-info{color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#49afcd;background-image:-moz-linear-gradient(top, #5bc0de, #2f96b4);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#5bc0de), to(#2f96b4));background-image:-webkit-linear-gradient(top, #5bc0de, #2f96b4);background-image:-o-linear-gradient(top, #5bc0de, #2f96b4);background-image:linear-gradient(to bottom, #5bc0de, #2f96b4);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff5bc0de', endColorstr='#ff2f96b4', GradientType=0);border-color:#2f96b4 #2f96b4 #1f6377;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#2f96b4;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);}.btn-info:hover,.btn-info:focus,.btn-info:active,.btn-info.active,.btn-info.disabled,.btn-info[disabled]{color:#ffffff;background-color:#2f96b4;*background-color:#2a85a0;} +.btn-info:active,.btn-info.active{background-color:#24748c \9;} +.btn-inverse{color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#363636;background-image:-moz-linear-gradient(top, #444444, #222222);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#444444), to(#222222));background-image:-webkit-linear-gradient(top, #444444, #222222);background-image:-o-linear-gradient(top, #444444, #222222);background-image:linear-gradient(to bottom, #444444, #222222);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff444444', endColorstr='#ff222222', GradientType=0);border-color:#222222 #222222 #000000;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#222222;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);}.btn-inverse:hover,.btn-inverse:focus,.btn-inverse:active,.btn-inverse.active,.btn-inverse.disabled,.btn-inverse[disabled]{color:#ffffff;background-color:#222222;*background-color:#151515;} +.btn-inverse:active,.btn-inverse.active{background-color:#080808 \9;} +button.btn,input[type="submit"].btn{*padding-top:3px;*padding-bottom:3px;}button.btn::-moz-focus-inner,input[type="submit"].btn::-moz-focus-inner{padding:0;border:0;} +button.btn.btn-large,input[type="submit"].btn.btn-large{*padding-top:7px;*padding-bottom:7px;} +button.btn.btn-small,input[type="submit"].btn.btn-small{*padding-top:3px;*padding-bottom:3px;} +button.btn.btn-mini,input[type="submit"].btn.btn-mini{*padding-top:1px;*padding-bottom:1px;} +.btn-link,.btn-link:active,.btn-link[disabled]{background-color:transparent;background-image:none;-webkit-box-shadow:none;-moz-box-shadow:none;box-shadow:none;} +.btn-link{border-color:transparent;cursor:pointer;color:#0088cc;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;} +.btn-link:hover,.btn-link:focus{color:#005580;text-decoration:underline;background-color:transparent;} +.btn-link[disabled]:hover,.btn-link[disabled]:focus{color:#333333;text-decoration:none;} +[class^="icon-"],[class*=" icon-"]{display:inline-block;width:14px;height:14px;*margin-right:.3em;line-height:14px;vertical-align:text-top;background-image:url("../img/glyphicons-halflings.png");background-position:14px 14px;background-repeat:no-repeat;margin-top:1px;} +.icon-white,.nav-pills>.active>a>[class^="icon-"],.nav-pills>.active>a>[class*=" icon-"],.nav-list>.active>a>[class^="icon-"],.nav-list>.active>a>[class*=" icon-"],.navbar-inverse .nav>.active>a>[class^="icon-"],.navbar-inverse .nav>.active>a>[class*=" icon-"],.dropdown-menu>li>a:hover>[class^="icon-"],.dropdown-menu>li>a:focus>[class^="icon-"],.dropdown-menu>li>a:hover>[class*=" icon-"],.dropdown-menu>li>a:focus>[class*=" icon-"],.dropdown-menu>.active>a>[class^="icon-"],.dropdown-menu>.active>a>[class*=" icon-"],.dropdown-submenu:hover>a>[class^="icon-"],.dropdown-submenu:focus>a>[class^="icon-"],.dropdown-submenu:hover>a>[class*=" icon-"],.dropdown-submenu:focus>a>[class*=" icon-"]{background-image:url("../img/glyphicons-halflings-white.png");} +.icon-glass{background-position:0 0;} +.icon-music{background-position:-24px 0;} +.icon-search{background-position:-48px 0;} +.icon-envelope{background-position:-72px 0;} +.icon-heart{background-position:-96px 0;} +.icon-star{background-position:-120px 0;} +.icon-star-empty{background-position:-144px 0;} +.icon-user{background-position:-168px 0;} +.icon-film{background-position:-192px 0;} +.icon-th-large{background-position:-216px 0;} +.icon-th{background-position:-240px 0;} +.icon-th-list{background-position:-264px 0;} +.icon-ok{background-position:-288px 0;} +.icon-remove{background-position:-312px 0;} +.icon-zoom-in{background-position:-336px 0;} +.icon-zoom-out{background-position:-360px 0;} +.icon-off{background-position:-384px 0;} +.icon-signal{background-position:-408px 0;} +.icon-cog{background-position:-432px 0;} +.icon-trash{background-position:-456px 0;} +.icon-home{background-position:0 -24px;} +.icon-file{background-position:-24px -24px;} +.icon-time{background-position:-48px -24px;} +.icon-road{background-position:-72px -24px;} +.icon-download-alt{background-position:-96px -24px;} +.icon-download{background-position:-120px -24px;} +.icon-upload{background-position:-144px -24px;} +.icon-inbox{background-position:-168px -24px;} +.icon-play-circle{background-position:-192px -24px;} +.icon-repeat{background-position:-216px -24px;} +.icon-refresh{background-position:-240px -24px;} +.icon-list-alt{background-position:-264px -24px;} +.icon-lock{background-position:-287px -24px;} +.icon-flag{background-position:-312px -24px;} +.icon-headphones{background-position:-336px -24px;} +.icon-volume-off{background-position:-360px -24px;} +.icon-volume-down{background-position:-384px -24px;} +.icon-volume-up{background-position:-408px -24px;} +.icon-qrcode{background-position:-432px -24px;} +.icon-barcode{background-position:-456px -24px;} +.icon-tag{background-position:0 -48px;} +.icon-tags{background-position:-25px -48px;} +.icon-book{background-position:-48px -48px;} +.icon-bookmark{background-position:-72px -48px;} +.icon-print{background-position:-96px -48px;} +.icon-camera{background-position:-120px -48px;} +.icon-font{background-position:-144px -48px;} +.icon-bold{background-position:-167px -48px;} +.icon-italic{background-position:-192px -48px;} +.icon-text-height{background-position:-216px -48px;} +.icon-text-width{background-position:-240px -48px;} +.icon-align-left{background-position:-264px -48px;} +.icon-align-center{background-position:-288px -48px;} +.icon-align-right{background-position:-312px -48px;} +.icon-align-justify{background-position:-336px -48px;} +.icon-list{background-position:-360px -48px;} +.icon-indent-left{background-position:-384px -48px;} +.icon-indent-right{background-position:-408px -48px;} +.icon-facetime-video{background-position:-432px -48px;} +.icon-picture{background-position:-456px -48px;} +.icon-pencil{background-position:0 -72px;} +.icon-map-marker{background-position:-24px -72px;} +.icon-adjust{background-position:-48px -72px;} +.icon-tint{background-position:-72px -72px;} +.icon-edit{background-position:-96px -72px;} +.icon-share{background-position:-120px -72px;} +.icon-check{background-position:-144px -72px;} +.icon-move{background-position:-168px -72px;} +.icon-step-backward{background-position:-192px -72px;} +.icon-fast-backward{background-position:-216px -72px;} +.icon-backward{background-position:-240px -72px;} +.icon-play{background-position:-264px -72px;} +.icon-pause{background-position:-288px -72px;} +.icon-stop{background-position:-312px -72px;} +.icon-forward{background-position:-336px -72px;} +.icon-fast-forward{background-position:-360px -72px;} +.icon-step-forward{background-position:-384px -72px;} +.icon-eject{background-position:-408px -72px;} +.icon-chevron-left{background-position:-432px -72px;} +.icon-chevron-right{background-position:-456px -72px;} +.icon-plus-sign{background-position:0 -96px;} +.icon-minus-sign{background-position:-24px -96px;} +.icon-remove-sign{background-position:-48px -96px;} +.icon-ok-sign{background-position:-72px -96px;} +.icon-question-sign{background-position:-96px -96px;} +.icon-info-sign{background-position:-120px -96px;} +.icon-screenshot{background-position:-144px -96px;} +.icon-remove-circle{background-position:-168px -96px;} +.icon-ok-circle{background-position:-192px -96px;} +.icon-ban-circle{background-position:-216px -96px;} +.icon-arrow-left{background-position:-240px -96px;} +.icon-arrow-right{background-position:-264px -96px;} +.icon-arrow-up{background-position:-289px -96px;} +.icon-arrow-down{background-position:-312px -96px;} +.icon-share-alt{background-position:-336px -96px;} +.icon-resize-full{background-position:-360px -96px;} +.icon-resize-small{background-position:-384px -96px;} +.icon-plus{background-position:-408px -96px;} +.icon-minus{background-position:-433px -96px;} +.icon-asterisk{background-position:-456px -96px;} +.icon-exclamation-sign{background-position:0 -120px;} +.icon-gift{background-position:-24px -120px;} +.icon-leaf{background-position:-48px -120px;} +.icon-fire{background-position:-72px -120px;} +.icon-eye-open{background-position:-96px -120px;} +.icon-eye-close{background-position:-120px -120px;} +.icon-warning-sign{background-position:-144px -120px;} +.icon-plane{background-position:-168px -120px;} +.icon-calendar{background-position:-192px -120px;} +.icon-random{background-position:-216px -120px;width:16px;} +.icon-comment{background-position:-240px -120px;} +.icon-magnet{background-position:-264px -120px;} +.icon-chevron-up{background-position:-288px -120px;} +.icon-chevron-down{background-position:-313px -119px;} +.icon-retweet{background-position:-336px -120px;} +.icon-shopping-cart{background-position:-360px -120px;} +.icon-folder-close{background-position:-384px -120px;width:16px;} +.icon-folder-open{background-position:-408px -120px;width:16px;} +.icon-resize-vertical{background-position:-432px -119px;} +.icon-resize-horizontal{background-position:-456px -118px;} +.icon-hdd{background-position:0 -144px;} +.icon-bullhorn{background-position:-24px -144px;} +.icon-bell{background-position:-48px -144px;} +.icon-certificate{background-position:-72px -144px;} +.icon-thumbs-up{background-position:-96px -144px;} +.icon-thumbs-down{background-position:-120px -144px;} +.icon-hand-right{background-position:-144px -144px;} +.icon-hand-left{background-position:-168px -144px;} +.icon-hand-up{background-position:-192px -144px;} +.icon-hand-down{background-position:-216px -144px;} +.icon-circle-arrow-right{background-position:-240px -144px;} +.icon-circle-arrow-left{background-position:-264px -144px;} +.icon-circle-arrow-up{background-position:-288px -144px;} +.icon-circle-arrow-down{background-position:-312px -144px;} +.icon-globe{background-position:-336px -144px;} +.icon-wrench{background-position:-360px -144px;} +.icon-tasks{background-position:-384px -144px;} +.icon-filter{background-position:-408px -144px;} +.icon-briefcase{background-position:-432px -144px;} +.icon-fullscreen{background-position:-456px -144px;} +.btn-group{position:relative;display:inline-block;*display:inline;*zoom:1;font-size:0;vertical-align:middle;white-space:nowrap;*margin-left:.3em;}.btn-group:first-child{*margin-left:0;} +.btn-group+.btn-group{margin-left:5px;} +.btn-toolbar{font-size:0;margin-top:10px;margin-bottom:10px;}.btn-toolbar>.btn+.btn,.btn-toolbar>.btn-group+.btn,.btn-toolbar>.btn+.btn-group{margin-left:5px;} +.btn-group>.btn{position:relative;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;} +.btn-group>.btn+.btn{margin-left:-1px;} +.btn-group>.btn,.btn-group>.dropdown-menu,.btn-group>.popover{font-size:14px;} +.btn-group>.btn-mini{font-size:10.5px;} +.btn-group>.btn-small{font-size:11.9px;} +.btn-group>.btn-large{font-size:17.5px;} +.btn-group>.btn:first-child{margin-left:0;-webkit-border-top-left-radius:4px;-moz-border-radius-topleft:4px;border-top-left-radius:4px;-webkit-border-bottom-left-radius:4px;-moz-border-radius-bottomleft:4px;border-bottom-left-radius:4px;} +.btn-group>.btn:last-child,.btn-group>.dropdown-toggle{-webkit-border-top-right-radius:4px;-moz-border-radius-topright:4px;border-top-right-radius:4px;-webkit-border-bottom-right-radius:4px;-moz-border-radius-bottomright:4px;border-bottom-right-radius:4px;} +.btn-group>.btn.large:first-child{margin-left:0;-webkit-border-top-left-radius:6px;-moz-border-radius-topleft:6px;border-top-left-radius:6px;-webkit-border-bottom-left-radius:6px;-moz-border-radius-bottomleft:6px;border-bottom-left-radius:6px;} +.btn-group>.btn.large:last-child,.btn-group>.large.dropdown-toggle{-webkit-border-top-right-radius:6px;-moz-border-radius-topright:6px;border-top-right-radius:6px;-webkit-border-bottom-right-radius:6px;-moz-border-radius-bottomright:6px;border-bottom-right-radius:6px;} +.btn-group>.btn:hover,.btn-group>.btn:focus,.btn-group>.btn:active,.btn-group>.btn.active{z-index:2;} +.btn-group .dropdown-toggle:active,.btn-group.open .dropdown-toggle{outline:0;} +.btn-group>.btn+.dropdown-toggle{padding-left:8px;padding-right:8px;-webkit-box-shadow:inset 1px 0 0 rgba(255,255,255,.125), inset 0 1px 0 rgba(255,255,255,.2), 0 1px 2px rgba(0,0,0,.05);-moz-box-shadow:inset 1px 0 0 rgba(255,255,255,.125), inset 0 1px 0 rgba(255,255,255,.2), 0 1px 2px rgba(0,0,0,.05);box-shadow:inset 1px 0 0 rgba(255,255,255,.125), inset 0 1px 0 rgba(255,255,255,.2), 0 1px 2px rgba(0,0,0,.05);*padding-top:5px;*padding-bottom:5px;} +.btn-group>.btn-mini+.dropdown-toggle{padding-left:5px;padding-right:5px;*padding-top:2px;*padding-bottom:2px;} +.btn-group>.btn-small+.dropdown-toggle{*padding-top:5px;*padding-bottom:4px;} +.btn-group>.btn-large+.dropdown-toggle{padding-left:12px;padding-right:12px;*padding-top:7px;*padding-bottom:7px;} +.btn-group.open .dropdown-toggle{background-image:none;-webkit-box-shadow:inset 0 2px 4px rgba(0,0,0,.15), 0 1px 2px rgba(0,0,0,.05);-moz-box-shadow:inset 0 2px 4px rgba(0,0,0,.15), 0 1px 2px rgba(0,0,0,.05);box-shadow:inset 0 2px 4px rgba(0,0,0,.15), 0 1px 2px rgba(0,0,0,.05);} +.btn-group.open .btn.dropdown-toggle{background-color:#e6e6e6;} +.btn-group.open .btn-primary.dropdown-toggle{background-color:#0044cc;} +.btn-group.open .btn-warning.dropdown-toggle{background-color:#f89406;} +.btn-group.open .btn-danger.dropdown-toggle{background-color:#bd362f;} +.btn-group.open .btn-success.dropdown-toggle{background-color:#51a351;} +.btn-group.open .btn-info.dropdown-toggle{background-color:#2f96b4;} +.btn-group.open .btn-inverse.dropdown-toggle{background-color:#222222;} +.btn .caret{margin-top:8px;margin-left:0;} +.btn-large .caret{margin-top:6px;} +.btn-large .caret{border-left-width:5px;border-right-width:5px;border-top-width:5px;} +.btn-mini .caret,.btn-small .caret{margin-top:8px;} +.dropup .btn-large .caret{border-bottom-width:5px;} +.btn-primary .caret,.btn-warning .caret,.btn-danger .caret,.btn-info .caret,.btn-success .caret,.btn-inverse .caret{border-top-color:#ffffff;border-bottom-color:#ffffff;} +.btn-group-vertical{display:inline-block;*display:inline;*zoom:1;} +.btn-group-vertical>.btn{display:block;float:none;max-width:100%;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;} +.btn-group-vertical>.btn+.btn{margin-left:0;margin-top:-1px;} +.btn-group-vertical>.btn:first-child{-webkit-border-radius:4px 4px 0 0;-moz-border-radius:4px 4px 0 0;border-radius:4px 4px 0 0;} +.btn-group-vertical>.btn:last-child{-webkit-border-radius:0 0 4px 4px;-moz-border-radius:0 0 4px 4px;border-radius:0 0 4px 4px;} +.btn-group-vertical>.btn-large:first-child{-webkit-border-radius:6px 6px 0 0;-moz-border-radius:6px 6px 0 0;border-radius:6px 6px 0 0;} +.btn-group-vertical>.btn-large:last-child{-webkit-border-radius:0 0 6px 6px;-moz-border-radius:0 0 6px 6px;border-radius:0 0 6px 6px;} +.nav{margin-left:0;margin-bottom:20px;list-style:none;} +.nav>li>a{display:block;} +.nav>li>a:hover,.nav>li>a:focus{text-decoration:none;background-color:#eeeeee;} +.nav>li>a>img{max-width:none;} +.nav>.pull-right{float:right;} +.nav-header{display:block;padding:3px 15px;font-size:11px;font-weight:bold;line-height:20px;color:#999999;text-shadow:0 1px 0 rgba(255, 255, 255, 0.5);text-transform:uppercase;} +.nav li+.nav-header{margin-top:9px;} +.nav-list{padding-left:15px;padding-right:15px;margin-bottom:0;} +.nav-list>li>a,.nav-list .nav-header{margin-left:-15px;margin-right:-15px;text-shadow:0 1px 0 rgba(255, 255, 255, 0.5);} +.nav-list>li>a{padding:3px 15px;} +.nav-list>.active>a,.nav-list>.active>a:hover,.nav-list>.active>a:focus{color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.2);background-color:#0088cc;} +.nav-list [class^="icon-"],.nav-list [class*=" icon-"]{margin-right:2px;} +.nav-list .divider{*width:100%;height:1px;margin:9px 1px;*margin:-5px 0 5px;overflow:hidden;background-color:#e5e5e5;border-bottom:1px solid #ffffff;} +.nav-tabs,.nav-pills{*zoom:1;}.nav-tabs:before,.nav-pills:before,.nav-tabs:after,.nav-pills:after{display:table;content:"";line-height:0;} +.nav-tabs:after,.nav-pills:after{clear:both;} +.nav-tabs>li,.nav-pills>li{float:left;} +.nav-tabs>li>a,.nav-pills>li>a{padding-right:12px;padding-left:12px;margin-right:2px;line-height:14px;} +.nav-tabs{border-bottom:1px solid #ddd;} +.nav-tabs>li{margin-bottom:-1px;} +.nav-tabs>li>a{padding-top:8px;padding-bottom:8px;line-height:20px;border:1px solid transparent;-webkit-border-radius:4px 4px 0 0;-moz-border-radius:4px 4px 0 0;border-radius:4px 4px 0 0;}.nav-tabs>li>a:hover,.nav-tabs>li>a:focus{border-color:#eeeeee #eeeeee #dddddd;} +.nav-tabs>.active>a,.nav-tabs>.active>a:hover,.nav-tabs>.active>a:focus{color:#555555;background-color:#ffffff;border:1px solid #ddd;border-bottom-color:transparent;cursor:default;} +.nav-pills>li>a{padding-top:8px;padding-bottom:8px;margin-top:2px;margin-bottom:2px;-webkit-border-radius:5px;-moz-border-radius:5px;border-radius:5px;} +.nav-pills>.active>a,.nav-pills>.active>a:hover,.nav-pills>.active>a:focus{color:#ffffff;background-color:#0088cc;} +.nav-stacked>li{float:none;} +.nav-stacked>li>a{margin-right:0;} +.nav-tabs.nav-stacked{border-bottom:0;} +.nav-tabs.nav-stacked>li>a{border:1px solid #ddd;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;} +.nav-tabs.nav-stacked>li:first-child>a{-webkit-border-top-right-radius:4px;-moz-border-radius-topright:4px;border-top-right-radius:4px;-webkit-border-top-left-radius:4px;-moz-border-radius-topleft:4px;border-top-left-radius:4px;} +.nav-tabs.nav-stacked>li:last-child>a{-webkit-border-bottom-right-radius:4px;-moz-border-radius-bottomright:4px;border-bottom-right-radius:4px;-webkit-border-bottom-left-radius:4px;-moz-border-radius-bottomleft:4px;border-bottom-left-radius:4px;} +.nav-tabs.nav-stacked>li>a:hover,.nav-tabs.nav-stacked>li>a:focus{border-color:#ddd;z-index:2;} +.nav-pills.nav-stacked>li>a{margin-bottom:3px;} +.nav-pills.nav-stacked>li:last-child>a{margin-bottom:1px;} +.nav-tabs .dropdown-menu{-webkit-border-radius:0 0 6px 6px;-moz-border-radius:0 0 6px 6px;border-radius:0 0 6px 6px;} +.nav-pills .dropdown-menu{-webkit-border-radius:6px;-moz-border-radius:6px;border-radius:6px;} +.nav .dropdown-toggle .caret{border-top-color:#0088cc;border-bottom-color:#0088cc;margin-top:6px;} +.nav .dropdown-toggle:hover .caret,.nav .dropdown-toggle:focus .caret{border-top-color:#005580;border-bottom-color:#005580;} +.nav-tabs .dropdown-toggle .caret{margin-top:8px;} +.nav .active .dropdown-toggle .caret{border-top-color:#fff;border-bottom-color:#fff;} +.nav-tabs .active .dropdown-toggle .caret{border-top-color:#555555;border-bottom-color:#555555;} +.nav>.dropdown.active>a:hover,.nav>.dropdown.active>a:focus{cursor:pointer;} +.nav-tabs .open .dropdown-toggle,.nav-pills .open .dropdown-toggle,.nav>li.dropdown.open.active>a:hover,.nav>li.dropdown.open.active>a:focus{color:#ffffff;background-color:#999999;border-color:#999999;} +.nav li.dropdown.open .caret,.nav li.dropdown.open.active .caret,.nav li.dropdown.open a:hover .caret,.nav li.dropdown.open a:focus .caret{border-top-color:#ffffff;border-bottom-color:#ffffff;opacity:1;filter:alpha(opacity=100);} +.tabs-stacked .open>a:hover,.tabs-stacked .open>a:focus{border-color:#999999;} +.tabbable{*zoom:1;}.tabbable:before,.tabbable:after{display:table;content:"";line-height:0;} +.tabbable:after{clear:both;} +.tab-content{overflow:auto;} +.tabs-below>.nav-tabs,.tabs-right>.nav-tabs,.tabs-left>.nav-tabs{border-bottom:0;} +.tab-content>.tab-pane,.pill-content>.pill-pane{display:none;} +.tab-content>.active,.pill-content>.active{display:block;} +.tabs-below>.nav-tabs{border-top:1px solid #ddd;} +.tabs-below>.nav-tabs>li{margin-top:-1px;margin-bottom:0;} +.tabs-below>.nav-tabs>li>a{-webkit-border-radius:0 0 4px 4px;-moz-border-radius:0 0 4px 4px;border-radius:0 0 4px 4px;}.tabs-below>.nav-tabs>li>a:hover,.tabs-below>.nav-tabs>li>a:focus{border-bottom-color:transparent;border-top-color:#ddd;} +.tabs-below>.nav-tabs>.active>a,.tabs-below>.nav-tabs>.active>a:hover,.tabs-below>.nav-tabs>.active>a:focus{border-color:transparent #ddd #ddd #ddd;} +.tabs-left>.nav-tabs>li,.tabs-right>.nav-tabs>li{float:none;} +.tabs-left>.nav-tabs>li>a,.tabs-right>.nav-tabs>li>a{min-width:74px;margin-right:0;margin-bottom:3px;} +.tabs-left>.nav-tabs{float:left;margin-right:19px;border-right:1px solid #ddd;} +.tabs-left>.nav-tabs>li>a{margin-right:-1px;-webkit-border-radius:4px 0 0 4px;-moz-border-radius:4px 0 0 4px;border-radius:4px 0 0 4px;} +.tabs-left>.nav-tabs>li>a:hover,.tabs-left>.nav-tabs>li>a:focus{border-color:#eeeeee #dddddd #eeeeee #eeeeee;} +.tabs-left>.nav-tabs .active>a,.tabs-left>.nav-tabs .active>a:hover,.tabs-left>.nav-tabs .active>a:focus{border-color:#ddd transparent #ddd #ddd;*border-right-color:#ffffff;} +.tabs-right>.nav-tabs{float:right;margin-left:19px;border-left:1px solid #ddd;} +.tabs-right>.nav-tabs>li>a{margin-left:-1px;-webkit-border-radius:0 4px 4px 0;-moz-border-radius:0 4px 4px 0;border-radius:0 4px 4px 0;} +.tabs-right>.nav-tabs>li>a:hover,.tabs-right>.nav-tabs>li>a:focus{border-color:#eeeeee #eeeeee #eeeeee #dddddd;} +.tabs-right>.nav-tabs .active>a,.tabs-right>.nav-tabs .active>a:hover,.tabs-right>.nav-tabs .active>a:focus{border-color:#ddd #ddd #ddd transparent;*border-left-color:#ffffff;} +.nav>.disabled>a{color:#999999;} +.nav>.disabled>a:hover,.nav>.disabled>a:focus{text-decoration:none;background-color:transparent;cursor:default;} +.navbar{overflow:visible;margin-bottom:20px;*position:relative;*z-index:2;} +.navbar-inner{min-height:40px;padding-left:20px;padding-right:20px;background-color:#fafafa;background-image:-moz-linear-gradient(top, #ffffff, #f2f2f2);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#ffffff), to(#f2f2f2));background-image:-webkit-linear-gradient(top, #ffffff, #f2f2f2);background-image:-o-linear-gradient(top, #ffffff, #f2f2f2);background-image:linear-gradient(to bottom, #ffffff, #f2f2f2);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ffffffff', endColorstr='#fff2f2f2', GradientType=0);border:1px solid #d4d4d4;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;-webkit-box-shadow:0 1px 4px rgba(0, 0, 0, 0.065);-moz-box-shadow:0 1px 4px rgba(0, 0, 0, 0.065);box-shadow:0 1px 4px rgba(0, 0, 0, 0.065);*zoom:1;}.navbar-inner:before,.navbar-inner:after{display:table;content:"";line-height:0;} +.navbar-inner:after{clear:both;} +.navbar .container{width:auto;} +.nav-collapse.collapse{height:auto;overflow:visible;} +.navbar .brand{float:left;display:block;padding:10px 20px 10px;margin-left:-20px;font-size:20px;font-weight:200;color:#777777;text-shadow:0 1px 0 #ffffff;}.navbar .brand:hover,.navbar .brand:focus{text-decoration:none;} +.navbar-text{margin-bottom:0;line-height:40px;color:#777777;} +.navbar-link{color:#777777;}.navbar-link:hover,.navbar-link:focus{color:#333333;} +.navbar .divider-vertical{height:40px;margin:0 9px;border-left:1px solid #f2f2f2;border-right:1px solid #ffffff;} +.navbar .btn,.navbar .btn-group{margin-top:5px;} +.navbar .btn-group .btn,.navbar .input-prepend .btn,.navbar .input-append .btn,.navbar .input-prepend .btn-group,.navbar .input-append .btn-group{margin-top:0;} +.navbar-form{margin-bottom:0;*zoom:1;}.navbar-form:before,.navbar-form:after{display:table;content:"";line-height:0;} +.navbar-form:after{clear:both;} +.navbar-form input,.navbar-form select,.navbar-form .radio,.navbar-form .checkbox{margin-top:5px;} +.navbar-form input,.navbar-form select,.navbar-form .btn{display:inline-block;margin-bottom:0;} +.navbar-form input[type="image"],.navbar-form input[type="checkbox"],.navbar-form input[type="radio"]{margin-top:3px;} +.navbar-form .input-append,.navbar-form .input-prepend{margin-top:5px;white-space:nowrap;}.navbar-form .input-append input,.navbar-form .input-prepend input{margin-top:0;} +.navbar-search{position:relative;float:left;margin-top:5px;margin-bottom:0;}.navbar-search .search-query{margin-bottom:0;padding:4px 14px;font-family:"Helvetica Neue",Helvetica,Arial,sans-serif;font-size:13px;font-weight:normal;line-height:1;-webkit-border-radius:15px;-moz-border-radius:15px;border-radius:15px;} +.navbar-static-top{position:static;margin-bottom:0;}.navbar-static-top .navbar-inner{-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;} +.navbar-fixed-top,.navbar-fixed-bottom{position:fixed;right:0;left:0;z-index:1030;margin-bottom:0;} +.navbar-fixed-top .navbar-inner,.navbar-static-top .navbar-inner{border-width:0 0 1px;} +.navbar-fixed-bottom .navbar-inner{border-width:1px 0 0;} +.navbar-fixed-top .navbar-inner,.navbar-fixed-bottom .navbar-inner{padding-left:0;padding-right:0;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;} +.navbar-static-top .container,.navbar-fixed-top .container,.navbar-fixed-bottom .container{width:940px;} +.navbar-fixed-top{top:0;} +.navbar-fixed-top .navbar-inner,.navbar-static-top .navbar-inner{-webkit-box-shadow:0 1px 10px rgba(0,0,0,.1);-moz-box-shadow:0 1px 10px rgba(0,0,0,.1);box-shadow:0 1px 10px rgba(0,0,0,.1);} +.navbar-fixed-bottom{bottom:0;}.navbar-fixed-bottom .navbar-inner{-webkit-box-shadow:0 -1px 10px rgba(0,0,0,.1);-moz-box-shadow:0 -1px 10px rgba(0,0,0,.1);box-shadow:0 -1px 10px rgba(0,0,0,.1);} +.navbar .nav{position:relative;left:0;display:block;float:left;margin:0 10px 0 0;} +.navbar .nav.pull-right{float:right;margin-right:0;} +.navbar .nav>li{float:left;} +.navbar .nav>li>a{float:none;padding:10px 15px 10px;color:#777777;text-decoration:none;text-shadow:0 1px 0 #ffffff;} +.navbar .nav .dropdown-toggle .caret{margin-top:8px;} +.navbar .nav>li>a:focus,.navbar .nav>li>a:hover{background-color:transparent;color:#333333;text-decoration:none;} +.navbar .nav>.active>a,.navbar .nav>.active>a:hover,.navbar .nav>.active>a:focus{color:#555555;text-decoration:none;background-color:#e5e5e5;-webkit-box-shadow:inset 0 3px 8px rgba(0, 0, 0, 0.125);-moz-box-shadow:inset 0 3px 8px rgba(0, 0, 0, 0.125);box-shadow:inset 0 3px 8px rgba(0, 0, 0, 0.125);} +.navbar .btn-navbar{display:none;float:right;padding:7px 10px;margin-left:5px;margin-right:5px;color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#ededed;background-image:-moz-linear-gradient(top, #f2f2f2, #e5e5e5);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#f2f2f2), to(#e5e5e5));background-image:-webkit-linear-gradient(top, #f2f2f2, #e5e5e5);background-image:-o-linear-gradient(top, #f2f2f2, #e5e5e5);background-image:linear-gradient(to bottom, #f2f2f2, #e5e5e5);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#fff2f2f2', endColorstr='#ffe5e5e5', GradientType=0);border-color:#e5e5e5 #e5e5e5 #bfbfbf;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#e5e5e5;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);-webkit-box-shadow:inset 0 1px 0 rgba(255,255,255,.1), 0 1px 0 rgba(255,255,255,.075);-moz-box-shadow:inset 0 1px 0 rgba(255,255,255,.1), 0 1px 0 rgba(255,255,255,.075);box-shadow:inset 0 1px 0 rgba(255,255,255,.1), 0 1px 0 rgba(255,255,255,.075);}.navbar .btn-navbar:hover,.navbar .btn-navbar:focus,.navbar .btn-navbar:active,.navbar .btn-navbar.active,.navbar .btn-navbar.disabled,.navbar .btn-navbar[disabled]{color:#ffffff;background-color:#e5e5e5;*background-color:#d9d9d9;} +.navbar .btn-navbar:active,.navbar .btn-navbar.active{background-color:#cccccc \9;} +.navbar .btn-navbar .icon-bar{display:block;width:18px;height:2px;background-color:#f5f5f5;-webkit-border-radius:1px;-moz-border-radius:1px;border-radius:1px;-webkit-box-shadow:0 1px 0 rgba(0, 0, 0, 0.25);-moz-box-shadow:0 1px 0 rgba(0, 0, 0, 0.25);box-shadow:0 1px 0 rgba(0, 0, 0, 0.25);} +.btn-navbar .icon-bar+.icon-bar{margin-top:3px;} +.navbar .nav>li>.dropdown-menu:before{content:'';display:inline-block;border-left:7px solid transparent;border-right:7px solid transparent;border-bottom:7px solid #ccc;border-bottom-color:rgba(0, 0, 0, 0.2);position:absolute;top:-7px;left:9px;} +.navbar .nav>li>.dropdown-menu:after{content:'';display:inline-block;border-left:6px solid transparent;border-right:6px solid transparent;border-bottom:6px solid #ffffff;position:absolute;top:-6px;left:10px;} +.navbar-fixed-bottom .nav>li>.dropdown-menu:before{border-top:7px solid #ccc;border-top-color:rgba(0, 0, 0, 0.2);border-bottom:0;bottom:-7px;top:auto;} +.navbar-fixed-bottom .nav>li>.dropdown-menu:after{border-top:6px solid #ffffff;border-bottom:0;bottom:-6px;top:auto;} +.navbar .nav li.dropdown>a:hover .caret,.navbar .nav li.dropdown>a:focus .caret{border-top-color:#333333;border-bottom-color:#333333;} +.navbar .nav li.dropdown.open>.dropdown-toggle,.navbar .nav li.dropdown.active>.dropdown-toggle,.navbar .nav li.dropdown.open.active>.dropdown-toggle{background-color:#e5e5e5;color:#555555;} +.navbar .nav li.dropdown>.dropdown-toggle .caret{border-top-color:#777777;border-bottom-color:#777777;} +.navbar .nav li.dropdown.open>.dropdown-toggle .caret,.navbar .nav li.dropdown.active>.dropdown-toggle .caret,.navbar .nav li.dropdown.open.active>.dropdown-toggle .caret{border-top-color:#555555;border-bottom-color:#555555;} +.navbar .pull-right>li>.dropdown-menu,.navbar .nav>li>.dropdown-menu.pull-right{left:auto;right:0;}.navbar .pull-right>li>.dropdown-menu:before,.navbar .nav>li>.dropdown-menu.pull-right:before{left:auto;right:12px;} +.navbar .pull-right>li>.dropdown-menu:after,.navbar .nav>li>.dropdown-menu.pull-right:after{left:auto;right:13px;} +.navbar .pull-right>li>.dropdown-menu .dropdown-menu,.navbar .nav>li>.dropdown-menu.pull-right .dropdown-menu{left:auto;right:100%;margin-left:0;margin-right:-1px;-webkit-border-radius:6px 0 6px 6px;-moz-border-radius:6px 0 6px 6px;border-radius:6px 0 6px 6px;} +.navbar-inverse .navbar-inner{background-color:#1b1b1b;background-image:-moz-linear-gradient(top, #222222, #111111);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#222222), to(#111111));background-image:-webkit-linear-gradient(top, #222222, #111111);background-image:-o-linear-gradient(top, #222222, #111111);background-image:linear-gradient(to bottom, #222222, #111111);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff222222', endColorstr='#ff111111', GradientType=0);border-color:#252525;} +.navbar-inverse .brand,.navbar-inverse .nav>li>a{color:#999999;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);}.navbar-inverse .brand:hover,.navbar-inverse .nav>li>a:hover,.navbar-inverse .brand:focus,.navbar-inverse .nav>li>a:focus{color:#ffffff;} +.navbar-inverse .brand{color:#999999;} +.navbar-inverse .navbar-text{color:#999999;} +.navbar-inverse .nav>li>a:focus,.navbar-inverse .nav>li>a:hover{background-color:transparent;color:#ffffff;} +.navbar-inverse .nav .active>a,.navbar-inverse .nav .active>a:hover,.navbar-inverse .nav .active>a:focus{color:#ffffff;background-color:#111111;} +.navbar-inverse .navbar-link{color:#999999;}.navbar-inverse .navbar-link:hover,.navbar-inverse .navbar-link:focus{color:#ffffff;} +.navbar-inverse .divider-vertical{border-left-color:#111111;border-right-color:#222222;} +.navbar-inverse .nav li.dropdown.open>.dropdown-toggle,.navbar-inverse .nav li.dropdown.active>.dropdown-toggle,.navbar-inverse .nav li.dropdown.open.active>.dropdown-toggle{background-color:#111111;color:#ffffff;} +.navbar-inverse .nav li.dropdown>a:hover .caret,.navbar-inverse .nav li.dropdown>a:focus .caret{border-top-color:#ffffff;border-bottom-color:#ffffff;} +.navbar-inverse .nav li.dropdown>.dropdown-toggle .caret{border-top-color:#999999;border-bottom-color:#999999;} +.navbar-inverse .nav li.dropdown.open>.dropdown-toggle .caret,.navbar-inverse .nav li.dropdown.active>.dropdown-toggle .caret,.navbar-inverse .nav li.dropdown.open.active>.dropdown-toggle .caret{border-top-color:#ffffff;border-bottom-color:#ffffff;} +.navbar-inverse .navbar-search .search-query{color:#ffffff;background-color:#515151;border-color:#111111;-webkit-box-shadow:inset 0 1px 2px rgba(0,0,0,.1), 0 1px 0 rgba(255,255,255,.15);-moz-box-shadow:inset 0 1px 2px rgba(0,0,0,.1), 0 1px 0 rgba(255,255,255,.15);box-shadow:inset 0 1px 2px rgba(0,0,0,.1), 0 1px 0 rgba(255,255,255,.15);-webkit-transition:none;-moz-transition:none;-o-transition:none;transition:none;}.navbar-inverse .navbar-search .search-query:-moz-placeholder{color:#cccccc;} +.navbar-inverse .navbar-search .search-query:-ms-input-placeholder{color:#cccccc;} +.navbar-inverse .navbar-search .search-query::-webkit-input-placeholder{color:#cccccc;} +.navbar-inverse .navbar-search .search-query:focus,.navbar-inverse .navbar-search .search-query.focused{padding:5px 15px;color:#333333;text-shadow:0 1px 0 #ffffff;background-color:#ffffff;border:0;-webkit-box-shadow:0 0 3px rgba(0, 0, 0, 0.15);-moz-box-shadow:0 0 3px rgba(0, 0, 0, 0.15);box-shadow:0 0 3px rgba(0, 0, 0, 0.15);outline:0;} +.navbar-inverse .btn-navbar{color:#ffffff;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#0e0e0e;background-image:-moz-linear-gradient(top, #151515, #040404);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#151515), to(#040404));background-image:-webkit-linear-gradient(top, #151515, #040404);background-image:-o-linear-gradient(top, #151515, #040404);background-image:linear-gradient(to bottom, #151515, #040404);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff151515', endColorstr='#ff040404', GradientType=0);border-color:#040404 #040404 #000000;border-color:rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.1) rgba(0, 0, 0, 0.25);*background-color:#040404;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);}.navbar-inverse .btn-navbar:hover,.navbar-inverse .btn-navbar:focus,.navbar-inverse .btn-navbar:active,.navbar-inverse .btn-navbar.active,.navbar-inverse .btn-navbar.disabled,.navbar-inverse .btn-navbar[disabled]{color:#ffffff;background-color:#040404;*background-color:#000000;} +.navbar-inverse .btn-navbar:active,.navbar-inverse .btn-navbar.active{background-color:#000000 \9;} +.breadcrumb{padding:8px 15px;margin:0 0 20px;list-style:none;background-color:#f5f5f5;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;}.breadcrumb>li{display:inline-block;*display:inline;*zoom:1;text-shadow:0 1px 0 #ffffff;}.breadcrumb>li>.divider{padding:0 5px;color:#ccc;} +.breadcrumb>.active{color:#999999;} +.pagination{margin:20px 0;} +.pagination ul{display:inline-block;*display:inline;*zoom:1;margin-left:0;margin-bottom:0;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;-webkit-box-shadow:0 1px 2px rgba(0, 0, 0, 0.05);-moz-box-shadow:0 1px 2px rgba(0, 0, 0, 0.05);box-shadow:0 1px 2px rgba(0, 0, 0, 0.05);} +.pagination ul>li{display:inline;} +.pagination ul>li>a,.pagination ul>li>span{float:left;padding:4px 12px;line-height:20px;text-decoration:none;background-color:#ffffff;border:1px solid #dddddd;border-left-width:0;} +.pagination ul>li>a:hover,.pagination ul>li>a:focus,.pagination ul>.active>a,.pagination ul>.active>span{background-color:#f5f5f5;} +.pagination ul>.active>a,.pagination ul>.active>span{color:#999999;cursor:default;} +.pagination ul>.disabled>span,.pagination ul>.disabled>a,.pagination ul>.disabled>a:hover,.pagination ul>.disabled>a:focus{color:#999999;background-color:transparent;cursor:default;} +.pagination ul>li:first-child>a,.pagination ul>li:first-child>span{border-left-width:1px;-webkit-border-top-left-radius:4px;-moz-border-radius-topleft:4px;border-top-left-radius:4px;-webkit-border-bottom-left-radius:4px;-moz-border-radius-bottomleft:4px;border-bottom-left-radius:4px;} +.pagination ul>li:last-child>a,.pagination ul>li:last-child>span{-webkit-border-top-right-radius:4px;-moz-border-radius-topright:4px;border-top-right-radius:4px;-webkit-border-bottom-right-radius:4px;-moz-border-radius-bottomright:4px;border-bottom-right-radius:4px;} +.pagination-centered{text-align:center;} +.pagination-right{text-align:right;} +.pagination-large ul>li>a,.pagination-large ul>li>span{padding:11px 19px;font-size:17.5px;} +.pagination-large ul>li:first-child>a,.pagination-large ul>li:first-child>span{-webkit-border-top-left-radius:6px;-moz-border-radius-topleft:6px;border-top-left-radius:6px;-webkit-border-bottom-left-radius:6px;-moz-border-radius-bottomleft:6px;border-bottom-left-radius:6px;} +.pagination-large ul>li:last-child>a,.pagination-large ul>li:last-child>span{-webkit-border-top-right-radius:6px;-moz-border-radius-topright:6px;border-top-right-radius:6px;-webkit-border-bottom-right-radius:6px;-moz-border-radius-bottomright:6px;border-bottom-right-radius:6px;} +.pagination-mini ul>li:first-child>a,.pagination-small ul>li:first-child>a,.pagination-mini ul>li:first-child>span,.pagination-small ul>li:first-child>span{-webkit-border-top-left-radius:3px;-moz-border-radius-topleft:3px;border-top-left-radius:3px;-webkit-border-bottom-left-radius:3px;-moz-border-radius-bottomleft:3px;border-bottom-left-radius:3px;} +.pagination-mini ul>li:last-child>a,.pagination-small ul>li:last-child>a,.pagination-mini ul>li:last-child>span,.pagination-small ul>li:last-child>span{-webkit-border-top-right-radius:3px;-moz-border-radius-topright:3px;border-top-right-radius:3px;-webkit-border-bottom-right-radius:3px;-moz-border-radius-bottomright:3px;border-bottom-right-radius:3px;} +.pagination-small ul>li>a,.pagination-small ul>li>span{padding:2px 10px;font-size:11.9px;} +.pagination-mini ul>li>a,.pagination-mini ul>li>span{padding:0 6px;font-size:10.5px;} +.pager{margin:20px 0;list-style:none;text-align:center;*zoom:1;}.pager:before,.pager:after{display:table;content:"";line-height:0;} +.pager:after{clear:both;} +.pager li{display:inline;} +.pager li>a,.pager li>span{display:inline-block;padding:5px 14px;background-color:#fff;border:1px solid #ddd;-webkit-border-radius:15px;-moz-border-radius:15px;border-radius:15px;} +.pager li>a:hover,.pager li>a:focus{text-decoration:none;background-color:#f5f5f5;} +.pager .next>a,.pager .next>span{float:right;} +.pager .previous>a,.pager .previous>span{float:left;} +.pager .disabled>a,.pager .disabled>a:hover,.pager .disabled>a:focus,.pager .disabled>span{color:#999999;background-color:#fff;cursor:default;} +.thumbnails{margin-left:-20px;list-style:none;*zoom:1;}.thumbnails:before,.thumbnails:after{display:table;content:"";line-height:0;} +.thumbnails:after{clear:both;} +.row-fluid .thumbnails{margin-left:0;} +.thumbnails>li{float:left;margin-bottom:20px;margin-left:20px;} +.thumbnail{display:block;padding:4px;line-height:20px;border:1px solid #ddd;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;-webkit-box-shadow:0 1px 3px rgba(0, 0, 0, 0.055);-moz-box-shadow:0 1px 3px rgba(0, 0, 0, 0.055);box-shadow:0 1px 3px rgba(0, 0, 0, 0.055);-webkit-transition:all 0.2s ease-in-out;-moz-transition:all 0.2s ease-in-out;-o-transition:all 0.2s ease-in-out;transition:all 0.2s ease-in-out;} +a.thumbnail:hover,a.thumbnail:focus{border-color:#0088cc;-webkit-box-shadow:0 1px 4px rgba(0, 105, 214, 0.25);-moz-box-shadow:0 1px 4px rgba(0, 105, 214, 0.25);box-shadow:0 1px 4px rgba(0, 105, 214, 0.25);} +.thumbnail>img{display:block;max-width:100%;margin-left:auto;margin-right:auto;} +.thumbnail .caption{padding:9px;color:#555555;} +.alert{padding:8px 35px 8px 14px;margin-bottom:20px;text-shadow:0 1px 0 rgba(255, 255, 255, 0.5);background-color:#fcf8e3;border:1px solid #fbeed5;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;} +.alert,.alert h4{color:#c09853;} +.alert h4{margin:0;} +.alert .close{position:relative;top:-2px;right:-21px;line-height:20px;} +.alert-success{background-color:#dff0d8;border-color:#d6e9c6;color:#468847;} +.alert-success h4{color:#468847;} +.alert-danger,.alert-error{background-color:#f2dede;border-color:#eed3d7;color:#b94a48;} +.alert-danger h4,.alert-error h4{color:#b94a48;} +.alert-info{background-color:#d9edf7;border-color:#bce8f1;color:#3a87ad;} +.alert-info h4{color:#3a87ad;} +.alert-block{padding-top:14px;padding-bottom:14px;} +.alert-block>p,.alert-block>ul{margin-bottom:0;} +.alert-block p+p{margin-top:5px;} +@-webkit-keyframes progress-bar-stripes{from{background-position:40px 0;} to{background-position:0 0;}}@-moz-keyframes progress-bar-stripes{from{background-position:40px 0;} to{background-position:0 0;}}@-ms-keyframes progress-bar-stripes{from{background-position:40px 0;} to{background-position:0 0;}}@-o-keyframes progress-bar-stripes{from{background-position:0 0;} to{background-position:40px 0;}}@keyframes progress-bar-stripes{from{background-position:40px 0;} to{background-position:0 0;}}.progress{overflow:hidden;height:20px;margin-bottom:20px;background-color:#f7f7f7;background-image:-moz-linear-gradient(top, #f5f5f5, #f9f9f9);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#f5f5f5), to(#f9f9f9));background-image:-webkit-linear-gradient(top, #f5f5f5, #f9f9f9);background-image:-o-linear-gradient(top, #f5f5f5, #f9f9f9);background-image:linear-gradient(to bottom, #f5f5f5, #f9f9f9);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#fff5f5f5', endColorstr='#fff9f9f9', GradientType=0);-webkit-box-shadow:inset 0 1px 2px rgba(0, 0, 0, 0.1);-moz-box-shadow:inset 0 1px 2px rgba(0, 0, 0, 0.1);box-shadow:inset 0 1px 2px rgba(0, 0, 0, 0.1);-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;} +.progress .bar{width:0%;height:100%;color:#ffffff;float:left;font-size:12px;text-align:center;text-shadow:0 -1px 0 rgba(0, 0, 0, 0.25);background-color:#0e90d2;background-image:-moz-linear-gradient(top, #149bdf, #0480be);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#149bdf), to(#0480be));background-image:-webkit-linear-gradient(top, #149bdf, #0480be);background-image:-o-linear-gradient(top, #149bdf, #0480be);background-image:linear-gradient(to bottom, #149bdf, #0480be);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff149bdf', endColorstr='#ff0480be', GradientType=0);-webkit-box-shadow:inset 0 -1px 0 rgba(0, 0, 0, 0.15);-moz-box-shadow:inset 0 -1px 0 rgba(0, 0, 0, 0.15);box-shadow:inset 0 -1px 0 rgba(0, 0, 0, 0.15);-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;-webkit-transition:width 0.6s ease;-moz-transition:width 0.6s ease;-o-transition:width 0.6s ease;transition:width 0.6s ease;} +.progress .bar+.bar{-webkit-box-shadow:inset 1px 0 0 rgba(0,0,0,.15), inset 0 -1px 0 rgba(0,0,0,.15);-moz-box-shadow:inset 1px 0 0 rgba(0,0,0,.15), inset 0 -1px 0 rgba(0,0,0,.15);box-shadow:inset 1px 0 0 rgba(0,0,0,.15), inset 0 -1px 0 rgba(0,0,0,.15);} +.progress-striped .bar{background-color:#149bdf;background-image:-webkit-gradient(linear, 0 100%, 100% 0, color-stop(0.25, rgba(255, 255, 255, 0.15)), color-stop(0.25, transparent), color-stop(0.5, transparent), color-stop(0.5, rgba(255, 255, 255, 0.15)), color-stop(0.75, rgba(255, 255, 255, 0.15)), color-stop(0.75, transparent), to(transparent));background-image:-webkit-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-moz-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-o-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);-webkit-background-size:40px 40px;-moz-background-size:40px 40px;-o-background-size:40px 40px;background-size:40px 40px;} +.progress.active .bar{-webkit-animation:progress-bar-stripes 2s linear infinite;-moz-animation:progress-bar-stripes 2s linear infinite;-ms-animation:progress-bar-stripes 2s linear infinite;-o-animation:progress-bar-stripes 2s linear infinite;animation:progress-bar-stripes 2s linear infinite;} +.progress-danger .bar,.progress .bar-danger{background-color:#dd514c;background-image:-moz-linear-gradient(top, #ee5f5b, #c43c35);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#ee5f5b), to(#c43c35));background-image:-webkit-linear-gradient(top, #ee5f5b, #c43c35);background-image:-o-linear-gradient(top, #ee5f5b, #c43c35);background-image:linear-gradient(to bottom, #ee5f5b, #c43c35);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ffee5f5b', endColorstr='#ffc43c35', GradientType=0);} +.progress-danger.progress-striped .bar,.progress-striped .bar-danger{background-color:#ee5f5b;background-image:-webkit-gradient(linear, 0 100%, 100% 0, color-stop(0.25, rgba(255, 255, 255, 0.15)), color-stop(0.25, transparent), color-stop(0.5, transparent), color-stop(0.5, rgba(255, 255, 255, 0.15)), color-stop(0.75, rgba(255, 255, 255, 0.15)), color-stop(0.75, transparent), to(transparent));background-image:-webkit-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-moz-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-o-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);} +.progress-success .bar,.progress .bar-success{background-color:#5eb95e;background-image:-moz-linear-gradient(top, #62c462, #57a957);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#62c462), to(#57a957));background-image:-webkit-linear-gradient(top, #62c462, #57a957);background-image:-o-linear-gradient(top, #62c462, #57a957);background-image:linear-gradient(to bottom, #62c462, #57a957);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff62c462', endColorstr='#ff57a957', GradientType=0);} +.progress-success.progress-striped .bar,.progress-striped .bar-success{background-color:#62c462;background-image:-webkit-gradient(linear, 0 100%, 100% 0, color-stop(0.25, rgba(255, 255, 255, 0.15)), color-stop(0.25, transparent), color-stop(0.5, transparent), color-stop(0.5, rgba(255, 255, 255, 0.15)), color-stop(0.75, rgba(255, 255, 255, 0.15)), color-stop(0.75, transparent), to(transparent));background-image:-webkit-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-moz-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-o-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);} +.progress-info .bar,.progress .bar-info{background-color:#4bb1cf;background-image:-moz-linear-gradient(top, #5bc0de, #339bb9);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#5bc0de), to(#339bb9));background-image:-webkit-linear-gradient(top, #5bc0de, #339bb9);background-image:-o-linear-gradient(top, #5bc0de, #339bb9);background-image:linear-gradient(to bottom, #5bc0de, #339bb9);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff5bc0de', endColorstr='#ff339bb9', GradientType=0);} +.progress-info.progress-striped .bar,.progress-striped .bar-info{background-color:#5bc0de;background-image:-webkit-gradient(linear, 0 100%, 100% 0, color-stop(0.25, rgba(255, 255, 255, 0.15)), color-stop(0.25, transparent), color-stop(0.5, transparent), color-stop(0.5, rgba(255, 255, 255, 0.15)), color-stop(0.75, rgba(255, 255, 255, 0.15)), color-stop(0.75, transparent), to(transparent));background-image:-webkit-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-moz-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-o-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);} +.progress-warning .bar,.progress .bar-warning{background-color:#faa732;background-image:-moz-linear-gradient(top, #fbb450, #f89406);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#fbb450), to(#f89406));background-image:-webkit-linear-gradient(top, #fbb450, #f89406);background-image:-o-linear-gradient(top, #fbb450, #f89406);background-image:linear-gradient(to bottom, #fbb450, #f89406);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#fffbb450', endColorstr='#fff89406', GradientType=0);} +.progress-warning.progress-striped .bar,.progress-striped .bar-warning{background-color:#fbb450;background-image:-webkit-gradient(linear, 0 100%, 100% 0, color-stop(0.25, rgba(255, 255, 255, 0.15)), color-stop(0.25, transparent), color-stop(0.5, transparent), color-stop(0.5, rgba(255, 255, 255, 0.15)), color-stop(0.75, rgba(255, 255, 255, 0.15)), color-stop(0.75, transparent), to(transparent));background-image:-webkit-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-moz-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:-o-linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);background-image:linear-gradient(45deg, rgba(255, 255, 255, 0.15) 25%, transparent 25%, transparent 50%, rgba(255, 255, 255, 0.15) 50%, rgba(255, 255, 255, 0.15) 75%, transparent 75%, transparent);} +.hero-unit{padding:60px;margin-bottom:30px;font-size:18px;font-weight:200;line-height:30px;color:inherit;background-color:#eeeeee;-webkit-border-radius:6px;-moz-border-radius:6px;border-radius:6px;}.hero-unit h1{margin-bottom:0;font-size:60px;line-height:1;color:inherit;letter-spacing:-1px;} +.hero-unit li{line-height:30px;} +.media,.media-body{overflow:hidden;*overflow:visible;zoom:1;} +.media,.media .media{margin-top:15px;} +.media:first-child{margin-top:0;} +.media-object{display:block;} +.media-heading{margin:0 0 5px;} +.media>.pull-left{margin-right:10px;} +.media>.pull-right{margin-left:10px;} +.media-list{margin-left:0;list-style:none;} +.tooltip{position:absolute;z-index:1030;display:block;visibility:visible;font-size:11px;line-height:1.4;opacity:0;filter:alpha(opacity=0);}.tooltip.in{opacity:0.8;filter:alpha(opacity=80);} +.tooltip.top{margin-top:-3px;padding:5px 0;} +.tooltip.right{margin-left:3px;padding:0 5px;} +.tooltip.bottom{margin-top:3px;padding:5px 0;} +.tooltip.left{margin-left:-3px;padding:0 5px;} +.tooltip-inner{max-width:200px;padding:8px;color:#ffffff;text-align:center;text-decoration:none;background-color:#000000;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;} +.tooltip-arrow{position:absolute;width:0;height:0;border-color:transparent;border-style:solid;} +.tooltip.top .tooltip-arrow{bottom:0;left:50%;margin-left:-5px;border-width:5px 5px 0;border-top-color:#000000;} +.tooltip.right .tooltip-arrow{top:50%;left:0;margin-top:-5px;border-width:5px 5px 5px 0;border-right-color:#000000;} +.tooltip.left .tooltip-arrow{top:50%;right:0;margin-top:-5px;border-width:5px 0 5px 5px;border-left-color:#000000;} +.tooltip.bottom .tooltip-arrow{top:0;left:50%;margin-left:-5px;border-width:0 5px 5px;border-bottom-color:#000000;} +.popover{position:absolute;top:0;left:0;z-index:1010;display:none;max-width:276px;padding:1px;text-align:left;background-color:#ffffff;-webkit-background-clip:padding-box;-moz-background-clip:padding;background-clip:padding-box;border:1px solid #ccc;border:1px solid rgba(0, 0, 0, 0.2);-webkit-border-radius:6px;-moz-border-radius:6px;border-radius:6px;-webkit-box-shadow:0 5px 10px rgba(0, 0, 0, 0.2);-moz-box-shadow:0 5px 10px rgba(0, 0, 0, 0.2);box-shadow:0 5px 10px rgba(0, 0, 0, 0.2);white-space:normal;}.popover.top{margin-top:-10px;} +.popover.right{margin-left:10px;} +.popover.bottom{margin-top:10px;} +.popover.left{margin-left:-10px;} +.popover-title{margin:0;padding:8px 14px;font-size:14px;font-weight:normal;line-height:18px;background-color:#f7f7f7;border-bottom:1px solid #ebebeb;-webkit-border-radius:5px 5px 0 0;-moz-border-radius:5px 5px 0 0;border-radius:5px 5px 0 0;}.popover-title:empty{display:none;} +.popover-content{padding:9px 14px;} +.popover .arrow,.popover .arrow:after{position:absolute;display:block;width:0;height:0;border-color:transparent;border-style:solid;} +.popover .arrow{border-width:11px;} +.popover .arrow:after{border-width:10px;content:"";} +.popover.top .arrow{left:50%;margin-left:-11px;border-bottom-width:0;border-top-color:#999;border-top-color:rgba(0, 0, 0, 0.25);bottom:-11px;}.popover.top .arrow:after{bottom:1px;margin-left:-10px;border-bottom-width:0;border-top-color:#ffffff;} +.popover.right .arrow{top:50%;left:-11px;margin-top:-11px;border-left-width:0;border-right-color:#999;border-right-color:rgba(0, 0, 0, 0.25);}.popover.right .arrow:after{left:1px;bottom:-10px;border-left-width:0;border-right-color:#ffffff;} +.popover.bottom .arrow{left:50%;margin-left:-11px;border-top-width:0;border-bottom-color:#999;border-bottom-color:rgba(0, 0, 0, 0.25);top:-11px;}.popover.bottom .arrow:after{top:1px;margin-left:-10px;border-top-width:0;border-bottom-color:#ffffff;} +.popover.left .arrow{top:50%;right:-11px;margin-top:-11px;border-right-width:0;border-left-color:#999;border-left-color:rgba(0, 0, 0, 0.25);}.popover.left .arrow:after{right:1px;border-right-width:0;border-left-color:#ffffff;bottom:-10px;} +.modal-backdrop{position:fixed;top:0;right:0;bottom:0;left:0;z-index:1040;background-color:#000000;}.modal-backdrop.fade{opacity:0;} +.modal-backdrop,.modal-backdrop.fade.in{opacity:0.8;filter:alpha(opacity=80);} +.modal{position:fixed;top:10%;left:50%;z-index:1050;width:560px;margin-left:-280px;background-color:#ffffff;border:1px solid #999;border:1px solid rgba(0, 0, 0, 0.3);*border:1px solid #999;-webkit-border-radius:6px;-moz-border-radius:6px;border-radius:6px;-webkit-box-shadow:0 3px 7px rgba(0, 0, 0, 0.3);-moz-box-shadow:0 3px 7px rgba(0, 0, 0, 0.3);box-shadow:0 3px 7px rgba(0, 0, 0, 0.3);-webkit-background-clip:padding-box;-moz-background-clip:padding-box;background-clip:padding-box;outline:none;}.modal.fade{-webkit-transition:opacity .3s linear, top .3s ease-out;-moz-transition:opacity .3s linear, top .3s ease-out;-o-transition:opacity .3s linear, top .3s ease-out;transition:opacity .3s linear, top .3s ease-out;top:-25%;} +.modal.fade.in{top:10%;} +.modal-header{padding:9px 15px;border-bottom:1px solid #eee;}.modal-header .close{margin-top:2px;} +.modal-header h3{margin:0;line-height:30px;} +.modal-body{position:relative;overflow-y:auto;max-height:400px;padding:15px;} +.modal-form{margin-bottom:0;} +.modal-footer{padding:14px 15px 15px;margin-bottom:0;text-align:right;background-color:#f5f5f5;border-top:1px solid #ddd;-webkit-border-radius:0 0 6px 6px;-moz-border-radius:0 0 6px 6px;border-radius:0 0 6px 6px;-webkit-box-shadow:inset 0 1px 0 #ffffff;-moz-box-shadow:inset 0 1px 0 #ffffff;box-shadow:inset 0 1px 0 #ffffff;*zoom:1;}.modal-footer:before,.modal-footer:after{display:table;content:"";line-height:0;} +.modal-footer:after{clear:both;} +.modal-footer .btn+.btn{margin-left:5px;margin-bottom:0;} +.modal-footer .btn-group .btn+.btn{margin-left:-1px;} +.modal-footer .btn-block+.btn-block{margin-left:0;} +.dropup,.dropdown{position:relative;} +.dropdown-toggle{*margin-bottom:-3px;} +.dropdown-toggle:active,.open .dropdown-toggle{outline:0;} +.caret{display:inline-block;width:0;height:0;vertical-align:top;border-top:4px solid #000000;border-right:4px solid transparent;border-left:4px solid transparent;content:"";} +.dropdown .caret{margin-top:8px;margin-left:2px;} +.dropdown-menu{position:absolute;top:100%;left:0;z-index:1000;display:none;float:left;min-width:160px;padding:5px 0;margin:2px 0 0;list-style:none;background-color:#ffffff;border:1px solid #ccc;border:1px solid rgba(0, 0, 0, 0.2);*border-right-width:2px;*border-bottom-width:2px;-webkit-border-radius:6px;-moz-border-radius:6px;border-radius:6px;-webkit-box-shadow:0 5px 10px rgba(0, 0, 0, 0.2);-moz-box-shadow:0 5px 10px rgba(0, 0, 0, 0.2);box-shadow:0 5px 10px rgba(0, 0, 0, 0.2);-webkit-background-clip:padding-box;-moz-background-clip:padding;background-clip:padding-box;}.dropdown-menu.pull-right{right:0;left:auto;} +.dropdown-menu .divider{*width:100%;height:1px;margin:9px 1px;*margin:-5px 0 5px;overflow:hidden;background-color:#e5e5e5;border-bottom:1px solid #ffffff;} +.dropdown-menu>li>a{display:block;padding:3px 20px;clear:both;font-weight:normal;line-height:20px;color:#333333;white-space:nowrap;} +.dropdown-menu>li>a:hover,.dropdown-menu>li>a:focus,.dropdown-submenu:hover>a,.dropdown-submenu:focus>a{text-decoration:none;color:#ffffff;background-color:#0081c2;background-image:-moz-linear-gradient(top, #0088cc, #0077b3);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#0088cc), to(#0077b3));background-image:-webkit-linear-gradient(top, #0088cc, #0077b3);background-image:-o-linear-gradient(top, #0088cc, #0077b3);background-image:linear-gradient(to bottom, #0088cc, #0077b3);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff0088cc', endColorstr='#ff0077b3', GradientType=0);} +.dropdown-menu>.active>a,.dropdown-menu>.active>a:hover,.dropdown-menu>.active>a:focus{color:#ffffff;text-decoration:none;outline:0;background-color:#0081c2;background-image:-moz-linear-gradient(top, #0088cc, #0077b3);background-image:-webkit-gradient(linear, 0 0, 0 100%, from(#0088cc), to(#0077b3));background-image:-webkit-linear-gradient(top, #0088cc, #0077b3);background-image:-o-linear-gradient(top, #0088cc, #0077b3);background-image:linear-gradient(to bottom, #0088cc, #0077b3);background-repeat:repeat-x;filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ff0088cc', endColorstr='#ff0077b3', GradientType=0);} +.dropdown-menu>.disabled>a,.dropdown-menu>.disabled>a:hover,.dropdown-menu>.disabled>a:focus{color:#999999;} +.dropdown-menu>.disabled>a:hover,.dropdown-menu>.disabled>a:focus{text-decoration:none;background-color:transparent;background-image:none;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);cursor:default;} +.open{*z-index:1000;}.open>.dropdown-menu{display:block;} +.dropdown-backdrop{position:fixed;left:0;right:0;bottom:0;top:0;z-index:990;} +.pull-right>.dropdown-menu{right:0;left:auto;} +.dropup .caret,.navbar-fixed-bottom .dropdown .caret{border-top:0;border-bottom:4px solid #000000;content:"";} +.dropup .dropdown-menu,.navbar-fixed-bottom .dropdown .dropdown-menu{top:auto;bottom:100%;margin-bottom:1px;} +.dropdown-submenu{position:relative;} +.dropdown-submenu>.dropdown-menu{top:0;left:100%;margin-top:-6px;margin-left:-1px;-webkit-border-radius:0 6px 6px 6px;-moz-border-radius:0 6px 6px 6px;border-radius:0 6px 6px 6px;} +.dropdown-submenu:hover>.dropdown-menu{display:block;} +.dropup .dropdown-submenu>.dropdown-menu{top:auto;bottom:0;margin-top:0;margin-bottom:-2px;-webkit-border-radius:5px 5px 5px 0;-moz-border-radius:5px 5px 5px 0;border-radius:5px 5px 5px 0;} +.dropdown-submenu>a:after{display:block;content:" ";float:right;width:0;height:0;border-color:transparent;border-style:solid;border-width:5px 0 5px 5px;border-left-color:#cccccc;margin-top:5px;margin-right:-10px;} +.dropdown-submenu:hover>a:after{border-left-color:#ffffff;} +.dropdown-submenu.pull-left{float:none;}.dropdown-submenu.pull-left>.dropdown-menu{left:-100%;margin-left:10px;-webkit-border-radius:6px 0 6px 6px;-moz-border-radius:6px 0 6px 6px;border-radius:6px 0 6px 6px;} +.dropdown .dropdown-menu .nav-header{padding-left:20px;padding-right:20px;} +.typeahead{z-index:1051;margin-top:2px;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;} +.accordion{margin-bottom:20px;} +.accordion-group{margin-bottom:2px;border:1px solid #e5e5e5;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;} +.accordion-heading{border-bottom:0;} +.accordion-heading .accordion-toggle{display:block;padding:8px 15px;} +.accordion-toggle{cursor:pointer;} +.accordion-inner{padding:9px 15px;border-top:1px solid #e5e5e5;} +.carousel{position:relative;margin-bottom:20px;line-height:1;} +.carousel-inner{overflow:hidden;width:100%;position:relative;} +.carousel-inner>.item{display:none;position:relative;-webkit-transition:0.6s ease-in-out left;-moz-transition:0.6s ease-in-out left;-o-transition:0.6s ease-in-out left;transition:0.6s ease-in-out left;}.carousel-inner>.item>img,.carousel-inner>.item>a>img{display:block;line-height:1;} +.carousel-inner>.active,.carousel-inner>.next,.carousel-inner>.prev{display:block;} +.carousel-inner>.active{left:0;} +.carousel-inner>.next,.carousel-inner>.prev{position:absolute;top:0;width:100%;} +.carousel-inner>.next{left:100%;} +.carousel-inner>.prev{left:-100%;} +.carousel-inner>.next.left,.carousel-inner>.prev.right{left:0;} +.carousel-inner>.active.left{left:-100%;} +.carousel-inner>.active.right{left:100%;} +.carousel-control{position:absolute;top:40%;left:15px;width:40px;height:40px;margin-top:-20px;font-size:60px;font-weight:100;line-height:30px;color:#ffffff;text-align:center;background:#222222;border:3px solid #ffffff;-webkit-border-radius:23px;-moz-border-radius:23px;border-radius:23px;opacity:0.5;filter:alpha(opacity=50);}.carousel-control.right{left:auto;right:15px;} +.carousel-control:hover,.carousel-control:focus{color:#ffffff;text-decoration:none;opacity:0.9;filter:alpha(opacity=90);} +.carousel-indicators{position:absolute;top:15px;right:15px;z-index:5;margin:0;list-style:none;}.carousel-indicators li{display:block;float:left;width:10px;height:10px;margin-left:5px;text-indent:-999px;background-color:#ccc;background-color:rgba(255, 255, 255, 0.25);border-radius:5px;} +.carousel-indicators .active{background-color:#fff;} +.carousel-caption{position:absolute;left:0;right:0;bottom:0;padding:15px;background:#333333;background:rgba(0, 0, 0, 0.75);} +.carousel-caption h4,.carousel-caption p{color:#ffffff;line-height:20px;} +.carousel-caption h4{margin:0 0 5px;} +.carousel-caption p{margin-bottom:0;} +.well{min-height:20px;padding:19px;margin-bottom:20px;background-color:#f5f5f5;border:1px solid #e3e3e3;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;-webkit-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.05);-moz-box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.05);box-shadow:inset 0 1px 1px rgba(0, 0, 0, 0.05);}.well blockquote{border-color:#ddd;border-color:rgba(0, 0, 0, 0.15);} +.well-large{padding:24px;-webkit-border-radius:6px;-moz-border-radius:6px;border-radius:6px;} +.well-small{padding:9px;-webkit-border-radius:3px;-moz-border-radius:3px;border-radius:3px;} +.close{float:right;font-size:20px;font-weight:bold;line-height:20px;color:#000000;text-shadow:0 1px 0 #ffffff;opacity:0.2;filter:alpha(opacity=20);}.close:hover,.close:focus{color:#000000;text-decoration:none;cursor:pointer;opacity:0.4;filter:alpha(opacity=40);} +button.close{padding:0;cursor:pointer;background:transparent;border:0;-webkit-appearance:none;} +.pull-right{float:right;} +.pull-left{float:left;} +.hide{display:none;} +.show{display:block;} +.invisible{visibility:hidden;} +.affix{position:fixed;} +.fade{opacity:0;-webkit-transition:opacity 0.15s linear;-moz-transition:opacity 0.15s linear;-o-transition:opacity 0.15s linear;transition:opacity 0.15s linear;}.fade.in{opacity:1;} +.collapse{position:relative;height:0;overflow:hidden;-webkit-transition:height 0.35s ease;-moz-transition:height 0.35s ease;-o-transition:height 0.35s ease;transition:height 0.35s ease;}.collapse.in{height:auto;} +@-ms-viewport{width:device-width;}.hidden{display:none;visibility:hidden;} +.visible-phone{display:none !important;} +.visible-tablet{display:none !important;} +.hidden-desktop{display:none !important;} +.visible-desktop{display:inherit !important;} +@media (min-width:768px) and (max-width:979px){.hidden-desktop{display:inherit !important;} .visible-desktop{display:none !important ;} .visible-tablet{display:inherit !important;} .hidden-tablet{display:none !important;}}@media (max-width:767px){.hidden-desktop{display:inherit !important;} .visible-desktop{display:none !important;} .visible-phone{display:inherit !important;} .hidden-phone{display:none !important;}}.visible-print{display:none !important;} +@media print{.visible-print{display:inherit !important;} .hidden-print{display:none !important;}}@media (max-width:767px){body{padding-left:20px;padding-right:20px;} .navbar-fixed-top,.navbar-fixed-bottom,.navbar-static-top{margin-left:-20px;margin-right:-20px;} .container-fluid{padding:0;} .dl-horizontal dt{float:none;clear:none;width:auto;text-align:left;} .dl-horizontal dd{margin-left:0;} .container{width:auto;} .row-fluid{width:100%;} .row,.thumbnails{margin-left:0;} .thumbnails>li{float:none;margin-left:0;} [class*="span"],.uneditable-input[class*="span"],.row-fluid [class*="span"]{float:none;display:block;width:100%;margin-left:0;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;} .span12,.row-fluid .span12{width:100%;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;} .row-fluid [class*="offset"]:first-child{margin-left:0;} .input-large,.input-xlarge,.input-xxlarge,input[class*="span"],select[class*="span"],textarea[class*="span"],.uneditable-input{display:block;width:100%;min-height:30px;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;} .input-prepend input,.input-append input,.input-prepend input[class*="span"],.input-append input[class*="span"]{display:inline-block;width:auto;} .controls-row [class*="span"]+[class*="span"]{margin-left:0;} .modal{position:fixed;top:20px;left:20px;right:20px;width:auto;margin:0;}.modal.fade{top:-100px;} .modal.fade.in{top:20px;}}@media (max-width:480px){.nav-collapse{-webkit-transform:translate3d(0, 0, 0);} .page-header h1 small{display:block;line-height:20px;} input[type="checkbox"],input[type="radio"]{border:1px solid #ccc;} .form-horizontal .control-label{float:none;width:auto;padding-top:0;text-align:left;} .form-horizontal .controls{margin-left:0;} .form-horizontal .control-list{padding-top:0;} .form-horizontal .form-actions{padding-left:10px;padding-right:10px;} .media .pull-left,.media .pull-right{float:none;display:block;margin-bottom:10px;} .media-object{margin-right:0;margin-left:0;} .modal{top:10px;left:10px;right:10px;} .modal-header .close{padding:10px;margin:-10px;} .carousel-caption{position:static;}}@media (min-width:768px) and (max-width:979px){.row{margin-left:-20px;*zoom:1;}.row:before,.row:after{display:table;content:"";line-height:0;} .row:after{clear:both;} [class*="span"]{float:left;min-height:1px;margin-left:20px;} .container,.navbar-static-top .container,.navbar-fixed-top .container,.navbar-fixed-bottom .container{width:724px;} .span12{width:724px;} .span11{width:662px;} .span10{width:600px;} .span9{width:538px;} .span8{width:476px;} .span7{width:414px;} .span6{width:352px;} .span5{width:290px;} .span4{width:228px;} .span3{width:166px;} .span2{width:104px;} .span1{width:42px;} .offset12{margin-left:764px;} .offset11{margin-left:702px;} .offset10{margin-left:640px;} .offset9{margin-left:578px;} .offset8{margin-left:516px;} .offset7{margin-left:454px;} .offset6{margin-left:392px;} .offset5{margin-left:330px;} .offset4{margin-left:268px;} .offset3{margin-left:206px;} .offset2{margin-left:144px;} .offset1{margin-left:82px;} .row-fluid{width:100%;*zoom:1;}.row-fluid:before,.row-fluid:after{display:table;content:"";line-height:0;} .row-fluid:after{clear:both;} .row-fluid [class*="span"]{display:block;width:100%;min-height:30px;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;float:left;margin-left:2.7624309392265194%;*margin-left:2.709239449864817%;} .row-fluid [class*="span"]:first-child{margin-left:0;} .row-fluid .controls-row [class*="span"]+[class*="span"]{margin-left:2.7624309392265194%;} .row-fluid .span12{width:100%;*width:99.94680851063829%;} .row-fluid .span11{width:91.43646408839778%;*width:91.38327259903608%;} .row-fluid .span10{width:82.87292817679558%;*width:82.81973668743387%;} .row-fluid .span9{width:74.30939226519337%;*width:74.25620077583166%;} .row-fluid .span8{width:65.74585635359117%;*width:65.69266486422946%;} .row-fluid .span7{width:57.18232044198895%;*width:57.12912895262725%;} .row-fluid .span6{width:48.61878453038674%;*width:48.56559304102504%;} .row-fluid .span5{width:40.05524861878453%;*width:40.00205712942283%;} .row-fluid .span4{width:31.491712707182323%;*width:31.43852121782062%;} .row-fluid .span3{width:22.92817679558011%;*width:22.87498530621841%;} .row-fluid .span2{width:14.3646408839779%;*width:14.311449394616199%;} .row-fluid .span1{width:5.801104972375691%;*width:5.747913483013988%;} .row-fluid .offset12{margin-left:105.52486187845304%;*margin-left:105.41847889972962%;} .row-fluid .offset12:first-child{margin-left:102.76243093922652%;*margin-left:102.6560479605031%;} .row-fluid .offset11{margin-left:96.96132596685082%;*margin-left:96.8549429881274%;} .row-fluid .offset11:first-child{margin-left:94.1988950276243%;*margin-left:94.09251204890089%;} .row-fluid .offset10{margin-left:88.39779005524862%;*margin-left:88.2914070765252%;} .row-fluid .offset10:first-child{margin-left:85.6353591160221%;*margin-left:85.52897613729868%;} .row-fluid .offset9{margin-left:79.8342541436464%;*margin-left:79.72787116492299%;} .row-fluid .offset9:first-child{margin-left:77.07182320441989%;*margin-left:76.96544022569647%;} .row-fluid .offset8{margin-left:71.2707182320442%;*margin-left:71.16433525332079%;} .row-fluid .offset8:first-child{margin-left:68.50828729281768%;*margin-left:68.40190431409427%;} .row-fluid .offset7{margin-left:62.70718232044199%;*margin-left:62.600799341718584%;} .row-fluid .offset7:first-child{margin-left:59.94475138121547%;*margin-left:59.838368402492065%;} .row-fluid .offset6{margin-left:54.14364640883978%;*margin-left:54.037263430116376%;} .row-fluid .offset6:first-child{margin-left:51.38121546961326%;*margin-left:51.27483249088986%;} .row-fluid .offset5{margin-left:45.58011049723757%;*margin-left:45.47372751851417%;} .row-fluid .offset5:first-child{margin-left:42.81767955801105%;*margin-left:42.71129657928765%;} .row-fluid .offset4{margin-left:37.01657458563536%;*margin-left:36.91019160691196%;} .row-fluid .offset4:first-child{margin-left:34.25414364640884%;*margin-left:34.14776066768544%;} .row-fluid .offset3{margin-left:28.45303867403315%;*margin-left:28.346655695309746%;} .row-fluid .offset3:first-child{margin-left:25.69060773480663%;*margin-left:25.584224756083227%;} .row-fluid .offset2{margin-left:19.88950276243094%;*margin-left:19.783119783707537%;} .row-fluid .offset2:first-child{margin-left:17.12707182320442%;*margin-left:17.02068884448102%;} .row-fluid .offset1{margin-left:11.32596685082873%;*margin-left:11.219583872105325%;} .row-fluid .offset1:first-child{margin-left:8.56353591160221%;*margin-left:8.457152932878806%;} input,textarea,.uneditable-input{margin-left:0;} .controls-row [class*="span"]+[class*="span"]{margin-left:20px;} input.span12,textarea.span12,.uneditable-input.span12{width:710px;} input.span11,textarea.span11,.uneditable-input.span11{width:648px;} input.span10,textarea.span10,.uneditable-input.span10{width:586px;} input.span9,textarea.span9,.uneditable-input.span9{width:524px;} input.span8,textarea.span8,.uneditable-input.span8{width:462px;} input.span7,textarea.span7,.uneditable-input.span7{width:400px;} input.span6,textarea.span6,.uneditable-input.span6{width:338px;} input.span5,textarea.span5,.uneditable-input.span5{width:276px;} input.span4,textarea.span4,.uneditable-input.span4{width:214px;} input.span3,textarea.span3,.uneditable-input.span3{width:152px;} input.span2,textarea.span2,.uneditable-input.span2{width:90px;} input.span1,textarea.span1,.uneditable-input.span1{width:28px;}}@media (min-width:1200px){.row{margin-left:-30px;*zoom:1;}.row:before,.row:after{display:table;content:"";line-height:0;} .row:after{clear:both;} [class*="span"]{float:left;min-height:1px;margin-left:30px;} .container,.navbar-static-top .container,.navbar-fixed-top .container,.navbar-fixed-bottom .container{width:1170px;} .span12{width:1170px;} .span11{width:1070px;} .span10{width:970px;} .span9{width:870px;} .span8{width:770px;} .span7{width:670px;} .span6{width:570px;} .span5{width:470px;} .span4{width:370px;} .span3{width:270px;} .span2{width:170px;} .span1{width:70px;} .offset12{margin-left:1230px;} .offset11{margin-left:1130px;} .offset10{margin-left:1030px;} .offset9{margin-left:930px;} .offset8{margin-left:830px;} .offset7{margin-left:730px;} .offset6{margin-left:630px;} .offset5{margin-left:530px;} .offset4{margin-left:430px;} .offset3{margin-left:330px;} .offset2{margin-left:230px;} .offset1{margin-left:130px;} .row-fluid{width:100%;*zoom:1;}.row-fluid:before,.row-fluid:after{display:table;content:"";line-height:0;} .row-fluid:after{clear:both;} .row-fluid [class*="span"]{display:block;width:100%;min-height:30px;-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box;float:left;margin-left:2.564102564102564%;*margin-left:2.5109110747408616%;} .row-fluid [class*="span"]:first-child{margin-left:0;} .row-fluid .controls-row [class*="span"]+[class*="span"]{margin-left:2.564102564102564%;} .row-fluid .span12{width:100%;*width:99.94680851063829%;} .row-fluid .span11{width:91.45299145299145%;*width:91.39979996362975%;} .row-fluid .span10{width:82.90598290598291%;*width:82.8527914166212%;} .row-fluid .span9{width:74.35897435897436%;*width:74.30578286961266%;} .row-fluid .span8{width:65.81196581196582%;*width:65.75877432260411%;} .row-fluid .span7{width:57.26495726495726%;*width:57.21176577559556%;} .row-fluid .span6{width:48.717948717948715%;*width:48.664757228587014%;} .row-fluid .span5{width:40.17094017094017%;*width:40.11774868157847%;} .row-fluid .span4{width:31.623931623931625%;*width:31.570740134569924%;} .row-fluid .span3{width:23.076923076923077%;*width:23.023731587561375%;} .row-fluid .span2{width:14.52991452991453%;*width:14.476723040552828%;} .row-fluid .span1{width:5.982905982905983%;*width:5.929714493544281%;} .row-fluid .offset12{margin-left:105.12820512820512%;*margin-left:105.02182214948171%;} .row-fluid .offset12:first-child{margin-left:102.56410256410257%;*margin-left:102.45771958537915%;} .row-fluid .offset11{margin-left:96.58119658119658%;*margin-left:96.47481360247316%;} .row-fluid .offset11:first-child{margin-left:94.01709401709402%;*margin-left:93.91071103837061%;} .row-fluid .offset10{margin-left:88.03418803418803%;*margin-left:87.92780505546462%;} .row-fluid .offset10:first-child{margin-left:85.47008547008548%;*margin-left:85.36370249136206%;} .row-fluid .offset9{margin-left:79.48717948717949%;*margin-left:79.38079650845607%;} .row-fluid .offset9:first-child{margin-left:76.92307692307693%;*margin-left:76.81669394435352%;} .row-fluid .offset8{margin-left:70.94017094017094%;*margin-left:70.83378796144753%;} .row-fluid .offset8:first-child{margin-left:68.37606837606839%;*margin-left:68.26968539734497%;} .row-fluid .offset7{margin-left:62.393162393162385%;*margin-left:62.28677941443899%;} .row-fluid .offset7:first-child{margin-left:59.82905982905982%;*margin-left:59.72267685033642%;} .row-fluid .offset6{margin-left:53.84615384615384%;*margin-left:53.739770867430444%;} .row-fluid .offset6:first-child{margin-left:51.28205128205128%;*margin-left:51.175668303327875%;} .row-fluid .offset5{margin-left:45.299145299145295%;*margin-left:45.1927623204219%;} .row-fluid .offset5:first-child{margin-left:42.73504273504273%;*margin-left:42.62865975631933%;} .row-fluid .offset4{margin-left:36.75213675213675%;*margin-left:36.645753773413354%;} .row-fluid .offset4:first-child{margin-left:34.18803418803419%;*margin-left:34.081651209310785%;} .row-fluid .offset3{margin-left:28.205128205128204%;*margin-left:28.0987452264048%;} .row-fluid .offset3:first-child{margin-left:25.641025641025642%;*margin-left:25.53464266230224%;} .row-fluid .offset2{margin-left:19.65811965811966%;*margin-left:19.551736679396257%;} .row-fluid .offset2:first-child{margin-left:17.094017094017094%;*margin-left:16.98763411529369%;} .row-fluid .offset1{margin-left:11.11111111111111%;*margin-left:11.004728132387708%;} .row-fluid .offset1:first-child{margin-left:8.547008547008547%;*margin-left:8.440625568285142%;} input,textarea,.uneditable-input{margin-left:0;} .controls-row [class*="span"]+[class*="span"]{margin-left:30px;} input.span12,textarea.span12,.uneditable-input.span12{width:1156px;} input.span11,textarea.span11,.uneditable-input.span11{width:1056px;} input.span10,textarea.span10,.uneditable-input.span10{width:956px;} input.span9,textarea.span9,.uneditable-input.span9{width:856px;} input.span8,textarea.span8,.uneditable-input.span8{width:756px;} input.span7,textarea.span7,.uneditable-input.span7{width:656px;} input.span6,textarea.span6,.uneditable-input.span6{width:556px;} input.span5,textarea.span5,.uneditable-input.span5{width:456px;} input.span4,textarea.span4,.uneditable-input.span4{width:356px;} input.span3,textarea.span3,.uneditable-input.span3{width:256px;} input.span2,textarea.span2,.uneditable-input.span2{width:156px;} input.span1,textarea.span1,.uneditable-input.span1{width:56px;} .thumbnails{margin-left:-30px;} .thumbnails>li{margin-left:30px;} .row-fluid .thumbnails{margin-left:0;}}@media (max-width:979px){body{padding-top:0;} .navbar-fixed-top,.navbar-fixed-bottom{position:static;} .navbar-fixed-top{margin-bottom:20px;} .navbar-fixed-bottom{margin-top:20px;} .navbar-fixed-top .navbar-inner,.navbar-fixed-bottom .navbar-inner{padding:5px;} .navbar .container{width:auto;padding:0;} .navbar .brand{padding-left:10px;padding-right:10px;margin:0 0 0 -5px;} .nav-collapse{clear:both;} .nav-collapse .nav{float:none;margin:0 0 10px;} .nav-collapse .nav>li{float:none;} .nav-collapse .nav>li>a{margin-bottom:2px;} .nav-collapse .nav>.divider-vertical{display:none;} .nav-collapse .nav .nav-header{color:#777777;text-shadow:none;} .nav-collapse .nav>li>a,.nav-collapse .dropdown-menu a{padding:9px 15px;font-weight:bold;color:#777777;-webkit-border-radius:3px;-moz-border-radius:3px;border-radius:3px;} .nav-collapse .btn{padding:4px 10px 4px;font-weight:normal;-webkit-border-radius:4px;-moz-border-radius:4px;border-radius:4px;} .nav-collapse .dropdown-menu li+li a{margin-bottom:2px;} .nav-collapse .nav>li>a:hover,.nav-collapse .nav>li>a:focus,.nav-collapse .dropdown-menu a:hover,.nav-collapse .dropdown-menu a:focus{background-color:#f2f2f2;} .navbar-inverse .nav-collapse .nav>li>a,.navbar-inverse .nav-collapse .dropdown-menu a{color:#999999;} .navbar-inverse .nav-collapse .nav>li>a:hover,.navbar-inverse .nav-collapse .nav>li>a:focus,.navbar-inverse .nav-collapse .dropdown-menu a:hover,.navbar-inverse .nav-collapse .dropdown-menu a:focus{background-color:#111111;} .nav-collapse.in .btn-group{margin-top:5px;padding:0;} .nav-collapse .dropdown-menu{position:static;top:auto;left:auto;float:none;display:none;max-width:none;margin:0 15px;padding:0;background-color:transparent;border:none;-webkit-border-radius:0;-moz-border-radius:0;border-radius:0;-webkit-box-shadow:none;-moz-box-shadow:none;box-shadow:none;} .nav-collapse .open>.dropdown-menu{display:block;} .nav-collapse .dropdown-menu:before,.nav-collapse .dropdown-menu:after{display:none;} .nav-collapse .dropdown-menu .divider{display:none;} .nav-collapse .nav>li>.dropdown-menu:before,.nav-collapse .nav>li>.dropdown-menu:after{display:none;} .nav-collapse .navbar-form,.nav-collapse .navbar-search{float:none;padding:10px 15px;margin:10px 0;border-top:1px solid #f2f2f2;border-bottom:1px solid #f2f2f2;-webkit-box-shadow:inset 0 1px 0 rgba(255,255,255,.1), 0 1px 0 rgba(255,255,255,.1);-moz-box-shadow:inset 0 1px 0 rgba(255,255,255,.1), 0 1px 0 rgba(255,255,255,.1);box-shadow:inset 0 1px 0 rgba(255,255,255,.1), 0 1px 0 rgba(255,255,255,.1);} .navbar-inverse .nav-collapse .navbar-form,.navbar-inverse .nav-collapse .navbar-search{border-top-color:#111111;border-bottom-color:#111111;} .navbar .nav-collapse .nav.pull-right{float:none;margin-left:0;} .nav-collapse,.nav-collapse.collapse{overflow:hidden;height:0;} .navbar .btn-navbar{display:block;} .navbar-static .navbar-inner{padding-left:10px;padding-right:10px;}}@media (min-width:980px){.nav-collapse.collapse{height:auto !important;overflow:visible !important;}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js new file mode 100644 index 0000000000..7abb9011cc --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -0,0 +1,495 @@ +/* + SortTable + version 2 + 7th April 2007 + Stuart Langridge, http://www.kryogenix.org/code/browser/sorttable/ + + Instructions: + Download this file + Add to your HTML + Add class="sortable" to any table you'd like to make sortable + Click on the headers to sort + + Thanks to many, many people for contributions and suggestions. + Licenced as X11: http://www.kryogenix.org/code/browser/licence.html + This basically means: do what you want with it. +*/ + + +var stIsIE = /*@cc_on!@*/false; + +sorttable = { + init: function() { + // quit if this function has already been called + if (arguments.callee.done) return; + // flag this function so we don't do the same thing twice + arguments.callee.done = true; + // kill the timer + if (_timer) clearInterval(_timer); + + if (!document.createElement || !document.getElementsByTagName) return; + + sorttable.DATE_RE = /^(\d\d?)[\/\.-](\d\d?)[\/\.-]((\d\d)?\d\d)$/; + + forEach(document.getElementsByTagName('table'), function(table) { + if (table.className.search(/\bsortable\b/) != -1) { + sorttable.makeSortable(table); + } + }); + + }, + + makeSortable: function(table) { + if (table.getElementsByTagName('thead').length == 0) { + // table doesn't have a tHead. Since it should have, create one and + // put the first table row in it. + the = document.createElement('thead'); + the.appendChild(table.rows[0]); + table.insertBefore(the,table.firstChild); + } + // Safari doesn't support table.tHead, sigh + if (table.tHead == null) table.tHead = table.getElementsByTagName('thead')[0]; + + if (table.tHead.rows.length != 1) return; // can't cope with two header rows + + // Sorttable v1 put rows with a class of "sortbottom" at the bottom (as + // "total" rows, for example). This is B&R, since what you're supposed + // to do is put them in a tfoot. So, if there are sortbottom rows, + // for backwards compatibility, move them to tfoot (creating it if needed). + sortbottomrows = []; + for (var i=0; i5' : ' ▴'; + this.appendChild(sortrevind); + return; + } + if (this.className.search(/\bsorttable_sorted_reverse\b/) != -1) { + // if we're already sorted by this column in reverse, just + // re-reverse the table, which is quicker + sorttable.reverse(this.sorttable_tbody); + this.className = this.className.replace('sorttable_sorted_reverse', + 'sorttable_sorted'); + this.removeChild(document.getElementById('sorttable_sortrevind')); + sortfwdind = document.createElement('span'); + sortfwdind.id = "sorttable_sortfwdind"; + sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; + this.appendChild(sortfwdind); + return; + } + + // remove sorttable_sorted classes + theadrow = this.parentNode; + forEach(theadrow.childNodes, function(cell) { + if (cell.nodeType == 1) { // an element + cell.className = cell.className.replace('sorttable_sorted_reverse',''); + cell.className = cell.className.replace('sorttable_sorted',''); + } + }); + sortfwdind = document.getElementById('sorttable_sortfwdind'); + if (sortfwdind) { sortfwdind.parentNode.removeChild(sortfwdind); } + sortrevind = document.getElementById('sorttable_sortrevind'); + if (sortrevind) { sortrevind.parentNode.removeChild(sortrevind); } + + this.className += ' sorttable_sorted'; + sortfwdind = document.createElement('span'); + sortfwdind.id = "sorttable_sortfwdind"; + sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; + this.appendChild(sortfwdind); + + // build an array to sort. This is a Schwartzian transform thing, + // i.e., we "decorate" each row with the actual sort key, + // sort based on the sort keys, and then put the rows back in order + // which is a lot faster because you only do getInnerText once per row + row_array = []; + col = this.sorttable_columnindex; + rows = this.sorttable_tbody.rows; + for (var j=0; j 12) { + // definitely dd/mm + return sorttable.sort_ddmm; + } else if (second > 12) { + return sorttable.sort_mmdd; + } else { + // looks like a date, but we can't tell which, so assume + // that it's dd/mm (English imperialism!) and keep looking + sortfn = sorttable.sort_ddmm; + } + } + } + } + return sortfn; + }, + + getInnerText: function(node) { + // gets the text we want to use for sorting for a cell. + // strips leading and trailing whitespace. + // this is *not* a generic getInnerText function; it's special to sorttable. + // for example, you can override the cell text with a customkey attribute. + // it also gets .value for fields. + + if (!node) return ""; + + hasInputs = (typeof node.getElementsByTagName == 'function') && + node.getElementsByTagName('input').length; + + if (node.getAttribute("sorttable_customkey") != null) { + return node.getAttribute("sorttable_customkey"); + } + else if (typeof node.textContent != 'undefined' && !hasInputs) { + return node.textContent.replace(/^\s+|\s+$/g, ''); + } + else if (typeof node.innerText != 'undefined' && !hasInputs) { + return node.innerText.replace(/^\s+|\s+$/g, ''); + } + else if (typeof node.text != 'undefined' && !hasInputs) { + return node.text.replace(/^\s+|\s+$/g, ''); + } + else { + switch (node.nodeType) { + case 3: + if (node.nodeName.toLowerCase() == 'input') { + return node.value.replace(/^\s+|\s+$/g, ''); + } + case 4: + return node.nodeValue.replace(/^\s+|\s+$/g, ''); + break; + case 1: + case 11: + var innerText = ''; + for (var i = 0; i < node.childNodes.length; i++) { + innerText += sorttable.getInnerText(node.childNodes[i]); + } + return innerText.replace(/^\s+|\s+$/g, ''); + break; + default: + return ''; + } + } + }, + + reverse: function(tbody) { + // reverse the rows in a tbody + newrows = []; + for (var i=0; i=0; i--) { + tbody.appendChild(newrows[i]); + } + delete newrows; + }, + + /* sort functions + each sort function takes two parameters, a and b + you are comparing a[0] and b[0] */ + sort_numeric: function(a,b) { + aa = parseFloat(a[0].replace(/[^0-9.-]/g,'')); + if (isNaN(aa)) aa = 0; + bb = parseFloat(b[0].replace(/[^0-9.-]/g,'')); + if (isNaN(bb)) bb = 0; + return aa-bb; + }, + sort_alpha: function(a,b) { + if (a[0]==b[0]) return 0; + if (a[0] 0 ) { + var q = list[i]; list[i] = list[i+1]; list[i+1] = q; + swap = true; + } + } // for + t--; + + if (!swap) break; + + for(var i = t; i > b; --i) { + if ( comp_func(list[i], list[i-1]) < 0 ) { + var q = list[i]; list[i] = list[i-1]; list[i-1] = q; + swap = true; + } + } // for + b++; + + } // while(swap) + } +} + +/* ****************************************************************** + Supporting functions: bundled here to avoid depending on a library + ****************************************************************** */ + +// Dean Edwards/Matthias Miller/John Resig + +/* for Mozilla/Opera9 */ +if (document.addEventListener) { + document.addEventListener("DOMContentLoaded", sorttable.init, false); +} + +/* for Internet Explorer */ +/*@cc_on @*/ +/*@if (@_win32) + document.write(" to your HTML - Add class="sortable" to any table you'd like to make sortable - Click on the headers to sort - - Thanks to many, many people for contributions and suggestions. - Licenced as X11: http://www.kryogenix.org/code/browser/licence.html - This basically means: do what you want with it. -*/ - - -var stIsIE = /*@cc_on!@*/false; - -sorttable = { - init: function() { - // quit if this function has already been called - if (arguments.callee.done) return; - // flag this function so we don't do the same thing twice - arguments.callee.done = true; - // kill the timer - if (_timer) clearInterval(_timer); - - if (!document.createElement || !document.getElementsByTagName) return; - - sorttable.DATE_RE = /^(\d\d?)[\/\.-](\d\d?)[\/\.-]((\d\d)?\d\d)$/; - - forEach(document.getElementsByTagName('table'), function(table) { - if (table.className.search(/\bsortable\b/) != -1) { - sorttable.makeSortable(table); - } - }); - - }, - - makeSortable: function(table) { - if (table.getElementsByTagName('thead').length == 0) { - // table doesn't have a tHead. Since it should have, create one and - // put the first table row in it. - the = document.createElement('thead'); - the.appendChild(table.rows[0]); - table.insertBefore(the,table.firstChild); - } - // Safari doesn't support table.tHead, sigh - if (table.tHead == null) table.tHead = table.getElementsByTagName('thead')[0]; - - if (table.tHead.rows.length != 1) return; // can't cope with two header rows - - // Sorttable v1 put rows with a class of "sortbottom" at the bottom (as - // "total" rows, for example). This is B&R, since what you're supposed - // to do is put them in a tfoot. So, if there are sortbottom rows, - // for backwards compatibility, move them to tfoot (creating it if needed). - sortbottomrows = []; - for (var i=0; i5' : ' ▴'; - this.appendChild(sortrevind); - return; - } - if (this.className.search(/\bsorttable_sorted_reverse\b/) != -1) { - // if we're already sorted by this column in reverse, just - // re-reverse the table, which is quicker - sorttable.reverse(this.sorttable_tbody); - this.className = this.className.replace('sorttable_sorted_reverse', - 'sorttable_sorted'); - this.removeChild(document.getElementById('sorttable_sortrevind')); - sortfwdind = document.createElement('span'); - sortfwdind.id = "sorttable_sortfwdind"; - sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; - this.appendChild(sortfwdind); - return; - } - - // remove sorttable_sorted classes - theadrow = this.parentNode; - forEach(theadrow.childNodes, function(cell) { - if (cell.nodeType == 1) { // an element - cell.className = cell.className.replace('sorttable_sorted_reverse',''); - cell.className = cell.className.replace('sorttable_sorted',''); - } - }); - sortfwdind = document.getElementById('sorttable_sortfwdind'); - if (sortfwdind) { sortfwdind.parentNode.removeChild(sortfwdind); } - sortrevind = document.getElementById('sorttable_sortrevind'); - if (sortrevind) { sortrevind.parentNode.removeChild(sortrevind); } - - this.className += ' sorttable_sorted'; - sortfwdind = document.createElement('span'); - sortfwdind.id = "sorttable_sortfwdind"; - sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; - this.appendChild(sortfwdind); - - // build an array to sort. This is a Schwartzian transform thing, - // i.e., we "decorate" each row with the actual sort key, - // sort based on the sort keys, and then put the rows back in order - // which is a lot faster because you only do getInnerText once per row - row_array = []; - col = this.sorttable_columnindex; - rows = this.sorttable_tbody.rows; - for (var j=0; j 12) { - // definitely dd/mm - return sorttable.sort_ddmm; - } else if (second > 12) { - return sorttable.sort_mmdd; - } else { - // looks like a date, but we can't tell which, so assume - // that it's dd/mm (English imperialism!) and keep looking - sortfn = sorttable.sort_ddmm; - } - } - } - } - return sortfn; - }, - - getInnerText: function(node) { - // gets the text we want to use for sorting for a cell. - // strips leading and trailing whitespace. - // this is *not* a generic getInnerText function; it's special to sorttable. - // for example, you can override the cell text with a customkey attribute. - // it also gets .value for fields. - - if (!node) return ""; - - hasInputs = (typeof node.getElementsByTagName == 'function') && - node.getElementsByTagName('input').length; - - if (node.getAttribute("sorttable_customkey") != null) { - return node.getAttribute("sorttable_customkey"); - } - else if (typeof node.textContent != 'undefined' && !hasInputs) { - return node.textContent.replace(/^\s+|\s+$/g, ''); - } - else if (typeof node.innerText != 'undefined' && !hasInputs) { - return node.innerText.replace(/^\s+|\s+$/g, ''); - } - else if (typeof node.text != 'undefined' && !hasInputs) { - return node.text.replace(/^\s+|\s+$/g, ''); - } - else { - switch (node.nodeType) { - case 3: - if (node.nodeName.toLowerCase() == 'input') { - return node.value.replace(/^\s+|\s+$/g, ''); - } - case 4: - return node.nodeValue.replace(/^\s+|\s+$/g, ''); - break; - case 1: - case 11: - var innerText = ''; - for (var i = 0; i < node.childNodes.length; i++) { - innerText += sorttable.getInnerText(node.childNodes[i]); - } - return innerText.replace(/^\s+|\s+$/g, ''); - break; - default: - return ''; - } - } - }, - - reverse: function(tbody) { - // reverse the rows in a tbody - newrows = []; - for (var i=0; i=0; i--) { - tbody.appendChild(newrows[i]); - } - delete newrows; - }, - - /* sort functions - each sort function takes two parameters, a and b - you are comparing a[0] and b[0] */ - sort_numeric: function(a,b) { - aa = parseFloat(a[0].replace(/[^0-9.-]/g,'')); - if (isNaN(aa)) aa = 0; - bb = parseFloat(b[0].replace(/[^0-9.-]/g,'')); - if (isNaN(bb)) bb = 0; - return aa-bb; - }, - sort_alpha: function(a,b) { - if (a[0]==b[0]) return 0; - if (a[0] 0 ) { - var q = list[i]; list[i] = list[i+1]; list[i+1] = q; - swap = true; - } - } // for - t--; - - if (!swap) break; - - for(var i = t; i > b; --i) { - if ( comp_func(list[i], list[i-1]) < 0 ) { - var q = list[i]; list[i] = list[i-1]; list[i-1] = q; - swap = true; - } - } // for - b++; - - } // while(swap) - } -} - -/* ****************************************************************** - Supporting functions: bundled here to avoid depending on a library - ****************************************************************** */ - -// Dean Edwards/Matthias Miller/John Resig - -/* for Mozilla/Opera9 */ -if (document.addEventListener) { - document.addEventListener("DOMContentLoaded", sorttable.init, false); -} - -/* for Internet Explorer */ -/*@cc_on @*/ -/*@if (@_win32) - document.write(" + {sc.appName} - {title} + + + + +
+
+
+

+ {title} +

+
+
+ {content} +
+ + + } + + /** Returns a page with the spark css/js and a simple format. Used for scheduler UI. */ + def basicSparkPage(content: => Seq[Node], title: String): Seq[Node] = { + + + + + + + {title} + + +
+
+
+

+ + {title} +

+
+
+ {content} +
+ + + } + + /** Returns an HTML table constructed by generating a row for each object in a sequence. */ + def listingTable[T]( + headers: Seq[String], + makeRow: T => Seq[Node], + rows: Seq[T], + fixedWidth: Boolean = false): Seq[Node] = { + + val colWidth = 100.toDouble / headers.size + val colWidthAttr = if (fixedWidth) colWidth + "%" else "" + var tableClass = "table table-bordered table-striped table-condensed sortable" + if (fixedWidth) { + tableClass += " table-fixed" + } + + + {headers.map(h => )} + + {rows.map(r => makeRow(r))} + +
{h}
+ } +} diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala new file mode 100644 index 0000000000..0ecb22d2f9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import scala.util.Random + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import org.apache.spark.scheduler.cluster.SchedulingMode + + +/** + * Continuously generates jobs that expose various features of the WebUI (internal testing tool). + * + * Usage: ./run spark.ui.UIWorkloadGenerator [master] + */ +private[spark] object UIWorkloadGenerator { + val NUM_PARTITIONS = 100 + val INTER_JOB_WAIT_MS = 5000 + + def main(args: Array[String]) { + if (args.length < 2) { + println("usage: ./spark-class spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]") + System.exit(1) + } + val master = args(0) + val schedulingMode = SchedulingMode.withName(args(1)) + val appName = "Spark UI Tester" + + if (schedulingMode == SchedulingMode.FAIR) { + System.setProperty("spark.cluster.schedulingmode", "FAIR") + } + val sc = new SparkContext(master, appName) + + def setProperties(s: String) = { + if(schedulingMode == SchedulingMode.FAIR) { + sc.setLocalProperty("spark.scheduler.cluster.fair.pool", s) + } + sc.setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, s) + } + + val baseData = sc.makeRDD(1 to NUM_PARTITIONS * 10, NUM_PARTITIONS) + def nextFloat() = (new Random()).nextFloat() + + val jobs = Seq[(String, () => Long)]( + ("Count", baseData.count), + ("Cache and Count", baseData.map(x => x).cache.count), + ("Single Shuffle", baseData.map(x => (x % 10, x)).reduceByKey(_ + _).count), + ("Entirely failed phase", baseData.map(x => throw new Exception).count), + ("Partially failed phase", { + baseData.map{x => + val probFailure = (4.0 / NUM_PARTITIONS) + if (nextFloat() < probFailure) { + throw new Exception("This is a task failure") + } + 1 + }.count + }), + ("Partially failed phase (longer tasks)", { + baseData.map{x => + val probFailure = (4.0 / NUM_PARTITIONS) + if (nextFloat() < probFailure) { + Thread.sleep(100) + throw new Exception("This is a task failure") + } + 1 + }.count + }), + ("Job with delays", baseData.map(x => Thread.sleep(100)).count) + ) + + while (true) { + for ((desc, job) <- jobs) { + new Thread { + override def run() { + try { + setProperties(desc) + job() + println("Job funished: " + desc) + } catch { + case e: Exception => + println("Job Failed: " + desc) + } + } + }.start + Thread.sleep(INTER_JOB_WAIT_MS) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala new file mode 100644 index 0000000000..c5bf2acc9e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.env + +import javax.servlet.http.HttpServletRequest + +import scala.collection.JavaConversions._ +import scala.util.Properties +import scala.xml.Node + +import org.eclipse.jetty.server.Handler + +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.ui.UIUtils +import org.apache.spark.ui.Page.Environment +import org.apache.spark.SparkContext + + +private[spark] class EnvironmentUI(sc: SparkContext) { + + def getHandlers = Seq[(String, Handler)]( + ("/environment", (request: HttpServletRequest) => envDetails(request)) + ) + + def envDetails(request: HttpServletRequest): Seq[Node] = { + val jvmInformation = Seq( + ("Java Version", "%s (%s)".format(Properties.javaVersion, Properties.javaVendor)), + ("Java Home", Properties.javaHome), + ("Scala Version", Properties.versionString), + ("Scala Home", Properties.scalaHome) + ).sorted + def jvmRow(kv: (String, String)) = {kv._1}{kv._2} + def jvmTable = + UIUtils.listingTable(Seq("Name", "Value"), jvmRow, jvmInformation, fixedWidth = true) + + val properties = System.getProperties.iterator.toSeq + val classPathProperty = properties.find { case (k, v) => + k.contains("java.class.path") + }.getOrElse(("", "")) + val sparkProperties = properties.filter(_._1.startsWith("spark")).sorted + val otherProperties = properties.diff(sparkProperties :+ classPathProperty).sorted + + val propertyHeaders = Seq("Name", "Value") + def propertyRow(kv: (String, String)) = {kv._1}{kv._2} + val sparkPropertyTable = + UIUtils.listingTable(propertyHeaders, propertyRow, sparkProperties, fixedWidth = true) + val otherPropertyTable = + UIUtils.listingTable(propertyHeaders, propertyRow, otherProperties, fixedWidth = true) + + val classPathEntries = classPathProperty._2 + .split(System.getProperty("path.separator", ":")) + .filterNot(e => e.isEmpty) + .map(e => (e, "System Classpath")) + val addedJars = sc.addedJars.iterator.toSeq.map{case (path, time) => (path, "Added By User")} + val addedFiles = sc.addedFiles.iterator.toSeq.map{case (path, time) => (path, "Added By User")} + val classPath = (addedJars ++ addedFiles ++ classPathEntries).sorted + + val classPathHeaders = Seq("Resource", "Source") + def classPathRow(data: (String, String)) = {data._1}{data._2} + val classPathTable = + UIUtils.listingTable(classPathHeaders, classPathRow, classPath, fixedWidth = true) + + val content = + +

Runtime Information

{jvmTable} +

Spark Properties

+ {sparkPropertyTable} +

System Properties

+ {otherPropertyTable} +

Classpath Entries

+ {classPathTable} +
+ + UIUtils.headerSparkPage(content, sc, "Environment", Environment) + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala new file mode 100644 index 0000000000..efe6b474e0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -0,0 +1,136 @@ +package org.apache.spark.ui.exec + +import javax.servlet.http.HttpServletRequest + +import scala.collection.mutable.{HashMap, HashSet} +import scala.xml.Node + +import org.eclipse.jetty.server.Handler + +import org.apache.spark.{ExceptionFailure, Logging, Utils, SparkContext} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.cluster.TaskInfo +import org.apache.spark.scheduler.{SparkListenerTaskStart, SparkListenerTaskEnd, SparkListener} +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.ui.Page.Executors +import org.apache.spark.ui.UIUtils + + +private[spark] class ExecutorsUI(val sc: SparkContext) { + + private var _listener: Option[ExecutorsListener] = None + def listener = _listener.get + + def start() { + _listener = Some(new ExecutorsListener) + sc.addSparkListener(listener) + } + + def getHandlers = Seq[(String, Handler)]( + ("/executors", (request: HttpServletRequest) => render(request)) + ) + + def render(request: HttpServletRequest): Seq[Node] = { + val storageStatusList = sc.getExecutorStorageStatus + + val maxMem = storageStatusList.map(_.maxMem).fold(0L)(_+_) + val memUsed = storageStatusList.map(_.memUsed()).fold(0L)(_+_) + val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_+_) + + val execHead = Seq("Executor ID", "Address", "RDD blocks", "Memory used", "Disk used", + "Active tasks", "Failed tasks", "Complete tasks", "Total tasks") + + def execRow(kv: Seq[String]) = { + + {kv(0)} + {kv(1)} + {kv(2)} + + {Utils.bytesToString(kv(3).toLong)} / {Utils.bytesToString(kv(4).toLong)} + + + {Utils.bytesToString(kv(5).toLong)} + + {kv(6)} + {kv(7)} + {kv(8)} + {kv(9)} + + } + + val execInfo = for (b <- 0 until storageStatusList.size) yield getExecInfo(b) + val execTable = UIUtils.listingTable(execHead, execRow, execInfo) + + val content = +
+
+
    +
  • Memory: + {Utils.bytesToString(memUsed)} Used + ({Utils.bytesToString(maxMem)} Total)
  • +
  • Disk: {Utils.bytesToString(diskSpaceUsed)} Used
  • +
+
+
+
+
+ {execTable} +
+
; + + UIUtils.headerSparkPage(content, sc, "Executors (" + execInfo.size + ")", Executors) + } + + def getExecInfo(a: Int): Seq[String] = { + val execId = sc.getExecutorStorageStatus(a).blockManagerId.executorId + val hostPort = sc.getExecutorStorageStatus(a).blockManagerId.hostPort + val rddBlocks = sc.getExecutorStorageStatus(a).blocks.size.toString + val memUsed = sc.getExecutorStorageStatus(a).memUsed().toString + val maxMem = sc.getExecutorStorageStatus(a).maxMem.toString + val diskUsed = sc.getExecutorStorageStatus(a).diskUsed().toString + val activeTasks = listener.executorToTasksActive.get(a.toString).map(l => l.size).getOrElse(0) + val failedTasks = listener.executorToTasksFailed.getOrElse(a.toString, 0) + val completedTasks = listener.executorToTasksComplete.getOrElse(a.toString, 0) + val totalTasks = activeTasks + failedTasks + completedTasks + + Seq( + execId, + hostPort, + rddBlocks, + memUsed, + maxMem, + diskUsed, + activeTasks.toString, + failedTasks.toString, + completedTasks.toString, + totalTasks.toString + ) + } + + private[spark] class ExecutorsListener extends SparkListener with Logging { + val executorToTasksActive = HashMap[String, HashSet[TaskInfo]]() + val executorToTasksComplete = HashMap[String, Int]() + val executorToTasksFailed = HashMap[String, Int]() + + override def onTaskStart(taskStart: SparkListenerTaskStart) { + val eid = taskStart.taskInfo.executorId + val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]()) + activeTasks += taskStart.taskInfo + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val eid = taskEnd.taskInfo.executorId + val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]()) + activeTasks -= taskEnd.taskInfo + val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = + taskEnd.reason match { + case e: ExceptionFailure => + executorToTasksFailed(eid) = executorToTasksFailed.getOrElse(eid, 0) + 1 + (Some(e), e.metrics) + case _ => + executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1 + (None, Option(taskEnd.taskMetrics)) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala new file mode 100644 index 0000000000..3b428effaf --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import javax.servlet.http.HttpServletRequest + +import scala.xml.{NodeSeq, Node} + +import org.apache.spark.scheduler.cluster.SchedulingMode +import org.apache.spark.ui.Page._ +import org.apache.spark.ui.UIUtils._ + + +/** Page showing list of all ongoing and recently finished stages and pools*/ +private[spark] class IndexPage(parent: JobProgressUI) { + def listener = parent.listener + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val activeStages = listener.activeStages.toSeq + val completedStages = listener.completedStages.reverse.toSeq + val failedStages = listener.failedStages.reverse.toSeq + val now = System.currentTimeMillis() + + var activeTime = 0L + for (tasks <- listener.stageToTasksActive.values; t <- tasks) { + activeTime += t.timeRunning(now) + } + + val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent) + val completedStagesTable = new StageTable(completedStages.sortBy(_.submissionTime).reverse, parent) + val failedStagesTable = new StageTable(failedStages.sortBy(_.submissionTime).reverse, parent) + + val pools = listener.sc.getAllPools + val poolTable = new PoolTable(pools, listener) + val summary: NodeSeq = +
+
    +
  • + Total Duration: + {parent.formatDuration(now - listener.sc.startTime)} +
  • +
  • Scheduling Mode: {parent.sc.getSchedulingMode}
  • +
  • + Active Stages: + {activeStages.size} +
  • +
  • + Completed Stages: + {completedStages.size} +
  • +
  • + Failed Stages: + {failedStages.size} +
  • +
+
+ + val content = summary ++ + {if (listener.sc.getSchedulingMode == SchedulingMode.FAIR) { +

{pools.size} Fair Scheduler Pools

++ poolTable.toNodeSeq + } else { + Seq() + }} ++ +

Active Stages ({activeStages.size})

++ + activeStagesTable.toNodeSeq++ +

Completed Stages ({completedStages.size})

++ + completedStagesTable.toNodeSeq++ +

Failed Stages ({failedStages.size})

++ + failedStagesTable.toNodeSeq + + headerSparkPage(content, parent.sc, "Spark Stages", Stages) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala new file mode 100644 index 0000000000..ae02226300 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -0,0 +1,156 @@ +package org.apache.spark.ui.jobs + +import scala.Seq +import scala.collection.mutable.{ListBuffer, HashMap, HashSet} + +import org.apache.spark.{ExceptionFailure, SparkContext, Success, Utils} +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.TaskInfo +import org.apache.spark.executor.TaskMetrics +import collection.mutable + +/** + * Tracks task-level information to be displayed in the UI. + * + * All access to the data structures in this class must be synchronized on the + * class, since the UI thread and the DAGScheduler event loop may otherwise + * be reading/updating the internal data structures concurrently. + */ +private[spark] class JobProgressListener(val sc: SparkContext) extends SparkListener { + // How many stages to remember + val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt + val DEFAULT_POOL_NAME = "default" + + val stageToPool = new HashMap[Stage, String]() + val stageToDescription = new HashMap[Stage, String]() + val poolToActiveStages = new HashMap[String, HashSet[Stage]]() + + val activeStages = HashSet[Stage]() + val completedStages = ListBuffer[Stage]() + val failedStages = ListBuffer[Stage]() + + // Total metrics reflect metrics only for completed tasks + var totalTime = 0L + var totalShuffleRead = 0L + var totalShuffleWrite = 0L + + val stageToTime = HashMap[Int, Long]() + val stageToShuffleRead = HashMap[Int, Long]() + val stageToShuffleWrite = HashMap[Int, Long]() + val stageToTasksActive = HashMap[Int, HashSet[TaskInfo]]() + val stageToTasksComplete = HashMap[Int, Int]() + val stageToTasksFailed = HashMap[Int, Int]() + val stageToTaskInfos = + HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]() + + override def onJobStart(jobStart: SparkListenerJobStart) {} + + override def onStageCompleted(stageCompleted: StageCompleted) = synchronized { + val stage = stageCompleted.stageInfo.stage + poolToActiveStages(stageToPool(stage)) -= stage + activeStages -= stage + completedStages += stage + trimIfNecessary(completedStages) + } + + /** If stages is too large, remove and garbage collect old stages */ + def trimIfNecessary(stages: ListBuffer[Stage]) = synchronized { + if (stages.size > RETAINED_STAGES) { + val toRemove = RETAINED_STAGES / 10 + stages.takeRight(toRemove).foreach( s => { + stageToTaskInfos.remove(s.id) + stageToTime.remove(s.id) + stageToShuffleRead.remove(s.id) + stageToShuffleWrite.remove(s.id) + stageToTasksActive.remove(s.id) + stageToTasksComplete.remove(s.id) + stageToTasksFailed.remove(s.id) + stageToPool.remove(s) + if (stageToDescription.contains(s)) {stageToDescription.remove(s)} + }) + stages.trimEnd(toRemove) + } + } + + /** For FIFO, all stages are contained by "default" pool but "default" pool here is meaningless */ + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized { + val stage = stageSubmitted.stage + activeStages += stage + + val poolName = Option(stageSubmitted.properties).map { + p => p.getProperty("spark.scheduler.cluster.fair.pool", DEFAULT_POOL_NAME) + }.getOrElse(DEFAULT_POOL_NAME) + stageToPool(stage) = poolName + + val description = Option(stageSubmitted.properties).flatMap { + p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) + } + description.map(d => stageToDescription(stage) = d) + + val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[Stage]()) + stages += stage + } + + override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { + val sid = taskStart.task.stageId + val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) + tasksActive += taskStart.taskInfo + val taskList = stageToTaskInfos.getOrElse( + sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]()) + taskList += ((taskStart.taskInfo, None, None)) + stageToTaskInfos(sid) = taskList + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + val sid = taskEnd.task.stageId + val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) + tasksActive -= taskEnd.taskInfo + val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = + taskEnd.reason match { + case e: ExceptionFailure => + stageToTasksFailed(sid) = stageToTasksFailed.getOrElse(sid, 0) + 1 + (Some(e), e.metrics) + case _ => + stageToTasksComplete(sid) = stageToTasksComplete.getOrElse(sid, 0) + 1 + (None, Option(taskEnd.taskMetrics)) + } + + stageToTime.getOrElseUpdate(sid, 0L) + val time = metrics.map(m => m.executorRunTime).getOrElse(0) + stageToTime(sid) += time + totalTime += time + + stageToShuffleRead.getOrElseUpdate(sid, 0L) + val shuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s => + s.remoteBytesRead).getOrElse(0L) + stageToShuffleRead(sid) += shuffleRead + totalShuffleRead += shuffleRead + + stageToShuffleWrite.getOrElseUpdate(sid, 0L) + val shuffleWrite = metrics.flatMap(m => m.shuffleWriteMetrics).map(s => + s.shuffleBytesWritten).getOrElse(0L) + stageToShuffleWrite(sid) += shuffleWrite + totalShuffleWrite += shuffleWrite + + val taskList = stageToTaskInfos.getOrElse( + sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]()) + taskList -= ((taskEnd.taskInfo, None, None)) + taskList += ((taskEnd.taskInfo, metrics, failureInfo)) + stageToTaskInfos(sid) = taskList + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { + jobEnd match { + case end: SparkListenerJobEnd => + end.jobResult match { + case JobFailed(ex, Some(stage)) => + activeStages -= stage + poolToActiveStages(stageToPool(stage)) -= stage + failedStages += stage + trimIfNecessary(failedStages) + case _ => + } + case _ => + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala new file mode 100644 index 0000000000..1bb7638bd9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import akka.util.Duration + +import java.text.SimpleDateFormat + +import javax.servlet.http.HttpServletRequest + +import org.eclipse.jetty.server.Handler + +import scala.Seq +import scala.collection.mutable.{HashSet, ListBuffer, HashMap, ArrayBuffer} + +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.{ExceptionFailure, SparkContext, Success, Utils} +import org.apache.spark.scheduler._ +import collection.mutable +import org.apache.spark.scheduler.cluster.SchedulingMode +import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode + +/** Web UI showing progress status of all jobs in the given SparkContext. */ +private[spark] class JobProgressUI(val sc: SparkContext) { + private var _listener: Option[JobProgressListener] = None + def listener = _listener.get + val dateFmt = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + + private val indexPage = new IndexPage(this) + private val stagePage = new StagePage(this) + private val poolPage = new PoolPage(this) + + def start() { + _listener = Some(new JobProgressListener(sc)) + sc.addSparkListener(listener) + } + + def formatDuration(ms: Long) = Utils.msDurationToString(ms) + + def getHandlers = Seq[(String, Handler)]( + ("/stages/stage", (request: HttpServletRequest) => stagePage.render(request)), + ("/stages/pool", (request: HttpServletRequest) => poolPage.render(request)), + ("/stages", (request: HttpServletRequest) => indexPage.render(request)) + ) +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala new file mode 100644 index 0000000000..ce92b6932b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -0,0 +1,32 @@ +package org.apache.spark.ui.jobs + +import javax.servlet.http.HttpServletRequest + +import scala.xml.{NodeSeq, Node} +import scala.collection.mutable.HashSet + +import org.apache.spark.scheduler.Stage +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.ui.Page._ + +/** Page showing specific pool details */ +private[spark] class PoolPage(parent: JobProgressUI) { + def listener = parent.listener + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val poolName = request.getParameter("poolname") + val poolToActiveStages = listener.poolToActiveStages + val activeStages = poolToActiveStages.get(poolName).toSeq.flatten + val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent) + + val pool = listener.sc.getPoolForName(poolName).get + val poolTable = new PoolTable(Seq(pool), listener) + + val content =

Summary

++ poolTable.toNodeSeq() ++ +

{activeStages.size} Active Stages

++ activeStagesTable.toNodeSeq() + + headerSparkPage(content, parent.sc, "Fair Scheduler Pool: " + poolName, Stages) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala new file mode 100644 index 0000000000..f31465e59d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -0,0 +1,55 @@ +package org.apache.spark.ui.jobs + +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.xml.Node + +import org.apache.spark.scheduler.Stage +import org.apache.spark.scheduler.cluster.Schedulable + +/** Table showing list of pools */ +private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressListener) { + + var poolToActiveStages: HashMap[String, HashSet[Stage]] = listener.poolToActiveStages + + def toNodeSeq(): Seq[Node] = { + listener.synchronized { + poolTable(poolRow, pools) + } + } + + private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[Stage]]) => Seq[Node], + rows: Seq[Schedulable] + ): Seq[Node] = { + + + + + + + + + + + {rows.map(r => makeRow(r, poolToActiveStages))} + +
Pool NameMinimum SharePool WeightActive StagesRunning TasksSchedulingMode
+ } + + private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[Stage]]) + : Seq[Node] = { + val activeStages = poolToActiveStages.get(p.name) match { + case Some(stages) => stages.size + case None => 0 + } + + {p.name} + {p.minShare} + {p.weight} + {activeStages} + {p.runningTasks} + {p.schedulingMode} + + } +} + diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala new file mode 100644 index 0000000000..2fe85bc0cf --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import java.util.Date + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.ui.Page._ +import org.apache.spark.util.Distribution +import org.apache.spark.{ExceptionFailure, Utils} +import org.apache.spark.scheduler.cluster.TaskInfo +import org.apache.spark.executor.TaskMetrics + +/** Page showing statistics and task list for a given stage */ +private[spark] class StagePage(parent: JobProgressUI) { + def listener = parent.listener + val dateFmt = parent.dateFmt + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val stageId = request.getParameter("id").toInt + val now = System.currentTimeMillis() + + if (!listener.stageToTaskInfos.contains(stageId)) { + val content = +
+

Summary Metrics

No tasks have started yet +

Tasks

No tasks have started yet +
+ return headerSparkPage(content, parent.sc, "Details for Stage %s".format(stageId), Stages) + } + + val tasks = listener.stageToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime) + + val numCompleted = tasks.count(_._1.finished) + val shuffleReadBytes = listener.stageToShuffleRead.getOrElse(stageId, 0L) + val hasShuffleRead = shuffleReadBytes > 0 + val shuffleWriteBytes = listener.stageToShuffleWrite.getOrElse(stageId, 0L) + val hasShuffleWrite = shuffleWriteBytes > 0 + + var activeTime = 0L + listener.stageToTasksActive(stageId).foreach(activeTime += _.timeRunning(now)) + + val summary = +
+
    +
  • + CPU time: + {parent.formatDuration(listener.stageToTime.getOrElse(stageId, 0L) + activeTime)} +
  • + {if (hasShuffleRead) +
  • + Shuffle read: + {Utils.bytesToString(shuffleReadBytes)} +
  • + } + {if (hasShuffleWrite) +
  • + Shuffle write: + {Utils.bytesToString(shuffleWriteBytes)} +
  • + } +
+
+ + val taskHeaders: Seq[String] = + Seq("Task ID", "Status", "Locality Level", "Executor", "Launch Time", "Duration") ++ + Seq("GC Time") ++ + {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ + {if (hasShuffleWrite) Seq("Shuffle Write") else Nil} ++ + Seq("Errors") + + val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite), tasks) + + // Excludes tasks which failed and have incomplete metrics + val validTasks = tasks.filter(t => t._1.status == "SUCCESS" && (t._2.isDefined)) + + val summaryTable: Option[Seq[Node]] = + if (validTasks.size == 0) { + None + } + else { + val serviceTimes = validTasks.map{case (info, metrics, exception) => + metrics.get.executorRunTime.toDouble} + val serviceQuantiles = "Duration" +: Distribution(serviceTimes).get.getQuantiles().map( + ms => parent.formatDuration(ms.toLong)) + + def getQuantileCols(data: Seq[Double]) = + Distribution(data).get.getQuantiles().map(d => Utils.bytesToString(d.toLong)) + + val shuffleReadSizes = validTasks.map { + case(info, metrics, exception) => + metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble + } + val shuffleReadQuantiles = "Shuffle Read (Remote)" +: getQuantileCols(shuffleReadSizes) + + val shuffleWriteSizes = validTasks.map { + case(info, metrics, exception) => + metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble + } + val shuffleWriteQuantiles = "Shuffle Write" +: getQuantileCols(shuffleWriteSizes) + + val listings: Seq[Seq[String]] = Seq(serviceQuantiles, + if (hasShuffleRead) shuffleReadQuantiles else Nil, + if (hasShuffleWrite) shuffleWriteQuantiles else Nil) + + val quantileHeaders = Seq("Metric", "Min", "25th percentile", + "Median", "75th percentile", "Max") + def quantileRow(data: Seq[String]): Seq[Node] = {data.map(d => {d})} + Some(listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true)) + } + + val content = + summary ++ +

Summary Metrics for {numCompleted} Completed Tasks

++ +
{summaryTable.getOrElse("No tasks have reported metrics yet.")}
++ +

Tasks

++ taskTable; + + headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages) + } + } + + + def taskRow(shuffleRead: Boolean, shuffleWrite: Boolean) + (taskData: (TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])): Seq[Node] = { + def fmtStackTrace(trace: Seq[StackTraceElement]): Seq[Node] = + trace.map(e => {e.toString}) + val (info, metrics, exception) = taskData + + val duration = if (info.status == "RUNNING") info.timeRunning(System.currentTimeMillis()) + else metrics.map(m => m.executorRunTime).getOrElse(1) + val formatDuration = if (info.status == "RUNNING") parent.formatDuration(duration) + else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("") + val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L) + + + {info.taskId} + {info.status} + {info.taskLocality} + {info.host} + {dateFmt.format(new Date(info.launchTime))} + + {formatDuration} + + + {if (gcTime > 0) parent.formatDuration(gcTime) else ""} + + {if (shuffleRead) { + {metrics.flatMap{m => m.shuffleReadMetrics}.map{s => + Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")} + }} + {if (shuffleWrite) { + {metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => + Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")} + }} + {exception.map(e => + + {e.className} ({e.description})
+ {fmtStackTrace(e.stackTrace)} +
).getOrElse("")} + + + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala new file mode 100644 index 0000000000..beb0574548 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -0,0 +1,107 @@ +package org.apache.spark.ui.jobs + +import java.util.Date + +import scala.xml.Node +import scala.collection.mutable.HashSet + +import org.apache.spark.Utils +import org.apache.spark.scheduler.cluster.{SchedulingMode, TaskInfo} +import org.apache.spark.scheduler.Stage + + +/** Page showing list of all ongoing and recently finished stages */ +private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressUI) { + + val listener = parent.listener + val dateFmt = parent.dateFmt + val isFairScheduler = listener.sc.getSchedulingMode == SchedulingMode.FAIR + + def toNodeSeq(): Seq[Node] = { + listener.synchronized { + stageTable(stageRow, stages) + } + } + + /** Special table which merges two header cells. */ + private def stageTable[T](makeRow: T => Seq[Node], rows: Seq[T]): Seq[Node] = { + + + + {if (isFairScheduler) {} else {}} + + + + + + + + + {rows.map(r => makeRow(r))} + +
Stage IdPool NameDescriptionSubmittedDurationTasks: Succeeded/TotalShuffle ReadShuffle Write
+ } + + private def makeProgressBar(started: Int, completed: Int, failed: String, total: Int): Seq[Node] = { + val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) + val startWidth = "width: %s%%".format((started.toDouble/total)*100) + +
+ + {completed}/{total} {failed} + +
+
+
+ } + + + private def stageRow(s: Stage): Seq[Node] = { + val submissionTime = s.submissionTime match { + case Some(t) => dateFmt.format(new Date(t)) + case None => "Unknown" + } + + val shuffleRead = listener.stageToShuffleRead.getOrElse(s.id, 0L) match { + case 0 => "" + case b => Utils.bytesToString(b) + } + val shuffleWrite = listener.stageToShuffleWrite.getOrElse(s.id, 0L) match { + case 0 => "" + case b => Utils.bytesToString(b) + } + + val startedTasks = listener.stageToTasksActive.getOrElse(s.id, HashSet[TaskInfo]()).size + val completedTasks = listener.stageToTasksComplete.getOrElse(s.id, 0) + val failedTasks = listener.stageToTasksFailed.getOrElse(s.id, 0) match { + case f if f > 0 => "(%s failed)".format(f) + case _ => "" + } + val totalTasks = s.numPartitions + + val poolName = listener.stageToPool.get(s) + + val nameLink = {s.name} + val description = listener.stageToDescription.get(s) + .map(d =>
{d}
{nameLink}
).getOrElse(nameLink) + val finishTime = s.completionTime.getOrElse(System.currentTimeMillis()) + val duration = s.submissionTime.map(t => finishTime - t) + + + {s.id} + {if (isFairScheduler) { + {poolName.get}} + } + {description} + {submissionTime} + + {duration.map(d => parent.formatDuration(d)).getOrElse("Unknown")} + + + {makeProgressBar(startedTasks, completedTasks, failedTasks, totalTasks)} + + {shuffleRead} + {shuffleWrite} + + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala new file mode 100644 index 0000000000..1d633d374a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.storage + +import akka.util.Duration + +import javax.servlet.http.HttpServletRequest + +import org.eclipse.jetty.server.Handler + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.ui.JettyUtils._ + +/** Web UI showing storage status of all RDD's in the given SparkContext. */ +private[spark] class BlockManagerUI(val sc: SparkContext) extends Logging { + implicit val timeout = Duration.create( + System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + + val indexPage = new IndexPage(this) + val rddPage = new RDDPage(this) + + def getHandlers = Seq[(String, Handler)]( + ("/storage/rdd", (request: HttpServletRequest) => rddPage.render(request)), + ("/storage", (request: HttpServletRequest) => indexPage.render(request)) + ) +} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala new file mode 100644 index 0000000000..1eb4a7a85e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.storage + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.storage.{RDDInfo, StorageUtils} +import org.apache.spark.Utils +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.ui.Page._ + +/** Page showing list of RDD's currently stored in the cluster */ +private[spark] class IndexPage(parent: BlockManagerUI) { + val sc = parent.sc + + def render(request: HttpServletRequest): Seq[Node] = { + val storageStatusList = sc.getExecutorStorageStatus + // Calculate macro-level statistics + + val rddHeaders = Seq( + "RDD Name", + "Storage Level", + "Cached Partitions", + "Fraction Cached", + "Size in Memory", + "Size on Disk") + val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) + val content = listingTable(rddHeaders, rddRow, rdds) + + headerSparkPage(content, parent.sc, "Storage ", Storage) + } + + def rddRow(rdd: RDDInfo): Seq[Node] = { + + + + {rdd.name} + + + {rdd.storageLevel.description} + + {rdd.numCachedPartitions} + {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} + {Utils.bytesToString(rdd.memSize)} + {Utils.bytesToString(rdd.diskSize)} + + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala new file mode 100644 index 0000000000..37baf17f7a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.storage + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.Utils +import org.apache.spark.storage.{StorageStatus, StorageUtils} +import org.apache.spark.storage.BlockManagerMasterActor.BlockStatus +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.ui.Page._ + + +/** Page showing storage details for a given RDD */ +private[spark] class RDDPage(parent: BlockManagerUI) { + val sc = parent.sc + + def render(request: HttpServletRequest): Seq[Node] = { + val id = request.getParameter("id") + val prefix = "rdd_" + id.toString + val storageStatusList = sc.getExecutorStorageStatus + val filteredStorageStatusList = StorageUtils. + filterStorageStatusByPrefix(storageStatusList, prefix) + val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head + + val workerHeaders = Seq("Host", "Memory Usage", "Disk Usage") + val workers = filteredStorageStatusList.map((prefix, _)) + val workerTable = listingTable(workerHeaders, workerRow, workers) + + val blockHeaders = Seq("Block Name", "Storage Level", "Size in Memory", "Size on Disk", + "Executors") + + val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1) + val blockLocations = StorageUtils.blockLocationsFromStorageStatus(filteredStorageStatusList) + val blocks = blockStatuses.map { + case(id, status) => (id, status, blockLocations.get(id).getOrElse(Seq("UNKNOWN"))) + } + val blockTable = listingTable(blockHeaders, blockRow, blocks) + + val content = +
+
+
    +
  • + Storage Level: + {rddInfo.storageLevel.description} +
  • +
  • + Cached Partitions: + {rddInfo.numCachedPartitions} +
  • +
  • + Total Partitions: + {rddInfo.numPartitions} +
  • +
  • + Memory Size: + {Utils.bytesToString(rddInfo.memSize)} +
  • +
  • + Disk Size: + {Utils.bytesToString(rddInfo.diskSize)} +
  • +
+
+
+ +
+
+

Data Distribution on {workers.size} Executors

+ {workerTable} +
+
+ +
+
+

{blocks.size} Partitions

+ {blockTable} +
+
; + + headerSparkPage(content, parent.sc, "RDD Storage Info for " + rddInfo.name, Storage) + } + + def blockRow(row: (String, BlockStatus, Seq[String])): Seq[Node] = { + val (id, block, locations) = row + + {id} + + {block.storageLevel.description} + + + {Utils.bytesToString(block.memSize)} + + + {Utils.bytesToString(block.diskSize)} + + + {locations.map(l => {l}
)} + + + } + + def workerRow(worker: (String, StorageStatus)): Seq[Node] = { + val (prefix, status) = worker + + {status.blockManagerId.host + ":" + status.blockManagerId.port} + + {Utils.bytesToString(status.memUsed(prefix))} + ({Utils.bytesToString(status.memRemaining)} Remaining) + + {Utils.bytesToString(status.diskUsed(prefix))} + + } +} diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala new file mode 100644 index 0000000000..d4c5065c3f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import akka.actor.{ActorSystem, ExtendedActorSystem} +import com.typesafe.config.ConfigFactory +import akka.util.duration._ +import akka.remote.RemoteActorRefProvider + + +/** + * Various utility classes for working with Akka. + */ +private[spark] object AkkaUtils { + + /** + * Creates an ActorSystem ready for remoting, with various Spark features. Returns both the + * ActorSystem itself and its port (which is hard to get from Akka). + * + * Note: the `name` parameter is important, as even if a client sends a message to right + * host + port, if the system name is incorrect, Akka will drop the message. + */ + def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = { + val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt + val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt + val akkaTimeout = System.getProperty("spark.akka.timeout", "60").toInt + val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt + val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off" + // 10 seconds is the default akka timeout, but in a cluster, we need higher by default. + val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt + + val akkaConf = ConfigFactory.parseString(""" + akka.daemonic = on + akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] + akka.stdout-loglevel = "ERROR" + akka.actor.provider = "akka.remote.RemoteActorRefProvider" + akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" + akka.remote.netty.hostname = "%s" + akka.remote.netty.port = %d + akka.remote.netty.connection-timeout = %ds + akka.remote.netty.message-frame-size = %d MiB + akka.remote.netty.execution-pool-size = %d + akka.actor.default-dispatcher.throughput = %d + akka.remote.log-remote-lifecycle-events = %s + akka.remote.netty.write-timeout = %ds + """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize, + lifecycleEvents, akkaWriteTimeout)) + + val actorSystem = ActorSystem(name, akkaConf) + + // Figure out the port number we bound to, in case port was passed as 0. This is a bit of a + // hack because Akka doesn't let you figure out the port through the public API yet. + val provider = actorSystem.asInstanceOf[ExtendedActorSystem].provider + val boundPort = provider.asInstanceOf[RemoteActorRefProvider].transport.address.port.get + return (actorSystem, boundPort) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala new file mode 100644 index 0000000000..0b51c23f7b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.Serializable +import java.util.{PriorityQueue => JPriorityQueue} +import scala.collection.generic.Growable +import scala.collection.JavaConverters._ + +/** + * Bounded priority queue. This class wraps the original PriorityQueue + * class and modifies it such that only the top K elements are retained. + * The top K elements are defined by an implicit Ordering[A]. + */ +class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) + extends Iterable[A] with Growable[A] with Serializable { + + private val underlying = new JPriorityQueue[A](maxSize, ord) + + override def iterator: Iterator[A] = underlying.iterator.asScala + + override def ++=(xs: TraversableOnce[A]): this.type = { + xs.foreach { this += _ } + this + } + + override def +=(elem: A): this.type = { + if (size < maxSize) underlying.offer(elem) + else maybeReplaceLowest(elem) + this + } + + override def +=(elem1: A, elem2: A, elems: A*): this.type = { + this += elem1 += elem2 ++= elems + } + + override def clear() { underlying.clear() } + + private def maybeReplaceLowest(a: A): Boolean = { + val head = underlying.peek() + if (head != null && ord.gt(a, head)) { + underlying.poll() + underlying.offer(a) + } else false + } +} + diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala new file mode 100644 index 0000000000..e214d2a519 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.InputStream +import java.nio.ByteBuffer +import org.apache.spark.storage.BlockManager + +/** + * Reads data from a ByteBuffer, and optionally cleans it up using BlockManager.dispose() + * at the end of the stream (e.g. to close a memory-mapped file). + */ +private[spark] +class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = false) + extends InputStream { + + override def read(): Int = { + if (buffer == null || buffer.remaining() == 0) { + cleanUp() + -1 + } else { + buffer.get() & 0xFF + } + } + + override def read(dest: Array[Byte]): Int = { + read(dest, 0, dest.length) + } + + override def read(dest: Array[Byte], offset: Int, length: Int): Int = { + if (buffer == null || buffer.remaining() == 0) { + cleanUp() + -1 + } else { + val amountToGet = math.min(buffer.remaining(), length) + buffer.get(dest, offset, amountToGet) + amountToGet + } + } + + override def skip(bytes: Long): Long = { + if (buffer != null) { + val amountToSkip = math.min(bytes, buffer.remaining).toInt + buffer.position(buffer.position + amountToSkip) + if (buffer.remaining() == 0) { + cleanUp() + } + amountToSkip + } else { + 0L + } + } + + /** + * Clean up the buffer, and potentially dispose of it using BlockManager.dispose(). + */ + private def cleanUp() { + if (buffer != null) { + if (dispose) { + BlockManager.dispose(buffer) + } + buffer = null + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Clock.scala b/core/src/main/scala/org/apache/spark/util/Clock.scala new file mode 100644 index 0000000000..97c2b45aab --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/Clock.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * An interface to represent clocks, so that they can be mocked out in unit tests. + */ +private[spark] trait Clock { + def getTime(): Long +} + +private[spark] object SystemClock extends Clock { + def getTime(): Long = System.currentTimeMillis() +} diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala new file mode 100644 index 0000000000..dc15a38b29 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Wrapper around an iterator which calls a completion method after it successfully iterates through all the elements + */ +abstract class CompletionIterator[+A, +I <: Iterator[A]](sub: I) extends Iterator[A]{ + def next = sub.next + def hasNext = { + val r = sub.hasNext + if (!r) { + completion + } + r + } + + def completion() +} + +object CompletionIterator { + def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A,I] = { + new CompletionIterator[A,I](sub) { + def completion() = completionFunction + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala new file mode 100644 index 0000000000..33bf3562fe --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.PrintStream + +/** + * Util for getting some stats from a small sample of numeric values, with some handy summary functions. + * + * Entirely in memory, not intended as a good way to compute stats over large data sets. + * + * Assumes you are giving it a non-empty set of data + */ +class Distribution(val data: Array[Double], val startIdx: Int, val endIdx: Int) { + require(startIdx < endIdx) + def this(data: Traversable[Double]) = this(data.toArray, 0, data.size) + java.util.Arrays.sort(data, startIdx, endIdx) + val length = endIdx - startIdx + + val defaultProbabilities = Array(0,0.25,0.5,0.75,1.0) + + /** + * Get the value of the distribution at the given probabilities. Probabilities should be + * given from 0 to 1 + * @param probabilities + */ + def getQuantiles(probabilities: Traversable[Double] = defaultProbabilities) = { + probabilities.toIndexedSeq.map{p:Double => data(closestIndex(p))} + } + + private def closestIndex(p: Double) = { + math.min((p * length).toInt + startIdx, endIdx - 1) + } + + def showQuantiles(out: PrintStream = System.out) = { + out.println("min\t25%\t50%\t75%\tmax") + getQuantiles(defaultProbabilities).foreach{q => out.print(q + "\t")} + out.println + } + + def statCounter = StatCounter(data.slice(startIdx, endIdx)) + + /** + * print a summary of this distribution to the given PrintStream. + * @param out + */ + def summary(out: PrintStream = System.out) { + out.println(statCounter) + showQuantiles(out) + } +} + +object Distribution { + + def apply(data: Traversable[Double]): Option[Distribution] = { + if (data.size > 0) + Some(new Distribution(data)) + else + None + } + + def showQuantiles(out: PrintStream = System.out, quantiles: Traversable[Double]) { + out.println("min\t25%\t50%\t75%\tmax") + quantiles.foreach{q => out.print(q + "\t")} + out.println + } +} diff --git a/core/src/main/scala/org/apache/spark/util/IdGenerator.scala b/core/src/main/scala/org/apache/spark/util/IdGenerator.scala new file mode 100644 index 0000000000..17e55f7996 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/IdGenerator.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.atomic.AtomicInteger + +/** + * A util used to get a unique generation ID. This is a wrapper around Java's + * AtomicInteger. An example usage is in BlockManager, where each BlockManager + * instance would start an Akka actor and we use this utility to assign the Akka + * actors unique names. + */ +private[spark] class IdGenerator { + private var id = new AtomicInteger + def next: Int = id.incrementAndGet +} diff --git a/core/src/main/scala/org/apache/spark/util/IntParam.scala b/core/src/main/scala/org/apache/spark/util/IntParam.scala new file mode 100644 index 0000000000..626bb49eea --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/IntParam.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * An extractor object for parsing strings into integers. + */ +private[spark] object IntParam { + def unapply(str: String): Option[Int] = { + try { + Some(str.toInt) + } catch { + case e: NumberFormatException => None + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/MemoryParam.scala b/core/src/main/scala/org/apache/spark/util/MemoryParam.scala new file mode 100644 index 0000000000..0ee6707826 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/MemoryParam.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.Utils + +/** + * An extractor object for parsing JVM memory strings, such as "10g", into an Int representing + * the number of megabytes. Supports the same formats as Utils.memoryStringToMb. + */ +private[spark] object MemoryParam { + def unapply(str: String): Option[Int] = { + try { + Some(Utils.memoryStringToMb(str)) + } catch { + case e: NumberFormatException => None + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala new file mode 100644 index 0000000000..a430a75451 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} +import java.util.{TimerTask, Timer} +import org.apache.spark.Logging + + +/** + * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) + */ +class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { + private val delaySeconds = MetadataCleaner.getDelaySeconds + private val periodSeconds = math.max(10, delaySeconds / 10) + private val timer = new Timer(name + " cleanup timer", true) + + private val task = new TimerTask { + override def run() { + try { + cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) + logInfo("Ran metadata cleaner for " + name) + } catch { + case e: Exception => logError("Error running cleanup task for " + name, e) + } + } + } + + if (delaySeconds > 0) { + logDebug( + "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " + + "and period of " + periodSeconds + " secs") + timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) + } + + def cancel() { + timer.cancel() + } +} + + +object MetadataCleaner { + def getDelaySeconds = System.getProperty("spark.cleaner.ttl", "-1").toInt + def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.ttl", delay.toString) } +} + diff --git a/core/src/main/scala/org/apache/spark/util/MutablePair.scala b/core/src/main/scala/org/apache/spark/util/MutablePair.scala new file mode 100644 index 0000000000..34f1f6606f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/MutablePair.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + + +/** + * A tuple of 2 elements. This can be used as an alternative to Scala's Tuple2 when we want to + * minimize object allocation. + * + * @param _1 Element 1 of this MutablePair + * @param _2 Element 2 of this MutablePair + */ +case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T1, + @specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2] + (var _1: T1, var _2: T2) + extends Product2[T1, T2] +{ + override def toString = "(" + _1 + "," + _2 + ")" + + override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_,_]] +} diff --git a/core/src/main/scala/org/apache/spark/util/NextIterator.scala b/core/src/main/scala/org/apache/spark/util/NextIterator.scala new file mode 100644 index 0000000000..8266e5e495 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/NextIterator.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** Provides a basic/boilerplate Iterator implementation. */ +private[spark] abstract class NextIterator[U] extends Iterator[U] { + + private var gotNext = false + private var nextValue: U = _ + private var closed = false + protected var finished = false + + /** + * Method for subclasses to implement to provide the next element. + * + * If no next element is available, the subclass should set `finished` + * to `true` and may return any value (it will be ignored). + * + * This convention is required because `null` may be a valid value, + * and using `Option` seems like it might create unnecessary Some/None + * instances, given some iterators might be called in a tight loop. + * + * @return U, or set 'finished' when done + */ + protected def getNext(): U + + /** + * Method for subclasses to implement when all elements have been successfully + * iterated, and the iteration is done. + * + * Note: `NextIterator` cannot guarantee that `close` will be + * called because it has no control over what happens when an exception + * happens in the user code that is calling hasNext/next. + * + * Ideally you should have another try/catch, as in HadoopRDD, that + * ensures any resources are closed should iteration fail. + */ + protected def close() + + /** + * Calls the subclass-defined close method, but only once. + * + * Usually calling `close` multiple times should be fine, but historically + * there have been issues with some InputFormats throwing exceptions. + */ + def closeIfNeeded() { + if (!closed) { + close() + closed = true + } + } + + override def hasNext: Boolean = { + if (!finished) { + if (!gotNext) { + nextValue = getNext() + if (finished) { + closeIfNeeded() + } + gotNext = true + } + } + !finished + } + + override def next(): U = { + if (!hasNext) { + throw new NoSuchElementException("End of stream") + } + gotNext = false + nextValue + } +} diff --git a/core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala new file mode 100644 index 0000000000..47e1b45004 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.annotation.tailrec + +import java.io.OutputStream +import java.util.concurrent.TimeUnit._ + +class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { + val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) + val CHUNK_SIZE = 8192 + var lastSyncTime = System.nanoTime + var bytesWrittenSinceSync: Long = 0 + + override def write(b: Int) { + waitToWrite(1) + out.write(b) + } + + override def write(bytes: Array[Byte]) { + write(bytes, 0, bytes.length) + } + + @tailrec + override final def write(bytes: Array[Byte], offset: Int, length: Int) { + val writeSize = math.min(length - offset, CHUNK_SIZE) + if (writeSize > 0) { + waitToWrite(writeSize) + out.write(bytes, offset, writeSize) + write(bytes, offset + writeSize, length) + } + } + + override def flush() { + out.flush() + } + + override def close() { + out.close() + } + + @tailrec + private def waitToWrite(numBytes: Int) { + val now = System.nanoTime + val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS) + val rate = bytesWrittenSinceSync.toDouble / elapsedSecs + if (rate < bytesPerSec) { + // It's okay to write; just update some variables and return + bytesWrittenSinceSync += numBytes + if (now > lastSyncTime + SYNC_INTERVAL) { + // Sync interval has passed; let's resync + lastSyncTime = now + bytesWrittenSinceSync = numBytes + } + } else { + // Calculate how much time we should sleep to bring ourselves to the desired rate. + // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala) + val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) + if (sleepTime > 0) Thread.sleep(sleepTime) + waitToWrite(numBytes) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala new file mode 100644 index 0000000000..f2b1ad7d0e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.nio.ByteBuffer +import java.io.{IOException, ObjectOutputStream, EOFException, ObjectInputStream} +import java.nio.channels.Channels + +/** + * A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make + * it easier to pass ByteBuffers in case class messages. + */ +private[spark] +class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable { + def value = buffer + + private def readObject(in: ObjectInputStream) { + val length = in.readInt() + buffer = ByteBuffer.allocate(length) + var amountRead = 0 + val channel = Channels.newChannel(in) + while (amountRead < length) { + val ret = channel.read(buffer) + if (ret == -1) { + throw new EOFException("End of file before fully reading buffer") + } + amountRead += ret + } + buffer.rewind() // Allow us to read it later + } + + private def writeObject(out: ObjectOutputStream) { + out.writeInt(buffer.limit()) + if (Channels.newChannel(out).write(buffer) != buffer.limit()) { + throw new IOException("Could not fully write buffer to output stream") + } + buffer.rewind() // Allow us to write it again later + } +} diff --git a/core/src/main/scala/org/apache/spark/util/StatCounter.scala b/core/src/main/scala/org/apache/spark/util/StatCounter.scala new file mode 100644 index 0000000000..020d5edba9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/StatCounter.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * A class for tracking the statistics of a set of numbers (count, mean and variance) in a + * numerically robust way. Includes support for merging two StatCounters. Based on + * [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance Welford and Chan's algorithms for running variance]]. + * + * @constructor Initialize the StatCounter with the given values. + */ +class StatCounter(values: TraversableOnce[Double]) extends Serializable { + private var n: Long = 0 // Running count of our values + private var mu: Double = 0 // Running mean of our values + private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2) + + merge(values) + + /** Initialize the StatCounter with no values. */ + def this() = this(Nil) + + /** Add a value into this StatCounter, updating the internal statistics. */ + def merge(value: Double): StatCounter = { + val delta = value - mu + n += 1 + mu += delta / n + m2 += delta * (value - mu) + this + } + + /** Add multiple values into this StatCounter, updating the internal statistics. */ + def merge(values: TraversableOnce[Double]): StatCounter = { + values.foreach(v => merge(v)) + this + } + + /** Merge another StatCounter into this one, adding up the internal statistics. */ + def merge(other: StatCounter): StatCounter = { + if (other == this) { + merge(other.copy()) // Avoid overwriting fields in a weird order + } else { + if (n == 0) { + mu = other.mu + m2 = other.m2 + n = other.n + } else if (other.n != 0) { + val delta = other.mu - mu + if (other.n * 10 < n) { + mu = mu + (delta * other.n) / (n + other.n) + } else if (n * 10 < other.n) { + mu = other.mu - (delta * n) / (n + other.n) + } else { + mu = (mu * n + other.mu * other.n) / (n + other.n) + } + m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) + n += other.n + } + this + } + } + + /** Clone this StatCounter */ + def copy(): StatCounter = { + val other = new StatCounter + other.n = n + other.mu = mu + other.m2 = m2 + other + } + + def count: Long = n + + def mean: Double = mu + + def sum: Double = n * mu + + /** Return the variance of the values. */ + def variance: Double = { + if (n == 0) + Double.NaN + else + m2 / n + } + + /** + * Return the sample variance, which corrects for bias in estimating the variance by dividing + * by N-1 instead of N. + */ + def sampleVariance: Double = { + if (n <= 1) + Double.NaN + else + m2 / (n - 1) + } + + /** Return the standard deviation of the values. */ + def stdev: Double = math.sqrt(variance) + + /** + * Return the sample standard deviation of the values, which corrects for bias in estimating the + * variance by dividing by N-1 instead of N. + */ + def sampleStdev: Double = math.sqrt(sampleVariance) + + override def toString: String = { + "(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev) + } +} + +object StatCounter { + /** Build a StatCounter from a list of values. */ + def apply(values: TraversableOnce[Double]) = new StatCounter(values) + + /** Build a StatCounter from a list of values passed as variable-length arguments. */ + def apply(values: Double*) = new StatCounter(values) +} diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala new file mode 100644 index 0000000000..277de2f8a6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConversions +import scala.collection.mutable.Map +import scala.collection.immutable +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.Logging + +/** + * This is a custom implementation of scala.collection.mutable.Map which stores the insertion + * time stamp along with each key-value pair. Key-value pairs that are older than a particular + * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in + * replacement of scala.collection.mutable.HashMap. + */ +class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging { + val internalMap = new ConcurrentHashMap[A, (B, Long)]() + + def get(key: A): Option[B] = { + val value = internalMap.get(key) + if (value != null) Some(value._1) else None + } + + def iterator: Iterator[(A, B)] = { + val jIterator = internalMap.entrySet().iterator() + JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1)) + } + + override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = { + val newMap = new TimeStampedHashMap[A, B1] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.put(kv._1, (kv._2, currentTime)) + newMap + } + + override def - (key: A): Map[A, B] = { + val newMap = new TimeStampedHashMap[A, B] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.remove(key) + newMap + } + + override def += (kv: (A, B)): this.type = { + internalMap.put(kv._1, (kv._2, currentTime)) + this + } + + // Should we return previous value directly or as Option ? + def putIfAbsent(key: A, value: B): Option[B] = { + val prev = internalMap.putIfAbsent(key, (value, currentTime)) + if (prev != null) Some(prev._1) else None + } + + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def update(key: A, value: B) { + this += ((key, value)) + } + + override def apply(key: A): B = { + val value = internalMap.get(key) + if (value == null) throw new NoSuchElementException() + value._1 + } + + override def filter(p: ((A, B)) => Boolean): Map[A, B] = { + JavaConversions.asScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p) + } + + override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() + + override def size: Int = internalMap.size + + override def foreach[U](f: ((A, B)) => U) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val kv = (entry.getKey, entry.getValue._1) + f(kv) + } + } + + def toMap: immutable.Map[A, B] = iterator.toMap + + /** + * Removes old key-value pairs that have timestamp earlier than `threshTime` + */ + def clearOldValues(threshTime: Long) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue._2 < threshTime) { + logDebug("Removing key " + entry.getKey) + iterator.remove() + } + } + } + + private def currentTime: Long = System.currentTimeMillis() + +} diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala new file mode 100644 index 0000000000..26983138ff --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.collection.mutable.Set +import scala.collection.JavaConversions +import java.util.concurrent.ConcurrentHashMap + + +class TimeStampedHashSet[A] extends Set[A] { + val internalMap = new ConcurrentHashMap[A, Long]() + + def contains(key: A): Boolean = { + internalMap.contains(key) + } + + def iterator: Iterator[A] = { + val jIterator = internalMap.entrySet().iterator() + JavaConversions.asScalaIterator(jIterator).map(_.getKey) + } + + override def + (elem: A): Set[A] = { + val newSet = new TimeStampedHashSet[A] + newSet ++= this + newSet += elem + newSet + } + + override def - (elem: A): Set[A] = { + val newSet = new TimeStampedHashSet[A] + newSet ++= this + newSet -= elem + newSet + } + + override def += (key: A): this.type = { + internalMap.put(key, currentTime) + this + } + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def empty: Set[A] = new TimeStampedHashSet[A]() + + override def size(): Int = internalMap.size() + + override def foreach[U](f: (A) => U): Unit = { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + f(iterator.next.getKey) + } + } + + /** + * Removes old values that have timestamp earlier than `threshTime` + */ + def clearOldValues(threshTime: Long) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue < threshTime) { + iterator.remove() + } + } + } + + private def currentTime: Long = System.currentTimeMillis() +} diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala new file mode 100644 index 0000000000..fe710c58ac --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/Vector.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +class Vector(val elements: Array[Double]) extends Serializable { + def length = elements.length + + def apply(index: Int) = elements(index) + + def + (other: Vector): Vector = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + return Vector(length, i => this(i) + other(i)) + } + + def add(other: Vector) = this + other + + def - (other: Vector): Vector = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + return Vector(length, i => this(i) - other(i)) + } + + def subtract(other: Vector) = this - other + + def dot(other: Vector): Double = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + var ans = 0.0 + var i = 0 + while (i < length) { + ans += this(i) * other(i) + i += 1 + } + return ans + } + + /** + * return (this + plus) dot other, but without creating any intermediate storage + * @param plus + * @param other + * @return + */ + def plusDot(plus: Vector, other: Vector): Double = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + if (length != plus.length) + throw new IllegalArgumentException("Vectors of different length") + var ans = 0.0 + var i = 0 + while (i < length) { + ans += (this(i) + plus(i)) * other(i) + i += 1 + } + return ans + } + + def += (other: Vector): Vector = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + var i = 0 + while (i < length) { + elements(i) += other(i) + i += 1 + } + this + } + + def addInPlace(other: Vector) = this +=other + + def * (scale: Double): Vector = Vector(length, i => this(i) * scale) + + def multiply (d: Double) = this * d + + def / (d: Double): Vector = this * (1 / d) + + def divide (d: Double) = this / d + + def unary_- = this * -1 + + def sum = elements.reduceLeft(_ + _) + + def squaredDist(other: Vector): Double = { + var ans = 0.0 + var i = 0 + while (i < length) { + ans += (this(i) - other(i)) * (this(i) - other(i)) + i += 1 + } + return ans + } + + def dist(other: Vector): Double = math.sqrt(squaredDist(other)) + + override def toString = elements.mkString("(", ", ", ")") +} + +object Vector { + def apply(elements: Array[Double]) = new Vector(elements) + + def apply(elements: Double*) = new Vector(elements.toArray) + + def apply(length: Int, initializer: Int => Double): Vector = { + val elements: Array[Double] = Array.tabulate(length)(initializer) + return new Vector(elements) + } + + def zeros(length: Int) = new Vector(new Array[Double](length)) + + def ones(length: Int) = Vector(length, _ => 1) + + class Multiplier(num: Double) { + def * (vec: Vector) = vec * num + } + + implicit def doubleToMultiplier(num: Double) = new Multiplier(num) + + implicit object VectorAccumParam extends org.apache.spark.AccumulatorParam[Vector] { + def addInPlace(t1: Vector, t2: Vector) = t1 + t2 + + def zero(initialValue: Vector) = Vector.zeros(initialValue.length) + } + +} diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala deleted file mode 100644 index 6ff92ce833..0000000000 --- a/core/src/main/scala/spark/Accumulators.scala +++ /dev/null @@ -1,256 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io._ - -import scala.collection.mutable.Map -import scala.collection.generic.Growable - -/** - * A datatype that can be accumulated, i.e. has an commutative and associative "add" operation, - * but where the result type, `R`, may be different from the element type being added, `T`. - * - * You must define how to add data, and how to merge two of these together. For some datatypes, - * such as a counter, these might be the same operation. In that case, you can use the simpler - * [[spark.Accumulator]]. They won't always be the same, though -- e.g., imagine you are - * accumulating a set. You will add items to the set, and you will union two sets together. - * - * @param initialValue initial value of accumulator - * @param param helper object defining how to add elements of type `R` and `T` - * @tparam R the full accumulated data (result type) - * @tparam T partial data that can be added in - */ -class Accumulable[R, T] ( - @transient initialValue: R, - param: AccumulableParam[R, T]) - extends Serializable { - - val id = Accumulators.newId - @transient private var value_ = initialValue // Current value on master - val zero = param.zero(initialValue) // Zero value to be passed to workers - var deserialized = false - - Accumulators.register(this, true) - - /** - * Add more data to this accumulator / accumulable - * @param term the data to add - */ - def += (term: T) { value_ = param.addAccumulator(value_, term) } - - /** - * Add more data to this accumulator / accumulable - * @param term the data to add - */ - def add(term: T) { value_ = param.addAccumulator(value_, term) } - - /** - * Merge two accumulable objects together - * - * Normally, a user will not want to use this version, but will instead call `+=`. - * @param term the other `R` that will get merged with this - */ - def ++= (term: R) { value_ = param.addInPlace(value_, term)} - - /** - * Merge two accumulable objects together - * - * Normally, a user will not want to use this version, but will instead call `add`. - * @param term the other `R` that will get merged with this - */ - def merge(term: R) { value_ = param.addInPlace(value_, term)} - - /** - * Access the accumulator's current value; only allowed on master. - */ - def value: R = { - if (!deserialized) { - value_ - } else { - throw new UnsupportedOperationException("Can't read accumulator value in task") - } - } - - /** - * Get the current value of this accumulator from within a task. - * - * This is NOT the global value of the accumulator. To get the global value after a - * completed operation on the dataset, call `value`. - * - * The typical use of this method is to directly mutate the local value, eg., to add - * an element to a Set. - */ - def localValue = value_ - - /** - * Set the accumulator's value; only allowed on master. - */ - def value_= (newValue: R) { - if (!deserialized) value_ = newValue - else throw new UnsupportedOperationException("Can't assign accumulator value in task") - } - - /** - * Set the accumulator's value; only allowed on master - */ - def setValue(newValue: R) { - this.value = newValue - } - - // Called by Java when deserializing an object - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - value_ = zero - deserialized = true - Accumulators.register(this, false) - } - - override def toString = value_.toString -} - -/** - * Helper object defining how to accumulate values of a particular type. An implicit - * AccumulableParam needs to be available when you create Accumulables of a specific type. - * - * @tparam R the full accumulated data (result type) - * @tparam T partial data that can be added in - */ -trait AccumulableParam[R, T] extends Serializable { - /** - * Add additional data to the accumulator value. Is allowed to modify and return `r` - * for efficiency (to avoid allocating objects). - * - * @param r the current value of the accumulator - * @param t the data to be added to the accumulator - * @return the new value of the accumulator - */ - def addAccumulator(r: R, t: T): R - - /** - * Merge two accumulated values together. Is allowed to modify and return the first value - * for efficiency (to avoid allocating objects). - * - * @param r1 one set of accumulated data - * @param r2 another set of accumulated data - * @return both data sets merged together - */ - def addInPlace(r1: R, r2: R): R - - /** - * Return the "zero" (identity) value for an accumulator type, given its initial value. For - * example, if R was a vector of N dimensions, this would return a vector of N zeroes. - */ - def zero(initialValue: R): R -} - -private[spark] -class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T] - extends AccumulableParam[R,T] { - - def addAccumulator(growable: R, elem: T): R = { - growable += elem - growable - } - - def addInPlace(t1: R, t2: R): R = { - t1 ++= t2 - t1 - } - - def zero(initialValue: R): R = { - // We need to clone initialValue, but it's hard to specify that R should also be Cloneable. - // Instead we'll serialize it to a buffer and load it back. - val ser = (new spark.JavaSerializer).newInstance() - val copy = ser.deserialize[R](ser.serialize(initialValue)) - copy.clear() // In case it contained stuff - copy - } -} - -/** - * A simpler value of [[spark.Accumulable]] where the result type being accumulated is the same - * as the types of elements being merged. - * - * @param initialValue initial value of accumulator - * @param param helper object defining how to add elements of type `T` - * @tparam T result type - */ -class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T]) - extends Accumulable[T,T](initialValue, param) - -/** - * A simpler version of [[spark.AccumulableParam]] where the only datatype you can add in is the same type - * as the accumulated value. An implicit AccumulatorParam object needs to be available when you create - * Accumulators of a specific type. - * - * @tparam T type of value to accumulate - */ -trait AccumulatorParam[T] extends AccumulableParam[T, T] { - def addAccumulator(t1: T, t2: T): T = { - addInPlace(t1, t2) - } -} - -// TODO: The multi-thread support in accumulators is kind of lame; check -// if there's a more intuitive way of doing it right -private object Accumulators { - // TODO: Use soft references? => need to make readObject work properly then - val originals = Map[Long, Accumulable[_, _]]() - val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]() - var lastId: Long = 0 - - def newId: Long = synchronized { - lastId += 1 - return lastId - } - - def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized { - if (original) { - originals(a.id) = a - } else { - val accums = localAccums.getOrElseUpdate(Thread.currentThread, Map()) - accums(a.id) = a - } - } - - // Clear the local (non-original) accumulators for the current thread - def clear() { - synchronized { - localAccums.remove(Thread.currentThread) - } - } - - // Get the values of the local accumulators for the current thread (by ID) - def values: Map[Long, Any] = synchronized { - val ret = Map[Long, Any]() - for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) { - ret(id) = accum.localValue - } - return ret - } - - // Add values to the original accumulators with some given IDs - def add(values: Map[Long, Any]): Unit = synchronized { - for ((id, value) <- values) { - if (originals.contains(id)) { - originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value - } - } - } -} diff --git a/core/src/main/scala/spark/Aggregator.scala b/core/src/main/scala/spark/Aggregator.scala deleted file mode 100644 index 9af401986d..0000000000 --- a/core/src/main/scala/spark/Aggregator.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.util.{HashMap => JHashMap} - -import scala.collection.JavaConversions._ - -/** A set of functions used to aggregate data. - * - * @param createCombiner function to create the initial value of the aggregation. - * @param mergeValue function to merge a new value into the aggregation result. - * @param mergeCombiners function to merge outputs from multiple mergeValue function. - */ -case class Aggregator[K, V, C] ( - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C) { - - def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = { - val combiners = new JHashMap[K, C] - for (kv <- iter) { - val oldC = combiners.get(kv._1) - if (oldC == null) { - combiners.put(kv._1, createCombiner(kv._2)) - } else { - combiners.put(kv._1, mergeValue(oldC, kv._2)) - } - } - combiners.iterator - } - - def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = { - val combiners = new JHashMap[K, C] - iter.foreach { case(k, c) => - val oldC = combiners.get(k) - if (oldC == null) { - combiners.put(k, c) - } else { - combiners.put(k, mergeCombiners(oldC, c)) - } - } - combiners.iterator - } -} - diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala deleted file mode 100644 index 1ec95ed9b8..0000000000 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import spark.executor.{ShuffleReadMetrics, TaskMetrics} -import spark.serializer.Serializer -import spark.storage.BlockManagerId -import spark.util.CompletionIterator - - -private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { - - override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) - : Iterator[T] = - { - - logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - val blockManager = SparkEnv.get.blockManager - - val startTime = System.currentTimeMillis - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) - logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( - shuffleId, reduceId, System.currentTimeMillis - startTime)) - - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]] - for (((address, size), index) <- statuses.zipWithIndex) { - splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) - } - - val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map { - case (address, splits) => - (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2))) - } - - def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[T] = { - val blockId = blockPair._1 - val blockOption = blockPair._2 - blockOption match { - case Some(block) => { - block.asInstanceOf[Iterator[T]] - } - case None => { - val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r - blockId match { - case regex(shufId, mapId, _) => - val address = statuses(mapId.toInt)._1 - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null) - case _ => - throw new SparkException( - "Failed to get block " + blockId + ", which is not a shuffle block") - } - } - } - } - - val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer) - val itr = blockFetcherItr.flatMap(unpackBlock) - - CompletionIterator[T, Iterator[T]](itr, { - val shuffleMetrics = new ShuffleReadMetrics - shuffleMetrics.shuffleFinishTime = System.currentTimeMillis - shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime - shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime - shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead - shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks - shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks - shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks - metrics.shuffleReadMetrics = Some(shuffleMetrics) - }) - } -} diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala deleted file mode 100644 index 81314805a9..0000000000 --- a/core/src/main/scala/spark/CacheManager.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import scala.collection.mutable.{ArrayBuffer, HashSet} -import spark.storage.{BlockManager, StorageLevel} - - -/** Spark class responsible for passing RDDs split contents to the BlockManager and making - sure a node doesn't load two copies of an RDD at once. - */ -private[spark] class CacheManager(blockManager: BlockManager) extends Logging { - private val loading = new HashSet[String] - - /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */ - def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel) - : Iterator[T] = { - val key = "rdd_%d_%d".format(rdd.id, split.index) - logInfo("Cache key is " + key) - blockManager.get(key) match { - case Some(cachedValues) => - // Partition is in cache, so just return its values - logInfo("Found partition in cache!") - return cachedValues.asInstanceOf[Iterator[T]] - - case None => - // Mark the split as loading (unless someone else marks it first) - loading.synchronized { - if (loading.contains(key)) { - logInfo("Loading contains " + key + ", waiting...") - while (loading.contains(key)) { - try {loading.wait()} catch {case _ : Throwable =>} - } - logInfo("Loading no longer contains " + key + ", so returning cached result") - // See whether someone else has successfully loaded it. The main way this would fail - // is for the RDD-level cache eviction policy if someone else has loaded the same RDD - // partition but we didn't want to make space for it. However, that case is unlikely - // because it's unlikely that two threads would work on the same RDD partition. One - // downside of the current code is that threads wait serially if this does happen. - blockManager.get(key) match { - case Some(values) => - return values.asInstanceOf[Iterator[T]] - case None => - logInfo("Whoever was loading " + key + " failed; we'll try it ourselves") - loading.add(key) - } - } else { - loading.add(key) - } - } - try { - // If we got here, we have to load the split - val elements = new ArrayBuffer[Any] - logInfo("Computing partition " + split) - elements ++= rdd.computeOrReadCheckpoint(split, context) - // Try to put this block in the blockManager - blockManager.put(key, elements, storageLevel, true) - return elements.iterator.asInstanceOf[Iterator[T]] - } finally { - loading.synchronized { - loading.remove(key) - loading.notifyAll() - } - } - } - } -} diff --git a/core/src/main/scala/spark/ClosureCleaner.scala b/core/src/main/scala/spark/ClosureCleaner.scala deleted file mode 100644 index 8b39241095..0000000000 --- a/core/src/main/scala/spark/ClosureCleaner.scala +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.lang.reflect.Field - -import scala.collection.mutable.Map -import scala.collection.mutable.Set - -import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} -import org.objectweb.asm.Opcodes._ -import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream} - -private[spark] object ClosureCleaner extends Logging { - // Get an ASM class reader for a given class from the JAR that loaded it - private def getClassReader(cls: Class[_]): ClassReader = { - // Copy data over, before delegating to ClassReader - else we can run out of open file handles. - val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" - val resourceStream = cls.getResourceAsStream(className) - // todo: Fixme - continuing with earlier behavior ... - if (resourceStream == null) return new ClassReader(resourceStream) - - val baos = new ByteArrayOutputStream(128) - Utils.copyStream(resourceStream, baos, true) - new ClassReader(new ByteArrayInputStream(baos.toByteArray)) - } - - // Check whether a class represents a Scala closure - private def isClosure(cls: Class[_]): Boolean = { - cls.getName.contains("$anonfun$") - } - - // Get a list of the classes of the outer objects of a given closure object, obj; - // the outer objects are defined as any closures that obj is nested within, plus - // possibly the class that the outermost closure is in, if any. We stop searching - // for outer objects beyond that because cloning the user's object is probably - // not a good idea (whereas we can clone closure objects just fine since we - // understand how all their fields are used). - private def getOuterClasses(obj: AnyRef): List[Class[_]] = { - for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { - f.setAccessible(true) - if (isClosure(f.getType)) { - return f.getType :: getOuterClasses(f.get(obj)) - } else { - return f.getType :: Nil // Stop at the first $outer that is not a closure - } - } - return Nil - } - - // Get a list of the outer objects for a given closure object. - private def getOuterObjects(obj: AnyRef): List[AnyRef] = { - for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { - f.setAccessible(true) - if (isClosure(f.getType)) { - return f.get(obj) :: getOuterObjects(f.get(obj)) - } else { - return f.get(obj) :: Nil // Stop at the first $outer that is not a closure - } - } - return Nil - } - - private def getInnerClasses(obj: AnyRef): List[Class[_]] = { - val seen = Set[Class[_]](obj.getClass) - var stack = List[Class[_]](obj.getClass) - while (!stack.isEmpty) { - val cr = getClassReader(stack.head) - stack = stack.tail - val set = Set[Class[_]]() - cr.accept(new InnerClosureFinder(set), 0) - for (cls <- set -- seen) { - seen += cls - stack = cls :: stack - } - } - return (seen - obj.getClass).toList - } - - private def createNullValue(cls: Class[_]): AnyRef = { - if (cls.isPrimitive) { - new java.lang.Byte(0: Byte) // Should be convertible to any primitive type - } else { - null - } - } - - def clean(func: AnyRef) { - // TODO: cache outerClasses / innerClasses / accessedFields - val outerClasses = getOuterClasses(func) - val innerClasses = getInnerClasses(func) - val outerObjects = getOuterObjects(func) - - val accessedFields = Map[Class[_], Set[String]]() - for (cls <- outerClasses) - accessedFields(cls) = Set[String]() - for (cls <- func.getClass :: innerClasses) - getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0) - //logInfo("accessedFields: " + accessedFields) - - val inInterpreter = { - try { - val interpClass = Class.forName("spark.repl.Main") - interpClass.getMethod("interp").invoke(null) != null - } catch { - case _: ClassNotFoundException => true - } - } - - var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse - var outer: AnyRef = null - if (outerPairs.size > 0 && !isClosure(outerPairs.head._1)) { - // The closure is ultimately nested inside a class; keep the object of that - // class without cloning it since we don't want to clone the user's objects. - outer = outerPairs.head._2 - outerPairs = outerPairs.tail - } - // Clone the closure objects themselves, nulling out any fields that are not - // used in the closure we're working on or any of its inner closures. - for ((cls, obj) <- outerPairs) { - outer = instantiateClass(cls, outer, inInterpreter) - for (fieldName <- accessedFields(cls)) { - val field = cls.getDeclaredField(fieldName) - field.setAccessible(true) - val value = field.get(obj) - //logInfo("1: Setting " + fieldName + " on " + cls + " to " + value); - field.set(outer, value) - } - } - - if (outer != null) { - //logInfo("2: Setting $outer on " + func.getClass + " to " + outer); - val field = func.getClass.getDeclaredField("$outer") - field.setAccessible(true) - field.set(func, outer) - } - } - - private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = { - //logInfo("Creating a " + cls + " with outer = " + outer) - if (!inInterpreter) { - // This is a bona fide closure class, whose constructor has no effects - // other than to set its fields, so use its constructor - val cons = cls.getConstructors()(0) - val params = cons.getParameterTypes.map(createNullValue).toArray - if (outer != null) - params(0) = outer // First param is always outer object - return cons.newInstance(params: _*).asInstanceOf[AnyRef] - } else { - // Use reflection to instantiate object without calling constructor - val rf = sun.reflect.ReflectionFactory.getReflectionFactory() - val parentCtor = classOf[java.lang.Object].getDeclaredConstructor() - val newCtor = rf.newConstructorForSerialization(cls, parentCtor) - val obj = newCtor.newInstance().asInstanceOf[AnyRef] - if (outer != null) { - //logInfo("3: Setting $outer on " + cls + " to " + outer); - val field = cls.getDeclaredField("$outer") - field.setAccessible(true) - field.set(obj, outer) - } - return obj - } - } -} - -private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) { - override def visitMethod(access: Int, name: String, desc: String, - sig: String, exceptions: Array[String]): MethodVisitor = { - return new MethodVisitor(ASM4) { - override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { - if (op == GETFIELD) { - for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { - output(cl) += name - } - } - } - - override def visitMethodInsn(op: Int, owner: String, name: String, - desc: String) { - // Check for calls a getter method for a variable in an interpreter wrapper object. - // This means that the corresponding field will be accessed, so we should save it. - if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) { - for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { - output(cl) += name - } - } - } - } - } -} - -private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) { - var myName: String = null - - override def visit(version: Int, access: Int, name: String, sig: String, - superName: String, interfaces: Array[String]) { - myName = name - } - - override def visitMethod(access: Int, name: String, desc: String, - sig: String, exceptions: Array[String]): MethodVisitor = { - return new MethodVisitor(ASM4) { - override def visitMethodInsn(op: Int, owner: String, name: String, - desc: String) { - val argTypes = Type.getArgumentTypes(desc) - if (op == INVOKESPECIAL && name == "" && argTypes.length > 0 - && argTypes(0).toString.startsWith("L") // is it an object? - && argTypes(0).getInternalName == myName) - output += Class.forName( - owner.replace('/', '.'), - false, - Thread.currentThread.getContextClassLoader) - } - } - } -} diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala deleted file mode 100644 index d5a9606570..0000000000 --- a/core/src/main/scala/spark/Dependency.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -/** - * Base class for dependencies. - */ -abstract class Dependency[T](val rdd: RDD[T]) extends Serializable - - -/** - * Base class for dependencies where each partition of the parent RDD is used by at most one - * partition of the child RDD. Narrow dependencies allow for pipelined execution. - */ -abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { - /** - * Get the parent partitions for a child partition. - * @param partitionId a partition of the child RDD - * @return the partitions of the parent RDD that the child partition depends upon - */ - def getParents(partitionId: Int): Seq[Int] -} - - -/** - * Represents a dependency on the output of a shuffle stage. - * @param rdd the parent RDD - * @param partitioner partitioner used to partition the shuffle output - * @param serializerClass class name of the serializer to use - */ -class ShuffleDependency[K, V]( - @transient rdd: RDD[_ <: Product2[K, V]], - val partitioner: Partitioner, - val serializerClass: String = null) - extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { - - val shuffleId: Int = rdd.context.newShuffleId() -} - - -/** - * Represents a one-to-one dependency between partitions of the parent and child RDDs. - */ -class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { - override def getParents(partitionId: Int) = List(partitionId) -} - - -/** - * Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs. - * @param rdd the parent RDD - * @param inStart the start of the range in the parent RDD - * @param outStart the start of the range in the child RDD - * @param length the length of the range - */ -class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) - extends NarrowDependency[T](rdd) { - - override def getParents(partitionId: Int) = { - if (partitionId >= outStart && partitionId < outStart + length) { - List(partitionId - outStart + inStart) - } else { - Nil - } - } -} diff --git a/core/src/main/scala/spark/DoubleRDDFunctions.scala b/core/src/main/scala/spark/DoubleRDDFunctions.scala deleted file mode 100644 index 104168e61c..0000000000 --- a/core/src/main/scala/spark/DoubleRDDFunctions.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import spark.partial.BoundedDouble -import spark.partial.MeanEvaluator -import spark.partial.PartialResult -import spark.partial.SumEvaluator -import spark.util.StatCounter - -/** - * Extra functions available on RDDs of Doubles through an implicit conversion. - * Import `spark.SparkContext._` at the top of your program to use these functions. - */ -class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { - /** Add up the elements in this RDD. */ - def sum(): Double = { - self.reduce(_ + _) - } - - /** - * Return a [[spark.util.StatCounter]] object that captures the mean, variance and count - * of the RDD's elements in one operation. - */ - def stats(): StatCounter = { - self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b)) - } - - /** Compute the mean of this RDD's elements. */ - def mean(): Double = stats().mean - - /** Compute the variance of this RDD's elements. */ - def variance(): Double = stats().variance - - /** Compute the standard deviation of this RDD's elements. */ - def stdev(): Double = stats().stdev - - /** - * Compute the sample standard deviation of this RDD's elements (which corrects for bias in - * estimating the standard deviation by dividing by N-1 instead of N). - */ - def sampleStdev(): Double = stats().sampleStdev - - /** - * Compute the sample variance of this RDD's elements (which corrects for bias in - * estimating the variance by dividing by N-1 instead of N). - */ - def sampleVariance(): Double = stats().sampleVariance - - /** (Experimental) Approximate operation to return the mean within a timeout. */ - def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { - val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new MeanEvaluator(self.partitions.size, confidence) - self.context.runApproximateJob(self, processPartition, evaluator, timeout) - } - - /** (Experimental) Approximate operation to return the sum within a timeout. */ - def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { - val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new SumEvaluator(self.partitions.size, confidence) - self.context.runApproximateJob(self, processPartition, evaluator, timeout) - } -} diff --git a/core/src/main/scala/spark/FetchFailedException.scala b/core/src/main/scala/spark/FetchFailedException.scala deleted file mode 100644 index a2dae6cae9..0000000000 --- a/core/src/main/scala/spark/FetchFailedException.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import spark.storage.BlockManagerId - -private[spark] class FetchFailedException( - taskEndReason: TaskEndReason, - message: String, - cause: Throwable) - extends Exception { - - def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) = - this(FetchFailed(bmAddress, shuffleId, mapId, reduceId), - "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId), - cause) - - def this (shuffleId: Int, reduceId: Int, cause: Throwable) = - this(FetchFailed(null, shuffleId, -1, reduceId), - "Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause) - - override def getMessage(): String = message - - - override def getCause(): Throwable = cause - - def toTaskEndReason: TaskEndReason = taskEndReason - -} diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala deleted file mode 100644 index a13a7a2859..0000000000 --- a/core/src/main/scala/spark/HttpFileServer.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io.{File} -import com.google.common.io.Files - -private[spark] class HttpFileServer extends Logging { - - var baseDir : File = null - var fileDir : File = null - var jarDir : File = null - var httpServer : HttpServer = null - var serverUri : String = null - - def initialize() { - baseDir = Utils.createTempDir() - fileDir = new File(baseDir, "files") - jarDir = new File(baseDir, "jars") - fileDir.mkdir() - jarDir.mkdir() - logInfo("HTTP File server directory is " + baseDir) - httpServer = new HttpServer(baseDir) - httpServer.start() - serverUri = httpServer.uri - } - - def stop() { - httpServer.stop() - } - - def addFile(file: File) : String = { - addFileToDir(file, fileDir) - return serverUri + "/files/" + file.getName - } - - def addJar(file: File) : String = { - addFileToDir(file, jarDir) - return serverUri + "/jars/" + file.getName - } - - def addFileToDir(file: File, dir: File) : String = { - Files.copy(file, new File(dir, file.getName)) - return dir + "/" + file.getName - } - -} diff --git a/core/src/main/scala/spark/HttpServer.scala b/core/src/main/scala/spark/HttpServer.scala deleted file mode 100644 index c9dffbc631..0000000000 --- a/core/src/main/scala/spark/HttpServer.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io.File -import java.net.InetAddress - -import org.eclipse.jetty.server.Server -import org.eclipse.jetty.server.bio.SocketConnector -import org.eclipse.jetty.server.handler.DefaultHandler -import org.eclipse.jetty.server.handler.HandlerList -import org.eclipse.jetty.server.handler.ResourceHandler -import org.eclipse.jetty.util.thread.QueuedThreadPool - -/** - * Exception type thrown by HttpServer when it is in the wrong state for an operation. - */ -private[spark] class ServerStateException(message: String) extends Exception(message) - -/** - * An HTTP server for static content used to allow worker nodes to access JARs added to SparkContext - * as well as classes created by the interpreter when the user types in code. This is just a wrapper - * around a Jetty server. - */ -private[spark] class HttpServer(resourceBase: File) extends Logging { - private var server: Server = null - private var port: Int = -1 - - def start() { - if (server != null) { - throw new ServerStateException("Server is already started") - } else { - server = new Server() - val connector = new SocketConnector - connector.setMaxIdleTime(60*1000) - connector.setSoLingerTime(-1) - connector.setPort(0) - server.addConnector(connector) - - val threadPool = new QueuedThreadPool - threadPool.setDaemon(true) - server.setThreadPool(threadPool) - val resHandler = new ResourceHandler - resHandler.setResourceBase(resourceBase.getAbsolutePath) - val handlerList = new HandlerList - handlerList.setHandlers(Array(resHandler, new DefaultHandler)) - server.setHandler(handlerList) - server.start() - port = server.getConnectors()(0).getLocalPort() - } - } - - def stop() { - if (server == null) { - throw new ServerStateException("Server is already stopped") - } else { - server.stop() - port = -1 - server = null - } - } - - /** - * Get the URI of this HTTP server (http://host:port) - */ - def uri: String = { - if (server == null) { - throw new ServerStateException("Server is not started") - } else { - return "http://" + Utils.localIpAddress + ":" + port - } - } -} diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala deleted file mode 100644 index 04c5f44e6b..0000000000 --- a/core/src/main/scala/spark/JavaSerializer.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io._ -import java.nio.ByteBuffer - -import serializer.{Serializer, SerializerInstance, DeserializationStream, SerializationStream} -import spark.util.ByteBufferInputStream - -private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream { - val objOut = new ObjectOutputStream(out) - def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this } - def flush() { objOut.flush() } - def close() { objOut.close() } -} - -private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader) -extends DeserializationStream { - val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, loader) - } - - def readObject[T](): T = objIn.readObject().asInstanceOf[T] - def close() { objIn.close() } -} - -private[spark] class JavaSerializerInstance extends SerializerInstance { - def serialize[T](t: T): ByteBuffer = { - val bos = new ByteArrayOutputStream() - val out = serializeStream(bos) - out.writeObject(t) - out.close() - ByteBuffer.wrap(bos.toByteArray) - } - - def deserialize[T](bytes: ByteBuffer): T = { - val bis = new ByteBufferInputStream(bytes) - val in = deserializeStream(bis) - in.readObject().asInstanceOf[T] - } - - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { - val bis = new ByteBufferInputStream(bytes) - val in = deserializeStream(bis, loader) - in.readObject().asInstanceOf[T] - } - - def serializeStream(s: OutputStream): SerializationStream = { - new JavaSerializationStream(s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new JavaDeserializationStream(s, Thread.currentThread.getContextClassLoader) - } - - def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = { - new JavaDeserializationStream(s, loader) - } -} - -/** - * A Spark serializer that uses Java's built-in serialization. - */ -class JavaSerializer extends Serializer { - def newInstance(): SerializerInstance = new JavaSerializerInstance -} diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala deleted file mode 100644 index eeb2993d8a..0000000000 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io._ -import java.nio.ByteBuffer -import com.esotericsoftware.kryo.{Kryo, KryoException} -import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} -import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} -import com.twitter.chill.ScalaKryoInstantiator -import serializer.{SerializerInstance, DeserializationStream, SerializationStream} -import spark.broadcast._ -import spark.storage._ - -private[spark] -class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { - val output = new KryoOutput(outStream) - - def writeObject[T](t: T): SerializationStream = { - kryo.writeClassAndObject(output, t) - this - } - - def flush() { output.flush() } - def close() { output.close() } -} - -private[spark] -class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { - val input = new KryoInput(inStream) - - def readObject[T](): T = { - try { - kryo.readClassAndObject(input).asInstanceOf[T] - } catch { - // DeserializationStream uses the EOF exception to indicate stopping condition. - case _: KryoException => throw new EOFException - } - } - - def close() { - // Kryo's Input automatically closes the input stream it is using. - input.close() - } -} - -private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - val kryo = ks.newKryo() - val output = ks.newKryoOutput() - val input = ks.newKryoInput() - - def serialize[T](t: T): ByteBuffer = { - output.clear() - kryo.writeClassAndObject(output, t) - ByteBuffer.wrap(output.toBytes) - } - - def deserialize[T](bytes: ByteBuffer): T = { - input.setBuffer(bytes.array) - kryo.readClassAndObject(input).asInstanceOf[T] - } - - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { - val oldClassLoader = kryo.getClassLoader - kryo.setClassLoader(loader) - input.setBuffer(bytes.array) - val obj = kryo.readClassAndObject(input).asInstanceOf[T] - kryo.setClassLoader(oldClassLoader) - obj - } - - def serializeStream(s: OutputStream): SerializationStream = { - new KryoSerializationStream(kryo, s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new KryoDeserializationStream(kryo, s) - } -} - -/** - * Interface implemented by clients to register their classes with Kryo when using Kryo - * serialization. - */ -trait KryoRegistrator { - def registerClasses(kryo: Kryo) -} - -/** - * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]]. - */ -class KryoSerializer extends spark.serializer.Serializer with Logging { - private val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 - - def newKryoOutput() = new KryoOutput(bufferSize) - - def newKryoInput() = new KryoInput(bufferSize) - - def newKryo(): Kryo = { - val instantiator = new ScalaKryoInstantiator - val kryo = instantiator.newKryo() - val classLoader = Thread.currentThread.getContextClassLoader - - // Register some commonly used classes - val toRegister: Seq[AnyRef] = Seq( - ByteBuffer.allocate(1), - StorageLevel.MEMORY_ONLY, - PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY), - GotBlock("1", ByteBuffer.allocate(1)), - GetBlock("1") - ) - - for (obj <- toRegister) kryo.register(obj.getClass) - - // Allow sending SerializableWritable - kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) - kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) - - // Allow the user to register their own classes by setting spark.kryo.registrator - try { - Option(System.getProperty("spark.kryo.registrator")).foreach { regCls => - logDebug("Running user registrator: " + regCls) - val reg = Class.forName(regCls, true, classLoader).newInstance().asInstanceOf[KryoRegistrator] - reg.registerClasses(kryo) - } - } catch { - case _: Exception => println("Failed to register spark.kryo.registrator") - } - - kryo.setClassLoader(classLoader) - - // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops - kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean) - - kryo - } - - def newInstance(): SerializerInstance = { - new KryoSerializerInstance(this) - } -} \ No newline at end of file diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala deleted file mode 100644 index 79b0362830..0000000000 --- a/core/src/main/scala/spark/Logging.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.slf4j.Logger -import org.slf4j.LoggerFactory - -/** - * Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows - * logging messages at different levels using methods that only evaluate parameters lazily if the - * log level is enabled. - */ -trait Logging { - // Make the log field transient so that objects with Logging can - // be serialized and used on another machine - @transient private var log_ : Logger = null - - // Method to get or create the logger for this object - protected def log: Logger = { - if (log_ == null) { - var className = this.getClass.getName - // Ignore trailing $'s in the class names for Scala objects - if (className.endsWith("$")) { - className = className.substring(0, className.length - 1) - } - log_ = LoggerFactory.getLogger(className) - } - return log_ - } - - // Log methods that take only a String - protected def logInfo(msg: => String) { - if (log.isInfoEnabled) log.info(msg) - } - - protected def logDebug(msg: => String) { - if (log.isDebugEnabled) log.debug(msg) - } - - protected def logTrace(msg: => String) { - if (log.isTraceEnabled) log.trace(msg) - } - - protected def logWarning(msg: => String) { - if (log.isWarnEnabled) log.warn(msg) - } - - protected def logError(msg: => String) { - if (log.isErrorEnabled) log.error(msg) - } - - // Log methods that take Throwables (Exceptions/Errors) too - protected def logInfo(msg: => String, throwable: Throwable) { - if (log.isInfoEnabled) log.info(msg, throwable) - } - - protected def logDebug(msg: => String, throwable: Throwable) { - if (log.isDebugEnabled) log.debug(msg, throwable) - } - - protected def logTrace(msg: => String, throwable: Throwable) { - if (log.isTraceEnabled) log.trace(msg, throwable) - } - - protected def logWarning(msg: => String, throwable: Throwable) { - if (log.isWarnEnabled) log.warn(msg, throwable) - } - - protected def logError(msg: => String, throwable: Throwable) { - if (log.isErrorEnabled) log.error(msg, throwable) - } - - protected def isTraceEnabled(): Boolean = { - log.isTraceEnabled - } - - // Method for ensuring that logging is initialized, to avoid having multiple - // threads do it concurrently (as SLF4J initialization is not thread safe). - protected def initLogging() { log } -} diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala deleted file mode 100644 index 0cd0341a72..0000000000 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ /dev/null @@ -1,338 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io._ -import java.util.zip.{GZIPInputStream, GZIPOutputStream} - -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -import akka.actor._ -import akka.dispatch._ -import akka.pattern.ask -import akka.remote._ -import akka.util.Duration - - -import spark.scheduler.MapStatus -import spark.storage.BlockManagerId -import spark.util.{MetadataCleaner, TimeStampedHashMap} - - -private[spark] sealed trait MapOutputTrackerMessage -private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String) - extends MapOutputTrackerMessage -private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage - -private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging { - def receive = { - case GetMapOutputStatuses(shuffleId: Int, requester: String) => - logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester) - sender ! tracker.getSerializedLocations(shuffleId) - - case StopMapOutputTracker => - logInfo("MapOutputTrackerActor stopped!") - sender ! true - context.stop(self) - } -} - -private[spark] class MapOutputTracker extends Logging { - - private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") - - // Set to the MapOutputTrackerActor living on the driver - var trackerActor: ActorRef = _ - - private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] - - // Incremented every time a fetch fails so that client nodes know to clear - // their cache of map output locations if this happens. - private var epoch: Long = 0 - private val epochLock = new java.lang.Object - - // Cache a serialized version of the output statuses for each shuffle to send them out faster - var cacheEpoch = epoch - private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] - - val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup) - - // Send a message to the trackerActor and get its result within a default timeout, or - // throw a SparkException if this fails. - def askTracker(message: Any): Any = { - try { - val future = trackerActor.ask(message)(timeout) - return Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error communicating with MapOutputTracker", e) - } - } - - // Send a one-way message to the trackerActor, to which we expect it to reply with true. - def communicate(message: Any) { - if (askTracker(message) != true) { - throw new SparkException("Error reply received from MapOutputTracker") - } - } - - def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) { - throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") - } - } - - def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { - var array = mapStatuses(shuffleId) - array.synchronized { - array(mapId) = status - } - } - - def registerMapOutputs( - shuffleId: Int, - statuses: Array[MapStatus], - changeEpoch: Boolean = false) { - mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses) - if (changeEpoch) { - incrementEpoch() - } - } - - def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - var arrayOpt = mapStatuses.get(shuffleId) - if (arrayOpt.isDefined && arrayOpt.get != null) { - var array = arrayOpt.get - array.synchronized { - if (array(mapId) != null && array(mapId).location == bmAddress) { - array(mapId) = null - } - } - incrementEpoch() - } else { - throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") - } - } - - // Remembers which map output locations are currently being fetched on a worker - private val fetching = new HashSet[Int] - - // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle - def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses == null) { - logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") - var fetchedStatuses: Array[MapStatus] = null - fetching.synchronized { - if (fetching.contains(shuffleId)) { - // Someone else is fetching it; wait for them to be done - while (fetching.contains(shuffleId)) { - try { - fetching.wait() - } catch { - case e: InterruptedException => - } - } - } - - // Either while we waited the fetch happened successfully, or - // someone fetched it in between the get and the fetching.synchronized. - fetchedStatuses = mapStatuses.get(shuffleId).orNull - if (fetchedStatuses == null) { - // We have to do the fetch, get others to wait for us. - fetching += shuffleId - } - } - - if (fetchedStatuses == null) { - // We won the race to fetch the output locs; do so - logInfo("Doing the fetch; tracker actor = " + trackerActor) - val hostPort = Utils.localHostPort() - // This try-finally prevents hangs due to timeouts: - try { - val fetchedBytes = - askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]] - fetchedStatuses = deserializeStatuses(fetchedBytes) - logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) - } finally { - fetching.synchronized { - fetching -= shuffleId - fetching.notifyAll() - } - } - } - if (fetchedStatuses != null) { - fetchedStatuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) - } - } - else{ - throw new FetchFailedException(null, shuffleId, -1, reduceId, - new Exception("Missing all output locations for shuffle " + shuffleId)) - } - } else { - statuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) - } - } - } - - private def cleanup(cleanupTime: Long) { - mapStatuses.clearOldValues(cleanupTime) - cachedSerializedStatuses.clearOldValues(cleanupTime) - } - - def stop() { - communicate(StopMapOutputTracker) - mapStatuses.clear() - metadataCleaner.cancel() - trackerActor = null - } - - // Called on master to increment the epoch number - def incrementEpoch() { - epochLock.synchronized { - epoch += 1 - logDebug("Increasing epoch to " + epoch) - } - } - - // Called on master or workers to get current epoch number - def getEpoch: Long = { - epochLock.synchronized { - return epoch - } - } - - // Called on workers to update the epoch number, potentially clearing old outputs - // because of a fetch failure. (Each worker task calls this with the latest epoch - // number on the master at the time it was created.) - def updateEpoch(newEpoch: Long) { - epochLock.synchronized { - if (newEpoch > epoch) { - logInfo("Updating epoch to " + newEpoch + " and clearing cache") - // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] - mapStatuses.clear() - epoch = newEpoch - } - } - } - - def getSerializedLocations(shuffleId: Int): Array[Byte] = { - var statuses: Array[MapStatus] = null - var epochGotten: Long = -1 - epochLock.synchronized { - if (epoch > cacheEpoch) { - cachedSerializedStatuses.clear() - cacheEpoch = epoch - } - cachedSerializedStatuses.get(shuffleId) match { - case Some(bytes) => - return bytes - case None => - statuses = mapStatuses(shuffleId) - epochGotten = epoch - } - } - // If we got here, we failed to find the serialized locations in the cache, so we pulled - // out a snapshot of the locations as "locs"; let's serialize and return that - val bytes = serializeStatuses(statuses) - logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) - // Add them into the table only if the epoch hasn't changed while we were working - epochLock.synchronized { - if (epoch == epochGotten) { - cachedSerializedStatuses(shuffleId) = bytes - } - } - return bytes - } - - // Serialize an array of map output locations into an efficient byte format so that we can send - // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will - // generally be pretty compressible because many map outputs will be on the same hostname. - private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = { - val out = new ByteArrayOutputStream - val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) - // Since statuses can be modified in parallel, sync on it - statuses.synchronized { - objOut.writeObject(statuses) - } - objOut.close() - out.toByteArray - } - - // Opposite of serializeStatuses. - def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = { - val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes))) - objIn.readObject(). - // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present - // comment this out - nulls could be due to missing location ? - asInstanceOf[Array[MapStatus]] // .filter( _ != null ) - } -} - -private[spark] object MapOutputTracker { - private val LOG_BASE = 1.1 - - // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If - // any of the statuses is null (indicating a missing location due to a failed mapper), - // throw a FetchFailedException. - private def convertMapStatuses( - shuffleId: Int, - reduceId: Int, - statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { - assert (statuses != null) - statuses.map { - status => - if (status == null) { - throw new FetchFailedException(null, shuffleId, -1, reduceId, - new Exception("Missing an output location for shuffle " + shuffleId)) - } else { - (status.location, decompressSize(status.compressedSizes(reduceId))) - } - } - } - - /** - * Compress a size in bytes to 8 bits for efficient reporting of map output sizes. - * We do this by encoding the log base 1.1 of the size as an integer, which can support - * sizes up to 35 GB with at most 10% error. - */ - def compressSize(size: Long): Byte = { - if (size == 0) { - 0 - } else if (size <= 1L) { - 1 - } else { - math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte - } - } - - /** - * Decompress an 8-bit encoded block size, using the reverse operation of compressSize. - */ - def decompressSize(compressedSize: Byte): Long = { - if (compressedSize == 0) { - 0 - } else { - math.pow(LOG_BASE, (compressedSize & 0xFF)).toLong - } - } -} diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala deleted file mode 100644 index cc1285dd95..0000000000 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ /dev/null @@ -1,703 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.nio.ByteBuffer -import java.util.{Date, HashMap => JHashMap} -import java.text.SimpleDateFormat - -import scala.collection.{mutable, Map} -import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.io.SequenceFile.CompressionType -import org.apache.hadoop.mapred.FileOutputCommitter -import org.apache.hadoop.mapred.FileOutputFormat -import org.apache.hadoop.mapred.SparkHadoopWriter -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.OutputFormat - -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, - RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, SparkHadoopMapReduceUtil} -import org.apache.hadoop.security.UserGroupInformation - -import spark.partial.BoundedDouble -import spark.partial.PartialResult -import spark.rdd._ -import spark.SparkContext._ -import spark.Partitioner._ - -/** - * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. - * Import `spark.SparkContext._` at the top of your program to use these functions. - */ -class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) - extends Logging - with SparkHadoopMapReduceUtil - with Serializable { - - /** - * Generic function to combine the elements for each key using a custom set of aggregation - * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C - * Note that V and C can be different -- for example, one might group an RDD of type - * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: - * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. - * - * In addition, users can control the partitioning of the output RDD, and whether to perform - * map-side aggregation (if a mapper can produce multiple items with the same key). - */ - def combineByKey[C](createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C, - partitioner: Partitioner, - mapSideCombine: Boolean = true, - serializerClass: String = null): RDD[(K, C)] = { - if (getKeyClass().isArray) { - if (mapSideCombine) { - throw new SparkException("Cannot use map-side combining with array keys.") - } - if (partitioner.isInstanceOf[HashPartitioner]) { - throw new SparkException("Default partitioner cannot partition array keys.") - } - } - val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) - if (self.partitioner == Some(partitioner)) { - self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) - } else if (mapSideCombine) { - val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) - val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner) - .setSerializer(serializerClass) - partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true) - } else { - // Don't apply map-side combiner. - // A sanity check to make sure mergeCombiners is not defined. - assert(mergeCombiners == null) - val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass) - values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) - } - } - - /** - * Simplified version of combineByKey that hash-partitions the output RDD. - */ - def combineByKey[C](createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C, - numPartitions: Int): RDD[(K, C)] = { - combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) - } - - /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). - */ - def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = { - // Serialize the zero value to a byte array so that we can get a new clone of it on each key - val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue) - val zeroArray = new Array[Byte](zeroBuffer.limit) - zeroBuffer.get(zeroArray) - - // When deserializing, use a lazy val to create just one instance of the serializer per task - lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance() - def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) - - combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner) - } - - /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). - */ - def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = { - foldByKey(zeroValue, new HashPartitioner(numPartitions))(func) - } - - /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). - */ - def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = { - foldByKey(zeroValue, defaultPartitioner(self))(func) - } - - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. - */ - def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = { - combineByKey[V]((v: V) => v, func, func, partitioner) - } - - /** - * Merge the values for each key using an associative reduce function, but return the results - * immediately to the master as a Map. This will also perform the merging locally on each mapper - * before sending results to a reducer, similarly to a "combiner" in MapReduce. - */ - def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = { - - if (getKeyClass().isArray) { - throw new SparkException("reduceByKeyLocally() does not support array keys") - } - - def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = { - val map = new JHashMap[K, V] - iter.foreach { case (k, v) => - val old = map.get(k) - map.put(k, if (old == null) v else func(old, v)) - } - Iterator(map) - } - - def mergeMaps(m1: JHashMap[K, V], m2: JHashMap[K, V]): JHashMap[K, V] = { - m2.foreach { case (k, v) => - val old = m1.get(k) - m1.put(k, if (old == null) v else func(old, v)) - } - m1 - } - - self.mapPartitions(reducePartition).reduce(mergeMaps) - } - - /** Alias for reduceByKeyLocally */ - def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func) - - /** Count the number of elements for each key, and return the result to the master as a Map. */ - def countByKey(): Map[K, Long] = self.map(_._1).countByValue() - - /** - * (Experimental) Approximate version of countByKey that can return a partial result if it does - * not finish within a timeout. - */ - def countByKeyApprox(timeout: Long, confidence: Double = 0.95) - : PartialResult[Map[K, BoundedDouble]] = { - self.map(_._1).countByValueApprox(timeout, confidence) - } - - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. - */ - def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = { - reduceByKey(new HashPartitioner(numPartitions), func) - } - - /** - * Group the values for each key in the RDD into a single sequence. Allows controlling the - * partitioning of the resulting key-value pair RDD by passing a Partitioner. - */ - def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = { - // groupByKey shouldn't use map side combine because map side combine does not - // reduce the amount of data shuffled and requires all map side data be inserted - // into a hash table, leading to more objects in the old gen. - def createCombiner(v: V) = ArrayBuffer(v) - def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v - val bufs = combineByKey[ArrayBuffer[V]]( - createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false) - bufs.asInstanceOf[RDD[(K, Seq[V])]] - } - - /** - * Group the values for each key in the RDD into a single sequence. Hash-partitions the - * resulting RDD with into `numPartitions` partitions. - */ - def groupByKey(numPartitions: Int): RDD[(K, Seq[V])] = { - groupByKey(new HashPartitioner(numPartitions)) - } - - /** - * Return a copy of the RDD partitioned using the specified partitioner. - */ - def partitionBy(partitioner: Partitioner): RDD[(K, V)] = { - if (getKeyClass().isArray && partitioner.isInstanceOf[HashPartitioner]) { - throw new SparkException("Default partitioner cannot partition array keys.") - } - new ShuffledRDD[K, V, (K, V)](self, partitioner) - } - - /** - * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each - * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and - * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. - */ - def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = { - this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => - for (v <- vs.iterator; w <- ws.iterator) yield (v, w) - } - } - - /** - * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the - * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the - * pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to - * partition the output RDD. - */ - def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = { - this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => - if (ws.isEmpty) { - vs.iterator.map(v => (v, None)) - } else { - for (v <- vs.iterator; w <- ws.iterator) yield (v, Some(w)) - } - } - } - - /** - * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the - * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the - * pair (k, (None, w)) if no elements in `this` have key k. Uses the given Partitioner to - * partition the output RDD. - */ - def rightOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner) - : RDD[(K, (Option[V], W))] = { - this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => - if (vs.isEmpty) { - ws.iterator.map(w => (None, w)) - } else { - for (v <- vs.iterator; w <- ws.iterator) yield (Some(v), w) - } - } - } - - /** - * Simplified version of combineByKey that hash-partitions the resulting RDD using the - * existing partitioner/parallelism level. - */ - def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) - : RDD[(K, C)] = { - combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) - } - - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ - * parallelism level. - */ - def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { - reduceByKey(defaultPartitioner(self), func) - } - - /** - * Group the values for each key in the RDD into a single sequence. Hash-partitions the - * resulting RDD with the existing partitioner/parallelism level. - */ - def groupByKey(): RDD[(K, Seq[V])] = { - groupByKey(defaultPartitioner(self)) - } - - /** - * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each - * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and - * (k, v2) is in `other`. Performs a hash join across the cluster. - */ - def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = { - join(other, defaultPartitioner(self, other)) - } - - /** - * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each - * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and - * (k, v2) is in `other`. Performs a hash join across the cluster. - */ - def join[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, W))] = { - join(other, new HashPartitioner(numPartitions)) - } - - /** - * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the - * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the - * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output - * using the existing partitioner/parallelism level. - */ - def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = { - leftOuterJoin(other, defaultPartitioner(self, other)) - } - - /** - * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the - * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the - * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output - * into `numPartitions` partitions. - */ - def leftOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, Option[W]))] = { - leftOuterJoin(other, new HashPartitioner(numPartitions)) - } - - /** - * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the - * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the - * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting - * RDD using the existing partitioner/parallelism level. - */ - def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = { - rightOuterJoin(other, defaultPartitioner(self, other)) - } - - /** - * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the - * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the - * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting - * RDD into the given number of partitions. - */ - def rightOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Option[V], W))] = { - rightOuterJoin(other, new HashPartitioner(numPartitions)) - } - - /** - * Return the key-value pairs in this RDD to the master as a Map. - */ - def collectAsMap(): Map[K, V] = { - val data = self.toArray() - val map = new mutable.HashMap[K, V] - map.sizeHint(data.length) - data.foreach { case (k, v) => map.put(k, v) } - map - } - - /** - * Pass each value in the key-value pair RDD through a map function without changing the keys; - * this also retains the original RDD's partitioning. - */ - def mapValues[U](f: V => U): RDD[(K, U)] = { - val cleanF = self.context.clean(f) - new MappedValuesRDD(self, cleanF) - } - - /** - * Pass each value in the key-value pair RDD through a flatMap function without changing the - * keys; this also retains the original RDD's partitioning. - */ - def flatMapValues[U](f: V => TraversableOnce[U]): RDD[(K, U)] = { - val cleanF = self.context.clean(f) - new FlatMappedValuesRDD(self, cleanF) - } - - /** - * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the - * list of values for that key in `this` as well as `other`. - */ - def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = { - if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { - throw new SparkException("Default partitioner cannot partition array keys.") - } - val cg = new CoGroupedRDD[K](Seq(self, other), partitioner) - val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest) - prfs.mapValues { case Seq(vs, ws) => - (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]]) - } - } - - /** - * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a - * tuple with the list of values for that key in `this`, `other1` and `other2`. - */ - def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner) - : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { - if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { - throw new SparkException("Default partitioner cannot partition array keys.") - } - val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner) - val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest) - prfs.mapValues { case Seq(vs, w1s, w2s) => - (vs.asInstanceOf[Seq[V]], w1s.asInstanceOf[Seq[W1]], w2s.asInstanceOf[Seq[W2]]) - } - } - - /** - * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the - * list of values for that key in `this` as well as `other`. - */ - def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = { - cogroup(other, defaultPartitioner(self, other)) - } - - /** - * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a - * tuple with the list of values for that key in `this`, `other1` and `other2`. - */ - def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) - : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { - cogroup(other1, other2, defaultPartitioner(self, other1, other2)) - } - - /** - * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the - * list of values for that key in `this` as well as `other`. - */ - def cogroup[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Seq[V], Seq[W]))] = { - cogroup(other, new HashPartitioner(numPartitions)) - } - - /** - * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a - * tuple with the list of values for that key in `this`, `other1` and `other2`. - */ - def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numPartitions: Int) - : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { - cogroup(other1, other2, new HashPartitioner(numPartitions)) - } - - /** Alias for cogroup. */ - def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = { - cogroup(other, defaultPartitioner(self, other)) - } - - /** Alias for cogroup. */ - def groupWith[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) - : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { - cogroup(other1, other2, defaultPartitioner(self, other1, other2)) - } - - /** - * Return an RDD with the pairs from `this` whose keys are not in `other`. - * - * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. - */ - def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] = - subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size))) - - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ - def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = - subtractByKey(other, new HashPartitioner(numPartitions)) - - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ - def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = - new SubtractedRDD[K, V, W](self, other, p) - - /** - * Return the list of values in the RDD for key `key`. This operation is done efficiently if the - * RDD has a known partitioner by only searching the partition that the key maps to. - */ - def lookup(key: K): Seq[V] = { - self.partitioner match { - case Some(p) => - val index = p.getPartition(key) - def process(it: Iterator[(K, V)]): Seq[V] = { - val buf = new ArrayBuffer[V] - for ((k, v) <- it if k == key) { - buf += v - } - buf - } - val res = self.context.runJob(self, process _, Array(index), false) - res(0) - case None => - self.filter(_._1 == key).map(_._2).collect() - } - } - - /** - * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class - * supporting the key and value types K and V in this RDD. - */ - def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) { - saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) - } - - /** - * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class - * supporting the key and value types K and V in this RDD. Compress the result with the - * supplied codec. - */ - def saveAsHadoopFile[F <: OutputFormat[K, V]]( - path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) { - saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec) - } - - /** - * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` - * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. - */ - def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) { - saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) - } - - /** - * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` - * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. - */ - def saveAsNewAPIHadoopFile( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration = self.context.hadoopConfiguration) { - val job = new NewAPIHadoopJob(conf) - job.setOutputKeyClass(keyClass) - job.setOutputValueClass(valueClass) - val wrappedConf = new SerializableWritable(job.getConfiguration) - NewFileOutputFormat.setOutputPath(job, new Path(path)) - val formatter = new SimpleDateFormat("yyyyMMddHHmm") - val jobtrackerID = formatter.format(new Date()) - val stageId = self.id - def writeShard(context: spark.TaskContext, iter: Iterator[(K,V)]): Int = { - // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it - // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt - /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber) - val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) - val format = outputFormatClass.newInstance - val committer = format.getOutputCommitter(hadoopContext) - committer.setupTask(hadoopContext) - val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] - while (iter.hasNext) { - val (k, v) = iter.next - writer.write(k, v) - } - writer.close(hadoopContext) - committer.commitTask(hadoopContext) - return 1 - } - val jobFormat = outputFormatClass.newInstance - /* apparently we need a TaskAttemptID to construct an OutputCommitter; - * however we're only going to use this local OutputCommitter for - * setupJob/commitJob, so we just use a dummy "map" task. - */ - val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, true, 0, 0) - val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) - val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) - jobCommitter.setupJob(jobTaskContext) - val count = self.context.runJob(self, writeShard _).sum - jobCommitter.commitJob(jobTaskContext) - jobCommitter.cleanupJob(jobTaskContext) - } - - /** - * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class - * supporting the key and value types K and V in this RDD. Compress with the supplied codec. - */ - def saveAsHadoopFile( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: OutputFormat[_, _]], - codec: Class[_ <: CompressionCodec]) { - saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, - new JobConf(self.context.hadoopConfiguration), Some(codec)) - } - - /** - * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class - * supporting the key and value types K and V in this RDD. - */ - def saveAsHadoopFile( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: OutputFormat[_, _]], - conf: JobConf = new JobConf(self.context.hadoopConfiguration), - codec: Option[Class[_ <: CompressionCodec]] = None) { - conf.setOutputKeyClass(keyClass) - conf.setOutputValueClass(valueClass) - // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug - conf.set("mapred.output.format.class", outputFormatClass.getName) - for (c <- codec) { - conf.setCompressMapOutput(true) - conf.set("mapred.output.compress", "true") - conf.setMapOutputCompressorClass(c) - conf.set("mapred.output.compression.codec", c.getCanonicalName) - conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) - } - conf.setOutputCommitter(classOf[FileOutputCommitter]) - FileOutputFormat.setOutputPath(conf, SparkHadoopWriter.createPathFromString(path, conf)) - saveAsHadoopDataset(conf) - } - - /** - * Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for - * that storage system. The JobConf should set an OutputFormat and any output paths required - * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop - * MapReduce job. - */ - def saveAsHadoopDataset(conf: JobConf) { - val outputFormatClass = conf.getOutputFormat - val keyClass = conf.getOutputKeyClass - val valueClass = conf.getOutputValueClass - if (outputFormatClass == null) { - throw new SparkException("Output format class not set") - } - if (keyClass == null) { - throw new SparkException("Output key class not set") - } - if (valueClass == null) { - throw new SparkException("Output value class not set") - } - - logInfo("Saving as hadoop file of type (" + keyClass.getSimpleName+ ", " + valueClass.getSimpleName+ ")") - - val writer = new SparkHadoopWriter(conf) - writer.preSetup() - - def writeToFile(context: TaskContext, iter: Iterator[(K, V)]) { - // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it - // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt - - writer.setup(context.stageId, context.splitId, attemptNumber) - writer.open() - - var count = 0 - while(iter.hasNext) { - val record = iter.next() - count += 1 - writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) - } - - writer.close() - writer.commit() - } - - self.context.runJob(self, writeToFile _) - writer.commitJob() - writer.cleanup() - } - - /** - * Return an RDD with the keys of each tuple. - */ - def keys: RDD[K] = self.map(_._1) - - /** - * Return an RDD with the values of each tuple. - */ - def values: RDD[V] = self.map(_._2) - - private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure - - private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure -} - - -private[spark] object Manifests { - val seqSeqManifest = classManifest[Seq[Seq[_]]] -} diff --git a/core/src/main/scala/spark/Partition.scala b/core/src/main/scala/spark/Partition.scala deleted file mode 100644 index 2a4edcec98..0000000000 --- a/core/src/main/scala/spark/Partition.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -/** - * A partition of an RDD. - */ -trait Partition extends Serializable { - /** - * Get the split's index within its parent RDD - */ - def index: Int - - // A better default implementation of HashCode - override def hashCode(): Int = index -} diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala deleted file mode 100644 index 65da8235d7..0000000000 --- a/core/src/main/scala/spark/Partitioner.scala +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -/** - * An object that defines how the elements in a key-value pair RDD are partitioned by key. - * Maps each key to a partition ID, from 0 to `numPartitions - 1`. - */ -abstract class Partitioner extends Serializable { - def numPartitions: Int - def getPartition(key: Any): Int -} - -object Partitioner { - /** - * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. - * - * If any of the RDDs already has a partitioner, choose that one. - * - * Otherwise, we use a default HashPartitioner. For the number of partitions, if - * spark.default.parallelism is set, then we'll use the value from SparkContext - * defaultParallelism, otherwise we'll use the max number of upstream partitions. - * - * Unless spark.default.parallelism is set, He number of partitions will be the - * same as the number of partitions in the largest upstream RDD, as this should - * be least likely to cause out-of-memory errors. - * - * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD. - */ - def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { - val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse - for (r <- bySize if r.partitioner != None) { - return r.partitioner.get - } - if (System.getProperty("spark.default.parallelism") != null) { - return new HashPartitioner(rdd.context.defaultParallelism) - } else { - return new HashPartitioner(bySize.head.partitions.size) - } - } -} - -/** - * A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`. - * - * Java arrays have hashCodes that are based on the arrays' identities rather than their contents, - * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will - * produce an unexpected or incorrect result. - */ -class HashPartitioner(partitions: Int) extends Partitioner { - def numPartitions = partitions - - def getPartition(key: Any): Int = key match { - case null => 0 - case _ => Utils.nonNegativeMod(key.hashCode, numPartitions) - } - - override def equals(other: Any): Boolean = other match { - case h: HashPartitioner => - h.numPartitions == numPartitions - case _ => - false - } -} - -/** - * A [[spark.Partitioner]] that partitions sortable records by range into roughly equal ranges. - * Determines the ranges by sampling the RDD passed in. - */ -class RangePartitioner[K <% Ordered[K]: ClassManifest, V]( - partitions: Int, - @transient rdd: RDD[_ <: Product2[K,V]], - private val ascending: Boolean = true) - extends Partitioner { - - // An array of upper bounds for the first (partitions - 1) partitions - private val rangeBounds: Array[K] = { - if (partitions == 1) { - Array() - } else { - val rddSize = rdd.count() - val maxSampleSize = partitions * 20.0 - val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0) - val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sortWith(_ < _) - if (rddSample.length == 0) { - Array() - } else { - val bounds = new Array[K](partitions - 1) - for (i <- 0 until partitions - 1) { - val index = (rddSample.length - 1) * (i + 1) / partitions - bounds(i) = rddSample(index) - } - bounds - } - } - } - - def numPartitions = partitions - - def getPartition(key: Any): Int = { - // TODO: Use a binary search here if number of partitions is large - val k = key.asInstanceOf[K] - var partition = 0 - while (partition < rangeBounds.length && k > rangeBounds(partition)) { - partition += 1 - } - if (ascending) { - partition - } else { - rangeBounds.length - partition - } - } - - override def equals(other: Any): Boolean = other match { - case r: RangePartitioner[_,_] => - r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending - case _ => - false - } -} diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala deleted file mode 100644 index 25a6951732..0000000000 --- a/core/src/main/scala/spark/RDD.scala +++ /dev/null @@ -1,957 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.util.Random - -import scala.collection.Map -import scala.collection.JavaConversions.mapAsScalaMap -import scala.collection.mutable.ArrayBuffer - -import org.apache.hadoop.io.BytesWritable -import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.TextOutputFormat - -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - -import spark.Partitioner._ -import spark.api.java.JavaRDD -import spark.partial.BoundedDouble -import spark.partial.CountEvaluator -import spark.partial.GroupedCountEvaluator -import spark.partial.PartialResult -import spark.rdd.CoalescedRDD -import spark.rdd.CartesianRDD -import spark.rdd.FilteredRDD -import spark.rdd.FlatMappedRDD -import spark.rdd.GlommedRDD -import spark.rdd.MappedRDD -import spark.rdd.MapPartitionsRDD -import spark.rdd.MapPartitionsWithIndexRDD -import spark.rdd.PipedRDD -import spark.rdd.SampledRDD -import spark.rdd.ShuffledRDD -import spark.rdd.UnionRDD -import spark.rdd.ZippedRDD -import spark.rdd.ZippedPartitionsRDD2 -import spark.rdd.ZippedPartitionsRDD3 -import spark.rdd.ZippedPartitionsRDD4 -import spark.storage.StorageLevel -import spark.util.BoundedPriorityQueue - -import SparkContext._ - -/** - * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, - * partitioned collection of elements that can be operated on in parallel. This class contains the - * basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition, - * [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value pairs, such - * as `groupByKey` and `join`; [[spark.DoubleRDDFunctions]] contains operations available only on - * RDDs of Doubles; and [[spark.SequenceFileRDDFunctions]] contains operations available on RDDs - * that can be saved as SequenceFiles. These operations are automatically available on any RDD of - * the right type (e.g. RDD[(Int, Int)] through implicit conversions when you - * `import spark.SparkContext._`. - * - * Internally, each RDD is characterized by five main properties: - * - * - A list of partitions - * - A function for computing each split - * - A list of dependencies on other RDDs - * - Optionally, a Partitioner for key-value RDDs (e.g. to say that the RDD is hash-partitioned) - * - Optionally, a list of preferred locations to compute each split on (e.g. block locations for - * an HDFS file) - * - * All of the scheduling and execution in Spark is done based on these methods, allowing each RDD - * to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for - * reading data from a new storage system) by overriding these functions. Please refer to the - * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details - * on RDD internals. - */ -abstract class RDD[T: ClassManifest]( - @transient private var sc: SparkContext, - @transient private var deps: Seq[Dependency[_]] - ) extends Serializable with Logging { - - /** Construct an RDD with just a one-to-one dependency on one parent */ - def this(@transient oneParent: RDD[_]) = - this(oneParent.context , List(new OneToOneDependency(oneParent))) - - // ======================================================================= - // Methods that should be implemented by subclasses of RDD - // ======================================================================= - - /** Implemented by subclasses to compute a given partition. */ - def compute(split: Partition, context: TaskContext): Iterator[T] - - /** - * Implemented by subclasses to return the set of partitions in this RDD. This method will only - * be called once, so it is safe to implement a time-consuming computation in it. - */ - protected def getPartitions: Array[Partition] - - /** - * Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only - * be called once, so it is safe to implement a time-consuming computation in it. - */ - protected def getDependencies: Seq[Dependency[_]] = deps - - /** Optionally overridden by subclasses to specify placement preferences. */ - protected def getPreferredLocations(split: Partition): Seq[String] = Nil - - /** Optionally overridden by subclasses to specify how they are partitioned. */ - val partitioner: Option[Partitioner] = None - - // ======================================================================= - // Methods and fields available on all RDDs - // ======================================================================= - - /** The SparkContext that created this RDD. */ - def sparkContext: SparkContext = sc - - /** A unique ID for this RDD (within its SparkContext). */ - val id: Int = sc.newRddId() - - /** A friendly name for this RDD */ - var name: String = null - - /** Assign a name to this RDD */ - def setName(_name: String) = { - name = _name - this - } - - /** User-defined generator of this RDD*/ - var generator = Utils.getCallSiteInfo.firstUserClass - - /** Reset generator*/ - def setGenerator(_generator: String) = { - generator = _generator - } - - /** - * Set this RDD's storage level to persist its values across operations after the first time - * it is computed. This can only be used to assign a new storage level if the RDD does not - * have a storage level set yet.. - */ - def persist(newLevel: StorageLevel): RDD[T] = { - // TODO: Handle changes of StorageLevel - if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) { - throw new UnsupportedOperationException( - "Cannot change storage level of an RDD after it was already assigned a level") - } - storageLevel = newLevel - // Register the RDD with the SparkContext - sc.persistentRdds(id) = this - this - } - - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY) - - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - def cache(): RDD[T] = persist() - - /** - * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. - * - * @param blocking Whether to block until all blocks are deleted. - * @return This RDD. - */ - def unpersist(blocking: Boolean = true): RDD[T] = { - logInfo("Removing RDD " + id + " from persistence list") - sc.env.blockManager.master.removeRdd(id, blocking) - sc.persistentRdds.remove(id) - storageLevel = StorageLevel.NONE - this - } - - /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ - def getStorageLevel = storageLevel - - // Our dependencies and partitions will be gotten by calling subclass's methods below, and will - // be overwritten when we're checkpointed - private var dependencies_ : Seq[Dependency[_]] = null - @transient private var partitions_ : Array[Partition] = null - - /** An Option holding our checkpoint RDD, if we are checkpointed */ - private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD) - - /** - * Get the list of dependencies of this RDD, taking into account whether the - * RDD is checkpointed or not. - */ - final def dependencies: Seq[Dependency[_]] = { - checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse { - if (dependencies_ == null) { - dependencies_ = getDependencies - } - dependencies_ - } - } - - /** - * Get the array of partitions of this RDD, taking into account whether the - * RDD is checkpointed or not. - */ - final def partitions: Array[Partition] = { - checkpointRDD.map(_.partitions).getOrElse { - if (partitions_ == null) { - partitions_ = getPartitions - } - partitions_ - } - } - - /** - * Get the preferred locations of a partition (as hostnames), taking into account whether the - * RDD is checkpointed. - */ - final def preferredLocations(split: Partition): Seq[String] = { - checkpointRDD.map(_.getPreferredLocations(split)).getOrElse { - getPreferredLocations(split) - } - } - - /** - * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. - * This should ''not'' be called by users directly, but is available for implementors of custom - * subclasses of RDD. - */ - final def iterator(split: Partition, context: TaskContext): Iterator[T] = { - if (storageLevel != StorageLevel.NONE) { - SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) - } else { - computeOrReadCheckpoint(split, context) - } - } - - /** - * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing. - */ - private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = { - if (isCheckpointed) { - firstParent[T].iterator(split, context) - } else { - compute(split, context) - } - } - - // Transformations (return a new RDD) - - /** - * Return a new RDD by applying a function to all elements of this RDD. - */ - def map[U: ClassManifest](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f)) - - /** - * Return a new RDD by first applying a function to all elements of this - * RDD, and then flattening the results. - */ - def flatMap[U: ClassManifest](f: T => TraversableOnce[U]): RDD[U] = - new FlatMappedRDD(this, sc.clean(f)) - - /** - * Return a new RDD containing only the elements that satisfy a predicate. - */ - def filter(f: T => Boolean): RDD[T] = new FilteredRDD(this, sc.clean(f)) - - /** - * Return a new RDD containing the distinct elements in this RDD. - */ - def distinct(numPartitions: Int): RDD[T] = - map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1) - - def distinct(): RDD[T] = distinct(partitions.size) - - /** - * Return a new RDD that is reduced into `numPartitions` partitions. - */ - def coalesce(numPartitions: Int, shuffle: Boolean = false): RDD[T] = { - if (shuffle) { - // include a shuffle step so that our upstream tasks are still distributed - new CoalescedRDD( - new ShuffledRDD[T, Null, (T, Null)](map(x => (x, null)), - new HashPartitioner(numPartitions)), - numPartitions).keys - } else { - new CoalescedRDD(this, numPartitions) - } - } - - /** - * Return a sampled subset of this RDD. - */ - def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = - new SampledRDD(this, withReplacement, fraction, seed) - - def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = { - var fraction = 0.0 - var total = 0 - val multiplier = 3.0 - val initialCount = this.count() - var maxSelected = 0 - - if (num < 0) { - throw new IllegalArgumentException("Negative number of elements requested") - } - - if (initialCount > Integer.MAX_VALUE - 1) { - maxSelected = Integer.MAX_VALUE - 1 - } else { - maxSelected = initialCount.toInt - } - - if (num > initialCount && !withReplacement) { - total = maxSelected - fraction = multiplier * (maxSelected + 1) / initialCount - } else { - fraction = multiplier * (num + 1) / initialCount - total = num - } - - val rand = new Random(seed) - var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() - - // If the first sample didn't turn out large enough, keep trying to take samples; - // this shouldn't happen often because we use a big multiplier for thei initial size - while (samples.length < total) { - samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() - } - - Utils.randomizeInPlace(samples, rand).take(total) - } - - /** - * Return the union of this RDD and another one. Any identical elements will appear multiple - * times (use `.distinct()` to eliminate them). - */ - def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other)) - - /** - * Return the union of this RDD and another one. Any identical elements will appear multiple - * times (use `.distinct()` to eliminate them). - */ - def ++(other: RDD[T]): RDD[T] = this.union(other) - - /** - * Return an RDD created by coalescing all elements within each partition into an array. - */ - def glom(): RDD[Array[T]] = new GlommedRDD(this) - - /** - * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of - * elements (a, b) where a is in `this` and b is in `other`. - */ - def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other) - - /** - * Return an RDD of grouped items. - */ - def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] = - groupBy[K](f, defaultPartitioner(this)) - - /** - * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements - * mapping to that key. - */ - def groupBy[K: ClassManifest](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] = - groupBy(f, new HashPartitioner(numPartitions)) - - /** - * Return an RDD of grouped items. - */ - def groupBy[K: ClassManifest](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = { - val cleanF = sc.clean(f) - this.map(t => (cleanF(t), t)).groupByKey(p) - } - - /** - * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: String): RDD[String] = new PipedRDD(this, command) - - /** - * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: String, env: Map[String, String]): RDD[String] = - new PipedRDD(this, command, env) - - - /** - * Return an RDD created by piping elements to a forked external process. - * The print behavior can be customized by providing two functions. - * - * @param command command to run in forked process. - * @param env environment variables to set. - * @param printPipeContext Before piping elements, this function is called as an oppotunity - * to pipe context data. Print line function (like out.println) will be - * passed as printPipeContext's parameter. - * @param printRDDElement Use this function to customize how to pipe elements. This function - * will be called with each RDD element as the 1st parameter, and the - * print line function (like out.println()) as the 2nd parameter. - * An example of pipe the RDD data of groupBy() in a streaming way, - * instead of constructing a huge String to concat all the elements: - * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = - * for (e <- record._2){f(e)} - * @return the result RDD - */ - def pipe( - command: Seq[String], - env: Map[String, String] = Map(), - printPipeContext: (String => Unit) => Unit = null, - printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = - new PipedRDD(this, command, env, - if (printPipeContext ne null) sc.clean(printPipeContext) else null, - if (printRDDElement ne null) sc.clean(printRDDElement) else null) - - /** - * Return a new RDD by applying a function to each partition of this RDD. - */ - def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning) - - /** - * Return a new RDD by applying a function to each partition of this RDD, while tracking the index - * of the original partition. - */ - def mapPartitionsWithIndex[U: ClassManifest]( - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) - - /** - * Return a new RDD by applying a function to each partition of this RDD, while tracking the index - * of the original partition. - */ - @deprecated("use mapPartitionsWithIndex", "0.7.0") - def mapPartitionsWithSplit[U: ClassManifest]( - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) - - /** - * Maps f over this RDD, where f takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false) - (f:(T, A) => U): RDD[U] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { - val a = constructA(index) - iter.map(t => f(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) - } - - /** - * FlatMaps f over this RDD, where f takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false) - (f:(T, A) => Seq[U]): RDD[U] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { - val a = constructA(index) - iter.flatMap(t => f(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) - } - - /** - * Applies f to each element of this RDD, where f takes an additional parameter of type A. - * This additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - def foreachWith[A: ClassManifest](constructA: Int => A) - (f:(T, A) => Unit) { - def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { - val a = constructA(index) - iter.map(t => {f(t, a); t}) - } - (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {}) - } - - /** - * Filters this RDD with p, where p takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - def filterWith[A: ClassManifest](constructA: Int => A) - (p:(T, A) => Boolean): RDD[T] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { - val a = constructA(index) - iter.filter(t => p(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true) - } - - /** - * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, - * second element in each RDD, etc. Assumes that the two RDDs have the *same number of - * partitions* and the *same number of elements in each partition* (e.g. one was made through - * a map on the other). - */ - def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) - - /** - * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by - * applying a function to the zipped partitions. Assumes that all the RDDs have the - * *same number of partitions*, but does *not* require them to have the same number - * of elements in each partition. - */ - def zipPartitions[B: ClassManifest, V: ClassManifest] - (rdd2: RDD[B]) - (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2) - - def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest] - (rdd2: RDD[B], rdd3: RDD[C]) - (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3) - - def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest] - (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D]) - (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4) - - - // Actions (launch a job to return a value to the user program) - - /** - * Applies a function f to all elements of this RDD. - */ - def foreach(f: T => Unit) { - val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) - } - - /** - * Applies a function f to each partition of this RDD. - */ - def foreachPartition(f: Iterator[T] => Unit) { - val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) - } - - /** - * Return an array that contains all of the elements in this RDD. - */ - def collect(): Array[T] = { - val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray) - Array.concat(results: _*) - } - - /** - * Return an array that contains all of the elements in this RDD. - */ - def toArray(): Array[T] = collect() - - /** - * Return an RDD that contains all matching values by applying `f`. - */ - def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = { - filter(f.isDefinedAt).map(f) - } - - /** - * Return an RDD with the elements from `this` that are not in `other`. - * - * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. - */ - def subtract(other: RDD[T]): RDD[T] = - subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size))) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: RDD[T], numPartitions: Int): RDD[T] = - subtract(other, new HashPartitioner(numPartitions)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: RDD[T], p: Partitioner): RDD[T] = { - if (partitioner == Some(p)) { - // Our partitioner knows how to handle T (which, since we have a partitioner, is - // really (K, V)) so make a new Partitioner that will de-tuple our fake tuples - val p2 = new Partitioner() { - override def numPartitions = p.numPartitions - override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1) - } - // Unfortunately, since we're making a new p2, we'll get ShuffleDependencies - // anyway, and when calling .keys, will not have a partitioner set, even though - // the SubtractedRDD will, thanks to p2's de-tupled partitioning, already be - // partitioned by the right/real keys (e.g. p). - this.map(x => (x, null)).subtractByKey(other.map((_, null)), p2).keys - } else { - this.map(x => (x, null)).subtractByKey(other.map((_, null)), p).keys - } - } - - /** - * Reduces the elements of this RDD using the specified commutative and associative binary operator. - */ - def reduce(f: (T, T) => T): T = { - val cleanF = sc.clean(f) - val reducePartition: Iterator[T] => Option[T] = iter => { - if (iter.hasNext) { - Some(iter.reduceLeft(cleanF)) - } else { - None - } - } - var jobResult: Option[T] = None - val mergeResult = (index: Int, taskResult: Option[T]) => { - if (taskResult != None) { - jobResult = jobResult match { - case Some(value) => Some(f(value, taskResult.get)) - case None => taskResult - } - } - } - sc.runJob(this, reducePartition, mergeResult) - // Get the final result out of our Option, or throw an exception if the RDD was empty - jobResult.getOrElse(throw new UnsupportedOperationException("empty collection")) - } - - /** - * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to - * modify t1 and return it as its result value to avoid object allocation; however, it should not - * modify t2. - */ - def fold(zeroValue: T)(op: (T, T) => T): T = { - // Clone the zero value since we will also be serializing it as part of tasks - var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) - val cleanOp = sc.clean(op) - val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp) - val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult) - sc.runJob(this, foldPartition, mergeResult) - jobResult - } - - /** - * Aggregate the elements of each partition, and then the results for all the partitions, using - * given combine functions and a neutral "zero value". This function can return a different result - * type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U - * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are - * allowed to modify and return their first argument instead of creating a new U to avoid memory - * allocation. - */ - def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { - // Clone the zero value since we will also be serializing it as part of tasks - var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) - val cleanSeqOp = sc.clean(seqOp) - val cleanCombOp = sc.clean(combOp) - val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) - val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult) - sc.runJob(this, aggregatePartition, mergeResult) - jobResult - } - - /** - * Return the number of elements in the RDD. - */ - def count(): Long = { - sc.runJob(this, (iter: Iterator[T]) => { - var result = 0L - while (iter.hasNext) { - result += 1L - iter.next() - } - result - }).sum - } - - /** - * (Experimental) Approximate version of count() that returns a potentially incomplete result - * within a timeout, even if not all tasks have finished. - */ - def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { - val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) => - var result = 0L - while (iter.hasNext) { - result += 1L - iter.next() - } - result - } - val evaluator = new CountEvaluator(partitions.size, confidence) - sc.runApproximateJob(this, countElements, evaluator, timeout) - } - - /** - * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final - * combine step happens locally on the master, equivalent to running a single reduce task. - */ - def countByValue(): Map[T, Long] = { - if (elementClassManifest.erasure.isArray) { - throw new SparkException("countByValue() does not support arrays") - } - // TODO: This should perhaps be distributed by default. - def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = { - val map = new OLMap[T] - while (iter.hasNext) { - val v = iter.next() - map.put(v, map.getLong(v) + 1L) - } - Iterator(map) - } - def mergeMaps(m1: OLMap[T], m2: OLMap[T]): OLMap[T] = { - val iter = m2.object2LongEntrySet.fastIterator() - while (iter.hasNext) { - val entry = iter.next() - m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue) - } - return m1 - } - val myResult = mapPartitions(countPartition).reduce(mergeMaps) - myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map - } - - /** - * (Experimental) Approximate version of countByValue(). - */ - def countByValueApprox( - timeout: Long, - confidence: Double = 0.95 - ): PartialResult[Map[T, BoundedDouble]] = { - if (elementClassManifest.erasure.isArray) { - throw new SparkException("countByValueApprox() does not support arrays") - } - val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) => - val map = new OLMap[T] - while (iter.hasNext) { - val v = iter.next() - map.put(v, map.getLong(v) + 1L) - } - map - } - val evaluator = new GroupedCountEvaluator[T](partitions.size, confidence) - sc.runApproximateJob(this, countPartition, evaluator, timeout) - } - - /** - * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so - * it will be slow if a lot of partitions are required. In that case, use collect() to get the - * whole RDD instead. - */ - def take(num: Int): Array[T] = { - if (num == 0) { - return new Array[T](0) - } - val buf = new ArrayBuffer[T] - var p = 0 - while (buf.size < num && p < partitions.size) { - val left = num - buf.size - val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, Array(p), true) - buf ++= res(0) - if (buf.size == num) - return buf.toArray - p += 1 - } - return buf.toArray - } - - /** - * Return the first element in this RDD. - */ - def first(): T = take(1) match { - case Array(t) => t - case _ => throw new UnsupportedOperationException("empty collection") - } - - /** - * Returns the top K elements from this RDD as defined by - * the specified implicit Ordering[T]. - * @param num the number of top elements to return - * @param ord the implicit ordering for T - * @return an array of top elements - */ - def top(num: Int)(implicit ord: Ordering[T]): Array[T] = { - mapPartitions { items => - val queue = new BoundedPriorityQueue[T](num) - queue ++= items - Iterator.single(queue) - }.reduce { (queue1, queue2) => - queue1 ++= queue2 - queue1 - }.toArray.sorted(ord.reverse) - } - - /** - * Returns the first K elements from this RDD as defined by - * the specified implicit Ordering[T] and maintains the - * ordering. - * @param num the number of top elements to return - * @param ord the implicit ordering for T - * @return an array of top elements - */ - def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = top(num)(ord.reverse) - - /** - * Save this RDD as a text file, using string representations of elements. - */ - def saveAsTextFile(path: String) { - this.map(x => (NullWritable.get(), new Text(x.toString))) - .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path) - } - - /** - * Save this RDD as a compressed text file, using string representations of elements. - */ - def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) { - this.map(x => (NullWritable.get(), new Text(x.toString))) - .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec) - } - - /** - * Save this RDD as a SequenceFile of serialized objects. - */ - def saveAsObjectFile(path: String) { - this.mapPartitions(iter => iter.grouped(10).map(_.toArray)) - .map(x => (NullWritable.get(), new BytesWritable(Utils.serialize(x)))) - .saveAsSequenceFile(path) - } - - /** - * Creates tuples of the elements in this RDD by applying `f`. - */ - def keyBy[K](f: T => K): RDD[(K, T)] = { - map(x => (f(x), x)) - } - - /** A private method for tests, to look at the contents of each partition */ - private[spark] def collectPartitions(): Array[Array[T]] = { - sc.runJob(this, (iter: Iterator[T]) => iter.toArray) - } - - /** - * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint - * directory set with SparkContext.setCheckpointDir() and all references to its parent - * RDDs will be removed. This function must be called before any job has been - * executed on this RDD. It is strongly recommended that this RDD is persisted in - * memory, otherwise saving it on a file will require recomputation. - */ - def checkpoint() { - if (context.checkpointDir.isEmpty) { - throw new Exception("Checkpoint directory has not been set in the SparkContext") - } else if (checkpointData.isEmpty) { - checkpointData = Some(new RDDCheckpointData(this)) - checkpointData.get.markForCheckpoint() - } - } - - /** - * Return whether this RDD has been checkpointed or not - */ - def isCheckpointed: Boolean = { - checkpointData.map(_.isCheckpointed).getOrElse(false) - } - - /** - * Gets the name of the file to which this RDD was checkpointed - */ - def getCheckpointFile: Option[String] = { - checkpointData.flatMap(_.getCheckpointFile) - } - - // ======================================================================= - // Other internal methods and fields - // ======================================================================= - - private var storageLevel: StorageLevel = StorageLevel.NONE - - /** Record user function generating this RDD. */ - private[spark] val origin = Utils.formatSparkCallSite - - private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] - - private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None - - /** Returns the first parent RDD */ - protected[spark] def firstParent[U: ClassManifest] = { - dependencies.head.rdd.asInstanceOf[RDD[U]] - } - - /** The [[spark.SparkContext]] that this RDD was created on. */ - def context = sc - - // Avoid handling doCheckpoint multiple times to prevent excessive recursion - private var doCheckpointCalled = false - - /** - * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler - * after a job using this RDD has completed (therefore the RDD has been materialized and - * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. - */ - private[spark] def doCheckpoint() { - if (!doCheckpointCalled) { - doCheckpointCalled = true - if (checkpointData.isDefined) { - checkpointData.get.doCheckpoint() - } else { - dependencies.foreach(_.rdd.doCheckpoint()) - } - } - } - - /** - * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`) - * created from the checkpoint file, and forget its old dependencies and partitions. - */ - private[spark] def markCheckpointed(checkpointRDD: RDD[_]) { - clearDependencies() - partitions_ = null - deps = null // Forget the constructor argument for dependencies too - } - - /** - * Clears the dependencies of this RDD. This method must ensure that all references - * to the original parent RDDs is removed to enable the parent RDDs to be garbage - * collected. Subclasses of RDD may override this method for implementing their own cleaning - * logic. See [[spark.rdd.UnionRDD]] for an example. - */ - protected def clearDependencies() { - dependencies_ = null - } - - /** A description of this RDD and its recursive dependencies for debugging. */ - def toDebugString: String = { - def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = { - Seq(prefix + rdd + " (" + rdd.partitions.size + " partitions)") ++ - rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " ")) - } - debugString(this).mkString("\n") - } - - override def toString: String = "%s%s[%d] at %s".format( - Option(name).map(_ + " ").getOrElse(""), - getClass.getSimpleName, - id, - origin) - - def toJavaRDD() : JavaRDD[T] = { - new JavaRDD(this)(elementClassManifest) - } - -} diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala deleted file mode 100644 index b615f820eb..0000000000 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.conf.Configuration -import rdd.{CheckpointRDD, CoalescedRDD} -import scheduler.{ResultTask, ShuffleMapTask} - -/** - * Enumeration to manage state transitions of an RDD through checkpointing - * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ] - */ -private[spark] object CheckpointState extends Enumeration { - type CheckpointState = Value - val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value -} - -/** - * This class contains all the information related to RDD checkpointing. Each instance of this class - * is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as, - * manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations - * of the checkpointed RDD. - */ -private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) - extends Logging with Serializable { - - import CheckpointState._ - - // The checkpoint state of the associated RDD. - var cpState = Initialized - - // The file to which the associated RDD has been checkpointed to - @transient var cpFile: Option[String] = None - - // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. - var cpRDD: Option[RDD[T]] = None - - // Mark the RDD for checkpointing - def markForCheckpoint() { - RDDCheckpointData.synchronized { - if (cpState == Initialized) cpState = MarkedForCheckpoint - } - } - - // Is the RDD already checkpointed - def isCheckpointed: Boolean = { - RDDCheckpointData.synchronized { cpState == Checkpointed } - } - - // Get the file to which this RDD was checkpointed to as an Option - def getCheckpointFile: Option[String] = { - RDDCheckpointData.synchronized { cpFile } - } - - // Do the checkpointing of the RDD. Called after the first job using that RDD is over. - def doCheckpoint() { - // If it is marked for checkpointing AND checkpointing is not already in progress, - // then set it to be in progress, else return - RDDCheckpointData.synchronized { - if (cpState == MarkedForCheckpoint) { - cpState = CheckpointingInProgress - } else { - return - } - } - - // Create the output path for the checkpoint - val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) - val fs = path.getFileSystem(new Configuration()) - if (!fs.mkdirs(path)) { - throw new SparkException("Failed to create checkpoint path " + path) - } - - // Save to file, and reload it as an RDD - rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _) - val newRDD = new CheckpointRDD[T](rdd.context, path.toString) - - // Change the dependencies and partitions of the RDD - RDDCheckpointData.synchronized { - cpFile = Some(path.toString) - cpRDD = Some(newRDD) - rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions - cpState = Checkpointed - RDDCheckpointData.clearTaskCaches() - logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) - } - } - - // Get preferred location of a split after checkpointing - def getPreferredLocations(split: Partition): Seq[String] = { - RDDCheckpointData.synchronized { - cpRDD.get.preferredLocations(split) - } - } - - def getPartitions: Array[Partition] = { - RDDCheckpointData.synchronized { - cpRDD.get.partitions - } - } - - def checkpointRDD: Option[RDD[T]] = { - RDDCheckpointData.synchronized { - cpRDD - } - } -} - -private[spark] object RDDCheckpointData { - def clearTaskCaches() { - ShuffleMapTask.clearCache() - ResultTask.clearCache() - } -} diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala deleted file mode 100644 index 9f30b7f22f..0000000000 --- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io.EOFException -import java.net.URL -import java.io.ObjectInputStream -import java.util.concurrent.atomic.AtomicLong -import java.util.HashSet -import java.util.Random -import java.util.Date - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.Map -import scala.collection.mutable.HashMap - -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.OutputFormat -import org.apache.hadoop.mapred.TextOutputFormat -import org.apache.hadoop.mapred.SequenceFileOutputFormat -import org.apache.hadoop.mapred.OutputCommitter -import org.apache.hadoop.mapred.FileOutputCommitter -import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.io.Writable -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.BytesWritable -import org.apache.hadoop.io.Text - -import spark.SparkContext._ - -/** - * Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile, - * through an implicit conversion. Note that this can't be part of PairRDDFunctions because - * we need more implicit parameters to convert our keys and values to Writable. - * - * Users should import `spark.SparkContext._` at the top of their program to use these functions. - */ -class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : ClassManifest]( - self: RDD[(K, V)]) - extends Logging - with Serializable { - - private def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = { - val c = { - if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) { - classManifest[T].erasure - } else { - // We get the type of the Writable class by looking at the apply method which converts - // from T to Writable. Since we have two apply methods we filter out the one which - // is not of the form "java.lang.Object apply(java.lang.Object)" - implicitly[T => Writable].getClass.getDeclaredMethods().filter( - m => m.getReturnType().toString != "class java.lang.Object" && - m.getName() == "apply")(0).getReturnType - - } - // TODO: use something like WritableConverter to avoid reflection - } - c.asInstanceOf[Class[_ <: Writable]] - } - - /** - * Output the RDD as a Hadoop SequenceFile using the Writable types we infer from the RDD's key - * and value types. If the key or value are Writable, then we use their classes directly; - * otherwise we map primitive types such as Int and Double to IntWritable, DoubleWritable, etc, - * byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported - * file system. - */ - def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) { - def anyToWritable[U <% Writable](u: U): Writable = u - - val keyClass = getWritableClass[K] - val valueClass = getWritableClass[V] - val convertKey = !classOf[Writable].isAssignableFrom(self.getKeyClass) - val convertValue = !classOf[Writable].isAssignableFrom(self.getValueClass) - - logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" ) - val format = classOf[SequenceFileOutputFormat[Writable, Writable]] - val jobConf = new JobConf(self.context.hadoopConfiguration) - if (!convertKey && !convertValue) { - self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec) - } else if (!convertKey && convertValue) { - self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) - } else if (convertKey && !convertValue) { - self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) - } else if (convertKey && convertValue) { - self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( - path, keyClass, valueClass, format, jobConf, codec) - } - } -} diff --git a/core/src/main/scala/spark/SerializableWritable.scala b/core/src/main/scala/spark/SerializableWritable.scala deleted file mode 100644 index 936d8e6241..0000000000 --- a/core/src/main/scala/spark/SerializableWritable.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io._ - -import org.apache.hadoop.io.ObjectWritable -import org.apache.hadoop.io.Writable -import org.apache.hadoop.conf.Configuration - -class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable { - def value = t - override def toString = t.toString - - private def writeObject(out: ObjectOutputStream) { - out.defaultWriteObject() - new ObjectWritable(t).write(out) - } - - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - val ow = new ObjectWritable() - ow.setConf(new Configuration()) - ow.readFields(in) - t = ow.get().asInstanceOf[T] - } -} diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala deleted file mode 100644 index a6839cf7a4..0000000000 --- a/core/src/main/scala/spark/ShuffleFetcher.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import spark.executor.TaskMetrics -import spark.serializer.Serializer - - -private[spark] abstract class ShuffleFetcher { - - /** - * Fetch the shuffle outputs for a given ShuffleDependency. - * @return An iterator over the elements of the fetched shuffle outputs. - */ - def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, - serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T] - - /** Stop the fetcher */ - def stop() {} -} diff --git a/core/src/main/scala/spark/SizeEstimator.scala b/core/src/main/scala/spark/SizeEstimator.scala deleted file mode 100644 index 6cc57566d7..0000000000 --- a/core/src/main/scala/spark/SizeEstimator.scala +++ /dev/null @@ -1,283 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.lang.reflect.Field -import java.lang.reflect.Modifier -import java.lang.reflect.{Array => JArray} -import java.util.IdentityHashMap -import java.util.concurrent.ConcurrentHashMap -import java.util.Random - -import javax.management.MBeanServer -import java.lang.management.ManagementFactory - -import scala.collection.mutable.ArrayBuffer - -import it.unimi.dsi.fastutil.ints.IntOpenHashSet - -/** - * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in - * memory-aware caches. - * - * Based on the following JavaWorld article: - * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html - */ -private[spark] object SizeEstimator extends Logging { - - // Sizes of primitive types - private val BYTE_SIZE = 1 - private val BOOLEAN_SIZE = 1 - private val CHAR_SIZE = 2 - private val SHORT_SIZE = 2 - private val INT_SIZE = 4 - private val LONG_SIZE = 8 - private val FLOAT_SIZE = 4 - private val DOUBLE_SIZE = 8 - - // Alignment boundary for objects - // TODO: Is this arch dependent ? - private val ALIGN_SIZE = 8 - - // A cache of ClassInfo objects for each class - private val classInfos = new ConcurrentHashMap[Class[_], ClassInfo] - - // Object and pointer sizes are arch dependent - private var is64bit = false - - // Size of an object reference - // Based on https://wikis.oracle.com/display/HotSpotInternals/CompressedOops - private var isCompressedOops = false - private var pointerSize = 4 - - // Minimum size of a java.lang.Object - private var objectSize = 8 - - initialize() - - // Sets object size, pointer size based on architecture and CompressedOops settings - // from the JVM. - private def initialize() { - is64bit = System.getProperty("os.arch").contains("64") - isCompressedOops = getIsCompressedOops - - objectSize = if (!is64bit) 8 else { - if(!isCompressedOops) { - 16 - } else { - 12 - } - } - pointerSize = if (is64bit && !isCompressedOops) 8 else 4 - classInfos.clear() - classInfos.put(classOf[Object], new ClassInfo(objectSize, Nil)) - } - - private def getIsCompressedOops : Boolean = { - if (System.getProperty("spark.test.useCompressedOops") != null) { - return System.getProperty("spark.test.useCompressedOops").toBoolean - } - - try { - val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic" - val server = ManagementFactory.getPlatformMBeanServer() - - // NOTE: This should throw an exception in non-Sun JVMs - val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean") - val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption", - Class.forName("java.lang.String")) - - val bean = ManagementFactory.newPlatformMXBeanProxy(server, - hotSpotMBeanName, hotSpotMBeanClass) - // TODO: We could use reflection on the VMOption returned ? - return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") - } catch { - case e: Exception => { - // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB - val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024) - val guessInWords = if (guess) "yes" else "not" - logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords) - return guess - } - } - } - - /** - * The state of an ongoing size estimation. Contains a stack of objects to visit as well as an - * IdentityHashMap of visited objects, and provides utility methods for enqueueing new objects - * to visit. - */ - private class SearchState(val visited: IdentityHashMap[AnyRef, AnyRef]) { - val stack = new ArrayBuffer[AnyRef] - var size = 0L - - def enqueue(obj: AnyRef) { - if (obj != null && !visited.containsKey(obj)) { - visited.put(obj, null) - stack += obj - } - } - - def isFinished(): Boolean = stack.isEmpty - - def dequeue(): AnyRef = { - val elem = stack.last - stack.trimEnd(1) - return elem - } - } - - /** - * Cached information about each class. We remember two things: the "shell size" of the class - * (size of all non-static fields plus the java.lang.Object size), and any fields that are - * pointers to objects. - */ - private class ClassInfo( - val shellSize: Long, - val pointerFields: List[Field]) {} - - def estimate(obj: AnyRef): Long = estimate(obj, new IdentityHashMap[AnyRef, AnyRef]) - - private def estimate(obj: AnyRef, visited: IdentityHashMap[AnyRef, AnyRef]): Long = { - val state = new SearchState(visited) - state.enqueue(obj) - while (!state.isFinished) { - visitSingleObject(state.dequeue(), state) - } - return state.size - } - - private def visitSingleObject(obj: AnyRef, state: SearchState) { - val cls = obj.getClass - if (cls.isArray) { - visitArray(obj, cls, state) - } else if (obj.isInstanceOf[ClassLoader] || obj.isInstanceOf[Class[_]]) { - // Hadoop JobConfs created in the interpreter have a ClassLoader, which greatly confuses - // the size estimator since it references the whole REPL. Do nothing in this case. In - // general all ClassLoaders and Classes will be shared between objects anyway. - } else { - val classInfo = getClassInfo(cls) - state.size += classInfo.shellSize - for (field <- classInfo.pointerFields) { - state.enqueue(field.get(obj)) - } - } - } - - // Estimat the size of arrays larger than ARRAY_SIZE_FOR_SAMPLING by sampling. - private val ARRAY_SIZE_FOR_SAMPLING = 200 - private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING - - private def visitArray(array: AnyRef, cls: Class[_], state: SearchState) { - val length = JArray.getLength(array) - val elementClass = cls.getComponentType - - // Arrays have object header and length field which is an integer - var arrSize: Long = alignSize(objectSize + INT_SIZE) - - if (elementClass.isPrimitive) { - arrSize += alignSize(length * primitiveSize(elementClass)) - state.size += arrSize - } else { - arrSize += alignSize(length * pointerSize) - state.size += arrSize - - if (length <= ARRAY_SIZE_FOR_SAMPLING) { - for (i <- 0 until length) { - state.enqueue(JArray.get(array, i)) - } - } else { - // Estimate the size of a large array by sampling elements without replacement. - var size = 0.0 - val rand = new Random(42) - val drawn = new IntOpenHashSet(ARRAY_SAMPLE_SIZE) - for (i <- 0 until ARRAY_SAMPLE_SIZE) { - var index = 0 - do { - index = rand.nextInt(length) - } while (drawn.contains(index)) - drawn.add(index) - val elem = JArray.get(array, index) - size += SizeEstimator.estimate(elem, state.visited) - } - state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong - } - } - } - - private def primitiveSize(cls: Class[_]): Long = { - if (cls == classOf[Byte]) - BYTE_SIZE - else if (cls == classOf[Boolean]) - BOOLEAN_SIZE - else if (cls == classOf[Char]) - CHAR_SIZE - else if (cls == classOf[Short]) - SHORT_SIZE - else if (cls == classOf[Int]) - INT_SIZE - else if (cls == classOf[Long]) - LONG_SIZE - else if (cls == classOf[Float]) - FLOAT_SIZE - else if (cls == classOf[Double]) - DOUBLE_SIZE - else throw new IllegalArgumentException( - "Non-primitive class " + cls + " passed to primitiveSize()") - } - - /** - * Get or compute the ClassInfo for a given class. - */ - private def getClassInfo(cls: Class[_]): ClassInfo = { - // Check whether we've already cached a ClassInfo for this class - val info = classInfos.get(cls) - if (info != null) { - return info - } - - val parent = getClassInfo(cls.getSuperclass) - var shellSize = parent.shellSize - var pointerFields = parent.pointerFields - - for (field <- cls.getDeclaredFields) { - if (!Modifier.isStatic(field.getModifiers)) { - val fieldClass = field.getType - if (fieldClass.isPrimitive) { - shellSize += primitiveSize(fieldClass) - } else { - field.setAccessible(true) // Enable future get()'s on this field - shellSize += pointerSize - pointerFields = field :: pointerFields - } - } - } - - shellSize = alignSize(shellSize) - - // Create and cache a new ClassInfo - val newInfo = new ClassInfo(shellSize, pointerFields) - classInfos.put(cls, newInfo) - return newInfo - } - - private def alignSize(size: Long): Long = { - val rem = size % ALIGN_SIZE - return if (rem == 0) size else (size + ALIGN_SIZE - rem) - } -} diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala deleted file mode 100644 index 7ce9505b9c..0000000000 --- a/core/src/main/scala/spark/SparkContext.scala +++ /dev/null @@ -1,995 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io._ -import java.net.URI -import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.Map -import scala.collection.generic.Growable -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.util.DynamicVariable - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.ArrayWritable -import org.apache.hadoop.io.BooleanWritable -import org.apache.hadoop.io.BytesWritable -import org.apache.hadoop.io.DoubleWritable -import org.apache.hadoop.io.FloatWritable -import org.apache.hadoop.io.IntWritable -import org.apache.hadoop.io.LongWritable -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapred.FileInputFormat -import org.apache.hadoop.mapred.InputFormat -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.SequenceFileInputFormat -import org.apache.hadoop.mapred.TextInputFormat -import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.hadoop.mapreduce.{Job => NewHadoopJob} -import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} - -import org.apache.mesos.MesosNativeLibrary - -import spark.deploy.LocalSparkCluster -import spark.partial.{ApproximateEvaluator, PartialResult} -import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD, - OrderedRDDFunctions} -import spark.scheduler._ -import spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, - ClusterScheduler, Schedulable, SchedulingMode} -import spark.scheduler.local.LocalScheduler -import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import spark.storage.{StorageStatus, StorageUtils, RDDInfo, BlockManagerSource} -import spark.ui.SparkUI -import spark.util.{MetadataCleaner, TimeStampedHashMap} -import scala.Some -import spark.scheduler.StageInfo -import spark.storage.RDDInfo -import spark.storage.StorageStatus - -/** - * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark - * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster. - * - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param appName A name for your application, to display on the cluster web UI. - * @param sparkHome Location where Spark is installed on cluster nodes. - * @param jars Collection of JARs to send to the cluster. These can be paths on the local file - * system or HDFS, HTTP, HTTPS, or FTP URLs. - * @param environment Environment variables to set on worker nodes. - */ -class SparkContext( - val master: String, - val appName: String, - val sparkHome: String = null, - val jars: Seq[String] = Nil, - val environment: Map[String, String] = Map(), - // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) too. - // This is typically generated from InputFormatInfo.computePreferredLocations .. host, set of data-local splits on host - val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map()) - extends Logging { - - // Ensure logging is initialized before we spawn any threads - initLogging() - - // Set Spark driver host and port system properties - if (System.getProperty("spark.driver.host") == null) { - System.setProperty("spark.driver.host", Utils.localHostName()) - } - if (System.getProperty("spark.driver.port") == null) { - System.setProperty("spark.driver.port", "0") - } - - val isLocal = (master == "local" || master.startsWith("local[")) - - // Create the Spark execution environment (cache, map output tracker, etc) - private[spark] val env = SparkEnv.createFromSystemProperties( - "", - System.getProperty("spark.driver.host"), - System.getProperty("spark.driver.port").toInt, - true, - isLocal) - SparkEnv.set(env) - - // Used to store a URL for each static file/jar together with the file's local timestamp - private[spark] val addedFiles = HashMap[String, Long]() - private[spark] val addedJars = HashMap[String, Long]() - - // Keeps track of all persisted RDDs - private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]] - private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup) - - // Initalize the Spark UI - private[spark] val ui = new SparkUI(this) - ui.bind() - - val startTime = System.currentTimeMillis() - - // Add each JAR given through the constructor - if (jars != null) { - jars.foreach { addJar(_) } - } - - // Environment variables to pass to our executors - private[spark] val executorEnvs = HashMap[String, String]() - // Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner - for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", "SPARK_TESTING")) { - val value = System.getenv(key) - if (value != null) { - executorEnvs(key) = value - } - } - // Since memory can be set with a system property too, use that - executorEnvs("SPARK_MEM") = SparkContext.executorMemoryRequested + "m" - if (environment != null) { - executorEnvs ++= environment - } - - // Create and start the scheduler - private var taskScheduler: TaskScheduler = { - // Regular expression used for local[N] master format - val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r - // Regular expression for local[N, maxRetries], used in tests with failing tasks - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r - // Regular expression for simulating a Spark cluster of [N, cores, memory] locally - val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r - // Regular expression for connecting to Spark deploy clusters - val SPARK_REGEX = """(spark://.*)""".r - //Regular expression for connection to Mesos cluster - val MESOS_REGEX = """(mesos://.*)""".r - - master match { - case "local" => - new LocalScheduler(1, 0, this) - - case LOCAL_N_REGEX(threads) => - new LocalScheduler(threads.toInt, 0, this) - - case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - new LocalScheduler(threads.toInt, maxFailures.toInt, this) - - case SPARK_REGEX(sparkUrl) => - val scheduler = new ClusterScheduler(this) - val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) - scheduler.initialize(backend) - scheduler - - case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => - // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. - val memoryPerSlaveInt = memoryPerSlave.toInt - if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) { - throw new SparkException( - "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format( - memoryPerSlaveInt, SparkContext.executorMemoryRequested)) - } - - val scheduler = new ClusterScheduler(this) - val localCluster = new LocalSparkCluster( - numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) - val sparkUrl = localCluster.start() - val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) - scheduler.initialize(backend) - backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { - localCluster.stop() - } - scheduler - - case "yarn-standalone" => - val scheduler = try { - val clazz = Class.forName("spark.scheduler.cluster.YarnClusterScheduler") - val cons = clazz.getConstructor(classOf[SparkContext]) - cons.newInstance(this).asInstanceOf[ClusterScheduler] - } catch { - // TODO: Enumerate the exact reasons why it can fail - // But irrespective of it, it means we cannot proceed ! - case th: Throwable => { - throw new SparkException("YARN mode not available ?", th) - } - } - val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem) - scheduler.initialize(backend) - scheduler - - case _ => - if (MESOS_REGEX.findFirstIn(master).isEmpty) { - logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master)) - } - MesosNativeLibrary.load() - val scheduler = new ClusterScheduler(this) - val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean - val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos:// - val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName) - } else { - new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName) - } - scheduler.initialize(backend) - scheduler - } - } - taskScheduler.start() - - @volatile private var dagScheduler = new DAGScheduler(taskScheduler) - dagScheduler.start() - - ui.start() - - /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ - val hadoopConfiguration = { - val env = SparkEnv.get - val conf = env.hadoop.newConfiguration() - // Explicitly check for S3 environment variables - if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) { - conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) - conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) - conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) - conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) - } - // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" - for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) { - conf.set(key.substring("spark.hadoop.".length), System.getProperty(key)) - } - val bufferSize = System.getProperty("spark.buffer.size", "65536") - conf.set("io.file.buffer.size", bufferSize) - conf - } - - private[spark] var checkpointDir: Option[String] = None - - // Thread Local variable that can be used by users to pass information down the stack - private val localProperties = new DynamicVariable[Properties](null) - - def initLocalProperties() { - localProperties.value = new Properties() - } - - def setLocalProperty(key: String, value: String) { - if (localProperties.value == null) { - localProperties.value = new Properties() - } - if (value == null) { - localProperties.value.remove(key) - } else { - localProperties.value.setProperty(key, value) - } - } - - /** Set a human readable description of the current job. */ - def setJobDescription(value: String) { - setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value) - } - - // Post init - taskScheduler.postStartHook() - - val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler) - val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager) - - def initDriverMetrics() { - SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource) - SparkEnv.get.metricsSystem.registerSource(blockManagerSource) - } - - initDriverMetrics() - - // Methods for creating RDDs - - /** Distribute a local Scala collection to form an RDD. */ - def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { - new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) - } - - /** Distribute a local Scala collection to form an RDD. */ - def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { - parallelize(seq, numSlices) - } - - /** Distribute a local Scala collection to form an RDD, with one or more - * location preferences (hostnames of Spark nodes) for each object. - * Create a new partition for each collection item. */ - def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = { - val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap - new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) - } - - /** - * Read a text file from HDFS, a local file system (available on all nodes), or any - * Hadoop-supported file system URI, and return it as an RDD of Strings. - */ - def textFile(path: String, minSplits: Int = defaultMinSplits): RDD[String] = { - hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], minSplits) - .map(pair => pair._2.toString) - } - - /** - * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf giving its InputFormat and any - * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, - * etc). - */ - def hadoopRDD[K, V]( - conf: JobConf, - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int = defaultMinSplits - ): RDD[(K, V)] = { - new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) - } - - /** Get an RDD for a Hadoop file with an arbitrary InputFormat */ - def hadoopFile[K, V]( - path: String, - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int = defaultMinSplits - ) : RDD[(K, V)] = { - val conf = new JobConf(hadoopConfiguration) - FileInputFormat.setInputPaths(conf, path) - new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) - } - - /** - * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys, - * values and the InputFormat so that users don't need to pass them directly. Instead, callers - * can just write, for example, - * {{{ - * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minSplits) - * }}} - */ - def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int) - (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]) - : RDD[(K, V)] = { - hadoopFile(path, - fm.erasure.asInstanceOf[Class[F]], - km.erasure.asInstanceOf[Class[K]], - vm.erasure.asInstanceOf[Class[V]], - minSplits) - } - - /** - * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys, - * values and the InputFormat so that users don't need to pass them directly. Instead, callers - * can just write, for example, - * {{{ - * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path) - * }}} - */ - def hadoopFile[K, V, F <: InputFormat[K, V]](path: String) - (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = - hadoopFile[K, V, F](path, defaultMinSplits) - - /** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */ - def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String) - (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = { - newAPIHadoopFile( - path, - fm.erasure.asInstanceOf[Class[F]], - km.erasure.asInstanceOf[Class[K]], - vm.erasure.asInstanceOf[Class[V]]) - } - - /** - * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat - * and extra configuration options to pass to the input format. - */ - def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]]( - path: String, - fClass: Class[F], - kClass: Class[K], - vClass: Class[V], - conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { - val job = new NewHadoopJob(conf) - NewFileInputFormat.addInputPath(job, new Path(path)) - val updatedConf = job.getConfiguration - new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf) - } - - /** - * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat - * and extra configuration options to pass to the input format. - */ - def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]]( - conf: Configuration = hadoopConfiguration, - fClass: Class[F], - kClass: Class[K], - vClass: Class[V]): RDD[(K, V)] = { - new NewHadoopRDD(this, fClass, kClass, vClass, conf) - } - - /** Get an RDD for a Hadoop SequenceFile with given key and value types. */ - def sequenceFile[K, V](path: String, - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int - ): RDD[(K, V)] = { - val inputFormatClass = classOf[SequenceFileInputFormat[K, V]] - hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits) - } - - /** Get an RDD for a Hadoop SequenceFile with given key and value types. */ - def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = - sequenceFile(path, keyClass, valueClass, defaultMinSplits) - - /** - * Version of sequenceFile() for types implicitly convertible to Writables through a - * WritableConverter. For example, to access a SequenceFile where the keys are Text and the - * values are IntWritable, you could simply write - * {{{ - * sparkContext.sequenceFile[String, Int](path, ...) - * }}} - * - * WritableConverters are provided in a somewhat strange way (by an implicit function) to support - * both subclasses of Writable and types for which we define a converter (e.g. Int to - * IntWritable). The most natural thing would've been to have implicit objects for the - * converters, but then we couldn't have an object for every subclass of Writable (you can't - * have a parameterized singleton object). We use functions instead to create a new converter - * for the appropriate type. In addition, we pass the converter a ClassManifest of its type to - * allow it to figure out the Writable class to use in the subclass case. - */ - def sequenceFile[K, V](path: String, minSplits: Int = defaultMinSplits) - (implicit km: ClassManifest[K], vm: ClassManifest[V], - kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]) - : RDD[(K, V)] = { - val kc = kcf() - val vc = vcf() - val format = classOf[SequenceFileInputFormat[Writable, Writable]] - val writables = hadoopFile(path, format, - kc.writableClass(km).asInstanceOf[Class[Writable]], - vc.writableClass(vm).asInstanceOf[Class[Writable]], minSplits) - writables.map{case (k,v) => (kc.convert(k), vc.convert(v))} - } - - /** - * Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and - * BytesWritable values that contain a serialized partition. This is still an experimental storage - * format and may not be supported exactly as is in future Spark releases. It will also be pretty - * slow if you use the default serializer (Java serialization), though the nice thing about it is - * that there's very little effort required to save arbitrary objects. - */ - def objectFile[T: ClassManifest]( - path: String, - minSplits: Int = defaultMinSplits - ): RDD[T] = { - sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minSplits) - .flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes)) - } - - - protected[spark] def checkpointFile[T: ClassManifest]( - path: String - ): RDD[T] = { - new CheckpointRDD[T](this, path) - } - - /** Build the union of a list of RDDs. */ - def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds) - - /** Build the union of a list of RDDs passed as variable-length arguments. */ - def union[T: ClassManifest](first: RDD[T], rest: RDD[T]*): RDD[T] = - new UnionRDD(this, Seq(first) ++ rest) - - // Methods for creating shared variables - - /** - * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values - * to using the `+=` method. Only the driver can access the accumulator's `value`. - */ - def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = - new Accumulator(initialValue, param) - - /** - * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`. - * Only the driver can access the accumuable's `value`. - * @tparam T accumulator type - * @tparam R type that can be added to the accumulator - */ - def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = - new Accumulable(initialValue, param) - - /** - * Create an accumulator from a "mutable collection" type. - * - * Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by - * standard mutable collections. So you can use this with mutable Map, Set, etc. - */ - def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = { - val param = new GrowableAccumulableParam[R,T] - new Accumulable(initialValue, param) - } - - /** - * Broadcast a read-only variable to the cluster, returning a [[spark.broadcast.Broadcast]] object for - * reading it in distributed functions. The variable will be sent to each cluster only once. - */ - def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) - - /** - * Add a file to be downloaded with this Spark job on every node. - * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, - * use `SparkFiles.get(path)` to find its download location. - */ - def addFile(path: String) { - val uri = new URI(path) - val key = uri.getScheme match { - case null | "file" => env.httpFileServer.addFile(new File(uri.getPath)) - case _ => path - } - addedFiles(key) = System.currentTimeMillis - - // Fetch the file locally in case a job is executed locally. - // Jobs that run through LocalScheduler will already fetch the required dependencies, - // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here. - Utils.fetchFile(path, new File(SparkFiles.getRootDirectory)) - - logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) - } - - def addSparkListener(listener: SparkListener) { - dagScheduler.addSparkListener(listener) - } - - /** - * Return a map from the slave to the max memory available for caching and the remaining - * memory available for caching. - */ - def getExecutorMemoryStatus: Map[String, (Long, Long)] = { - env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => - (blockManagerId.host + ":" + blockManagerId.port, mem) - } - } - - /** - * Return information about what RDDs are cached, if they are in mem or on disk, how much space - * they take, etc. - */ - def getRDDStorageInfo: Array[RDDInfo] = { - StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this) - } - - /** - * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. - */ - def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap - - def getStageInfo: Map[Stage,StageInfo] = { - dagScheduler.stageToInfos - } - - /** - * Return information about blocks stored in all of the slaves - */ - def getExecutorStorageStatus: Array[StorageStatus] = { - env.blockManager.master.getStorageStatus - } - - /** - * Return pools for fair scheduler - * TODO(xiajunluan): We should take nested pools into account - */ - def getAllPools: ArrayBuffer[Schedulable] = { - taskScheduler.rootPool.schedulableQueue - } - - /** - * Return the pool associated with the given name, if one exists - */ - def getPoolForName(pool: String): Option[Schedulable] = { - taskScheduler.rootPool.schedulableNameToSchedulable.get(pool) - } - - /** - * Return current scheduling mode - */ - def getSchedulingMode: SchedulingMode.SchedulingMode = { - taskScheduler.schedulingMode - } - - /** - * Clear the job's list of files added by `addFile` so that they do not get downloaded to - * any new nodes. - */ - def clearFiles() { - addedFiles.clear() - } - - /** - * Gets the locality information associated with the partition in a particular rdd - * @param rdd of interest - * @param partition to be looked up for locality - * @return list of preferred locations for the partition - */ - private [spark] def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = { - dagScheduler.getPreferredLocs(rdd, partition) - } - - /** - * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. - * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. - */ - def addJar(path: String) { - if (null == path) { - logWarning("null specified as parameter to addJar", - new SparkException("null specified as parameter to addJar")) - } else { - val env = SparkEnv.get - val uri = new URI(path) - val key = uri.getScheme match { - case null | "file" => - if (env.hadoop.isYarnMode()) { - logWarning("local jar specified as parameter to addJar under Yarn mode") - return - } - env.httpFileServer.addJar(new File(uri.getPath)) - case _ => path - } - addedJars(key) = System.currentTimeMillis - logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key)) - } - } - - /** - * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to - * any new nodes. - */ - def clearJars() { - addedJars.clear() - } - - /** Shut down the SparkContext. */ - def stop() { - ui.stop() - // Do this only if not stopped already - best case effort. - // prevent NPE if stopped more than once. - val dagSchedulerCopy = dagScheduler - dagScheduler = null - if (dagSchedulerCopy != null) { - metadataCleaner.cancel() - dagSchedulerCopy.stop() - taskScheduler = null - // TODO: Cache.stop()? - env.stop() - // Clean up locally linked files - clearFiles() - clearJars() - SparkEnv.set(null) - ShuffleMapTask.clearCache() - ResultTask.clearCache() - logInfo("Successfully stopped SparkContext") - } else { - logInfo("SparkContext already stopped") - } - } - - - /** - * Get Spark's home location from either a value set through the constructor, - * or the spark.home Java property, or the SPARK_HOME environment variable - * (in that order of preference). If neither of these is set, return None. - */ - private[spark] def getSparkHome(): Option[String] = { - if (sparkHome != null) { - Some(sparkHome) - } else if (System.getProperty("spark.home") != null) { - Some(System.getProperty("spark.home")) - } else if (System.getenv("SPARK_HOME") != null) { - Some(System.getenv("SPARK_HOME")) - } else { - None - } - } - - /** - * Run a function on a given set of partitions in an RDD and pass the results to the given - * handler function. This is the main entry point for all actions in Spark. The allowLocal - * flag specifies whether the scheduler can run the computation on the driver rather than - * shipping it out to the cluster, for short actions like first(). - */ - def runJob[T, U: ClassManifest]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitions: Seq[Int], - allowLocal: Boolean, - resultHandler: (Int, U) => Unit) { - val callSite = Utils.formatSparkCallSite - logInfo("Starting job: " + callSite) - val start = System.nanoTime - val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value) - logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") - rdd.doCheckpoint() - result - } - - /** - * Run a function on a given set of partitions in an RDD and return the results as an array. The - * allowLocal flag specifies whether the scheduler can run the computation on the driver rather - * than shipping it out to the cluster, for short actions like first(). - */ - def runJob[T, U: ClassManifest]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitions: Seq[Int], - allowLocal: Boolean - ): Array[U] = { - val results = new Array[U](partitions.size) - runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res) - results - } - - /** - * Run a job on a given set of partitions of an RDD, but take a function of type - * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. - */ - def runJob[T, U: ClassManifest]( - rdd: RDD[T], - func: Iterator[T] => U, - partitions: Seq[Int], - allowLocal: Boolean - ): Array[U] = { - runJob(rdd, (context: TaskContext, iter: Iterator[T]) => func(iter), partitions, allowLocal) - } - - /** - * Run a job on all partitions in an RDD and return the results in an array. - */ - def runJob[T, U: ClassManifest](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = { - runJob(rdd, func, 0 until rdd.partitions.size, false) - } - - /** - * Run a job on all partitions in an RDD and return the results in an array. - */ - def runJob[T, U: ClassManifest](rdd: RDD[T], func: Iterator[T] => U): Array[U] = { - runJob(rdd, func, 0 until rdd.partitions.size, false) - } - - /** - * Run a job on all partitions in an RDD and pass the results to a handler function. - */ - def runJob[T, U: ClassManifest]( - rdd: RDD[T], - processPartition: (TaskContext, Iterator[T]) => U, - resultHandler: (Int, U) => Unit) - { - runJob[T, U](rdd, processPartition, 0 until rdd.partitions.size, false, resultHandler) - } - - /** - * Run a job on all partitions in an RDD and pass the results to a handler function. - */ - def runJob[T, U: ClassManifest]( - rdd: RDD[T], - processPartition: Iterator[T] => U, - resultHandler: (Int, U) => Unit) - { - val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter) - runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler) - } - - /** - * Run a job that can return approximate results. - */ - def runApproximateJob[T, U, R]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - evaluator: ApproximateEvaluator[U, R], - timeout: Long): PartialResult[R] = { - val callSite = Utils.formatSparkCallSite - logInfo("Starting job: " + callSite) - val start = System.nanoTime - val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value) - logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") - result - } - - /** - * Clean a closure to make it ready to serialized and send to tasks - * (removes unreferenced variables in $outer's, updates REPL variables) - */ - private[spark] def clean[F <: AnyRef](f: F): F = { - ClosureCleaner.clean(f) - return f - } - - /** - * Set the directory under which RDDs are going to be checkpointed. The directory must - * be a HDFS path if running on a cluster. If the directory does not exist, it will - * be created. If the directory exists and useExisting is set to true, then the - * exisiting directory will be used. Otherwise an exception will be thrown to - * prevent accidental overriding of checkpoint files in the existing directory. - */ - def setCheckpointDir(dir: String, useExisting: Boolean = false) { - val env = SparkEnv.get - val path = new Path(dir) - val fs = path.getFileSystem(env.hadoop.newConfiguration()) - if (!useExisting) { - if (fs.exists(path)) { - throw new Exception("Checkpoint directory '" + path + "' already exists.") - } else { - fs.mkdirs(path) - } - } - checkpointDir = Some(dir) - } - - /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ - def defaultParallelism: Int = taskScheduler.defaultParallelism - - /** Default min number of partitions for Hadoop RDDs when not given by user */ - def defaultMinSplits: Int = math.min(defaultParallelism, 2) - - private val nextShuffleId = new AtomicInteger(0) - - private[spark] def newShuffleId(): Int = nextShuffleId.getAndIncrement() - - private val nextRddId = new AtomicInteger(0) - - /** Register a new RDD, returning its RDD ID */ - private[spark] def newRddId(): Int = nextRddId.getAndIncrement() - - /** Called by MetadataCleaner to clean up the persistentRdds map periodically */ - private[spark] def cleanup(cleanupTime: Long) { - persistentRdds.clearOldValues(cleanupTime) - } -} - -/** - * The SparkContext object contains a number of implicit conversions and parameters for use with - * various Spark features. - */ -object SparkContext { - val SPARK_JOB_DESCRIPTION = "spark.job.description" - - implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { - def addInPlace(t1: Double, t2: Double): Double = t1 + t2 - def zero(initialValue: Double) = 0.0 - } - - implicit object IntAccumulatorParam extends AccumulatorParam[Int] { - def addInPlace(t1: Int, t2: Int): Int = t1 + t2 - def zero(initialValue: Int) = 0 - } - - implicit object LongAccumulatorParam extends AccumulatorParam[Long] { - def addInPlace(t1: Long, t2: Long) = t1 + t2 - def zero(initialValue: Long) = 0l - } - - implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { - def addInPlace(t1: Float, t2: Float) = t1 + t2 - def zero(initialValue: Float) = 0f - } - - // TODO: Add AccumulatorParams for other types, e.g. lists and strings - - implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) = - new PairRDDFunctions(rdd) - - implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest]( - rdd: RDD[(K, V)]) = - new SequenceFileRDDFunctions(rdd) - - implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( - rdd: RDD[(K, V)]) = - new OrderedRDDFunctions[K, V, (K, V)](rdd) - - implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd) - - implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = - new DoubleRDDFunctions(rdd.map(x => num.toDouble(x))) - - // Implicit conversions to common Writable types, for saveAsSequenceFile - - implicit def intToIntWritable(i: Int) = new IntWritable(i) - - implicit def longToLongWritable(l: Long) = new LongWritable(l) - - implicit def floatToFloatWritable(f: Float) = new FloatWritable(f) - - implicit def doubleToDoubleWritable(d: Double) = new DoubleWritable(d) - - implicit def boolToBoolWritable (b: Boolean) = new BooleanWritable(b) - - implicit def bytesToBytesWritable (aob: Array[Byte]) = new BytesWritable(aob) - - implicit def stringToText(s: String) = new Text(s) - - private implicit def arrayToArrayWritable[T <% Writable: ClassManifest](arr: Traversable[T]): ArrayWritable = { - def anyToWritable[U <% Writable](u: U): Writable = u - - new ArrayWritable(classManifest[T].erasure.asInstanceOf[Class[Writable]], - arr.map(x => anyToWritable(x)).toArray) - } - - // Helper objects for converting common types to Writable - private def simpleWritableConverter[T, W <: Writable: ClassManifest](convert: W => T) = { - val wClass = classManifest[W].erasure.asInstanceOf[Class[W]] - new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W])) - } - - implicit def intWritableConverter() = simpleWritableConverter[Int, IntWritable](_.get) - - implicit def longWritableConverter() = simpleWritableConverter[Long, LongWritable](_.get) - - implicit def doubleWritableConverter() = simpleWritableConverter[Double, DoubleWritable](_.get) - - implicit def floatWritableConverter() = simpleWritableConverter[Float, FloatWritable](_.get) - - implicit def booleanWritableConverter() = simpleWritableConverter[Boolean, BooleanWritable](_.get) - - implicit def bytesWritableConverter() = simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes) - - implicit def stringWritableConverter() = simpleWritableConverter[String, Text](_.toString) - - implicit def writableWritableConverter[T <: Writable]() = - new WritableConverter[T](_.erasure.asInstanceOf[Class[T]], _.asInstanceOf[T]) - - /** - * Find the JAR from which a given class was loaded, to make it easy for users to pass - * their JARs to SparkContext - */ - def jarOfClass(cls: Class[_]): Seq[String] = { - val uri = cls.getResource("/" + cls.getName.replace('.', '/') + ".class") - if (uri != null) { - val uriStr = uri.toString - if (uriStr.startsWith("jar:file:")) { - // URI will be of the form "jar:file:/path/foo.jar!/package/cls.class", so pull out the /path/foo.jar - List(uriStr.substring("jar:file:".length, uriStr.indexOf('!'))) - } else { - Nil - } - } else { - Nil - } - } - - /** Find the JAR that contains the class of a particular object */ - def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass) - - /** Get the amount of memory per executor requested through system properties or SPARK_MEM */ - private[spark] val executorMemoryRequested = { - // TODO: Might need to add some extra memory for the non-heap parts of the JVM - Option(System.getProperty("spark.executor.memory")) - .orElse(Option(System.getenv("SPARK_MEM"))) - .map(Utils.memoryStringToMb) - .getOrElse(512) - } -} - -/** - * A class encapsulating how to convert some type T to Writable. It stores both the Writable class - * corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion. - * The getter for the writable class takes a ClassManifest[T] in case this is a generic object - * that doesn't know the type of T when it is created. This sounds strange but is necessary to - * support converting subclasses of Writable to themselves (writableWritableConverter). - */ -private[spark] class WritableConverter[T]( - val writableClass: ClassManifest[T] => Class[_ <: Writable], - val convert: Writable => T) - extends Serializable - diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala deleted file mode 100644 index 1f66e9cc7f..0000000000 --- a/core/src/main/scala/spark/SparkEnv.scala +++ /dev/null @@ -1,241 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import collection.mutable -import serializer.Serializer - -import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem} -import akka.remote.RemoteActorRefProvider - -import spark.broadcast.BroadcastManager -import spark.metrics.MetricsSystem -import spark.deploy.SparkHadoopUtil -import spark.storage.BlockManager -import spark.storage.BlockManagerMaster -import spark.network.ConnectionManager -import spark.serializer.{Serializer, SerializerManager} -import spark.util.AkkaUtils -import spark.api.python.PythonWorkerFactory - - -/** - * Holds all the runtime environment objects for a running Spark instance (either master or worker), - * including the serializer, Akka actor system, block manager, map output tracker, etc. Currently - * Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these - * objects needs to have the right SparkEnv set. You can get the current environment with - * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set. - */ -class SparkEnv ( - val executorId: String, - val actorSystem: ActorSystem, - val serializerManager: SerializerManager, - val serializer: Serializer, - val closureSerializer: Serializer, - val cacheManager: CacheManager, - val mapOutputTracker: MapOutputTracker, - val shuffleFetcher: ShuffleFetcher, - val broadcastManager: BroadcastManager, - val blockManager: BlockManager, - val connectionManager: ConnectionManager, - val httpFileServer: HttpFileServer, - val sparkFilesDir: String, - val metricsSystem: MetricsSystem) { - - private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() - - val hadoop = { - val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) - if(yarnMode) { - try { - Class.forName("spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil] - } catch { - case th: Throwable => throw new SparkException("Unable to load YARN support", th) - } - } else { - new SparkHadoopUtil - } - } - - def stop() { - pythonWorkers.foreach { case(key, worker) => worker.stop() } - httpFileServer.stop() - mapOutputTracker.stop() - shuffleFetcher.stop() - broadcastManager.stop() - blockManager.stop() - blockManager.master.stop() - metricsSystem.stop() - actorSystem.shutdown() - // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut - // down, but let's call it anyway in case it gets fixed in a later release - actorSystem.awaitTermination() - } - - def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = { - synchronized { - val key = (pythonExec, envVars) - pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create() - } - } -} - -object SparkEnv extends Logging { - private val env = new ThreadLocal[SparkEnv] - @volatile private var lastSetSparkEnv : SparkEnv = _ - - def set(e: SparkEnv) { - lastSetSparkEnv = e - env.set(e) - } - - /** - * Returns the ThreadLocal SparkEnv, if non-null. Else returns the SparkEnv - * previously set in any thread. - */ - def get: SparkEnv = { - Option(env.get()).getOrElse(lastSetSparkEnv) - } - - /** - * Returns the ThreadLocal SparkEnv. - */ - def getThreadLocal : SparkEnv = { - env.get() - } - - def createFromSystemProperties( - executorId: String, - hostname: String, - port: Int, - isDriver: Boolean, - isLocal: Boolean): SparkEnv = { - - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port) - - // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port), - // figure out which port number Akka actually bound to and set spark.driver.port to it. - if (isDriver && port == 0) { - System.setProperty("spark.driver.port", boundPort.toString) - } - - // set only if unset until now. - if (System.getProperty("spark.hostPort", null) == null) { - if (!isDriver){ - // unexpected - Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set") - } - Utils.checkHost(hostname) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) - } - - val classLoader = Thread.currentThread.getContextClassLoader - - // Create an instance of the class named by the given Java system property, or by - // defaultClassName if the property is not set, and return it as a T - def instantiateClass[T](propertyName: String, defaultClassName: String): T = { - val name = System.getProperty(propertyName, defaultClassName) - Class.forName(name, true, classLoader).newInstance().asInstanceOf[T] - } - - val serializerManager = new SerializerManager - - val serializer = serializerManager.setDefault( - System.getProperty("spark.serializer", "spark.JavaSerializer")) - - val closureSerializer = serializerManager.get( - System.getProperty("spark.closure.serializer", "spark.JavaSerializer")) - - def registerOrLookup(name: String, newActor: => Actor): ActorRef = { - if (isDriver) { - logInfo("Registering " + name) - actorSystem.actorOf(Props(newActor), name = name) - } else { - val driverHost: String = System.getProperty("spark.driver.host", "localhost") - val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt - Utils.checkHost(driverHost, "Expected hostname") - val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name) - logInfo("Connecting to " + name + ": " + url) - actorSystem.actorFor(url) - } - } - - val blockManagerMaster = new BlockManagerMaster(registerOrLookup( - "BlockManagerMaster", - new spark.storage.BlockManagerMasterActor(isLocal))) - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer) - - val connectionManager = blockManager.connectionManager - - val broadcastManager = new BroadcastManager(isDriver) - - val cacheManager = new CacheManager(blockManager) - - // Have to assign trackerActor after initialization as MapOutputTrackerActor - // requires the MapOutputTracker itself - val mapOutputTracker = new MapOutputTracker() - mapOutputTracker.trackerActor = registerOrLookup( - "MapOutputTracker", - new MapOutputTrackerActor(mapOutputTracker)) - - val shuffleFetcher = instantiateClass[ShuffleFetcher]( - "spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") - - val httpFileServer = new HttpFileServer() - httpFileServer.initialize() - System.setProperty("spark.fileserver.uri", httpFileServer.serverUri) - - val metricsSystem = if (isDriver) { - MetricsSystem.createMetricsSystem("driver") - } else { - MetricsSystem.createMetricsSystem("executor") - } - metricsSystem.start() - - // Set the sparkFiles directory, used when downloading dependencies. In local mode, - // this is a temporary directory; in distributed mode, this is the executor's current working - // directory. - val sparkFilesDir: String = if (isDriver) { - Utils.createTempDir().getAbsolutePath - } else { - "." - } - - // Warn about deprecated spark.cache.class property - if (System.getProperty("spark.cache.class") != null) { - logWarning("The spark.cache.class property is no longer being used! Specify storage " + - "levels using the RDD.persist() method instead.") - } - - new SparkEnv( - executorId, - actorSystem, - serializerManager, - serializer, - closureSerializer, - cacheManager, - mapOutputTracker, - shuffleFetcher, - broadcastManager, - blockManager, - connectionManager, - httpFileServer, - sparkFilesDir, - metricsSystem) - } -} diff --git a/core/src/main/scala/spark/SparkException.scala b/core/src/main/scala/spark/SparkException.scala deleted file mode 100644 index b7045eea63..0000000000 --- a/core/src/main/scala/spark/SparkException.scala +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -class SparkException(message: String, cause: Throwable) - extends Exception(message, cause) { - - def this(message: String) = this(message, null) -} diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java deleted file mode 100644 index f9b3f7965e..0000000000 --- a/core/src/main/scala/spark/SparkFiles.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark; - -import java.io.File; - -/** - * Resolves paths to files added through `SparkContext.addFile()`. - */ -public class SparkFiles { - - private SparkFiles() {} - - /** - * Get the absolute path of a file added through `SparkContext.addFile()`. - */ - public static String get(String filename) { - return new File(getRootDirectory(), filename).getAbsolutePath(); - } - - /** - * Get the root directory that contains files added through `SparkContext.addFile()`. - */ - public static String getRootDirectory() { - return SparkEnv.get().sparkFilesDir(); - } -} diff --git a/core/src/main/scala/spark/SparkHadoopWriter.scala b/core/src/main/scala/spark/SparkHadoopWriter.scala deleted file mode 100644 index 6b330ef572..0000000000 --- a/core/src/main/scala/spark/SparkHadoopWriter.scala +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.hadoop.mapred - -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.fs.Path - -import java.text.SimpleDateFormat -import java.text.NumberFormat -import java.io.IOException -import java.util.Date - -import spark.Logging -import spark.SerializableWritable - -/** - * Internal helper class that saves an RDD using a Hadoop OutputFormat. This is only public - * because we need to access this class from the `spark` package to use some package-private Hadoop - * functions, but this class should not be used directly by users. - * - * Saves the RDD using a JobConf, which should contain an output key class, an output value class, - * a filename to write to, etc, exactly like in a Hadoop MapReduce job. - */ -class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable { - - private val now = new Date() - private val conf = new SerializableWritable(jobConf) - - private var jobID = 0 - private var splitID = 0 - private var attemptID = 0 - private var jID: SerializableWritable[JobID] = null - private var taID: SerializableWritable[TaskAttemptID] = null - - @transient private var writer: RecordWriter[AnyRef,AnyRef] = null - @transient private var format: OutputFormat[AnyRef,AnyRef] = null - @transient private var committer: OutputCommitter = null - @transient private var jobContext: JobContext = null - @transient private var taskContext: TaskAttemptContext = null - - def preSetup() { - setIDs(0, 0, 0) - setConfParams() - - val jCtxt = getJobContext() - getOutputCommitter().setupJob(jCtxt) - } - - - def setup(jobid: Int, splitid: Int, attemptid: Int) { - setIDs(jobid, splitid, attemptid) - setConfParams() - } - - def open() { - val numfmt = NumberFormat.getInstance() - numfmt.setMinimumIntegerDigits(5) - numfmt.setGroupingUsed(false) - - val outputName = "part-" + numfmt.format(splitID) - val path = FileOutputFormat.getOutputPath(conf.value) - val fs: FileSystem = { - if (path != null) { - path.getFileSystem(conf.value) - } else { - FileSystem.get(conf.value) - } - } - - getOutputCommitter().setupTask(getTaskContext()) - writer = getOutputFormat().getRecordWriter( - fs, conf.value, outputName, Reporter.NULL) - } - - def write(key: AnyRef, value: AnyRef) { - if (writer!=null) { - //println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")") - writer.write(key, value) - } else { - throw new IOException("Writer is null, open() has not been called") - } - } - - def close() { - writer.close(Reporter.NULL) - } - - def commit() { - val taCtxt = getTaskContext() - val cmtr = getOutputCommitter() - if (cmtr.needsTaskCommit(taCtxt)) { - try { - cmtr.commitTask(taCtxt) - logInfo (taID + ": Committed") - } catch { - case e: IOException => { - logError("Error committing the output of task: " + taID.value, e) - cmtr.abortTask(taCtxt) - throw e - } - } - } else { - logWarning ("No need to commit output of task: " + taID.value) - } - } - - def commitJob() { - // always ? Or if cmtr.needsTaskCommit ? - val cmtr = getOutputCommitter() - cmtr.commitJob(getJobContext()) - } - - def cleanup() { - getOutputCommitter().cleanupJob(getJobContext()) - } - - // ********* Private Functions ********* - - private def getOutputFormat(): OutputFormat[AnyRef,AnyRef] = { - if (format == null) { - format = conf.value.getOutputFormat() - .asInstanceOf[OutputFormat[AnyRef,AnyRef]] - } - return format - } - - private def getOutputCommitter(): OutputCommitter = { - if (committer == null) { - committer = conf.value.getOutputCommitter - } - return committer - } - - private def getJobContext(): JobContext = { - if (jobContext == null) { - jobContext = newJobContext(conf.value, jID.value) - } - return jobContext - } - - private def getTaskContext(): TaskAttemptContext = { - if (taskContext == null) { - taskContext = newTaskAttemptContext(conf.value, taID.value) - } - return taskContext - } - - private def setIDs(jobid: Int, splitid: Int, attemptid: Int) { - jobID = jobid - splitID = splitid - attemptID = attemptid - - jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid)) - taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) - } - - private def setConfParams() { - conf.value.set("mapred.job.id", jID.value.toString) - conf.value.set("mapred.tip.id", taID.value.getTaskID.toString) - conf.value.set("mapred.task.id", taID.value.toString) - conf.value.setBoolean("mapred.task.is.map", true) - conf.value.setInt("mapred.task.partition", splitID) - } -} - -object SparkHadoopWriter { - def createJobID(time: Date, id: Int): JobID = { - val formatter = new SimpleDateFormat("yyyyMMddHHmm") - val jobtrackerID = formatter.format(new Date()) - return new JobID(jobtrackerID, id) - } - - def createPathFromString(path: String, conf: JobConf): Path = { - if (path == null) { - throw new IllegalArgumentException("Output path is null") - } - var outputPath = new Path(path) - val fs = outputPath.getFileSystem(conf) - if (outputPath == null || fs == null) { - throw new IllegalArgumentException("Incorrectly formatted output path") - } - outputPath = outputPath.makeQualified(fs) - return outputPath - } -} diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala deleted file mode 100644 index b79f4ca813..0000000000 --- a/core/src/main/scala/spark/TaskContext.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import executor.TaskMetrics -import scala.collection.mutable.ArrayBuffer - -class TaskContext( - val stageId: Int, - val splitId: Int, - val attemptId: Long, - val taskMetrics: TaskMetrics = TaskMetrics.empty() -) extends Serializable { - - @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit] - - // Add a callback function to be executed on task completion. An example use - // is for HadoopRDD to register a callback to close the input stream. - def addOnCompleteCallback(f: () => Unit) { - onCompleteCallbacks += f - } - - def executeOnCompleteCallbacks() { - onCompleteCallbacks.foreach{_()} - } -} diff --git a/core/src/main/scala/spark/TaskEndReason.scala b/core/src/main/scala/spark/TaskEndReason.scala deleted file mode 100644 index 3ad665da34..0000000000 --- a/core/src/main/scala/spark/TaskEndReason.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import spark.executor.TaskMetrics -import spark.storage.BlockManagerId - -/** - * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry - * tasks several times for "ephemeral" failures, and only report back failures that require some - * old stages to be resubmitted, such as shuffle map fetch failures. - */ -private[spark] sealed trait TaskEndReason - -private[spark] case object Success extends TaskEndReason - -private[spark] -case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it - -private[spark] case class FetchFailed( - bmAddress: BlockManagerId, - shuffleId: Int, - mapId: Int, - reduceId: Int) - extends TaskEndReason - -private[spark] case class ExceptionFailure( - className: String, - description: String, - stackTrace: Array[StackTraceElement], - metrics: Option[TaskMetrics]) - extends TaskEndReason - -private[spark] case class OtherFailure(message: String) extends TaskEndReason - -private[spark] case class TaskResultTooBigFailure() extends TaskEndReason diff --git a/core/src/main/scala/spark/TaskState.scala b/core/src/main/scala/spark/TaskState.scala deleted file mode 100644 index bf75753056..0000000000 --- a/core/src/main/scala/spark/TaskState.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.apache.mesos.Protos.{TaskState => MesosTaskState} - -private[spark] object TaskState - extends Enumeration("LAUNCHING", "RUNNING", "FINISHED", "FAILED", "KILLED", "LOST") { - - val LAUNCHING, RUNNING, FINISHED, FAILED, KILLED, LOST = Value - - val FINISHED_STATES = Set(FINISHED, FAILED, KILLED, LOST) - - type TaskState = Value - - def isFinished(state: TaskState) = FINISHED_STATES.contains(state) - - def toMesos(state: TaskState): MesosTaskState = state match { - case LAUNCHING => MesosTaskState.TASK_STARTING - case RUNNING => MesosTaskState.TASK_RUNNING - case FINISHED => MesosTaskState.TASK_FINISHED - case FAILED => MesosTaskState.TASK_FAILED - case KILLED => MesosTaskState.TASK_KILLED - case LOST => MesosTaskState.TASK_LOST - } - - def fromMesos(mesosState: MesosTaskState): TaskState = mesosState match { - case MesosTaskState.TASK_STAGING => LAUNCHING - case MesosTaskState.TASK_STARTING => LAUNCHING - case MesosTaskState.TASK_RUNNING => RUNNING - case MesosTaskState.TASK_FINISHED => FINISHED - case MesosTaskState.TASK_FAILED => FAILED - case MesosTaskState.TASK_KILLED => KILLED - case MesosTaskState.TASK_LOST => LOST - } -} diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala deleted file mode 100644 index bb8aad3f4c..0000000000 --- a/core/src/main/scala/spark/Utils.scala +++ /dev/null @@ -1,780 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io._ -import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket} -import java.util.{Locale, Random, UUID} -import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} -import java.util.regex.Pattern - -import scala.collection.Map -import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.collection.JavaConversions._ -import scala.io.Source - -import com.google.common.io.Files -import com.google.common.util.concurrent.ThreadFactoryBuilder - -import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} - -import spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import spark.deploy.SparkHadoopUtil -import java.nio.ByteBuffer - - -/** - * Various utility methods used by Spark. - */ -private object Utils extends Logging { - - /** Serialize an object using Java serialization */ - def serialize[T](o: T): Array[Byte] = { - val bos = new ByteArrayOutputStream() - val oos = new ObjectOutputStream(bos) - oos.writeObject(o) - oos.close() - return bos.toByteArray - } - - /** Deserialize an object using Java serialization */ - def deserialize[T](bytes: Array[Byte]): T = { - val bis = new ByteArrayInputStream(bytes) - val ois = new ObjectInputStream(bis) - return ois.readObject.asInstanceOf[T] - } - - /** Deserialize an object using Java serialization and the given ClassLoader */ - def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { - val bis = new ByteArrayInputStream(bytes) - val ois = new ObjectInputStream(bis) { - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, loader) - } - return ois.readObject.asInstanceOf[T] - } - - /** Serialize via nested stream using specific serializer */ - def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)(f: SerializationStream => Unit) = { - val osWrapper = ser.serializeStream(new OutputStream { - def write(b: Int) = os.write(b) - - override def write(b: Array[Byte], off: Int, len: Int) = os.write(b, off, len) - }) - try { - f(osWrapper) - } finally { - osWrapper.close() - } - } - - /** Deserialize via nested stream using specific serializer */ - def deserializeViaNestedStream(is: InputStream, ser: SerializerInstance)(f: DeserializationStream => Unit) = { - val isWrapper = ser.deserializeStream(new InputStream { - def read(): Int = is.read() - - override def read(b: Array[Byte], off: Int, len: Int): Int = is.read(b, off, len) - }) - try { - f(isWrapper) - } finally { - isWrapper.close() - } - } - - /** - * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}. - */ - def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = { - if (bb.hasArray) { - out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) - } else { - val bbval = new Array[Byte](bb.remaining()) - bb.get(bbval) - out.write(bbval) - } - } - - def isAlpha(c: Char): Boolean = { - (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') - } - - /** Split a string into words at non-alphabetic characters */ - def splitWords(s: String): Seq[String] = { - val buf = new ArrayBuffer[String] - var i = 0 - while (i < s.length) { - var j = i - while (j < s.length && isAlpha(s.charAt(j))) { - j += 1 - } - if (j > i) { - buf += s.substring(i, j) - } - i = j - while (i < s.length && !isAlpha(s.charAt(i))) { - i += 1 - } - } - return buf - } - - private val shutdownDeletePaths = new collection.mutable.HashSet[String]() - - // Register the path to be deleted via shutdown hook - def registerShutdownDeleteDir(file: File) { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths += absolutePath - } - } - - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - shutdownDeletePaths.synchronized { - shutdownDeletePaths.contains(absolutePath) - } - } - - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in IOException and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: File): Boolean = { - val absolutePath = file.getAbsolutePath() - val retval = shutdownDeletePaths.synchronized { - shutdownDeletePaths.find { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - }.isDefined - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") - } - retval - } - - /** Create a temporary directory inside the given parent directory */ - def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = { - var attempts = 0 - val maxAttempts = 10 - var dir: File = null - while (dir == null) { - attempts += 1 - if (attempts > maxAttempts) { - throw new IOException("Failed to create a temp directory (under " + root + ") after " + - maxAttempts + " attempts!") - } - try { - dir = new File(root, "spark-" + UUID.randomUUID.toString) - if (dir.exists() || !dir.mkdirs()) { - dir = null - } - } catch { case e: IOException => ; } - } - - registerShutdownDeleteDir(dir) - - // Add a shutdown hook to delete the temp dir when the JVM exits - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) { - override def run() { - // Attempt to delete if some patch which is parent of this is not already registered. - if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir) - } - }) - dir - } - - /** Copy all data from an InputStream to an OutputStream */ - def copyStream(in: InputStream, - out: OutputStream, - closeStreams: Boolean = false) - { - val buf = new Array[Byte](8192) - var n = 0 - while (n != -1) { - n = in.read(buf) - if (n != -1) { - out.write(buf, 0, n) - } - } - if (closeStreams) { - in.close() - out.close() - } - } - - /** - * Download a file requested by the executor. Supports fetching the file in a variety of ways, - * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. - * - * Throws SparkException if the target file already exists and has different contents than - * the requested file. - */ - def fetchFile(url: String, targetDir: File) { - val filename = url.split("/").last - val tempDir = getLocalDir - val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) - val targetFile = new File(targetDir, filename) - val uri = new URI(url) - uri.getScheme match { - case "http" | "https" | "ftp" => - logInfo("Fetching " + url + " to " + tempFile) - val in = new URL(url).openStream() - val out = new FileOutputStream(tempFile) - Utils.copyStream(in, out, true) - if (targetFile.exists && !Files.equal(tempFile, targetFile)) { - tempFile.delete() - throw new SparkException( - "File " + targetFile + " exists and does not match contents of" + " " + url) - } else { - Files.move(tempFile, targetFile) - } - case "file" | null => - // In the case of a local file, copy the local file to the target directory. - // Note the difference between uri vs url. - val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) - if (targetFile.exists) { - // If the target file already exists, warn the user if - if (!Files.equal(sourceFile, targetFile)) { - throw new SparkException( - "File " + targetFile + " exists and does not match contents of" + " " + url) - } else { - // Do nothing if the file contents are the same, i.e. this file has been copied - // previously. - logInfo(sourceFile.getAbsolutePath + " has been previously copied to " - + targetFile.getAbsolutePath) - } - } else { - // The file does not exist in the target directory. Copy it there. - logInfo("Copying " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) - Files.copy(sourceFile, targetFile) - } - case _ => - // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others - val env = SparkEnv.get - val uri = new URI(url) - val conf = env.hadoop.newConfiguration() - val fs = FileSystem.get(uri, conf) - val in = fs.open(new Path(uri)) - val out = new FileOutputStream(tempFile) - Utils.copyStream(in, out, true) - if (targetFile.exists && !Files.equal(tempFile, targetFile)) { - tempFile.delete() - throw new SparkException("File " + targetFile + " exists and does not match contents of" + - " " + url) - } else { - Files.move(tempFile, targetFile) - } - } - // Decompress the file if it's a .tar or .tar.gz - if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) { - logInfo("Untarring " + filename) - Utils.execute(Seq("tar", "-xzf", filename), targetDir) - } else if (filename.endsWith(".tar")) { - logInfo("Untarring " + filename) - Utils.execute(Seq("tar", "-xf", filename), targetDir) - } - // Make the file executable - That's necessary for scripts - FileUtil.chmod(targetFile.getAbsolutePath, "a+x") - } - - /** - * Get a temporary directory using Spark's spark.local.dir property, if set. This will always - * return a single directory, even though the spark.local.dir property might be a list of - * multiple paths. - */ - def getLocalDir: String = { - System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0) - } - - /** - * Shuffle the elements of a collection into a random order, returning the - * result in a new collection. Unlike scala.util.Random.shuffle, this method - * uses a local random number generator, avoiding inter-thread contention. - */ - def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = { - randomizeInPlace(seq.toArray) - } - - /** - * Shuffle the elements of an array into a random order, modifying the - * original array. Returns the original array. - */ - def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = { - for (i <- (arr.length - 1) to 1 by -1) { - val j = rand.nextInt(i) - val tmp = arr(j) - arr(j) = arr(i) - arr(i) = tmp - } - arr - } - - /** - * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). - * Note, this is typically not used from within core spark. - */ - lazy val localIpAddress: String = findLocalIpAddress() - lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress) - - private def findLocalIpAddress(): String = { - val defaultIpOverride = System.getenv("SPARK_LOCAL_IP") - if (defaultIpOverride != null) { - defaultIpOverride - } else { - val address = InetAddress.getLocalHost - if (address.isLoopbackAddress) { - // Address resolves to something like 127.0.1.1, which happens on Debian; try to find - // a better address using the local network interfaces - for (ni <- NetworkInterface.getNetworkInterfaces) { - for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && - !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) { - // We've found an address that looks reasonable! - logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + - " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress + - " instead (on interface " + ni.getName + ")") - logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") - return addr.getHostAddress - } - } - logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + - " a loopback address: " + address.getHostAddress + ", but we couldn't find any" + - " external IP address!") - logWarning("Set SPARK_LOCAL_IP if you need to bind to another address") - } - address.getHostAddress - } - } - - private var customHostname: Option[String] = None - - /** - * Allow setting a custom host name because when we run on Mesos we need to use the same - * hostname it reports to the master. - */ - def setCustomHostname(hostname: String) { - // DEBUG code - Utils.checkHost(hostname) - customHostname = Some(hostname) - } - - /** - * Get the local machine's hostname. - */ - def localHostName(): String = { - customHostname.getOrElse(localIpAddressHostname) - } - - def getAddressHostName(address: String): String = { - InetAddress.getByName(address).getHostName - } - - def localHostPort(): String = { - val retval = System.getProperty("spark.hostPort", null) - if (retval == null) { - logErrorWithStack("spark.hostPort not set but invoking localHostPort") - return localHostName() - } - - retval - } - - def checkHost(host: String, message: String = "") { - assert(host.indexOf(':') == -1, message) - } - - def checkHostPort(hostPort: String, message: String = "") { - assert(hostPort.indexOf(':') != -1, message) - } - - // Used by DEBUG code : remove when all testing done - def logErrorWithStack(msg: String) { - try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } } - } - - // Typically, this will be of order of number of nodes in cluster - // If not, we should change it to LRUCache or something. - private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() - - def parseHostPort(hostPort: String): (String, Int) = { - { - // Check cache first. - var cached = hostPortParseResults.get(hostPort) - if (cached != null) return cached - } - - val indx: Int = hostPort.lastIndexOf(':') - // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... - // but then hadoop does not support ipv6 right now. - // For now, we assume that if port exists, then it is valid - not check if it is an int > 0 - if (-1 == indx) { - val retval = (hostPort, 0) - hostPortParseResults.put(hostPort, retval) - return retval - } - - val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt) - hostPortParseResults.putIfAbsent(hostPort, retval) - hostPortParseResults.get(hostPort) - } - - private[spark] val daemonThreadFactory: ThreadFactory = - new ThreadFactoryBuilder().setDaemon(true).build() - - /** - * Wrapper over newCachedThreadPool. - */ - def newDaemonCachedThreadPool(): ThreadPoolExecutor = - Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] - - /** - * Return the string to tell how long has passed in seconds. The passing parameter should be in - * millisecond. - */ - def getUsedTimeMs(startTimeMs: Long): String = { - return " " + (System.currentTimeMillis - startTimeMs) + " ms" - } - - /** - * Wrapper over newFixedThreadPool. - */ - def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = - Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] - - /** - * Delete a file or directory and its contents recursively. - */ - def deleteRecursively(file: File) { - if (file.isDirectory) { - for (child <- file.listFiles()) { - deleteRecursively(child) - } - } - if (!file.delete()) { - throw new IOException("Failed to delete: " + file) - } - } - - /** - * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. - * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM - * environment variable. - */ - def memoryStringToMb(str: String): Int = { - val lower = str.toLowerCase - if (lower.endsWith("k")) { - (lower.substring(0, lower.length-1).toLong / 1024).toInt - } else if (lower.endsWith("m")) { - lower.substring(0, lower.length-1).toInt - } else if (lower.endsWith("g")) { - lower.substring(0, lower.length-1).toInt * 1024 - } else if (lower.endsWith("t")) { - lower.substring(0, lower.length-1).toInt * 1024 * 1024 - } else {// no suffix, so it's just a number in bytes - (lower.toLong / 1024 / 1024).toInt - } - } - - /** - * Convert a quantity in bytes to a human-readable string such as "4.0 MB". - */ - def bytesToString(size: Long): String = { - val TB = 1L << 40 - val GB = 1L << 30 - val MB = 1L << 20 - val KB = 1L << 10 - - val (value, unit) = { - if (size >= 2*TB) { - (size.asInstanceOf[Double] / TB, "TB") - } else if (size >= 2*GB) { - (size.asInstanceOf[Double] / GB, "GB") - } else if (size >= 2*MB) { - (size.asInstanceOf[Double] / MB, "MB") - } else if (size >= 2*KB) { - (size.asInstanceOf[Double] / KB, "KB") - } else { - (size.asInstanceOf[Double], "B") - } - } - "%.1f %s".formatLocal(Locale.US, value, unit) - } - - /** - * Returns a human-readable string representing a duration such as "35ms" - */ - def msDurationToString(ms: Long): String = { - val second = 1000 - val minute = 60 * second - val hour = 60 * minute - - ms match { - case t if t < second => - "%d ms".format(t) - case t if t < minute => - "%.1f s".format(t.toFloat / second) - case t if t < hour => - "%.1f m".format(t.toFloat / minute) - case t => - "%.2f h".format(t.toFloat / hour) - } - } - - /** - * Convert a quantity in megabytes to a human-readable string such as "4.0 MB". - */ - def megabytesToString(megabytes: Long): String = { - bytesToString(megabytes * 1024L * 1024L) - } - - /** - * Execute a command in the given working directory, throwing an exception if it completes - * with an exit code other than 0. - */ - def execute(command: Seq[String], workingDir: File) { - val process = new ProcessBuilder(command: _*) - .directory(workingDir) - .redirectErrorStream(true) - .start() - new Thread("read stdout for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines) { - System.err.println(line) - } - } - }.start() - val exitCode = process.waitFor() - if (exitCode != 0) { - throw new SparkException("Process " + command + " exited with code " + exitCode) - } - } - - /** - * Execute a command in the current working directory, throwing an exception if it completes - * with an exit code other than 0. - */ - def execute(command: Seq[String]) { - execute(command, new File(".")) - } - - /** - * Execute a command and get its output, throwing an exception if it yields a code other than 0. - */ - def executeAndGetOutput(command: Seq[String], workingDir: File = new File("."), - extraEnvironment: Map[String, String] = Map.empty): String = { - val builder = new ProcessBuilder(command: _*) - .directory(workingDir) - val environment = builder.environment() - for ((key, value) <- extraEnvironment) { - environment.put(key, value) - } - val process = builder.start() - new Thread("read stderr for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getErrorStream).getLines) { - System.err.println(line) - } - } - }.start() - val output = new StringBuffer - val stdoutThread = new Thread("read stdout for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines) { - output.append(line) - } - } - } - stdoutThread.start() - val exitCode = process.waitFor() - stdoutThread.join() // Wait for it to finish reading output - if (exitCode != 0) { - throw new SparkException("Process " + command + " exited with code " + exitCode) - } - output.toString - } - - /** - * A regular expression to match classes of the "core" Spark API that we want to skip when - * finding the call site of a method. - */ - private val SPARK_CLASS_REGEX = """^spark(\.api\.java)?(\.rdd)?\.[A-Z]""".r - - private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String, - val firstUserLine: Int, val firstUserClass: String) - - /** - * When called inside a class in the spark package, returns the name of the user code class - * (outside the spark package) that called into Spark, as well as which Spark method they called. - * This is used, for example, to tell users where in their code each RDD got created. - */ - def getCallSiteInfo: CallSiteInfo = { - val trace = Thread.currentThread.getStackTrace().filter( el => - (!el.getMethodName.contains("getStackTrace"))) - - // Keep crawling up the stack trace until we find the first function not inside of the spark - // package. We track the last (shallowest) contiguous Spark method. This might be an RDD - // transformation, a SparkContext function (such as parallelize), or anything else that leads - // to instantiation of an RDD. We also track the first (deepest) user method, file, and line. - var lastSparkMethod = "" - var firstUserFile = "" - var firstUserLine = 0 - var finished = false - var firstUserClass = "" - - for (el <- trace) { - if (!finished) { - if (SPARK_CLASS_REGEX.findFirstIn(el.getClassName) != None) { - lastSparkMethod = if (el.getMethodName == "") { - // Spark method is a constructor; get its class name - el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1) - } else { - el.getMethodName - } - } - else { - firstUserLine = el.getLineNumber - firstUserFile = el.getFileName - firstUserClass = el.getClassName - finished = true - } - } - } - new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass) - } - - def formatSparkCallSite = { - val callSiteInfo = getCallSiteInfo - "%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile, - callSiteInfo.firstUserLine) - } - - /** Return a string containing part of a file from byte 'start' to 'end'. */ - def offsetBytes(path: String, start: Long, end: Long): String = { - val file = new File(path) - val length = file.length() - val effectiveEnd = math.min(length, end) - val effectiveStart = math.max(0, start) - val buff = new Array[Byte]((effectiveEnd-effectiveStart).toInt) - val stream = new FileInputStream(file) - - stream.skip(effectiveStart) - stream.read(buff) - stream.close() - Source.fromBytes(buff).mkString - } - - /** - * Clone an object using a Spark serializer. - */ - def clone[T](value: T, serializer: SerializerInstance): T = { - serializer.deserialize[T](serializer.serialize(value)) - } - - /** - * Detect whether this thread might be executing a shutdown hook. Will always return true if - * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. - * if System.exit was just called by a concurrent thread). - * - * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing - * an IllegalStateException. - */ - def inShutdown(): Boolean = { - try { - val hook = new Thread { - override def run() {} - } - Runtime.getRuntime.addShutdownHook(hook) - Runtime.getRuntime.removeShutdownHook(hook) - } catch { - case ise: IllegalStateException => return true - } - return false - } - - def isSpace(c: Char): Boolean = { - " \t\r\n".indexOf(c) != -1 - } - - /** - * Split a string of potentially quoted arguments from the command line the way that a shell - * would do it to determine arguments to a command. For example, if the string is 'a "b c" d', - * then it would be parsed as three arguments: 'a', 'b c' and 'd'. - */ - def splitCommandString(s: String): Seq[String] = { - val buf = new ArrayBuffer[String] - var inWord = false - var inSingleQuote = false - var inDoubleQuote = false - var curWord = new StringBuilder - def endWord() { - buf += curWord.toString - curWord.clear() - } - var i = 0 - while (i < s.length) { - var nextChar = s.charAt(i) - if (inDoubleQuote) { - if (nextChar == '"') { - inDoubleQuote = false - } else if (nextChar == '\\') { - if (i < s.length - 1) { - // Append the next character directly, because only " and \ may be escaped in - // double quotes after the shell's own expansion - curWord.append(s.charAt(i + 1)) - i += 1 - } - } else { - curWord.append(nextChar) - } - } else if (inSingleQuote) { - if (nextChar == '\'') { - inSingleQuote = false - } else { - curWord.append(nextChar) - } - // Backslashes are not treated specially in single quotes - } else if (nextChar == '"') { - inWord = true - inDoubleQuote = true - } else if (nextChar == '\'') { - inWord = true - inSingleQuote = true - } else if (!isSpace(nextChar)) { - curWord.append(nextChar) - inWord = true - } else if (inWord && isSpace(nextChar)) { - endWord() - inWord = false - } - i += 1 - } - if (inWord || inDoubleQuote || inSingleQuote) { - endWord() - } - return buf - } - - /* Calculates 'x' modulo 'mod', takes to consideration sign of x, - * i.e. if 'x' is negative, than 'x' % 'mod' is negative too - * so function return (x % mod) + mod in that case. - */ - def nonNegativeMod(x: Int, mod: Int): Int = { - val rawMod = x % mod - rawMod + (if (rawMod < 0) mod else 0) - } -} diff --git a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala deleted file mode 100644 index 8ce7df6213..0000000000 --- a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java - -import spark.RDD -import spark.SparkContext.doubleRDDToDoubleRDDFunctions -import spark.api.java.function.{Function => JFunction} -import spark.util.StatCounter -import spark.partial.{BoundedDouble, PartialResult} -import spark.storage.StorageLevel -import java.lang.Double -import spark.Partitioner - -class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] { - - override val classManifest: ClassManifest[Double] = implicitly[ClassManifest[Double]] - - override val rdd: RDD[Double] = srdd.map(x => Double.valueOf(x)) - - override def wrapRDD(rdd: RDD[Double]): JavaDoubleRDD = - new JavaDoubleRDD(rdd.map(_.doubleValue)) - - // Common RDD functions - - import JavaDoubleRDD.fromRDD - - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - def cache(): JavaDoubleRDD = fromRDD(srdd.cache()) - - /** - * Set this RDD's storage level to persist its values across operations after the first time - * it is computed. Can only be called once on each RDD. - */ - def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel)) - - // first() has to be overriden here in order for its return type to be Double instead of Object. - override def first(): Double = srdd.first() - - // Transformations (return a new RDD) - - /** - * Return a new RDD containing the distinct elements in this RDD. - */ - def distinct(): JavaDoubleRDD = fromRDD(srdd.distinct()) - - /** - * Return a new RDD containing the distinct elements in this RDD. - */ - def distinct(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.distinct(numPartitions)) - - /** - * Return a new RDD containing only the elements that satisfy a predicate. - */ - def filter(f: JFunction[Double, java.lang.Boolean]): JavaDoubleRDD = - fromRDD(srdd.filter(x => f(x).booleanValue())) - - /** - * Return a new RDD that is reduced into `numPartitions` partitions. - */ - def coalesce(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.coalesce(numPartitions)) - - /** - * Return a new RDD that is reduced into `numPartitions` partitions. - */ - def coalesce(numPartitions: Int, shuffle: Boolean): JavaDoubleRDD = - fromRDD(srdd.coalesce(numPartitions, shuffle)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - * - * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. - */ - def subtract(other: JavaDoubleRDD): JavaDoubleRDD = - fromRDD(srdd.subtract(other)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: JavaDoubleRDD, numPartitions: Int): JavaDoubleRDD = - fromRDD(srdd.subtract(other, numPartitions)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: JavaDoubleRDD, p: Partitioner): JavaDoubleRDD = - fromRDD(srdd.subtract(other, p)) - - /** - * Return a sampled subset of this RDD. - */ - def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaDoubleRDD = - fromRDD(srdd.sample(withReplacement, fraction, seed)) - - /** - * Return the union of this RDD and another one. Any identical elements will appear multiple - * times (use `.distinct()` to eliminate them). - */ - def union(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.union(other.srdd)) - - // Double RDD functions - - /** Add up the elements in this RDD. */ - def sum(): Double = srdd.sum() - - /** - * Return a [[spark.util.StatCounter]] object that captures the mean, variance and count - * of the RDD's elements in one operation. - */ - def stats(): StatCounter = srdd.stats() - - /** Compute the mean of this RDD's elements. */ - def mean(): Double = srdd.mean() - - /** Compute the variance of this RDD's elements. */ - def variance(): Double = srdd.variance() - - /** Compute the standard deviation of this RDD's elements. */ - def stdev(): Double = srdd.stdev() - - /** - * Compute the sample standard deviation of this RDD's elements (which corrects for bias in - * estimating the standard deviation by dividing by N-1 instead of N). - */ - def sampleStdev(): Double = srdd.sampleStdev() - - /** - * Compute the sample variance of this RDD's elements (which corrects for bias in - * estimating the standard variance by dividing by N-1 instead of N). - */ - def sampleVariance(): Double = srdd.sampleVariance() - - /** Return the approximate mean of the elements in this RDD. */ - def meanApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = - srdd.meanApprox(timeout, confidence) - - /** (Experimental) Approximate operation to return the mean within a timeout. */ - def meanApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.meanApprox(timeout) - - /** (Experimental) Approximate operation to return the sum within a timeout. */ - def sumApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = - srdd.sumApprox(timeout, confidence) - - /** (Experimental) Approximate operation to return the sum within a timeout. */ - def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout) -} - -object JavaDoubleRDD { - def fromRDD(rdd: RDD[scala.Double]): JavaDoubleRDD = new JavaDoubleRDD(rdd) - - implicit def toRDD(rdd: JavaDoubleRDD): RDD[scala.Double] = rdd.srdd -} diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala deleted file mode 100644 index effe6e5e0d..0000000000 --- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala +++ /dev/null @@ -1,601 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java - -import java.util.{List => JList} -import java.util.Comparator - -import scala.Tuple2 -import scala.collection.JavaConversions._ - -import com.google.common.base.Optional -import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.OutputFormat -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.hadoop.conf.Configuration - -import spark.HashPartitioner -import spark.Partitioner -import spark.Partitioner._ -import spark.RDD -import spark.SparkContext.rddToPairRDDFunctions -import spark.api.java.function.{Function2 => JFunction2} -import spark.api.java.function.{Function => JFunction} -import spark.partial.BoundedDouble -import spark.partial.PartialResult -import spark.rdd.OrderedRDDFunctions -import spark.storage.StorageLevel - - -class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManifest[K], - implicit val vManifest: ClassManifest[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] { - - override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) - - override val classManifest: ClassManifest[(K, V)] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] - - import JavaPairRDD._ - - // Common RDD functions - - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache()) - - /** - * Set this RDD's storage level to persist its values across operations after the first time - * it is computed. Can only be called once on each RDD. - */ - def persist(newLevel: StorageLevel): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.persist(newLevel)) - - // Transformations (return a new RDD) - - /** - * Return a new RDD containing the distinct elements in this RDD. - */ - def distinct(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct()) - - /** - * Return a new RDD containing the distinct elements in this RDD. - */ - def distinct(numPartitions: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct(numPartitions)) - - /** - * Return a new RDD containing only the elements that satisfy a predicate. - */ - def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue())) - - /** - * Return a new RDD that is reduced into `numPartitions` partitions. - */ - def coalesce(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.coalesce(numPartitions)) - - /** - * Return a new RDD that is reduced into `numPartitions` partitions. - */ - def coalesce(numPartitions: Int, shuffle: Boolean): JavaPairRDD[K, V] = - fromRDD(rdd.coalesce(numPartitions, shuffle)) - - /** - * Return a sampled subset of this RDD. - */ - def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed)) - - /** - * Return the union of this RDD and another one. Any identical elements will appear multiple - * times (use `.distinct()` to eliminate them). - */ - def union(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.union(other.rdd)) - - // first() has to be overridden here so that the generated method has the signature - // 'public scala.Tuple2 first()'; if the trait's definition is used, - // then the method has the signature 'public java.lang.Object first()', - // causing NoSuchMethodErrors at runtime. - override def first(): (K, V) = rdd.first() - - // Pair RDD functions - - /** - * Generic function to combine the elements for each key using a custom set of aggregation - * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C * Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three - * functions: - * - * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) - * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) - * - `mergeCombiners`, to combine two C's into a single one. - * - * In addition, users can control the partitioning of the output RDD, and whether to perform - * map-side aggregation (if a mapper can produce multiple items with the same key). - */ - def combineByKey[C](createCombiner: JFunction[V, C], - mergeValue: JFunction2[C, V, C], - mergeCombiners: JFunction2[C, C, C], - partitioner: Partitioner): JavaPairRDD[K, C] = { - implicit val cm: ClassManifest[C] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]] - fromRDD(rdd.combineByKey( - createCombiner, - mergeValue, - mergeCombiners, - partitioner - )) - } - - /** - * Simplified version of combineByKey that hash-partitions the output RDD. - */ - def combineByKey[C](createCombiner: JFunction[V, C], - mergeValue: JFunction2[C, V, C], - mergeCombiners: JFunction2[C, C, C], - numPartitions: Int): JavaPairRDD[K, C] = - combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) - - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. - */ - def reduceByKey(partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = - fromRDD(rdd.reduceByKey(partitioner, func)) - - /** - * Merge the values for each key using an associative reduce function, but return the results - * immediately to the master as a Map. This will also perform the merging locally on each mapper - * before sending results to a reducer, similarly to a "combiner" in MapReduce. - */ - def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] = - mapAsJavaMap(rdd.reduceByKeyLocally(func)) - - /** Count the number of elements for each key, and return the result to the master as a Map. */ - def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey()) - - /** - * (Experimental) Approximate version of countByKey that can return a partial result if it does - * not finish within a timeout. - */ - def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] = - rdd.countByKeyApprox(timeout).map(mapAsJavaMap) - - /** - * (Experimental) Approximate version of countByKey that can return a partial result if it does - * not finish within a timeout. - */ - def countByKeyApprox(timeout: Long, confidence: Double = 0.95) - : PartialResult[java.util.Map[K, BoundedDouble]] = - rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap) - - /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). - */ - def foldByKey(zeroValue: V, partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = - fromRDD(rdd.foldByKey(zeroValue, partitioner)(func)) - - /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). - */ - def foldByKey(zeroValue: V, numPartitions: Int, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = - fromRDD(rdd.foldByKey(zeroValue, numPartitions)(func)) - - /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). - */ - def foldByKey(zeroValue: V, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = - fromRDD(rdd.foldByKey(zeroValue)(func)) - - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. - */ - def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairRDD[K, V] = - fromRDD(rdd.reduceByKey(func, numPartitions)) - - /** - * Group the values for each key in the RDD into a single sequence. Allows controlling the - * partitioning of the resulting key-value pair RDD by passing a Partitioner. - */ - def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JList[V]] = - fromRDD(groupByResultToJava(rdd.groupByKey(partitioner))) - - /** - * Group the values for each key in the RDD into a single sequence. Hash-partitions the - * resulting RDD with into `numPartitions` partitions. - */ - def groupByKey(numPartitions: Int): JavaPairRDD[K, JList[V]] = - fromRDD(groupByResultToJava(rdd.groupByKey(numPartitions))) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - * - * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. - */ - def subtract(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = - fromRDD(rdd.subtract(other)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: JavaPairRDD[K, V], numPartitions: Int): JavaPairRDD[K, V] = - fromRDD(rdd.subtract(other, numPartitions)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: JavaPairRDD[K, V], p: Partitioner): JavaPairRDD[K, V] = - fromRDD(rdd.subtract(other, p)) - - /** - * Return a copy of the RDD partitioned using the specified partitioner. - */ - def partitionBy(partitioner: Partitioner): JavaPairRDD[K, V] = - fromRDD(rdd.partitionBy(partitioner)) - - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. - */ - def join[W](other: JavaPairRDD[K, W], partitioner: Partitioner): JavaPairRDD[K, (V, W)] = - fromRDD(rdd.join(other, partitioner)) - - /** - * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the - * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the - * pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to - * partition the output RDD. - */ - def leftOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner) - : JavaPairRDD[K, (V, Optional[W])] = { - val joinResult = rdd.leftOuterJoin(other, partitioner) - fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))}) - } - - /** - * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the - * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the - * pair (k, (None, w)) if no elements in `this` have key k. Uses the given Partitioner to - * partition the output RDD. - */ - def rightOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner) - : JavaPairRDD[K, (Optional[V], W)] = { - val joinResult = rdd.rightOuterJoin(other, partitioner) - fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}) - } - - /** - * Simplified version of combineByKey that hash-partitions the resulting RDD using the existing - * partitioner/parallelism level. - */ - def combineByKey[C](createCombiner: JFunction[V, C], - mergeValue: JFunction2[C, V, C], - mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = { - implicit val cm: ClassManifest[C] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]] - fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(rdd))) - } - - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ - * parallelism level. - */ - def reduceByKey(func: JFunction2[V, V, V]): JavaPairRDD[K, V] = { - fromRDD(reduceByKey(defaultPartitioner(rdd), func)) - } - - /** - * Group the values for each key in the RDD into a single sequence. Hash-partitions the - * resulting RDD with the existing partitioner/parallelism level. - */ - def groupByKey(): JavaPairRDD[K, JList[V]] = - fromRDD(groupByResultToJava(rdd.groupByKey())) - - /** - * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each - * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and - * (k, v2) is in `other`. Performs a hash join across the cluster. - */ - def join[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, W)] = - fromRDD(rdd.join(other)) - - /** - * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each - * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and - * (k, v2) is in `other`. Performs a hash join across the cluster. - */ - def join[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, W)] = - fromRDD(rdd.join(other, numPartitions)) - - /** - * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the - * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the - * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output - * using the existing partitioner/parallelism level. - */ - def leftOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, Optional[W])] = { - val joinResult = rdd.leftOuterJoin(other) - fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))}) - } - - /** - * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the - * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the - * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output - * into `numPartitions` partitions. - */ - def leftOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, Optional[W])] = { - val joinResult = rdd.leftOuterJoin(other, numPartitions) - fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))}) - } - - /** - * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the - * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the - * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting - * RDD using the existing partitioner/parallelism level. - */ - def rightOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Optional[V], W)] = { - val joinResult = rdd.rightOuterJoin(other) - fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}) - } - - /** - * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the - * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the - * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting - * RDD into the given number of partitions. - */ - def rightOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (Optional[V], W)] = { - val joinResult = rdd.rightOuterJoin(other, numPartitions) - fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}) - } - - /** - * Return the key-value pairs in this RDD to the master as a Map. - */ - def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap()) - - /** - * Pass each value in the key-value pair RDD through a map function without changing the keys; - * this also retains the original RDD's partitioning. - */ - def mapValues[U](f: JFunction[V, U]): JavaPairRDD[K, U] = { - implicit val cm: ClassManifest[U] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] - fromRDD(rdd.mapValues(f)) - } - - /** - * Pass each value in the key-value pair RDD through a flatMap function without changing the - * keys; this also retains the original RDD's partitioning. - */ - def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = { - import scala.collection.JavaConverters._ - def fn = (x: V) => f.apply(x).asScala - implicit val cm: ClassManifest[U] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] - fromRDD(rdd.flatMapValues(fn)) - } - - /** - * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the - * list of values for that key in `this` as well as `other`. - */ - def cogroup[W](other: JavaPairRDD[K, W], partitioner: Partitioner) - : JavaPairRDD[K, (JList[V], JList[W])] = - fromRDD(cogroupResultToJava(rdd.cogroup(other, partitioner))) - - /** - * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a - * tuple with the list of values for that key in `this`, `other1` and `other2`. - */ - def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], partitioner: Partitioner) - : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = - fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner))) - - /** - * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the - * list of values for that key in `this` as well as `other`. - */ - def cogroup[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] = - fromRDD(cogroupResultToJava(rdd.cogroup(other))) - - /** - * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a - * tuple with the list of values for that key in `this`, `other1` and `other2`. - */ - def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2]) - : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = - fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2))) - - /** - * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the - * list of values for that key in `this` as well as `other`. - */ - def cogroup[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (JList[V], JList[W])] - = fromRDD(cogroupResultToJava(rdd.cogroup(other, numPartitions))) - - /** - * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a - * tuple with the list of values for that key in `this`, `other1` and `other2`. - */ - def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], numPartitions: Int) - : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = - fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions))) - - /** Alias for cogroup. */ - def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] = - fromRDD(cogroupResultToJava(rdd.groupWith(other))) - - /** Alias for cogroup. */ - def groupWith[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2]) - : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = - fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2))) - - /** - * Return the list of values in the RDD for key `key`. This operation is done efficiently if the - * RDD has a known partitioner by only searching the partition that the key maps to. - */ - def lookup(key: K): JList[V] = seqAsJavaList(rdd.lookup(key)) - - /** Output the RDD to any Hadoop-supported file system. */ - def saveAsHadoopFile[F <: OutputFormat[_, _]]( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[F], - conf: JobConf) { - rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, conf) - } - - /** Output the RDD to any Hadoop-supported file system. */ - def saveAsHadoopFile[F <: OutputFormat[_, _]]( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[F]) { - rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass) - } - - /** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */ - def saveAsHadoopFile[F <: OutputFormat[_, _]]( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[F], - codec: Class[_ <: CompressionCodec]) { - rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec) - } - - /** Output the RDD to any Hadoop-supported file system. */ - def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]]( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[F], - conf: Configuration) { - rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, conf) - } - - /** Output the RDD to any Hadoop-supported file system. */ - def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]]( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[F]) { - rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass) - } - - /** - * Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for - * that storage system. The JobConf should set an OutputFormat and any output paths required - * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop - * MapReduce job. - */ - def saveAsHadoopDataset(conf: JobConf) { - rdd.saveAsHadoopDataset(conf) - } - - /** - * Sort the RDD by key, so that each partition contains a sorted range of the elements in - * ascending order. Calling `collect` or `save` on the resulting RDD will return or output an - * ordered list of records (in the `save` case, they will be written to multiple `part-X` files - * in the filesystem, in order of the keys). - */ - def sortByKey(): JavaPairRDD[K, V] = sortByKey(true) - - /** - * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling - * `collect` or `save` on the resulting RDD will return or output an ordered list of records - * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in - * order of the keys). - */ - def sortByKey(ascending: Boolean): JavaPairRDD[K, V] = { - val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]] - sortByKey(comp, ascending) - } - - /** - * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling - * `collect` or `save` on the resulting RDD will return or output an ordered list of records - * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in - * order of the keys). - */ - def sortByKey(comp: Comparator[K]): JavaPairRDD[K, V] = sortByKey(comp, true) - - /** - * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling - * `collect` or `save` on the resulting RDD will return or output an ordered list of records - * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in - * order of the keys). - */ - def sortByKey(comp: Comparator[K], ascending: Boolean): JavaPairRDD[K, V] = { - class KeyOrdering(val a: K) extends Ordered[K] { - override def compare(b: K) = comp.compare(a, b) - } - implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x) - fromRDD(new OrderedRDDFunctions[K, V, (K, V)](rdd).sortByKey(ascending)) - } - - /** - * Return an RDD with the keys of each tuple. - */ - def keys(): JavaRDD[K] = JavaRDD.fromRDD[K](rdd.map(_._1)) - - /** - * Return an RDD with the values of each tuple. - */ - def values(): JavaRDD[V] = JavaRDD.fromRDD[V](rdd.map(_._2)) -} - -object JavaPairRDD { - def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassManifest[K], - vcm: ClassManifest[T]): RDD[(K, JList[T])] = - rddToPairRDDFunctions(rdd).mapValues(seqAsJavaList _) - - def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassManifest[K], - vcm: ClassManifest[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd).mapValues((x: (Seq[V], - Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2))) - - def cogroupResult2ToJava[W1, W2, K, V](rdd: RDD[(K, (Seq[V], Seq[W1], - Seq[W2]))])(implicit kcm: ClassManifest[K]) : RDD[(K, (JList[V], JList[W1], - JList[W2]))] = rddToPairRDDFunctions(rdd).mapValues( - (x: (Seq[V], Seq[W1], Seq[W2])) => (seqAsJavaList(x._1), - seqAsJavaList(x._2), - seqAsJavaList(x._3))) - - def fromRDD[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd) - - implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd -} diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala deleted file mode 100644 index c0bf2cf568..0000000000 --- a/core/src/main/scala/spark/api/java/JavaRDD.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java - -import spark._ -import spark.api.java.function.{Function => JFunction} -import spark.storage.StorageLevel - -class JavaRDD[T](val rdd: RDD[T])(implicit val classManifest: ClassManifest[T]) extends -JavaRDDLike[T, JavaRDD[T]] { - - override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd) - - // Common RDD functions - - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ - def cache(): JavaRDD[T] = wrapRDD(rdd.cache()) - - /** - * Set this RDD's storage level to persist its values across operations after the first time - * it is computed. This can only be used to assign a new storage level if the RDD does not - * have a storage level set yet.. - */ - def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel)) - - /** - * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. - */ - def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist()) - - // Transformations (return a new RDD) - - /** - * Return a new RDD containing the distinct elements in this RDD. - */ - def distinct(): JavaRDD[T] = wrapRDD(rdd.distinct()) - - /** - * Return a new RDD containing the distinct elements in this RDD. - */ - def distinct(numPartitions: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numPartitions)) - - /** - * Return a new RDD containing only the elements that satisfy a predicate. - */ - def filter(f: JFunction[T, java.lang.Boolean]): JavaRDD[T] = - wrapRDD(rdd.filter((x => f(x).booleanValue()))) - - /** - * Return a new RDD that is reduced into `numPartitions` partitions. - */ - def coalesce(numPartitions: Int): JavaRDD[T] = rdd.coalesce(numPartitions) - - /** - * Return a new RDD that is reduced into `numPartitions` partitions. - */ - def coalesce(numPartitions: Int, shuffle: Boolean): JavaRDD[T] = - rdd.coalesce(numPartitions, shuffle) - - /** - * Return a sampled subset of this RDD. - */ - def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] = - wrapRDD(rdd.sample(withReplacement, fraction, seed)) - - /** - * Return the union of this RDD and another one. Any identical elements will appear multiple - * times (use `.distinct()` to eliminate them). - */ - def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - * - * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. - */ - def subtract(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.subtract(other)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: JavaRDD[T], numPartitions: Int): JavaRDD[T] = - wrapRDD(rdd.subtract(other, numPartitions)) - - /** - * Return an RDD with the elements from `this` that are not in `other`. - */ - def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] = - wrapRDD(rdd.subtract(other, p)) -} - -object JavaRDD { - - implicit def fromRDD[T: ClassManifest](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd) - - implicit def toRDD[T](rdd: JavaRDD[T]): RDD[T] = rdd.rdd -} - diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala deleted file mode 100644 index 2c2b138f16..0000000000 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ /dev/null @@ -1,426 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java - -import java.util.{List => JList, Comparator} -import scala.Tuple2 -import scala.collection.JavaConversions._ - -import org.apache.hadoop.io.compress.CompressionCodec -import spark.{SparkContext, Partition, RDD, TaskContext} -import spark.api.java.JavaPairRDD._ -import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} -import spark.partial.{PartialResult, BoundedDouble} -import spark.storage.StorageLevel -import com.google.common.base.Optional - - -trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { - def wrapRDD(rdd: RDD[T]): This - - implicit val classManifest: ClassManifest[T] - - def rdd: RDD[T] - - /** Set of partitions in this RDD. */ - def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) - - /** The [[spark.SparkContext]] that this RDD was created on. */ - def context: SparkContext = rdd.context - - /** A unique ID for this RDD (within its SparkContext). */ - def id: Int = rdd.id - - /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ - def getStorageLevel: StorageLevel = rdd.getStorageLevel - - /** - * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. - * This should ''not'' be called by users directly, but is available for implementors of custom - * subclasses of RDD. - */ - def iterator(split: Partition, taskContext: TaskContext): java.util.Iterator[T] = - asJavaIterator(rdd.iterator(split, taskContext)) - - // Transformations (return a new RDD) - - /** - * Return a new RDD by applying a function to all elements of this RDD. - */ - def map[R](f: JFunction[T, R]): JavaRDD[R] = - new JavaRDD(rdd.map(f)(f.returnType()))(f.returnType()) - - /** - * Return a new RDD by applying a function to all elements of this RDD. - */ - def map[R](f: DoubleFunction[T]): JavaDoubleRDD = - new JavaDoubleRDD(rdd.map(x => f(x).doubleValue())) - - /** - * Return a new RDD by applying a function to all elements of this RDD. - */ - def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { - def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]] - new JavaPairRDD(rdd.map(f)(cm))(f.keyType(), f.valueType()) - } - - /** - * Return a new RDD by first applying a function to all elements of this - * RDD, and then flattening the results. - */ - def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = { - import scala.collection.JavaConverters._ - def fn = (x: T) => f.apply(x).asScala - JavaRDD.fromRDD(rdd.flatMap(fn)(f.elementType()))(f.elementType()) - } - - /** - * Return a new RDD by first applying a function to all elements of this - * RDD, and then flattening the results. - */ - def flatMap(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = { - import scala.collection.JavaConverters._ - def fn = (x: T) => f.apply(x).asScala - new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue())) - } - - /** - * Return a new RDD by first applying a function to all elements of this - * RDD, and then flattening the results. - */ - def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { - import scala.collection.JavaConverters._ - def fn = (x: T) => f.apply(x).asScala - def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]] - JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType()) - } - - /** - * Return a new RDD by applying a function to each partition of this RDD. - */ - def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) - JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType()) - } - - /** - * Return a new RDD by applying a function to each partition of this RDD. - */ - def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { - def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) - new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue())) - } - - /** - * Return a new RDD by applying a function to each partition of this RDD. - */ - def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): - JavaPairRDD[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) - JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType()) - } - - /** - * Return an RDD created by coalescing all elements within each partition into an array. - */ - def glom(): JavaRDD[JList[T]] = - new JavaRDD(rdd.glom().map(x => new java.util.ArrayList[T](x.toSeq))) - - /** - * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of - * elements (a, b) where a is in `this` and b is in `other`. - */ - def cartesian[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = - JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classManifest))(classManifest, - other.classManifest) - - /** - * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements - * mapping to that key. - */ - def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = { - implicit val kcm: ClassManifest[K] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - implicit val vcm: ClassManifest[JList[T]] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]] - JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))(kcm, vcm) - } - - /** - * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements - * mapping to that key. - */ - def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = { - implicit val kcm: ClassManifest[K] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - implicit val vcm: ClassManifest[JList[T]] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]] - JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType)))(kcm, vcm) - } - - /** - * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: String): JavaRDD[String] = rdd.pipe(command) - - /** - * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: JList[String]): JavaRDD[String] = - rdd.pipe(asScalaBuffer(command)) - - /** - * Return an RDD created by piping elements to a forked external process. - */ - def pipe(command: JList[String], env: java.util.Map[String, String]): JavaRDD[String] = - rdd.pipe(asScalaBuffer(command), mapAsScalaMap(env)) - - /** - * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, - * second element in each RDD, etc. Assumes that the two RDDs have the *same number of - * partitions* and the *same number of elements in each partition* (e.g. one was made through - * a map on the other). - */ - def zip[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = { - JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classManifest))(classManifest, other.classManifest) - } - - /** - * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by - * applying a function to the zipped partitions. Assumes that all the RDDs have the - * *same number of partitions*, but does *not* require them to have the same number - * of elements in each partition. - */ - def zipPartitions[U, V]( - other: JavaRDDLike[U, _], - f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = { - def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator( - f.apply(asJavaIterator(x), asJavaIterator(y)).iterator()) - JavaRDD.fromRDD( - rdd.zipPartitions(other.rdd)(fn)(other.classManifest, f.elementType()))(f.elementType()) - } - - // Actions (launch a job to return a value to the user program) - - /** - * Applies a function f to all elements of this RDD. - */ - def foreach(f: VoidFunction[T]) { - val cleanF = rdd.context.clean(f) - rdd.foreach(cleanF) - } - - /** - * Return an array that contains all of the elements in this RDD. - */ - def collect(): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.collect().toSeq - new java.util.ArrayList(arr) - } - - /** - * Reduces the elements of this RDD using the specified commutative and associative binary operator. - */ - def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f) - - /** - * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to - * modify t1 and return it as its result value to avoid object allocation; however, it should not - * modify t2. - */ - def fold(zeroValue: T)(f: JFunction2[T, T, T]): T = - rdd.fold(zeroValue)(f) - - /** - * Aggregate the elements of each partition, and then the results for all the partitions, using - * given combine functions and a neutral "zero value". This function can return a different result - * type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U - * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are - * allowed to modify and return their first argument instead of creating a new U to avoid memory - * allocation. - */ - def aggregate[U](zeroValue: U)(seqOp: JFunction2[U, T, U], - combOp: JFunction2[U, U, U]): U = - rdd.aggregate(zeroValue)(seqOp, combOp)(seqOp.returnType) - - /** - * Return the number of elements in the RDD. - */ - def count(): Long = rdd.count() - - /** - * (Experimental) Approximate version of count() that returns a potentially incomplete result - * within a timeout, even if not all tasks have finished. - */ - def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = - rdd.countApprox(timeout, confidence) - - /** - * (Experimental) Approximate version of count() that returns a potentially incomplete result - * within a timeout, even if not all tasks have finished. - */ - def countApprox(timeout: Long): PartialResult[BoundedDouble] = - rdd.countApprox(timeout) - - /** - * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final - * combine step happens locally on the master, equivalent to running a single reduce task. - */ - def countByValue(): java.util.Map[T, java.lang.Long] = - mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) - - /** - * (Experimental) Approximate version of countByValue(). - */ - def countByValueApprox( - timeout: Long, - confidence: Double - ): PartialResult[java.util.Map[T, BoundedDouble]] = - rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap) - - /** - * (Experimental) Approximate version of countByValue(). - */ - def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] = - rdd.countByValueApprox(timeout).map(mapAsJavaMap) - - /** - * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so - * it will be slow if a lot of partitions are required. In that case, use collect() to get the - * whole RDD instead. - */ - def take(num: Int): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.take(num).toSeq - new java.util.ArrayList(arr) - } - - def takeSample(withReplacement: Boolean, num: Int, seed: Int): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq - new java.util.ArrayList(arr) - } - - /** - * Return the first element in this RDD. - */ - def first(): T = rdd.first() - - /** - * Save this RDD as a text file, using string representations of elements. - */ - def saveAsTextFile(path: String) = rdd.saveAsTextFile(path) - - - /** - * Save this RDD as a compressed text file, using string representations of elements. - */ - def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) = - rdd.saveAsTextFile(path, codec) - - /** - * Save this RDD as a SequenceFile of serialized objects. - */ - def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path) - - /** - * Creates tuples of the elements in this RDD by applying `f`. - */ - def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = { - implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - JavaPairRDD.fromRDD(rdd.keyBy(f)) - } - - /** - * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint - * directory set with SparkContext.setCheckpointDir() and all references to its parent - * RDDs will be removed. This function must be called before any job has been - * executed on this RDD. It is strongly recommended that this RDD is persisted in - * memory, otherwise saving it on a file will require recomputation. - */ - def checkpoint() = rdd.checkpoint() - - /** - * Return whether this RDD has been checkpointed or not - */ - def isCheckpointed: Boolean = rdd.isCheckpointed - - /** - * Gets the name of the file to which this RDD was checkpointed - */ - def getCheckpointFile(): Optional[String] = { - JavaUtils.optionToOptional(rdd.getCheckpointFile) - } - - /** A description of this RDD and its recursive dependencies for debugging. */ - def toDebugString(): String = { - rdd.toDebugString - } - - /** - * Returns the top K elements from this RDD as defined by - * the specified Comparator[T]. - * @param num the number of top elements to return - * @param comp the comparator that defines the order - * @return an array of top elements - */ - def top(num: Int, comp: Comparator[T]): JList[T] = { - import scala.collection.JavaConversions._ - val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp)) - val arr: java.util.Collection[T] = topElems.toSeq - new java.util.ArrayList(arr) - } - - /** - * Returns the top K elements from this RDD using the - * natural ordering for T. - * @param num the number of top elements to return - * @return an array of top elements - */ - def top(num: Int): JList[T] = { - val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]] - top(num, comp) - } - - /** - * Returns the first K elements from this RDD as defined by - * the specified Comparator[T] and maintains the order. - * @param num the number of top elements to return - * @param comp the comparator that defines the order - * @return an array of top elements - */ - def takeOrdered(num: Int, comp: Comparator[T]): JList[T] = { - import scala.collection.JavaConversions._ - val topElems = rdd.takeOrdered(num)(Ordering.comparatorToOrdering(comp)) - val arr: java.util.Collection[T] = topElems.toSeq - new java.util.ArrayList(arr) - } - - /** - * Returns the first K elements from this RDD using the - * natural ordering for T while maintain the order. - * @param num the number of top elements to return - * @return an array of top elements - */ - def takeOrdered(num: Int): JList[T] = { - val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]] - takeOrdered(num, comp) - } -} diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala deleted file mode 100644 index 29d57004b5..0000000000 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ /dev/null @@ -1,418 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java - -import java.util.{Map => JMap} - -import scala.collection.JavaConversions -import scala.collection.JavaConversions._ - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapred.InputFormat -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} - -import spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext} -import spark.SparkContext.IntAccumulatorParam -import spark.SparkContext.DoubleAccumulatorParam -import spark.broadcast.Broadcast - -import com.google.common.base.Optional - -/** - * A Java-friendly version of [[spark.SparkContext]] that returns [[spark.api.java.JavaRDD]]s and - * works with Java collections instead of Scala ones. - */ -class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround { - - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param appName A name for your application, to display on the cluster web UI - */ - def this(master: String, appName: String) = this(new SparkContext(master, appName)) - - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param appName A name for your application, to display on the cluster web UI - * @param sparkHome The SPARK_HOME directory on the slave nodes - * @param jarFile JAR file to send to the cluster. This can be a path on the local file system - * or an HDFS, HTTP, HTTPS, or FTP URL. - */ - def this(master: String, appName: String, sparkHome: String, jarFile: String) = - this(new SparkContext(master, appName, sparkHome, Seq(jarFile))) - - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param appName A name for your application, to display on the cluster web UI - * @param sparkHome The SPARK_HOME directory on the slave nodes - * @param jars Collection of JARs to send to the cluster. These can be paths on the local file - * system or HDFS, HTTP, HTTPS, or FTP URLs. - */ - def this(master: String, appName: String, sparkHome: String, jars: Array[String]) = - this(new SparkContext(master, appName, sparkHome, jars.toSeq)) - - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param appName A name for your application, to display on the cluster web UI - * @param sparkHome The SPARK_HOME directory on the slave nodes - * @param jars Collection of JARs to send to the cluster. These can be paths on the local file - * system or HDFS, HTTP, HTTPS, or FTP URLs. - * @param environment Environment variables to set on worker nodes - */ - def this(master: String, appName: String, sparkHome: String, jars: Array[String], - environment: JMap[String, String]) = - this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment)) - - private[spark] val env = sc.env - - /** Distribute a local Scala collection to form an RDD. */ - def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices) - } - - /** Distribute a local Scala collection to form an RDD. */ - def parallelize[T](list: java.util.List[T]): JavaRDD[T] = - parallelize(list, sc.defaultParallelism) - - /** Distribute a local Scala collection to form an RDD. */ - def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]], numSlices: Int) - : JavaPairRDD[K, V] = { - implicit val kcm: ClassManifest[K] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - implicit val vcm: ClassManifest[V] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] - JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)) - } - - /** Distribute a local Scala collection to form an RDD. */ - def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]]): JavaPairRDD[K, V] = - parallelizePairs(list, sc.defaultParallelism) - - /** Distribute a local Scala collection to form an RDD. */ - def parallelizeDoubles(list: java.util.List[java.lang.Double], numSlices: Int): JavaDoubleRDD = - JavaDoubleRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list).map(_.doubleValue()), - numSlices)) - - /** Distribute a local Scala collection to form an RDD. */ - def parallelizeDoubles(list: java.util.List[java.lang.Double]): JavaDoubleRDD = - parallelizeDoubles(list, sc.defaultParallelism) - - /** - * Read a text file from HDFS, a local file system (available on all nodes), or any - * Hadoop-supported file system URI, and return it as an RDD of Strings. - */ - def textFile(path: String): JavaRDD[String] = sc.textFile(path) - - /** - * Read a text file from HDFS, a local file system (available on all nodes), or any - * Hadoop-supported file system URI, and return it as an RDD of Strings. - */ - def textFile(path: String, minSplits: Int): JavaRDD[String] = sc.textFile(path, minSplits) - - /**Get an RDD for a Hadoop SequenceFile with given key and value types. */ - def sequenceFile[K, V](path: String, - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int - ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) - new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minSplits)) - } - - /**Get an RDD for a Hadoop SequenceFile. */ - def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): - JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) - new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass)) - } - - /** - * Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and - * BytesWritable values that contain a serialized partition. This is still an experimental storage - * format and may not be supported exactly as is in future Spark releases. It will also be pretty - * slow if you use the default serializer (Java serialization), though the nice thing about it is - * that there's very little effort required to save arbitrary objects. - */ - def objectFile[T](path: String, minSplits: Int): JavaRDD[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - sc.objectFile(path, minSplits)(cm) - } - - /** - * Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and - * BytesWritable values that contain a serialized partition. This is still an experimental storage - * format and may not be supported exactly as is in future Spark releases. It will also be pretty - * slow if you use the default serializer (Java serialization), though the nice thing about it is - * that there's very little effort required to save arbitrary objects. - */ - def objectFile[T](path: String): JavaRDD[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - sc.objectFile(path)(cm) - } - - /** - * Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any - * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, - * etc). - */ - def hadoopRDD[K, V, F <: InputFormat[K, V]]( - conf: JobConf, - inputFormatClass: Class[F], - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int - ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) - new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minSplits)) - } - - /** - * Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any - * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, - * etc). - */ - def hadoopRDD[K, V, F <: InputFormat[K, V]]( - conf: JobConf, - inputFormatClass: Class[F], - keyClass: Class[K], - valueClass: Class[V] - ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) - new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass)) - } - - /** Get an RDD for a Hadoop file with an arbitrary InputFormat */ - def hadoopFile[K, V, F <: InputFormat[K, V]]( - path: String, - inputFormatClass: Class[F], - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int - ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) - new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits)) - } - - /** Get an RDD for a Hadoop file with an arbitrary InputFormat */ - def hadoopFile[K, V, F <: InputFormat[K, V]]( - path: String, - inputFormatClass: Class[F], - keyClass: Class[K], - valueClass: Class[V] - ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) - new JavaPairRDD(sc.hadoopFile(path, - inputFormatClass, keyClass, valueClass)) - } - - /** - * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat - * and extra configuration options to pass to the input format. - */ - def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]]( - path: String, - fClass: Class[F], - kClass: Class[K], - vClass: Class[V], - conf: Configuration): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(kClass) - implicit val vcm = ClassManifest.fromClass(vClass) - new JavaPairRDD(sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf)) - } - - /** - * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat - * and extra configuration options to pass to the input format. - */ - def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]]( - conf: Configuration, - fClass: Class[F], - kClass: Class[K], - vClass: Class[V]): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(kClass) - implicit val vcm = ClassManifest.fromClass(vClass) - new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass)) - } - - /** Build the union of two or more RDDs. */ - override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = { - val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) - implicit val cm: ClassManifest[T] = first.classManifest - sc.union(rdds)(cm) - } - - /** Build the union of two or more RDDs. */ - override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]]) - : JavaPairRDD[K, V] = { - val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) - implicit val cm: ClassManifest[(K, V)] = first.classManifest - implicit val kcm: ClassManifest[K] = first.kManifest - implicit val vcm: ClassManifest[V] = first.vManifest - new JavaPairRDD(sc.union(rdds)(cm))(kcm, vcm) - } - - /** Build the union of two or more RDDs. */ - override def union(first: JavaDoubleRDD, rest: java.util.List[JavaDoubleRDD]): JavaDoubleRDD = { - val rdds: Seq[RDD[Double]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.srdd) - new JavaDoubleRDD(sc.union(rdds)) - } - - /** - * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - */ - def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] = - sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]] - - /** - * Create an [[spark.Accumulator]] double variable, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - */ - def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = - sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] - - /** - * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - */ - def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue) - - /** - * Create an [[spark.Accumulator]] double variable, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - */ - def accumulator(initialValue: Double): Accumulator[java.lang.Double] = - doubleAccumulator(initialValue) - - /** - * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values - * to using the `add` method. Only the master can access the accumulator's `value`. - */ - def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = - sc.accumulator(initialValue)(accumulatorParam) - - /** - * Create an [[spark.Accumulable]] shared variable of the given type, to which tasks can - * "add" values with `add`. Only the master can access the accumuable's `value`. - */ - def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = - sc.accumulable(initialValue)(param) - - /** - * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for - * reading it in distributed functions. The variable will be sent to each cluster only once. - */ - def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value) - - /** Shut down the SparkContext. */ - def stop() { - sc.stop() - } - - /** - * Get Spark's home location from either a value set through the constructor, - * or the spark.home Java property, or the SPARK_HOME environment variable - * (in that order of preference). If neither of these is set, return None. - */ - def getSparkHome(): Optional[String] = JavaUtils.optionToOptional(sc.getSparkHome()) - - /** - * Add a file to be downloaded with this Spark job on every node. - * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, - * use `SparkFiles.get(path)` to find its download location. - */ - def addFile(path: String) { - sc.addFile(path) - } - - /** - * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. - * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. - */ - def addJar(path: String) { - sc.addJar(path) - } - - /** - * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to - * any new nodes. - */ - def clearJars() { - sc.clearJars() - } - - /** - * Clear the job's list of files added by `addFile` so that they do not get downloaded to - * any new nodes. - */ - def clearFiles() { - sc.clearFiles() - } - - /** - * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. - */ - def hadoopConfiguration(): Configuration = { - sc.hadoopConfiguration - } - - /** - * Set the directory under which RDDs are going to be checkpointed. The directory must - * be a HDFS path if running on a cluster. If the directory does not exist, it will - * be created. If the directory exists and useExisting is set to true, then the - * exisiting directory will be used. Otherwise an exception will be thrown to - * prevent accidental overriding of checkpoint files in the existing directory. - */ - def setCheckpointDir(dir: String, useExisting: Boolean) { - sc.setCheckpointDir(dir, useExisting) - } - - /** - * Set the directory under which RDDs are going to be checkpointed. The directory must - * be a HDFS path if running on a cluster. If the directory does not exist, it will - * be created. If the directory exists, an exception will be thrown to prevent accidental - * overriding of checkpoint files. - */ - def setCheckpointDir(dir: String) { - sc.setCheckpointDir(dir) - } - - protected def checkpointFile[T](path: String): JavaRDD[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - new JavaRDD(sc.checkpointFile(path)) - } -} - -object JavaSparkContext { - implicit def fromSparkContext(sc: SparkContext): JavaSparkContext = new JavaSparkContext(sc) - - implicit def toSparkContext(jsc: JavaSparkContext): SparkContext = jsc.sc -} diff --git a/core/src/main/scala/spark/api/java/JavaSparkContextVarargsWorkaround.java b/core/src/main/scala/spark/api/java/JavaSparkContextVarargsWorkaround.java deleted file mode 100644 index 42b1de01b1..0000000000 --- a/core/src/main/scala/spark/api/java/JavaSparkContextVarargsWorkaround.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java; - -import java.util.Arrays; -import java.util.ArrayList; -import java.util.List; - -// See -// http://scala-programming-language.1934581.n4.nabble.com/Workaround-for-implementing-java-varargs-in-2-7-2-final-tp1944767p1944772.html -abstract class JavaSparkContextVarargsWorkaround { - public JavaRDD union(JavaRDD... rdds) { - if (rdds.length == 0) { - throw new IllegalArgumentException("Union called on empty list"); - } - ArrayList> rest = new ArrayList>(rdds.length - 1); - for (int i = 1; i < rdds.length; i++) { - rest.add(rdds[i]); - } - return union(rdds[0], rest); - } - - public JavaDoubleRDD union(JavaDoubleRDD... rdds) { - if (rdds.length == 0) { - throw new IllegalArgumentException("Union called on empty list"); - } - ArrayList rest = new ArrayList(rdds.length - 1); - for (int i = 1; i < rdds.length; i++) { - rest.add(rdds[i]); - } - return union(rdds[0], rest); - } - - public JavaPairRDD union(JavaPairRDD... rdds) { - if (rdds.length == 0) { - throw new IllegalArgumentException("Union called on empty list"); - } - ArrayList> rest = new ArrayList>(rdds.length - 1); - for (int i = 1; i < rdds.length; i++) { - rest.add(rdds[i]); - } - return union(rdds[0], rest); - } - - // These methods take separate "first" and "rest" elements to avoid having the same type erasure - abstract public JavaRDD union(JavaRDD first, List> rest); - abstract public JavaDoubleRDD union(JavaDoubleRDD first, List rest); - abstract public JavaPairRDD union(JavaPairRDD first, List> rest); -} diff --git a/core/src/main/scala/spark/api/java/JavaUtils.scala b/core/src/main/scala/spark/api/java/JavaUtils.scala deleted file mode 100644 index ffc131ac83..0000000000 --- a/core/src/main/scala/spark/api/java/JavaUtils.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java - -import com.google.common.base.Optional - -object JavaUtils { - def optionToOptional[T](option: Option[T]): Optional[T] = - option match { - case Some(value) => Optional.of(value) - case None => Optional.absent() - } -} diff --git a/core/src/main/scala/spark/api/java/StorageLevels.java b/core/src/main/scala/spark/api/java/StorageLevels.java deleted file mode 100644 index f385636e83..0000000000 --- a/core/src/main/scala/spark/api/java/StorageLevels.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java; - -import spark.storage.StorageLevel; - -/** - * Expose some commonly useful storage level constants. - */ -public class StorageLevels { - public static final StorageLevel NONE = new StorageLevel(false, false, false, 1); - public static final StorageLevel DISK_ONLY = new StorageLevel(true, false, false, 1); - public static final StorageLevel DISK_ONLY_2 = new StorageLevel(true, false, false, 2); - public static final StorageLevel MEMORY_ONLY = new StorageLevel(false, true, true, 1); - public static final StorageLevel MEMORY_ONLY_2 = new StorageLevel(false, true, true, 2); - public static final StorageLevel MEMORY_ONLY_SER = new StorageLevel(false, true, false, 1); - public static final StorageLevel MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, 2); - public static final StorageLevel MEMORY_AND_DISK = new StorageLevel(true, true, true, 1); - public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2); - public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1); - public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2); - - /** - * Create a new StorageLevel object. - * @param useDisk saved to disk, if true - * @param useMemory saved to memory, if true - * @param deserialized saved as deserialized objects, if true - * @param replication replication factor - */ - public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) { - return StorageLevel.apply(useDisk, useMemory, deserialized, replication); - } -} diff --git a/core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java deleted file mode 100644 index 8bc88d757f..0000000000 --- a/core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function; - - -import scala.runtime.AbstractFunction1; - -import java.io.Serializable; - -/** - * A function that returns zero or more records of type Double from each input record. - */ -// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is -// overloaded for both FlatMapFunction and DoubleFlatMapFunction. -public abstract class DoubleFlatMapFunction extends AbstractFunction1> - implements Serializable { - - public abstract Iterable call(T t); - - @Override - public final Iterable apply(T t) { return call(t); } -} diff --git a/core/src/main/scala/spark/api/java/function/DoubleFunction.java b/core/src/main/scala/spark/api/java/function/DoubleFunction.java deleted file mode 100644 index 1aa1e5dae0..0000000000 --- a/core/src/main/scala/spark/api/java/function/DoubleFunction.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function; - - -import scala.runtime.AbstractFunction1; - -import java.io.Serializable; - -/** - * A function that returns Doubles, and can be used to construct DoubleRDDs. - */ -// DoubleFunction does not extend Function because some UDF functions, like map, -// are overloaded for both Function and DoubleFunction. -public abstract class DoubleFunction extends WrappedFunction1 - implements Serializable { - - public abstract Double call(T t) throws Exception; -} diff --git a/core/src/main/scala/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/spark/api/java/function/FlatMapFunction.scala deleted file mode 100644 index 9eb0cfe3f9..0000000000 --- a/core/src/main/scala/spark/api/java/function/FlatMapFunction.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function - -/** - * A function that returns zero or more output records from each input record. - */ -abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] { - @throws(classOf[Exception]) - def call(x: T) : java.lang.Iterable[R] - - def elementType() : ClassManifest[R] = ClassManifest.Any.asInstanceOf[ClassManifest[R]] -} diff --git a/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala deleted file mode 100644 index dda98710c2..0000000000 --- a/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function - -/** - * A function that takes two inputs and returns zero or more output records. - */ -abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] { - @throws(classOf[Exception]) - def call(a: A, b:B) : java.lang.Iterable[C] - - def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]] -} diff --git a/core/src/main/scala/spark/api/java/function/Function.java b/core/src/main/scala/spark/api/java/function/Function.java deleted file mode 100644 index 2a2ea0aacf..0000000000 --- a/core/src/main/scala/spark/api/java/function/Function.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function; - -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; -import scala.runtime.AbstractFunction1; - -import java.io.Serializable; - - -/** - * Base class for functions whose return types do not create special RDDs. PairFunction and - * DoubleFunction are handled separately, to allow PairRDDs and DoubleRDDs to be constructed - * when mapping RDDs of other types. - */ -public abstract class Function extends WrappedFunction1 implements Serializable { - public abstract R call(T t) throws Exception; - - public ClassManifest returnType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); - } -} - diff --git a/core/src/main/scala/spark/api/java/function/Function2.java b/core/src/main/scala/spark/api/java/function/Function2.java deleted file mode 100644 index 952d31ece4..0000000000 --- a/core/src/main/scala/spark/api/java/function/Function2.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function; - -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; -import scala.runtime.AbstractFunction2; - -import java.io.Serializable; - -/** - * A two-argument function that takes arguments of type T1 and T2 and returns an R. - */ -public abstract class Function2 extends WrappedFunction2 - implements Serializable { - - public abstract R call(T1 t1, T2 t2) throws Exception; - - public ClassManifest returnType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); - } -} - diff --git a/core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java deleted file mode 100644 index 4aad602da3..0000000000 --- a/core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function; - -import scala.Tuple2; -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; -import scala.runtime.AbstractFunction1; - -import java.io.Serializable; - -/** - * A function that returns zero or more key-value pair records from each input record. The - * key-value pairs are represented as scala.Tuple2 objects. - */ -// PairFlatMapFunction does not extend FlatMapFunction because flatMap is -// overloaded for both FlatMapFunction and PairFlatMapFunction. -public abstract class PairFlatMapFunction - extends WrappedFunction1>> - implements Serializable { - - public abstract Iterable> call(T t) throws Exception; - - public ClassManifest keyType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); - } - - public ClassManifest valueType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); - } -} diff --git a/core/src/main/scala/spark/api/java/function/PairFunction.java b/core/src/main/scala/spark/api/java/function/PairFunction.java deleted file mode 100644 index ccfe64ecf1..0000000000 --- a/core/src/main/scala/spark/api/java/function/PairFunction.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function; - -import scala.Tuple2; -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; -import scala.runtime.AbstractFunction1; - -import java.io.Serializable; - -/** - * A function that returns key-value pairs (Tuple2), and can be used to construct PairRDDs. - */ -// PairFunction does not extend Function because some UDF functions, like map, -// are overloaded for both Function and PairFunction. -public abstract class PairFunction - extends WrappedFunction1> - implements Serializable { - - public abstract Tuple2 call(T t) throws Exception; - - public ClassManifest keyType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); - } - - public ClassManifest valueType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); - } -} diff --git a/core/src/main/scala/spark/api/java/function/VoidFunction.scala b/core/src/main/scala/spark/api/java/function/VoidFunction.scala deleted file mode 100644 index f6fc0b0f7d..0000000000 --- a/core/src/main/scala/spark/api/java/function/VoidFunction.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function - -/** - * A function with no return value. - */ -// This allows Java users to write void methods without having to return Unit. -abstract class VoidFunction[T] extends Serializable { - @throws(classOf[Exception]) - def call(t: T) : Unit -} - -// VoidFunction cannot extend AbstractFunction1 (because that would force users to explicitly -// return Unit), so it is implicitly converted to a Function1[T, Unit]: -object VoidFunction { - implicit def toFunction[T](f: VoidFunction[T]) : Function1[T, Unit] = ((x : T) => f.call(x)) -} diff --git a/core/src/main/scala/spark/api/java/function/WrappedFunction1.scala b/core/src/main/scala/spark/api/java/function/WrappedFunction1.scala deleted file mode 100644 index 1758a38c4e..0000000000 --- a/core/src/main/scala/spark/api/java/function/WrappedFunction1.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function - -import scala.runtime.AbstractFunction1 - -/** - * Subclass of Function1 for ease of calling from Java. The main thing it does is re-expose the - * apply() method as call() and declare that it can throw Exception (since AbstractFunction1.apply - * isn't marked to allow that). - */ -private[spark] abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] { - @throws(classOf[Exception]) - def call(t: T): R - - final def apply(t: T): R = call(t) -} diff --git a/core/src/main/scala/spark/api/java/function/WrappedFunction2.scala b/core/src/main/scala/spark/api/java/function/WrappedFunction2.scala deleted file mode 100644 index b093567d2c..0000000000 --- a/core/src/main/scala/spark/api/java/function/WrappedFunction2.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.java.function - -import scala.runtime.AbstractFunction2 - -/** - * Subclass of Function2 for ease of calling from Java. The main thing it does is re-expose the - * apply() method as call() and declare that it can throw Exception (since AbstractFunction2.apply - * isn't marked to allow that). - */ -private[spark] abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] { - @throws(classOf[Exception]) - def call(t1: T1, t2: T2): R - - final def apply(t1: T1, t2: T2): R = call(t1, t2) -} diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala deleted file mode 100644 index ac112b8c2c..0000000000 --- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.python - -import spark.Partitioner -import spark.Utils -import java.util.Arrays - -/** - * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. - * - * Stores the unique id() of the Python-side partitioning function so that it is incorporated into - * equality comparisons. Correctness requires that the id is a unique identifier for the - * lifetime of the program (i.e. that it is not re-used as the id of a different partitioning - * function). This can be ensured by using the Python id() function and maintaining a reference - * to the Python partitioning function so that its id() is not reused. - */ -private[spark] class PythonPartitioner( - override val numPartitions: Int, - val pyPartitionFunctionId: Long) - extends Partitioner { - - override def getPartition(key: Any): Int = key match { - case null => 0 - case key: Array[Byte] => Utils.nonNegativeMod(Arrays.hashCode(key), numPartitions) - case _ => Utils.nonNegativeMod(key.hashCode(), numPartitions) - } - - override def equals(other: Any): Boolean = other match { - case h: PythonPartitioner => - h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId - case _ => - false - } -} diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala deleted file mode 100644 index 49671437d0..0000000000 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ /dev/null @@ -1,344 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.python - -import java.io._ -import java.net._ -import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} - -import scala.collection.JavaConversions._ - -import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} -import spark.broadcast.Broadcast -import spark._ -import spark.rdd.PipedRDD - - -private[spark] class PythonRDD[T: ClassManifest]( - parent: RDD[T], - command: Seq[String], - envVars: JMap[String, String], - pythonIncludes: JList[String], - preservePartitoning: Boolean, - pythonExec: String, - broadcastVars: JList[Broadcast[Array[Byte]]], - accumulator: Accumulator[JList[Array[Byte]]]) - extends RDD[Array[Byte]](parent) { - - val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - - // Similar to Runtime.exec(), if we are given a single string, split it into words - // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, envVars: JMap[String, String], - pythonIncludes: JList[String], - preservePartitoning: Boolean, pythonExec: String, - broadcastVars: JList[Broadcast[Array[Byte]]], - accumulator: Accumulator[JList[Array[Byte]]]) = - this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec, - broadcastVars, accumulator) - - override def getPartitions = parent.partitions - - override val partitioner = if (preservePartitoning) parent.partitioner else None - - - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val startTime = System.currentTimeMillis - val env = SparkEnv.get - val worker = env.createPythonWorker(pythonExec, envVars.toMap) - - // Start a thread to feed the process input from our parent's iterator - new Thread("stdin writer for " + pythonExec) { - override def run() { - try { - SparkEnv.set(env) - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) - val dataOut = new DataOutputStream(stream) - val printOut = new PrintWriter(stream) - // Partition index - dataOut.writeInt(split.index) - // sparkFilesDir - PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut) - // Broadcast variables - dataOut.writeInt(broadcastVars.length) - for (broadcast <- broadcastVars) { - dataOut.writeLong(broadcast.id) - dataOut.writeInt(broadcast.value.length) - dataOut.write(broadcast.value) - } - // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.length) - for (f <- pythonIncludes) { - PythonRDD.writeAsPickle(f, dataOut) - } - dataOut.flush() - // Serialized user code - for (elem <- command) { - printOut.println(elem) - } - printOut.flush() - // Data values - for (elem <- parent.iterator(split, context)) { - PythonRDD.writeAsPickle(elem, dataOut) - } - dataOut.flush() - printOut.flush() - worker.shutdownOutput() - } catch { - case e: IOException => - // This can happen for legitimate reasons if the Python code stops returning data before we are done - // passing elements through, e.g., for take(). Just log a message to say it happened. - logInfo("stdin writer to Python finished early") - logDebug("stdin writer to Python finished early", e) - } - } - }.start() - - // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) - return new Iterator[Array[Byte]] { - def next(): Array[Byte] = { - val obj = _nextObj - if (hasNext) { - // FIXME: can deadlock if worker is waiting for us to - // respond to current message (currently irrelevant because - // output is shutdown before we read any input) - _nextObj = read() - } - obj - } - - private def read(): Array[Byte] = { - try { - stream.readInt() match { - case length if length > 0 => - val obj = new Array[Byte](length) - stream.readFully(obj) - obj - case -3 => - // Timing data from worker - val bootTime = stream.readLong() - val initTime = stream.readLong() - val finishTime = stream.readLong() - val boot = bootTime - startTime - val init = initTime - bootTime - val finish = finishTime - initTime - val total = finishTime - startTime - logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) - read - case -2 => - // Signals that an exception has been thrown in python - val exLength = stream.readInt() - val obj = new Array[Byte](exLength) - stream.readFully(obj) - throw new PythonException(new String(obj)) - case -1 => - // We've finished the data section of the output, but we can still - // read some accumulator updates; let's do that, breaking when we - // get a negative length record. - var len2 = stream.readInt() - while (len2 >= 0) { - val update = new Array[Byte](len2) - stream.readFully(update) - accumulator += Collections.singletonList(update) - len2 = stream.readInt() - } - new Array[Byte](0) - } - } catch { - case eof: EOFException => { - throw new SparkException("Python worker exited unexpectedly (crashed)", eof) - } - case e => throw e - } - } - - var _nextObj = read() - - def hasNext = _nextObj.length != 0 - } - } - - val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) -} - -/** Thrown for exceptions in user Python code. */ -private class PythonException(msg: String) extends Exception(msg) - -/** - * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. - * This is used by PySpark's shuffle operations. - */ -private class PairwiseRDD(prev: RDD[Array[Byte]]) extends - RDD[(Array[Byte], Array[Byte])](prev) { - override def getPartitions = prev.partitions - override def compute(split: Partition, context: TaskContext) = - prev.iterator(split, context).grouped(2).map { - case Seq(a, b) => (a, b) - case x => throw new SparkException("PairwiseRDD: unexpected value: " + x) - } - val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) -} - -private[spark] object PythonRDD { - - /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ - def stripPickle(arr: Array[Byte]) : Array[Byte] = { - arr.slice(2, arr.length - 1) - } - - /** - * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream. - * The data format is a 32-bit integer representing the pickled object's length (in bytes), - * followed by the pickled data. - * - * Pickle module: - * - * http://docs.python.org/2/library/pickle.html - * - * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules: - * - * http://hg.python.org/cpython/file/2.6/Lib/pickle.py - * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py - * - * @param elem the object to write - * @param dOut a data output stream - */ - def writeAsPickle(elem: Any, dOut: DataOutputStream) { - if (elem.isInstanceOf[Array[Byte]]) { - val arr = elem.asInstanceOf[Array[Byte]] - dOut.writeInt(arr.length) - dOut.write(arr) - } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) { - val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] - val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(PythonRDD.stripPickle(t._1)) - dOut.write(PythonRDD.stripPickle(t._2)) - dOut.writeByte(Pickle.TUPLE2) - dOut.writeByte(Pickle.STOP) - } else if (elem.isInstanceOf[String]) { - // For uniformity, strings are wrapped into Pickles. - val s = elem.asInstanceOf[String].getBytes("UTF-8") - val length = 2 + 1 + 4 + s.length + 1 - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(Pickle.BINUNICODE) - dOut.writeInt(Integer.reverseBytes(s.length)) - dOut.write(s) - dOut.writeByte(Pickle.STOP) - } else { - throw new SparkException("Unexpected RDD type") - } - } - - def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : - JavaRDD[Array[Byte]] = { - val file = new DataInputStream(new FileInputStream(filename)) - val objs = new collection.mutable.ArrayBuffer[Array[Byte]] - try { - while (true) { - val length = file.readInt() - val obj = new Array[Byte](length) - file.readFully(obj) - objs.append(obj) - } - } catch { - case eof: EOFException => {} - case e => throw e - } - JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) - } - - def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { - import scala.collection.JavaConverters._ - writeIteratorToPickleFile(items.asScala, filename) - } - - def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) { - val file = new DataOutputStream(new FileOutputStream(filename)) - for (item <- items) { - writeAsPickle(item, file) - } - file.close() - } - - def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = { - implicit val cm : ClassManifest[T] = rdd.elementClassManifest - rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator - } -} - -private object Pickle { - val PROTO: Byte = 0x80.toByte - val TWO: Byte = 0x02.toByte - val BINUNICODE: Byte = 'X' - val STOP: Byte = '.' - val TUPLE2: Byte = 0x86.toByte - val EMPTY_LIST: Byte = ']' - val MARK: Byte = '(' - val APPENDS: Byte = 'e' -} - -private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { - override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") -} - -/** - * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it - * collects a list of pickled strings that we pass to Python through a socket. - */ -class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) - extends AccumulatorParam[JList[Array[Byte]]] { - - Utils.checkHost(serverHost, "Expected hostname") - - val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - - override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList - - override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) - : JList[Array[Byte]] = { - if (serverHost == null) { - // This happens on the worker node, where we just want to remember all the updates - val1.addAll(val2) - val1 - } else { - // This happens on the master, where we pass the updates to Python through a socket - val socket = new Socket(serverHost, serverPort) - val in = socket.getInputStream - val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) - out.writeInt(val2.size) - for (array <- val2) { - out.writeInt(array.length) - out.write(array) - } - out.flush() - // Wait for a byte from the Python side as an acknowledgement - val byteRead = in.read() - if (byteRead == -1) { - throw new SparkException("EOF reached before Python server acknowledged") - } - socket.close() - null - } - } -} diff --git a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala deleted file mode 100644 index 14f8320678..0000000000 --- a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.api.python - -import java.io.{File, DataInputStream, IOException} -import java.net.{Socket, SocketException, InetAddress} - -import scala.collection.JavaConversions._ - -import spark._ - -private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) - extends Logging { - var daemon: Process = null - val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) - var daemonPort: Int = 0 - - def create(): Socket = { - synchronized { - // Start the daemon if it hasn't been started - startDaemon() - - // Attempt to connect, restart and retry once if it fails - try { - new Socket(daemonHost, daemonPort) - } catch { - case exc: SocketException => { - logWarning("Python daemon unexpectedly quit, attempting to restart") - stopDaemon() - startDaemon() - new Socket(daemonHost, daemonPort) - } - case e => throw e - } - } - } - - def stop() { - stopDaemon() - } - - private def startDaemon() { - synchronized { - // Is it already running? - if (daemon != null) { - return - } - - try { - // Create and start the daemon - val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") - val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py")) - val workerEnv = pb.environment() - workerEnv.putAll(envVars) - val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH") - workerEnv.put("PYTHONPATH", pythonPath) - daemon = pb.start() - - // Redirect the stderr to ours - new Thread("stderr reader for " + pythonExec) { - override def run() { - scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME HACK: We copy the stream on the level of bytes to - // attempt to dodge encoding problems. - val in = daemon.getErrorStream - var buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - System.err.write(buf, 0, len) - len = in.read(buf) - } - } - } - }.start() - - val in = new DataInputStream(daemon.getInputStream) - daemonPort = in.readInt() - - // Redirect further stdout output to our stderr - new Thread("stdout reader for " + pythonExec) { - override def run() { - scala.util.control.Exception.ignoring(classOf[IOException]) { - // FIXME HACK: We copy the stream on the level of bytes to - // attempt to dodge encoding problems. - var buf = new Array[Byte](1024) - var len = in.read(buf) - while (len != -1) { - System.err.write(buf, 0, len) - len = in.read(buf) - } - } - } - }.start() - } catch { - case e => { - stopDaemon() - throw e - } - } - - // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly - // detect our disappearance. - } - } - - private def stopDaemon() { - synchronized { - // Request shutdown of existing daemon by sending SIGTERM - if (daemon != null) { - daemon.destroy() - } - - daemon = null - daemonPort = 0 - } - } -} diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala deleted file mode 100644 index 6f7d385379..0000000000 --- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala +++ /dev/null @@ -1,1057 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.broadcast - -import java.io._ -import java.net._ -import java.util.{BitSet, Comparator, Timer, TimerTask, UUID} -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable.{ListBuffer, Map, Set} -import scala.math - -import spark._ -import spark.storage.StorageLevel - -private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) - extends Broadcast[T](id) - with Logging - with Serializable { - - def value = value_ - - def blockId: String = "broadcast_" + id - - MultiTracker.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } - - @transient var arrayOfBlocks: Array[BroadcastBlock] = null - @transient var hasBlocksBitVector: BitSet = null - @transient var numCopiesSent: Array[Int] = null - @transient var totalBytes = -1 - @transient var totalBlocks = -1 - @transient var hasBlocks = new AtomicInteger(0) - - // Used ONLY by driver to track how many unique blocks have been sent out - @transient var sentBlocks = new AtomicInteger(0) - - @transient var listenPortLock = new Object - @transient var guidePortLock = new Object - @transient var totalBlocksLock = new Object - - @transient var listOfSources = ListBuffer[SourceInfo]() - - @transient var serveMR: ServeMultipleRequests = null - - // Used only in driver - @transient var guideMR: GuideMultipleRequests = null - - // Used only in Workers - @transient var ttGuide: TalkToGuide = null - - @transient var hostAddress = Utils.localIpAddress - @transient var listenPort = -1 - @transient var guidePort = -1 - - @transient var stopBroadcast = false - - // Must call this after all the variables have been created/initialized - if (!isLocal) { - sendBroadcast() - } - - def sendBroadcast() { - logInfo("Local host address: " + hostAddress) - - // Create a variableInfo object and store it in valueInfos - var variableInfo = MultiTracker.blockifyObject(value_) - - // Prepare the value being broadcasted - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks.set(variableInfo.totalBlocks) - - // Guide has all the blocks - hasBlocksBitVector = new BitSet(totalBlocks) - hasBlocksBitVector.set(0, totalBlocks) - - // Guide still hasn't sent any block - numCopiesSent = new Array[Int](totalBlocks) - - guideMR = new GuideMultipleRequests - guideMR.setDaemon(true) - guideMR.start() - logInfo("GuideMultipleRequests started...") - - // Must always come AFTER guideMR is created - while (guidePort == -1) { - guidePortLock.synchronized { guidePortLock.wait() } - } - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - // Must always come AFTER serveMR is created - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Must always come AFTER listenPort is created - val driverSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) - hasBlocksBitVector.synchronized { - driverSource.hasBlocksBitVector = hasBlocksBitVector - } - - // In the beginning, this is the only known source to Guide - listOfSources += driverSource - - // Register with the Tracker - MultiTracker.registerBroadcast(id, - SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) - } - - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - MultiTracker.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { - case Some(x) => - value_ = x.asInstanceOf[T] - - case None => - logInfo("Started reading broadcast variable " + id) - // Initializing everything because driver will only send null/0 values - // Only the 1st worker in a node can be here. Others will get from cache - initializeWorkerVariables() - - logInfo("Local host address: " + hostAddress) - - // Start local ServeMultipleRequests thread first - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast(id) - if (receptionSucceeded) { - value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } else { - logError("Reading broadcast variable " + id + " failed") - } - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - - // Initialize variables in the worker node. Driver sends everything as 0/null - private def initializeWorkerVariables() { - arrayOfBlocks = null - hasBlocksBitVector = null - numCopiesSent = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = new AtomicInteger(0) - - listenPortLock = new Object - totalBlocksLock = new Object - - serveMR = null - ttGuide = null - - hostAddress = Utils.localIpAddress - listenPort = -1 - - listOfSources = ListBuffer[SourceInfo]() - - stopBroadcast = false - } - - private def getLocalSourceInfo: SourceInfo = { - // Wait till hostName and listenPort are OK - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Wait till totalBlocks and totalBytes are OK - while (totalBlocks == -1) { - totalBlocksLock.synchronized { totalBlocksLock.wait() } - } - - var localSourceInfo = SourceInfo( - hostAddress, listenPort, totalBlocks, totalBytes) - - localSourceInfo.hasBlocks = hasBlocks.get - - hasBlocksBitVector.synchronized { - localSourceInfo.hasBlocksBitVector = hasBlocksBitVector - } - - return localSourceInfo - } - - // Add new SourceInfo to the listOfSources. Update if it exists already. - // Optimizing just by OR-ing the BitVectors was BAD for performance - private def addToListOfSources(newSourceInfo: SourceInfo) { - listOfSources.synchronized { - if (listOfSources.contains(newSourceInfo)) { - listOfSources = listOfSources - newSourceInfo - } - listOfSources += newSourceInfo - } - } - - private def addToListOfSources(newSourceInfos: ListBuffer[SourceInfo]) { - newSourceInfos.foreach { newSourceInfo => - addToListOfSources(newSourceInfo) - } - } - - class TalkToGuide(gInfo: SourceInfo) - extends Thread with Logging { - override def run() { - - // Keep exchaning information until all blocks have been received - while (hasBlocks.get < totalBlocks) { - talkOnce - Thread.sleep(MultiTracker.ranGen.nextInt( - MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + - MultiTracker.MinKnockInterval) - } - - // Talk one more time to let the Guide know of reception completion - talkOnce - } - - // Connect to Guide and send this worker's information - private def talkOnce { - var clientSocketToGuide: Socket = null - var oosGuide: ObjectOutputStream = null - var oisGuide: ObjectInputStream = null - - clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort) - oosGuide = new ObjectOutputStream(clientSocketToGuide.getOutputStream) - oosGuide.flush() - oisGuide = new ObjectInputStream(clientSocketToGuide.getInputStream) - - // Send local information - oosGuide.writeObject(getLocalSourceInfo) - oosGuide.flush() - - // Receive source information from Guide - var suitableSources = - oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]] - logDebug("Received suitableSources from Driver " + suitableSources) - - addToListOfSources(suitableSources) - - oisGuide.close() - oosGuide.close() - clientSocketToGuide.close() - } - } - - def receiveBroadcast(variableID: Long): Boolean = { - val gInfo = MultiTracker.getGuideInfo(variableID) - - if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { - return false - } - - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Setup initial states of variables - totalBlocks = gInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) - hasBlocksBitVector = new BitSet(totalBlocks) - numCopiesSent = new Array[Int](totalBlocks) - totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } - totalBytes = gInfo.totalBytes - - // Start ttGuide to periodically talk to the Guide - var ttGuide = new TalkToGuide(gInfo) - ttGuide.setDaemon(true) - ttGuide.start() - logInfo("TalkToGuide started...") - - // Start pController to run TalkToPeer threads - var pcController = new PeerChatterController - pcController.setDaemon(true) - pcController.start() - logInfo("PeerChatterController started...") - - // FIXME: Must fix this. This might never break if broadcast fails. - // We should be able to break and send false. Also need to kill threads - while (hasBlocks.get < totalBlocks) { - Thread.sleep(MultiTracker.MaxKnockInterval) - } - - return true - } - - class PeerChatterController - extends Thread with Logging { - private var peersNowTalking = ListBuffer[SourceInfo]() - // TODO: There is a possible bug with blocksInRequestBitVector when a - // certain bit is NOT unset upon failure resulting in an infinite loop. - private var blocksInRequestBitVector = new BitSet(totalBlocks) - - override def run() { - var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots) - - while (hasBlocks.get < totalBlocks) { - var numThreadsToCreate = 0 - listOfSources.synchronized { - numThreadsToCreate = math.min(listOfSources.size, MultiTracker.MaxChatSlots) - - threadPool.getActiveCount - } - - while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) { - var peerToTalkTo = pickPeerToTalkToRandom - - if (peerToTalkTo != null) - logDebug("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector) - else - logDebug("No peer chosen...") - - if (peerToTalkTo != null) { - threadPool.execute(new TalkToPeer(peerToTalkTo)) - - // Add to peersNowTalking. Remove in the thread. We have to do this - // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once - peersNowTalking.synchronized { peersNowTalking += peerToTalkTo } - } - - numThreadsToCreate = numThreadsToCreate - 1 - } - - // Sleep for a while before starting some more threads - Thread.sleep(MultiTracker.MinKnockInterval) - } - // Shutdown the thread pool - threadPool.shutdown() - } - - // Right now picking the one that has the most blocks this peer wants - // Also picking peer randomly if no one has anything interesting - private def pickPeerToTalkToRandom: SourceInfo = { - var curPeer: SourceInfo = null - var curMax = 0 - - logDebug("Picking peers to talk to...") - - // Find peers that are not connected right now - var peersNotInUse = ListBuffer[SourceInfo]() - listOfSources.synchronized { - peersNowTalking.synchronized { - peersNotInUse = listOfSources -- peersNowTalking - } - } - - // Select the peer that has the most blocks that this receiver does not - peersNotInUse.foreach { eachSource => - var tempHasBlocksBitVector: BitSet = null - hasBlocksBitVector.synchronized { - tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - tempHasBlocksBitVector.flip(0, tempHasBlocksBitVector.size) - tempHasBlocksBitVector.and(eachSource.hasBlocksBitVector) - - if (tempHasBlocksBitVector.cardinality > curMax) { - curPeer = eachSource - curMax = tempHasBlocksBitVector.cardinality - } - } - - // Always picking randomly - if (curPeer == null && peersNotInUse.size > 0) { - // Pick uniformly the i'th required peer - var i = MultiTracker.ranGen.nextInt(peersNotInUse.size) - - var peerIter = peersNotInUse.iterator - curPeer = peerIter.next - - while (i > 0) { - curPeer = peerIter.next - i = i - 1 - } - } - - return curPeer - } - - // Picking peer with the weight of rare blocks it has - private def pickPeerToTalkToRarestFirst: SourceInfo = { - // Find peers that are not connected right now - var peersNotInUse = ListBuffer[SourceInfo]() - listOfSources.synchronized { - peersNowTalking.synchronized { - peersNotInUse = listOfSources -- peersNowTalking - } - } - - // Count the number of copies of each block in the neighborhood - var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0) - - listOfSources.synchronized { - listOfSources.foreach { eachSource => - for (i <- 0 until totalBlocks) { - numCopiesPerBlock(i) += - ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 ) - } - } - } - - // A block is considered rare if there are at most 2 copies of that block - // This CONSTANT could be a function of the neighborhood size - var rareBlocksIndices = ListBuffer[Int]() - for (i <- 0 until totalBlocks) { - if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) { - rareBlocksIndices += i - } - } - - // Find peers with rare blocks - var peersWithRareBlocks = ListBuffer[(SourceInfo, Int)]() - var totalRareBlocks = 0 - - peersNotInUse.foreach { eachPeer => - var hasRareBlocks = 0 - rareBlocksIndices.foreach { rareBlock => - if (eachPeer.hasBlocksBitVector.get(rareBlock)) { - hasRareBlocks += 1 - } - } - - if (hasRareBlocks > 0) { - peersWithRareBlocks += ((eachPeer, hasRareBlocks)) - } - totalRareBlocks += hasRareBlocks - } - - // Select a peer from peersWithRareBlocks based on weight calculated from - // unique rare blocks - var selectedPeerToTalkTo: SourceInfo = null - - if (peersWithRareBlocks.size > 0) { - // Sort the peers based on how many rare blocks they have - peersWithRareBlocks.sortBy(_._2) - - var randomNumber = MultiTracker.ranGen.nextDouble - var tempSum = 0.0 - - var i = 0 - do { - tempSum += (1.0 * peersWithRareBlocks(i)._2 / totalRareBlocks) - if (tempSum >= randomNumber) { - selectedPeerToTalkTo = peersWithRareBlocks(i)._1 - } - i += 1 - } while (i < peersWithRareBlocks.size && selectedPeerToTalkTo == null) - } - - if (selectedPeerToTalkTo == null) { - selectedPeerToTalkTo = pickPeerToTalkToRandom - } - - return selectedPeerToTalkTo - } - - class TalkToPeer(peerToTalkTo: SourceInfo) - extends Thread with Logging { - private var peerSocketToSource: Socket = null - private var oosSource: ObjectOutputStream = null - private var oisSource: ObjectInputStream = null - - override def run() { - // TODO: There is a possible bug here regarding blocksInRequestBitVector - var blockToAskFor = -1 - - // Setup the timeout mechanism - var timeOutTask = new TimerTask { - override def run() { - cleanUpConnections() - } - } - - var timeOutTimer = new Timer - timeOutTimer.schedule(timeOutTask, MultiTracker.MaxKnockInterval) - - logInfo("TalkToPeer started... => " + peerToTalkTo) - - try { - // Connect to the source - peerSocketToSource = - new Socket(peerToTalkTo.hostAddress, peerToTalkTo.listenPort) - oosSource = - new ObjectOutputStream(peerSocketToSource.getOutputStream) - oosSource.flush() - oisSource = - new ObjectInputStream(peerSocketToSource.getInputStream) - - // Receive latest SourceInfo from peerToTalkTo - var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo] - // Update listOfSources - addToListOfSources(newPeerToTalkTo) - - // Turn the timer OFF, if the sender responds before timeout - timeOutTimer.cancel() - - // Send the latest SourceInfo - oosSource.writeObject(getLocalSourceInfo) - oosSource.flush() - - var keepReceiving = true - - while (hasBlocks.get < totalBlocks && keepReceiving) { - blockToAskFor = - pickBlockRandom(newPeerToTalkTo.hasBlocksBitVector) - - // No block to request - if (blockToAskFor < 0) { - // Nothing to receive from newPeerToTalkTo - keepReceiving = false - } else { - // Let other threads know that blockToAskFor is being requested - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor) - } - - // Start with sending the blockID - oosSource.writeObject(blockToAskFor) - oosSource.flush() - - // CHANGED: Driver might send some other block than the one - // requested to ensure fast spreading of all blocks. - val recvStartTime = System.currentTimeMillis - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - val receptionTime = (System.currentTimeMillis - recvStartTime) - - logDebug("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.") - - if (!hasBlocksBitVector.get(bcBlock.blockID)) { - arrayOfBlocks(bcBlock.blockID) = bcBlock - - // Update the hasBlocksBitVector first - hasBlocksBitVector.synchronized { - hasBlocksBitVector.set(bcBlock.blockID) - hasBlocks.getAndIncrement - } - - // Some block(may NOT be blockToAskFor) has arrived. - // In any case, blockToAskFor is not in request any more - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor, false) - } - - // Reset blockToAskFor to -1. Else it will be considered missing - blockToAskFor = -1 - } - - // Send the latest SourceInfo - oosSource.writeObject(getLocalSourceInfo) - oosSource.flush() - } - } - } catch { - // EOFException is expected to happen because sender can break - // connection due to timeout - case eofe: java.io.EOFException => { } - case e: Exception => { - logError("TalktoPeer had a " + e) - // FIXME: Remove 'newPeerToTalkTo' from listOfSources - // We probably should have the following in some form, but not - // really here. This exception can happen if the sender just breaks connection - // listOfSources.synchronized { - // logInfo("Exception in TalkToPeer. Removing source: " + peerToTalkTo) - // listOfSources = listOfSources - peerToTalkTo - // } - } - } finally { - // blockToAskFor != -1 => there was an exception - if (blockToAskFor != -1) { - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor, false) - } - } - - cleanUpConnections() - } - } - - // Right now it picks a block uniformly that this peer does not have - private def pickBlockRandom(txHasBlocksBitVector: BitSet): Int = { - var needBlocksBitVector: BitSet = null - - // Blocks already present - hasBlocksBitVector.synchronized { - needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - - // Include blocks already in transmission ONLY IF - // MultiTracker.EndGameFraction has NOT been achieved - if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { - blocksInRequestBitVector.synchronized { - needBlocksBitVector.or(blocksInRequestBitVector) - } - } - - // Find blocks that are neither here nor in transit - needBlocksBitVector.flip(0, needBlocksBitVector.size) - - // Blocks that should/can be requested - needBlocksBitVector.and(txHasBlocksBitVector) - - if (needBlocksBitVector.cardinality == 0) { - return -1 - } else { - // Pick uniformly the i'th required block - var i = MultiTracker.ranGen.nextInt(needBlocksBitVector.cardinality) - var pickedBlockIndex = needBlocksBitVector.nextSetBit(0) - - while (i > 0) { - pickedBlockIndex = - needBlocksBitVector.nextSetBit(pickedBlockIndex + 1) - i -= 1 - } - - return pickedBlockIndex - } - } - - // Pick the block that seems to be the rarest across sources - private def pickBlockRarestFirst(txHasBlocksBitVector: BitSet): Int = { - var needBlocksBitVector: BitSet = null - - // Blocks already present - hasBlocksBitVector.synchronized { - needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - - // Include blocks already in transmission ONLY IF - // MultiTracker.EndGameFraction has NOT been achieved - if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { - blocksInRequestBitVector.synchronized { - needBlocksBitVector.or(blocksInRequestBitVector) - } - } - - // Find blocks that are neither here nor in transit - needBlocksBitVector.flip(0, needBlocksBitVector.size) - - // Blocks that should/can be requested - needBlocksBitVector.and(txHasBlocksBitVector) - - if (needBlocksBitVector.cardinality == 0) { - return -1 - } else { - // Count the number of copies for each block across all sources - var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0) - - listOfSources.synchronized { - listOfSources.foreach { eachSource => - for (i <- 0 until totalBlocks) { - numCopiesPerBlock(i) += - ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 ) - } - } - } - - // Find the minimum - var minVal = Integer.MAX_VALUE - for (i <- 0 until totalBlocks) { - if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) < minVal) { - minVal = numCopiesPerBlock(i) - } - } - - // Find the blocks with the least copies that this peer does not have - var minBlocksIndices = ListBuffer[Int]() - for (i <- 0 until totalBlocks) { - if (needBlocksBitVector.get(i) && numCopiesPerBlock(i) == minVal) { - minBlocksIndices += i - } - } - - // Now select a random index from minBlocksIndices - if (minBlocksIndices.size == 0) { - return -1 - } else { - // Pick uniformly the i'th index - var i = MultiTracker.ranGen.nextInt(minBlocksIndices.size) - return minBlocksIndices(i) - } - } - } - - private def cleanUpConnections() { - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (peerSocketToSource != null) { - peerSocketToSource.close() - } - - // Delete from peersNowTalking - peersNowTalking.synchronized { peersNowTalking -= peerToTalkTo } - } - } - } - - class GuideMultipleRequests - extends Thread with Logging { - // Keep track of sources that have completed reception - private var setOfCompletedSources = Set[SourceInfo]() - - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) - guidePort = serverSocket.getLocalPort - logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - - guidePortLock.synchronized { guidePortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - // Stop broadcast if at least one worker has connected and - // everyone connected so far are done. Comparing with - // listOfSources.size - 1, because it includes the Guide itself - listOfSources.synchronized { - setOfCompletedSources.synchronized { - if (listOfSources.size > 1 && - setOfCompletedSources.size == listOfSources.size - 1) { - stopBroadcast = true - logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.") - } - } - } - } - } - if (clientSocket != null) { - logDebug("Guide: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new GuideSingleRequest(clientSocket)) - } catch { - // In failure, close the socket here; else, thread will close it - case ioe: IOException => { - clientSocket.close() - } - } - } - } - - // Shutdown the thread pool - threadPool.shutdown() - - logInfo("Sending stopBroadcast notifications...") - sendStopBroadcastNotifications - - MultiTracker.unregisterBroadcast(id) - } finally { - if (serverSocket != null) { - logInfo("GuideMultipleRequests now stopping...") - serverSocket.close() - } - } - } - - private def sendStopBroadcastNotifications() { - listOfSources.synchronized { - listOfSources.foreach { sourceInfo => - - var guideSocketToSource: Socket = null - var gosSource: ObjectOutputStream = null - var gisSource: ObjectInputStream = null - - try { - // Connect to the source - guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) - gosSource.flush() - gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) - - // Throw away whatever comes in - gisSource.readObject.asInstanceOf[SourceInfo] - - // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast - gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast)) - gosSource.flush() - } catch { - case e: Exception => { - logError("sendStopBroadcastNotifications had a " + e) - } - } finally { - if (gisSource != null) { - gisSource.close() - } - if (gosSource != null) { - gosSource.close() - } - if (guideSocketToSource != null) { - guideSocketToSource.close() - } - } - } - } - } - - class GuideSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var sourceInfo: SourceInfo = null - private var selectedSources: ListBuffer[SourceInfo] = null - - override def run() { - try { - logInfo("new GuideSingleRequest is running") - // Connecting worker is sending in its information - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Select a suitable source and send it back to the worker - selectedSources = selectSuitableSources(sourceInfo) - logDebug("Sending selectedSources:" + selectedSources) - oos.writeObject(selectedSources) - oos.flush() - - // Add this source to the listOfSources - addToListOfSources(sourceInfo) - } catch { - case e: Exception => { - // Assuming exception caused by receiver failure: remove - if (listOfSources != null) { - listOfSources.synchronized { listOfSources -= sourceInfo } - } - } - } finally { - logInfo("GuideSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - // Randomly select some sources to send back - private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = { - var selectedSources = ListBuffer[SourceInfo]() - - // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true' - // then add skipSourceInfo to setOfCompletedSources. Return blank. - if (skipSourceInfo.hasBlocks == totalBlocks) { - setOfCompletedSources.synchronized { setOfCompletedSources += skipSourceInfo } - return selectedSources - } - - listOfSources.synchronized { - if (listOfSources.size <= MultiTracker.MaxPeersInGuideResponse) { - selectedSources = listOfSources.clone - } else { - var picksLeft = MultiTracker.MaxPeersInGuideResponse - var alreadyPicked = new BitSet(listOfSources.size) - - while (picksLeft > 0) { - var i = -1 - - do { - i = MultiTracker.ranGen.nextInt(listOfSources.size) - } while (alreadyPicked.get(i)) - - var peerIter = listOfSources.iterator - var curPeer = peerIter.next - - // Set the BitSet before i is decremented - alreadyPicked.set(i) - - while (i > 0) { - curPeer = peerIter.next - i = i - 1 - } - - selectedSources += curPeer - - picksLeft = picksLeft - 1 - } - } - } - - // Remove the receiving source (if present) - selectedSources = selectedSources - skipSourceInfo - - return selectedSources - } - } - } - - class ServeMultipleRequests - extends Thread with Logging { - // Server at most MultiTracker.MaxChatSlots peers - var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots) - - override def run() { - var serverSocket = new ServerSocket(0) - listenPort = serverSocket.getLocalPort - - logInfo("ServeMultipleRequests started with " + serverSocket) - - listenPortLock.synchronized { listenPortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { } - } - if (clientSocket != null) { - logDebug("Serve: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new ServeSingleRequest(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ServeMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ServeSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - logInfo("new ServeSingleRequest is running") - - override def run() { - try { - // Send latest local SourceInfo to the receiver - // In the case of receiver timeout and connection close, this will - // throw a java.net.SocketException: Broken pipe - oos.writeObject(getLocalSourceInfo) - oos.flush() - - // Receive latest SourceInfo from the receiver - var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) { - stopBroadcast = true - } else { - addToListOfSources(rxSourceInfo) - } - - val startTime = System.currentTimeMillis - var curTime = startTime - var keepSending = true - var numBlocksToSend = MultiTracker.MaxChatBlocks - - while (!stopBroadcast && keepSending && numBlocksToSend > 0) { - // Receive which block to send - var blockToSend = ois.readObject.asInstanceOf[Int] - - // If it is driver AND at least one copy of each block has not been - // sent out already, MODIFY blockToSend - if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) { - blockToSend = sentBlocks.getAndIncrement - } - - // Send the block - sendBlock(blockToSend) - rxSourceInfo.hasBlocksBitVector.set(blockToSend) - - numBlocksToSend -= 1 - - // Receive latest SourceInfo from the receiver - rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] - logDebug("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector) - addToListOfSources(rxSourceInfo) - - curTime = System.currentTimeMillis - // Revoke sending only if there is anyone waiting in the queue - if (curTime - startTime >= MultiTracker.MaxChatTime && - threadPool.getQueue.size > 0) { - keepSending = false - } - } - } catch { - case e: Exception => logError("ServeSingleRequest had a " + e) - } finally { - logInfo("ServeSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - private def sendBlock(blockToSend: Int) { - try { - oos.writeObject(arrayOfBlocks(blockToSend)) - oos.flush() - } catch { - case e: Exception => logError("sendBlock had a " + e) - } - logDebug("Sent block: " + blockToSend + " to " + clientSocket) - } - } - } -} - -private[spark] class BitTorrentBroadcastFactory -extends BroadcastFactory { - def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new BitTorrentBroadcast[T](value_, isLocal, id) - - def stop() { MultiTracker.stop() } -} diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala deleted file mode 100644 index aba56a60ca..0000000000 --- a/core/src/main/scala/spark/broadcast/Broadcast.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.broadcast - -import java.io._ -import java.util.concurrent.atomic.AtomicLong - -import spark._ - -abstract class Broadcast[T](private[spark] val id: Long) extends Serializable { - def value: T - - // We cannot have an abstract readObject here due to some weird issues with - // readObject having to be 'private' in sub-classes. - - override def toString = "spark.Broadcast(" + id + ")" -} - -private[spark] -class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable { - - private var initialized = false - private var broadcastFactory: BroadcastFactory = null - - initialize() - - // Called by SparkContext or Executor before using Broadcast - private def initialize() { - synchronized { - if (!initialized) { - val broadcastFactoryClass = System.getProperty( - "spark.broadcast.factory", "spark.broadcast.HttpBroadcastFactory") - - broadcastFactory = - Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] - - // Initialize appropriate BroadcastFactory and BroadcastObject - broadcastFactory.initialize(isDriver) - - initialized = true - } - } - } - - def stop() { - broadcastFactory.stop() - } - - private val nextBroadcastId = new AtomicLong(0) - - def newBroadcast[T](value_ : T, isLocal: Boolean) = - broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) - - def isDriver = _isDriver -} diff --git a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala deleted file mode 100644 index d33d95c7d9..0000000000 --- a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.broadcast - -/** - * An interface for all the broadcast implementations in Spark (to allow - * multiple broadcast implementations). SparkContext uses a user-specified - * BroadcastFactory implementation to instantiate a particular broadcast for the - * entire Spark job. - */ -private[spark] trait BroadcastFactory { - def initialize(isDriver: Boolean): Unit - def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] - def stop(): Unit -} diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala deleted file mode 100644 index 138a8c21bc..0000000000 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.broadcast - -import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} -import java.net.URL - -import it.unimi.dsi.fastutil.io.FastBufferedInputStream -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream - -import spark.{HttpServer, Logging, SparkEnv, Utils} -import spark.io.CompressionCodec -import spark.storage.StorageLevel -import spark.util.{MetadataCleaner, TimeStampedHashSet} - - -private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) - extends Broadcast[T](id) with Logging with Serializable { - - def value = value_ - - def blockId: String = "broadcast_" + id - - HttpBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } - - if (!isLocal) { - HttpBroadcast.write(id, value_) - } - - // Called by JVM when deserializing an object - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - HttpBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { - case Some(x) => value_ = x.asInstanceOf[T] - case None => { - logInfo("Started reading broadcast variable " + id) - val start = System.nanoTime - value_ = HttpBroadcast.read[T](id) - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - } -} - -private[spark] class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean) { HttpBroadcast.initialize(isDriver) } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new HttpBroadcast[T](value_, isLocal, id) - - def stop() { HttpBroadcast.stop() } -} - -private object HttpBroadcast extends Logging { - private var initialized = false - - private var broadcastDir: File = null - private var compress: Boolean = false - private var bufferSize: Int = 65536 - private var serverUri: String = null - private var server: HttpServer = null - - private val files = new TimeStampedHashSet[String] - private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup) - - private lazy val compressionCodec = CompressionCodec.createCodec() - - def initialize(isDriver: Boolean) { - synchronized { - if (!initialized) { - bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - compress = System.getProperty("spark.broadcast.compress", "true").toBoolean - if (isDriver) { - createServer() - } - serverUri = System.getProperty("spark.httpBroadcast.uri") - initialized = true - } - } - } - - def stop() { - synchronized { - if (server != null) { - server.stop() - server = null - } - initialized = false - cleaner.cancel() - } - } - - private def createServer() { - broadcastDir = Utils.createTempDir(Utils.getLocalDir) - server = new HttpServer(broadcastDir) - server.start() - serverUri = server.uri - System.setProperty("spark.httpBroadcast.uri", serverUri) - logInfo("Broadcast server started at " + serverUri) - } - - def write(id: Long, value: Any) { - val file = new File(broadcastDir, "broadcast-" + id) - val out: OutputStream = { - if (compress) { - compressionCodec.compressedOutputStream(new FileOutputStream(file)) - } else { - new FastBufferedOutputStream(new FileOutputStream(file), bufferSize) - } - } - val ser = SparkEnv.get.serializer.newInstance() - val serOut = ser.serializeStream(out) - serOut.writeObject(value) - serOut.close() - files += file.getAbsolutePath - } - - def read[T](id: Long): T = { - val url = serverUri + "/broadcast-" + id - val in = { - if (compress) { - compressionCodec.compressedInputStream(new URL(url).openStream()) - } else { - new FastBufferedInputStream(new URL(url).openStream(), bufferSize) - } - } - val ser = SparkEnv.get.serializer.newInstance() - val serIn = ser.deserializeStream(in) - val obj = serIn.readObject[T]() - serIn.close() - obj - } - - def cleanup(cleanupTime: Long) { - val iterator = files.internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - val (file, time) = (entry.getKey, entry.getValue) - if (time < cleanupTime) { - try { - iterator.remove() - new File(file.toString).delete() - logInfo("Deleted broadcast file '" + file + "'") - } catch { - case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) - } - } - } - } -} diff --git a/core/src/main/scala/spark/broadcast/MultiTracker.scala b/core/src/main/scala/spark/broadcast/MultiTracker.scala deleted file mode 100644 index 7855d44e9b..0000000000 --- a/core/src/main/scala/spark/broadcast/MultiTracker.scala +++ /dev/null @@ -1,409 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.broadcast - -import java.io._ -import java.net._ -import java.util.Random - -import scala.collection.mutable.Map - -import spark._ - -private object MultiTracker -extends Logging { - - // Tracker Messages - val REGISTER_BROADCAST_TRACKER = 0 - val UNREGISTER_BROADCAST_TRACKER = 1 - val FIND_BROADCAST_TRACKER = 2 - - // Map to keep track of guides of ongoing broadcasts - var valueToGuideMap = Map[Long, SourceInfo]() - - // Random number generator - var ranGen = new Random - - private var initialized = false - private var _isDriver = false - - private var stopBroadcast = false - - private var trackMV: TrackMultipleValues = null - - def initialize(__isDriver: Boolean) { - synchronized { - if (!initialized) { - _isDriver = __isDriver - - if (isDriver) { - trackMV = new TrackMultipleValues - trackMV.setDaemon(true) - trackMV.start() - - // Set DriverHostAddress to the driver's IP address for the slaves to read - System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress) - } - - initialized = true - } - } - } - - def stop() { - stopBroadcast = true - } - - // Load common parameters - private var DriverHostAddress_ = System.getProperty( - "spark.MultiTracker.DriverHostAddress", "") - private var DriverTrackerPort_ = System.getProperty( - "spark.broadcast.driverTrackerPort", "11111").toInt - private var BlockSize_ = System.getProperty( - "spark.broadcast.blockSize", "4096").toInt * 1024 - private var MaxRetryCount_ = System.getProperty( - "spark.broadcast.maxRetryCount", "2").toInt - - private var TrackerSocketTimeout_ = System.getProperty( - "spark.broadcast.trackerSocketTimeout", "50000").toInt - private var ServerSocketTimeout_ = System.getProperty( - "spark.broadcast.serverSocketTimeout", "10000").toInt - - private var MinKnockInterval_ = System.getProperty( - "spark.broadcast.minKnockInterval", "500").toInt - private var MaxKnockInterval_ = System.getProperty( - "spark.broadcast.maxKnockInterval", "999").toInt - - // Load TreeBroadcast config params - private var MaxDegree_ = System.getProperty( - "spark.broadcast.maxDegree", "2").toInt - - // Load BitTorrentBroadcast config params - private var MaxPeersInGuideResponse_ = System.getProperty( - "spark.broadcast.maxPeersInGuideResponse", "4").toInt - - private var MaxChatSlots_ = System.getProperty( - "spark.broadcast.maxChatSlots", "4").toInt - private var MaxChatTime_ = System.getProperty( - "spark.broadcast.maxChatTime", "500").toInt - private var MaxChatBlocks_ = System.getProperty( - "spark.broadcast.maxChatBlocks", "1024").toInt - - private var EndGameFraction_ = System.getProperty( - "spark.broadcast.endGameFraction", "0.95").toDouble - - def isDriver = _isDriver - - // Common config params - def DriverHostAddress = DriverHostAddress_ - def DriverTrackerPort = DriverTrackerPort_ - def BlockSize = BlockSize_ - def MaxRetryCount = MaxRetryCount_ - - def TrackerSocketTimeout = TrackerSocketTimeout_ - def ServerSocketTimeout = ServerSocketTimeout_ - - def MinKnockInterval = MinKnockInterval_ - def MaxKnockInterval = MaxKnockInterval_ - - // TreeBroadcast configs - def MaxDegree = MaxDegree_ - - // BitTorrentBroadcast configs - def MaxPeersInGuideResponse = MaxPeersInGuideResponse_ - - def MaxChatSlots = MaxChatSlots_ - def MaxChatTime = MaxChatTime_ - def MaxChatBlocks = MaxChatBlocks_ - - def EndGameFraction = EndGameFraction_ - - class TrackMultipleValues - extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(DriverTrackerPort) - logInfo("TrackMultipleValues started at " + serverSocket) - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(TrackerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - if (stopBroadcast) { - logInfo("Stopping TrackMultipleValues...") - } - } - } - - if (clientSocket != null) { - try { - threadPool.execute(new Thread { - override def run() { - val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - val ois = new ObjectInputStream(clientSocket.getInputStream) - - try { - // First, read message type - val messageType = ois.readObject.asInstanceOf[Int] - - if (messageType == REGISTER_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - // Receive hostAddress and listenPort - val gInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Add to the map - valueToGuideMap.synchronized { - valueToGuideMap += (id -> gInfo) - } - - logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == UNREGISTER_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - - // Remove from the map - valueToGuideMap.synchronized { - valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault) - } - - logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == FIND_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - - var gInfo = - if (valueToGuideMap.contains(id)) valueToGuideMap(id) - else SourceInfo("", SourceInfo.TxNotStartedRetry) - - logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort) - - // Send reply back - oos.writeObject(gInfo) - oos.flush() - } else { - throw new SparkException("Undefined messageType at TrackMultipleValues") - } - } catch { - case e: Exception => { - logError("TrackMultipleValues had a " + e) - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - }) - } catch { - // In failure, close socket here; else, client thread will close - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - serverSocket.close() - } - // Shutdown the thread pool - threadPool.shutdown() - } - } - - def getGuideInfo(variableLong: Long): SourceInfo = { - var clientSocketToTracker: Socket = null - var oosTracker: ObjectOutputStream = null - var oisTracker: ObjectInputStream = null - - var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry) - - var retriesLeft = MultiTracker.MaxRetryCount - do { - try { - // Connect to the tracker to find out GuideInfo - clientSocketToTracker = - new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort) - oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - // Send messageType/intention - oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER) - oosTracker.flush() - - // Send Long and receive GuideInfo - oosTracker.writeObject(variableLong) - oosTracker.flush() - gInfo = oisTracker.readObject.asInstanceOf[SourceInfo] - } catch { - case e: Exception => logError("getGuideInfo had a " + e) - } finally { - if (oisTracker != null) { - oisTracker.close() - } - if (oosTracker != null) { - oosTracker.close() - } - if (clientSocketToTracker != null) { - clientSocketToTracker.close() - } - } - - Thread.sleep(MultiTracker.ranGen.nextInt( - MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + - MultiTracker.MinKnockInterval) - - retriesLeft -= 1 - } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry) - - logDebug("Got this guidePort from Tracker: " + gInfo.listenPort) - return gInfo - } - - def registerBroadcast(id: Long, gInfo: SourceInfo) { - val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(REGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send Long of this broadcast - oosST.writeObject(id) - oosST.flush() - - // Send this tracker's information - oosST.writeObject(gInfo) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - - def unregisterBroadcast(id: Long) { - val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(UNREGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send Long of this broadcast - oosST.writeObject(id) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - - // Helper method to convert an object to Array[BroadcastBlock] - def blockifyObject[IN](obj: IN): VariableInfo = { - val baos = new ByteArrayOutputStream - val oos = new ObjectOutputStream(baos) - oos.writeObject(obj) - oos.close() - baos.close() - val byteArray = baos.toByteArray - val bais = new ByteArrayInputStream(byteArray) - - var blockNum = (byteArray.length / BlockSize) - if (byteArray.length % BlockSize != 0) - blockNum += 1 - - var retVal = new Array[BroadcastBlock](blockNum) - var blockID = 0 - - for (i <- 0 until (byteArray.length, BlockSize)) { - val thisBlockSize = math.min(BlockSize, byteArray.length - i) - var tempByteArray = new Array[Byte](thisBlockSize) - val hasRead = bais.read(tempByteArray, 0, thisBlockSize) - - retVal(blockID) = new BroadcastBlock(blockID, tempByteArray) - blockID += 1 - } - bais.close() - - var variableInfo = VariableInfo(retVal, blockNum, byteArray.length) - variableInfo.hasBlocks = blockNum - - return variableInfo - } - - // Helper method to convert Array[BroadcastBlock] to object - def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock], - totalBytes: Int, - totalBlocks: Int): OUT = { - - var retByteArray = new Array[Byte](totalBytes) - for (i <- 0 until totalBlocks) { - System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, - i * BlockSize, arrayOfBlocks(i).byteArray.length) - } - byteArrayToObject(retByteArray) - } - - private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = { - val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){ - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader) - } - val retVal = in.readObject.asInstanceOf[OUT] - in.close() - return retVal - } -} - -private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte]) -extends Serializable - -private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock], - totalBlocks: Int, - totalBytes: Int) -extends Serializable { - @transient var hasBlocks = 0 -} diff --git a/core/src/main/scala/spark/broadcast/SourceInfo.scala b/core/src/main/scala/spark/broadcast/SourceInfo.scala deleted file mode 100644 index b17ae63b5c..0000000000 --- a/core/src/main/scala/spark/broadcast/SourceInfo.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.broadcast - -import java.util.BitSet - -import spark._ - -/** - * Used to keep and pass around information of peers involved in a broadcast - */ -private[spark] case class SourceInfo (hostAddress: String, - listenPort: Int, - totalBlocks: Int = SourceInfo.UnusedParam, - totalBytes: Int = SourceInfo.UnusedParam) -extends Comparable[SourceInfo] with Logging { - - var currentLeechers = 0 - var receptionFailed = false - - var hasBlocks = 0 - var hasBlocksBitVector: BitSet = new BitSet (totalBlocks) - - // Ascending sort based on leecher count - def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers) -} - -/** - * Helper Object of SourceInfo for its constants - */ -private[spark] object SourceInfo { - // Broadcast has not started yet! Should never happen. - val TxNotStartedRetry = -1 - // Broadcast has already finished. Try default mechanism. - val TxOverGoToDefault = -3 - // Other constants - val StopBroadcast = -2 - val UnusedParam = 0 -} diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala deleted file mode 100644 index ea1e9a12c1..0000000000 --- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala +++ /dev/null @@ -1,602 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.broadcast - -import java.io._ -import java.net._ -import java.util.{Comparator, Random, UUID} - -import scala.collection.mutable.{ListBuffer, Map, Set} -import scala.math - -import spark._ -import spark.storage.StorageLevel - -private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) -extends Broadcast[T](id) with Logging with Serializable { - - def value = value_ - - def blockId = "broadcast_" + id - - MultiTracker.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } - - @transient var arrayOfBlocks: Array[BroadcastBlock] = null - @transient var totalBytes = -1 - @transient var totalBlocks = -1 - @transient var hasBlocks = 0 - - @transient var listenPortLock = new Object - @transient var guidePortLock = new Object - @transient var totalBlocksLock = new Object - @transient var hasBlocksLock = new Object - - @transient var listOfSources = ListBuffer[SourceInfo]() - - @transient var serveMR: ServeMultipleRequests = null - @transient var guideMR: GuideMultipleRequests = null - - @transient var hostAddress = Utils.localIpAddress - @transient var listenPort = -1 - @transient var guidePort = -1 - - @transient var stopBroadcast = false - - // Must call this after all the variables have been created/initialized - if (!isLocal) { - sendBroadcast() - } - - def sendBroadcast() { - logInfo("Local host address: " + hostAddress) - - // Create a variableInfo object and store it in valueInfos - var variableInfo = MultiTracker.blockifyObject(value_) - - // Prepare the value being broadcasted - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks = variableInfo.totalBlocks - - guideMR = new GuideMultipleRequests - guideMR.setDaemon(true) - guideMR.start() - logInfo("GuideMultipleRequests started...") - - // Must always come AFTER guideMR is created - while (guidePort == -1) { - guidePortLock.synchronized { guidePortLock.wait() } - } - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - // Must always come AFTER serveMR is created - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Must always come AFTER listenPort is created - val masterSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) - listOfSources += masterSource - - // Register with the Tracker - MultiTracker.registerBroadcast(id, - SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) - } - - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - MultiTracker.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { - case Some(x) => - value_ = x.asInstanceOf[T] - - case None => - logInfo("Started reading broadcast variable " + id) - // Initializing everything because Driver will only send null/0 values - // Only the 1st worker in a node can be here. Others will get from cache - initializeWorkerVariables() - - logInfo("Local host address: " + hostAddress) - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast(id) - if (receptionSucceeded) { - value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } else { - logError("Reading broadcast variable " + id + " failed") - } - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - - private def initializeWorkerVariables() { - arrayOfBlocks = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = 0 - - listenPortLock = new Object - totalBlocksLock = new Object - hasBlocksLock = new Object - - serveMR = null - - hostAddress = Utils.localIpAddress - listenPort = -1 - - stopBroadcast = false - } - - def receiveBroadcast(variableID: Long): Boolean = { - val gInfo = MultiTracker.getGuideInfo(variableID) - - if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { - return false - } - - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - var clientSocketToDriver: Socket = null - var oosDriver: ObjectOutputStream = null - var oisDriver: ObjectInputStream = null - - // Connect and receive broadcast from the specified source, retrying the - // specified number of times in case of failures - var retriesLeft = MultiTracker.MaxRetryCount - do { - // Connect to Driver and send this worker's Information - clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort) - oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream) - oosDriver.flush() - oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream) - - logDebug("Connected to Driver's guiding object") - - // Send local source information - oosDriver.writeObject(SourceInfo(hostAddress, listenPort)) - oosDriver.flush() - - // Receive source information from Driver - var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo] - totalBlocks = sourceInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) - totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } - totalBytes = sourceInfo.totalBytes - - logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort) - - val start = System.nanoTime - val receptionSucceeded = receiveSingleTransmission(sourceInfo) - val time = (System.nanoTime - start) / 1e9 - - // Updating some statistics in sourceInfo. Driver will be using them later - if (!receptionSucceeded) { - sourceInfo.receptionFailed = true - } - - // Send back statistics to the Driver - oosDriver.writeObject(sourceInfo) - - if (oisDriver != null) { - oisDriver.close() - } - if (oosDriver != null) { - oosDriver.close() - } - if (clientSocketToDriver != null) { - clientSocketToDriver.close() - } - - retriesLeft -= 1 - } while (retriesLeft > 0 && hasBlocks < totalBlocks) - - return (hasBlocks == totalBlocks) - } - - /** - * Tries to receive broadcast from the source and returns Boolean status. - * This might be called multiple times to retry a defined number of times. - */ - private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = { - var clientSocketToSource: Socket = null - var oosSource: ObjectOutputStream = null - var oisSource: ObjectInputStream = null - - var receptionSucceeded = false - try { - // Connect to the source to get the object itself - clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream) - oosSource.flush() - oisSource = new ObjectInputStream(clientSocketToSource.getInputStream) - - logDebug("Inside receiveSingleTransmission") - logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) - - // Send the range - oosSource.writeObject((hasBlocks, totalBlocks)) - oosSource.flush() - - for (i <- hasBlocks until totalBlocks) { - val recvStartTime = System.currentTimeMillis - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - val receptionTime = (System.currentTimeMillis - recvStartTime) - - logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.") - - arrayOfBlocks(hasBlocks) = bcBlock - hasBlocks += 1 - - // Set to true if at least one block is received - receptionSucceeded = true - hasBlocksLock.synchronized { hasBlocksLock.notifyAll() } - } - } catch { - case e: Exception => logError("receiveSingleTransmission had a " + e) - } finally { - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (clientSocketToSource != null) { - clientSocketToSource.close() - } - } - - return receptionSucceeded - } - - class GuideMultipleRequests - extends Thread with Logging { - // Keep track of sources that have completed reception - private var setOfCompletedSources = Set[SourceInfo]() - - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) - guidePort = serverSocket.getLocalPort - logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - - guidePortLock.synchronized { guidePortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - // Stop broadcast if at least one worker has connected and - // everyone connected so far are done. Comparing with - // listOfSources.size - 1, because it includes the Guide itself - listOfSources.synchronized { - setOfCompletedSources.synchronized { - if (listOfSources.size > 1 && - setOfCompletedSources.size == listOfSources.size - 1) { - stopBroadcast = true - logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.") - } - } - } - } - } - if (clientSocket != null) { - logDebug("Guide: Accepted new client connection: " + clientSocket) - try { - threadPool.execute(new GuideSingleRequest(clientSocket)) - } catch { - // In failure, close() the socket here; else, the thread will close() it - case ioe: IOException => clientSocket.close() - } - } - } - - logInfo("Sending stopBroadcast notifications...") - sendStopBroadcastNotifications - - MultiTracker.unregisterBroadcast(id) - } finally { - if (serverSocket != null) { - logInfo("GuideMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - private def sendStopBroadcastNotifications() { - listOfSources.synchronized { - var listIter = listOfSources.iterator - while (listIter.hasNext) { - var sourceInfo = listIter.next - - var guideSocketToSource: Socket = null - var gosSource: ObjectOutputStream = null - var gisSource: ObjectInputStream = null - - try { - // Connect to the source - guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) - gosSource.flush() - gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) - - // Send stopBroadcast signal - gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast)) - gosSource.flush() - } catch { - case e: Exception => { - logError("sendStopBroadcastNotifications had a " + e) - } - } finally { - if (gisSource != null) { - gisSource.close() - } - if (gosSource != null) { - gosSource.close() - } - if (guideSocketToSource != null) { - guideSocketToSource.close() - } - } - } - } - } - - class GuideSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var selectedSourceInfo: SourceInfo = null - private var thisWorkerInfo:SourceInfo = null - - override def run() { - try { - logInfo("new GuideSingleRequest is running") - // Connecting worker is sending in its hostAddress and listenPort it will - // be listening to. Other fields are invalid (SourceInfo.UnusedParam) - var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - listOfSources.synchronized { - // Select a suitable source and send it back to the worker - selectedSourceInfo = selectSuitableSource(sourceInfo) - logDebug("Sending selectedSourceInfo: " + selectedSourceInfo) - oos.writeObject(selectedSourceInfo) - oos.flush() - - // Add this new (if it can finish) source to the list of sources - thisWorkerInfo = SourceInfo(sourceInfo.hostAddress, - sourceInfo.listenPort, totalBlocks, totalBytes) - logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo) - listOfSources += thisWorkerInfo - } - - // Wait till the whole transfer is done. Then receive and update source - // statistics in listOfSources - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - listOfSources.synchronized { - // This should work since SourceInfo is a case class - assert(listOfSources.contains(selectedSourceInfo)) - - // Remove first - // (Currently removing a source based on just one failure notification!) - listOfSources = listOfSources - selectedSourceInfo - - // Update sourceInfo and put it back in, IF reception succeeded - if (!sourceInfo.receptionFailed) { - // Add thisWorkerInfo to sources that have completed reception - setOfCompletedSources.synchronized { - setOfCompletedSources += thisWorkerInfo - } - - // Update leecher count and put it back in - selectedSourceInfo.currentLeechers -= 1 - listOfSources += selectedSourceInfo - } - } - } catch { - case e: Exception => { - // Remove failed worker from listOfSources and update leecherCount of - // corresponding source worker - listOfSources.synchronized { - if (selectedSourceInfo != null) { - // Remove first - listOfSources = listOfSources - selectedSourceInfo - // Update leecher count and put it back in - selectedSourceInfo.currentLeechers -= 1 - listOfSources += selectedSourceInfo - } - - // Remove thisWorkerInfo - if (listOfSources != null) { - listOfSources = listOfSources - thisWorkerInfo - } - } - } - } finally { - logInfo("GuideSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - // Assuming the caller to have a synchronized block on listOfSources - // Select one with the most leechers. This will level-wise fill the tree - private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { - var maxLeechers = -1 - var selectedSource: SourceInfo = null - - listOfSources.foreach { source => - if ((source.hostAddress != skipSourceInfo.hostAddress || - source.listenPort != skipSourceInfo.listenPort) && - source.currentLeechers < MultiTracker.MaxDegree && - source.currentLeechers > maxLeechers) { - selectedSource = source - maxLeechers = source.currentLeechers - } - } - - // Update leecher count - selectedSource.currentLeechers += 1 - return selectedSource - } - } - } - - class ServeMultipleRequests - extends Thread with Logging { - - var threadPool = Utils.newDaemonCachedThreadPool() - - override def run() { - var serverSocket = new ServerSocket(0) - listenPort = serverSocket.getLocalPort - - logInfo("ServeMultipleRequests started with " + serverSocket) - - listenPortLock.synchronized { listenPortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { } - } - - if (clientSocket != null) { - logDebug("Serve: Accepted new client connection: " + clientSocket) - try { - threadPool.execute(new ServeSingleRequest(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ServeMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ServeSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var sendFrom = 0 - private var sendUntil = totalBlocks - - override def run() { - try { - logInfo("new ServeSingleRequest is running") - - // Receive range to send - var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)] - sendFrom = rangeToSend._1 - sendUntil = rangeToSend._2 - - // If not a valid range, stop broadcast - if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) { - stopBroadcast = true - } else { - sendObject - } - } catch { - case e: Exception => logError("ServeSingleRequest had a " + e) - } finally { - logInfo("ServeSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - private def sendObject() { - // Wait till receiving the SourceInfo from Driver - while (totalBlocks == -1) { - totalBlocksLock.synchronized { totalBlocksLock.wait() } - } - - for (i <- sendFrom until sendUntil) { - while (i == hasBlocks) { - hasBlocksLock.synchronized { hasBlocksLock.wait() } - } - try { - oos.writeObject(arrayOfBlocks(i)) - oos.flush() - } catch { - case e: Exception => logError("sendObject had a " + e) - } - logDebug("Sent block: " + i + " to " + clientSocket) - } - } - } - } -} - -private[spark] class TreeBroadcastFactory -extends BroadcastFactory { - def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new TreeBroadcast[T](value_, isLocal, id) - - def stop() { MultiTracker.stop() } -} diff --git a/core/src/main/scala/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/spark/deploy/ApplicationDescription.scala deleted file mode 100644 index a8b22fbef8..0000000000 --- a/core/src/main/scala/spark/deploy/ApplicationDescription.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy - -private[spark] class ApplicationDescription( - val name: String, - val maxCores: Int, /* Integer.MAX_VALUE denotes an unlimited number of cores */ - val memoryPerSlave: Int, - val command: Command, - val sparkHome: String, - val appUiUrl: String) - extends Serializable { - - val user = System.getProperty("user.name", "") - - override def toString: String = "ApplicationDescription(" + name + ")" -} diff --git a/core/src/main/scala/spark/deploy/Command.scala b/core/src/main/scala/spark/deploy/Command.scala deleted file mode 100644 index bad629e965..0000000000 --- a/core/src/main/scala/spark/deploy/Command.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy - -import scala.collection.Map - -private[spark] case class Command( - mainClass: String, - arguments: Seq[String], - environment: Map[String, String]) { -} diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala deleted file mode 100644 index 0db13ffc98..0000000000 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy - -import scala.collection.immutable.List - -import spark.Utils -import spark.deploy.ExecutorState.ExecutorState -import spark.deploy.master.{WorkerInfo, ApplicationInfo} -import spark.deploy.worker.ExecutorRunner - - -private[deploy] sealed trait DeployMessage extends Serializable - -private[deploy] object DeployMessages { - - // Worker to Master - - case class RegisterWorker( - id: String, - host: String, - port: Int, - cores: Int, - memory: Int, - webUiPort: Int, - publicAddress: String) - extends DeployMessage { - Utils.checkHost(host, "Required hostname") - assert (port > 0) - } - - case class ExecutorStateChanged( - appId: String, - execId: Int, - state: ExecutorState, - message: Option[String], - exitStatus: Option[Int]) - extends DeployMessage - - case class Heartbeat(workerId: String) extends DeployMessage - - // Master to Worker - - case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage - - case class RegisterWorkerFailed(message: String) extends DeployMessage - - case class KillExecutor(appId: String, execId: Int) extends DeployMessage - - case class LaunchExecutor( - appId: String, - execId: Int, - appDesc: ApplicationDescription, - cores: Int, - memory: Int, - sparkHome: String) - extends DeployMessage - - // Client to Master - - case class RegisterApplication(appDescription: ApplicationDescription) - extends DeployMessage - - // Master to Client - - case class RegisteredApplication(appId: String) extends DeployMessage - - // TODO(matei): replace hostPort with host - case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { - Utils.checkHostPort(hostPort, "Required hostport") - } - - case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String], - exitStatus: Option[Int]) - - case class ApplicationRemoved(message: String) - - // Internal message in Client - - case object StopClient - - // MasterWebUI To Master - - case object RequestMasterState - - // Master to MasterWebUI - - case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo], - activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) { - - Utils.checkHost(host, "Required hostname") - assert (port > 0) - - def uri = "spark://" + host + ":" + port - } - - // WorkerWebUI to Worker - - case object RequestWorkerState - - // Worker to WorkerWebUI - - case class WorkerStateResponse(host: String, port: Int, workerId: String, - executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner], masterUrl: String, - cores: Int, memory: Int, coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) { - - Utils.checkHost(host, "Required hostname") - assert (port > 0) - } - - // Actor System to Master - - case object CheckForWorkerTimeOut - -} diff --git a/core/src/main/scala/spark/deploy/ExecutorState.scala b/core/src/main/scala/spark/deploy/ExecutorState.scala deleted file mode 100644 index 08c9a3b725..0000000000 --- a/core/src/main/scala/spark/deploy/ExecutorState.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy - -private[spark] object ExecutorState - extends Enumeration("LAUNCHING", "LOADING", "RUNNING", "KILLED", "FAILED", "LOST") { - - val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value - - type ExecutorState = Value - - def isFinished(state: ExecutorState): Boolean = Seq(KILLED, FAILED, LOST).contains(state) -} diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala deleted file mode 100644 index f8dcf025b4..0000000000 --- a/core/src/main/scala/spark/deploy/JsonProtocol.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy - -import net.liftweb.json.JsonDSL._ - -import spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} -import spark.deploy.master.{ApplicationInfo, WorkerInfo} -import spark.deploy.worker.ExecutorRunner - - -private[spark] object JsonProtocol { - def writeWorkerInfo(obj: WorkerInfo) = { - ("id" -> obj.id) ~ - ("host" -> obj.host) ~ - ("port" -> obj.port) ~ - ("webuiaddress" -> obj.webUiAddress) ~ - ("cores" -> obj.cores) ~ - ("coresused" -> obj.coresUsed) ~ - ("memory" -> obj.memory) ~ - ("memoryused" -> obj.memoryUsed) ~ - ("state" -> obj.state.toString) - } - - def writeApplicationInfo(obj: ApplicationInfo) = { - ("starttime" -> obj.startTime) ~ - ("id" -> obj.id) ~ - ("name" -> obj.desc.name) ~ - ("cores" -> obj.desc.maxCores) ~ - ("user" -> obj.desc.user) ~ - ("memoryperslave" -> obj.desc.memoryPerSlave) ~ - ("submitdate" -> obj.submitDate.toString) - } - - def writeApplicationDescription(obj: ApplicationDescription) = { - ("name" -> obj.name) ~ - ("cores" -> obj.maxCores) ~ - ("memoryperslave" -> obj.memoryPerSlave) ~ - ("user" -> obj.user) - } - - def writeExecutorRunner(obj: ExecutorRunner) = { - ("id" -> obj.execId) ~ - ("memory" -> obj.memory) ~ - ("appid" -> obj.appId) ~ - ("appdesc" -> writeApplicationDescription(obj.appDesc)) - } - - def writeMasterState(obj: MasterStateResponse) = { - ("url" -> ("spark://" + obj.uri)) ~ - ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~ - ("cores" -> obj.workers.map(_.cores).sum) ~ - ("coresused" -> obj.workers.map(_.coresUsed).sum) ~ - ("memory" -> obj.workers.map(_.memory).sum) ~ - ("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~ - ("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~ - ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) - } - - def writeWorkerState(obj: WorkerStateResponse) = { - ("id" -> obj.workerId) ~ - ("masterurl" -> obj.masterUrl) ~ - ("masterwebuiurl" -> obj.masterWebUiUrl) ~ - ("cores" -> obj.cores) ~ - ("coresused" -> obj.coresUsed) ~ - ("memory" -> obj.memory) ~ - ("memoryused" -> obj.memoryUsed) ~ - ("executors" -> obj.executors.toList.map(writeExecutorRunner)) ~ - ("finishedexecutors" -> obj.finishedExecutors.toList.map(writeExecutorRunner)) - } -} diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala deleted file mode 100644 index 6b8e9f27af..0000000000 --- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy - -import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated} - -import spark.deploy.worker.Worker -import spark.deploy.master.Master -import spark.util.AkkaUtils -import spark.{Logging, Utils} - -import scala.collection.mutable.ArrayBuffer - -/** - * Testing class that creates a Spark standalone process in-cluster (that is, running the - * spark.deploy.master.Master and spark.deploy.worker.Workers in the same JVMs). Executors launched - * by the Workers still run in separate JVMs. This can be used to test distributed operation and - * fault recovery without spinning up a lot of processes. - */ -private[spark] -class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging { - - private val localHostname = Utils.localHostName() - private val masterActorSystems = ArrayBuffer[ActorSystem]() - private val workerActorSystems = ArrayBuffer[ActorSystem]() - - def start(): String = { - logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") - - /* Start the Master */ - val (masterSystem, masterPort) = Master.startSystemAndActor(localHostname, 0, 0) - masterActorSystems += masterSystem - val masterUrl = "spark://" + localHostname + ":" + masterPort - - /* Start the Workers */ - for (workerNum <- 1 to numWorkers) { - val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, - memoryPerWorker, masterUrl, null, Some(workerNum)) - workerActorSystems += workerSystem - } - - return masterUrl - } - - def stop() { - logInfo("Shutting down local Spark cluster.") - // Stop the workers before the master so they don't get upset that it disconnected - workerActorSystems.foreach(_.shutdown()) - workerActorSystems.foreach(_.awaitTermination()) - - masterActorSystems.foreach(_.shutdown()) - masterActorSystems.foreach(_.awaitTermination()) - } -} diff --git a/core/src/main/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/spark/deploy/SparkHadoopUtil.scala deleted file mode 100644 index 882161e669..0000000000 --- a/core/src/main/scala/spark/deploy/SparkHadoopUtil.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapred.JobConf - - -/** - * Contains util methods to interact with Hadoop from spark. - */ -class SparkHadoopUtil { - - // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems - def newConfiguration(): Configuration = new Configuration() - - // add any user credentials to the job conf which are necessary for running on a secure Hadoop cluster - def addCredentials(conf: JobConf) {} - - def isYarnMode(): Boolean = { false } - -} diff --git a/core/src/main/scala/spark/deploy/WebUI.scala b/core/src/main/scala/spark/deploy/WebUI.scala deleted file mode 100644 index 8ea7792ef4..0000000000 --- a/core/src/main/scala/spark/deploy/WebUI.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy - -import java.text.SimpleDateFormat -import java.util.Date - -/** - * Utilities used throughout the web UI. - */ -private[spark] object DeployWebUI { - val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") - - def formatDate(date: Date): String = DATE_FORMAT.format(date) - - def formatDate(timestamp: Long): String = DATE_FORMAT.format(new Date(timestamp)) - - def formatDuration(milliseconds: Long): String = { - val seconds = milliseconds.toDouble / 1000 - if (seconds < 60) { - return "%.0f s".format(seconds) - } - val minutes = seconds / 60 - if (minutes < 10) { - return "%.1f min".format(minutes) - } else if (minutes < 60) { - return "%.0f min".format(minutes) - } - val hours = minutes / 60 - return "%.1f h".format(hours) - } -} diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala deleted file mode 100644 index 9d5ba8a796..0000000000 --- a/core/src/main/scala/spark/deploy/client/Client.scala +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.client - -import java.util.concurrent.TimeoutException - -import akka.actor._ -import akka.actor.Terminated -import akka.pattern.ask -import akka.util.Duration -import akka.remote.RemoteClientDisconnected -import akka.remote.RemoteClientLifeCycleEvent -import akka.remote.RemoteClientShutdown -import akka.dispatch.Await - -import spark.Logging -import spark.deploy.{ApplicationDescription, ExecutorState} -import spark.deploy.DeployMessages._ -import spark.deploy.master.Master - - -/** - * The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description, - * and a listener for cluster events, and calls back the listener when various events occur. - */ -private[spark] class Client( - actorSystem: ActorSystem, - masterUrl: String, - appDescription: ApplicationDescription, - listener: ClientListener) - extends Logging { - - var actor: ActorRef = null - var appId: String = null - - class ClientActor extends Actor with Logging { - var master: ActorRef = null - var masterAddress: Address = null - var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times - - override def preStart() { - logInfo("Connecting to master " + masterUrl) - try { - master = context.actorFor(Master.toAkkaUrl(masterUrl)) - masterAddress = master.path.address - master ! RegisterApplication(appDescription) - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(master) // Doesn't work with remote actors, but useful for testing - } catch { - case e: Exception => - logError("Failed to connect to master", e) - markDisconnected() - context.stop(self) - } - } - - override def receive = { - case RegisteredApplication(appId_) => - appId = appId_ - listener.connected(appId) - - case ApplicationRemoved(message) => - logError("Master removed our application: %s; stopping client".format(message)) - markDisconnected() - context.stop(self) - - case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => - val fullId = appId + "/" + id - logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) - listener.executorAdded(fullId, workerId, hostPort, cores, memory) - - case ExecutorUpdated(id, state, message, exitStatus) => - val fullId = appId + "/" + id - val messageText = message.map(s => " (" + s + ")").getOrElse("") - logInfo("Executor updated: %s is now %s%s".format(fullId, state, messageText)) - if (ExecutorState.isFinished(state)) { - listener.executorRemoved(fullId, message.getOrElse(""), exitStatus) - } - - case Terminated(actor_) if actor_ == master => - logError("Connection to master failed; stopping client") - markDisconnected() - context.stop(self) - - case RemoteClientDisconnected(transport, address) if address == masterAddress => - logError("Connection to master failed; stopping client") - markDisconnected() - context.stop(self) - - case RemoteClientShutdown(transport, address) if address == masterAddress => - logError("Connection to master failed; stopping client") - markDisconnected() - context.stop(self) - - case StopClient => - markDisconnected() - sender ! true - context.stop(self) - } - - /** - * Notify the listener that we disconnected, if we hadn't already done so before. - */ - def markDisconnected() { - if (!alreadyDisconnected) { - listener.disconnected() - alreadyDisconnected = true - } - } - } - - def start() { - // Just launch an actor; it will call back into the listener. - actor = actorSystem.actorOf(Props(new ClientActor)) - } - - def stop() { - if (actor != null) { - try { - val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") - val future = actor.ask(StopClient)(timeout) - Await.result(future, timeout) - } catch { - case e: TimeoutException => - logInfo("Stop request to Master timed out; it may already be shut down.") - } - actor = null - } - } -} diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala deleted file mode 100644 index 064024455e..0000000000 --- a/core/src/main/scala/spark/deploy/client/ClientListener.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.client - -/** - * Callbacks invoked by deploy client when various events happen. There are currently four events: - * connecting to the cluster, disconnecting, being given an executor, and having an executor - * removed (either due to failure or due to revocation). - * - * Users of this API should *not* block inside the callback methods. - */ -private[spark] trait ClientListener { - def connected(appId: String): Unit - - def disconnected(): Unit - - def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit - - def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit -} diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala deleted file mode 100644 index 4f4daa141a..0000000000 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.client - -import spark.util.AkkaUtils -import spark.{Logging, Utils} -import spark.deploy.{Command, ApplicationDescription} - -private[spark] object TestClient { - - class TestListener extends ClientListener with Logging { - def connected(id: String) { - logInfo("Connected to master, got app ID " + id) - } - - def disconnected() { - logInfo("Disconnected from master") - System.exit(0) - } - - def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {} - - def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {} - } - - def main(args: Array[String]) { - val url = args(0) - val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) - val desc = new ApplicationDescription( - "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored") - val listener = new TestListener - val client = new Client(actorSystem, url, desc, listener) - client.start() - actorSystem.awaitTermination() - } -} diff --git a/core/src/main/scala/spark/deploy/client/TestExecutor.scala b/core/src/main/scala/spark/deploy/client/TestExecutor.scala deleted file mode 100644 index 8a22b6b89f..0000000000 --- a/core/src/main/scala/spark/deploy/client/TestExecutor.scala +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.client - -private[spark] object TestExecutor { - def main(args: Array[String]) { - println("Hello world!") - while (true) { - Thread.sleep(1000) - } - } -} diff --git a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala deleted file mode 100644 index 6dd2f06126..0000000000 --- a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master - -import spark.deploy.ApplicationDescription -import java.util.Date -import akka.actor.ActorRef -import scala.collection.mutable - -private[spark] class ApplicationInfo( - val startTime: Long, - val id: String, - val desc: ApplicationDescription, - val submitDate: Date, - val driver: ActorRef, - val appUiUrl: String) -{ - var state = ApplicationState.WAITING - var executors = new mutable.HashMap[Int, ExecutorInfo] - var coresGranted = 0 - var endTime = -1L - val appSource = new ApplicationSource(this) - - private var nextExecutorId = 0 - - def newExecutorId(): Int = { - val id = nextExecutorId - nextExecutorId += 1 - id - } - - def addExecutor(worker: WorkerInfo, cores: Int): ExecutorInfo = { - val exec = new ExecutorInfo(newExecutorId(), this, worker, cores, desc.memoryPerSlave) - executors(exec.id) = exec - coresGranted += cores - exec - } - - def removeExecutor(exec: ExecutorInfo) { - if (executors.contains(exec.id)) { - executors -= exec.id - coresGranted -= exec.cores - } - } - - def coresLeft: Int = desc.maxCores - coresGranted - - private var _retryCount = 0 - - def retryCount = _retryCount - - def incrementRetryCount = { - _retryCount += 1 - _retryCount - } - - def markFinished(endState: ApplicationState.Value) { - state = endState - endTime = System.currentTimeMillis() - } - - def duration: Long = { - if (endTime != -1) { - endTime - startTime - } else { - System.currentTimeMillis() - startTime - } - } - -} diff --git a/core/src/main/scala/spark/deploy/master/ApplicationSource.scala b/core/src/main/scala/spark/deploy/master/ApplicationSource.scala deleted file mode 100644 index 4df2b6bfdd..0000000000 --- a/core/src/main/scala/spark/deploy/master/ApplicationSource.scala +++ /dev/null @@ -1,24 +0,0 @@ -package spark.deploy.master - -import com.codahale.metrics.{Gauge, MetricRegistry} - -import spark.metrics.source.Source - -class ApplicationSource(val application: ApplicationInfo) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "%s.%s.%s".format("application", application.desc.name, - System.currentTimeMillis()) - - metricRegistry.register(MetricRegistry.name("status"), new Gauge[String] { - override def getValue: String = application.state.toString - }) - - metricRegistry.register(MetricRegistry.name("runtime_ms"), new Gauge[Long] { - override def getValue: Long = application.duration - }) - - metricRegistry.register(MetricRegistry.name("cores", "number"), new Gauge[Int] { - override def getValue: Int = application.coresGranted - }) - -} diff --git a/core/src/main/scala/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/spark/deploy/master/ApplicationState.scala deleted file mode 100644 index 94f0ad8bae..0000000000 --- a/core/src/main/scala/spark/deploy/master/ApplicationState.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master - -private[spark] object ApplicationState - extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") { - - type ApplicationState = Value - - val WAITING, RUNNING, FINISHED, FAILED = Value - - val MAX_NUM_RETRY = 10 -} diff --git a/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala deleted file mode 100644 index 99b60f7d09..0000000000 --- a/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master - -import spark.deploy.ExecutorState - -private[spark] class ExecutorInfo( - val id: Int, - val application: ApplicationInfo, - val worker: WorkerInfo, - val cores: Int, - val memory: Int) { - - var state = ExecutorState.LAUNCHING - - def fullId: String = application.id + "/" + id -} diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala deleted file mode 100644 index 04af5e149c..0000000000 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ /dev/null @@ -1,386 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master - -import java.text.SimpleDateFormat -import java.util.Date - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} - -import akka.actor._ -import akka.actor.Terminated -import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown} -import akka.util.duration._ - -import spark.{Logging, SparkException, Utils} -import spark.deploy.{ApplicationDescription, ExecutorState} -import spark.deploy.DeployMessages._ -import spark.deploy.master.ui.MasterWebUI -import spark.metrics.MetricsSystem -import spark.util.AkkaUtils - - -private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { - val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000 - val RETAINED_APPLICATIONS = System.getProperty("spark.deploy.retainedApplications", "200").toInt - val REAPER_ITERATIONS = System.getProperty("spark.dead.worker.persistence", "15").toInt - - var nextAppNumber = 0 - val workers = new HashSet[WorkerInfo] - val idToWorker = new HashMap[String, WorkerInfo] - val actorToWorker = new HashMap[ActorRef, WorkerInfo] - val addressToWorker = new HashMap[Address, WorkerInfo] - - val apps = new HashSet[ApplicationInfo] - val idToApp = new HashMap[String, ApplicationInfo] - val actorToApp = new HashMap[ActorRef, ApplicationInfo] - val addressToApp = new HashMap[Address, ApplicationInfo] - - val waitingApps = new ArrayBuffer[ApplicationInfo] - val completedApps = new ArrayBuffer[ApplicationInfo] - - var firstApp: Option[ApplicationInfo] = None - - Utils.checkHost(host, "Expected hostname") - - val masterMetricsSystem = MetricsSystem.createMetricsSystem("master") - val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications") - val masterSource = new MasterSource(this) - - val webUi = new MasterWebUI(this, webUiPort) - - val masterPublicAddress = { - val envVar = System.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else host - } - - // As a temporary workaround before better ways of configuring memory, we allow users to set - // a flag that will perform round-robin scheduling across the nodes (spreading out each app - // among all the nodes) instead of trying to consolidate each app onto a small # of nodes. - val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean - - override def preStart() { - logInfo("Starting Spark master at spark://" + host + ":" + port) - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - webUi.start() - context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut) - - masterMetricsSystem.registerSource(masterSource) - masterMetricsSystem.start() - applicationMetricsSystem.start() - } - - override def postStop() { - webUi.stop() - masterMetricsSystem.stop() - applicationMetricsSystem.stop() - } - - override def receive = { - case RegisterWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress) => { - logInfo("Registering worker %s:%d with %d cores, %s RAM".format( - host, workerPort, cores, Utils.megabytesToString(memory))) - if (idToWorker.contains(id)) { - sender ! RegisterWorkerFailed("Duplicate worker ID") - } else { - addWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress) - context.watch(sender) // This doesn't work with remote actors but helps for testing - sender ! RegisteredWorker("http://" + masterPublicAddress + ":" + webUi.boundPort.get) - schedule() - } - } - - case RegisterApplication(description) => { - logInfo("Registering app " + description.name) - val app = addApplication(description, sender) - logInfo("Registered app " + description.name + " with ID " + app.id) - waitingApps += app - context.watch(sender) // This doesn't work with remote actors but helps for testing - sender ! RegisteredApplication(app.id) - schedule() - } - - case ExecutorStateChanged(appId, execId, state, message, exitStatus) => { - val execOption = idToApp.get(appId).flatMap(app => app.executors.get(execId)) - execOption match { - case Some(exec) => { - exec.state = state - exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus) - if (ExecutorState.isFinished(state)) { - val appInfo = idToApp(appId) - // Remove this executor from the worker and app - logInfo("Removing executor " + exec.fullId + " because it is " + state) - appInfo.removeExecutor(exec) - exec.worker.removeExecutor(exec) - - // Only retry certain number of times so we don't go into an infinite loop. - if (appInfo.incrementRetryCount < ApplicationState.MAX_NUM_RETRY) { - schedule() - } else { - logError("Application %s with ID %s failed %d times, removing it".format( - appInfo.desc.name, appInfo.id, appInfo.retryCount)) - removeApplication(appInfo, ApplicationState.FAILED) - } - } - } - case None => - logWarning("Got status update for unknown executor " + appId + "/" + execId) - } - } - - case Heartbeat(workerId) => { - idToWorker.get(workerId) match { - case Some(workerInfo) => - workerInfo.lastHeartbeat = System.currentTimeMillis() - case None => - logWarning("Got heartbeat from unregistered worker " + workerId) - } - } - - case Terminated(actor) => { - // The disconnected actor could've been either a worker or an app; remove whichever of - // those we have an entry for in the corresponding actor hashmap - actorToWorker.get(actor).foreach(removeWorker) - actorToApp.get(actor).foreach(finishApplication) - } - - case RemoteClientDisconnected(transport, address) => { - // The disconnected client could've been either a worker or an app; remove whichever it was - addressToWorker.get(address).foreach(removeWorker) - addressToApp.get(address).foreach(finishApplication) - } - - case RemoteClientShutdown(transport, address) => { - // The disconnected client could've been either a worker or an app; remove whichever it was - addressToWorker.get(address).foreach(removeWorker) - addressToApp.get(address).foreach(finishApplication) - } - - case RequestMasterState => { - sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray) - } - - case CheckForWorkerTimeOut => { - timeOutDeadWorkers() - } - } - - /** - * Can an app use the given worker? True if the worker has enough memory and we haven't already - * launched an executor for the app on it (right now the standalone backend doesn't like having - * two executors on the same worker). - */ - def canUse(app: ApplicationInfo, worker: WorkerInfo): Boolean = { - worker.memoryFree >= app.desc.memoryPerSlave && !worker.hasExecutor(app) - } - - /** - * Schedule the currently available resources among waiting apps. This method will be called - * every time a new app joins or resource availability changes. - */ - def schedule() { - // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app - // in the queue, then the second app, etc. - if (spreadOutApps) { - // Try to spread out each app among all the nodes, until it has all its cores - for (app <- waitingApps if app.coresLeft > 0) { - val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(canUse(app, _)).sortBy(_.coresFree).reverse - val numUsable = usableWorkers.length - val assigned = new Array[Int](numUsable) // Number of cores to give on each node - var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) - var pos = 0 - while (toAssign > 0) { - if (usableWorkers(pos).coresFree - assigned(pos) > 0) { - toAssign -= 1 - assigned(pos) += 1 - } - pos = (pos + 1) % numUsable - } - // Now that we've decided how many cores to give on each node, let's actually give them - for (pos <- 0 until numUsable) { - if (assigned(pos) > 0) { - val exec = app.addExecutor(usableWorkers(pos), assigned(pos)) - launchExecutor(usableWorkers(pos), exec, app.desc.sparkHome) - app.state = ApplicationState.RUNNING - } - } - } - } else { - // Pack each app into as few nodes as possible until we've assigned all its cores - for (worker <- workers if worker.coresFree > 0 && worker.state == WorkerState.ALIVE) { - for (app <- waitingApps if app.coresLeft > 0) { - if (canUse(app, worker)) { - val coresToUse = math.min(worker.coresFree, app.coresLeft) - if (coresToUse > 0) { - val exec = app.addExecutor(worker, coresToUse) - launchExecutor(worker, exec, app.desc.sparkHome) - app.state = ApplicationState.RUNNING - } - } - } - } - } - } - - def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) { - logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) - worker.addExecutor(exec) - worker.actor ! LaunchExecutor( - exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome) - exec.application.driver ! ExecutorAdded( - exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) - } - - def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int, - publicAddress: String): WorkerInfo = { - // There may be one or more refs to dead workers on this same node (w/ different ID's), - // remove them. - workers.filter { w => - (w.host == host && w.port == port) && (w.state == WorkerState.DEAD) - }.foreach { w => - workers -= w - } - val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress) - workers += worker - idToWorker(worker.id) = worker - actorToWorker(sender) = worker - addressToWorker(sender.path.address) = worker - worker - } - - def removeWorker(worker: WorkerInfo) { - logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) - worker.setState(WorkerState.DEAD) - idToWorker -= worker.id - actorToWorker -= worker.actor - addressToWorker -= worker.actor.path.address - for (exec <- worker.executors.values) { - logInfo("Telling app of lost executor: " + exec.id) - exec.application.driver ! ExecutorUpdated( - exec.id, ExecutorState.LOST, Some("worker lost"), None) - exec.application.removeExecutor(exec) - } - } - - def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { - val now = System.currentTimeMillis() - val date = new Date(now) - val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl) - applicationMetricsSystem.registerSource(app.appSource) - apps += app - idToApp(app.id) = app - actorToApp(driver) = app - addressToApp(driver.path.address) = app - if (firstApp == None) { - firstApp = Some(app) - } - val workersAlive = workers.filter(_.state == WorkerState.ALIVE).toArray - if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= desc.memoryPerSlave)) { - logWarning("Could not find any workers with enough memory for " + firstApp.get.id) - } - app - } - - def finishApplication(app: ApplicationInfo) { - removeApplication(app, ApplicationState.FINISHED) - } - - def removeApplication(app: ApplicationInfo, state: ApplicationState.Value) { - if (apps.contains(app)) { - logInfo("Removing app " + app.id) - apps -= app - idToApp -= app.id - actorToApp -= app.driver - addressToApp -= app.driver.path.address - if (completedApps.size >= RETAINED_APPLICATIONS) { - val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) - completedApps.take(toRemove).foreach( a => { - applicationMetricsSystem.removeSource(a.appSource) - }) - completedApps.trimStart(toRemove) - } - completedApps += app // Remember it in our history - waitingApps -= app - for (exec <- app.executors.values) { - exec.worker.removeExecutor(exec) - exec.worker.actor ! KillExecutor(exec.application.id, exec.id) - exec.state = ExecutorState.KILLED - } - app.markFinished(state) - if (state != ApplicationState.FINISHED) { - app.driver ! ApplicationRemoved(state.toString) - } - schedule() - } - } - - /** Generate a new app ID given a app's submission date */ - def newApplicationId(submitDate: Date): String = { - val appId = "app-%s-%04d".format(DATE_FORMAT.format(submitDate), nextAppNumber) - nextAppNumber += 1 - appId - } - - /** Check for, and remove, any timed-out workers */ - def timeOutDeadWorkers() { - // Copy the workers into an array so we don't modify the hashset while iterating through it - val currentTime = System.currentTimeMillis() - val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT).toArray - for (worker <- toRemove) { - if (worker.state != WorkerState.DEAD) { - logWarning("Removing %s because we got no heartbeat in %d seconds".format( - worker.id, WORKER_TIMEOUT/1000)) - removeWorker(worker) - } else { - if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT)) - workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it - } - } - } -} - -private[spark] object Master { - private val systemName = "sparkMaster" - private val actorName = "Master" - private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r - - def main(argStrings: Array[String]) { - val args = new MasterArguments(argStrings) - val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort) - actorSystem.awaitTermination() - } - - /** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */ - def toAkkaUrl(sparkUrl: String): String = { - sparkUrl match { - case sparkUrlRegex(host, port) => - "akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName) - case _ => - throw new SparkException("Invalid master URL: " + sparkUrl) - } - } - - def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int) = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) - val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName) - (actorSystem, boundPort) - } -} diff --git a/core/src/main/scala/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/spark/deploy/master/MasterArguments.scala deleted file mode 100644 index 0ae0160767..0000000000 --- a/core/src/main/scala/spark/deploy/master/MasterArguments.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master - -import spark.util.IntParam -import spark.Utils - -/** - * Command-line parser for the master. - */ -private[spark] class MasterArguments(args: Array[String]) { - var host = Utils.localHostName() - var port = 7077 - var webUiPort = 8080 - - // Check for settings in environment variables - if (System.getenv("SPARK_MASTER_HOST") != null) { - host = System.getenv("SPARK_MASTER_HOST") - } - if (System.getenv("SPARK_MASTER_PORT") != null) { - port = System.getenv("SPARK_MASTER_PORT").toInt - } - if (System.getenv("SPARK_MASTER_WEBUI_PORT") != null) { - webUiPort = System.getenv("SPARK_MASTER_WEBUI_PORT").toInt - } - if (System.getProperty("master.ui.port") != null) { - webUiPort = System.getProperty("master.ui.port").toInt - } - - parse(args.toList) - - def parse(args: List[String]): Unit = args match { - case ("--ip" | "-i") :: value :: tail => - Utils.checkHost(value, "ip no longer supported, please use hostname " + value) - host = value - parse(tail) - - case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) - host = value - parse(tail) - - case ("--port" | "-p") :: IntParam(value) :: tail => - port = value - parse(tail) - - case "--webui-port" :: IntParam(value) :: tail => - webUiPort = value - parse(tail) - - case ("--help" | "-h") :: tail => - printUsageAndExit(0) - - case Nil => {} - - case _ => - printUsageAndExit(1) - } - - /** - * Print usage and exit JVM with the given exit code. - */ - def printUsageAndExit(exitCode: Int) { - System.err.println( - "Usage: Master [options]\n" + - "\n" + - "Options:\n" + - " -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" + - " -h HOST, --host HOST Hostname to listen on\n" + - " -p PORT, --port PORT Port to listen on (default: 7077)\n" + - " --webui-port PORT Port for web UI (default: 8080)") - System.exit(exitCode) - } -} diff --git a/core/src/main/scala/spark/deploy/master/MasterSource.scala b/core/src/main/scala/spark/deploy/master/MasterSource.scala deleted file mode 100644 index b8cfa6a773..0000000000 --- a/core/src/main/scala/spark/deploy/master/MasterSource.scala +++ /dev/null @@ -1,25 +0,0 @@ -package spark.deploy.master - -import com.codahale.metrics.{Gauge, MetricRegistry} - -import spark.metrics.source.Source - -private[spark] class MasterSource(val master: Master) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "master" - - // Gauge for worker numbers in cluster - metricRegistry.register(MetricRegistry.name("workers","number"), new Gauge[Int] { - override def getValue: Int = master.workers.size - }) - - // Gauge for application numbers in cluster - metricRegistry.register(MetricRegistry.name("apps", "number"), new Gauge[Int] { - override def getValue: Int = master.apps.size - }) - - // Gauge for waiting application numbers in cluster - metricRegistry.register(MetricRegistry.name("waitingApps", "number"), new Gauge[Int] { - override def getValue: Int = master.waitingApps.size - }) -} diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala deleted file mode 100644 index 4135cfeb28..0000000000 --- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master - -import akka.actor.ActorRef -import scala.collection.mutable -import spark.Utils - -private[spark] class WorkerInfo( - val id: String, - val host: String, - val port: Int, - val cores: Int, - val memory: Int, - val actor: ActorRef, - val webUiPort: Int, - val publicAddress: String) { - - Utils.checkHost(host, "Expected hostname") - assert (port > 0) - - var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info - var state: WorkerState.Value = WorkerState.ALIVE - var coresUsed = 0 - var memoryUsed = 0 - - var lastHeartbeat = System.currentTimeMillis() - - def coresFree: Int = cores - coresUsed - def memoryFree: Int = memory - memoryUsed - - def hostPort: String = { - assert (port > 0) - host + ":" + port - } - - def addExecutor(exec: ExecutorInfo) { - executors(exec.fullId) = exec - coresUsed += exec.cores - memoryUsed += exec.memory - } - - def removeExecutor(exec: ExecutorInfo) { - if (executors.contains(exec.fullId)) { - executors -= exec.fullId - coresUsed -= exec.cores - memoryUsed -= exec.memory - } - } - - def hasExecutor(app: ApplicationInfo): Boolean = { - executors.values.exists(_.application == app) - } - - def webUiAddress : String = { - "http://" + this.publicAddress + ":" + this.webUiPort - } - - def setState(state: WorkerState.Value) = { - this.state = state - } -} diff --git a/core/src/main/scala/spark/deploy/master/WorkerState.scala b/core/src/main/scala/spark/deploy/master/WorkerState.scala deleted file mode 100644 index 3e50b7748d..0000000000 --- a/core/src/main/scala/spark/deploy/master/WorkerState.scala +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master - -private[spark] object WorkerState extends Enumeration("ALIVE", "DEAD", "DECOMMISSIONED") { - type WorkerState = Value - - val ALIVE, DEAD, DECOMMISSIONED = Value -} diff --git a/core/src/main/scala/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/spark/deploy/master/ui/ApplicationPage.scala deleted file mode 100644 index 2ad98f759c..0000000000 --- a/core/src/main/scala/spark/deploy/master/ui/ApplicationPage.scala +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master.ui - -import scala.xml.Node - -import akka.dispatch.Await -import akka.pattern.ask -import akka.util.duration._ - -import javax.servlet.http.HttpServletRequest - -import net.liftweb.json.JsonAST.JValue - -import spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} -import spark.deploy.JsonProtocol -import spark.deploy.master.ExecutorInfo -import spark.ui.UIUtils -import spark.Utils - -private[spark] class ApplicationPage(parent: MasterWebUI) { - val master = parent.masterActorRef - implicit val timeout = parent.timeout - - /** Executor details for a particular application */ - def renderJson(request: HttpServletRequest): JValue = { - val appId = request.getParameter("appId") - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, 30 seconds) - val app = state.activeApps.find(_.id == appId).getOrElse({ - state.completedApps.find(_.id == appId).getOrElse(null) - }) - JsonProtocol.writeApplicationInfo(app) - } - - /** Executor details for a particular application */ - def render(request: HttpServletRequest): Seq[Node] = { - val appId = request.getParameter("appId") - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, 30 seconds) - val app = state.activeApps.find(_.id == appId).getOrElse({ - state.completedApps.find(_.id == appId).getOrElse(null) - }) - - val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs") - val executors = app.executors.values.toSeq - val executorTable = UIUtils.listingTable(executorHeaders, executorRow, executors) - - val content = -
-
-
    -
  • ID: {app.id}
  • -
  • Name: {app.desc.name}
  • -
  • User: {app.desc.user}
  • -
  • Cores: - { - if (app.desc.maxCores == Integer.MAX_VALUE) { - "Unlimited (%s granted)".format(app.coresGranted) - } else { - "%s (%s granted, %s left)".format( - app.desc.maxCores, app.coresGranted, app.coresLeft) - } - } -
  • -
  • - Executor Memory: - {Utils.megabytesToString(app.desc.memoryPerSlave)} -
  • -
  • Submit Date: {app.submitDate}
  • -
  • State: {app.state}
  • -
  • Application Detail UI
  • -
-
-
- -
-
-

Executor Summary

- {executorTable} -
-
; - UIUtils.basicSparkPage(content, "Application: " + app.desc.name) - } - - def executorRow(executor: ExecutorInfo): Seq[Node] = { - - {executor.id} - - {executor.worker.id} - - {executor.cores} - {executor.memory} - {executor.state} - - stdout - stderr - - - } -} diff --git a/core/src/main/scala/spark/deploy/master/ui/IndexPage.scala b/core/src/main/scala/spark/deploy/master/ui/IndexPage.scala deleted file mode 100644 index 093e523e23..0000000000 --- a/core/src/main/scala/spark/deploy/master/ui/IndexPage.scala +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master.ui - -import javax.servlet.http.HttpServletRequest - -import scala.xml.Node - -import akka.dispatch.Await -import akka.pattern.ask -import akka.util.duration._ - -import net.liftweb.json.JsonAST.JValue - -import spark.Utils -import spark.deploy.DeployWebUI -import spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} -import spark.deploy.JsonProtocol -import spark.deploy.master.{ApplicationInfo, WorkerInfo} -import spark.ui.UIUtils - -private[spark] class IndexPage(parent: MasterWebUI) { - val master = parent.masterActorRef - implicit val timeout = parent.timeout - - def renderJson(request: HttpServletRequest): JValue = { - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, 30 seconds) - JsonProtocol.writeMasterState(state) - } - - /** Index view listing applications and executors */ - def render(request: HttpServletRequest): Seq[Node] = { - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, 30 seconds) - - val workerHeaders = Seq("Id", "Address", "State", "Cores", "Memory") - val workers = state.workers.sortBy(_.id) - val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) - - val appHeaders = Seq("ID", "Name", "Cores", "Memory per Node", "Submitted Time", "User", - "State", "Duration") - val activeApps = state.activeApps.sortBy(_.startTime).reverse - val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps) - val completedApps = state.completedApps.sortBy(_.endTime).reverse - val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps) - - val content = -
-
-
    -
  • URL: {state.uri}
  • -
  • Workers: {state.workers.size}
  • -
  • Cores: {state.workers.map(_.cores).sum} Total, - {state.workers.map(_.coresUsed).sum} Used
  • -
  • Memory: - {Utils.megabytesToString(state.workers.map(_.memory).sum)} Total, - {Utils.megabytesToString(state.workers.map(_.memoryUsed).sum)} Used
  • -
  • Applications: - {state.activeApps.size} Running, - {state.completedApps.size} Completed
  • -
-
-
- -
-
-

Workers

- {workerTable} -
-
- -
-
-

Running Applications

- - {activeAppsTable} -
-
- -
-
-

Completed Applications

- {completedAppsTable} -
-
; - UIUtils.basicSparkPage(content, "Spark Master at " + state.uri) - } - - def workerRow(worker: WorkerInfo): Seq[Node] = { - - - {worker.id} - - {worker.host}:{worker.port} - {worker.state} - {worker.cores} ({worker.coresUsed} Used) - - {Utils.megabytesToString(worker.memory)} - ({Utils.megabytesToString(worker.memoryUsed)} Used) - - - } - - - def appRow(app: ApplicationInfo): Seq[Node] = { - - - {app.id} - - - {app.desc.name} - - - {app.coresGranted} - - - {Utils.megabytesToString(app.desc.memoryPerSlave)} - - {DeployWebUI.formatDate(app.submitDate)} - {app.desc.user} - {app.state.toString} - {DeployWebUI.formatDuration(app.duration)} - - } -} diff --git a/core/src/main/scala/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/ui/MasterWebUI.scala deleted file mode 100644 index c91e1db9f2..0000000000 --- a/core/src/main/scala/spark/deploy/master/ui/MasterWebUI.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.master.ui - -import akka.util.Duration - -import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.{Handler, Server} - -import spark.{Logging, Utils} -import spark.deploy.master.Master -import spark.ui.JettyUtils -import spark.ui.JettyUtils._ - -/** - * Web UI server for the standalone master. - */ -private[spark] -class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { - implicit val timeout = Duration.create( - System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") - val host = Utils.localHostName() - val port = requestedPort - - val masterActorRef = master.self - - var server: Option[Server] = None - var boundPort: Option[Int] = None - - val applicationPage = new ApplicationPage(this) - val indexPage = new IndexPage(this) - - def start() { - try { - val (srv, bPort) = JettyUtils.startJettyServer("0.0.0.0", port, handlers) - server = Some(srv) - boundPort = Some(bPort) - logInfo("Started Master web UI at http://%s:%d".format(host, boundPort.get)) - } catch { - case e: Exception => - logError("Failed to create Master JettyUtils", e) - System.exit(1) - } - } - - val metricsHandlers = master.masterMetricsSystem.getServletHandlers ++ - master.applicationMetricsSystem.getServletHandlers - - val handlers = metricsHandlers ++ Array[(String, Handler)]( - ("/static", createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR)), - ("/app/json", (request: HttpServletRequest) => applicationPage.renderJson(request)), - ("/app", (request: HttpServletRequest) => applicationPage.render(request)), - ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), - ("*", (request: HttpServletRequest) => indexPage.render(request)) - ) - - def stop() { - server.foreach(_.stop()) - } -} - -private[spark] object MasterWebUI { - val STATIC_RESOURCE_DIR = "spark/ui/static" -} diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala deleted file mode 100644 index 34665ce451..0000000000 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.worker - -import java.io._ -import java.lang.System.getenv - -import akka.actor.ActorRef - -import com.google.common.base.Charsets -import com.google.common.io.Files - -import spark.{Utils, Logging} -import spark.deploy.{ExecutorState, ApplicationDescription} -import spark.deploy.DeployMessages.ExecutorStateChanged - -/** - * Manages the execution of one executor process. - */ -private[spark] class ExecutorRunner( - val appId: String, - val execId: Int, - val appDesc: ApplicationDescription, - val cores: Int, - val memory: Int, - val worker: ActorRef, - val workerId: String, - val host: String, - val sparkHome: File, - val workDir: File) - extends Logging { - - val fullId = appId + "/" + execId - var workerThread: Thread = null - var process: Process = null - var shutdownHook: Thread = null - - private def getAppEnv(key: String): Option[String] = - appDesc.command.environment.get(key).orElse(Option(getenv(key))) - - def start() { - workerThread = new Thread("ExecutorRunner for " + fullId) { - override def run() { fetchAndRunExecutor() } - } - workerThread.start() - - // Shutdown hook that kills actors on shutdown. - shutdownHook = new Thread() { - override def run() { - if (process != null) { - logInfo("Shutdown hook killing child process.") - process.destroy() - process.waitFor() - } - } - } - Runtime.getRuntime.addShutdownHook(shutdownHook) - } - - /** Stop this executor runner, including killing the process it launched */ - def kill() { - if (workerThread != null) { - workerThread.interrupt() - workerThread = null - if (process != null) { - logInfo("Killing process!") - process.destroy() - process.waitFor() - } - worker ! ExecutorStateChanged(appId, execId, ExecutorState.KILLED, None, None) - Runtime.getRuntime.removeShutdownHook(shutdownHook) - } - } - - /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */ - def substituteVariables(argument: String): String = argument match { - case "{{EXECUTOR_ID}}" => execId.toString - case "{{HOSTNAME}}" => host - case "{{CORES}}" => cores.toString - case other => other - } - - def buildCommandSeq(): Seq[String] = { - val command = appDesc.command - val runner = getAppEnv("JAVA_HOME").map(_ + "/bin/java").getOrElse("java") - // SPARK-698: do not call the run.cmd script, as process.destroy() - // fails to kill a process tree on Windows - Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++ - command.arguments.map(substituteVariables) - } - - /** - * Attention: this must always be aligned with the environment variables in the run scripts and - * the way the JAVA_OPTS are assembled there. - */ - def buildJavaOpts(): Seq[String] = { - val libraryOpts = getAppEnv("SPARK_LIBRARY_PATH") - .map(p => List("-Djava.library.path=" + p)) - .getOrElse(Nil) - val workerLocalOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil) - val userOpts = getAppEnv("SPARK_JAVA_OPTS").map(Utils.splitCommandString).getOrElse(Nil) - val memoryOpts = Seq("-Xms" + memory + "M", "-Xmx" + memory + "M") - - // Figure out our classpath with the external compute-classpath script - val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh" - val classPath = Utils.executeAndGetOutput( - Seq(sparkHome + "/bin/compute-classpath" + ext), - extraEnvironment=appDesc.command.environment) - - Seq("-cp", classPath) ++ libraryOpts ++ workerLocalOpts ++ userOpts ++ memoryOpts - } - - /** Spawn a thread that will redirect a given stream to a file */ - def redirectStream(in: InputStream, file: File) { - val out = new FileOutputStream(file, true) - new Thread("redirect output to " + file) { - override def run() { - try { - Utils.copyStream(in, out, true) - } catch { - case e: IOException => - logInfo("Redirection to " + file + " closed: " + e.getMessage) - } - } - }.start() - } - - /** - * Download and run the executor described in our ApplicationDescription - */ - def fetchAndRunExecutor() { - try { - // Create the executor's working directory - val executorDir = new File(workDir, appId + "/" + execId) - if (!executorDir.mkdirs()) { - throw new IOException("Failed to create directory " + executorDir) - } - - // Launch the process - val command = buildCommandSeq() - logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) - val builder = new ProcessBuilder(command: _*).directory(executorDir) - val env = builder.environment() - for ((key, value) <- appDesc.command.environment) { - env.put(key, value) - } - // In case we are running this from within the Spark Shell, avoid creating a "scala" - // parent process for the executor command - env.put("SPARK_LAUNCH_WITH_SCALA", "0") - process = builder.start() - - val header = "Spark Executor Command: %s\n%s\n\n".format( - command.mkString("\"", "\" \"", "\""), "=" * 40) - - // Redirect its stdout and stderr to files - val stdout = new File(executorDir, "stdout") - redirectStream(process.getInputStream, stdout) - - val stderr = new File(executorDir, "stderr") - Files.write(header, stderr, Charsets.UTF_8) - redirectStream(process.getErrorStream, stderr) - - // Wait for it to exit; this is actually a bad thing if it happens, because we expect to run - // long-lived processes only. However, in the future, we might restart the executor a few - // times on the same machine. - val exitCode = process.waitFor() - val message = "Command exited with code " + exitCode - worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), - Some(exitCode)) - } catch { - case interrupted: InterruptedException => - logInfo("Runner thread for executor " + fullId + " interrupted") - - case e: Exception => { - logError("Error running executor", e) - if (process != null) { - process.destroy() - } - val message = e.getClass + ": " + e.getMessage - worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), None) - } - } - } -} diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala deleted file mode 100644 index 053ac55226..0000000000 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ /dev/null @@ -1,213 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.worker - -import java.text.SimpleDateFormat -import java.util.Date -import java.io.File - -import scala.collection.mutable.HashMap - -import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated} -import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} -import akka.util.duration._ - -import spark.{Logging, Utils} -import spark.deploy.ExecutorState -import spark.deploy.DeployMessages._ -import spark.deploy.master.Master -import spark.deploy.worker.ui.WorkerWebUI -import spark.metrics.MetricsSystem -import spark.util.AkkaUtils - - -private[spark] class Worker( - host: String, - port: Int, - webUiPort: Int, - cores: Int, - memory: Int, - masterUrl: String, - workDirPath: String = null) - extends Actor with Logging { - - Utils.checkHost(host, "Expected hostname") - assert (port > 0) - - val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs - - // Send a heartbeat every (heartbeat timeout) / 4 milliseconds - val HEARTBEAT_MILLIS = System.getProperty("spark.worker.timeout", "60").toLong * 1000 / 4 - - var master: ActorRef = null - var masterWebUiUrl : String = "" - val workerId = generateWorkerId() - var sparkHome: File = null - var workDir: File = null - val executors = new HashMap[String, ExecutorRunner] - val finishedExecutors = new HashMap[String, ExecutorRunner] - val publicAddress = { - val envVar = System.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else host - } - var webUi: WorkerWebUI = null - - var coresUsed = 0 - var memoryUsed = 0 - - val metricsSystem = MetricsSystem.createMetricsSystem("worker") - val workerSource = new WorkerSource(this) - - def coresFree: Int = cores - coresUsed - def memoryFree: Int = memory - memoryUsed - - def createWorkDir() { - workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work")) - try { - // This sporadically fails - not sure why ... !workDir.exists() && !workDir.mkdirs() - // So attempting to create and then check if directory was created or not. - workDir.mkdirs() - if ( !workDir.exists() || !workDir.isDirectory) { - logError("Failed to create work directory " + workDir) - System.exit(1) - } - assert (workDir.isDirectory) - } catch { - case e: Exception => - logError("Failed to create work directory " + workDir, e) - System.exit(1) - } - } - - override def preStart() { - logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( - host, port, cores, Utils.megabytesToString(memory))) - sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse(".")) - logInfo("Spark home: " + sparkHome) - createWorkDir() - webUi = new WorkerWebUI(this, workDir, Some(webUiPort)) - - webUi.start() - connectToMaster() - - metricsSystem.registerSource(workerSource) - metricsSystem.start() - } - - def connectToMaster() { - logInfo("Connecting to master " + masterUrl) - master = context.actorFor(Master.toAkkaUrl(masterUrl)) - master ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get, publicAddress) - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(master) // Doesn't work with remote actors, but useful for testing - } - - override def receive = { - case RegisteredWorker(url) => - masterWebUiUrl = url - logInfo("Successfully registered with master") - context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis) { - master ! Heartbeat(workerId) - } - - case RegisterWorkerFailed(message) => - logError("Worker registration failed: " + message) - System.exit(1) - - case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) => - logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) - val manager = new ExecutorRunner( - appId, execId, appDesc, cores_, memory_, self, workerId, host, new File(execSparkHome_), workDir) - executors(appId + "/" + execId) = manager - manager.start() - coresUsed += cores_ - memoryUsed += memory_ - master ! ExecutorStateChanged(appId, execId, ExecutorState.RUNNING, None, None) - - case ExecutorStateChanged(appId, execId, state, message, exitStatus) => - master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) - val fullId = appId + "/" + execId - if (ExecutorState.isFinished(state)) { - val executor = executors(fullId) - logInfo("Executor " + fullId + " finished with state " + state + - message.map(" message " + _).getOrElse("") + - exitStatus.map(" exitStatus " + _).getOrElse("")) - finishedExecutors(fullId) = executor - executors -= fullId - coresUsed -= executor.cores - memoryUsed -= executor.memory - } - - case KillExecutor(appId, execId) => - val fullId = appId + "/" + execId - executors.get(fullId) match { - case Some(executor) => - logInfo("Asked to kill executor " + fullId) - executor.kill() - case None => - logInfo("Asked to kill unknown executor " + fullId) - } - - case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => - masterDisconnected() - - case RequestWorkerState => { - sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, - finishedExecutors.values.toList, masterUrl, cores, memory, - coresUsed, memoryUsed, masterWebUiUrl) - } - } - - def masterDisconnected() { - // TODO: It would be nice to try to reconnect to the master, but just shut down for now. - // (Note that if reconnecting we would also need to assign IDs differently.) - logError("Connection to master failed! Shutting down.") - executors.values.foreach(_.kill()) - System.exit(1) - } - - def generateWorkerId(): String = { - "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), host, port) - } - - override def postStop() { - executors.values.foreach(_.kill()) - webUi.stop() - metricsSystem.stop() - } -} - -private[spark] object Worker { - def main(argStrings: Array[String]) { - val args = new WorkerArguments(argStrings) - val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, - args.memory, args.master, args.workDir) - actorSystem.awaitTermination() - } - - def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int, - masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = { - // The LocalSparkCluster runs multiple local sparkWorkerX actor systems - val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) - val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory, - masterUrl, workDir)), name = "Worker") - (actorSystem, boundPort) - } - -} diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala deleted file mode 100644 index 9fcd3260ca..0000000000 --- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.worker - -import spark.util.IntParam -import spark.util.MemoryParam -import spark.Utils -import java.lang.management.ManagementFactory - -/** - * Command-line parser for the master. - */ -private[spark] class WorkerArguments(args: Array[String]) { - var host = Utils.localHostName() - var port = 0 - var webUiPort = 8081 - var cores = inferDefaultCores() - var memory = inferDefaultMemory() - var master: String = null - var workDir: String = null - - // Check for settings in environment variables - if (System.getenv("SPARK_WORKER_PORT") != null) { - port = System.getenv("SPARK_WORKER_PORT").toInt - } - if (System.getenv("SPARK_WORKER_CORES") != null) { - cores = System.getenv("SPARK_WORKER_CORES").toInt - } - if (System.getenv("SPARK_WORKER_MEMORY") != null) { - memory = Utils.memoryStringToMb(System.getenv("SPARK_WORKER_MEMORY")) - } - if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) { - webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt - } - if (System.getenv("SPARK_WORKER_DIR") != null) { - workDir = System.getenv("SPARK_WORKER_DIR") - } - - parse(args.toList) - - def parse(args: List[String]): Unit = args match { - case ("--ip" | "-i") :: value :: tail => - Utils.checkHost(value, "ip no longer supported, please use hostname " + value) - host = value - parse(tail) - - case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) - host = value - parse(tail) - - case ("--port" | "-p") :: IntParam(value) :: tail => - port = value - parse(tail) - - case ("--cores" | "-c") :: IntParam(value) :: tail => - cores = value - parse(tail) - - case ("--memory" | "-m") :: MemoryParam(value) :: tail => - memory = value - parse(tail) - - case ("--work-dir" | "-d") :: value :: tail => - workDir = value - parse(tail) - - case "--webui-port" :: IntParam(value) :: tail => - webUiPort = value - parse(tail) - - case ("--help" | "-h") :: tail => - printUsageAndExit(0) - - case value :: tail => - if (master != null) { // Two positional arguments were given - printUsageAndExit(1) - } - master = value - parse(tail) - - case Nil => - if (master == null) { // No positional argument was given - printUsageAndExit(1) - } - - case _ => - printUsageAndExit(1) - } - - /** - * Print usage and exit JVM with the given exit code. - */ - def printUsageAndExit(exitCode: Int) { - System.err.println( - "Usage: Worker [options] \n" + - "\n" + - "Master must be a URL of the form spark://hostname:port\n" + - "\n" + - "Options:\n" + - " -c CORES, --cores CORES Number of cores to use\n" + - " -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" + - " -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" + - " -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" + - " -h HOST, --host HOST Hostname to listen on\n" + - " -p PORT, --port PORT Port to listen on (default: random)\n" + - " --webui-port PORT Port for web UI (default: 8081)") - System.exit(exitCode) - } - - def inferDefaultCores(): Int = { - Runtime.getRuntime.availableProcessors() - } - - def inferDefaultMemory(): Int = { - val ibmVendor = System.getProperty("java.vendor").contains("IBM") - var totalMb = 0 - try { - val bean = ManagementFactory.getOperatingSystemMXBean() - if (ibmVendor) { - val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean") - val method = beanClass.getDeclaredMethod("getTotalPhysicalMemory") - totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt - } else { - val beanClass = Class.forName("com.sun.management.OperatingSystemMXBean") - val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize") - totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt - } - } catch { - case e: Exception => { - totalMb = 2*1024 - System.out.println("Failed to get total physical memory. Using " + totalMb + " MB") - } - } - // Leave out 1 GB for the operating system, but don't return a negative memory size - math.max(totalMb - 1024, 512) - } -} diff --git a/core/src/main/scala/spark/deploy/worker/WorkerSource.scala b/core/src/main/scala/spark/deploy/worker/WorkerSource.scala deleted file mode 100644 index 39cb8e5690..0000000000 --- a/core/src/main/scala/spark/deploy/worker/WorkerSource.scala +++ /dev/null @@ -1,34 +0,0 @@ -package spark.deploy.worker - -import com.codahale.metrics.{Gauge, MetricRegistry} - -import spark.metrics.source.Source - -private[spark] class WorkerSource(val worker: Worker) extends Source { - val sourceName = "worker" - val metricRegistry = new MetricRegistry() - - metricRegistry.register(MetricRegistry.name("executors", "number"), new Gauge[Int] { - override def getValue: Int = worker.executors.size - }) - - // Gauge for cores used of this worker - metricRegistry.register(MetricRegistry.name("coresUsed", "number"), new Gauge[Int] { - override def getValue: Int = worker.coresUsed - }) - - // Gauge for memory used of this worker - metricRegistry.register(MetricRegistry.name("memUsed", "MBytes"), new Gauge[Int] { - override def getValue: Int = worker.memoryUsed - }) - - // Gauge for cores free of this worker - metricRegistry.register(MetricRegistry.name("coresFree", "number"), new Gauge[Int] { - override def getValue: Int = worker.coresFree - }) - - // Gauge for memory free of this worker - metricRegistry.register(MetricRegistry.name("memFree", "MBytes"), new Gauge[Int] { - override def getValue: Int = worker.memoryFree - }) -} diff --git a/core/src/main/scala/spark/deploy/worker/ui/IndexPage.scala b/core/src/main/scala/spark/deploy/worker/ui/IndexPage.scala deleted file mode 100644 index 243e0765cb..0000000000 --- a/core/src/main/scala/spark/deploy/worker/ui/IndexPage.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.worker.ui - -import javax.servlet.http.HttpServletRequest - -import scala.xml.Node - -import akka.dispatch.Await -import akka.pattern.ask -import akka.util.duration._ - -import net.liftweb.json.JsonAST.JValue - -import spark.Utils -import spark.deploy.JsonProtocol -import spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse} -import spark.deploy.worker.ExecutorRunner -import spark.ui.UIUtils - - -private[spark] class IndexPage(parent: WorkerWebUI) { - val workerActor = parent.worker.self - val worker = parent.worker - val timeout = parent.timeout - - def renderJson(request: HttpServletRequest): JValue = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, 30 seconds) - JsonProtocol.writeWorkerState(workerState) - } - - def render(request: HttpServletRequest): Seq[Node] = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, 30 seconds) - - val executorHeaders = Seq("ExecutorID", "Cores", "Memory", "Job Details", "Logs") - val runningExecutorTable = - UIUtils.listingTable(executorHeaders, executorRow, workerState.executors) - val finishedExecutorTable = - UIUtils.listingTable(executorHeaders, executorRow, workerState.finishedExecutors) - - val content = -
-
-
    -
  • ID: {workerState.workerId}
  • -
  • - Master URL: {workerState.masterUrl} -
  • -
  • Cores: {workerState.cores} ({workerState.coresUsed} Used)
  • -
  • Memory: {Utils.megabytesToString(workerState.memory)} - ({Utils.megabytesToString(workerState.memoryUsed)} Used)
  • -
-

Back to Master

-
-
- -
-
-

Running Executors {workerState.executors.size}

- {runningExecutorTable} -
-
- -
-
-

Finished Executors

- {finishedExecutorTable} -
-
; - - UIUtils.basicSparkPage(content, "Spark Worker at %s:%s".format( - workerState.host, workerState.port)) - } - - def executorRow(executor: ExecutorRunner): Seq[Node] = { - - {executor.execId} - {executor.cores} - - {Utils.megabytesToString(executor.memory)} - - -
    -
  • ID: {executor.appId}
  • -
  • Name: {executor.appDesc.name}
  • -
  • User: {executor.appDesc.user}
  • -
- - - stdout - stderr - - - } - -} diff --git a/core/src/main/scala/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/ui/WorkerWebUI.scala deleted file mode 100644 index 0a75ad8cf4..0000000000 --- a/core/src/main/scala/spark/deploy/worker/ui/WorkerWebUI.scala +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.deploy.worker.ui - -import akka.util.{Duration, Timeout} - -import java.io.{FileInputStream, File} - -import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.{Handler, Server} - -import spark.deploy.worker.Worker -import spark.{Utils, Logging} -import spark.ui.JettyUtils -import spark.ui.JettyUtils._ -import spark.ui.UIUtils - -/** - * Web UI server for the standalone worker. - */ -private[spark] -class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[Int] = None) - extends Logging { - implicit val timeout = Timeout( - Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")) - val host = Utils.localHostName() - val port = requestedPort.getOrElse( - System.getProperty("worker.ui.port", WorkerWebUI.DEFAULT_PORT).toInt) - - var server: Option[Server] = None - var boundPort: Option[Int] = None - - val indexPage = new IndexPage(this) - - val metricsHandlers = worker.metricsSystem.getServletHandlers - - val handlers = metricsHandlers ++ Array[(String, Handler)]( - ("/static", createStaticHandler(WorkerWebUI.STATIC_RESOURCE_DIR)), - ("/log", (request: HttpServletRequest) => log(request)), - ("/logPage", (request: HttpServletRequest) => logPage(request)), - ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), - ("*", (request: HttpServletRequest) => indexPage.render(request)) - ) - - def start() { - try { - val (srv, bPort) = JettyUtils.startJettyServer("0.0.0.0", port, handlers) - server = Some(srv) - boundPort = Some(bPort) - logInfo("Started Worker web UI at http://%s:%d".format(host, bPort)) - } catch { - case e: Exception => - logError("Failed to create Worker JettyUtils", e) - System.exit(1) - } - } - - def log(request: HttpServletRequest): String = { - val defaultBytes = 100 * 1024 - val appId = request.getParameter("appId") - val executorId = request.getParameter("executorId") - val logType = request.getParameter("logType") - val offset = Option(request.getParameter("offset")).map(_.toLong) - val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) - val path = "%s/%s/%s/%s".format(workDir.getPath, appId, executorId, logType) - - val (startByte, endByte) = getByteRange(path, offset, byteLength) - val file = new File(path) - val logLength = file.length - - val pre = "==== Bytes %s-%s of %s of %s/%s/%s ====\n" - .format(startByte, endByte, logLength, appId, executorId, logType) - pre + Utils.offsetBytes(path, startByte, endByte) - } - - def logPage(request: HttpServletRequest): Seq[scala.xml.Node] = { - val defaultBytes = 100 * 1024 - val appId = request.getParameter("appId") - val executorId = request.getParameter("executorId") - val logType = request.getParameter("logType") - val offset = Option(request.getParameter("offset")).map(_.toLong) - val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) - val path = "%s/%s/%s/%s".format(workDir.getPath, appId, executorId, logType) - - val (startByte, endByte) = getByteRange(path, offset, byteLength) - val file = new File(path) - val logLength = file.length - - val logText = {Utils.offsetBytes(path, startByte, endByte)} - - val linkToMaster =

Back to Master

- - val range = Bytes {startByte.toString} - {endByte.toString} of {logLength} - - val backButton = - if (startByte > 0) { - - - - } - else { - - } - - val nextButton = - if (endByte < logLength) { - - - - } - else { - - } - - val content = - - - {linkToMaster} -
-
{backButton}
-
{range}
-
{nextButton}
-
-
-
-
{logText}
-
- - - UIUtils.basicSparkPage(content, logType + " log page for " + appId) - } - - /** Determine the byte range for a log or log page. */ - def getByteRange(path: String, offset: Option[Long], byteLength: Int) - : (Long, Long) = { - val defaultBytes = 100 * 1024 - val maxBytes = 1024 * 1024 - - val file = new File(path) - val logLength = file.length() - val getOffset = offset.getOrElse(logLength-defaultBytes) - - val startByte = - if (getOffset < 0) 0L - else if (getOffset > logLength) logLength - else getOffset - - val logPageLength = math.min(byteLength, maxBytes) - - val endByte = math.min(startByte+logPageLength, logLength) - - (startByte, endByte) - } - - def stop() { - server.foreach(_.stop()) - } -} - -private[spark] object WorkerWebUI { - val STATIC_RESOURCE_DIR = "spark/ui/static" - val DEFAULT_PORT="8081" -} diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala deleted file mode 100644 index fa82d2b324..0000000000 --- a/core/src/main/scala/spark/executor/Executor.scala +++ /dev/null @@ -1,269 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.executor - -import java.io.{File} -import java.lang.management.ManagementFactory -import java.nio.ByteBuffer -import java.util.concurrent._ - -import scala.collection.JavaConversions._ -import scala.collection.mutable.HashMap - -import spark.scheduler._ -import spark._ - - -/** - * The Mesos executor for Spark. - */ -private[spark] class Executor( - executorId: String, - slaveHostname: String, - properties: Seq[(String, String)]) - extends Logging -{ - // Application dependencies (added through SparkContext) that we've fetched so far on this node. - // Each map holds the master's timestamp for the version of that file or JAR we got. - private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() - private val currentJars: HashMap[String, Long] = new HashMap[String, Long]() - - private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) - - initLogging() - - // No ip or host:port - just hostname - Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") - // must not have port specified. - assert (0 == Utils.parseHostPort(slaveHostname)._2) - - // Make sure the local hostname we report matches the cluster scheduler's name for this host - Utils.setCustomHostname(slaveHostname) - - // Set spark.* system properties from executor arg - for ((key, value) <- properties) { - System.setProperty(key, value) - } - - // If we are in yarn mode, systems can have different disk layouts so we must set it - // to what Yarn on this system said was available. This will be used later when SparkEnv - // created. - if (java.lang.Boolean.valueOf(System.getenv("SPARK_YARN_MODE"))) { - System.setProperty("spark.local.dir", getYarnLocalDirs()) - } - - // Create our ClassLoader and set it on this thread - private val urlClassLoader = createClassLoader() - private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) - Thread.currentThread.setContextClassLoader(replClassLoader) - - // Make any thread terminations due to uncaught exceptions kill the entire - // executor process to avoid surprising stalls. - Thread.setDefaultUncaughtExceptionHandler( - new Thread.UncaughtExceptionHandler { - override def uncaughtException(thread: Thread, exception: Throwable) { - try { - logError("Uncaught exception in thread " + thread, exception) - - // We may have been called from a shutdown hook. If so, we must not call System.exit(). - // (If we do, we will deadlock.) - if (!Utils.inShutdown()) { - if (exception.isInstanceOf[OutOfMemoryError]) { - System.exit(ExecutorExitCode.OOM) - } else { - System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION) - } - } - } catch { - case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM) - case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE) - } - } - } - ) - - val executorSource = new ExecutorSource(this) - - // Initialize Spark environment (using system properties read above) - val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false) - SparkEnv.set(env) - env.metricsSystem.registerSource(executorSource) - - private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size") - - // Start worker thread pool - val threadPool = new ThreadPoolExecutor( - 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) - - def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) { - threadPool.execute(new TaskRunner(context, taskId, serializedTask)) - } - - /** Get the Yarn approved local directories. */ - private def getYarnLocalDirs(): String = { - // Hadoop 0.23 and 2.x have different Environment variable names for the - // local dirs, so lets check both. We assume one of the 2 is set. - // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X - val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) - .getOrElse(Option(System.getenv("LOCAL_DIRS")) - .getOrElse("")) - - if (localDirs.isEmpty()) { - throw new Exception("Yarn Local dirs can't be empty") - } - return localDirs - } - - class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) - extends Runnable { - - override def run() { - val startTime = System.currentTimeMillis() - SparkEnv.set(env) - Thread.currentThread.setContextClassLoader(replClassLoader) - val ser = SparkEnv.get.closureSerializer.newInstance() - logInfo("Running task ID " + taskId) - context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) - var attemptedTask: Option[Task[Any]] = None - var taskStart: Long = 0 - def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum - val startGCTime = getTotalGCTime - - try { - SparkEnv.set(env) - Accumulators.clear() - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) - updateDependencies(taskFiles, taskJars) - val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) - attemptedTask = Some(task) - logInfo("Its epoch is " + task.epoch) - env.mapOutputTracker.updateEpoch(task.epoch) - taskStart = System.currentTimeMillis() - val value = task.run(taskId.toInt) - val taskFinish = System.currentTimeMillis() - for (m <- task.metrics) { - m.hostname = Utils.localHostName - m.executorDeserializeTime = (taskStart - startTime).toInt - m.executorRunTime = (taskFinish - taskStart).toInt - m.jvmGCTime = getTotalGCTime - startGCTime - } - //TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c - // we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could - // just change the relevants bytes in the byte buffer - val accumUpdates = Accumulators.values - val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null)) - val serializedResult = ser.serialize(result) - logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit) - if (serializedResult.limit >= (akkaFrameSize - 1024)) { - context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(TaskResultTooBigFailure())) - return - } - context.statusUpdate(taskId, TaskState.FINISHED, serializedResult) - logInfo("Finished task ID " + taskId) - } catch { - case ffe: FetchFailedException => { - val reason = ffe.toTaskEndReason - context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - } - - case t: Throwable => { - val serviceTime = (System.currentTimeMillis() - taskStart).toInt - val metrics = attemptedTask.flatMap(t => t.metrics) - for (m <- metrics) { - m.executorRunTime = serviceTime - m.jvmGCTime = getTotalGCTime - startGCTime - } - val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics) - context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - - // TODO: Should we exit the whole executor here? On the one hand, the failed task may - // have left some weird state around depending on when the exception was thrown, but on - // the other hand, maybe we could detect that when future tasks fail and exit then. - logError("Exception in task ID " + taskId, t) - //System.exit(1) - } - } - } - } - - /** - * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes - * created by the interpreter to the search path - */ - private def createClassLoader(): ExecutorURLClassLoader = { - var loader = this.getClass.getClassLoader - - // For each of the jars in the jarSet, add them to the class loader. - // We assume each of the files has already been fetched. - val urls = currentJars.keySet.map { uri => - new File(uri.split("/").last).toURI.toURL - }.toArray - new ExecutorURLClassLoader(urls, loader) - } - - /** - * If the REPL is in use, add another ClassLoader that will read - * new classes defined by the REPL as the user types code - */ - private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = { - val classUri = System.getProperty("spark.repl.class.uri") - if (classUri != null) { - logInfo("Using REPL class URI: " + classUri) - try { - val klass = Class.forName("spark.repl.ExecutorClassLoader") - .asInstanceOf[Class[_ <: ClassLoader]] - val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader]) - return constructor.newInstance(classUri, parent) - } catch { - case _: ClassNotFoundException => - logError("Could not find spark.repl.ExecutorClassLoader on classpath!") - System.exit(1) - null - } - } else { - return parent - } - } - - /** - * Download any missing dependencies if we receive a new set of files and JARs from the - * SparkContext. Also adds any new JARs we fetched to the class loader. - */ - private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - synchronized { - // Fetch missing dependencies - for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentFiles(name) = timestamp - } - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentJars(name) = timestamp - // Add it to our class loader - val localName = name.split("/").last - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!urlClassLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - urlClassLoader.addURL(url) - } - } - } - } -} diff --git a/core/src/main/scala/spark/executor/ExecutorBackend.scala b/core/src/main/scala/spark/executor/ExecutorBackend.scala deleted file mode 100644 index 33a6f8a824..0000000000 --- a/core/src/main/scala/spark/executor/ExecutorBackend.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.executor - -import java.nio.ByteBuffer -import spark.TaskState.TaskState - -/** - * A pluggable interface used by the Executor to send updates to the cluster scheduler. - */ -private[spark] trait ExecutorBackend { - def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) -} diff --git a/core/src/main/scala/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/spark/executor/ExecutorExitCode.scala deleted file mode 100644 index 64b9fb88f8..0000000000 --- a/core/src/main/scala/spark/executor/ExecutorExitCode.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.executor - -/** - * These are exit codes that executors should use to provide the master with information about - * executor failures assuming that cluster management framework can capture the exit codes (but - * perhaps not log files). The exit code constants here are chosen to be unlikely to conflict - * with "natural" exit statuses that may be caused by the JVM or user code. In particular, - * exit codes 128+ arise on some Unix-likes as a result of signals, and it appears that the - * OpenJDK JVM may use exit code 1 in some of its own "last chance" code. - */ -private[spark] -object ExecutorExitCode { - /** The default uncaught exception handler was reached. */ - val UNCAUGHT_EXCEPTION = 50 - - /** The default uncaught exception handler was called and an exception was encountered while - logging the exception. */ - val UNCAUGHT_EXCEPTION_TWICE = 51 - - /** The default uncaught exception handler was reached, and the uncaught exception was an - OutOfMemoryError. */ - val OOM = 52 - - /** DiskStore failed to create a local temporary directory after many attempts. */ - val DISK_STORE_FAILED_TO_CREATE_DIR = 53 - - def explainExitCode(exitCode: Int): String = { - exitCode match { - case UNCAUGHT_EXCEPTION => "Uncaught exception" - case UNCAUGHT_EXCEPTION_TWICE => "Uncaught exception, and logging the exception failed" - case OOM => "OutOfMemoryError" - case DISK_STORE_FAILED_TO_CREATE_DIR => - "Failed to create local directory (bad spark.local.dir?)" - case _ => - "Unknown executor exit code (" + exitCode + ")" + ( - if (exitCode > 128) - " (died from signal " + (exitCode - 128) + "?)" - else - "" - ) - } - } -} diff --git a/core/src/main/scala/spark/executor/ExecutorSource.scala b/core/src/main/scala/spark/executor/ExecutorSource.scala deleted file mode 100644 index d491a3c0c9..0000000000 --- a/core/src/main/scala/spark/executor/ExecutorSource.scala +++ /dev/null @@ -1,55 +0,0 @@ -package spark.executor - -import com.codahale.metrics.{Gauge, MetricRegistry} - -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.hdfs.DistributedFileSystem -import org.apache.hadoop.fs.LocalFileSystem - -import scala.collection.JavaConversions._ - -import spark.metrics.source.Source - -class ExecutorSource(val executor: Executor) extends Source { - private def fileStats(scheme: String) : Option[FileSystem.Statistics] = - FileSystem.getAllStatistics().filter(s => s.getScheme.equals(scheme)).headOption - - private def registerFileSystemStat[T]( - scheme: String, name: String, f: FileSystem.Statistics => T, defaultValue: T) = { - metricRegistry.register(MetricRegistry.name("filesystem", scheme, name), new Gauge[T] { - override def getValue: T = fileStats(scheme).map(f).getOrElse(defaultValue) - }) - } - - val metricRegistry = new MetricRegistry() - val sourceName = "executor" - - // Gauge for executor thread pool's actively executing task counts - metricRegistry.register(MetricRegistry.name("threadpool", "activeTask", "count"), new Gauge[Int] { - override def getValue: Int = executor.threadPool.getActiveCount() - }) - - // Gauge for executor thread pool's approximate total number of tasks that have been completed - metricRegistry.register(MetricRegistry.name("threadpool", "completeTask", "count"), new Gauge[Long] { - override def getValue: Long = executor.threadPool.getCompletedTaskCount() - }) - - // Gauge for executor thread pool's current number of threads - metricRegistry.register(MetricRegistry.name("threadpool", "currentPool", "size"), new Gauge[Int] { - override def getValue: Int = executor.threadPool.getPoolSize() - }) - - // Gauge got executor thread pool's largest number of threads that have ever simultaneously been in th pool - metricRegistry.register(MetricRegistry.name("threadpool", "maxPool", "size"), new Gauge[Int] { - override def getValue: Int = executor.threadPool.getMaximumPoolSize() - }) - - // Gauge for file system stats of this executor - for (scheme <- Array("hdfs", "file")) { - registerFileSystemStat(scheme, "bytesRead", _.getBytesRead(), 0L) - registerFileSystemStat(scheme, "bytesWritten", _.getBytesWritten(), 0L) - registerFileSystemStat(scheme, "readOps", _.getReadOps(), 0) - registerFileSystemStat(scheme, "largeReadOps", _.getLargeReadOps(), 0) - registerFileSystemStat(scheme, "writeOps", _.getWriteOps(), 0) - } -} diff --git a/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala deleted file mode 100644 index 09d12fb65b..0000000000 --- a/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.executor - -import java.net.{URLClassLoader, URL} - -/** - * The addURL method in URLClassLoader is protected. We subclass it to make this accessible. - */ -private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader) - extends URLClassLoader(urls, parent) { - - override def addURL(url: URL) { - super.addURL(url) - } -} diff --git a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala deleted file mode 100644 index 4961c42fad..0000000000 --- a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.executor - -import java.nio.ByteBuffer -import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver} -import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => MesosTaskStatus, _} -import spark.TaskState.TaskState -import com.google.protobuf.ByteString -import spark.{Utils, Logging} -import spark.TaskState - -private[spark] class MesosExecutorBackend - extends MesosExecutor - with ExecutorBackend - with Logging { - - var executor: Executor = null - var driver: ExecutorDriver = null - - override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - val mesosTaskId = TaskID.newBuilder().setValue(taskId.toString).build() - driver.sendStatusUpdate(MesosTaskStatus.newBuilder() - .setTaskId(mesosTaskId) - .setState(TaskState.toMesos(state)) - .setData(ByteString.copyFrom(data)) - .build()) - } - - override def registered( - driver: ExecutorDriver, - executorInfo: ExecutorInfo, - frameworkInfo: FrameworkInfo, - slaveInfo: SlaveInfo) { - logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue) - this.driver = driver - val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) - executor = new Executor( - executorInfo.getExecutorId.getValue, - slaveInfo.getHostname, - properties) - } - - override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { - val taskId = taskInfo.getTaskId.getValue.toLong - if (executor == null) { - logError("Received launchTask but executor was null") - } else { - executor.launchTask(this, taskId, taskInfo.getData.asReadOnlyByteBuffer) - } - } - - override def error(d: ExecutorDriver, message: String) { - logError("Error from Mesos: " + message) - } - - override def killTask(d: ExecutorDriver, t: TaskID) { - logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)") - } - - override def reregistered(d: ExecutorDriver, p2: SlaveInfo) {} - - override def disconnected(d: ExecutorDriver) {} - - override def frameworkMessage(d: ExecutorDriver, data: Array[Byte]) {} - - override def shutdown(d: ExecutorDriver) {} -} - -/** - * Entry point for Mesos executor. - */ -private[spark] object MesosExecutorBackend { - def main(args: Array[String]) { - MesosNativeLibrary.load() - // Create a new Executor and start it running - val runner = new MesosExecutorBackend() - new MesosExecutorDriver(runner).run() - } -} diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala deleted file mode 100644 index b5fb6dbe29..0000000000 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.executor - -import java.nio.ByteBuffer - -import akka.actor.{ActorRef, Actor, Props, Terminated} -import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} - -import spark.{Logging, Utils, SparkEnv} -import spark.TaskState.TaskState -import spark.scheduler.cluster.StandaloneClusterMessages._ -import spark.util.AkkaUtils - - -private[spark] class StandaloneExecutorBackend( - driverUrl: String, - executorId: String, - hostPort: String, - cores: Int) - extends Actor - with ExecutorBackend - with Logging { - - Utils.checkHostPort(hostPort, "Expected hostport") - - var executor: Executor = null - var driver: ActorRef = null - - override def preStart() { - logInfo("Connecting to driver: " + driverUrl) - driver = context.actorFor(driverUrl) - driver ! RegisterExecutor(executorId, hostPort, cores) - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(driver) // Doesn't work with remote actors, but useful for testing - } - - override def receive = { - case RegisteredExecutor(sparkProperties) => - logInfo("Successfully registered with driver") - // Make this host instead of hostPort ? - executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties) - - case RegisterExecutorFailed(message) => - logError("Slave registration failed: " + message) - System.exit(1) - - case LaunchTask(taskDesc) => - logInfo("Got assigned task " + taskDesc.taskId) - if (executor == null) { - logError("Received launchTask but executor was null") - System.exit(1) - } else { - executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask) - } - - case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => - logError("Driver terminated or disconnected! Shutting down.") - System.exit(1) - } - - override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - driver ! StatusUpdate(executorId, taskId, state, data) - } -} - -private[spark] object StandaloneExecutorBackend { - def run(driverUrl: String, executorId: String, hostname: String, cores: Int) { - // Debug code - Utils.checkHost(hostname) - - // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor - // before getting started with all our system properties, etc - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0) - // set it - val sparkHostPort = hostname + ":" + boundPort - System.setProperty("spark.hostPort", sparkHostPort) - val actor = actorSystem.actorOf( - Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)), - name = "Executor") - actorSystem.awaitTermination() - } - - def main(args: Array[String]) { - if (args.length < 4) { - //the reason we allow the last frameworkId argument is to make it easy to kill rogue executors - System.err.println("Usage: StandaloneExecutorBackend []") - System.exit(1) - } - run(args(0), args(1), args(2), args(3).toInt) - } -} diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala deleted file mode 100644 index 47b8890bee..0000000000 --- a/core/src/main/scala/spark/executor/TaskMetrics.scala +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.executor - -class TaskMetrics extends Serializable { - /** - * Host's name the task runs on - */ - var hostname: String = _ - - /** - * Time taken on the executor to deserialize this task - */ - var executorDeserializeTime: Int = _ - - /** - * Time the executor spends actually running the task (including fetching shuffle data) - */ - var executorRunTime: Int = _ - - /** - * The number of bytes this task transmitted back to the driver as the TaskResult - */ - var resultSize: Long = _ - - /** - * Amount of time the JVM spent in garbage collection while executing this task - */ - var jvmGCTime: Long = _ - - /** - * If this task reads from shuffle output, metrics on getting shuffle data will be collected here - */ - var shuffleReadMetrics: Option[ShuffleReadMetrics] = None - - /** - * If this task writes to shuffle output, metrics on the written shuffle data will be collected here - */ - var shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None -} - -object TaskMetrics { - private[spark] def empty(): TaskMetrics = new TaskMetrics -} - - -class ShuffleReadMetrics extends Serializable { - /** - * Time when shuffle finishs - */ - var shuffleFinishTime: Long = _ - - /** - * Total number of blocks fetched in a shuffle (remote or local) - */ - var totalBlocksFetched: Int = _ - - /** - * Number of remote blocks fetched in a shuffle - */ - var remoteBlocksFetched: Int = _ - - /** - * Local blocks fetched in a shuffle - */ - var localBlocksFetched: Int = _ - - /** - * Total time that is spent blocked waiting for shuffle to fetch data - */ - var fetchWaitTime: Long = _ - - /** - * The total amount of time for all the shuffle fetches. This adds up time from overlapping - * shuffles, so can be longer than task time - */ - var remoteFetchTime: Long = _ - - /** - * Total number of remote bytes read from a shuffle - */ - var remoteBytesRead: Long = _ -} - -class ShuffleWriteMetrics extends Serializable { - /** - * Number of bytes written for a shuffle - */ - var shuffleBytesWritten: Long = _ -} diff --git a/core/src/main/scala/spark/io/CompressionCodec.scala b/core/src/main/scala/spark/io/CompressionCodec.scala deleted file mode 100644 index 0adebecadb..0000000000 --- a/core/src/main/scala/spark/io/CompressionCodec.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.io - -import java.io.{InputStream, OutputStream} - -import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} - -import org.xerial.snappy.{SnappyInputStream, SnappyOutputStream} - - -/** - * CompressionCodec allows the customization of choosing different compression implementations - * to be used in block storage. - */ -trait CompressionCodec { - - def compressedOutputStream(s: OutputStream): OutputStream - - def compressedInputStream(s: InputStream): InputStream -} - - -private[spark] object CompressionCodec { - - def createCodec(): CompressionCodec = { - // Set the default codec to Snappy since the LZF implementation initializes a pretty large - // buffer for every stream, which results in a lot of memory overhead when the number of - // shuffle reduce buckets are large. - createCodec(classOf[SnappyCompressionCodec].getName) - } - - def createCodec(codecName: String): CompressionCodec = { - Class.forName( - System.getProperty("spark.io.compression.codec", codecName), - true, - Thread.currentThread.getContextClassLoader).newInstance().asInstanceOf[CompressionCodec] - } -} - - -/** - * LZF implementation of [[spark.io.CompressionCodec]]. - */ -class LZFCompressionCodec extends CompressionCodec { - - override def compressedOutputStream(s: OutputStream): OutputStream = { - new LZFOutputStream(s).setFinishBlockOnFlush(true) - } - - override def compressedInputStream(s: InputStream): InputStream = new LZFInputStream(s) -} - - -/** - * Snappy implementation of [[spark.io.CompressionCodec]]. - * Block size can be configured by spark.io.compression.snappy.block.size. - */ -class SnappyCompressionCodec extends CompressionCodec { - - override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = System.getProperty("spark.io.compression.snappy.block.size", "32768").toInt - new SnappyOutputStream(s, blockSize) - } - - override def compressedInputStream(s: InputStream): InputStream = new SnappyInputStream(s) -} diff --git a/core/src/main/scala/spark/metrics/MetricsConfig.scala b/core/src/main/scala/spark/metrics/MetricsConfig.scala deleted file mode 100644 index d7fb5378a4..0000000000 --- a/core/src/main/scala/spark/metrics/MetricsConfig.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics - -import java.util.Properties -import java.io.{File, FileInputStream, InputStream, IOException} - -import scala.collection.mutable -import scala.util.matching.Regex - -import spark.Logging - -private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging { - initLogging() - - val DEFAULT_PREFIX = "*" - val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r - val METRICS_CONF = "metrics.properties" - - val properties = new Properties() - var propertyCategories: mutable.HashMap[String, Properties] = null - - private def setDefaultProperties(prop: Properties) { - prop.setProperty("*.sink.servlet.class", "spark.metrics.sink.MetricsServlet") - prop.setProperty("*.sink.servlet.uri", "/metrics/json") - prop.setProperty("*.sink.servlet.sample", "false") - prop.setProperty("master.sink.servlet.uri", "/metrics/master/json") - prop.setProperty("applications.sink.servlet.uri", "/metrics/applications/json") - } - - def initialize() { - //Add default properties in case there's no properties file - setDefaultProperties(properties) - - // If spark.metrics.conf is not set, try to get file in class path - var is: InputStream = null - try { - is = configFile match { - case Some(f) => new FileInputStream(f) - case None => getClass.getClassLoader.getResourceAsStream(METRICS_CONF) - } - - if (is != null) { - properties.load(is) - } - } catch { - case e: Exception => logError("Error loading configure file", e) - } finally { - if (is != null) is.close() - } - - propertyCategories = subProperties(properties, INSTANCE_REGEX) - if (propertyCategories.contains(DEFAULT_PREFIX)) { - import scala.collection.JavaConversions._ - - val defaultProperty = propertyCategories(DEFAULT_PREFIX) - for { (inst, prop) <- propertyCategories - if (inst != DEFAULT_PREFIX) - (k, v) <- defaultProperty - if (prop.getProperty(k) == null) } { - prop.setProperty(k, v) - } - } - } - - def subProperties(prop: Properties, regex: Regex): mutable.HashMap[String, Properties] = { - val subProperties = new mutable.HashMap[String, Properties] - import scala.collection.JavaConversions._ - prop.foreach { kv => - if (regex.findPrefixOf(kv._1) != None) { - val regex(prefix, suffix) = kv._1 - subProperties.getOrElseUpdate(prefix, new Properties).setProperty(suffix, kv._2) - } - } - subProperties - } - - def getInstance(inst: String): Properties = { - propertyCategories.get(inst) match { - case Some(s) => s - case None => propertyCategories.getOrElse(DEFAULT_PREFIX, new Properties) - } - } -} - diff --git a/core/src/main/scala/spark/metrics/MetricsSystem.scala b/core/src/main/scala/spark/metrics/MetricsSystem.scala deleted file mode 100644 index 4e6c6b26c8..0000000000 --- a/core/src/main/scala/spark/metrics/MetricsSystem.scala +++ /dev/null @@ -1,163 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics - -import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} - -import java.util.Properties -import java.util.concurrent.TimeUnit - -import scala.collection.mutable - -import spark.Logging -import spark.metrics.sink.{MetricsServlet, Sink} -import spark.metrics.source.Source - -/** - * Spark Metrics System, created by specific "instance", combined by source, - * sink, periodically poll source metrics data to sink destinations. - * - * "instance" specify "who" (the role) use metrics system. In spark there are several roles - * like master, worker, executor, client driver, these roles will create metrics system - * for monitoring. So instance represents these roles. Currently in Spark, several instances - * have already implemented: master, worker, executor, driver, applications. - * - * "source" specify "where" (source) to collect metrics data. In metrics system, there exists - * two kinds of source: - * 1. Spark internal source, like MasterSource, WorkerSource, etc, which will collect - * Spark component's internal state, these sources are related to instance and will be - * added after specific metrics system is created. - * 2. Common source, like JvmSource, which will collect low level state, is configured by - * configuration and loaded through reflection. - * - * "sink" specify "where" (destination) to output metrics data to. Several sinks can be - * coexisted and flush metrics to all these sinks. - * - * Metrics configuration format is like below: - * [instance].[sink|source].[name].[options] = xxxx - * - * [instance] can be "master", "worker", "executor", "driver", "applications" which means only - * the specified instance has this property. - * wild card "*" can be used to replace instance name, which means all the instances will have - * this property. - * - * [sink|source] means this property belongs to source or sink. This field can only be source or sink. - * - * [name] specify the name of sink or source, it is custom defined. - * - * [options] is the specific property of this source or sink. - */ -private[spark] class MetricsSystem private (val instance: String) extends Logging { - initLogging() - - val confFile = System.getProperty("spark.metrics.conf") - val metricsConfig = new MetricsConfig(Option(confFile)) - - val sinks = new mutable.ArrayBuffer[Sink] - val sources = new mutable.ArrayBuffer[Source] - val registry = new MetricRegistry() - - // Treat MetricsServlet as a special sink as it should be exposed to add handlers to web ui - private var metricsServlet: Option[MetricsServlet] = None - - /** Get any UI handlers used by this metrics system. */ - def getServletHandlers = metricsServlet.map(_.getHandlers).getOrElse(Array()) - - metricsConfig.initialize() - registerSources() - registerSinks() - - def start() { - sinks.foreach(_.start) - } - - def stop() { - sinks.foreach(_.stop) - } - - def registerSource(source: Source) { - sources += source - try { - registry.register(source.sourceName, source.metricRegistry) - } catch { - case e: IllegalArgumentException => logInfo("Metrics already registered", e) - } - } - - def removeSource(source: Source) { - sources -= source - registry.removeMatching(new MetricFilter { - def matches(name: String, metric: Metric): Boolean = name.startsWith(source.sourceName) - }) - } - - def registerSources() { - val instConfig = metricsConfig.getInstance(instance) - val sourceConfigs = metricsConfig.subProperties(instConfig, MetricsSystem.SOURCE_REGEX) - - // Register all the sources related to instance - sourceConfigs.foreach { kv => - val classPath = kv._2.getProperty("class") - try { - val source = Class.forName(classPath).newInstance() - registerSource(source.asInstanceOf[Source]) - } catch { - case e: Exception => logError("Source class " + classPath + " cannot be instantialized", e) - } - } - } - - def registerSinks() { - val instConfig = metricsConfig.getInstance(instance) - val sinkConfigs = metricsConfig.subProperties(instConfig, MetricsSystem.SINK_REGEX) - - sinkConfigs.foreach { kv => - val classPath = kv._2.getProperty("class") - try { - val sink = Class.forName(classPath) - .getConstructor(classOf[Properties], classOf[MetricRegistry]) - .newInstance(kv._2, registry) - if (kv._1 == "servlet") { - metricsServlet = Some(sink.asInstanceOf[MetricsServlet]) - } else { - sinks += sink.asInstanceOf[Sink] - } - } catch { - case e: Exception => logError("Sink class " + classPath + " cannot be instantialized", e) - } - } - } -} - -private[spark] object MetricsSystem { - val SINK_REGEX = "^sink\\.(.+)\\.(.+)".r - val SOURCE_REGEX = "^source\\.(.+)\\.(.+)".r - - val MINIMAL_POLL_UNIT = TimeUnit.SECONDS - val MINIMAL_POLL_PERIOD = 1 - - def checkMinimalPollingPeriod(pollUnit: TimeUnit, pollPeriod: Int) { - val period = MINIMAL_POLL_UNIT.convert(pollPeriod, pollUnit) - if (period < MINIMAL_POLL_PERIOD) { - throw new IllegalArgumentException("Polling period " + pollPeriod + " " + pollUnit + - " below than minimal polling period ") - } - } - - def createMetricsSystem(instance: String): MetricsSystem = new MetricsSystem(instance) -} diff --git a/core/src/main/scala/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/spark/metrics/sink/ConsoleSink.scala deleted file mode 100644 index 966ba37c20..0000000000 --- a/core/src/main/scala/spark/metrics/sink/ConsoleSink.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics.sink - -import com.codahale.metrics.{ConsoleReporter, MetricRegistry} - -import java.util.Properties -import java.util.concurrent.TimeUnit - -import spark.metrics.MetricsSystem - -class ConsoleSink(val property: Properties, val registry: MetricRegistry) extends Sink { - val CONSOLE_DEFAULT_PERIOD = 10 - val CONSOLE_DEFAULT_UNIT = "SECONDS" - - val CONSOLE_KEY_PERIOD = "period" - val CONSOLE_KEY_UNIT = "unit" - - val pollPeriod = Option(property.getProperty(CONSOLE_KEY_PERIOD)) match { - case Some(s) => s.toInt - case None => CONSOLE_DEFAULT_PERIOD - } - - val pollUnit = Option(property.getProperty(CONSOLE_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) - case None => TimeUnit.valueOf(CONSOLE_DEFAULT_UNIT) - } - - MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) - - val reporter: ConsoleReporter = ConsoleReporter.forRegistry(registry) - .convertDurationsTo(TimeUnit.MILLISECONDS) - .convertRatesTo(TimeUnit.SECONDS) - .build() - - override def start() { - reporter.start(pollPeriod, pollUnit) - } - - override def stop() { - reporter.stop() - } -} - diff --git a/core/src/main/scala/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/spark/metrics/sink/CsvSink.scala deleted file mode 100644 index cb990afdef..0000000000 --- a/core/src/main/scala/spark/metrics/sink/CsvSink.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics.sink - -import com.codahale.metrics.{CsvReporter, MetricRegistry} - -import java.io.File -import java.util.{Locale, Properties} -import java.util.concurrent.TimeUnit - -import spark.metrics.MetricsSystem - -class CsvSink(val property: Properties, val registry: MetricRegistry) extends Sink { - val CSV_KEY_PERIOD = "period" - val CSV_KEY_UNIT = "unit" - val CSV_KEY_DIR = "directory" - - val CSV_DEFAULT_PERIOD = 10 - val CSV_DEFAULT_UNIT = "SECONDS" - val CSV_DEFAULT_DIR = "/tmp/" - - val pollPeriod = Option(property.getProperty(CSV_KEY_PERIOD)) match { - case Some(s) => s.toInt - case None => CSV_DEFAULT_PERIOD - } - - val pollUnit = Option(property.getProperty(CSV_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) - case None => TimeUnit.valueOf(CSV_DEFAULT_UNIT) - } - - MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) - - val pollDir = Option(property.getProperty(CSV_KEY_DIR)) match { - case Some(s) => s - case None => CSV_DEFAULT_DIR - } - - val reporter: CsvReporter = CsvReporter.forRegistry(registry) - .formatFor(Locale.US) - .convertDurationsTo(TimeUnit.MILLISECONDS) - .convertRatesTo(TimeUnit.SECONDS) - .build(new File(pollDir)) - - override def start() { - reporter.start(pollPeriod, pollUnit) - } - - override def stop() { - reporter.stop() - } -} - diff --git a/core/src/main/scala/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/spark/metrics/sink/JmxSink.scala deleted file mode 100644 index ee04544c0e..0000000000 --- a/core/src/main/scala/spark/metrics/sink/JmxSink.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics.sink - -import com.codahale.metrics.{JmxReporter, MetricRegistry} - -import java.util.Properties - -class JmxSink(val property: Properties, val registry: MetricRegistry) extends Sink { - val reporter: JmxReporter = JmxReporter.forRegistry(registry).build() - - override def start() { - reporter.start() - } - - override def stop() { - reporter.stop() - } - -} diff --git a/core/src/main/scala/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/spark/metrics/sink/MetricsServlet.scala deleted file mode 100644 index 17432b1ed1..0000000000 --- a/core/src/main/scala/spark/metrics/sink/MetricsServlet.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics.sink - -import com.codahale.metrics.MetricRegistry -import com.codahale.metrics.json.MetricsModule - -import com.fasterxml.jackson.databind.ObjectMapper - -import java.util.Properties -import java.util.concurrent.TimeUnit -import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.Handler - -import spark.ui.JettyUtils - -class MetricsServlet(val property: Properties, val registry: MetricRegistry) extends Sink { - val SERVLET_KEY_URI = "uri" - val SERVLET_KEY_SAMPLE = "sample" - - val servletURI = property.getProperty(SERVLET_KEY_URI) - - val servletShowSample = property.getProperty(SERVLET_KEY_SAMPLE).toBoolean - - val mapper = new ObjectMapper().registerModule( - new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - - def getHandlers = Array[(String, Handler)]( - (servletURI, JettyUtils.createHandler(request => getMetricsSnapshot(request), "text/json")) - ) - - def getMetricsSnapshot(request: HttpServletRequest): String = { - mapper.writeValueAsString(registry) - } - - override def start() { } - - override def stop() { } -} diff --git a/core/src/main/scala/spark/metrics/sink/Sink.scala b/core/src/main/scala/spark/metrics/sink/Sink.scala deleted file mode 100644 index dad1a7f0fe..0000000000 --- a/core/src/main/scala/spark/metrics/sink/Sink.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics.sink - -trait Sink { - def start: Unit - def stop: Unit -} \ No newline at end of file diff --git a/core/src/main/scala/spark/metrics/source/JvmSource.scala b/core/src/main/scala/spark/metrics/source/JvmSource.scala deleted file mode 100644 index e771008557..0000000000 --- a/core/src/main/scala/spark/metrics/source/JvmSource.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics.source - -import com.codahale.metrics.MetricRegistry -import com.codahale.metrics.jvm.{GarbageCollectorMetricSet, MemoryUsageGaugeSet} - -class JvmSource extends Source { - val sourceName = "jvm" - val metricRegistry = new MetricRegistry() - - val gcMetricSet = new GarbageCollectorMetricSet - val memGaugeSet = new MemoryUsageGaugeSet - - metricRegistry.registerAll(gcMetricSet) - metricRegistry.registerAll(memGaugeSet) -} diff --git a/core/src/main/scala/spark/metrics/source/Source.scala b/core/src/main/scala/spark/metrics/source/Source.scala deleted file mode 100644 index 76199a004b..0000000000 --- a/core/src/main/scala/spark/metrics/source/Source.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics.source - -import com.codahale.metrics.MetricRegistry - -trait Source { - def sourceName: String - def metricRegistry: MetricRegistry -} diff --git a/core/src/main/scala/spark/network/BufferMessage.scala b/core/src/main/scala/spark/network/BufferMessage.scala deleted file mode 100644 index e566aeac13..0000000000 --- a/core/src/main/scala/spark/network/BufferMessage.scala +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -import spark.storage.BlockManager - - -private[spark] -class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) - extends Message(Message.BUFFER_MESSAGE, id_) { - - val initialSize = currentSize() - var gotChunkForSendingOnce = false - - def size = initialSize - - def currentSize() = { - if (buffers == null || buffers.isEmpty) { - 0 - } else { - buffers.map(_.remaining).reduceLeft(_ + _) - } - } - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { - if (maxChunkSize <= 0) { - throw new Exception("Max chunk size is " + maxChunkSize) - } - - if (size == 0 && gotChunkForSendingOnce == false) { - val newChunk = new MessageChunk( - new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) - gotChunkForSendingOnce = true - return Some(newChunk) - } - - while(!buffers.isEmpty) { - val buffer = buffers(0) - if (buffer.remaining == 0) { - BlockManager.dispose(buffer) - buffers -= buffer - } else { - val newBuffer = if (buffer.remaining <= maxChunkSize) { - buffer.duplicate() - } else { - buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] - } - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) - gotChunkForSendingOnce = true - return Some(newChunk) - } - } - None - } - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { - // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer - if (buffers.size > 1) { - throw new Exception("Attempting to get chunk from message with multiple data buffers") - } - val buffer = buffers(0) - if (buffer.remaining > 0) { - if (buffer.remaining < chunkSize) { - throw new Exception("Not enough space in data buffer for receiving chunk") - } - val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) - return Some(newChunk) - } - None - } - - def flip() { - buffers.foreach(_.flip) - } - - def hasAckId() = (ackId != 0) - - def isCompletelyReceived() = !buffers(0).hasRemaining - - override def toString = { - if (hasAckId) { - "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" - } else { - "BufferMessage(id = " + id + ", size = " + size + ")" - } - } -} diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala deleted file mode 100644 index 1e571d39ae..0000000000 --- a/core/src/main/scala/spark/network/Connection.scala +++ /dev/null @@ -1,586 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import spark._ - -import scala.collection.mutable.{HashMap, Queue, ArrayBuffer} - -import java.io._ -import java.nio._ -import java.nio.channels._ -import java.nio.channels.spi._ -import java.net._ - - -private[spark] -abstract class Connection(val channel: SocketChannel, val selector: Selector, - val socketRemoteConnectionManagerId: ConnectionManagerId) - extends Logging { - - def this(channel_ : SocketChannel, selector_ : Selector) = { - this(channel_, selector_, - ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress])) - } - - channel.configureBlocking(false) - channel.socket.setTcpNoDelay(true) - channel.socket.setReuseAddress(true) - channel.socket.setKeepAlive(true) - /*channel.socket.setReceiveBufferSize(32768) */ - - @volatile private var closed = false - var onCloseCallback: Connection => Unit = null - var onExceptionCallback: (Connection, Exception) => Unit = null - var onKeyInterestChangeCallback: (Connection, Int) => Unit = null - - val remoteAddress = getRemoteAddress() - - def resetForceReregister(): Boolean - - // Read channels typically do not register for write and write does not for read - // Now, we do have write registering for read too (temporarily), but this is to detect - // channel close NOT to actually read/consume data on it ! - // How does this work if/when we move to SSL ? - - // What is the interest to register with selector for when we want this connection to be selected - def registerInterest() - - // What is the interest to register with selector for when we want this connection to - // be de-selected - // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, - // it will be SelectionKey.OP_READ (until we fix it properly) - def unregisterInterest() - - // On receiving a read event, should we change the interest for this channel or not ? - // Will be true for ReceivingConnection, false for SendingConnection. - def changeInterestForRead(): Boolean - - // On receiving a write event, should we change the interest for this channel or not ? - // Will be false for ReceivingConnection, true for SendingConnection. - // Actually, for now, should not get triggered for ReceivingConnection - def changeInterestForWrite(): Boolean - - def getRemoteConnectionManagerId(): ConnectionManagerId = { - socketRemoteConnectionManagerId - } - - def key() = channel.keyFor(selector) - - def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] - - // Returns whether we have to register for further reads or not. - def read(): Boolean = { - throw new UnsupportedOperationException( - "Cannot read on connection of type " + this.getClass.toString) - } - - // Returns whether we have to register for further writes or not. - def write(): Boolean = { - throw new UnsupportedOperationException( - "Cannot write on connection of type " + this.getClass.toString) - } - - def close() { - closed = true - val k = key() - if (k != null) { - k.cancel() - } - channel.close() - callOnCloseCallback() - } - - protected def isClosed: Boolean = closed - - def onClose(callback: Connection => Unit) { - onCloseCallback = callback - } - - def onException(callback: (Connection, Exception) => Unit) { - onExceptionCallback = callback - } - - def onKeyInterestChange(callback: (Connection, Int) => Unit) { - onKeyInterestChangeCallback = callback - } - - def callOnExceptionCallback(e: Exception) { - if (onExceptionCallback != null) { - onExceptionCallback(this, e) - } else { - logError("Error in connection to " + getRemoteConnectionManagerId() + - " and OnExceptionCallback not registered", e) - } - } - - def callOnCloseCallback() { - if (onCloseCallback != null) { - onCloseCallback(this) - } else { - logWarning("Connection to " + getRemoteConnectionManagerId() + - " closed and OnExceptionCallback not registered") - } - - } - - def changeConnectionKeyInterest(ops: Int) { - if (onKeyInterestChangeCallback != null) { - onKeyInterestChangeCallback(this, ops) - } else { - throw new Exception("OnKeyInterestChangeCallback not registered") - } - } - - def printRemainingBuffer(buffer: ByteBuffer) { - val bytes = new Array[Byte](buffer.remaining) - val curPosition = buffer.position - buffer.get(bytes) - bytes.foreach(x => print(x + " ")) - buffer.position(curPosition) - print(" (" + bytes.size + ")") - } - - def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { - val bytes = new Array[Byte](length) - val curPosition = buffer.position - buffer.position(position) - buffer.get(bytes) - bytes.foreach(x => print(x + " ")) - print(" (" + position + ", " + length + ")") - buffer.position(curPosition) - } -} - - -private[spark] -class SendingConnection(val address: InetSocketAddress, selector_ : Selector, - remoteId_ : ConnectionManagerId) - extends Connection(SocketChannel.open, selector_, remoteId_) { - - private class Outbox(fair: Int = 0) { - val messages = new Queue[Message]() - val defaultChunkSize = 65536 //32768 //16384 - var nextMessageToBeUsed = 0 - - def addMessage(message: Message) { - messages.synchronized{ - /*messages += message*/ - messages.enqueue(message) - logDebug("Added [" + message + "] to outbox for sending to " + - "[" + getRemoteConnectionManagerId() + "]") - } - } - - def getChunk(): Option[MessageChunk] = { - fair match { - case 0 => getChunkFIFO() - case 1 => getChunkRR() - case _ => throw new Exception("Unexpected fairness policy in outbox") - } - } - - private def getChunkFIFO(): Option[MessageChunk] = { - /*logInfo("Using FIFO")*/ - messages.synchronized { - while (!messages.isEmpty) { - val message = messages(0) - val chunk = message.getChunkForSending(defaultChunkSize) - if (chunk.isDefined) { - messages += message // this is probably incorrect, it wont work as fifo - if (!message.started) { - logDebug("Starting to send [" + message + "]") - message.started = true - message.startTime = System.currentTimeMillis - } - return chunk - } else { - /*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/ - message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + - "] in " + message.timeTaken ) - } - } - } - None - } - - private def getChunkRR(): Option[MessageChunk] = { - messages.synchronized { - while (!messages.isEmpty) { - /*nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ - /*val message = messages(nextMessageToBeUsed)*/ - val message = messages.dequeue - val chunk = message.getChunkForSending(defaultChunkSize) - if (chunk.isDefined) { - messages.enqueue(message) - nextMessageToBeUsed = nextMessageToBeUsed + 1 - if (!message.started) { - logDebug( - "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") - message.started = true - message.startTime = System.currentTimeMillis - } - logTrace( - "Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]") - return chunk - } else { - message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + - "] in " + message.timeTaken ) - } - } - } - None - } - } - - // outbox is used as a lock - ensure that it is always used as a leaf (since methods which - // lock it are invoked in context of other locks) - private val outbox = new Outbox(1) - /* - This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly - different purpose. This flag is to see if we need to force reregister for write even when we - do not have any pending bytes to write to socket. - This can happen due to a race between adding pending buffers, and checking for existing of - data as detailed in https://github.com/mesos/spark/pull/791 - */ - private var needForceReregister = false - val currentBuffers = new ArrayBuffer[ByteBuffer]() - - /*channel.socket.setSendBufferSize(256 * 1024)*/ - - override def getRemoteAddress() = address - - val DEFAULT_INTEREST = SelectionKey.OP_READ - - override def registerInterest() { - // Registering read too - does not really help in most cases, but for some - // it does - so let us keep it for now. - changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST) - } - - override def unregisterInterest() { - changeConnectionKeyInterest(DEFAULT_INTEREST) - } - - def send(message: Message) { - outbox.synchronized { - outbox.addMessage(message) - needForceReregister = true - } - if (channel.isConnected) { - registerInterest() - } - } - - // return previous value after resetting it. - def resetForceReregister(): Boolean = { - outbox.synchronized { - val result = needForceReregister - needForceReregister = false - result - } - } - - // MUST be called within the selector loop - def connect() { - try{ - channel.register(selector, SelectionKey.OP_CONNECT) - channel.connect(address) - logInfo("Initiating connection to [" + address + "]") - } catch { - case e: Exception => { - logError("Error connecting to " + address, e) - callOnExceptionCallback(e) - } - } - } - - def finishConnect(force: Boolean): Boolean = { - try { - // Typically, this should finish immediately since it was triggered by a connect - // selection - though need not necessarily always complete successfully. - val connected = channel.finishConnect - if (!force && !connected) { - logInfo( - "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") - return false - } - - // Fallback to previous behavior - assume finishConnect completed - // This will happen only when finishConnect failed for some repeated number of times - // (10 or so) - // Is highly unlikely unless there was an unclean close of socket, etc - registerInterest() - logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") - return true - } catch { - case e: Exception => { - logWarning("Error finishing connection to " + address, e) - callOnExceptionCallback(e) - // ignore - return true - } - } - } - - override def write(): Boolean = { - try { - while (true) { - if (currentBuffers.size == 0) { - outbox.synchronized { - outbox.getChunk() match { - case Some(chunk) => { - val buffers = chunk.buffers - // If we have 'seen' pending messages, then reset flag - since we handle that as normal - // registering of event (below) - if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister() - currentBuffers ++= buffers - } - case None => { - // changeConnectionKeyInterest(0) - /*key.interestOps(0)*/ - return false - } - } - } - } - - if (currentBuffers.size > 0) { - val buffer = currentBuffers(0) - val remainingBytes = buffer.remaining - val writtenBytes = channel.write(buffer) - if (buffer.remaining == 0) { - currentBuffers -= buffer - } - if (writtenBytes < remainingBytes) { - // re-register for write. - return true - } - } - } - } catch { - case e: Exception => { - logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) - close() - return false - } - } - // should not happen - to keep scala compiler happy - return true - } - - // This is a hack to determine if remote socket was closed or not. - // SendingConnection DOES NOT expect to receive any data - if it does, it is an error - // For a bunch of cases, read will return -1 in case remote socket is closed : hence we - // register for reads to determine that. - override def read(): Boolean = { - // We don't expect the other side to send anything; so, we just read to detect an error or EOF. - try { - val length = channel.read(ByteBuffer.allocate(1)) - if (length == -1) { // EOF - close() - } else if (length > 0) { - logWarning( - "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) - } - } catch { - case e: Exception => - logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) - close() - } - - false - } - - override def changeInterestForRead(): Boolean = false - - override def changeInterestForWrite(): Boolean = ! isClosed -} - - -// Must be created within selector loop - else deadlock -private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) - extends Connection(channel_, selector_) { - - class Inbox() { - val messages = new HashMap[Int, BufferMessage]() - - def getChunk(header: MessageChunkHeader): Option[MessageChunk] = { - - def createNewMessage: BufferMessage = { - val newMessage = Message.create(header).asInstanceOf[BufferMessage] - newMessage.started = true - newMessage.startTime = System.currentTimeMillis - logDebug( - "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") - messages += ((newMessage.id, newMessage)) - newMessage - } - - val message = messages.getOrElseUpdate(header.id, createNewMessage) - logTrace( - "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") - message.getChunkForReceiving(header.chunkSize) - } - - def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = { - messages.get(chunk.header.id) - } - - def removeMessage(message: Message) { - messages -= message.id - } - } - - @volatile private var inferredRemoteManagerId: ConnectionManagerId = null - - override def getRemoteConnectionManagerId(): ConnectionManagerId = { - val currId = inferredRemoteManagerId - if (currId != null) currId else super.getRemoteConnectionManagerId() - } - - // The reciever's remote address is the local socket on remote side : which is NOT - // the connection manager id of the receiver. - // We infer that from the messages we receive on the receiver socket. - private def processConnectionManagerId(header: MessageChunkHeader) { - val currId = inferredRemoteManagerId - if (header.address == null || currId != null) return - - val managerId = ConnectionManagerId.fromSocketAddress(header.address) - - if (managerId != null) { - inferredRemoteManagerId = managerId - } - } - - - val inbox = new Inbox() - val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) - var onReceiveCallback: (Connection , Message) => Unit = null - var currentChunk: MessageChunk = null - - channel.register(selector, SelectionKey.OP_READ) - - override def read(): Boolean = { - try { - while (true) { - if (currentChunk == null) { - val headerBytesRead = channel.read(headerBuffer) - if (headerBytesRead == -1) { - close() - return false - } - if (headerBuffer.remaining > 0) { - // re-register for read event ... - return true - } - headerBuffer.flip - if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { - throw new Exception( - "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") - } - val header = MessageChunkHeader.create(headerBuffer) - headerBuffer.clear() - - processConnectionManagerId(header) - - header.typ match { - case Message.BUFFER_MESSAGE => { - if (header.totalSize == 0) { - if (onReceiveCallback != null) { - onReceiveCallback(this, Message.create(header)) - } - currentChunk = null - // re-register for read event ... - return true - } else { - currentChunk = inbox.getChunk(header).orNull - } - } - case _ => throw new Exception("Message of unknown type received") - } - } - - if (currentChunk == null) throw new Exception("No message chunk to receive data") - - val bytesRead = channel.read(currentChunk.buffer) - if (bytesRead == 0) { - // re-register for read event ... - return true - } else if (bytesRead == -1) { - close() - return false - } - - /*logDebug("Read " + bytesRead + " bytes for the buffer")*/ - - if (currentChunk.buffer.remaining == 0) { - /*println("Filled buffer at " + System.currentTimeMillis)*/ - val bufferMessage = inbox.getMessageForChunk(currentChunk).get - if (bufferMessage.isCompletelyReceived) { - bufferMessage.flip - bufferMessage.finishTime = System.currentTimeMillis - logDebug("Finished receiving [" + bufferMessage + "] from " + - "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) - if (onReceiveCallback != null) { - onReceiveCallback(this, bufferMessage) - } - inbox.removeMessage(bufferMessage) - } - currentChunk = null - } - } - } catch { - case e: Exception => { - logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) - close() - return false - } - } - // should not happen - to keep scala compiler happy - return true - } - - def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} - - // override def changeInterestForRead(): Boolean = ! isClosed - override def changeInterestForRead(): Boolean = true - - override def changeInterestForWrite(): Boolean = { - throw new IllegalStateException("Unexpected invocation right now") - } - - override def registerInterest() { - // Registering read too - does not really help in most cases, but for some - // it does - so let us keep it for now. - changeConnectionKeyInterest(SelectionKey.OP_READ) - } - - override def unregisterInterest() { - changeConnectionKeyInterest(0) - } - - // For read conn, always false. - override def resetForceReregister(): Boolean = false -} diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala deleted file mode 100644 index 8b9f3ae18c..0000000000 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ /dev/null @@ -1,720 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import spark._ - -import java.nio._ -import java.nio.channels._ -import java.nio.channels.spi._ -import java.net._ -import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} - -import scala.collection.mutable.HashSet -import scala.collection.mutable.HashMap -import scala.collection.mutable.SynchronizedMap -import scala.collection.mutable.SynchronizedQueue -import scala.collection.mutable.ArrayBuffer - -import akka.dispatch.{Await, Promise, ExecutionContext, Future} -import akka.util.Duration -import akka.util.duration._ - - -private[spark] class ConnectionManager(port: Int) extends Logging { - - class MessageStatus( - val message: Message, - val connectionManagerId: ConnectionManagerId, - completionHandler: MessageStatus => Unit) { - - var ackMessage: Option[Message] = None - var attempted = false - var acked = false - - def markDone() { completionHandler(this) } - } - - private val selector = SelectorProvider.provider.openSelector() - - private val handleMessageExecutor = new ThreadPoolExecutor( - System.getProperty("spark.core.connection.handler.threads.min","20").toInt, - System.getProperty("spark.core.connection.handler.threads.max","60").toInt, - System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable]()) - - private val handleReadWriteExecutor = new ThreadPoolExecutor( - System.getProperty("spark.core.connection.io.threads.min","4").toInt, - System.getProperty("spark.core.connection.io.threads.max","32").toInt, - System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable]()) - - // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap - private val handleConnectExecutor = new ThreadPoolExecutor( - System.getProperty("spark.core.connection.connect.threads.min","1").toInt, - System.getProperty("spark.core.connection.connect.threads.max","8").toInt, - System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable]()) - - private val serverChannel = ServerSocketChannel.open() - private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] - private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] - private val messageStatuses = new HashMap[Int, MessageStatus] - private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] - private val registerRequests = new SynchronizedQueue[SendingConnection] - - implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool()) - - private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null - - serverChannel.configureBlocking(false) - serverChannel.socket.setReuseAddress(true) - serverChannel.socket.setReceiveBufferSize(256 * 1024) - - serverChannel.socket.bind(new InetSocketAddress(port)) - serverChannel.register(selector, SelectionKey.OP_ACCEPT) - - val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) - logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) - - private val selectorThread = new Thread("connection-manager-thread") { - override def run() = ConnectionManager.this.run() - } - selectorThread.setDaemon(true) - selectorThread.start() - - private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - - private def triggerWrite(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - writeRunnableStarted.synchronized { - // So that we do not trigger more write events while processing this one. - // The write method will re-register when done. - if (conn.changeInterestForWrite()) conn.unregisterInterest() - if (writeRunnableStarted.contains(key)) { - // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE) - return - } - - writeRunnableStarted += key - } - handleReadWriteExecutor.execute(new Runnable { - override def run() { - var register: Boolean = false - try { - register = conn.write() - } finally { - writeRunnableStarted.synchronized { - writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() - if (needReregister && conn.changeInterestForWrite()) { - conn.registerInterest() - } - } - } - } - } ) - } - - private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - - private def triggerRead(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - readRunnableStarted.synchronized { - // So that we do not trigger more read events while processing this one. - // The read method will re-register when done. - if (conn.changeInterestForRead())conn.unregisterInterest() - if (readRunnableStarted.contains(key)) { - return - } - - readRunnableStarted += key - } - handleReadWriteExecutor.execute(new Runnable { - override def run() { - var register: Boolean = false - try { - register = conn.read() - } finally { - readRunnableStarted.synchronized { - readRunnableStarted -= key - if (register && conn.changeInterestForRead()) { - conn.registerInterest() - } - } - } - } - } ) - } - - private def triggerConnect(key: SelectionKey) { - val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection] - if (conn == null) return - - // prevent other events from being triggered - // Since we are still trying to connect, we do not need to do the additional steps in triggerWrite - conn.changeConnectionKeyInterest(0) - - handleConnectExecutor.execute(new Runnable { - override def run() { - - var tries: Int = 10 - while (tries >= 0) { - if (conn.finishConnect(false)) return - // Sleep ? - Thread.sleep(1) - tries -= 1 - } - - // fallback to previous behavior : we should not really come here since this method was - // triggered since channel became connectable : but at times, the first finishConnect need not - // succeed : hence the loop to retry a few 'times'. - conn.finishConnect(true) - } - } ) - } - - // MUST be called within selector loop - else deadlock. - private def triggerForceCloseByException(key: SelectionKey, e: Exception) { - try { - key.interestOps(0) - } catch { - // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) - } - - val conn = connectionsByKey.getOrElse(key, null) - if (conn == null) return - - // Pushing to connect threadpool - handleConnectExecutor.execute(new Runnable { - override def run() { - try { - conn.callOnExceptionCallback(e) - } catch { - // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) - } - try { - conn.close() - } catch { - // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) - } - } - }) - } - - - def run() { - try { - while(!selectorThread.isInterrupted) { - while (! registerRequests.isEmpty) { - val conn: SendingConnection = registerRequests.dequeue - addListeners(conn) - conn.connect() - addConnection(conn) - } - - while(!keyInterestChangeRequests.isEmpty) { - val (key, ops) = keyInterestChangeRequests.dequeue - - try { - if (key.isValid) { - val connection = connectionsByKey.getOrElse(key, null) - if (connection != null) { - val lastOps = key.interestOps() - key.interestOps(ops) - - // hot loop - prevent materialization of string if trace not enabled. - if (isTraceEnabled()) { - def intToOpStr(op: Int): String = { - val opStrs = ArrayBuffer[String]() - if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" - if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" - if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" - if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" - if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " - } - - logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() + - "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") - } - } - } else { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - - val selectedKeysCount = - try { - selector.select() - } catch { - // Explicitly only dealing with CancelledKeyException here since other exceptions should be dealt with differently. - case e: CancelledKeyException => { - // Some keys within the selectors list are invalid/closed. clear them. - val allKeys = selector.keys().iterator() - - while (allKeys.hasNext()) { - val key = allKeys.next() - try { - if (! key.isValid) { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - } - 0 - } - - if (selectedKeysCount == 0) { - logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys") - } - if (selectorThread.isInterrupted) { - logInfo("Selector thread was interrupted!") - return - } - - if (0 != selectedKeysCount) { - val selectedKeys = selector.selectedKeys().iterator() - while (selectedKeys.hasNext()) { - val key = selectedKeys.next - selectedKeys.remove() - try { - if (key.isValid) { - if (key.isAcceptable) { - acceptConnection(key) - } else - if (key.isConnectable) { - triggerConnect(key) - } else - if (key.isReadable) { - triggerRead(key) - } else - if (key.isWritable) { - triggerWrite(key) - } - } else { - logInfo("Key not valid ? " + key) - throw new CancelledKeyException() - } - } catch { - // weird, but we saw this happening - even though key.isValid was true, key.isAcceptable would throw CancelledKeyException. - case e: CancelledKeyException => { - logInfo("key already cancelled ? " + key, e) - triggerForceCloseByException(key, e) - } - case e: Exception => { - logError("Exception processing key " + key, e) - triggerForceCloseByException(key, e) - } - } - } - } - } - } catch { - case e: Exception => logError("Error in select loop", e) - } - } - - def acceptConnection(key: SelectionKey) { - val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] - - var newChannel = serverChannel.accept() - - // accept them all in a tight loop. non blocking accept with no processing, should be fine - while (newChannel != null) { - try { - val newConnection = new ReceivingConnection(newChannel, selector) - newConnection.onReceive(receiveMessage) - addListeners(newConnection) - addConnection(newConnection) - logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]") - } catch { - // might happen in case of issues with registering with selector - case e: Exception => logError("Error in accept loop", e) - } - - newChannel = serverChannel.accept() - } - } - - private def addListeners(connection: Connection) { - connection.onKeyInterestChange(changeConnectionKeyInterest) - connection.onException(handleConnectionError) - connection.onClose(removeConnection) - } - - def addConnection(connection: Connection) { - connectionsByKey += ((connection.key, connection)) - } - - def removeConnection(connection: Connection) { - connectionsByKey -= connection.key - - try { - if (connection.isInstanceOf[SendingConnection]) { - val sendingConnection = connection.asInstanceOf[SendingConnection] - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - logInfo("Removing SendingConnection to " + sendingConnectionManagerId) - - connectionsById -= sendingConnectionManagerId - - messageStatuses.synchronized { - messageStatuses - .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => { - logInfo("Notifying " + status) - status.synchronized { - status.attempted = true - status.acked = false - status.markDone() - } - }) - - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - } else if (connection.isInstanceOf[ReceivingConnection]) { - val receivingConnection = connection.asInstanceOf[ReceivingConnection] - val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId() - logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) - - val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) - if (! sendingConnectionOpt.isDefined) { - logError("Corresponding SendingConnectionManagerId not found") - return - } - - val sendingConnection = sendingConnectionOpt.get - connectionsById -= remoteConnectionManagerId - sendingConnection.close() - - val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() - - assert (sendingConnectionManagerId == remoteConnectionManagerId) - - messageStatuses.synchronized { - for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { - logInfo("Notifying " + s) - s.synchronized { - s.attempted = true - s.acked = false - s.markDone() - } - } - - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - } - } finally { - // So that the selection keys can be removed. - wakeupSelector() - } - } - - def handleConnectionError(connection: Connection, e: Exception) { - logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId()) - removeConnection(connection) - } - - def changeConnectionKeyInterest(connection: Connection, ops: Int) { - keyInterestChangeRequests += ((connection.key, ops)) - // so that registerations happen ! - wakeupSelector() - } - - def receiveMessage(connection: Connection, message: Message) { - val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) - logDebug("Received [" + message + "] from [" + connectionManagerId + "]") - val runnable = new Runnable() { - val creationTime = System.currentTimeMillis - def run() { - logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message) - logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") - } - } - handleMessageExecutor.execute(runnable) - /*handleMessage(connection, message)*/ - } - - private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) { - logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") - message match { - case bufferMessage: BufferMessage => { - if (bufferMessage.hasAckId) { - val sentMessageStatus = messageStatuses.synchronized { - messageStatuses.get(bufferMessage.ackId) match { - case Some(status) => { - messageStatuses -= bufferMessage.ackId - status - } - case None => { - throw new Exception("Could not find reference for received ack message " + message.id) - null - } - } - } - sentMessageStatus.synchronized { - sentMessageStatus.ackMessage = Some(message) - sentMessageStatus.attempted = true - sentMessageStatus.acked = true - sentMessageStatus.markDone() - } - } else { - val ackMessage = if (onReceiveCallback != null) { - logDebug("Calling back") - onReceiveCallback(bufferMessage, connectionManagerId) - } else { - logDebug("Not calling back as callback is null") - None - } - - if (ackMessage.isDefined) { - if (!ackMessage.get.isInstanceOf[BufferMessage]) { - logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass()) - } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { - logDebug("Response to " + bufferMessage + " does not have ack id set") - ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id - } - } - - sendMessage(connectionManagerId, ackMessage.getOrElse { - Message.createBufferMessage(bufferMessage.id) - }) - } - } - case _ => throw new Exception("Unknown type message received") - } - } - - private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { - def startNewConnection(): SendingConnection = { - val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId) - registerRequests.enqueue(newConnection) - - newConnection - } - // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ... - // If we do re-add it, we should consistently use it everywhere I guess ? - val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) - message.senderAddress = id.toSocketAddress() - logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") - connection.send(message) - - wakeupSelector() - } - - private def wakeupSelector() { - selector.wakeup() - } - - def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message) - : Future[Option[Message]] = { - val promise = Promise[Option[Message]] - val status = new MessageStatus(message, connectionManagerId, s => promise.success(s.ackMessage)) - messageStatuses.synchronized { - messageStatuses += ((message.id, status)) - } - sendMessage(connectionManagerId, message) - promise.future - } - - def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = { - Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf) - } - - def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { - onReceiveCallback = callback - } - - def stop() { - selectorThread.interrupt() - selectorThread.join() - selector.close() - val connections = connectionsByKey.values - connections.foreach(_.close()) - if (connectionsByKey.size != 0) { - logWarning("All connections not cleaned up") - } - handleMessageExecutor.shutdown() - handleReadWriteExecutor.shutdown() - handleConnectExecutor.shutdown() - logInfo("ConnectionManager stopped") - } -} - - -private[spark] object ConnectionManager { - - def main(args: Array[String]) { - val manager = new ConnectionManager(9999) - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - println("Received [" + msg + "] from [" + id + "]") - None - }) - - /*testSequentialSending(manager)*/ - /*System.gc()*/ - - /*testParallelSending(manager)*/ - /*System.gc()*/ - - /*testParallelDecreasingSending(manager)*/ - /*System.gc()*/ - - testContinuousSending(manager) - System.gc() - } - - def testSequentialSending(manager: ConnectionManager) { - println("--------------------------") - println("Sequential Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(manager.id, bufferMessage) - }) - println("--------------------------") - println() - } - - def testParallelSending(manager: ConnectionManager) { - println("--------------------------") - println("Parallel Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") - }) - val finishTime = System.currentTimeMillis - - val mb = size * count / 1024.0 / 1024.0 - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("--------------------------") - println("Started at " + startTime + ", finished at " + finishTime) - println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - - def testParallelDecreasingSending(manager: ConnectionManager) { - println("--------------------------") - println("Parallel Decreasing Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte))) - buffers.foreach(_.flip) - val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0 - - val startTime = System.currentTimeMillis - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") - }) - val finishTime = System.currentTimeMillis - - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("--------------------------") - /*println("Started at " + startTime + ", finished at " + finishTime) */ - println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - - def testContinuousSending(manager: ConnectionManager) { - println("--------------------------") - println("Continuous Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - while(true) { - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") - }) - val finishTime = System.currentTimeMillis - Thread.sleep(1000) - val mb = size * count / 1024.0 / 1024.0 - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - } -} diff --git a/core/src/main/scala/spark/network/ConnectionManagerId.scala b/core/src/main/scala/spark/network/ConnectionManagerId.scala deleted file mode 100644 index 9d5c518293..0000000000 --- a/core/src/main/scala/spark/network/ConnectionManagerId.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import java.net.InetSocketAddress - -import spark.Utils - - -private[spark] case class ConnectionManagerId(host: String, port: Int) { - // DEBUG code - Utils.checkHost(host) - assert (port > 0) - - def toSocketAddress() = new InetSocketAddress(host, port) -} - - -private[spark] object ConnectionManagerId { - def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { - new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort()) - } -} diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala deleted file mode 100644 index 9e3827aaf5..0000000000 --- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import spark._ -import spark.SparkContext._ - -import scala.io.Source - -import java.nio.ByteBuffer -import java.net.InetAddress - -import akka.dispatch.Await -import akka.util.duration._ - -private[spark] object ConnectionManagerTest extends Logging{ - def main(args: Array[String]) { - // - the master URL - // - a list slaves to run connectionTest on - //[num of tasks] - the number of parallel tasks to be initiated default is number of slave hosts - //[size of msg in MB (integer)] - the size of messages to be sent in each task, default is 10 - //[count] - how many times to run, default is 3 - //[await time in seconds] : await time (in seconds), default is 600 - if (args.length < 2) { - println("Usage: ConnectionManagerTest [num of tasks] [size of msg in MB (integer)] [count] [await time in seconds)] ") - System.exit(1) - } - - if (args(0).startsWith("local")) { - println("This runs only on a mesos cluster") - } - - val sc = new SparkContext(args(0), "ConnectionManagerTest") - val slavesFile = Source.fromFile(args(1)) - val slaves = slavesFile.mkString.split("\n") - slavesFile.close() - - /*println("Slaves")*/ - /*slaves.foreach(println)*/ - val tasknum = if (args.length > 2) args(2).toInt else slaves.length - val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024 - val count = if (args.length > 4) args(4).toInt else 3 - val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second - println("Running "+count+" rounds of test: " + "parallel tasks = " + tasknum + ", msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime) - val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map( - i => SparkEnv.get.connectionManager.id).collect() - println("\nSlave ConnectionManagerIds") - slaveConnManagerIds.foreach(println) - println - - (0 until count).foreach(i => { - val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => { - val connManager = SparkEnv.get.connectionManager - val thisConnManagerId = connManager.id - connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - logInfo("Received [" + msg + "] from [" + id + "]") - None - }) - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") - connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) - }) - val results = futures.map(f => Await.result(f, awaitTime)) - val finishTime = System.currentTimeMillis - Thread.sleep(5000) - - val mb = size * results.size / 1024.0 / 1024.0 - val ms = finishTime - startTime - val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" - logInfo(resultStr) - resultStr - }).collect() - - println("---------------------") - println("Run " + i) - resultStrs.foreach(println) - println("---------------------") - }) - } -} - diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala deleted file mode 100644 index a25457ea35..0000000000 --- a/core/src/main/scala/spark/network/Message.scala +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import java.nio.ByteBuffer -import java.net.InetSocketAddress - -import scala.collection.mutable.ArrayBuffer - - -private[spark] abstract class Message(val typ: Long, val id: Int) { - var senderAddress: InetSocketAddress = null - var started = false - var startTime = -1L - var finishTime = -1L - - def size: Int - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] - - def timeTaken(): String = (finishTime - startTime).toString + " ms" - - override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" -} - - -private[spark] object Message { - val BUFFER_MESSAGE = 1111111111L - - var lastId = 1 - - def getNewId() = synchronized { - lastId += 1 - if (lastId == 0) { - lastId += 1 - } - lastId - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = { - if (dataBuffers == null) { - return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId) - } - if (dataBuffers.exists(_ == null)) { - throw new Exception("Attempting to create buffer message with null buffer") - } - return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId) - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = - createBufferMessage(dataBuffers, 0) - - def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { - if (dataBuffer == null) { - return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) - } else { - return createBufferMessage(Array(dataBuffer), ackId) - } - } - - def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = - createBufferMessage(dataBuffer, 0) - - def createBufferMessage(ackId: Int): BufferMessage = { - createBufferMessage(new Array[ByteBuffer](0), ackId) - } - - def create(header: MessageChunkHeader): Message = { - val newMessage: Message = header.typ match { - case BUFFER_MESSAGE => new BufferMessage(header.id, - ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) - } - newMessage.senderAddress = header.address - newMessage - } -} diff --git a/core/src/main/scala/spark/network/MessageChunk.scala b/core/src/main/scala/spark/network/MessageChunk.scala deleted file mode 100644 index 784db5ab62..0000000000 --- a/core/src/main/scala/spark/network/MessageChunk.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - - -private[network] -class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { - - val size = if (buffer == null) 0 else buffer.remaining - - lazy val buffers = { - val ab = new ArrayBuffer[ByteBuffer]() - ab += header.buffer - if (buffer != null) { - ab += buffer - } - ab - } - - override def toString = { - "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" - } -} diff --git a/core/src/main/scala/spark/network/MessageChunkHeader.scala b/core/src/main/scala/spark/network/MessageChunkHeader.scala deleted file mode 100644 index 18d0cbcc14..0000000000 --- a/core/src/main/scala/spark/network/MessageChunkHeader.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import java.net.InetAddress -import java.net.InetSocketAddress -import java.nio.ByteBuffer - - -private[spark] class MessageChunkHeader( - val typ: Long, - val id: Int, - val totalSize: Int, - val chunkSize: Int, - val other: Int, - val address: InetSocketAddress) { - lazy val buffer = { - // No need to change this, at 'use' time, we do a reverse lookup of the hostname. - // Refer to network.Connection - val ip = address.getAddress.getAddress() - val port = address.getPort() - ByteBuffer. - allocate(MessageChunkHeader.HEADER_SIZE). - putLong(typ). - putInt(id). - putInt(totalSize). - putInt(chunkSize). - putInt(other). - putInt(ip.size). - put(ip). - putInt(port). - position(MessageChunkHeader.HEADER_SIZE). - flip.asInstanceOf[ByteBuffer] - } - - override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + - " and sizes " + totalSize + " / " + chunkSize + " bytes" -} - - -private[spark] object MessageChunkHeader { - val HEADER_SIZE = 40 - - def create(buffer: ByteBuffer): MessageChunkHeader = { - if (buffer.remaining != HEADER_SIZE) { - throw new IllegalArgumentException("Cannot convert buffer data to Message") - } - val typ = buffer.getLong() - val id = buffer.getInt() - val totalSize = buffer.getInt() - val chunkSize = buffer.getInt() - val other = buffer.getInt() - val ipSize = buffer.getInt() - val ipBytes = new Array[Byte](ipSize) - buffer.get(ipBytes) - val ip = InetAddress.getByAddress(ipBytes) - val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port)) - } -} diff --git a/core/src/main/scala/spark/network/ReceiverTest.scala b/core/src/main/scala/spark/network/ReceiverTest.scala deleted file mode 100644 index 2bbc736f40..0000000000 --- a/core/src/main/scala/spark/network/ReceiverTest.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import java.nio.ByteBuffer -import java.net.InetAddress - -private[spark] object ReceiverTest { - - def main(args: Array[String]) { - val manager = new ConnectionManager(9999) - println("Started connection manager with id = " + manager.id) - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - /*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/ - val buffer = ByteBuffer.wrap("response".getBytes()) - Some(Message.createBufferMessage(buffer, msg.id)) - }) - Thread.currentThread.join() - } -} - diff --git a/core/src/main/scala/spark/network/SenderTest.scala b/core/src/main/scala/spark/network/SenderTest.scala deleted file mode 100644 index 542c54c36b..0000000000 --- a/core/src/main/scala/spark/network/SenderTest.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network - -import java.nio.ByteBuffer -import java.net.InetAddress - -private[spark] object SenderTest { - - def main(args: Array[String]) { - - if (args.length < 2) { - println("Usage: SenderTest ") - System.exit(1) - } - - val targetHost = args(0) - val targetPort = args(1).toInt - val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort) - - val manager = new ConnectionManager(0) - println("Started connection manager with id = " + manager.id) - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - println("Received [" + msg + "] from [" + id + "]") - None - }) - - val size = 100 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val targetServer = args(0) - - val count = 100 - (0 until count).foreach(i => { - val dataMessage = Message.createBufferMessage(buffer.duplicate) - val startTime = System.currentTimeMillis - /*println("Started timer at " + startTime)*/ - val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) match { - case Some(response) => - val buffer = response.asInstanceOf[BufferMessage].buffers(0) - new String(buffer.array) - case None => "none" - } - val finishTime = System.currentTimeMillis - val mb = size / 1024.0 / 1024.0 - val ms = finishTime - startTime - /*val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"*/ - val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" + (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr - println(resultStr) - }) - } -} - diff --git a/core/src/main/scala/spark/network/netty/FileHeader.scala b/core/src/main/scala/spark/network/netty/FileHeader.scala deleted file mode 100644 index bf46d32aa3..0000000000 --- a/core/src/main/scala/spark/network/netty/FileHeader.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty - -import io.netty.buffer._ - -import spark.Logging - -private[spark] class FileHeader ( - val fileLen: Int, - val blockId: String) extends Logging { - - lazy val buffer = { - val buf = Unpooled.buffer() - buf.capacity(FileHeader.HEADER_SIZE) - buf.writeInt(fileLen) - buf.writeInt(blockId.length) - blockId.foreach((x: Char) => buf.writeByte(x)) - //padding the rest of header - if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { - buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) - } else { - throw new Exception("too long header " + buf.readableBytes) - logInfo("too long header") - } - buf - } - -} - -private[spark] object FileHeader { - - val HEADER_SIZE = 40 - - def getFileLenOffset = 0 - def getFileLenSize = Integer.SIZE/8 - - def create(buf: ByteBuf): FileHeader = { - val length = buf.readInt - val idLength = buf.readInt - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buf.readByte().asInstanceOf[Char] - } - val blockId = idBuilder.toString() - new FileHeader(length, blockId) - } - - - def main (args:Array[String]){ - - val header = new FileHeader(25,"block_0"); - val buf = header.buffer; - val newheader = FileHeader.create(buf); - System.out.println("id="+newheader.blockId+",size="+newheader.fileLen) - - } -} - diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala deleted file mode 100644 index b01f6369f6..0000000000 --- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty - -import java.util.concurrent.Executors - -import io.netty.buffer.ByteBuf -import io.netty.channel.ChannelHandlerContext -import io.netty.util.CharsetUtil - -import spark.Logging -import spark.network.ConnectionManagerId - -import scala.collection.JavaConverters._ - - -private[spark] class ShuffleCopier extends Logging { - - def getBlock(host: String, port: Int, blockId: String, - resultCollectCallback: (String, Long, ByteBuf) => Unit) { - - val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) - val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt - val fc = new FileClient(handler, connectTimeout) - - try { - fc.init() - fc.connect(host, port) - fc.sendRequest(blockId) - fc.waitForClose() - fc.close() - } catch { - // Handle any socket-related exceptions in FileClient - case e: Exception => { - logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e) - handler.handleError(blockId) - } - } - } - - def getBlock(cmId: ConnectionManagerId, blockId: String, - resultCollectCallback: (String, Long, ByteBuf) => Unit) { - getBlock(cmId.host, cmId.port, blockId, resultCollectCallback) - } - - def getBlocks(cmId: ConnectionManagerId, - blocks: Seq[(String, Long)], - resultCollectCallback: (String, Long, ByteBuf) => Unit) { - - for ((blockId, size) <- blocks) { - getBlock(cmId, blockId, resultCollectCallback) - } - } -} - - -private[spark] object ShuffleCopier extends Logging { - - private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit) - extends FileClientHandler with Logging { - - override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { - logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); - resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) - } - - override def handleError(blockId: String) { - if (!isComplete) { - resultCollectCallBack(blockId, -1, null) - } - } - } - - def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { - if (size != -1) { - logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") - } - } - - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: ShuffleCopier ") - System.exit(1) - } - val host = args(0) - val port = args(1).toInt - val file = args(2) - val threads = if (args.length > 3) args(3).toInt else 10 - - val copiers = Executors.newFixedThreadPool(80) - val tasks = (for (i <- Range(0, threads)) yield { - Executors.callable(new Runnable() { - def run() { - val copier = new ShuffleCopier() - copier.getBlock(host, port, file, echoResultCollectCallBack) - } - }) - }).asJava - copiers.invokeAll(tasks) - copiers.shutdown - System.exit(0) - } -} diff --git a/core/src/main/scala/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/spark/network/netty/ShuffleSender.scala deleted file mode 100644 index cdf88b03a0..0000000000 --- a/core/src/main/scala/spark/network/netty/ShuffleSender.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.network.netty - -import java.io.File - -import spark.Logging - - -private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging { - - val server = new FileServer(pResolver, portIn) - server.start() - - def stop() { - server.stop() - } - - def port: Int = server.getPort() -} - - -/** - * An application for testing the shuffle sender as a standalone program. - */ -private[spark] object ShuffleSender { - - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println( - "Usage: ShuffleSender ") - System.exit(1) - } - - val port = args(0).toInt - val subDirsPerLocalDir = args(1).toInt - val localDirs = args.drop(2).map(new File(_)) - - val pResovler = new PathResolver { - override def getAbsolutePath(blockId: String): String = { - if (!blockId.startsWith("shuffle_")) { - throw new Exception("Block " + blockId + " is not a shuffle block") - } - // Figure out which local directory it hashes to, and which subdirectory in that - val hash = math.abs(blockId.hashCode) - val dirId = hash % localDirs.length - val subDirId = (hash / localDirs.length) % subDirsPerLocalDir - val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) - val file = new File(subDir, blockId) - return file.getAbsolutePath - } - } - val sender = new ShuffleSender(port, pResovler) - } -} diff --git a/core/src/main/scala/spark/package.scala b/core/src/main/scala/spark/package.scala deleted file mode 100644 index b244bfbf06..0000000000 --- a/core/src/main/scala/spark/package.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Core Spark functionality. [[spark.SparkContext]] serves as the main entry point to Spark, while - * [[spark.RDD]] is the data type representing a distributed collection, and provides most - * parallel operations. - * - * In addition, [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value - * pairs, such as `groupByKey` and `join`; [[spark.DoubleRDDFunctions]] contains operations - * available only on RDDs of Doubles; and [[spark.SequenceFileRDDFunctions]] contains operations - * available on RDDs that can be saved as SequenceFiles. These operations are automatically - * available on any RDD of the right type (e.g. RDD[(Int, Int)] through implicit conversions when - * you `import spark.SparkContext._`. - */ -package object spark { - // For package docs only -} diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala deleted file mode 100644 index 691d939150..0000000000 --- a/core/src/main/scala/spark/partial/ApproximateActionListener.scala +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -import spark._ -import spark.scheduler.JobListener - -/** - * A JobListener for an approximate single-result action, such as count() or non-parallel reduce(). - * This listener waits up to timeout milliseconds and will return a partial answer even if the - * complete answer is not available by then. - * - * This class assumes that the action is performed on an entire RDD[T] via a function that computes - * a result of type U for each partition, and that the action returns a partial or complete result - * of type R. Note that the type R must *include* any error bars on it (e.g. see BoundedInt). - */ -private[spark] class ApproximateActionListener[T, U, R]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - evaluator: ApproximateEvaluator[U, R], - timeout: Long) - extends JobListener { - - val startTime = System.currentTimeMillis() - val totalTasks = rdd.partitions.size - var finishedTasks = 0 - var failure: Option[Exception] = None // Set if the job has failed (permanently) - var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult - - override def taskSucceeded(index: Int, result: Any) { - synchronized { - evaluator.merge(index, result.asInstanceOf[U]) - finishedTasks += 1 - if (finishedTasks == totalTasks) { - // If we had already returned a PartialResult, set its final value - resultObject.foreach(r => r.setFinalValue(evaluator.currentResult())) - // Notify any waiting thread that may have called awaitResult - this.notifyAll() - } - } - } - - override def jobFailed(exception: Exception) { - synchronized { - failure = Some(exception) - this.notifyAll() - } - } - - /** - * Waits for up to timeout milliseconds since the listener was created and then returns a - * PartialResult with the result so far. This may be complete if the whole job is done. - */ - def awaitResult(): PartialResult[R] = synchronized { - val finishTime = startTime + timeout - while (true) { - val time = System.currentTimeMillis() - if (failure != None) { - throw failure.get - } else if (finishedTasks == totalTasks) { - return new PartialResult(evaluator.currentResult(), true) - } else if (time >= finishTime) { - resultObject = Some(new PartialResult(evaluator.currentResult(), false)) - return resultObject.get - } else { - this.wait(finishTime - time) - } - } - // Should never be reached, but required to keep the compiler happy - return null - } -} diff --git a/core/src/main/scala/spark/partial/ApproximateEvaluator.scala b/core/src/main/scala/spark/partial/ApproximateEvaluator.scala deleted file mode 100644 index 5eae144dfb..0000000000 --- a/core/src/main/scala/spark/partial/ApproximateEvaluator.scala +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -/** - * An object that computes a function incrementally by merging in results of type U from multiple - * tasks. Allows partial evaluation at any point by calling currentResult(). - */ -private[spark] trait ApproximateEvaluator[U, R] { - def merge(outputId: Int, taskResult: U): Unit - def currentResult(): R -} diff --git a/core/src/main/scala/spark/partial/BoundedDouble.scala b/core/src/main/scala/spark/partial/BoundedDouble.scala deleted file mode 100644 index 8bdbe6c012..0000000000 --- a/core/src/main/scala/spark/partial/BoundedDouble.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -/** - * A Double with error bars on it. - */ -class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) { - override def toString(): String = "[%.3f, %.3f]".format(low, high) -} diff --git a/core/src/main/scala/spark/partial/CountEvaluator.scala b/core/src/main/scala/spark/partial/CountEvaluator.scala deleted file mode 100644 index 6aa92094eb..0000000000 --- a/core/src/main/scala/spark/partial/CountEvaluator.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -import cern.jet.stat.Probability - -/** - * An ApproximateEvaluator for counts. - * - * TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might - * be best to make this a special case of GroupedCountEvaluator with one group. - */ -private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[Long, BoundedDouble] { - - var outputsMerged = 0 - var sum: Long = 0 - - override def merge(outputId: Int, taskResult: Long) { - outputsMerged += 1 - sum += taskResult - } - - override def currentResult(): BoundedDouble = { - if (outputsMerged == totalOutputs) { - new BoundedDouble(sum, 1.0, sum, sum) - } else if (outputsMerged == 0) { - new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) - } else { - val p = outputsMerged.toDouble / totalOutputs - val mean = (sum + 1 - p) / p - val variance = (sum + 1) * (1 - p) / (p * p) - val stdev = math.sqrt(variance) - val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - new BoundedDouble(mean, confidence, low, high) - } - } -} diff --git a/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala deleted file mode 100644 index ebe2e5a1e3..0000000000 --- a/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -import java.util.{HashMap => JHashMap} -import java.util.{Map => JMap} - -import scala.collection.Map -import scala.collection.mutable.HashMap -import scala.collection.JavaConversions.mapAsScalaMap - -import cern.jet.stat.Probability - -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - -/** - * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval. - */ -private[spark] class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] { - - var outputsMerged = 0 - var sums = new OLMap[T] // Sum of counts for each key - - override def merge(outputId: Int, taskResult: OLMap[T]) { - outputsMerged += 1 - val iter = taskResult.object2LongEntrySet.fastIterator() - while (iter.hasNext) { - val entry = iter.next() - sums.put(entry.getKey, sums.getLong(entry.getKey) + entry.getLongValue) - } - } - - override def currentResult(): Map[T, BoundedDouble] = { - if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.object2LongEntrySet.fastIterator() - while (iter.hasNext) { - val entry = iter.next() - val sum = entry.getLongValue() - result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum) - } - result - } else if (outputsMerged == 0) { - new HashMap[T, BoundedDouble] - } else { - val p = outputsMerged.toDouble / totalOutputs - val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2) - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.object2LongEntrySet.fastIterator() - while (iter.hasNext) { - val entry = iter.next() - val sum = entry.getLongValue - val mean = (sum + 1 - p) / p - val variance = (sum + 1) * (1 - p) / (p * p) - val stdev = math.sqrt(variance) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - result(entry.getKey) = new BoundedDouble(mean, confidence, low, high) - } - result - } - } -} diff --git a/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala deleted file mode 100644 index 2dadbbd5fb..0000000000 --- a/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -import java.util.{HashMap => JHashMap} -import java.util.{Map => JMap} - -import scala.collection.mutable.HashMap -import scala.collection.Map -import scala.collection.JavaConversions.mapAsScalaMap - -import spark.util.StatCounter - -/** - * An ApproximateEvaluator for means by key. Returns a map of key to confidence interval. - */ -private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { - - var outputsMerged = 0 - var sums = new JHashMap[T, StatCounter] // Sum of counts for each key - - override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) { - outputsMerged += 1 - val iter = taskResult.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val old = sums.get(entry.getKey) - if (old != null) { - old.merge(entry.getValue) - } else { - sums.put(entry.getKey, entry.getValue) - } - } - } - - override def currentResult(): Map[T, BoundedDouble] = { - if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val mean = entry.getValue.mean - result(entry.getKey) = new BoundedDouble(mean, 1.0, mean, mean) - } - result - } else if (outputsMerged == 0) { - new HashMap[T, BoundedDouble] - } else { - val p = outputsMerged.toDouble / totalOutputs - val studentTCacher = new StudentTCacher(confidence) - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val counter = entry.getValue - val mean = counter.mean - val stdev = math.sqrt(counter.sampleVariance / counter.count) - val confFactor = studentTCacher.get(counter.count) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - result(entry.getKey) = new BoundedDouble(mean, confidence, low, high) - } - result - } - } -} diff --git a/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala deleted file mode 100644 index ae2b63f7cb..0000000000 --- a/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -import java.util.{HashMap => JHashMap} -import java.util.{Map => JMap} - -import scala.collection.mutable.HashMap -import scala.collection.Map -import scala.collection.JavaConversions.mapAsScalaMap - -import spark.util.StatCounter - -/** - * An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval. - */ -private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { - - var outputsMerged = 0 - var sums = new JHashMap[T, StatCounter] // Sum of counts for each key - - override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) { - outputsMerged += 1 - val iter = taskResult.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val old = sums.get(entry.getKey) - if (old != null) { - old.merge(entry.getValue) - } else { - sums.put(entry.getKey, entry.getValue) - } - } - } - - override def currentResult(): Map[T, BoundedDouble] = { - if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val sum = entry.getValue.sum - result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum) - } - result - } else if (outputsMerged == 0) { - new HashMap[T, BoundedDouble] - } else { - val p = outputsMerged.toDouble / totalOutputs - val studentTCacher = new StudentTCacher(confidence) - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val counter = entry.getValue - val meanEstimate = counter.mean - val meanVar = counter.sampleVariance / counter.count - val countEstimate = (counter.count + 1 - p) / p - val countVar = (counter.count + 1) * (1 - p) / (p * p) - val sumEstimate = meanEstimate * countEstimate - val sumVar = (meanEstimate * meanEstimate * countVar) + - (countEstimate * countEstimate * meanVar) + - (meanVar * countVar) - val sumStdev = math.sqrt(sumVar) - val confFactor = studentTCacher.get(counter.count) - val low = sumEstimate - confFactor * sumStdev - val high = sumEstimate + confFactor * sumStdev - result(entry.getKey) = new BoundedDouble(sumEstimate, confidence, low, high) - } - result - } - } -} diff --git a/core/src/main/scala/spark/partial/MeanEvaluator.scala b/core/src/main/scala/spark/partial/MeanEvaluator.scala deleted file mode 100644 index 5ddcad7075..0000000000 --- a/core/src/main/scala/spark/partial/MeanEvaluator.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -import cern.jet.stat.Probability - -import spark.util.StatCounter - -/** - * An ApproximateEvaluator for means. - */ -private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[StatCounter, BoundedDouble] { - - var outputsMerged = 0 - var counter = new StatCounter - - override def merge(outputId: Int, taskResult: StatCounter) { - outputsMerged += 1 - counter.merge(taskResult) - } - - override def currentResult(): BoundedDouble = { - if (outputsMerged == totalOutputs) { - new BoundedDouble(counter.mean, 1.0, counter.mean, counter.mean) - } else if (outputsMerged == 0) { - new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) - } else { - val mean = counter.mean - val stdev = math.sqrt(counter.sampleVariance / counter.count) - val confFactor = { - if (counter.count > 100) { - Probability.normalInverse(1 - (1 - confidence) / 2) - } else { - Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt) - } - } - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - new BoundedDouble(mean, confidence, low, high) - } - } -} diff --git a/core/src/main/scala/spark/partial/PartialResult.scala b/core/src/main/scala/spark/partial/PartialResult.scala deleted file mode 100644 index 922a9f9bc6..0000000000 --- a/core/src/main/scala/spark/partial/PartialResult.scala +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -class PartialResult[R](initialVal: R, isFinal: Boolean) { - private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None - private var failure: Option[Exception] = None - private var completionHandler: Option[R => Unit] = None - private var failureHandler: Option[Exception => Unit] = None - - def initialValue: R = initialVal - - def isInitialValueFinal: Boolean = isFinal - - /** - * Blocking method to wait for and return the final value. - */ - def getFinalValue(): R = synchronized { - while (finalValue == None && failure == None) { - this.wait() - } - if (finalValue != None) { - return finalValue.get - } else { - throw failure.get - } - } - - /** - * Set a handler to be called when this PartialResult completes. Only one completion handler - * is supported per PartialResult. - */ - def onComplete(handler: R => Unit): PartialResult[R] = synchronized { - if (completionHandler != None) { - throw new UnsupportedOperationException("onComplete cannot be called twice") - } - completionHandler = Some(handler) - if (finalValue != None) { - // We already have a final value, so let's call the handler - handler(finalValue.get) - } - return this - } - - /** - * Set a handler to be called if this PartialResult's job fails. Only one failure handler - * is supported per PartialResult. - */ - def onFail(handler: Exception => Unit) { - synchronized { - if (failureHandler != None) { - throw new UnsupportedOperationException("onFail cannot be called twice") - } - failureHandler = Some(handler) - if (failure != None) { - // We already have a failure, so let's call the handler - handler(failure.get) - } - } - } - - /** - * Transform this PartialResult into a PartialResult of type T. - */ - def map[T](f: R => T) : PartialResult[T] = { - new PartialResult[T](f(initialVal), isFinal) { - override def getFinalValue() : T = synchronized { - f(PartialResult.this.getFinalValue()) - } - override def onComplete(handler: T => Unit): PartialResult[T] = synchronized { - PartialResult.this.onComplete(handler.compose(f)).map(f) - } - override def onFail(handler: Exception => Unit) { - synchronized { - PartialResult.this.onFail(handler) - } - } - override def toString : String = synchronized { - PartialResult.this.getFinalValueInternal() match { - case Some(value) => "(final: " + f(value) + ")" - case None => "(partial: " + initialValue + ")" - } - } - def getFinalValueInternal() = PartialResult.this.getFinalValueInternal().map(f) - } - } - - private[spark] def setFinalValue(value: R) { - synchronized { - if (finalValue != None) { - throw new UnsupportedOperationException("setFinalValue called twice on a PartialResult") - } - finalValue = Some(value) - // Call the completion handler if it was set - completionHandler.foreach(h => h(value)) - // Notify any threads that may be calling getFinalValue() - this.notifyAll() - } - } - - private def getFinalValueInternal() = finalValue - - private[spark] def setFailure(exception: Exception) { - synchronized { - if (failure != None) { - throw new UnsupportedOperationException("setFailure called twice on a PartialResult") - } - failure = Some(exception) - // Call the failure handler if it was set - failureHandler.foreach(h => h(exception)) - // Notify any threads that may be calling getFinalValue() - this.notifyAll() - } - } - - override def toString: String = synchronized { - finalValue match { - case Some(value) => "(final: " + value + ")" - case None => "(partial: " + initialValue + ")" - } - } -} diff --git a/core/src/main/scala/spark/partial/StudentTCacher.scala b/core/src/main/scala/spark/partial/StudentTCacher.scala deleted file mode 100644 index f3bb987d46..0000000000 --- a/core/src/main/scala/spark/partial/StudentTCacher.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -import cern.jet.stat.Probability - -/** - * A utility class for caching Student's T distribution values for a given confidence level - * and various sample sizes. This is used by the MeanEvaluator to efficiently calculate - * confidence intervals for many keys. - */ -private[spark] class StudentTCacher(confidence: Double) { - val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation - val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2) - val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0) - - def get(sampleSize: Long): Double = { - if (sampleSize >= NORMAL_APPROX_SAMPLE_SIZE) { - normalApprox - } else { - val size = sampleSize.toInt - if (cache(size) < 0) { - cache(size) = Probability.studentTInverse(1 - confidence, size - 1) - } - cache(size) - } - } -} diff --git a/core/src/main/scala/spark/partial/SumEvaluator.scala b/core/src/main/scala/spark/partial/SumEvaluator.scala deleted file mode 100644 index 4083abef03..0000000000 --- a/core/src/main/scala/spark/partial/SumEvaluator.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.partial - -import cern.jet.stat.Probability - -import spark.util.StatCounter - -/** - * An ApproximateEvaluator for sums. It estimates the mean and the cont and multiplies them - * together, then uses the formula for the variance of two independent random variables to get - * a variance for the result and compute a confidence interval. - */ -private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[StatCounter, BoundedDouble] { - - var outputsMerged = 0 - var counter = new StatCounter - - override def merge(outputId: Int, taskResult: StatCounter) { - outputsMerged += 1 - counter.merge(taskResult) - } - - override def currentResult(): BoundedDouble = { - if (outputsMerged == totalOutputs) { - new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum) - } else if (outputsMerged == 0) { - new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) - } else { - val p = outputsMerged.toDouble / totalOutputs - val meanEstimate = counter.mean - val meanVar = counter.sampleVariance / counter.count - val countEstimate = (counter.count + 1 - p) / p - val countVar = (counter.count + 1) * (1 - p) / (p * p) - val sumEstimate = meanEstimate * countEstimate - val sumVar = (meanEstimate * meanEstimate * countVar) + - (countEstimate * countEstimate * meanVar) + - (meanVar * countVar) - val sumStdev = math.sqrt(sumVar) - val confFactor = { - if (counter.count > 100) { - Probability.normalInverse(1 - (1 - confidence) / 2) - } else { - Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt) - } - } - val low = sumEstimate - confFactor * sumStdev - val high = sumEstimate + confFactor * sumStdev - new BoundedDouble(sumEstimate, confidence, low, high) - } - } -} diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala deleted file mode 100644 index 03800584ae..0000000000 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext} -import spark.storage.BlockManager - -private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition { - val index = idx -} - -private[spark] -class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) - extends RDD[T](sc, Nil) { - - @transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) - - override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => { - new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] - }).toArray - - override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val blockManager = SparkEnv.get.blockManager - val blockId = split.asInstanceOf[BlockRDDPartition].blockId - blockManager.get(blockId) match { - case Some(block) => block.asInstanceOf[Iterator[T]] - case None => - throw new Exception("Could not compute split, block " + blockId + " not found") - } - } - - override def getPreferredLocations(split: Partition): Seq[String] = { - locations_(split.asInstanceOf[BlockRDDPartition].blockId) - } -} - diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala deleted file mode 100644 index 91b3e69d6f..0000000000 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import java.io.{ObjectOutputStream, IOException} -import spark._ - - -private[spark] -class CartesianPartition( - idx: Int, - @transient rdd1: RDD[_], - @transient rdd2: RDD[_], - s1Index: Int, - s2Index: Int - ) extends Partition { - var s1 = rdd1.partitions(s1Index) - var s2 = rdd2.partitions(s2Index) - override val index: Int = idx - - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - // Update the reference to parent split at the time of task serialization - s1 = rdd1.partitions(s1Index) - s2 = rdd2.partitions(s2Index) - oos.defaultWriteObject() - } -} - -private[spark] -class CartesianRDD[T: ClassManifest, U:ClassManifest]( - sc: SparkContext, - var rdd1 : RDD[T], - var rdd2 : RDD[U]) - extends RDD[Pair[T, U]](sc, Nil) - with Serializable { - - val numPartitionsInRdd2 = rdd2.partitions.size - - override def getPartitions: Array[Partition] = { - // create the cross product split - val array = new Array[Partition](rdd1.partitions.size * rdd2.partitions.size) - for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) { - val idx = s1.index * numPartitionsInRdd2 + s2.index - array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index) - } - array - } - - override def getPreferredLocations(split: Partition): Seq[String] = { - val currSplit = split.asInstanceOf[CartesianPartition] - (rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct - } - - override def compute(split: Partition, context: TaskContext) = { - val currSplit = split.asInstanceOf[CartesianPartition] - for (x <- rdd1.iterator(currSplit.s1, context); - y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) - } - - override def getDependencies: Seq[Dependency[_]] = List( - new NarrowDependency(rdd1) { - def getParents(id: Int): Seq[Int] = List(id / numPartitionsInRdd2) - }, - new NarrowDependency(rdd2) { - def getParents(id: Int): Seq[Int] = List(id % numPartitionsInRdd2) - } - ) - - override def clearDependencies() { - super.clearDependencies() - rdd1 = null - rdd2 = null - } -} diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala deleted file mode 100644 index 1ad5fe6539..0000000000 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark._ -import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.io.{NullWritable, BytesWritable} -import org.apache.hadoop.util.ReflectionUtils -import org.apache.hadoop.fs.Path -import java.io.{File, IOException, EOFException} -import java.text.NumberFormat - -private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} - -/** - * This RDD represents a RDD checkpoint file (similar to HadoopRDD). - */ -private[spark] -class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String) - extends RDD[T](sc, Nil) { - - @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) - - override def getPartitions: Array[Partition] = { - val cpath = new Path(checkpointPath) - val numPartitions = - // listStatus can throw exception if path does not exist. - if (fs.exists(cpath)) { - val dirContents = fs.listStatus(cpath) - val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted - val numPart = partitionFiles.size - if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || - ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { - throw new SparkException("Invalid checkpoint directory: " + checkpointPath) - } - numPart - } else 0 - - Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i)) - } - - checkpointData = Some(new RDDCheckpointData[T](this)) - checkpointData.get.cpFile = Some(checkpointPath) - - override def getPreferredLocations(split: Partition): Seq[String] = { - val status = fs.getFileStatus(new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))) - val locations = fs.getFileBlockLocations(status, 0, status.getLen) - locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost") - } - - override def compute(split: Partition, context: TaskContext): Iterator[T] = { - val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)) - CheckpointRDD.readFromFile(file, context) - } - - override def checkpoint() { - // Do nothing. CheckpointRDD should not be checkpointed. - } -} - -private[spark] object CheckpointRDD extends Logging { - - def splitIdToFile(splitId: Int): String = { - "part-%05d".format(splitId) - } - - def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { - val env = SparkEnv.get - val outputDir = new Path(path) - val fs = outputDir.getFileSystem(env.hadoop.newConfiguration()) - - val finalOutputName = splitIdToFile(ctx.splitId) - val finalOutputPath = new Path(outputDir, finalOutputName) - val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId) - - if (fs.exists(tempOutputPath)) { - throw new IOException("Checkpoint failed: temporary path " + - tempOutputPath + " already exists") - } - val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - - val fileOutputStream = if (blockSize < 0) { - fs.create(tempOutputPath, false, bufferSize) - } else { - // This is mainly for testing purpose - fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) - } - val serializer = env.serializer.newInstance() - val serializeStream = serializer.serializeStream(fileOutputStream) - serializeStream.writeAll(iterator) - serializeStream.close() - - if (!fs.rename(tempOutputPath, finalOutputPath)) { - if (!fs.exists(finalOutputPath)) { - logInfo("Deleting tempOutputPath " + tempOutputPath) - fs.delete(tempOutputPath, false) - throw new IOException("Checkpoint failed: failed to save output of task: " - + ctx.attemptId + " and final output path does not exist") - } else { - // Some other copy of this task must've finished before us and renamed it - logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it") - fs.delete(tempOutputPath, false) - } - } - } - - def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = { - val env = SparkEnv.get - val fs = path.getFileSystem(env.hadoop.newConfiguration()) - val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - val fileInputStream = fs.open(path, bufferSize) - val serializer = env.serializer.newInstance() - val deserializeStream = serializer.deserializeStream(fileInputStream) - - // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => deserializeStream.close()) - - deserializeStream.asIterator.asInstanceOf[Iterator[T]] - } - - // Test whether CheckpointRDD generate expected number of partitions despite - // each split file having multiple blocks. This needs to be run on a - // cluster (mesos or standalone) using HDFS. - def main(args: Array[String]) { - import spark._ - - val Array(cluster, hdfsPath) = args - val env = SparkEnv.get - val sc = new SparkContext(cluster, "CheckpointRDD Test") - val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) - val path = new Path(hdfsPath, "temp") - val fs = path.getFileSystem(env.hadoop.newConfiguration()) - sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _) - val cpRDD = new CheckpointRDD[Int](sc, path.toString) - assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") - assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same") - fs.delete(path, true) - } -} diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala deleted file mode 100644 index 01b6c23dcc..0000000000 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import java.io.{ObjectOutputStream, IOException} -import java.util.{HashMap => JHashMap} - -import scala.collection.JavaConversions -import scala.collection.mutable.ArrayBuffer - -import spark.{Partition, Partitioner, RDD, SparkEnv, TaskContext} -import spark.{Dependency, OneToOneDependency, ShuffleDependency} - - -private[spark] sealed trait CoGroupSplitDep extends Serializable - -private[spark] case class NarrowCoGroupSplitDep( - rdd: RDD[_], - splitIndex: Int, - var split: Partition - ) extends CoGroupSplitDep { - - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - // Update the reference to parent split at the time of task serialization - split = rdd.partitions(splitIndex) - oos.defaultWriteObject() - } -} - -private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep - -private[spark] -class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) - extends Partition with Serializable { - override val index: Int = idx - override def hashCode(): Int = idx -} - - -/** - * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a - * tuple with the list of values for that key. - * - * @param rdds parent RDDs. - * @param part partitioner used to partition the shuffle output. - */ -class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) - extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { - - private var serializerClass: String = null - - def setSerializer(cls: String): CoGroupedRDD[K] = { - serializerClass = cls - this - } - - override def getDependencies: Seq[Dependency[_]] = { - rdds.map { rdd: RDD[_ <: Product2[K, _]] => - if (rdd.partitioner == Some(part)) { - logDebug("Adding one-to-one dependency with " + rdd) - new OneToOneDependency(rdd) - } else { - logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency[Any, Any](rdd, part, serializerClass) - } - } - } - - override def getPartitions: Array[Partition] = { - val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { - // Each CoGroupPartition will have a dependency per contributing RDD - array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) => - // Assume each RDD contributed a single dependency, and get it - dependencies(j) match { - case s: ShuffleDependency[_, _] => - new ShuffleCoGroupSplitDep(s.shuffleId) - case _ => - new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) - } - }.toArray) - } - array - } - - override val partitioner = Some(part) - - override def compute(s: Partition, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { - val split = s.asInstanceOf[CoGroupPartition] - val numRdds = split.deps.size - // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs) - val map = new JHashMap[K, Seq[ArrayBuffer[Any]]] - - def getSeq(k: K): Seq[ArrayBuffer[Any]] = { - val seq = map.get(k) - if (seq != null) { - seq - } else { - val seq = Array.fill(numRdds)(new ArrayBuffer[Any]) - map.put(k, seq) - seq - } - } - - val ser = SparkEnv.get.serializerManager.get(serializerClass) - for ((dep, depNum) <- split.deps.zipWithIndex) dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { - // Read them from the parent - rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]].foreach { kv => - getSeq(kv._1)(depNum) += kv._2 - } - } - case ShuffleCoGroupSplitDep(shuffleId) => { - // Read map outputs of shuffle - val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach { - kv => getSeq(kv._1)(depNum) += kv._2 - } - } - } - JavaConversions.mapAsScalaMap(map).iterator - } - - override def clearDependencies() { - super.clearDependencies() - rdds = null - } -} diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala deleted file mode 100644 index e612d026b2..0000000000 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ /dev/null @@ -1,342 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark._ -import java.io.{ObjectOutputStream, IOException} -import scala.collection.mutable -import scala.Some -import scala.collection.mutable.ArrayBuffer - -/** - * Class that captures a coalesced RDD by essentially keeping track of parent partitions - * @param index of this coalesced partition - * @param rdd which it belongs to - * @param parentsIndices list of indices in the parent that have been coalesced into this partition - * @param preferredLocation the preferred location for this partition - */ -case class CoalescedRDDPartition( - index: Int, - @transient rdd: RDD[_], - parentsIndices: Array[Int], - @transient preferredLocation: String = "" - ) extends Partition { - var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_)) - - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - // Update the reference to parent partition at the time of task serialization - parents = parentsIndices.map(rdd.partitions(_)) - oos.defaultWriteObject() - } - - /** - * Computes how many of the parents partitions have getPreferredLocation - * as one of their preferredLocations - * @return locality of this coalesced partition between 0 and 1 - */ - def localFraction: Double = { - val loc = parents.count(p => - rdd.context.getPreferredLocs(rdd, p.index).map(tl => tl.host).contains(preferredLocation)) - - if (parents.size == 0) 0.0 else (loc.toDouble / parents.size.toDouble) - } -} - -/** - * Represents a coalesced RDD that has fewer partitions than its parent RDD - * This class uses the PartitionCoalescer class to find a good partitioning of the parent RDD - * so that each new partition has roughly the same number of parent partitions and that - * the preferred location of each new partition overlaps with as many preferred locations of its - * parent partitions - * @param prev RDD to be coalesced - * @param maxPartitions number of desired partitions in the coalesced RDD - * @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance - */ -class CoalescedRDD[T: ClassManifest]( - @transient var prev: RDD[T], - maxPartitions: Int, - balanceSlack: Double = 0.10) - extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies - - override def getPartitions: Array[Partition] = { - val pc = new PartitionCoalescer(maxPartitions, prev, balanceSlack) - - pc.run().zipWithIndex.map { - case (pg, i) => - val ids = pg.arr.map(_.index).toArray - new CoalescedRDDPartition(i, prev, ids, pg.prefLoc) - } - } - - override def compute(partition: Partition, context: TaskContext): Iterator[T] = { - partition.asInstanceOf[CoalescedRDDPartition].parents.iterator.flatMap { parentPartition => - firstParent[T].iterator(parentPartition, context) - } - } - - override def getDependencies: Seq[Dependency[_]] = { - Seq(new NarrowDependency(prev) { - def getParents(id: Int): Seq[Int] = - partitions(id).asInstanceOf[CoalescedRDDPartition].parentsIndices - }) - } - - override def clearDependencies() { - super.clearDependencies() - prev = null - } - - /** - * Returns the preferred machine for the partition. If split is of type CoalescedRDDPartition, - * then the preferred machine will be one which most parent splits prefer too. - * @param partition - * @return the machine most preferred by split - */ - override def getPreferredLocations(partition: Partition): Seq[String] = { - List(partition.asInstanceOf[CoalescedRDDPartition].preferredLocation) - } -} - -/** - * Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of - * this RDD computes one or more of the parent ones. It will produce exactly `maxPartitions` if the - * parent had more than maxPartitions, or fewer if the parent had fewer. - * - * This transformation is useful when an RDD with many partitions gets filtered into a smaller one, - * or to avoid having a large number of small tasks when processing a directory with many files. - * - * If there is no locality information (no preferredLocations) in the parent, then the coalescing - * is very simple: chunk parents that are close in the Array in chunks. - * If there is locality information, it proceeds to pack them with the following four goals: - * - * (1) Balance the groups so they roughly have the same number of parent partitions - * (2) Achieve locality per partition, i.e. find one machine which most parent partitions prefer - * (3) Be efficient, i.e. O(n) algorithm for n parent partitions (problem is likely NP-hard) - * (4) Balance preferred machines, i.e. avoid as much as possible picking the same preferred machine - * - * Furthermore, it is assumed that the parent RDD may have many partitions, e.g. 100 000. - * We assume the final number of desired partitions is small, e.g. less than 1000. - * - * The algorithm tries to assign unique preferred machines to each partition. If the number of - * desired partitions is greater than the number of preferred machines (can happen), it needs to - * start picking duplicate preferred machines. This is determined using coupon collector estimation - * (2n log(n)). The load balancing is done using power-of-two randomized bins-balls with one twist: - * it tries to also achieve locality. This is done by allowing a slack (balanceSlack) between two - * bins. If two bins are within the slack in terms of balance, the algorithm will assign partitions - * according to locality. (contact alig for questions) - * - */ - -private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) { - - def compare(o1: PartitionGroup, o2: PartitionGroup): Boolean = o1.size < o2.size - def compare(o1: Option[PartitionGroup], o2: Option[PartitionGroup]): Boolean = - if (o1 == None) false else if (o2 == None) true else compare(o1.get, o2.get) - - val rnd = new scala.util.Random(7919) // keep this class deterministic - - // each element of groupArr represents one coalesced partition - val groupArr = ArrayBuffer[PartitionGroup]() - - // hash used to check whether some machine is already in groupArr - val groupHash = mutable.Map[String, ArrayBuffer[PartitionGroup]]() - - // hash used for the first maxPartitions (to avoid duplicates) - val initialHash = mutable.Set[Partition]() - - // determines the tradeoff between load-balancing the partitions sizes and their locality - // e.g. balanceSlack=0.10 means that it allows up to 10% imbalance in favor of locality - val slack = (balanceSlack * prev.partitions.size).toInt - - var noLocality = true // if true if no preferredLocations exists for parent RDD - - // gets the *current* preferred locations from the DAGScheduler (as opposed to the static ones) - def currPrefLocs(part: Partition): Seq[String] = { - prev.context.getPreferredLocs(prev, part.index).map(tl => tl.host) - } - - // this class just keeps iterating and rotating infinitely over the partitions of the RDD - // next() returns the next preferred machine that a partition is replicated on - // the rotator first goes through the first replica copy of each partition, then second, third - // the iterators return type is a tuple: (replicaString, partition) - class LocationIterator(prev: RDD[_]) extends Iterator[(String, Partition)] { - - var it: Iterator[(String, Partition)] = resetIterator() - - override val isEmpty = !it.hasNext - - // initializes/resets to start iterating from the beginning - def resetIterator() = { - val iterators = (0 to 2).map( x => - prev.partitions.iterator.flatMap(p => { - if (currPrefLocs(p).size > x) Some((currPrefLocs(p)(x), p)) else None - } ) - ) - iterators.reduceLeft((x, y) => x ++ y) - } - - // hasNext() is false iff there are no preferredLocations for any of the partitions of the RDD - def hasNext(): Boolean = { !isEmpty } - - // return the next preferredLocation of some partition of the RDD - def next(): (String, Partition) = { - if (it.hasNext) - it.next() - else { - it = resetIterator() // ran out of preferred locations, reset and rotate to the beginning - it.next() - } - } - } - - /** - * Sorts and gets the least element of the list associated with key in groupHash - * The returned PartitionGroup is the least loaded of all groups that represent the machine "key" - * @param key string representing a partitioned group on preferred machine key - * @return Option of PartitionGroup that has least elements for key - */ - def getLeastGroupHash(key: String): Option[PartitionGroup] = { - groupHash.get(key).map(_.sortWith(compare).head) - } - - def addPartToPGroup(part: Partition, pgroup: PartitionGroup): Boolean = { - if (!initialHash.contains(part)) { - pgroup.arr += part // already assign this element - initialHash += part // needed to avoid assigning partitions to multiple buckets - true - } else { false } - } - - /** - * Initializes targetLen partition groups and assigns a preferredLocation - * This uses coupon collector to estimate how many preferredLocations it must rotate through - * until it has seen most of the preferred locations (2 * n log(n)) - * @param targetLen - */ - def setupGroups(targetLen: Int) { - val rotIt = new LocationIterator(prev) - - // deal with empty case, just create targetLen partition groups with no preferred location - if (!rotIt.hasNext()) { - (1 to targetLen).foreach(x => groupArr += PartitionGroup()) - return - } - - noLocality = false - - // number of iterations needed to be certain that we've seen most preferred locations - val expectedCoupons2 = 2 * (math.log(targetLen)*targetLen + targetLen + 0.5).toInt - var numCreated = 0 - var tries = 0 - - // rotate through until either targetLen unique/distinct preferred locations have been created - // OR we've rotated expectedCoupons2, in which case we have likely seen all preferred locations, - // i.e. likely targetLen >> number of preferred locations (more buckets than there are machines) - while (numCreated < targetLen && tries < expectedCoupons2) { - tries += 1 - val (nxt_replica, nxt_part) = rotIt.next() - if (!groupHash.contains(nxt_replica)) { - val pgroup = PartitionGroup(nxt_replica) - groupArr += pgroup - addPartToPGroup(nxt_part, pgroup) - groupHash += (nxt_replica -> (ArrayBuffer(pgroup))) // list in case we have multiple - numCreated += 1 - } - } - - while (numCreated < targetLen) { // if we don't have enough partition groups, create duplicates - var (nxt_replica, nxt_part) = rotIt.next() - val pgroup = PartitionGroup(nxt_replica) - groupArr += pgroup - groupHash.get(nxt_replica).get += pgroup - var tries = 0 - while (!addPartToPGroup(nxt_part, pgroup) && tries < targetLen) { // ensure at least one part - nxt_part = rotIt.next()._2 - tries += 1 - } - numCreated += 1 - } - - } - - /** - * Takes a parent RDD partition and decides which of the partition groups to put it in - * Takes locality into account, but also uses power of 2 choices to load balance - * It strikes a balance between the two use the balanceSlack variable - * @param p partition (ball to be thrown) - * @return partition group (bin to be put in) - */ - def pickBin(p: Partition): PartitionGroup = { - val pref = currPrefLocs(p).map(getLeastGroupHash(_)).sortWith(compare) // least loaded pref locs - val prefPart = if (pref == Nil) None else pref.head - - val r1 = rnd.nextInt(groupArr.size) - val r2 = rnd.nextInt(groupArr.size) - val minPowerOfTwo = if (groupArr(r1).size < groupArr(r2).size) groupArr(r1) else groupArr(r2) - if (prefPart== None) // if no preferred locations, just use basic power of two - return minPowerOfTwo - - val prefPartActual = prefPart.get - - if (minPowerOfTwo.size + slack <= prefPartActual.size) // more imbalance than the slack allows - return minPowerOfTwo // prefer balance over locality - else { - return prefPartActual // prefer locality over balance - } - } - - def throwBalls() { - if (noLocality) { // no preferredLocations in parent RDD, no randomization needed - if (maxPartitions > groupArr.size) { // just return prev.partitions - for ((p,i) <- prev.partitions.zipWithIndex) { - groupArr(i).arr += p - } - } else { // no locality available, then simply split partitions based on positions in array - for(i <- 0 until maxPartitions) { - val rangeStart = ((i.toLong * prev.partitions.length) / maxPartitions).toInt - val rangeEnd = (((i.toLong + 1) * prev.partitions.length) / maxPartitions).toInt - (rangeStart until rangeEnd).foreach{ j => groupArr(i).arr += prev.partitions(j) } - } - } - } else { - for (p <- prev.partitions if (!initialHash.contains(p))) { // throw every partition into group - pickBin(p).arr += p - } - } - } - - def getPartitions: Array[PartitionGroup] = groupArr.filter( pg => pg.size > 0).toArray - - /** - * Runs the packing algorithm and returns an array of PartitionGroups that if possible are - * load balanced and grouped by locality - * @return array of partition groups - */ - def run(): Array[PartitionGroup] = { - setupGroups(math.min(prev.partitions.length, maxPartitions)) // setup the groups (bins) - throwBalls() // assign partitions (balls) to each group (bins) - getPartitions - } -} - -private[spark] case class PartitionGroup(prefLoc: String = "") { - var arr = mutable.ArrayBuffer[Partition]() - - def size = arr.size -} diff --git a/core/src/main/scala/spark/rdd/EmptyRDD.scala b/core/src/main/scala/spark/rdd/EmptyRDD.scala deleted file mode 100644 index d7d4db5d30..0000000000 --- a/core/src/main/scala/spark/rdd/EmptyRDD.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext} - - -/** - * An RDD that is empty, i.e. has no element in it. - */ -class EmptyRDD[T: ClassManifest](sc: SparkContext) extends RDD[T](sc, Nil) { - - override def getPartitions: Array[Partition] = Array.empty - - override def compute(split: Partition, context: TaskContext): Iterator[T] = { - throw new UnsupportedOperationException("empty RDD") - } -} diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala deleted file mode 100644 index 783508cfd1..0000000000 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{OneToOneDependency, RDD, Partition, TaskContext} - -private[spark] class FilteredRDD[T: ClassManifest]( - prev: RDD[T], - f: T => Boolean) - extends RDD[T](prev) { - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override val partitioner = prev.partitioner // Since filter cannot change a partition's keys - - override def compute(split: Partition, context: TaskContext) = - firstParent[T].iterator(split, context).filter(f) -} diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala deleted file mode 100644 index ed75eac3ff..0000000000 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{RDD, Partition, TaskContext} - - -private[spark] -class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], - f: T => TraversableOnce[U]) - extends RDD[U](prev) { - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override def compute(split: Partition, context: TaskContext) = - firstParent[T].iterator(split, context).flatMap(f) -} diff --git a/core/src/main/scala/spark/rdd/FlatMappedValuesRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedValuesRDD.scala deleted file mode 100644 index a6bdce89d8..0000000000 --- a/core/src/main/scala/spark/rdd/FlatMappedValuesRDD.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{TaskContext, Partition, RDD} - - -private[spark] -class FlatMappedValuesRDD[K, V, U](prev: RDD[_ <: Product2[K, V]], f: V => TraversableOnce[U]) - extends RDD[(K, U)](prev) { - - override def getPartitions = firstParent[Product2[K, V]].partitions - - override val partitioner = firstParent[Product2[K, V]].partitioner - - override def compute(split: Partition, context: TaskContext) = { - firstParent[Product2[K, V]].iterator(split, context).flatMap { case Product2(k, v) => - f(v).map(x => (k, x)) - } - } -} diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala deleted file mode 100644 index 1573f8a289..0000000000 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{RDD, Partition, TaskContext} - -private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T]) - extends RDD[Array[T]](prev) { - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override def compute(split: Partition, context: TaskContext) = - Array(firstParent[T].iterator(split, context).toArray).iterator -} diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala deleted file mode 100644 index e512423fd6..0000000000 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import java.io.EOFException -import java.util.NoSuchElementException - -import org.apache.hadoop.io.LongWritable -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.FileInputFormat -import org.apache.hadoop.mapred.InputFormat -import org.apache.hadoop.mapred.InputSplit -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.TextInputFormat -import org.apache.hadoop.mapred.RecordReader -import org.apache.hadoop.mapred.Reporter -import org.apache.hadoop.util.ReflectionUtils - -import spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, SparkEnv, TaskContext} -import spark.util.NextIterator -import org.apache.hadoop.conf.{Configuration, Configurable} - - -/** - * A Spark split class that wraps around a Hadoop InputSplit. - */ -private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit) - extends Partition { - - val inputSplit = new SerializableWritable[InputSplit](s) - - override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt - - override val index: Int = idx -} - -/** - * An RDD that reads a Hadoop dataset as specified by a JobConf (e.g. files in HDFS, the local file - * system, or S3, tables in HBase, etc). - */ -class HadoopRDD[K, V]( - sc: SparkContext, - @transient conf: JobConf, - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int) - extends RDD[(K, V)](sc, Nil) with Logging { - - // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = sc.broadcast(new SerializableWritable(conf)) - - override def getPartitions: Array[Partition] = { - val env = SparkEnv.get - env.hadoop.addCredentials(conf) - val inputFormat = createInputFormat(conf) - if (inputFormat.isInstanceOf[Configurable]) { - inputFormat.asInstanceOf[Configurable].setConf(conf) - } - val inputSplits = inputFormat.getSplits(conf, minSplits) - val array = new Array[Partition](inputSplits.size) - for (i <- 0 until inputSplits.size) { - array(i) = new HadoopPartition(id, i, inputSplits(i)) - } - array - } - - def createInputFormat(conf: JobConf): InputFormat[K, V] = { - ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf) - .asInstanceOf[InputFormat[K, V]] - } - - override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] { - val split = theSplit.asInstanceOf[HadoopPartition] - logInfo("Input split: " + split.inputSplit) - var reader: RecordReader[K, V] = null - - val conf = confBroadcast.value.value - val fmt = createInputFormat(conf) - if (fmt.isInstanceOf[Configurable]) { - fmt.asInstanceOf[Configurable].setConf(conf) - } - reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) - - // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback{ () => closeIfNeeded() } - - val key: K = reader.createKey() - val value: V = reader.createValue() - - override def getNext() = { - try { - finished = !reader.next(key, value) - } catch { - case eof: EOFException => - finished = true - } - (key, value) - } - - override def close() { - try { - reader.close() - } catch { - case e: Exception => logWarning("Exception in RecordReader.close()", e) - } - } - } - - override def getPreferredLocations(split: Partition): Seq[String] = { - // TODO: Filtering out "localhost" in case of file:// URLs - val hadoopSplit = split.asInstanceOf[HadoopPartition] - hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") - } - - override def checkpoint() { - // Do nothing. Hadoop RDD should not be checkpointed. - } - - def getConf: Configuration = confBroadcast.value.value -} diff --git a/core/src/main/scala/spark/rdd/JdbcRDD.scala b/core/src/main/scala/spark/rdd/JdbcRDD.scala deleted file mode 100644 index 59132437d2..0000000000 --- a/core/src/main/scala/spark/rdd/JdbcRDD.scala +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import java.sql.{Connection, ResultSet} - -import spark.{Logging, Partition, RDD, SparkContext, TaskContext} -import spark.util.NextIterator - -private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { - override def index = idx -} - -/** - * An RDD that executes an SQL query on a JDBC connection and reads results. - * For usage example, see test case JdbcRDDSuite. - * - * @param getConnection a function that returns an open Connection. - * The RDD takes care of closing the connection. - * @param sql the text of the query. - * The query must contain two ? placeholders for parameters used to partition the results. - * E.g. "select title, author from books where ? <= id and id <= ?" - * @param lowerBound the minimum value of the first placeholder - * @param upperBound the maximum value of the second placeholder - * The lower and upper bounds are inclusive. - * @param numPartitions the number of partitions. - * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, - * the query would be executed twice, once with (1, 10) and once with (11, 20) - * @param mapRow a function from a ResultSet to a single row of the desired result type(s). - * This should only call getInt, getString, etc; the RDD takes care of calling next. - * The default maps a ResultSet to an array of Object. - */ -class JdbcRDD[T: ClassManifest]( - sc: SparkContext, - getConnection: () => Connection, - sql: String, - lowerBound: Long, - upperBound: Long, - numPartitions: Int, - mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _) - extends RDD[T](sc, Nil) with Logging { - - override def getPartitions: Array[Partition] = { - // bounds are inclusive, hence the + 1 here and - 1 on end - val length = 1 + upperBound - lowerBound - (0 until numPartitions).map(i => { - val start = lowerBound + ((i * length) / numPartitions).toLong - val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1 - new JdbcPartition(i, start, end) - }).toArray - } - - override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { - context.addOnCompleteCallback{ () => closeIfNeeded() } - val part = thePart.asInstanceOf[JdbcPartition] - val conn = getConnection() - val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) - - // setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results, - // rather than pulling entire resultset into memory. - // see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html - if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) { - stmt.setFetchSize(Integer.MIN_VALUE) - logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ") - } - - stmt.setLong(1, part.lower) - stmt.setLong(2, part.upper) - val rs = stmt.executeQuery() - - override def getNext: T = { - if (rs.next()) { - mapRow(rs) - } else { - finished = true - null.asInstanceOf[T] - } - } - - override def close() { - try { - if (null != rs && ! rs.isClosed()) rs.close() - } catch { - case e: Exception => logWarning("Exception closing resultset", e) - } - try { - if (null != stmt && ! stmt.isClosed()) stmt.close() - } catch { - case e: Exception => logWarning("Exception closing statement", e) - } - try { - if (null != conn && ! stmt.isClosed()) conn.close() - logInfo("closed connection") - } catch { - case e: Exception => logWarning("Exception closing connection", e) - } - } - } -} - -object JdbcRDD { - def resultSetToObjectArray(rs: ResultSet) = { - Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) - } -} diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala deleted file mode 100644 index af8f0a112f..0000000000 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{RDD, Partition, TaskContext} - - -private[spark] -class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], - f: Iterator[T] => Iterator[U], - preservesPartitioning: Boolean = false) - extends RDD[U](prev) { - - override val partitioner = - if (preservesPartitioning) firstParent[T].partitioner else None - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override def compute(split: Partition, context: TaskContext) = - f(firstParent[T].iterator(split, context)) -} diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithIndexRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithIndexRDD.scala deleted file mode 100644 index 3b4e9518fd..0000000000 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithIndexRDD.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{RDD, Partition, TaskContext} - - -/** - * A variant of the MapPartitionsRDD that passes the partition index into the - * closure. This can be used to generate or collect partition specific - * information such as the number of tuples in a partition. - */ -private[spark] -class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean - ) extends RDD[U](prev) { - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override val partitioner = if (preservesPartitioning) prev.partitioner else None - - override def compute(split: Partition, context: TaskContext) = - f(split.index, firstParent[T].iterator(split, context)) -} diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala deleted file mode 100644 index 8b411dd85d..0000000000 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{RDD, Partition, TaskContext} - -private[spark] -class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U) - extends RDD[U](prev) { - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override def compute(split: Partition, context: TaskContext) = - firstParent[T].iterator(split, context).map(f) -} diff --git a/core/src/main/scala/spark/rdd/MappedValuesRDD.scala b/core/src/main/scala/spark/rdd/MappedValuesRDD.scala deleted file mode 100644 index 8334e3b557..0000000000 --- a/core/src/main/scala/spark/rdd/MappedValuesRDD.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - - -import spark.{TaskContext, Partition, RDD} - -private[spark] -class MappedValuesRDD[K, V, U](prev: RDD[_ <: Product2[K, V]], f: V => U) - extends RDD[(K, U)](prev) { - - override def getPartitions = firstParent[Product2[K, U]].partitions - - override val partitioner = firstParent[Product2[K, U]].partitioner - - override def compute(split: Partition, context: TaskContext): Iterator[(K, U)] = { - firstParent[Product2[K, V]].iterator(split, context).map { case Product2(k ,v) => (k, f(v)) } - } -} diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala deleted file mode 100644 index b1877dc06e..0000000000 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import java.text.SimpleDateFormat -import java.util.Date - -import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapreduce._ - -import spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, TaskContext} - - -private[spark] -class NewHadoopPartition(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable) - extends Partition { - - val serializableHadoopSplit = new SerializableWritable(rawSplit) - - override def hashCode(): Int = (41 * (41 + rddId) + index) -} - -class NewHadoopRDD[K, V]( - sc : SparkContext, - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], - valueClass: Class[V], - @transient conf: Configuration) - extends RDD[(K, V)](sc, Nil) - with SparkHadoopMapReduceUtil - with Logging { - - // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = sc.broadcast(new SerializableWritable(conf)) - // private val serializableConf = new SerializableWritable(conf) - - private val jobtrackerId: String = { - val formatter = new SimpleDateFormat("yyyyMMddHHmm") - formatter.format(new Date()) - } - - @transient private val jobId = new JobID(jobtrackerId, id) - - override def getPartitions: Array[Partition] = { - val inputFormat = inputFormatClass.newInstance - if (inputFormat.isInstanceOf[Configurable]) { - inputFormat.asInstanceOf[Configurable].setConf(conf) - } - val jobContext = newJobContext(conf, jobId) - val rawSplits = inputFormat.getSplits(jobContext).toArray - val result = new Array[Partition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) - } - result - } - - override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] { - val split = theSplit.asInstanceOf[NewHadoopPartition] - logInfo("Input split: " + split.serializableHadoopSplit) - val conf = confBroadcast.value.value - val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) - val format = inputFormatClass.newInstance - if (format.isInstanceOf[Configurable]) { - format.asInstanceOf[Configurable].setConf(conf) - } - val reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - - // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => close()) - - var havePair = false - var finished = false - - override def hasNext: Boolean = { - if (!finished && !havePair) { - finished = !reader.nextKeyValue - havePair = !finished - } - !finished - } - - override def next: (K, V) = { - if (!hasNext) { - throw new java.util.NoSuchElementException("End of stream") - } - havePair = false - return (reader.getCurrentKey, reader.getCurrentValue) - } - - private def close() { - try { - reader.close() - } catch { - case e: Exception => logWarning("Exception in RecordReader.close()", e) - } - } - } - - override def getPreferredLocations(split: Partition): Seq[String] = { - val theSplit = split.asInstanceOf[NewHadoopPartition] - theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") - } - - def getConf: Configuration = confBroadcast.value.value -} - diff --git a/core/src/main/scala/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/spark/rdd/OrderedRDDFunctions.scala deleted file mode 100644 index 9154b76035..0000000000 --- a/core/src/main/scala/spark/rdd/OrderedRDDFunctions.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{RangePartitioner, Logging, RDD} - -/** - * Extra functions available on RDDs of (key, value) pairs where the key is sortable through - * an implicit conversion. Import `spark.SparkContext._` at the top of your program to use these - * functions. They will work with any key type that has a `scala.math.Ordered` implementation. - */ -class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, - V: ClassManifest, - P <: Product2[K, V] : ClassManifest]( - self: RDD[P]) - extends Logging with Serializable { - - /** - * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling - * `collect` or `save` on the resulting RDD will return or output an ordered list of records - * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in - * order of the keys). - */ - def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = { - val part = new RangePartitioner(numPartitions, self, ascending) - val shuffled = new ShuffledRDD[K, V, P](self, part) - shuffled.mapPartitions(iter => { - val buf = iter.toArray - if (ascending) { - buf.sortWith((x, y) => x._1 < y._1).iterator - } else { - buf.sortWith((x, y) => x._1 > y._1).iterator - } - }, preservesPartitioning = true) - } -} diff --git a/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala deleted file mode 100644 index 33079cd539..0000000000 --- a/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import scala.collection.immutable.NumericRange -import scala.collection.mutable.ArrayBuffer -import scala.collection.Map -import spark._ -import java.io._ -import scala.Serializable - -private[spark] class ParallelCollectionPartition[T: ClassManifest]( - var rddId: Long, - var slice: Int, - var values: Seq[T]) - extends Partition with Serializable { - - def iterator: Iterator[T] = values.iterator - - override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt - - override def equals(other: Any): Boolean = other match { - case that: ParallelCollectionPartition[_] => (this.rddId == that.rddId && this.slice == that.slice) - case _ => false - } - - override def index: Int = slice - - @throws(classOf[IOException]) - private def writeObject(out: ObjectOutputStream): Unit = { - - val sfactory = SparkEnv.get.serializer - - // Treat java serializer with default action rather than going thru serialization, to avoid a - // separate serialization header. - - sfactory match { - case js: JavaSerializer => out.defaultWriteObject() - case _ => - out.writeLong(rddId) - out.writeInt(slice) - - val ser = sfactory.newInstance() - Utils.serializeViaNestedStream(out, ser)(_.writeObject(values)) - } - } - - @throws(classOf[IOException]) - private def readObject(in: ObjectInputStream): Unit = { - - val sfactory = SparkEnv.get.serializer - sfactory match { - case js: JavaSerializer => in.defaultReadObject() - case _ => - rddId = in.readLong() - slice = in.readInt() - - val ser = sfactory.newInstance() - Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject()) - } - } -} - -private[spark] class ParallelCollectionRDD[T: ClassManifest]( - @transient sc: SparkContext, - @transient data: Seq[T], - numSlices: Int, - locationPrefs: Map[Int, Seq[String]]) - extends RDD[T](sc, Nil) { - // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets - // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split - // instead. - // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal. - - override def getPartitions: Array[Partition] = { - val slices = ParallelCollectionRDD.slice(data, numSlices).toArray - slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray - } - - override def compute(s: Partition, context: TaskContext) = - s.asInstanceOf[ParallelCollectionPartition[T]].iterator - - override def getPreferredLocations(s: Partition): Seq[String] = { - locationPrefs.getOrElse(s.index, Nil) - } -} - -private object ParallelCollectionRDD { - /** - * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range - * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes - * it efficient to run Spark over RDDs representing large sets of numbers. - */ - def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { - if (numSlices < 1) { - throw new IllegalArgumentException("Positive number of slices required") - } - seq match { - case r: Range.Inclusive => { - val sign = if (r.step < 0) { - -1 - } else { - 1 - } - slice(new Range( - r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices) - } - case r: Range => { - (0 until numSlices).map(i => { - val start = ((i * r.length.toLong) / numSlices).toInt - val end = (((i + 1) * r.length.toLong) / numSlices).toInt - new Range(r.start + start * r.step, r.start + end * r.step, r.step) - }).asInstanceOf[Seq[Seq[T]]] - } - case nr: NumericRange[_] => { - // For ranges of Long, Double, BigInteger, etc - val slices = new ArrayBuffer[Seq[T]](numSlices) - val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything - var r = nr - for (i <- 0 until numSlices) { - slices += r.take(sliceSize).asInstanceOf[Seq[T]] - r = r.drop(sliceSize) - } - slices - } - case _ => { - val array = seq.toArray // To prevent O(n^2) operations for List etc - (0 until numSlices).map(i => { - val start = ((i * array.length.toLong) / numSlices).toInt - val end = (((i + 1) * array.length.toLong) / numSlices).toInt - array.slice(start, end).toSeq - }) - } - } - } -} diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala deleted file mode 100644 index d8700becb0..0000000000 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{NarrowDependency, RDD, SparkEnv, Partition, TaskContext} - - -class PartitionPruningRDDPartition(idx: Int, val parentSplit: Partition) extends Partition { - override val index = idx -} - - -/** - * Represents a dependency between the PartitionPruningRDD and its parent. In this - * case, the child RDD contains a subset of partitions of the parents'. - */ -class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) - extends NarrowDependency[T](rdd) { - - @transient - val partitions: Array[Partition] = rdd.partitions.zipWithIndex - .filter(s => partitionFilterFunc(s._2)) - .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } - - override def getParents(partitionId: Int) = List(partitions(partitionId).index) -} - - -/** - * A RDD used to prune RDD partitions/partitions so we can avoid launching tasks on - * all partitions. An example use case: If we know the RDD is partitioned by range, - * and the execution DAG has a filter on the key, we can avoid launching tasks - * on partitions that don't have the range covering the key. - */ -class PartitionPruningRDD[T: ClassManifest]( - @transient prev: RDD[T], - @transient partitionFilterFunc: Int => Boolean) - extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { - - override def compute(split: Partition, context: TaskContext) = firstParent[T].iterator( - split.asInstanceOf[PartitionPruningRDDPartition].parentSplit, context) - - override protected def getPartitions: Array[Partition] = - getDependencies.head.asInstanceOf[PruneDependency[T]].partitions -} - - -object PartitionPruningRDD { - - /** - * Create a PartitionPruningRDD. This function can be used to create the PartitionPruningRDD - * when its type T is not known at compile time. - */ - def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) = { - new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassManifest) - } -} diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala deleted file mode 100644 index 2cefdc78b0..0000000000 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import java.io.PrintWriter -import java.util.StringTokenizer - -import scala.collection.Map -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.io.Source - -import spark.{RDD, SparkEnv, Partition, TaskContext} -import spark.broadcast.Broadcast - - -/** - * An RDD that pipes the contents of each parent partition through an external command - * (printing them one per line) and returns the output as a collection of strings. - */ -class PipedRDD[T: ClassManifest]( - prev: RDD[T], - command: Seq[String], - envVars: Map[String, String], - printPipeContext: (String => Unit) => Unit, - printRDDElement: (T, String => Unit) => Unit) - extends RDD[String](prev) { - - // Similar to Runtime.exec(), if we are given a single string, split it into words - // using a standard StringTokenizer (i.e. by spaces) - def this( - prev: RDD[T], - command: String, - envVars: Map[String, String] = Map(), - printPipeContext: (String => Unit) => Unit = null, - printRDDElement: (T, String => Unit) => Unit = null) = - this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement) - - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override def compute(split: Partition, context: TaskContext): Iterator[String] = { - val pb = new ProcessBuilder(command) - // Add the environmental variables to the process. - val currentEnvVars = pb.environment() - envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) } - - val proc = pb.start() - val env = SparkEnv.get - - // Start a thread to print the process's stderr to ours - new Thread("stderr reader for " + command) { - override def run() { - for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { - System.err.println(line) - } - } - }.start() - - // Start a thread to feed the process input from our parent's iterator - new Thread("stdin writer for " + command) { - override def run() { - SparkEnv.set(env) - val out = new PrintWriter(proc.getOutputStream) - - // input the pipe context firstly - if (printPipeContext != null) { - printPipeContext(out.println(_)) - } - for (elem <- firstParent[T].iterator(split, context)) { - if (printRDDElement != null) { - printRDDElement(elem, out.println(_)) - } else { - out.println(elem) - } - } - out.close() - } - }.start() - - // Return an iterator that read lines from the process's stdout - val lines = Source.fromInputStream(proc.getInputStream).getLines - return new Iterator[String] { - def next() = lines.next() - def hasNext = { - if (lines.hasNext) { - true - } else { - val exitStatus = proc.waitFor() - if (exitStatus != 0) { - throw new Exception("Subprocess exited with status " + exitStatus) - } - false - } - } - } - } -} - -object PipedRDD { - // Split a string into words using a standard StringTokenizer - def tokenize(command: String): Seq[String] = { - val buf = new ArrayBuffer[String] - val tok = new StringTokenizer(command) - while(tok.hasMoreElements) { - buf += tok.nextToken() - } - buf - } -} diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala deleted file mode 100644 index 574c9b141d..0000000000 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import java.util.Random - -import cern.jet.random.Poisson -import cern.jet.random.engine.DRand - -import spark.{RDD, Partition, TaskContext} - -private[spark] -class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable { - override val index: Int = prev.index -} - -class SampledRDD[T: ClassManifest]( - prev: RDD[T], - withReplacement: Boolean, - frac: Double, - seed: Int) - extends RDD[T](prev) { - - override def getPartitions: Array[Partition] = { - val rg = new Random(seed) - firstParent[T].partitions.map(x => new SampledRDDPartition(x, rg.nextInt)) - } - - override def getPreferredLocations(split: Partition): Seq[String] = - firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDPartition].prev) - - override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = { - val split = splitIn.asInstanceOf[SampledRDDPartition] - if (withReplacement) { - // For large datasets, the expected number of occurrences of each element in a sample with - // replacement is Poisson(frac). We use that to get a count for each element. - val poisson = new Poisson(frac, new DRand(split.seed)) - firstParent[T].iterator(split.prev, context).flatMap { element => - val count = poisson.nextInt() - if (count == 0) { - Iterator.empty // Avoid object allocation when we return 0 items, which is quite often - } else { - Iterator.fill(count)(element) - } - } - } else { // Sampling without replacement - val rand = new Random(split.seed) - firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac)) - } - } -} diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala deleted file mode 100644 index 51c05af064..0000000000 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{Dependency, Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext} - - -private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { - override val index = idx - override def hashCode(): Int = idx -} - -/** - * The resulting RDD from a shuffle (e.g. repartitioning of data). - * @param prev the parent RDD. - * @param part the partitioner used to partition the RDD - * @tparam K the key class. - * @tparam V the value class. - */ -class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest]( - @transient var prev: RDD[P], - part: Partitioner) - extends RDD[P](prev.context, Nil) { - - private var serializerClass: String = null - - def setSerializer(cls: String): ShuffledRDD[K, V, P] = { - serializerClass = cls - this - } - - override def getDependencies: Seq[Dependency[_]] = { - List(new ShuffleDependency(prev, part, serializerClass)) - } - - override val partitioner = Some(part) - - override def getPartitions: Array[Partition] = { - Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i)) - } - - override def compute(split: Partition, context: TaskContext): Iterator[P] = { - val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId - SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context.taskMetrics, - SparkEnv.get.serializerManager.get(serializerClass)) - } - - override def clearDependencies() { - super.clearDependencies() - prev = null - } -} diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala deleted file mode 100644 index dadef5e17d..0000000000 --- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import spark.RDD -import spark.Partitioner -import spark.Dependency -import spark.TaskContext -import spark.Partition -import spark.SparkEnv -import spark.ShuffleDependency -import spark.OneToOneDependency - - -/** - * An optimized version of cogroup for set difference/subtraction. - * - * It is possible to implement this operation with just `cogroup`, but - * that is less efficient because all of the entries from `rdd2`, for - * both matching and non-matching values in `rdd1`, are kept in the - * JHashMap until the end. - * - * With this implementation, only the entries from `rdd1` are kept in-memory, - * and the entries from `rdd2` are essentially streamed, as we only need to - * touch each once to decide if the value needs to be removed. - * - * This is particularly helpful when `rdd1` is much smaller than `rdd2`, as - * you can use `rdd1`'s partitioner/partition size and not worry about running - * out of memory because of the size of `rdd2`. - */ -private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest]( - @transient var rdd1: RDD[_ <: Product2[K, V]], - @transient var rdd2: RDD[_ <: Product2[K, W]], - part: Partitioner) - extends RDD[(K, V)](rdd1.context, Nil) { - - private var serializerClass: String = null - - def setSerializer(cls: String): SubtractedRDD[K, V, W] = { - serializerClass = cls - this - } - - override def getDependencies: Seq[Dependency[_]] = { - Seq(rdd1, rdd2).map { rdd => - if (rdd.partitioner == Some(part)) { - logDebug("Adding one-to-one dependency with " + rdd) - new OneToOneDependency(rdd) - } else { - logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency(rdd, part, serializerClass) - } - } - } - - override def getPartitions: Array[Partition] = { - val array = new Array[Partition](part.numPartitions) - for (i <- 0 until array.size) { - // Each CoGroupPartition will depend on rdd1 and rdd2 - array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => - dependencies(j) match { - case s: ShuffleDependency[_, _] => - new ShuffleCoGroupSplitDep(s.shuffleId) - case _ => - new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) - } - }.toArray) - } - array - } - - override val partitioner = Some(part) - - override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = { - val partition = p.asInstanceOf[CoGroupPartition] - val serializer = SparkEnv.get.serializerManager.get(serializerClass) - val map = new JHashMap[K, ArrayBuffer[V]] - def getSeq(k: K): ArrayBuffer[V] = { - val seq = map.get(k) - if (seq != null) { - seq - } else { - val seq = new ArrayBuffer[V]() - map.put(k, seq) - seq - } - } - def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { - rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op) - } - case ShuffleCoGroupSplitDep(shuffleId) => { - val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index, - context.taskMetrics, serializer) - iter.foreach(op) - } - } - // the first dep is rdd1; add all values to the map - integrate(partition.deps(0), t => getSeq(t._1) += t._2) - // the second dep is rdd2; remove all of its keys - integrate(partition.deps(1), t => map.remove(t._1)) - map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten - } - - override def clearDependencies() { - super.clearDependencies() - rdd1 = null - rdd2 = null - } - -} diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala deleted file mode 100644 index 2776826f18..0000000000 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import scala.collection.mutable.ArrayBuffer -import spark.{Dependency, RangeDependency, RDD, SparkContext, Partition, TaskContext} -import java.io.{ObjectOutputStream, IOException} - -private[spark] class UnionPartition[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int) - extends Partition { - - var split: Partition = rdd.partitions(splitIndex) - - def iterator(context: TaskContext) = rdd.iterator(split, context) - - def preferredLocations() = rdd.preferredLocations(split) - - override val index: Int = idx - - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - // Update the reference to parent split at the time of task serialization - split = rdd.partitions(splitIndex) - oos.defaultWriteObject() - } -} - -class UnionRDD[T: ClassManifest]( - sc: SparkContext, - @transient var rdds: Seq[RDD[T]]) - extends RDD[T](sc, Nil) { // Nil since we implement getDependencies - - override def getPartitions: Array[Partition] = { - val array = new Array[Partition](rdds.map(_.partitions.size).sum) - var pos = 0 - for (rdd <- rdds; split <- rdd.partitions) { - array(pos) = new UnionPartition(pos, rdd, split.index) - pos += 1 - } - array - } - - override def getDependencies: Seq[Dependency[_]] = { - val deps = new ArrayBuffer[Dependency[_]] - var pos = 0 - for (rdd <- rdds) { - deps += new RangeDependency(rdd, 0, pos, rdd.partitions.size) - pos += rdd.partitions.size - } - deps - } - - override def compute(s: Partition, context: TaskContext): Iterator[T] = - s.asInstanceOf[UnionPartition[T]].iterator(context) - - override def getPreferredLocations(s: Partition): Seq[String] = - s.asInstanceOf[UnionPartition[T]].preferredLocations() -} diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala deleted file mode 100644 index 9a0831bd89..0000000000 --- a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{Utils, OneToOneDependency, RDD, SparkContext, Partition, TaskContext} -import java.io.{ObjectOutputStream, IOException} - -private[spark] class ZippedPartitionsPartition( - idx: Int, - @transient rdds: Seq[RDD[_]]) - extends Partition { - - override val index: Int = idx - var partitionValues = rdds.map(rdd => rdd.partitions(idx)) - def partitions = partitionValues - - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - // Update the reference to parent split at the time of task serialization - partitionValues = rdds.map(rdd => rdd.partitions(idx)) - oos.defaultWriteObject() - } -} - -abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( - sc: SparkContext, - var rdds: Seq[RDD[_]]) - extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) { - - override def getPartitions: Array[Partition] = { - val sizes = rdds.map(x => x.partitions.size) - if (!sizes.forall(x => x == sizes(0))) { - throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") - } - val array = new Array[Partition](sizes(0)) - for (i <- 0 until sizes(0)) { - array(i) = new ZippedPartitionsPartition(i, rdds) - } - array - } - - override def getPreferredLocations(s: Partition): Seq[String] = { - val parts = s.asInstanceOf[ZippedPartitionsPartition].partitions - val prefs = rdds.zip(parts).map { case (rdd, p) => rdd.preferredLocations(p) } - // Check whether there are any hosts that match all RDDs; otherwise return the union - val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y)) - if (!exactMatchLocations.isEmpty) { - exactMatchLocations - } else { - prefs.flatten.distinct - } - } - - override def clearDependencies() { - super.clearDependencies() - rdds = null - } -} - -class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest]( - sc: SparkContext, - f: (Iterator[A], Iterator[B]) => Iterator[V], - var rdd1: RDD[A], - var rdd2: RDD[B]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) { - - override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions - f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context)) - } - - override def clearDependencies() { - super.clearDependencies() - rdd1 = null - rdd2 = null - } -} - -class ZippedPartitionsRDD3 - [A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest]( - sc: SparkContext, - f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], - var rdd1: RDD[A], - var rdd2: RDD[B], - var rdd3: RDD[C]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) { - - override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions - f(rdd1.iterator(partitions(0), context), - rdd2.iterator(partitions(1), context), - rdd3.iterator(partitions(2), context)) - } - - override def clearDependencies() { - super.clearDependencies() - rdd1 = null - rdd2 = null - rdd3 = null - } -} - -class ZippedPartitionsRDD4 - [A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest]( - sc: SparkContext, - f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], - var rdd1: RDD[A], - var rdd2: RDD[B], - var rdd3: RDD[C], - var rdd4: RDD[D]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) { - - override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions - f(rdd1.iterator(partitions(0), context), - rdd2.iterator(partitions(1), context), - rdd3.iterator(partitions(2), context), - rdd4.iterator(partitions(3), context)) - } - - override def clearDependencies() { - super.clearDependencies() - rdd1 = null - rdd2 = null - rdd3 = null - rdd4 = null - } -} diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala deleted file mode 100644 index 4074e50e44..0000000000 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import spark.{Utils, OneToOneDependency, RDD, SparkContext, Partition, TaskContext} -import java.io.{ObjectOutputStream, IOException} - - -private[spark] class ZippedPartition[T: ClassManifest, U: ClassManifest]( - idx: Int, - @transient rdd1: RDD[T], - @transient rdd2: RDD[U] - ) extends Partition { - - var partition1 = rdd1.partitions(idx) - var partition2 = rdd2.partitions(idx) - override val index: Int = idx - - def partitions = (partition1, partition2) - - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - // Update the reference to parent partition at the time of task serialization - partition1 = rdd1.partitions(idx) - partition2 = rdd2.partitions(idx) - oos.defaultWriteObject() - } -} - -class ZippedRDD[T: ClassManifest, U: ClassManifest]( - sc: SparkContext, - var rdd1: RDD[T], - var rdd2: RDD[U]) - extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))) { - - override def getPartitions: Array[Partition] = { - if (rdd1.partitions.size != rdd2.partitions.size) { - throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") - } - val array = new Array[Partition](rdd1.partitions.size) - for (i <- 0 until rdd1.partitions.size) { - array(i) = new ZippedPartition(i, rdd1, rdd2) - } - array - } - - override def compute(s: Partition, context: TaskContext): Iterator[(T, U)] = { - val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions - rdd1.iterator(partition1, context).zip(rdd2.iterator(partition2, context)) - } - - override def getPreferredLocations(s: Partition): Seq[String] = { - val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions - val pref1 = rdd1.preferredLocations(partition1) - val pref2 = rdd2.preferredLocations(partition2) - // Check whether there are any hosts that match both RDDs; otherwise return the union - val exactMatchLocations = pref1.intersect(pref2) - if (!exactMatchLocations.isEmpty) { - exactMatchLocations - } else { - (pref1 ++ pref2).distinct - } - } - - override def clearDependencies() { - super.clearDependencies() - rdd1 = null - rdd2 = null - } -} diff --git a/core/src/main/scala/spark/scheduler/ActiveJob.scala b/core/src/main/scala/spark/scheduler/ActiveJob.scala deleted file mode 100644 index fecc3e9648..0000000000 --- a/core/src/main/scala/spark/scheduler/ActiveJob.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import spark.TaskContext - -import java.util.Properties - -/** - * Tracks information about an active job in the DAGScheduler. - */ -private[spark] class ActiveJob( - val jobId: Int, - val finalStage: Stage, - val func: (TaskContext, Iterator[_]) => _, - val partitions: Array[Int], - val callSite: String, - val listener: JobListener, - val properties: Properties) { - - val numPartitions = partitions.length - val finished = Array.fill[Boolean](numPartitions)(false) - var numFinished = 0 -} diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala deleted file mode 100644 index 7275bd346a..0000000000 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ /dev/null @@ -1,849 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.io.NotSerializableException -import java.util.Properties -import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} - -import spark._ -import spark.executor.TaskMetrics -import spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import spark.scheduler.cluster.TaskInfo -import spark.storage.{BlockManager, BlockManagerMaster} -import spark.util.{MetadataCleaner, TimeStampedHashMap} - -/** - * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of - * stages for each job, keeps track of which RDDs and stage outputs are materialized, and finds a - * minimal schedule to run the job. It then submits stages as TaskSets to an underlying - * TaskScheduler implementation that runs them on the cluster. - * - * In addition to coming up with a DAG of stages, this class also determines the preferred - * locations to run each task on, based on the current cache status, and passes these to the - * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being - * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are - * not caused by shuffie file loss are handled by the TaskScheduler, which will retry each task - * a small number of times before cancelling the whole stage. - * - * THREADING: This class runs all its logic in a single thread executing the run() method, to which - * events are submitted using a synchonized queue (eventQueue). The public API methods, such as - * runJob, taskEnded and executorLost, post events asynchronously to this queue. All other methods - * should be private. - */ -private[spark] -class DAGScheduler( - taskSched: TaskScheduler, - mapOutputTracker: MapOutputTracker, - blockManagerMaster: BlockManagerMaster, - env: SparkEnv) - extends TaskSchedulerListener with Logging { - - def this(taskSched: TaskScheduler) { - this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get) - } - taskSched.setListener(this) - - // Called by TaskScheduler to report task's starting. - override def taskStarted(task: Task[_], taskInfo: TaskInfo) { - eventQueue.put(BeginEvent(task, taskInfo)) - } - - // Called by TaskScheduler to report task completions or failures. - override def taskEnded( - task: Task[_], - reason: TaskEndReason, - result: Any, - accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics) { - eventQueue.put(CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) - } - - // Called by TaskScheduler when an executor fails. - override def executorLost(execId: String) { - eventQueue.put(ExecutorLost(execId)) - } - - // Called by TaskScheduler when a host is added - override def executorGained(execId: String, host: String) { - eventQueue.put(ExecutorGained(execId, host)) - } - - // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures. - override def taskSetFailed(taskSet: TaskSet, reason: String) { - eventQueue.put(TaskSetFailed(taskSet, reason)) - } - - // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; - // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one - // as more failure events come in - val RESUBMIT_TIMEOUT = 50L - - // The time, in millis, to wake up between polls of the completion queue in order to potentially - // resubmit failed stages - val POLL_TIMEOUT = 10L - - private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent] - - val nextJobId = new AtomicInteger(0) - - val nextStageId = new AtomicInteger(0) - - val stageIdToStage = new TimeStampedHashMap[Int, Stage] - - val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] - - private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] - - private val listenerBus = new SparkListenerBus() - - // Contains the locations that each RDD's partitions are cached on - private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]] - - // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with - // every task. When we detect a node failing, we note the current epoch number and failed - // executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask results. - // - // TODO: Garbage collect information about failure epochs when we know there are no more - // stray messages to detect. - val failedEpoch = new HashMap[String, Long] - - val idToActiveJob = new HashMap[Int, ActiveJob] - - val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done - val running = new HashSet[Stage] // Stages we are running right now - val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures - val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage - var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits - - val activeJobs = new HashSet[ActiveJob] - val resultStageToJob = new HashMap[Stage, ActiveJob] - - val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) - - // Start a thread to run the DAGScheduler event loop - def start() { - new Thread("DAGScheduler") { - setDaemon(true) - override def run() { - DAGScheduler.this.run() - } - }.start() - } - - def addSparkListener(listener: SparkListener) { - listenerBus.addListener(listener) - } - - private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { - if (!cacheLocs.contains(rdd.id)) { - val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray - val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster) - cacheLocs(rdd.id) = blockIds.map { id => - locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId)) - } - } - cacheLocs(rdd.id) - } - - private def clearCacheLocs() { - cacheLocs.clear() - } - - /** - * Get or create a shuffle map stage for the given shuffle dependency's map side. - * The jobId value passed in will be used if the stage doesn't already exist with - * a lower jobId (jobId always increases across jobs.) - */ - private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], jobId: Int): Stage = { - shuffleToMapStage.get(shuffleDep.shuffleId) match { - case Some(stage) => stage - case None => - val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId) - shuffleToMapStage(shuffleDep.shuffleId) = stage - stage - } - } - - /** - * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or - * as a result stage for the final RDD used directly in an action. The stage will also be - * associated with the provided jobId. - */ - private def newStage( - rdd: RDD[_], - shuffleDep: Option[ShuffleDependency[_,_]], - jobId: Int, - callSite: Option[String] = None) - : Stage = - { - if (shuffleDep != None) { - // Kind of ugly: need to register RDDs with the cache and map output tracker here - // since we can't do it in the RDD constructor because # of partitions is unknown - logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") - mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size) - } - val id = nextStageId.getAndIncrement() - val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) - stageIdToStage(id) = stage - stageToInfos(stage) = StageInfo(stage) - stage - } - - /** - * Get or create the list of parent stages for a given RDD. The stages will be assigned the - * provided jobId if they haven't already been created with a lower jobId. - */ - private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = { - val parents = new HashSet[Stage] - val visited = new HashSet[RDD[_]] - def visit(r: RDD[_]) { - if (!visited(r)) { - visited += r - // Kind of ugly: need to register RDDs with the cache here since - // we can't do it in its constructor because # of partitions is unknown - for (dep <- r.dependencies) { - dep match { - case shufDep: ShuffleDependency[_,_] => - parents += getShuffleMapStage(shufDep, jobId) - case _ => - visit(dep.rdd) - } - } - } - } - visit(rdd) - parents.toList - } - - private def getMissingParentStages(stage: Stage): List[Stage] = { - val missing = new HashSet[Stage] - val visited = new HashSet[RDD[_]] - def visit(rdd: RDD[_]) { - if (!visited(rdd)) { - visited += rdd - if (getCacheLocs(rdd).contains(Nil)) { - for (dep <- rdd.dependencies) { - dep match { - case shufDep: ShuffleDependency[_,_] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) - if (!mapStage.isAvailable) { - missing += mapStage - } - case narrowDep: NarrowDependency[_] => - visit(narrowDep.rdd) - } - } - } - } - } - visit(stage.rdd) - missing.toList - } - - /** - * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a - * JobWaiter whose getResult() method will return the result of the job when it is complete. - * - * The job is assumed to have at least one partition; zero partition jobs should be handled - * without a JobSubmitted event. - */ - private[scheduler] def prepareJob[T, U: ClassManifest]( - finalRdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitions: Seq[Int], - callSite: String, - allowLocal: Boolean, - resultHandler: (Int, U) => Unit, - properties: Properties = null) - : (JobSubmitted, JobWaiter[U]) = - { - assert(partitions.size > 0) - val waiter = new JobWaiter(partitions.size, resultHandler) - val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, - properties) - (toSubmit, waiter) - } - - def runJob[T, U: ClassManifest]( - finalRdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitions: Seq[Int], - callSite: String, - allowLocal: Boolean, - resultHandler: (Int, U) => Unit, - properties: Properties = null) - { - if (partitions.size == 0) { - return - } - - // Check to make sure we are not launching a task on a partition that does not exist. - val maxPartitions = finalRdd.partitions.length - partitions.find(p => p >= maxPartitions).foreach { p => - throw new IllegalArgumentException( - "Attempting to access a non-existent partition: " + p + ". " + - "Total number of partitions: " + maxPartitions) - } - - val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob( - finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties) - eventQueue.put(toSubmit) - waiter.awaitResult() match { - case JobSucceeded => {} - case JobFailed(exception: Exception, _) => - logInfo("Failed to run " + callSite) - throw exception - } - } - - def runApproximateJob[T, U, R]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - evaluator: ApproximateEvaluator[U, R], - callSite: String, - timeout: Long, - properties: Properties = null) - : PartialResult[R] = - { - val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) - val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val partitions = (0 until rdd.partitions.size).toArray - eventQueue.put(JobSubmitted(rdd, func2, partitions, allowLocal = false, callSite, listener, properties)) - listener.awaitResult() // Will throw an exception if the job fails - } - - /** - * Process one event retrieved from the event queue. - * Returns true if we should stop the event loop. - */ - private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { - event match { - case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) => - val jobId = nextJobId.getAndIncrement() - val finalStage = newStage(finalRDD, None, jobId, Some(callSite)) - val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) - clearCacheLocs() - logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length + - " output partitions (allowLocal=" + allowLocal + ")") - logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") - logInfo("Parents of final stage: " + finalStage.parents) - logInfo("Missing parents: " + getMissingParentStages(finalStage)) - if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { - // Compute very short actions like first() or take() with no parent stages locally. - runLocally(job) - } else { - listenerBus.post(SparkListenerJobStart(job, properties)) - idToActiveJob(jobId) = job - activeJobs += job - resultStageToJob(finalStage) = job - submitStage(finalStage) - } - - case ExecutorGained(execId, host) => - handleExecutorGained(execId, host) - - case ExecutorLost(execId) => - handleExecutorLost(execId) - - case begin: BeginEvent => - listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo)) - - case completion: CompletionEvent => - listenerBus.post(SparkListenerTaskEnd( - completion.task, completion.reason, completion.taskInfo, completion.taskMetrics)) - handleTaskCompletion(completion) - - case TaskSetFailed(taskSet, reason) => - abortStage(stageIdToStage(taskSet.stageId), reason) - - case StopDAGScheduler => - // Cancel any active jobs - for (job <- activeJobs) { - val error = new SparkException("Job cancelled because SparkContext was shut down") - job.listener.jobFailed(error) - listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, None))) - } - return true - } - false - } - - /** - * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since - * the last fetch failure. - */ - private[scheduler] def resubmitFailedStages() { - logInfo("Resubmitting failed stages") - clearCacheLocs() - val failed2 = failed.toArray - failed.clear() - for (stage <- failed2.sortBy(_.jobId)) { - submitStage(stage) - } - } - - /** - * Check for waiting or failed stages which are now eligible for resubmission. - * Ordinarily run on every iteration of the event loop. - */ - private[scheduler] def submitWaitingStages() { - // TODO: We might want to run this less often, when we are sure that something has become - // runnable that wasn't before. - logTrace("Checking for newly runnable parent stages") - logTrace("running: " + running) - logTrace("waiting: " + waiting) - logTrace("failed: " + failed) - val waiting2 = waiting.toArray - waiting.clear() - for (stage <- waiting2.sortBy(_.jobId)) { - submitStage(stage) - } - } - - - /** - * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure - * events and responds by launching tasks. This runs in a dedicated thread and receives events - * via the eventQueue. - */ - private def run() { - SparkEnv.set(env) - - while (true) { - val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS) - if (event != null) { - logDebug("Got event of type " + event.getClass.getName) - } - this.synchronized { // needed in case other threads makes calls into methods of this class - if (event != null) { - if (processEvent(event)) { - return - } - } - - val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability - // Periodically resubmit failed stages if some map output fetches have failed and we have - // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails, - // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at - // the same time, so we want to make sure we've identified all the reduce tasks that depend - // on the failed node. - if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { - resubmitFailedStages() - } else { - submitWaitingStages() - } - } - } - } - - /** - * Run a job on an RDD locally, assuming it has only a single partition and no dependencies. - * We run the operation in a separate thread just in case it takes a bunch of time, so that we - * don't block the DAGScheduler event loop or other concurrent jobs. - */ - protected def runLocally(job: ActiveJob) { - logInfo("Computing the requested partition locally") - new Thread("Local computation of job " + job.jobId) { - override def run() { - runLocallyWithinThread(job) - } - }.start() - } - - // Broken out for easier testing in DAGSchedulerSuite. - protected def runLocallyWithinThread(job: ActiveJob) { - try { - SparkEnv.set(env) - val rdd = job.finalStage.rdd - val split = rdd.partitions(job.partitions(0)) - val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) - try { - val result = job.func(taskContext, rdd.iterator(split, taskContext)) - job.listener.taskSucceeded(0, result) - } finally { - taskContext.executeOnCompleteCallbacks() - } - } catch { - case e: Exception => - job.listener.jobFailed(e) - } - } - - /** Submits stage, but first recursively submits any missing parents. */ - private def submitStage(stage: Stage) { - logDebug("submitStage(" + stage + ")") - if (!waiting(stage) && !running(stage) && !failed(stage)) { - val missing = getMissingParentStages(stage).sortBy(_.id) - logDebug("missing: " + missing) - if (missing == Nil) { - logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") - submitMissingTasks(stage) - running += stage - } else { - for (parent <- missing) { - submitStage(parent) - } - waiting += stage - } - } - } - - /** Called when stage's parents are available and we can now do its task. */ - private def submitMissingTasks(stage: Stage) { - logDebug("submitMissingTasks(" + stage + ")") - // Get our pending tasks and remember them in our pendingTasks entry - val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) - myPending.clear() - var tasks = ArrayBuffer[Task[_]]() - if (stage.isShuffleMap) { - for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { - val locs = getPreferredLocs(stage.rdd, p) - tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs) - } - } else { - // This is a final stage; figure out its job's missing partitions - val job = resultStageToJob(stage) - for (id <- 0 until job.numPartitions if !job.finished(id)) { - val partition = job.partitions(id) - val locs = getPreferredLocs(stage.rdd, partition) - tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id) - } - } - // must be run listener before possible NotSerializableException - // should be "StageSubmitted" first and then "JobEnded" - val properties = idToActiveJob(stage.jobId).properties - listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties)) - - if (tasks.size > 0) { - // Preemptively serialize a task to make sure it can be serialized. We are catching this - // exception here because it would be fairly hard to catch the non-serializable exception - // down the road, where we have several different implementations for local scheduler and - // cluster schedulers. - try { - SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head) - } catch { - case e: NotSerializableException => - abortStage(stage, e.toString) - running -= stage - return - } - - logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") - myPending ++= tasks - logDebug("New pending tasks: " + myPending) - taskSched.submitTasks( - new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) - if (!stage.submissionTime.isDefined) { - stage.submissionTime = Some(System.currentTimeMillis()) - } - } else { - logDebug("Stage " + stage + " is actually done; %b %d %d".format( - stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) - running -= stage - } - } - - /** - * Responds to a task finishing. This is called inside the event loop so it assumes that it can - * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. - */ - private def handleTaskCompletion(event: CompletionEvent) { - val task = event.task - val stage = stageIdToStage(task.stageId) - - def markStageAsFinished(stage: Stage) = { - val serviceTime = stage.submissionTime match { - case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0) - case _ => "Unkown" - } - logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stage.completionTime = Some(System.currentTimeMillis) - listenerBus.post(StageCompleted(stageToInfos(stage))) - running -= stage - } - event.reason match { - case Success => - logInfo("Completed " + task) - if (event.accumUpdates != null) { - Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted - } - pendingTasks(stage) -= task - stageToInfos(stage).taskInfos += event.taskInfo -> event.taskMetrics - task match { - case rt: ResultTask[_, _] => - resultStageToJob.get(stage) match { - case Some(job) => - if (!job.finished(rt.outputId)) { - job.finished(rt.outputId) = true - job.numFinished += 1 - // If the whole job has finished, remove it - if (job.numFinished == job.numPartitions) { - idToActiveJob -= stage.jobId - activeJobs -= job - resultStageToJob -= stage - markStageAsFinished(stage) - listenerBus.post(SparkListenerJobEnd(job, JobSucceeded)) - } - job.listener.taskSucceeded(rt.outputId, event.result) - } - case None => - logInfo("Ignoring result from " + rt + " because its job has finished") - } - - case smt: ShuffleMapTask => - val status = event.result.asInstanceOf[MapStatus] - val execId = status.location.executorId - logDebug("ShuffleMapTask finished on " + execId) - if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { - logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) - } else { - stage.addOutputLoc(smt.partition, status) - } - if (running.contains(stage) && pendingTasks(stage).isEmpty) { - markStageAsFinished(stage) - logInfo("looking for newly runnable stages") - logInfo("running: " + running) - logInfo("waiting: " + waiting) - logInfo("failed: " + failed) - if (stage.shuffleDep != None) { - // We supply true to increment the epoch number here in case this is a - // recomputation of the map outputs. In that case, some nodes may have cached - // locations with holes (from when we detected the error) and will need the - // epoch incremented to refetch them. - // TODO: Only increment the epoch number if this is not the first time - // we registered these map outputs. - mapOutputTracker.registerMapOutputs( - stage.shuffleDep.get.shuffleId, - stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, - changeEpoch = true) - } - clearCacheLocs() - if (stage.outputLocs.count(_ == Nil) != 0) { - // Some tasks had failed; let's resubmit this stage - // TODO: Lower-level scheduler should also deal with this - logInfo("Resubmitting " + stage + " (" + stage.name + - ") because some of its tasks had failed: " + - stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", ")) - submitStage(stage) - } else { - val newlyRunnable = new ArrayBuffer[Stage] - for (stage <- waiting) { - logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage)) - } - for (stage <- waiting if getMissingParentStages(stage) == Nil) { - newlyRunnable += stage - } - waiting --= newlyRunnable - running ++= newlyRunnable - for (stage <- newlyRunnable.sortBy(_.id)) { - logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable") - submitMissingTasks(stage) - } - } - } - } - - case Resubmitted => - logInfo("Resubmitted " + task + ", so marking it as still running") - pendingTasks(stage) += task - - case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => - // Mark the stage that the reducer was in as unrunnable - val failedStage = stageIdToStage(task.stageId) - running -= failedStage - failed += failedStage - // TODO: Cancel running tasks in the stage - logInfo("Marking " + failedStage + " (" + failedStage.name + - ") for resubmision due to a fetch failure") - // Mark the map whose fetch failed as broken in the map stage - val mapStage = shuffleToMapStage(shuffleId) - if (mapId != -1) { - mapStage.removeOutputLoc(mapId, bmAddress) - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) - } - logInfo("The failed fetch was from " + mapStage + " (" + mapStage.name + - "); marking it for resubmission") - failed += mapStage - // Remember that a fetch failed now; this is used to resubmit the broken - // stages later, after a small wait (to give other tasks the chance to fail) - lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock - // TODO: mark the executor as failed only if there were lots of fetch failures on it - if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, Some(task.epoch)) - } - - case ExceptionFailure(className, description, stackTrace, metrics) => - // Do nothing here, left up to the TaskScheduler to decide how to handle user failures - - case other => - // Unrecognized failure - abort all jobs depending on this stage - abortStage(stageIdToStage(task.stageId), task + " failed: " + other) - } - } - - /** - * Responds to an executor being lost. This is called inside the event loop, so it assumes it can - * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. - * - * Optionally the epoch during which the failure was caught can be passed to avoid allowing - * stray fetch failures from possibly retriggering the detection of a node as lost. - */ - private def handleExecutorLost(execId: String, maybeEpoch: Option[Long] = None) { - val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch) - if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) { - failedEpoch(execId) = currentEpoch - logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch)) - blockManagerMaster.removeExecutor(execId) - // TODO: This will be really slow if we keep accumulating shuffle map stages - for ((shuffleId, stage) <- shuffleToMapStage) { - stage.removeOutputsOnExecutor(execId) - val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray - mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) - } - if (shuffleToMapStage.isEmpty) { - mapOutputTracker.incrementEpoch() - } - clearCacheLocs() - } else { - logDebug("Additional executor lost message for " + execId + - "(epoch " + currentEpoch + ")") - } - } - - private def handleExecutorGained(execId: String, host: String) { - // remove from failedEpoch(execId) ? - if (failedEpoch.contains(execId)) { - logInfo("Host gained which was in lost list earlier: " + host) - failedEpoch -= execId - } - } - - /** - * Aborts all jobs depending on a particular Stage. This is called in response to a task set - * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. - */ - private def abortStage(failedStage: Stage, reason: String) { - val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq - failedStage.completionTime = Some(System.currentTimeMillis()) - for (resultStage <- dependentStages) { - val job = resultStageToJob(resultStage) - val error = new SparkException("Job failed: " + reason) - job.listener.jobFailed(error) - listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage)))) - idToActiveJob -= resultStage.jobId - activeJobs -= job - resultStageToJob -= resultStage - } - if (dependentStages.isEmpty) { - logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") - } - } - - /** - * Return true if one of stage's ancestors is target. - */ - private def stageDependsOn(stage: Stage, target: Stage): Boolean = { - if (stage == target) { - return true - } - val visitedRdds = new HashSet[RDD[_]] - val visitedStages = new HashSet[Stage] - def visit(rdd: RDD[_]) { - if (!visitedRdds(rdd)) { - visitedRdds += rdd - for (dep <- rdd.dependencies) { - dep match { - case shufDep: ShuffleDependency[_,_] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) - if (!mapStage.isAvailable) { - visitedStages += mapStage - visit(mapStage.rdd) - } // Otherwise there's no need to follow the dependency back - case narrowDep: NarrowDependency[_] => - visit(narrowDep.rdd) - } - } - } - } - visit(stage.rdd) - visitedRdds.contains(target.rdd) - } - - /** - * Synchronized method that might be called from other threads. - * @param rdd whose partitions are to be looked at - * @param partition to lookup locality information for - * @return list of machines that are preferred by the partition - */ - private[spark] - def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized { - // If the partition is cached, return the cache locations - val cached = getCacheLocs(rdd)(partition) - if (!cached.isEmpty) { - return cached - } - // If the RDD has some placement preferences (as is the case for input RDDs), get those - val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList - if (!rddPrefs.isEmpty) { - return rddPrefs.map(host => TaskLocation(host)) - } - // If the RDD has narrow dependencies, pick the first partition of the first narrow dep - // that has any placement preferences. Ideally we would choose based on transfer sizes, - // but this will do for now. - rdd.dependencies.foreach(_ match { - case n: NarrowDependency[_] => - for (inPart <- n.getParents(partition)) { - val locs = getPreferredLocs(n.rdd, inPart) - if (locs != Nil) - return locs - } - case _ => - }) - Nil - } - - private def cleanup(cleanupTime: Long) { - var sizeBefore = stageIdToStage.size - stageIdToStage.clearOldValues(cleanupTime) - logInfo("stageIdToStage " + sizeBefore + " --> " + stageIdToStage.size) - - sizeBefore = shuffleToMapStage.size - shuffleToMapStage.clearOldValues(cleanupTime) - logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size) - - sizeBefore = pendingTasks.size - pendingTasks.clearOldValues(cleanupTime) - logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size) - - sizeBefore = stageToInfos.size - stageToInfos.clearOldValues(cleanupTime) - logInfo("stageToInfos " + sizeBefore + " --> " + stageToInfos.size) - } - - def stop() { - eventQueue.put(StopDAGScheduler) - metadataCleaner.cancel() - taskSched.stop() - } -} diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala deleted file mode 100644 index b8ba0e9239..0000000000 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.util.Properties - -import spark.scheduler.cluster.TaskInfo -import scala.collection.mutable.Map - -import spark._ -import spark.executor.TaskMetrics - -/** - * Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue - * architecture where any thread can post an event (e.g. a task finishing or a new job being - * submitted) but there is a single "logic" thread that reads these events and takes decisions. - * This greatly simplifies synchronization. - */ -private[spark] sealed trait DAGSchedulerEvent - -private[spark] case class JobSubmitted( - finalRDD: RDD[_], - func: (TaskContext, Iterator[_]) => _, - partitions: Array[Int], - allowLocal: Boolean, - callSite: String, - listener: JobListener, - properties: Properties = null) - extends DAGSchedulerEvent - -private[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent - -private[spark] case class CompletionEvent( - task: Task[_], - reason: TaskEndReason, - result: Any, - accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics) - extends DAGSchedulerEvent - -private[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent - -private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent - -private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent - -private[spark] case object StopDAGScheduler extends DAGSchedulerEvent diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerSource.scala deleted file mode 100644 index 98c4fb7e59..0000000000 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerSource.scala +++ /dev/null @@ -1,30 +0,0 @@ -package spark.scheduler - -import com.codahale.metrics.{Gauge,MetricRegistry} - -import spark.metrics.source.Source - -private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "DAGScheduler" - - metricRegistry.register(MetricRegistry.name("stage", "failedStages", "number"), new Gauge[Int] { - override def getValue: Int = dagScheduler.failed.size - }) - - metricRegistry.register(MetricRegistry.name("stage", "runningStages", "number"), new Gauge[Int] { - override def getValue: Int = dagScheduler.running.size - }) - - metricRegistry.register(MetricRegistry.name("stage", "waitingStages", "number"), new Gauge[Int] { - override def getValue: Int = dagScheduler.waiting.size - }) - - metricRegistry.register(MetricRegistry.name("job", "allJobs", "number"), new Gauge[Int] { - override def getValue: Int = dagScheduler.nextJobId.get() - }) - - metricRegistry.register(MetricRegistry.name("job", "activeJobs", "number"), new Gauge[Int] { - override def getValue: Int = dagScheduler.activeJobs.size - }) -} diff --git a/core/src/main/scala/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/spark/scheduler/InputFormatInfo.scala deleted file mode 100644 index 8f1b9b29b5..0000000000 --- a/core/src/main/scala/spark/scheduler/InputFormatInfo.scala +++ /dev/null @@ -1,178 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import spark.{Logging, SparkEnv} -import scala.collection.immutable.Set -import org.apache.hadoop.mapred.{FileInputFormat, JobConf} -import org.apache.hadoop.security.UserGroupInformation -import org.apache.hadoop.util.ReflectionUtils -import org.apache.hadoop.mapreduce.Job -import org.apache.hadoop.conf.Configuration -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.collection.JavaConversions._ - - -/** - * Parses and holds information about inputFormat (and files) specified as a parameter. - */ -class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_], - val path: String) extends Logging { - - var mapreduceInputFormat: Boolean = false - var mapredInputFormat: Boolean = false - - validate() - - override def toString(): String = { - "InputFormatInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + ", path : " + path - } - - override def hashCode(): Int = { - var hashCode = inputFormatClazz.hashCode - hashCode = hashCode * 31 + path.hashCode - hashCode - } - - // Since we are not doing canonicalization of path, this can be wrong : like relative vs absolute path - // .. which is fine, this is best case effort to remove duplicates - right ? - override def equals(other: Any): Boolean = other match { - case that: InputFormatInfo => { - // not checking config - that should be fine, right ? - this.inputFormatClazz == that.inputFormatClazz && - this.path == that.path - } - case _ => false - } - - private def validate() { - logDebug("validate InputFormatInfo : " + inputFormatClazz + ", path " + path) - - try { - if (classOf[org.apache.hadoop.mapreduce.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) { - logDebug("inputformat is from mapreduce package") - mapreduceInputFormat = true - } - else if (classOf[org.apache.hadoop.mapred.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) { - logDebug("inputformat is from mapred package") - mapredInputFormat = true - } - else { - throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + - " is NOT a supported input format ? does not implement either of the supported hadoop api's") - } - } - catch { - case e: ClassNotFoundException => { - throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + " cannot be found ?", e) - } - } - } - - - // This method does not expect failures, since validate has already passed ... - private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = { - val env = SparkEnv.get - val conf = new JobConf(configuration) - env.hadoop.addCredentials(conf) - FileInputFormat.setInputPaths(conf, path) - - val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] = - ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[ - org.apache.hadoop.mapreduce.InputFormat[_, _]] - val job = new Job(conf) - - val retval = new ArrayBuffer[SplitInfo]() - val list = instance.getSplits(job) - for (split <- list) { - retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split) - } - - return retval.toSet - } - - // This method does not expect failures, since validate has already passed ... - private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = { - val env = SparkEnv.get - val jobConf = new JobConf(configuration) - env.hadoop.addCredentials(jobConf) - FileInputFormat.setInputPaths(jobConf, path) - - val instance: org.apache.hadoop.mapred.InputFormat[_, _] = - ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], jobConf).asInstanceOf[ - org.apache.hadoop.mapred.InputFormat[_, _]] - - val retval = new ArrayBuffer[SplitInfo]() - instance.getSplits(jobConf, jobConf.getNumMapTasks()).foreach( - elem => retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, elem) - ) - - return retval.toSet - } - - private def findPreferredLocations(): Set[SplitInfo] = { - logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat + - ", inputFormatClazz : " + inputFormatClazz) - if (mapreduceInputFormat) { - return prefLocsFromMapreduceInputFormat() - } - else { - assert(mapredInputFormat) - return prefLocsFromMapredInputFormat() - } - } -} - - - - -object InputFormatInfo { - /** - Computes the preferred locations based on input(s) and returned a location to block map. - Typical use of this method for allocation would follow some algo like this - (which is what we currently do in YARN branch) : - a) For each host, count number of splits hosted on that host. - b) Decrement the currently allocated containers on that host. - c) Compute rack info for each host and update rack -> count map based on (b). - d) Allocate nodes based on (c) - e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node - (even if data locality on that is very high) : this is to prevent fragility of job if a single - (or small set of) hosts go down. - - go to (a) until required nodes are allocated. - - If a node 'dies', follow same procedure. - - PS: I know the wording here is weird, hopefully it makes some sense ! - */ - def computePreferredLocations(formats: Seq[InputFormatInfo]): HashMap[String, HashSet[SplitInfo]] = { - - val nodeToSplit = new HashMap[String, HashSet[SplitInfo]] - for (inputSplit <- formats) { - val splits = inputSplit.findPreferredLocations() - - for (split <- splits){ - val location = split.hostLocation - val set = nodeToSplit.getOrElseUpdate(location, new HashSet[SplitInfo]) - set += split - } - } - - nodeToSplit - } -} diff --git a/core/src/main/scala/spark/scheduler/JobListener.scala b/core/src/main/scala/spark/scheduler/JobListener.scala deleted file mode 100644 index af108b8fec..0000000000 --- a/core/src/main/scala/spark/scheduler/JobListener.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -/** - * Interface used to listen for job completion or failure events after submitting a job to the - * DAGScheduler. The listener is notified each time a task succeeds, as well as if the whole - * job fails (and no further taskSucceeded events will happen). - */ -private[spark] trait JobListener { - def taskSucceeded(index: Int, result: Any) - def jobFailed(exception: Exception) -} diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala deleted file mode 100644 index 1bc9fabdff..0000000000 --- a/core/src/main/scala/spark/scheduler/JobLogger.scala +++ /dev/null @@ -1,292 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.io.PrintWriter -import java.io.File -import java.io.FileNotFoundException -import java.text.SimpleDateFormat -import java.util.{Date, Properties} -import java.util.concurrent.LinkedBlockingQueue - -import scala.collection.mutable.{Map, HashMap, ListBuffer} -import scala.io.Source - -import spark._ -import spark.executor.TaskMetrics -import spark.scheduler.cluster.TaskInfo - -// Used to record runtime information for each job, including RDD graph -// tasks' start/stop shuffle information and information from outside - -class JobLogger(val logDirName: String) extends SparkListener with Logging { - private val logDir = - if (System.getenv("SPARK_LOG_DIR") != null) - System.getenv("SPARK_LOG_DIR") - else - "/tmp/spark" - private val jobIDToPrintWriter = new HashMap[Int, PrintWriter] - private val stageIDToJobID = new HashMap[Int, Int] - private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]] - private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") - private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents] - - createLogDir() - def this() = this(String.valueOf(System.currentTimeMillis())) - - def getLogDir = logDir - def getJobIDtoPrintWriter = jobIDToPrintWriter - def getStageIDToJobID = stageIDToJobID - def getJobIDToStages = jobIDToStages - def getEventQueue = eventQueue - - // Create a folder for log files, the folder's name is the creation time of the jobLogger - protected def createLogDir() { - val dir = new File(logDir + "/" + logDirName + "/") - if (dir.exists()) { - return - } - if (dir.mkdirs() == false) { - logError("create log directory error:" + logDir + "/" + logDirName + "/") - } - } - - // Create a log file for one job, the file name is the jobID - protected def createLogWriter(jobID: Int) { - try{ - val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID) - jobIDToPrintWriter += (jobID -> fileWriter) - } catch { - case e: FileNotFoundException => e.printStackTrace() - } - } - - // Close log file, and clean the stage relationship in stageIDToJobID - protected def closeLogWriter(jobID: Int) = - jobIDToPrintWriter.get(jobID).foreach { fileWriter => - fileWriter.close() - jobIDToStages.get(jobID).foreach(_.foreach{ stage => - stageIDToJobID -= stage.id - }) - jobIDToPrintWriter -= jobID - jobIDToStages -= jobID - } - - // Write log information to log file, withTime parameter controls whether to recored - // time stamp for the information - protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) { - var writeInfo = info - if (withTime) { - val date = new Date(System.currentTimeMillis()) - writeInfo = DATE_FORMAT.format(date) + ": " +info - } - jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo)) - } - - protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) = - stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime)) - - protected def buildJobDep(jobID: Int, stage: Stage) { - if (stage.jobId == jobID) { - jobIDToStages.get(jobID) match { - case Some(stageList) => stageList += stage - case None => val stageList = new ListBuffer[Stage] - stageList += stage - jobIDToStages += (jobID -> stageList) - } - stageIDToJobID += (stage.id -> jobID) - stage.parents.foreach(buildJobDep(jobID, _)) - } - } - - protected def recordStageDep(jobID: Int) { - def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = { - var rddList = new ListBuffer[RDD[_]] - rddList += rdd - rdd.dependencies.foreach{ dep => dep match { - case shufDep: ShuffleDependency[_,_] => - case _ => rddList ++= getRddsInStage(dep.rdd) - } - } - rddList - } - jobIDToStages.get(jobID).foreach {_.foreach { stage => - var depRddDesc: String = "" - getRddsInStage(stage.rdd).foreach { rdd => - depRddDesc += rdd.id + "," - } - var depStageDesc: String = "" - stage.parents.foreach { stage => - depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")" - } - jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" + - depRddDesc.substring(0, depRddDesc.length - 1) + ")" + - " STAGE_DEP=" + depStageDesc, false) - } - } - } - - // Generate indents and convert to String - protected def indentString(indent: Int) = { - val sb = new StringBuilder() - for (i <- 1 to indent) { - sb.append(" ") - } - sb.toString() - } - - protected def getRddName(rdd: RDD[_]) = { - var rddName = rdd.getClass.getName - if (rdd.name != null) { - rddName = rdd.name - } - rddName - } - - protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) { - val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")" - jobLogInfo(jobID, indentString(indent) + rddInfo, false) - rdd.dependencies.foreach{ dep => dep match { - case shufDep: ShuffleDependency[_,_] => - val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId - jobLogInfo(jobID, indentString(indent + 1) + depInfo, false) - case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1) - } - } - } - - protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) { - var stageInfo: String = "" - if (stage.isShuffleMap) { - stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + - stage.shuffleDep.get.shuffleId - }else{ - stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE" - } - if (stage.jobId == jobID) { - jobLogInfo(jobID, indentString(indent) + stageInfo, false) - recordRddInStageGraph(jobID, stage.rdd, indent) - stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2)) - } else - jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false) - } - - // Record task metrics into job log files - protected def recordTaskMetrics(stageID: Int, status: String, - taskInfo: TaskInfo, taskMetrics: TaskMetrics) { - val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID + - " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime + - " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname - val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime - val readMetrics = - taskMetrics.shuffleReadMetrics match { - case Some(metrics) => - " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime + - " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched + - " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched + - " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched + - " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime + - " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime + - " REMOTE_BYTES_READ=" + metrics.remoteBytesRead - case None => "" - } - val writeMetrics = - taskMetrics.shuffleWriteMetrics match { - case Some(metrics) => - " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten - case None => "" - } - stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics) - } - - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { - stageLogInfo( - stageSubmitted.stage.id, - "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format( - stageSubmitted.stage.id, stageSubmitted.taskSize)) - } - - override def onStageCompleted(stageCompleted: StageCompleted) { - stageLogInfo( - stageCompleted.stageInfo.stage.id, - "STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id)) - - } - - override def onTaskStart(taskStart: SparkListenerTaskStart) { } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - val task = taskEnd.task - val taskInfo = taskEnd.taskInfo - var taskStatus = "" - task match { - case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK" - case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK" - } - taskEnd.reason match { - case Success => taskStatus += " STATUS=SUCCESS" - recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics) - case Resubmitted => - taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId + - " STAGE_ID=" + task.stageId - stageLogInfo(task.stageId, taskStatus) - case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => - taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" + - task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + - mapId + " REDUCE_ID=" + reduceId - stageLogInfo(task.stageId, taskStatus) - case OtherFailure(message) => - taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId + - " STAGE_ID=" + task.stageId + " INFO=" + message - stageLogInfo(task.stageId, taskStatus) - case _ => - } - } - - override def onJobEnd(jobEnd: SparkListenerJobEnd) { - val job = jobEnd.job - var info = "JOB_ID=" + job.jobId - jobEnd.jobResult match { - case JobSucceeded => info += " STATUS=SUCCESS" - case JobFailed(exception, _) => - info += " STATUS=FAILED REASON=" - exception.getMessage.split("\\s+").foreach(info += _ + "_") - case _ => - } - jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase) - closeLogWriter(job.jobId) - } - - protected def recordJobProperties(jobID: Int, properties: Properties) { - if(properties != null) { - val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "") - jobLogInfo(jobID, description, false) - } - } - - override def onJobStart(jobStart: SparkListenerJobStart) { - val job = jobStart.job - val properties = jobStart.properties - createLogWriter(job.jobId) - recordJobProperties(job.jobId, properties) - buildJobDep(job.jobId, job.finalStage) - recordStageDep(job.jobId) - recordStageDepGraph(job.jobId, job.finalStage) - jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED") - } -} diff --git a/core/src/main/scala/spark/scheduler/JobResult.scala b/core/src/main/scala/spark/scheduler/JobResult.scala deleted file mode 100644 index a61b335152..0000000000 --- a/core/src/main/scala/spark/scheduler/JobResult.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -/** - * A result of a job in the DAGScheduler. - */ -private[spark] sealed trait JobResult - -private[spark] case object JobSucceeded extends JobResult -private[spark] case class JobFailed(exception: Exception, failedStage: Option[Stage]) extends JobResult diff --git a/core/src/main/scala/spark/scheduler/JobWaiter.scala b/core/src/main/scala/spark/scheduler/JobWaiter.scala deleted file mode 100644 index 69cd161c1f..0000000000 --- a/core/src/main/scala/spark/scheduler/JobWaiter.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import scala.collection.mutable.ArrayBuffer - -/** - * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their - * results to the given handler function. - */ -private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit) - extends JobListener { - - private var finishedTasks = 0 - - private var jobFinished = false // Is the job as a whole finished (succeeded or failed)? - private var jobResult: JobResult = null // If the job is finished, this will be its result - - override def taskSucceeded(index: Int, result: Any) { - synchronized { - if (jobFinished) { - throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") - } - resultHandler(index, result.asInstanceOf[T]) - finishedTasks += 1 - if (finishedTasks == totalTasks) { - jobFinished = true - jobResult = JobSucceeded - this.notifyAll() - } - } - } - - override def jobFailed(exception: Exception) { - synchronized { - if (jobFinished) { - throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter") - } - jobFinished = true - jobResult = JobFailed(exception, None) - this.notifyAll() - } - } - - def awaitResult(): JobResult = synchronized { - while (!jobFinished) { - this.wait() - } - return jobResult - } -} diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala deleted file mode 100644 index 2f6a68ee85..0000000000 --- a/core/src/main/scala/spark/scheduler/MapStatus.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import spark.storage.BlockManagerId -import java.io.{ObjectOutput, ObjectInput, Externalizable} - -/** - * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the - * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. - * The map output sizes are compressed using MapOutputTracker.compressSize. - */ -private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte]) - extends Externalizable { - - def this() = this(null, null) // For deserialization only - - def writeExternal(out: ObjectOutput) { - location.writeExternal(out) - out.writeInt(compressedSizes.length) - out.write(compressedSizes) - } - - def readExternal(in: ObjectInput) { - location = BlockManagerId(in) - compressedSizes = new Array[Byte](in.readInt()) - in.readFully(compressedSizes) - } -} diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala deleted file mode 100644 index d066df5dc1..0000000000 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import spark._ -import java.io._ -import util.{MetadataCleaner, TimeStampedHashMap} -import java.util.zip.{GZIPInputStream, GZIPOutputStream} - -private[spark] object ResultTask { - - // A simple map between the stage id to the serialized byte array of a task. - // Served as a cache for task serialization because serialization can be - // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - - val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues) - - def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { - synchronized { - val old = serializedInfoCache.get(stageId).orNull - if (old != null) { - return old - } else { - val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance - val objOut = ser.serializeStream(new GZIPOutputStream(out)) - objOut.writeObject(rdd) - objOut.writeObject(func) - objOut.close() - val bytes = out.toByteArray - serializedInfoCache.put(stageId, bytes) - return bytes - } - } - } - - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = { - val loader = Thread.currentThread.getContextClassLoader - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance - val objIn = ser.deserializeStream(in) - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] - return (rdd, func) - } - - def clearCache() { - synchronized { - serializedInfoCache.clear() - } - } -} - - -private[spark] class ResultTask[T, U]( - stageId: Int, - var rdd: RDD[T], - var func: (TaskContext, Iterator[T]) => U, - var partition: Int, - @transient locs: Seq[TaskLocation], - val outputId: Int) - extends Task[U](stageId) with Externalizable { - - def this() = this(0, null, null, 0, null, 0) - - var split = if (rdd == null) { - null - } else { - rdd.partitions(partition) - } - - @transient private val preferredLocs: Seq[TaskLocation] = { - if (locs == null) Nil else locs.toSet.toSeq - } - - override def run(attemptId: Long): U = { - val context = new TaskContext(stageId, partition, attemptId) - metrics = Some(context.taskMetrics) - try { - func(context, rdd.iterator(split, context)) - } finally { - context.executeOnCompleteCallbacks() - } - } - - override def preferredLocations: Seq[TaskLocation] = preferredLocs - - override def toString = "ResultTask(" + stageId + ", " + partition + ")" - - override def writeExternal(out: ObjectOutput) { - RDDCheckpointData.synchronized { - split = rdd.partitions(partition) - out.writeInt(stageId) - val bytes = ResultTask.serializeInfo( - stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partition) - out.writeInt(outputId) - out.writeLong(epoch) - out.writeObject(split) - } - } - - override def readExternal(in: ObjectInput) { - val stageId = in.readInt() - val numBytes = in.readInt() - val bytes = new Array[Byte](numBytes) - in.readFully(bytes) - val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) - rdd = rdd_.asInstanceOf[RDD[T]] - func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] - partition = in.readInt() - val outputId = in.readInt() - epoch = in.readLong() - split = in.readObject().asInstanceOf[Partition] - } -} diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala deleted file mode 100644 index f2a038576b..0000000000 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.io._ -import java.util.zip.{GZIPInputStream, GZIPOutputStream} - -import scala.collection.mutable.HashMap - -import spark._ -import spark.executor.ShuffleWriteMetrics -import spark.storage._ -import spark.util.{TimeStampedHashMap, MetadataCleaner} - - -private[spark] object ShuffleMapTask { - - // A simple map between the stage id to the serialized byte array of a task. - // Served as a cache for task serialization because serialization can be - // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - - val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.clearOldValues) - - def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { - synchronized { - val old = serializedInfoCache.get(stageId).orNull - if (old != null) { - return old - } else { - val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance() - val objOut = ser.serializeStream(new GZIPOutputStream(out)) - objOut.writeObject(rdd) - objOut.writeObject(dep) - objOut.close() - val bytes = out.toByteArray - serializedInfoCache.put(stageId, bytes) - return bytes - } - } - } - - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = { - synchronized { - val loader = Thread.currentThread.getContextClassLoader - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objIn = ser.deserializeStream(in) - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] - return (rdd, dep) - } - } - - // Since both the JarSet and FileSet have the same format this is used for both. - def deserializeFileSet(bytes: Array[Byte]) : HashMap[String, Long] = { - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val objIn = new ObjectInputStream(in) - val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap - return (HashMap(set.toSeq: _*)) - } - - def clearCache() { - synchronized { - serializedInfoCache.clear() - } - } -} - -private[spark] class ShuffleMapTask( - stageId: Int, - var rdd: RDD[_], - var dep: ShuffleDependency[_,_], - var partition: Int, - @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId) - with Externalizable - with Logging { - - protected def this() = this(0, null, null, 0, null) - - @transient private val preferredLocs: Seq[TaskLocation] = { - if (locs == null) Nil else locs.toSet.toSeq - } - - var split = if (rdd == null) null else rdd.partitions(partition) - - override def writeExternal(out: ObjectOutput) { - RDDCheckpointData.synchronized { - split = rdd.partitions(partition) - out.writeInt(stageId) - val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partition) - out.writeLong(epoch) - out.writeObject(split) - } - } - - override def readExternal(in: ObjectInput) { - val stageId = in.readInt() - val numBytes = in.readInt() - val bytes = new Array[Byte](numBytes) - in.readFully(bytes) - val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) - rdd = rdd_ - dep = dep_ - partition = in.readInt() - epoch = in.readLong() - split = in.readObject().asInstanceOf[Partition] - } - - override def run(attemptId: Long): MapStatus = { - val numOutputSplits = dep.partitioner.numPartitions - - val taskContext = new TaskContext(stageId, partition, attemptId) - metrics = Some(taskContext.taskMetrics) - - val blockManager = SparkEnv.get.blockManager - var shuffle: ShuffleBlocks = null - var buckets: ShuffleWriterGroup = null - - try { - // Obtain all the block writers for shuffle blocks. - val ser = SparkEnv.get.serializerManager.get(dep.serializerClass) - shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser) - buckets = shuffle.acquireWriters(partition) - - // Write the map output to its associated buckets. - for (elem <- rdd.iterator(split, taskContext)) { - val pair = elem.asInstanceOf[Product2[Any, Any]] - val bucketId = dep.partitioner.getPartition(pair._1) - buckets.writers(bucketId).write(pair) - } - - // Commit the writes. Get the size of each bucket block (total block size). - var totalBytes = 0L - val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter => - writer.commit() - writer.close() - val size = writer.size() - totalBytes += size - MapOutputTracker.compressSize(size) - } - - // Update shuffle metrics. - val shuffleMetrics = new ShuffleWriteMetrics - shuffleMetrics.shuffleBytesWritten = totalBytes - metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) - - return new MapStatus(blockManager.blockManagerId, compressedSizes) - } catch { case e: Exception => - // If there is an exception from running the task, revert the partial writes - // and throw the exception upstream to Spark. - if (buckets != null) { - buckets.writers.foreach(_.revertPartialWrites()) - } - throw e - } finally { - // Release the writers back to the shuffle block manager. - if (shuffle != null && buckets != null) { - shuffle.releaseWriters(buckets) - } - // Execute the callbacks on task completion. - taskContext.executeOnCompleteCallbacks() - } - } - - override def preferredLocations: Seq[TaskLocation] = preferredLocs - - override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition) -} diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala deleted file mode 100644 index e5531011c2..0000000000 --- a/core/src/main/scala/spark/scheduler/SparkListener.scala +++ /dev/null @@ -1,204 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.util.Properties -import spark.scheduler.cluster.TaskInfo -import spark.util.Distribution -import spark.{Logging, SparkContext, TaskEndReason, Utils} -import spark.executor.TaskMetrics - -sealed trait SparkListenerEvents - -case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int, properties: Properties) - extends SparkListenerEvents - -case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents - -case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents - -case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo, - taskMetrics: TaskMetrics) extends SparkListenerEvents - -case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null) - extends SparkListenerEvents - -case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult) - extends SparkListenerEvents - -trait SparkListener { - /** - * Called when a stage is completed, with information on the completed stage - */ - def onStageCompleted(stageCompleted: StageCompleted) { } - - /** - * Called when a stage is submitted - */ - def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { } - - /** - * Called when a task starts - */ - def onTaskStart(taskEnd: SparkListenerTaskStart) { } - - /** - * Called when a task ends - */ - def onTaskEnd(taskEnd: SparkListenerTaskEnd) { } - - /** - * Called when a job starts - */ - def onJobStart(jobStart: SparkListenerJobStart) { } - - /** - * Called when a job ends - */ - def onJobEnd(jobEnd: SparkListenerJobEnd) { } - -} - -/** - * Simple SparkListener that logs a few summary statistics when each stage completes - */ -class StatsReportListener extends SparkListener with Logging { - override def onStageCompleted(stageCompleted: StageCompleted) { - import spark.scheduler.StatsReportListener._ - implicit val sc = stageCompleted - this.logInfo("Finished stage: " + stageCompleted.stageInfo) - showMillisDistribution("task runtime:", (info, _) => Some(info.duration)) - - //shuffle write - showBytesDistribution("shuffle bytes written:",(_,metric) => metric.shuffleWriteMetrics.map{_.shuffleBytesWritten}) - - //fetch & io - showMillisDistribution("fetch wait time:",(_, metric) => metric.shuffleReadMetrics.map{_.fetchWaitTime}) - showBytesDistribution("remote bytes read:", (_, metric) => metric.shuffleReadMetrics.map{_.remoteBytesRead}) - showBytesDistribution("task result size:", (_, metric) => Some(metric.resultSize)) - - //runtime breakdown - - val runtimePcts = stageCompleted.stageInfo.taskInfos.map{ - case (info, metrics) => RuntimePercentage(info.duration, metrics) - } - showDistribution("executor (non-fetch) time pct: ", Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%") - showDistribution("fetch wait time pct: ", Distribution(runtimePcts.flatMap{_.fetchPct.map{_ * 100}}), "%2.0f %%") - showDistribution("other time pct: ", Distribution(runtimePcts.map{_.other * 100}), "%2.0f %%") - } - -} - -object StatsReportListener extends Logging { - - //for profiling, the extremes are more interesting - val percentiles = Array[Int](0,5,10,25,50,75,90,95,100) - val probabilities = percentiles.map{_ / 100.0} - val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%" - - def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = { - Distribution(stage.stageInfo.taskInfos.flatMap{ - case ((info,metric)) => getMetric(info, metric)}) - } - - //is there some way to setup the types that I can get rid of this completely? - def extractLongDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Long]): Option[Distribution] = { - extractDoubleDistribution(stage, (info, metric) => getMetric(info,metric).map{_.toDouble}) - } - - def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) { - val stats = d.statCounter - logInfo(heading + stats) - val quantiles = d.getQuantiles(probabilities).map{formatNumber} - logInfo(percentilesHeader) - logInfo("\t" + quantiles.mkString("\t")) - } - - def showDistribution(heading: String, dOpt: Option[Distribution], formatNumber: Double => String) { - dOpt.foreach { d => showDistribution(heading, d, formatNumber)} - } - - def showDistribution(heading: String, dOpt: Option[Distribution], format:String) { - def f(d:Double) = format.format(d) - showDistribution(heading, dOpt, f _) - } - - def showDistribution(heading:String, format: String, getMetric: (TaskInfo,TaskMetrics) => Option[Double]) - (implicit stage: StageCompleted) { - showDistribution(heading, extractDoubleDistribution(stage, getMetric), format) - } - - def showBytesDistribution(heading:String, getMetric: (TaskInfo,TaskMetrics) => Option[Long]) - (implicit stage: StageCompleted) { - showBytesDistribution(heading, extractLongDistribution(stage, getMetric)) - } - - def showBytesDistribution(heading: String, dOpt: Option[Distribution]) { - dOpt.foreach{dist => showBytesDistribution(heading, dist)} - } - - def showBytesDistribution(heading: String, dist: Distribution) { - showDistribution(heading, dist, (d => Utils.bytesToString(d.toLong)): Double => String) - } - - def showMillisDistribution(heading: String, dOpt: Option[Distribution]) { - showDistribution(heading, dOpt, (d => StatsReportListener.millisToString(d.toLong)): Double => String) - } - - def showMillisDistribution(heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long]) - (implicit stage: StageCompleted) { - showMillisDistribution(heading, extractLongDistribution(stage, getMetric)) - } - - - - val seconds = 1000L - val minutes = seconds * 60 - val hours = minutes * 60 - - /** - * reformat a time interval in milliseconds to a prettier format for output - */ - def millisToString(ms: Long) = { - val (size, units) = - if (ms > hours) { - (ms.toDouble / hours, "hours") - } else if (ms > minutes) { - (ms.toDouble / minutes, "min") - } else if (ms > seconds) { - (ms.toDouble / seconds, "s") - } else { - (ms.toDouble, "ms") - } - "%.1f %s".format(size, units) - } -} - - - -case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double) -object RuntimePercentage { - def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = { - val denom = totalTime.toDouble - val fetchTime = metrics.shuffleReadMetrics.map{_.fetchWaitTime} - val fetch = fetchTime.map{_ / denom} - val exec = (metrics.executorRunTime - fetchTime.getOrElse(0l)) / denom - val other = 1.0 - (exec + fetch.getOrElse(0d)) - RuntimePercentage(exec, fetch, other) - } -} diff --git a/core/src/main/scala/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/spark/scheduler/SparkListenerBus.scala deleted file mode 100644 index f55ed455ed..0000000000 --- a/core/src/main/scala/spark/scheduler/SparkListenerBus.scala +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.util.concurrent.LinkedBlockingQueue - -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} - -import spark.Logging - -/** Asynchronously passes SparkListenerEvents to registered SparkListeners. */ -private[spark] class SparkListenerBus() extends Logging { - private val sparkListeners = new ArrayBuffer[SparkListener]() with SynchronizedBuffer[SparkListener] - - /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than - * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ - private val EVENT_QUEUE_CAPACITY = 10000 - private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents](EVENT_QUEUE_CAPACITY) - private var queueFullErrorMessageLogged = false - - new Thread("SparkListenerBus") { - setDaemon(true) - override def run() { - while (true) { - val event = eventQueue.take - event match { - case stageSubmitted: SparkListenerStageSubmitted => - sparkListeners.foreach(_.onStageSubmitted(stageSubmitted)) - case stageCompleted: StageCompleted => - sparkListeners.foreach(_.onStageCompleted(stageCompleted)) - case jobStart: SparkListenerJobStart => - sparkListeners.foreach(_.onJobStart(jobStart)) - case jobEnd: SparkListenerJobEnd => - sparkListeners.foreach(_.onJobEnd(jobEnd)) - case taskStart: SparkListenerTaskStart => - sparkListeners.foreach(_.onTaskStart(taskStart)) - case taskEnd: SparkListenerTaskEnd => - sparkListeners.foreach(_.onTaskEnd(taskEnd)) - case _ => - } - } - } - }.start() - - def addListener(listener: SparkListener) { - sparkListeners += listener - } - - def post(event: SparkListenerEvents) { - val eventAdded = eventQueue.offer(event) - if (!eventAdded && !queueFullErrorMessageLogged) { - logError("Dropping SparkListenerEvent because no remaining room in event queue. " + - "This likely means one of the SparkListeners is too slow and cannot keep up with the " + - "rate at which tasks are being started by the scheduler.") - queueFullErrorMessageLogged = true - } - } -} - diff --git a/core/src/main/scala/spark/scheduler/SplitInfo.scala b/core/src/main/scala/spark/scheduler/SplitInfo.scala deleted file mode 100644 index 4e3661ec5d..0000000000 --- a/core/src/main/scala/spark/scheduler/SplitInfo.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import collection.mutable.ArrayBuffer - -// information about a specific split instance : handles both split instances. -// So that we do not need to worry about the differences. -class SplitInfo(val inputFormatClazz: Class[_], val hostLocation: String, val path: String, - val length: Long, val underlyingSplit: Any) { - override def toString(): String = { - "SplitInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + - ", hostLocation : " + hostLocation + ", path : " + path + - ", length : " + length + ", underlyingSplit " + underlyingSplit - } - - override def hashCode(): Int = { - var hashCode = inputFormatClazz.hashCode - hashCode = hashCode * 31 + hostLocation.hashCode - hashCode = hashCode * 31 + path.hashCode - // ignore overflow ? It is hashcode anyway ! - hashCode = hashCode * 31 + (length & 0x7fffffff).toInt - hashCode - } - - // This is practically useless since most of the Split impl's dont seem to implement equals :-( - // So unless there is identity equality between underlyingSplits, it will always fail even if it - // is pointing to same block. - override def equals(other: Any): Boolean = other match { - case that: SplitInfo => { - this.hostLocation == that.hostLocation && - this.inputFormatClazz == that.inputFormatClazz && - this.path == that.path && - this.length == that.length && - // other split specific checks (like start for FileSplit) - this.underlyingSplit == that.underlyingSplit - } - case _ => false - } -} - -object SplitInfo { - - def toSplitInfo(inputFormatClazz: Class[_], path: String, - mapredSplit: org.apache.hadoop.mapred.InputSplit): Seq[SplitInfo] = { - val retval = new ArrayBuffer[SplitInfo]() - val length = mapredSplit.getLength - for (host <- mapredSplit.getLocations) { - retval += new SplitInfo(inputFormatClazz, host, path, length, mapredSplit) - } - retval - } - - def toSplitInfo(inputFormatClazz: Class[_], path: String, - mapreduceSplit: org.apache.hadoop.mapreduce.InputSplit): Seq[SplitInfo] = { - val retval = new ArrayBuffer[SplitInfo]() - val length = mapreduceSplit.getLength - for (host <- mapreduceSplit.getLocations) { - retval += new SplitInfo(inputFormatClazz, host, path, length, mapreduceSplit) - } - retval - } -} diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala deleted file mode 100644 index c599c00ac4..0000000000 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.net.URI - -import spark._ -import spark.storage.BlockManagerId - -/** - * A stage is a set of independent tasks all computing the same function that need to run as part - * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run - * by the scheduler is split up into stages at the boundaries where shuffle occurs, and then the - * DAGScheduler runs these stages in topological order. - * - * Each Stage can either be a shuffle map stage, in which case its tasks' results are input for - * another stage, or a result stage, in which case its tasks directly compute the action that - * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes - * that each output partition is on. - * - * Each Stage also has a jobId, identifying the job that first submitted the stage. When FIFO - * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered - * faster on failure. - */ -private[spark] class Stage( - val id: Int, - val rdd: RDD[_], - val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage - val parents: List[Stage], - val jobId: Int, - callSite: Option[String]) - extends Logging { - - val isShuffleMap = shuffleDep != None - val numPartitions = rdd.partitions.size - val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) - var numAvailableOutputs = 0 - - /** When first task was submitted to scheduler. */ - var submissionTime: Option[Long] = None - var completionTime: Option[Long] = None - - private var nextAttemptId = 0 - - def isAvailable: Boolean = { - if (!isShuffleMap) { - true - } else { - numAvailableOutputs == numPartitions - } - } - - def addOutputLoc(partition: Int, status: MapStatus) { - val prevList = outputLocs(partition) - outputLocs(partition) = status :: prevList - if (prevList == Nil) - numAvailableOutputs += 1 - } - - def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location == bmAddress) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - numAvailableOutputs -= 1 - } - } - - def removeOutputsOnExecutor(execId: String) { - var becameUnavailable = false - for (partition <- 0 until numPartitions) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location.executorId == execId) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - becameUnavailable = true - numAvailableOutputs -= 1 - } - } - if (becameUnavailable) { - logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( - this, execId, numAvailableOutputs, numPartitions, isAvailable)) - } - } - - def newAttemptId(): Int = { - val id = nextAttemptId - nextAttemptId += 1 - return id - } - - val name = callSite.getOrElse(rdd.origin) - - override def toString = "Stage " + id - - override def hashCode(): Int = id -} diff --git a/core/src/main/scala/spark/scheduler/StageInfo.scala b/core/src/main/scala/spark/scheduler/StageInfo.scala deleted file mode 100644 index c4026f995a..0000000000 --- a/core/src/main/scala/spark/scheduler/StageInfo.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import spark.scheduler.cluster.TaskInfo -import scala.collection._ -import spark.executor.TaskMetrics - -case class StageInfo( - val stage: Stage, - val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]() -) { - override def toString = stage.rdd.toString -} diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala deleted file mode 100644 index 0ab2ae6cfe..0000000000 --- a/core/src/main/scala/spark/scheduler/Task.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import spark.serializer.SerializerInstance -import java.io.{DataInputStream, DataOutputStream} -import java.nio.ByteBuffer -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import spark.util.ByteBufferInputStream -import scala.collection.mutable.HashMap -import spark.executor.TaskMetrics - -/** - * A task to execute on a worker node. - */ -private[spark] abstract class Task[T](val stageId: Int) extends Serializable { - def run(attemptId: Long): T - def preferredLocations: Seq[TaskLocation] = Nil - - var epoch: Long = -1 // Map output tracker epoch. Will be set by TaskScheduler. - - var metrics: Option[TaskMetrics] = None - -} - -/** - * Handles transmission of tasks and their dependencies, because this can be slightly tricky. We - * need to send the list of JARs and files added to the SparkContext with each task to ensure that - * worker nodes find out about it, but we can't make it part of the Task because the user's code in - * the task might depend on one of the JARs. Thus we serialize each task as multiple objects, by - * first writing out its dependencies. - */ -private[spark] object Task { - /** - * Serialize a task and the current app dependencies (files and JARs added to the SparkContext) - */ - def serializeWithDependencies( - task: Task[_], - currentFiles: HashMap[String, Long], - currentJars: HashMap[String, Long], - serializer: SerializerInstance) - : ByteBuffer = { - - val out = new FastByteArrayOutputStream(4096) - val dataOut = new DataOutputStream(out) - - // Write currentFiles - dataOut.writeInt(currentFiles.size) - for ((name, timestamp) <- currentFiles) { - dataOut.writeUTF(name) - dataOut.writeLong(timestamp) - } - - // Write currentJars - dataOut.writeInt(currentJars.size) - for ((name, timestamp) <- currentJars) { - dataOut.writeUTF(name) - dataOut.writeLong(timestamp) - } - - // Write the task itself and finish - dataOut.flush() - val taskBytes = serializer.serialize(task).array() - out.write(taskBytes) - out.trim() - ByteBuffer.wrap(out.array) - } - - /** - * Deserialize the list of dependencies in a task serialized with serializeWithDependencies, - * and return the task itself as a serialized ByteBuffer. The caller can then update its - * ClassLoaders and deserialize the task. - * - * @return (taskFiles, taskJars, taskBytes) - */ - def deserializeWithDependencies(serializedTask: ByteBuffer) - : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = { - - val in = new ByteBufferInputStream(serializedTask) - val dataIn = new DataInputStream(in) - - // Read task's files - val taskFiles = new HashMap[String, Long]() - val numFiles = dataIn.readInt() - for (i <- 0 until numFiles) { - taskFiles(dataIn.readUTF()) = dataIn.readLong() - } - - // Read task's JARs - val taskJars = new HashMap[String, Long]() - val numJars = dataIn.readInt() - for (i <- 0 until numJars) { - taskJars(dataIn.readUTF()) = dataIn.readLong() - } - - // Create a sub-buffer for the rest of the data, which is the serialized Task object - val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task - (taskFiles, taskJars, subBuffer) - } -} diff --git a/core/src/main/scala/spark/scheduler/TaskLocation.scala b/core/src/main/scala/spark/scheduler/TaskLocation.scala deleted file mode 100644 index fea117e956..0000000000 --- a/core/src/main/scala/spark/scheduler/TaskLocation.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -/** - * A location where a task should run. This can either be a host or a (host, executorID) pair. - * In the latter case, we will prefer to launch the task on that executorID, but our next level - * of preference will be executors on the same host if this is not possible. - */ -private[spark] -class TaskLocation private (val host: String, val executorId: Option[String]) extends Serializable { - override def toString: String = "TaskLocation(" + host + ", " + executorId + ")" -} - -private[spark] object TaskLocation { - def apply(host: String, executorId: String) = new TaskLocation(host, Some(executorId)) - - def apply(host: String) = new TaskLocation(host, None) -} diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala deleted file mode 100644 index fc4856756b..0000000000 --- a/core/src/main/scala/spark/scheduler/TaskResult.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.io._ - -import scala.collection.mutable.Map -import spark.executor.TaskMetrics -import spark.{Utils, SparkEnv} -import java.nio.ByteBuffer - -// Task result. Also contains updates to accumulator variables. -// TODO: Use of distributed cache to return result is a hack to get around -// what seems to be a bug with messages over 60KB in libprocess; fix it -private[spark] -class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics) - extends Externalizable -{ - def this() = this(null.asInstanceOf[T], null, null) - - override def writeExternal(out: ObjectOutput) { - - val objectSer = SparkEnv.get.serializer.newInstance() - val bb = objectSer.serialize(value) - - out.writeInt(bb.remaining()) - Utils.writeByteBuffer(bb, out) - - out.writeInt(accumUpdates.size) - for ((key, value) <- accumUpdates) { - out.writeLong(key) - out.writeObject(value) - } - out.writeObject(metrics) - } - - override def readExternal(in: ObjectInput) { - - val objectSer = SparkEnv.get.serializer.newInstance() - - val blen = in.readInt() - val byteVal = new Array[Byte](blen) - in.readFully(byteVal) - value = objectSer.deserialize(ByteBuffer.wrap(byteVal)) - - val numUpdates = in.readInt - if (numUpdates == 0) { - accumUpdates = null - } else { - accumUpdates = Map() - for (i <- 0 until numUpdates) { - accumUpdates(in.readLong()) = in.readObject() - } - } - metrics = in.readObject().asInstanceOf[TaskMetrics] - } -} diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala deleted file mode 100644 index 4943d58e25..0000000000 --- a/core/src/main/scala/spark/scheduler/TaskScheduler.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import spark.scheduler.cluster.Pool -import spark.scheduler.cluster.SchedulingMode.SchedulingMode -/** - * Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler. - * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage, - * and are responsible for sending the tasks to the cluster, running them, retrying if there - * are failures, and mitigating stragglers. They return events to the DAGScheduler through - * the TaskSchedulerListener interface. - */ -private[spark] trait TaskScheduler { - - def rootPool: Pool - - def schedulingMode: SchedulingMode - - def start(): Unit - - // Invoked after system has successfully initialized (typically in spark context). - // Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc. - def postStartHook() { } - - // Disconnect from the cluster. - def stop(): Unit - - // Submit a sequence of tasks to run. - def submitTasks(taskSet: TaskSet): Unit - - // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called. - def setListener(listener: TaskSchedulerListener): Unit - - // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. - def defaultParallelism(): Int -} diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala deleted file mode 100644 index 64be50b2d0..0000000000 --- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import spark.scheduler.cluster.TaskInfo -import scala.collection.mutable.Map - -import spark.TaskEndReason -import spark.executor.TaskMetrics - -/** - * Interface for getting events back from the TaskScheduler. - */ -private[spark] trait TaskSchedulerListener { - // A task has started. - def taskStarted(task: Task[_], taskInfo: TaskInfo) - - // A task has finished or failed. - def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit - - // A node was added to the cluster. - def executorGained(execId: String, host: String): Unit - - // A node was lost from the cluster. - def executorLost(execId: String): Unit - - // The TaskScheduler wants to abort an entire task set. - def taskSetFailed(taskSet: TaskSet, reason: String): Unit -} diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala deleted file mode 100644 index dc3550dd0b..0000000000 --- a/core/src/main/scala/spark/scheduler/TaskSet.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.util.Properties - -/** - * A set of tasks submitted together to the low-level TaskScheduler, usually representing - * missing partitions of a particular stage. - */ -private[spark] class TaskSet( - val tasks: Array[Task[_]], - val stageId: Int, - val attempt: Int, - val priority: Int, - val properties: Properties) { - val id: String = stageId + "." + attempt - - override def toString: String = "TaskSet " + id -} diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala deleted file mode 100644 index 679d899b47..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ /dev/null @@ -1,440 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import java.lang.{Boolean => JBoolean} - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -import spark._ -import spark.TaskState.TaskState -import spark.scheduler._ -import spark.scheduler.cluster.SchedulingMode.SchedulingMode -import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicLong -import java.util.{TimerTask, Timer} - -/** - * The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call - * initialize() and start(), then submit task sets through the runTasks method. - * - * This class can work with multiple types of clusters by acting through a SchedulerBackend. - * It handles common logic, like determining a scheduling order across jobs, waking up to launch - * speculative tasks, etc. - * - * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple - * threads, so it needs locks in public API methods to maintain its state. In addition, some - * SchedulerBackends sycnchronize on themselves when they want to send events here, and then - * acquire a lock on us, so we need to make sure that we don't try to lock the backend while - * we are holding a lock on ourselves. - */ -private[spark] class ClusterScheduler(val sc: SparkContext) - extends TaskScheduler - with Logging -{ - // How often to check for speculative tasks - val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong - - // Threshold above which we warn user initial TaskSet may be starved - val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong - - val activeTaskSets = new HashMap[String, TaskSetManager] - - val taskIdToTaskSetId = new HashMap[Long, String] - val taskIdToExecutorId = new HashMap[Long, String] - val taskSetTaskIds = new HashMap[String, HashSet[Long]] - - @volatile private var hasReceivedTask = false - @volatile private var hasLaunchedTask = false - private val starvationTimer = new Timer(true) - - // Incrementing Mesos task IDs - val nextTaskId = new AtomicLong(0) - - // Which executor IDs we have executors on - val activeExecutorIds = new HashSet[String] - - // The set of executors we have on each host; this is used to compute hostsAlive, which - // in turn is used to decide when we can attain data locality on a given host - private val executorsByHost = new HashMap[String, HashSet[String]] - - private val executorIdToHost = new HashMap[String, String] - - // JAR server, if any JARs were added by the user to the SparkContext - var jarServer: HttpServer = null - - // URIs of JARs to pass to executor - var jarUris: String = "" - - // Listener object to pass upcalls into - var listener: TaskSchedulerListener = null - - var backend: SchedulerBackend = null - - val mapOutputTracker = SparkEnv.get.mapOutputTracker - - var schedulableBuilder: SchedulableBuilder = null - var rootPool: Pool = null - // default scheduler is FIFO - val schedulingMode: SchedulingMode = SchedulingMode.withName( - System.getProperty("spark.cluster.schedulingmode", "FIFO")) - - override def setListener(listener: TaskSchedulerListener) { - this.listener = listener - } - - def initialize(context: SchedulerBackend) { - backend = context - // temporarily set rootPool name to empty - rootPool = new Pool("", schedulingMode, 0, 0) - schedulableBuilder = { - schedulingMode match { - case SchedulingMode.FIFO => - new FIFOSchedulableBuilder(rootPool) - case SchedulingMode.FAIR => - new FairSchedulableBuilder(rootPool) - } - } - schedulableBuilder.buildPools() - } - - def newTaskId(): Long = nextTaskId.getAndIncrement() - - override def start() { - backend.start() - - if (System.getProperty("spark.speculation", "false").toBoolean) { - new Thread("ClusterScheduler speculation check") { - setDaemon(true) - - override def run() { - logInfo("Starting speculative execution thread") - while (true) { - try { - Thread.sleep(SPECULATION_INTERVAL) - } catch { - case e: InterruptedException => {} - } - checkSpeculatableTasks() - } - } - }.start() - } - } - - override def submitTasks(taskSet: TaskSet) { - val tasks = taskSet.tasks - logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") - this.synchronized { - val manager = new ClusterTaskSetManager(this, taskSet) - activeTaskSets(taskSet.id) = manager - schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) - taskSetTaskIds(taskSet.id) = new HashSet[Long]() - - if (!hasReceivedTask) { - starvationTimer.scheduleAtFixedRate(new TimerTask() { - override def run() { - if (!hasLaunchedTask) { - logWarning("Initial job has not accepted any resources; " + - "check your cluster UI to ensure that workers are registered " + - "and have sufficient memory") - } else { - this.cancel() - } - } - }, STARVATION_TIMEOUT, STARVATION_TIMEOUT) - } - hasReceivedTask = true - } - backend.reviveOffers() - } - - def taskSetFinished(manager: TaskSetManager) { - this.synchronized { - activeTaskSets -= manager.taskSet.id - manager.parent.removeSchedulable(manager) - logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) - taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) - taskSetTaskIds.remove(manager.taskSet.id) - } - } - - /** - * Called by cluster manager to offer resources on slaves. We respond by asking our active task - * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so - * that tasks are balanced across the cluster. - */ - def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { - SparkEnv.set(sc.env) - - // Mark each slave as alive and remember its hostname - for (o <- offers) { - executorIdToHost(o.executorId) = o.host - if (!executorsByHost.contains(o.host)) { - executorsByHost(o.host) = new HashSet[String]() - executorGained(o.executorId, o.host) - } - } - - // Build a list of tasks to assign to each worker - val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) - val availableCpus = offers.map(o => o.cores).toArray - val sortedTaskSets = rootPool.getSortedTaskSetQueue() - for (taskSet <- sortedTaskSets) { - logDebug("parentName: %s, name: %s, runningTasks: %s".format( - taskSet.parent.name, taskSet.name, taskSet.runningTasks)) - } - - // Take each TaskSet in our scheduling order, and then offer it each node in increasing order - // of locality levels so that it gets a chance to launch local tasks on all of them. - var launchedTask = false - for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) { - do { - launchedTask = false - for (i <- 0 until offers.size) { - val execId = offers(i).executorId - val host = offers(i).host - for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) { - tasks(i) += task - val tid = task.taskId - taskIdToTaskSetId(tid) = taskSet.taskSet.id - taskSetTaskIds(taskSet.taskSet.id) += tid - taskIdToExecutorId(tid) = execId - activeExecutorIds += execId - executorsByHost(host) += execId - availableCpus(i) -= 1 - launchedTask = true - } - } - } while (launchedTask) - } - - if (tasks.size > 0) { - hasLaunchedTask = true - } - return tasks - } - - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - var taskSetToUpdate: Option[TaskSetManager] = None - var failedExecutor: Option[String] = None - var taskFailed = false - synchronized { - try { - if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { - // We lost this entire executor, so remember that it's gone - val execId = taskIdToExecutorId(tid) - if (activeExecutorIds.contains(execId)) { - removeExecutor(execId) - failedExecutor = Some(execId) - } - } - taskIdToTaskSetId.get(tid) match { - case Some(taskSetId) => - if (activeTaskSets.contains(taskSetId)) { - taskSetToUpdate = Some(activeTaskSets(taskSetId)) - } - if (TaskState.isFinished(state)) { - taskIdToTaskSetId.remove(tid) - if (taskSetTaskIds.contains(taskSetId)) { - taskSetTaskIds(taskSetId) -= tid - } - taskIdToExecutorId.remove(tid) - } - if (state == TaskState.FAILED) { - taskFailed = true - } - case None => - logInfo("Ignoring update from TID " + tid + " because its task set is gone") - } - } catch { - case e: Exception => logError("Exception in statusUpdate", e) - } - } - // Update the task set and DAGScheduler without holding a lock on this, since that can deadlock - if (taskSetToUpdate != None) { - taskSetToUpdate.get.statusUpdate(tid, state, serializedData) - } - if (failedExecutor != None) { - listener.executorLost(failedExecutor.get) - backend.reviveOffers() - } - if (taskFailed) { - // Also revive offers if a task had failed for some reason other than host lost - backend.reviveOffers() - } - } - - def error(message: String) { - synchronized { - if (activeTaskSets.size > 0) { - // Have each task set throw a SparkException with the error - for ((taskSetId, manager) <- activeTaskSets) { - try { - manager.error(message) - } catch { - case e: Exception => logError("Exception in error callback", e) - } - } - } else { - // No task sets are active but we still got an error. Just exit since this - // must mean the error is during registration. - // It might be good to do something smarter here in the future. - logError("Exiting due to error from cluster scheduler: " + message) - System.exit(1) - } - } - } - - override def stop() { - if (backend != null) { - backend.stop() - } - if (jarServer != null) { - jarServer.stop() - } - - // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out. - // TODO: Do something better ! - Thread.sleep(5000L) - } - - override def defaultParallelism() = backend.defaultParallelism() - - - // Check for speculatable tasks in all our active jobs. - def checkSpeculatableTasks() { - var shouldRevive = false - synchronized { - shouldRevive = rootPool.checkSpeculatableTasks() - } - if (shouldRevive) { - backend.reviveOffers() - } - } - - // Check for pending tasks in all our active jobs. - def hasPendingTasks: Boolean = { - synchronized { - rootPool.hasPendingTasks() - } - } - - def executorLost(executorId: String, reason: ExecutorLossReason) { - var failedExecutor: Option[String] = None - - synchronized { - if (activeExecutorIds.contains(executorId)) { - val hostPort = executorIdToHost(executorId) - logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) - removeExecutor(executorId) - failedExecutor = Some(executorId) - } else { - // We may get multiple executorLost() calls with different loss reasons. For example, one - // may be triggered by a dropped connection from the slave while another may be a report - // of executor termination from Mesos. We produce log messages for both so we eventually - // report the termination reason. - logError("Lost an executor " + executorId + " (already removed): " + reason) - } - } - // Call listener.executorLost without holding the lock on this to prevent deadlock - if (failedExecutor != None) { - listener.executorLost(failedExecutor.get) - backend.reviveOffers() - } - } - - /** Remove an executor from all our data structures and mark it as lost */ - private def removeExecutor(executorId: String) { - activeExecutorIds -= executorId - val host = executorIdToHost(executorId) - val execs = executorsByHost.getOrElse(host, new HashSet) - execs -= executorId - if (execs.isEmpty) { - executorsByHost -= host - } - executorIdToHost -= executorId - rootPool.executorLost(executorId, host) - } - - def executorGained(execId: String, host: String) { - listener.executorGained(execId, host) - } - - def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized { - executorsByHost.get(host).map(_.toSet) - } - - def hasExecutorsAliveOnHost(host: String): Boolean = synchronized { - executorsByHost.contains(host) - } - - def isExecutorAlive(execId: String): Boolean = synchronized { - activeExecutorIds.contains(execId) - } - - // By default, rack is unknown - def getRackForHost(value: String): Option[String] = None -} - - -object ClusterScheduler { - /** - * Used to balance containers across hosts. - * - * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of - * resource offers representing the order in which the offers should be used. The resource - * offers are ordered such that we'll allocate one container on each host before allocating a - * second container on any host, and so on, in order to reduce the damage if a host fails. - * - * For example, given , , , returns - * [o1, o5, o4, 02, o6, o3] - */ - def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { - val _keyList = new ArrayBuffer[K](map.size) - _keyList ++= map.keys - - // order keyList based on population of value in map - val keyList = _keyList.sortWith( - (left, right) => map(left).size > map(right).size - ) - - val retval = new ArrayBuffer[T](keyList.size * 2) - var index = 0 - var found = true - - while (found) { - found = false - for (key <- keyList) { - val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null) - assert(containerList != null) - // Get the index'th entry for this host - if present - if (index < containerList.size){ - retval += containerList.apply(index) - found = true - } - } - index += 1 - } - - retval.toList - } -} diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala deleted file mode 100644 index a4d6880abb..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ /dev/null @@ -1,712 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import java.nio.ByteBuffer -import java.util.{Arrays, NoSuchElementException} - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.math.max -import scala.math.min - -import spark.{FetchFailed, Logging, Resubmitted, SparkEnv, Success, TaskEndReason, TaskState, Utils} -import spark.{ExceptionFailure, SparkException, TaskResultTooBigFailure} -import spark.TaskState.TaskState -import spark.scheduler._ -import scala.Some -import spark.FetchFailed -import spark.ExceptionFailure -import spark.TaskResultTooBigFailure -import spark.util.{SystemClock, Clock} - - -/** - * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of - * the status of each task, retries tasks if they fail (up to a limited number of times), and - * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces - * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, - * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished). - * - * THREADING: This class is designed to only be called from code with a lock on the - * ClusterScheduler (e.g. its event handlers). It should not be called from other threads. - */ -private[spark] class ClusterTaskSetManager( - sched: ClusterScheduler, - val taskSet: TaskSet, - clock: Clock = SystemClock) - extends TaskSetManager - with Logging -{ - // CPUs to request per task - val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt - - // Maximum times a task is allowed to fail before failing the job - val MAX_TASK_FAILURES = System.getProperty("spark.task.maxFailures", "4").toInt - - // Quantile of tasks at which to start speculation - val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble - val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble - - // Serializer for closures and tasks. - val env = SparkEnv.get - val ser = env.closureSerializer.newInstance() - - val tasks = taskSet.tasks - val numTasks = tasks.length - val copiesRunning = new Array[Int](numTasks) - val finished = new Array[Boolean](numTasks) - val numFailures = new Array[Int](numTasks) - val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) - var tasksFinished = 0 - - var weight = 1 - var minShare = 0 - var runningTasks = 0 - var priority = taskSet.priority - var stageId = taskSet.stageId - var name = "TaskSet_"+taskSet.stageId.toString - var parent: Schedulable = null - - // Set of pending tasks for each executor. These collections are actually - // treated as stacks, in which new tasks are added to the end of the - // ArrayBuffer and removed from the end. This makes it faster to detect - // tasks that repeatedly fail because whenever a task failed, it is put - // back at the head of the stack. They are also only cleaned up lazily; - // when a task is launched, it remains in all the pending lists except - // the one that it was launched from, but gets removed from them later. - private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]] - - // Set of pending tasks for each host. Similar to pendingTasksForExecutor, - // but at host level. - private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] - - // Set of pending tasks for each rack -- similar to the above. - private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] - - // Set containing pending tasks with no locality preferences. - val pendingTasksWithNoPrefs = new ArrayBuffer[Int] - - // Set containing all pending tasks (also used as a stack, as above). - val allPendingTasks = new ArrayBuffer[Int] - - // Tasks that can be speculated. Since these will be a small fraction of total - // tasks, we'll just hold them in a HashSet. - val speculatableTasks = new HashSet[Int] - - // Task index, start and finish time for each task attempt (indexed by task ID) - val taskInfos = new HashMap[Long, TaskInfo] - - // Did the TaskSet fail? - var failed = false - var causeOfFailure = "" - - // How frequently to reprint duplicate exceptions in full, in milliseconds - val EXCEPTION_PRINT_INTERVAL = - System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong - - // Map of recent exceptions (identified by string representation and top stack frame) to - // duplicate count (how many times the same exception has appeared) and time the full exception - // was printed. This should ideally be an LRU map that can drop old exceptions automatically. - val recentExceptions = HashMap[String, (Int, Long)]() - - // Figure out the current map output tracker epoch and set it on all tasks - val epoch = sched.mapOutputTracker.getEpoch - logDebug("Epoch for " + taskSet + ": " + epoch) - for (t <- tasks) { - t.epoch = epoch - } - - // Add all our tasks to the pending lists. We do this in reverse order - // of task index so that tasks with low indices get launched first. - for (i <- (0 until numTasks).reverse) { - addPendingTask(i) - } - - // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling - val myLocalityLevels = computeValidLocalityLevels() - val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level - - // Delay scheduling variables: we keep track of our current locality level and the time we - // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. - // We then move down if we manage to launch a "more local" task. - var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels - var lastLaunchTime = clock.getTime() // Time we last launched a task at this level - - /** - * Add a task to all the pending-task lists that it should be on. If readding is set, we are - * re-adding the task so only include it in each list if it's not already there. - */ - private def addPendingTask(index: Int, readding: Boolean = false) { - // Utility method that adds `index` to a list only if readding=false or it's not already there - def addTo(list: ArrayBuffer[Int]) { - if (!readding || !list.contains(index)) { - list += index - } - } - - var hadAliveLocations = false - for (loc <- tasks(index).preferredLocations) { - for (execId <- loc.executorId) { - if (sched.isExecutorAlive(execId)) { - addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) - hadAliveLocations = true - } - } - if (sched.hasExecutorsAliveOnHost(loc.host)) { - addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) - for (rack <- sched.getRackForHost(loc.host)) { - addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) - } - hadAliveLocations = true - } - } - - if (!hadAliveLocations) { - // Even though the task might've had preferred locations, all of those hosts or executors - // are dead; put it in the no-prefs list so we can schedule it elsewhere right away. - addTo(pendingTasksWithNoPrefs) - } - - if (!readding) { - allPendingTasks += index // No point scanning this whole list to find the old task there - } - } - - /** - * Return the pending tasks list for a given executor ID, or an empty list if - * there is no map entry for that host - */ - private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = { - pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer()) - } - - /** - * Return the pending tasks list for a given host, or an empty list if - * there is no map entry for that host - */ - private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { - pendingTasksForHost.getOrElse(host, ArrayBuffer()) - } - - /** - * Return the pending rack-local task list for a given rack, or an empty list if - * there is no map entry for that rack - */ - private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = { - pendingTasksForRack.getOrElse(rack, ArrayBuffer()) - } - - /** - * Dequeue a pending task from the given list and return its index. - * Return None if the list is empty. - * This method also cleans up any tasks in the list that have already - * been launched, since we want that to happen lazily. - */ - private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { - while (!list.isEmpty) { - val index = list.last - list.trimEnd(1) - if (copiesRunning(index) == 0 && !finished(index)) { - return Some(index) - } - } - return None - } - - /** Check whether a task is currently running an attempt on a given host */ - private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = { - !taskAttempts(taskIndex).exists(_.host == host) - } - - /** - * Return a speculative task for a given executor if any are available. The task should not have - * an attempt running on this host, in case the host is slow. In addition, the task should meet - * the given locality constraint. - */ - private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) - : Option[(Int, TaskLocality.Value)] = - { - speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set - - if (!speculatableTasks.isEmpty) { - // Check for process-local or preference-less tasks; note that tasks can be process-local - // on multiple nodes when we replicate cached blocks, as in Spark Streaming - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - val prefs = tasks(index).preferredLocations - val executors = prefs.flatMap(_.executorId) - if (prefs.size == 0 || executors.contains(execId)) { - speculatableTasks -= index - return Some((index, TaskLocality.PROCESS_LOCAL)) - } - } - - // Check for node-local tasks - if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - val locations = tasks(index).preferredLocations.map(_.host) - if (locations.contains(host)) { - speculatableTasks -= index - return Some((index, TaskLocality.NODE_LOCAL)) - } - } - } - - // Check for rack-local tasks - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { - for (rack <- sched.getRackForHost(host)) { - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost) - if (racks.contains(rack)) { - speculatableTasks -= index - return Some((index, TaskLocality.RACK_LOCAL)) - } - } - } - } - - // Check for non-local tasks - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { - for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) { - speculatableTasks -= index - return Some((index, TaskLocality.ANY)) - } - } - } - - return None - } - - /** - * Dequeue a pending task for a given node and return its index and locality level. - * Only search for tasks matching the given locality constraint. - */ - private def findTask(execId: String, host: String, locality: TaskLocality.Value) - : Option[(Int, TaskLocality.Value)] = - { - for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) { - return Some((index, TaskLocality.PROCESS_LOCAL)) - } - - if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { - for (index <- findTaskFromList(getPendingTasksForHost(host))) { - return Some((index, TaskLocality.NODE_LOCAL)) - } - } - - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { - for { - rack <- sched.getRackForHost(host) - index <- findTaskFromList(getPendingTasksForRack(rack)) - } { - return Some((index, TaskLocality.RACK_LOCAL)) - } - } - - // Look for no-pref tasks after rack-local tasks since they can run anywhere. - for (index <- findTaskFromList(pendingTasksWithNoPrefs)) { - return Some((index, TaskLocality.PROCESS_LOCAL)) - } - - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { - for (index <- findTaskFromList(allPendingTasks)) { - return Some((index, TaskLocality.ANY)) - } - } - - // Finally, if all else has failed, find a speculative task - return findSpeculativeTask(execId, host, locality) - } - - /** - * Respond to an offer of a single slave from the scheduler by finding a task - */ - override def resourceOffer( - execId: String, - host: String, - availableCpus: Int, - maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] = - { - if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { - val curTime = clock.getTime() - - var allowedLocality = getAllowedLocalityLevel(curTime) - if (allowedLocality > maxLocality) { - allowedLocality = maxLocality // We're not allowed to search for farther-away tasks - } - - findTask(execId, host, allowedLocality) match { - case Some((index, taskLocality)) => { - // Found a task; do some bookkeeping and return a task description - val task = tasks(index) - val taskId = sched.newTaskId() - // Figure out whether this should count as a preferred launch - logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( - taskSet.id, index, taskId, execId, host, taskLocality)) - // Do various bookkeeping - copiesRunning(index) += 1 - val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality) - taskInfos(taskId) = info - taskAttempts(index) = info :: taskAttempts(index) - // Update our locality level for delay scheduling - currentLocalityIndex = getLocalityIndex(taskLocality) - lastLaunchTime = curTime - // Serialize and return the task - val startTime = clock.getTime() - // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here - // we assume the task can be serialized without exceptions. - val serializedTask = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) - val timeTaken = clock.getTime() - startTime - increaseRunningTasks(1) - logInfo("Serialized task %s:%d as %d bytes in %d ms".format( - taskSet.id, index, serializedTask.limit, timeTaken)) - val taskName = "task %s:%d".format(taskSet.id, index) - if (taskAttempts(index).size == 1) - taskStarted(task,info) - return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) - } - case _ => - } - } - return None - } - - /** - * Get the level we can launch tasks according to delay scheduling, based on current wait time. - */ - private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = { - while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) && - currentLocalityIndex < myLocalityLevels.length - 1) - { - // Jump to the next locality level, and remove our waiting time for the current one since - // we don't want to count it again on the next one - lastLaunchTime += localityWaits(currentLocalityIndex) - currentLocalityIndex += 1 - } - myLocalityLevels(currentLocalityIndex) - } - - /** - * Find the index in myLocalityLevels for a given locality. This is also designed to work with - * localities that are not in myLocalityLevels (in case we somehow get those) by returning the - * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY. - */ - def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = { - var index = 0 - while (locality > myLocalityLevels(index)) { - index += 1 - } - index - } - - /** Called by cluster scheduler when one of our tasks changes state */ - override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - SparkEnv.set(env) - state match { - case TaskState.FINISHED => - taskFinished(tid, state, serializedData) - case TaskState.LOST => - taskLost(tid, state, serializedData) - case TaskState.FAILED => - taskLost(tid, state, serializedData) - case TaskState.KILLED => - taskLost(tid, state, serializedData) - case _ => - } - } - - def taskStarted(task: Task[_], info: TaskInfo) { - sched.listener.taskStarted(task, info) - } - - def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - if (info.failed) { - // We might get two task-lost messages for the same task in coarse-grained Mesos mode, - // or even from Mesos itself when acks get delayed. - return - } - val index = info.index - info.markSuccessful() - decreaseRunningTasks(1) - if (!finished(index)) { - tasksFinished += 1 - logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format( - tid, info.duration, info.host, tasksFinished, numTasks)) - // Deserialize task result and pass it to the scheduler - try { - val result = ser.deserialize[TaskResult[_]](serializedData) - result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded( - tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) - } catch { - case cnf: ClassNotFoundException => - val loader = Thread.currentThread().getContextClassLoader - throw new SparkException("ClassNotFound with classloader: " + loader, cnf) - case ex => throw ex - } - // Mark finished and stop if we've finished all the tasks - finished(index) = true - if (tasksFinished == numTasks) { - sched.taskSetFinished(this) - } - } else { - logInfo("Ignoring task-finished event for TID " + tid + - " because task " + index + " is already finished") - } - } - - def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - if (info.failed) { - // We might get two task-lost messages for the same task in coarse-grained Mesos mode, - // or even from Mesos itself when acks get delayed. - return - } - val index = info.index - info.markFailed() - decreaseRunningTasks(1) - if (!finished(index)) { - logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) - copiesRunning(index) -= 1 - // Check if the problem is a map output fetch failure. In that case, this - // task will never succeed on any node, so tell the scheduler about it. - if (serializedData != null && serializedData.limit() > 0) { - val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader) - reason match { - case fetchFailed: FetchFailed => - logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) - sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) - finished(index) = true - tasksFinished += 1 - sched.taskSetFinished(this) - decreaseRunningTasks(runningTasks) - return - - case taskResultTooBig: TaskResultTooBigFailure => - logInfo("Loss was due to task %s result exceeding Akka frame size; aborting job".format( - tid)) - abort("Task %s result exceeded Akka frame size".format(tid)) - return - - case ef: ExceptionFailure => - sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) - val key = ef.description - val now = clock.getTime() - val (printFull, dupCount) = { - if (recentExceptions.contains(key)) { - val (dupCount, printTime) = recentExceptions(key) - if (now - printTime > EXCEPTION_PRINT_INTERVAL) { - recentExceptions(key) = (0, now) - (true, 0) - } else { - recentExceptions(key) = (dupCount + 1, printTime) - (false, dupCount + 1) - } - } else { - recentExceptions(key) = (0, now) - (true, 0) - } - } - if (printFull) { - val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s\n%s".format( - ef.className, ef.description, locs.mkString("\n"))) - } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) - } - - case _ => {} - } - } - // On non-fetch failures, re-enqueue the task as pending for a max number of retries - addPendingTask(index) - // Count failed attempts only on FAILED and LOST state (not on KILLED) - if (state == TaskState.FAILED || state == TaskState.LOST) { - numFailures(index) += 1 - if (numFailures(index) > MAX_TASK_FAILURES) { - logError("Task %s:%d failed more than %d times; aborting job".format( - taskSet.id, index, MAX_TASK_FAILURES)) - abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) - } - } - } else { - logInfo("Ignoring task-lost event for TID " + tid + - " because task " + index + " is already finished") - } - } - - override def error(message: String) { - // Save the error message - abort("Error: " + message) - } - - def abort(message: String) { - failed = true - causeOfFailure = message - // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.listener.taskSetFailed(taskSet, message) - decreaseRunningTasks(runningTasks) - sched.taskSetFinished(this) - } - - override def increaseRunningTasks(taskNum: Int) { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - override def decreaseRunningTasks(taskNum: Int) { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - override def getSchedulableByName(name: String): Schedulable = { - return null - } - - override def addSchedulable(schedulable: Schedulable) {} - - override def removeSchedulable(schedulable: Schedulable) {} - - override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this) - sortedTaskSetQueue += this - return sortedTaskSetQueue - } - - /** Called by cluster scheduler when an executor is lost so we can re-enqueue our tasks */ - override def executorLost(execId: String, host: String) { - logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - - // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a - // task that used to have locations on only this host might now go to the no-prefs list. Note - // that it's okay if we add a task to the same queue twice (if it had multiple preferred - // locations), because findTaskFromList will skip already-running tasks. - for (index <- getPendingTasksForExecutor(execId)) { - addPendingTask(index, readding=true) - } - for (index <- getPendingTasksForHost(host)) { - addPendingTask(index, readding=true) - } - - // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage - if (tasks(0).isInstanceOf[ShuffleMapTask]) { - for ((tid, info) <- taskInfos if info.executorId == execId) { - val index = taskInfos(tid).index - if (finished(index)) { - finished(index) = false - copiesRunning(index) -= 1 - tasksFinished -= 1 - addPendingTask(index) - // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our - // stage finishes when a total of tasks.size tasks finish. - sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null) - } - } - } - // Also re-enqueue any tasks that were running on the node - for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - taskLost(tid, TaskState.KILLED, null) - } - } - - /** - * Check for tasks to be speculated and return true if there are any. This is called periodically - * by the ClusterScheduler. - * - * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that - * we don't scan the whole task set. It might also help to make this sorted by launch time. - */ - override def checkSpeculatableTasks(): Boolean = { - // Can't speculate if we only have one task, or if all tasks have finished. - if (numTasks == 1 || tasksFinished == numTasks) { - return false - } - var foundTasks = false - val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt - logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) - if (tasksFinished >= minFinishedForSpeculation) { - val time = clock.getTime() - val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray - Arrays.sort(durations) - val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1)) - val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) - // TODO: Threshold should also look at standard deviation of task durations and have a lower - // bound based on that. - logDebug("Task length threshold for speculation: " + threshold) - for ((tid, info) <- taskInfos) { - val index = info.index - if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && - !speculatableTasks.contains(index)) { - logInfo( - "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( - taskSet.id, index, info.host, threshold)) - speculatableTasks += index - foundTasks = true - } - } - } - return foundTasks - } - - override def hasPendingTasks(): Boolean = { - numTasks > 0 && tasksFinished < numTasks - } - - private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { - val defaultWait = System.getProperty("spark.locality.wait", "3000") - level match { - case TaskLocality.PROCESS_LOCAL => - System.getProperty("spark.locality.wait.process", defaultWait).toLong - case TaskLocality.NODE_LOCAL => - System.getProperty("spark.locality.wait.node", defaultWait).toLong - case TaskLocality.RACK_LOCAL => - System.getProperty("spark.locality.wait.rack", defaultWait).toLong - case TaskLocality.ANY => - 0L - } - } - - /** - * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been - * added to queues using addPendingTask. - */ - private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = { - import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY} - val levels = new ArrayBuffer[TaskLocality.TaskLocality] - if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) { - levels += PROCESS_LOCAL - } - if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) { - levels += NODE_LOCAL - } - if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) { - levels += RACK_LOCAL - } - levels += ANY - logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", ")) - levels.toArray - } -} diff --git a/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala b/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala deleted file mode 100644 index 8825f2dd24..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import spark.executor.ExecutorExitCode - -/** - * Represents an explanation for a executor or whole slave failing or exiting. - */ -private[spark] -class ExecutorLossReason(val message: String) { - override def toString: String = message -} - -private[spark] -case class ExecutorExited(val exitCode: Int) - extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) { -} - -private[spark] -case class SlaveLost(_message: String = "Slave lost") - extends ExecutorLossReason(_message) { -} diff --git a/core/src/main/scala/spark/scheduler/cluster/Pool.scala b/core/src/main/scala/spark/scheduler/cluster/Pool.scala deleted file mode 100644 index 83708f07e1..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/Pool.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import spark.Logging -import spark.scheduler.cluster.SchedulingMode.SchedulingMode - -/** - * An Schedulable entity that represent collection of Pools or TaskSetManagers - */ - -private[spark] class Pool( - val poolName: String, - val schedulingMode: SchedulingMode, - initMinShare: Int, - initWeight: Int) - extends Schedulable - with Logging { - - var schedulableQueue = new ArrayBuffer[Schedulable] - var schedulableNameToSchedulable = new HashMap[String, Schedulable] - - var weight = initWeight - var minShare = initMinShare - var runningTasks = 0 - - var priority = 0 - var stageId = 0 - var name = poolName - var parent:Schedulable = null - - var taskSetSchedulingAlgorithm: SchedulingAlgorithm = { - schedulingMode match { - case SchedulingMode.FAIR => - new FairSchedulingAlgorithm() - case SchedulingMode.FIFO => - new FIFOSchedulingAlgorithm() - } - } - - override def addSchedulable(schedulable: Schedulable) { - schedulableQueue += schedulable - schedulableNameToSchedulable(schedulable.name) = schedulable - schedulable.parent= this - } - - override def removeSchedulable(schedulable: Schedulable) { - schedulableQueue -= schedulable - schedulableNameToSchedulable -= schedulable.name - } - - override def getSchedulableByName(schedulableName: String): Schedulable = { - if (schedulableNameToSchedulable.contains(schedulableName)) { - return schedulableNameToSchedulable(schedulableName) - } - for (schedulable <- schedulableQueue) { - var sched = schedulable.getSchedulableByName(schedulableName) - if (sched != null) { - return sched - } - } - return null - } - - override def executorLost(executorId: String, host: String) { - schedulableQueue.foreach(_.executorLost(executorId, host)) - } - - override def checkSpeculatableTasks(): Boolean = { - var shouldRevive = false - for (schedulable <- schedulableQueue) { - shouldRevive |= schedulable.checkSpeculatableTasks() - } - return shouldRevive - } - - override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] - val sortedSchedulableQueue = schedulableQueue.sortWith(taskSetSchedulingAlgorithm.comparator) - for (schedulable <- sortedSchedulableQueue) { - sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue() - } - return sortedTaskSetQueue - } - - override def increaseRunningTasks(taskNum: Int) { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - override def decreaseRunningTasks(taskNum: Int) { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - override def hasPendingTasks(): Boolean = { - schedulableQueue.exists(_.hasPendingTasks()) - } -} diff --git a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala deleted file mode 100644 index e77e8e4162..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import spark.scheduler.cluster.SchedulingMode.SchedulingMode - -import scala.collection.mutable.ArrayBuffer -/** - * An interface for schedulable entities. - * there are two type of Schedulable entities(Pools and TaskSetManagers) - */ -private[spark] trait Schedulable { - var parent: Schedulable - // child queues - def schedulableQueue: ArrayBuffer[Schedulable] - def schedulingMode: SchedulingMode - def weight: Int - def minShare: Int - def runningTasks: Int - def priority: Int - def stageId: Int - def name: String - - def increaseRunningTasks(taskNum: Int): Unit - def decreaseRunningTasks(taskNum: Int): Unit - def addSchedulable(schedulable: Schedulable): Unit - def removeSchedulable(schedulable: Schedulable): Unit - def getSchedulableByName(name: String): Schedulable - def executorLost(executorId: String, host: String): Unit - def checkSpeculatableTasks(): Boolean - def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] - def hasPendingTasks(): Boolean -} diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala deleted file mode 100644 index 2fc8a76a05..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import java.io.{File, FileInputStream, FileOutputStream, FileNotFoundException} -import java.util.Properties - -import scala.xml.XML - -import spark.Logging -import spark.scheduler.cluster.SchedulingMode.SchedulingMode - - -/** - * An interface to build Schedulable tree - * buildPools: build the tree nodes(pools) - * addTaskSetManager: build the leaf nodes(TaskSetManagers) - */ -private[spark] trait SchedulableBuilder { - def buildPools() - def addTaskSetManager(manager: Schedulable, properties: Properties) -} - -private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) - extends SchedulableBuilder with Logging { - - override def buildPools() { - // nothing - } - - override def addTaskSetManager(manager: Schedulable, properties: Properties) { - rootPool.addSchedulable(manager) - } -} - -private[spark] class FairSchedulableBuilder(val rootPool: Pool) - extends SchedulableBuilder with Logging { - - val schedulerAllocFile = System.getProperty("spark.fairscheduler.allocation.file") - val FAIR_SCHEDULER_PROPERTIES = "spark.scheduler.cluster.fair.pool" - val DEFAULT_POOL_NAME = "default" - val MINIMUM_SHARES_PROPERTY = "minShare" - val SCHEDULING_MODE_PROPERTY = "schedulingMode" - val WEIGHT_PROPERTY = "weight" - val POOL_NAME_PROPERTY = "@name" - val POOLS_PROPERTY = "pool" - val DEFAULT_SCHEDULING_MODE = SchedulingMode.FIFO - val DEFAULT_MINIMUM_SHARE = 2 - val DEFAULT_WEIGHT = 1 - - override def buildPools() { - if (schedulerAllocFile != null) { - val file = new File(schedulerAllocFile) - if (file.exists()) { - val xml = XML.loadFile(file) - for (poolNode <- (xml \\ POOLS_PROPERTY)) { - - val poolName = (poolNode \ POOL_NAME_PROPERTY).text - var schedulingMode = DEFAULT_SCHEDULING_MODE - var minShare = DEFAULT_MINIMUM_SHARE - var weight = DEFAULT_WEIGHT - - val xmlSchedulingMode = (poolNode \ SCHEDULING_MODE_PROPERTY).text - if (xmlSchedulingMode != "") { - try { - schedulingMode = SchedulingMode.withName(xmlSchedulingMode) - } catch { - case e: Exception => logInfo("Error xml schedulingMode, using default schedulingMode") - } - } - - val xmlMinShare = (poolNode \ MINIMUM_SHARES_PROPERTY).text - if (xmlMinShare != "") { - minShare = xmlMinShare.toInt - } - - val xmlWeight = (poolNode \ WEIGHT_PROPERTY).text - if (xmlWeight != "") { - weight = xmlWeight.toInt - } - - val pool = new Pool(poolName, schedulingMode, minShare, weight) - rootPool.addSchedulable(pool) - logInfo("Created pool %s, schedulingMode: %s, minShare: %d, weight: %d".format( - poolName, schedulingMode, minShare, weight)) - } - } else { - throw new java.io.FileNotFoundException( - "Fair scheduler allocation file not found: " + schedulerAllocFile) - } - } - - // finally create "default" pool - if (rootPool.getSchedulableByName(DEFAULT_POOL_NAME) == null) { - val pool = new Pool(DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, - DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) - rootPool.addSchedulable(pool) - logInfo("Created default pool %s, schedulingMode: %s, minShare: %d, weight: %d".format( - DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) - } - } - - override def addTaskSetManager(manager: Schedulable, properties: Properties) { - var poolName = DEFAULT_POOL_NAME - var parentPool = rootPool.getSchedulableByName(poolName) - if (properties != null) { - poolName = properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME) - parentPool = rootPool.getSchedulableByName(poolName) - if (parentPool == null) { - // we will create a new pool that user has configured in app - // instead of being defined in xml file - parentPool = new Pool(poolName, DEFAULT_SCHEDULING_MODE, - DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) - rootPool.addSchedulable(parentPool) - logInfo("Created pool %s, schedulingMode: %s, minShare: %d, weight: %d".format( - poolName, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) - } - } - parentPool.addSchedulable(manager) - logInfo("Added task set " + manager.name + " tasks to pool "+poolName) - } -} diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala deleted file mode 100644 index 4431744ec3..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import spark.{SparkContext, Utils} - -/** - * A backend interface for cluster scheduling systems that allows plugging in different ones under - * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as - * machines become available and can launch tasks on them. - */ -private[spark] trait SchedulerBackend { - def start(): Unit - def stop(): Unit - def reviveOffers(): Unit - def defaultParallelism(): Int - - // Memory used by each executor (in megabytes) - protected val executorMemory: Int = SparkContext.executorMemoryRequested - - // TODO: Probably want to add a killTask too -} diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala deleted file mode 100644 index 69e0ac2a6b..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -/** - * An interface for sort algorithm - * FIFO: FIFO algorithm between TaskSetManagers - * FS: FS algorithm between Pools, and FIFO or FS within Pools - */ -private[spark] trait SchedulingAlgorithm { - def comparator(s1: Schedulable, s2: Schedulable): Boolean -} - -private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm { - override def comparator(s1: Schedulable, s2: Schedulable): Boolean = { - val priority1 = s1.priority - val priority2 = s2.priority - var res = math.signum(priority1 - priority2) - if (res == 0) { - val stageId1 = s1.stageId - val stageId2 = s2.stageId - res = math.signum(stageId1 - stageId2) - } - if (res < 0) { - return true - } else { - return false - } - } -} - -private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { - override def comparator(s1: Schedulable, s2: Schedulable): Boolean = { - val minShare1 = s1.minShare - val minShare2 = s2.minShare - val runningTasks1 = s1.runningTasks - val runningTasks2 = s2.runningTasks - val s1Needy = runningTasks1 < minShare1 - val s2Needy = runningTasks2 < minShare2 - val minShareRatio1 = runningTasks1.toDouble / math.max(minShare1, 1.0).toDouble - val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble - val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble - val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble - var res:Boolean = true - var compare:Int = 0 - - if (s1Needy && !s2Needy) { - return true - } else if (!s1Needy && s2Needy) { - return false - } else if (s1Needy && s2Needy) { - compare = minShareRatio1.compareTo(minShareRatio2) - } else { - compare = taskToWeightRatio1.compareTo(taskToWeightRatio2) - } - - if (compare < 0) { - return true - } else if (compare > 0) { - return false - } else { - return s1.name < s2.name - } - } -} - diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala deleted file mode 100644 index 55cdf4791f..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -/** - * "FAIR" and "FIFO" determines which policy is used - * to order tasks amongst a Schedulable's sub-queues - * "NONE" is used when the a Schedulable has no sub-queues. - */ -object SchedulingMode extends Enumeration("FAIR", "FIFO", "NONE") { - - type SchedulingMode = Value - val FAIR,FIFO,NONE = Value -} diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala deleted file mode 100644 index 7ac574bdc8..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import spark.{Utils, Logging, SparkContext} -import spark.deploy.client.{Client, ClientListener} -import spark.deploy.{Command, ApplicationDescription} -import scala.collection.mutable.HashMap - -private[spark] class SparkDeploySchedulerBackend( - scheduler: ClusterScheduler, - sc: SparkContext, - master: String, - appName: String) - extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem) - with ClientListener - with Logging { - - var client: Client = null - var stopping = false - var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ - - val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt - - override def start() { - super.start() - - // The endpoint for executors to talk to us - val driverUrl = "akka://spark@%s:%s/user/%s".format( - System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), - StandaloneSchedulerBackend.ACTOR_NAME) - val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") - val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) - val sparkHome = sc.getSparkHome().getOrElse(null) - val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome, - sc.ui.appUIAddress) - - client = new Client(sc.env.actorSystem, master, appDesc, this) - client.start() - } - - override def stop() { - stopping = true - super.stop() - client.stop() - if (shutdownCallback != null) { - shutdownCallback(this) - } - } - - override def connected(appId: String) { - logInfo("Connected to Spark cluster with app ID " + appId) - } - - override def disconnected() { - if (!stopping) { - logError("Disconnected from Spark cluster!") - scheduler.error("Disconnected from Spark cluster") - } - } - - override def executorAdded(executorId: String, workerId: String, hostPort: String, cores: Int, memory: Int) { - logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format( - executorId, hostPort, cores, Utils.megabytesToString(memory))) - } - - override def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) { - val reason: ExecutorLossReason = exitStatus match { - case Some(code) => ExecutorExited(code) - case None => SlaveLost(message) - } - logInfo("Executor %s removed: %s".format(executorId, message)) - removeExecutor(executorId, reason.toString) - } -} diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala deleted file mode 100644 index 05c29eb72f..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import java.nio.ByteBuffer - -import spark.TaskState.TaskState -import spark.Utils -import spark.util.SerializableBuffer - - -private[spark] sealed trait StandaloneClusterMessage extends Serializable - -private[spark] object StandaloneClusterMessages { - - // Driver to executors - case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage - - case class RegisteredExecutor(sparkProperties: Seq[(String, String)]) - extends StandaloneClusterMessage - - case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage - - // Executors to driver - case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) - extends StandaloneClusterMessage { - Utils.checkHostPort(hostPort, "Expected host port") - } - - case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, - data: SerializableBuffer) extends StandaloneClusterMessage - - object StatusUpdate { - /** Alternate factory method that takes a ByteBuffer directly for the data field */ - def apply(executorId: String, taskId: Long, state: TaskState, data: ByteBuffer) - : StatusUpdate = { - StatusUpdate(executorId, taskId, state, new SerializableBuffer(data)) - } - } - - // Internal messages in driver - case object ReviveOffers extends StandaloneClusterMessage - - case object StopDriver extends StandaloneClusterMessage - - case class RemoveExecutor(executorId: String, reason: String) extends StandaloneClusterMessage - -} diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala deleted file mode 100644 index 3203be1029..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} - -import akka.actor._ -import akka.dispatch.Await -import akka.pattern.ask -import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} -import akka.util.Duration -import akka.util.duration._ - -import spark.{Utils, SparkException, Logging, TaskState} -import spark.scheduler.cluster.StandaloneClusterMessages._ - -/** - * A standalone scheduler backend, which waits for standalone executors to connect to it through - * Akka. These may be executed in a variety of ways, such as Mesos tasks for the coarse-grained - * Mesos mode or standalone processes for Spark's standalone deploy mode (spark.deploy.*). - */ -private[spark] -class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) - extends SchedulerBackend with Logging -{ - // Use an atomic variable to track total number of cores in the cluster for simplicity and speed - var totalCoreCount = new AtomicInteger(0) - - class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { - private val executorActor = new HashMap[String, ActorRef] - private val executorAddress = new HashMap[String, Address] - private val executorHost = new HashMap[String, String] - private val freeCores = new HashMap[String, Int] - private val actorToExecutorId = new HashMap[ActorRef, String] - private val addressToExecutorId = new HashMap[Address, String] - - override def preStart() { - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - - // Periodically revive offers to allow delay scheduling to work - val reviveInterval = System.getProperty("spark.scheduler.revive.interval", "1000").toLong - context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) - } - - def receive = { - case RegisterExecutor(executorId, hostPort, cores) => - Utils.checkHostPort(hostPort, "Host port expected " + hostPort) - if (executorActor.contains(executorId)) { - sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) - } else { - logInfo("Registered executor: " + sender + " with ID " + executorId) - sender ! RegisteredExecutor(sparkProperties) - context.watch(sender) - executorActor(executorId) = sender - executorHost(executorId) = Utils.parseHostPort(hostPort)._1 - freeCores(executorId) = cores - executorAddress(executorId) = sender.path.address - actorToExecutorId(sender) = executorId - addressToExecutorId(sender.path.address) = executorId - totalCoreCount.addAndGet(cores) - makeOffers() - } - - case StatusUpdate(executorId, taskId, state, data) => - scheduler.statusUpdate(taskId, state, data.value) - if (TaskState.isFinished(state)) { - freeCores(executorId) += 1 - makeOffers(executorId) - } - - case ReviveOffers => - makeOffers() - - case StopDriver => - sender ! true - context.stop(self) - - case RemoveExecutor(executorId, reason) => - removeExecutor(executorId, reason) - sender ! true - - case Terminated(actor) => - actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated")) - - case RemoteClientDisconnected(transport, address) => - addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disconnected")) - - case RemoteClientShutdown(transport, address) => - addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client shutdown")) - } - - // Make fake resource offers on all executors - def makeOffers() { - launchTasks(scheduler.resourceOffers( - executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))})) - } - - // Make fake resource offers on just one executor - def makeOffers(executorId: String) { - launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId))))) - } - - // Launch tasks returned by a set of resource offers - def launchTasks(tasks: Seq[Seq[TaskDescription]]) { - for (task <- tasks.flatten) { - freeCores(task.executorId) -= 1 - executorActor(task.executorId) ! LaunchTask(task) - } - } - - // Remove a disconnected slave from the cluster - def removeExecutor(executorId: String, reason: String) { - if (executorActor.contains(executorId)) { - logInfo("Executor " + executorId + " disconnected, so removing it") - val numCores = freeCores(executorId) - actorToExecutorId -= executorActor(executorId) - addressToExecutorId -= executorAddress(executorId) - executorActor -= executorId - executorHost -= executorId - freeCores -= executorId - totalCoreCount.addAndGet(-numCores) - scheduler.executorLost(executorId, SlaveLost(reason)) - } - } - } - - var driverActor: ActorRef = null - val taskIdsOnSlave = new HashMap[String, HashSet[String]] - - override def start() { - val properties = new ArrayBuffer[(String, String)] - val iterator = System.getProperties.entrySet.iterator - while (iterator.hasNext) { - val entry = iterator.next - val (key, value) = (entry.getKey.toString, entry.getValue.toString) - if (key.startsWith("spark.") && !key.equals("spark.hostPort")) { - properties += ((key, value)) - } - } - driverActor = actorSystem.actorOf( - Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME) - } - - private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") - - override def stop() { - try { - if (driverActor != null) { - val future = driverActor.ask(StopDriver)(timeout) - Await.result(future, timeout) - } - } catch { - case e: Exception => - throw new SparkException("Error stopping standalone scheduler's driver actor", e) - } - } - - override def reviveOffers() { - driverActor ! ReviveOffers - } - - override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism")) - .map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2)) - - // Called by subclasses when notified of a lost worker - def removeExecutor(executorId: String, reason: String) { - try { - val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) - Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error notifying standalone scheduler's driver actor", e) - } - } -} - -private[spark] object StandaloneSchedulerBackend { - val ACTOR_NAME = "StandaloneScheduler" -} diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala deleted file mode 100644 index 187553233f..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import java.nio.ByteBuffer -import spark.util.SerializableBuffer - -private[spark] class TaskDescription( - val taskId: Long, - val executorId: String, - val name: String, - val index: Int, // Index within this task's TaskSet - _serializedTask: ByteBuffer) - extends Serializable { - - // Because ByteBuffers are not serializable, wrap the task in a SerializableBuffer - private val buffer = new SerializableBuffer(_serializedTask) - - def serializedTask: ByteBuffer = buffer.value - - override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index) -} diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala deleted file mode 100644 index c2c5522686..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import spark.Utils - -/** - * Information about a running task attempt inside a TaskSet. - */ -private[spark] -class TaskInfo( - val taskId: Long, - val index: Int, - val launchTime: Long, - val executorId: String, - val host: String, - val taskLocality: TaskLocality.TaskLocality) { - - var finishTime: Long = 0 - var failed = false - - def markSuccessful(time: Long = System.currentTimeMillis) { - finishTime = time - } - - def markFailed(time: Long = System.currentTimeMillis) { - finishTime = time - failed = true - } - - def finished: Boolean = finishTime != 0 - - def successful: Boolean = finished && !failed - - def running: Boolean = !finished - - def status: String = { - if (running) - "RUNNING" - else if (failed) - "FAILED" - else if (successful) - "SUCCESS" - else - "UNKNOWN" - } - - def duration: Long = { - if (!finished) { - throw new UnsupportedOperationException("duration() called on unfinished tasks") - } else { - finishTime - launchTime - } - } - - def timeRunning(currentTime: Long): Long = currentTime - launchTime -} diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskLocality.scala b/core/src/main/scala/spark/scheduler/cluster/TaskLocality.scala deleted file mode 100644 index 1c33e41f87..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/TaskLocality.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - - -private[spark] object TaskLocality - extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") -{ - // process local is expected to be used ONLY within tasksetmanager for now. - val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value - - type TaskLocality = Value - - def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = { - condition <= constraint - } -} diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala deleted file mode 100644 index 0248830b7a..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import java.nio.ByteBuffer - -import spark.TaskState.TaskState -import spark.scheduler.TaskSet - -/** - * Tracks and schedules the tasks within a single TaskSet. This class keeps track of the status of - * each task and is responsible for retries on failure and locality. The main interfaces to it - * are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, and - * statusUpdate, which tells it that one of its tasks changed state (e.g. finished). - * - * THREADING: This class is designed to only be called from code with a lock on the TaskScheduler - * (e.g. its event handlers). It should not be called from other threads. - */ -private[spark] trait TaskSetManager extends Schedulable { - def schedulableQueue = null - - def schedulingMode = SchedulingMode.NONE - - def taskSet: TaskSet - - def resourceOffer( - execId: String, - host: String, - availableCpus: Int, - maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] - - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) - - def error(message: String) -} diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala deleted file mode 100644 index 1d09bd9b03..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -/** - * Represents free resources available on an executor. - */ -private[spark] -class WorkerOffer(val executorId: String, val host: String, val cores: Int) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala deleted file mode 100644 index 5be4dbd9f0..0000000000 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ /dev/null @@ -1,272 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.local - -import java.io.File -import java.lang.management.ManagementFactory -import java.util.concurrent.atomic.AtomicInteger -import java.nio.ByteBuffer - -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -import spark._ -import spark.TaskState.TaskState -import spark.executor.ExecutorURLClassLoader -import spark.scheduler._ -import spark.scheduler.cluster._ -import spark.scheduler.cluster.SchedulingMode.SchedulingMode -import akka.actor._ - -/** - * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally - * the scheduler also allows each task to fail up to maxFailures times, which is useful for - * testing fault recovery. - */ - -private[spark] -case class LocalReviveOffers() - -private[spark] -case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) - -private[spark] -class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging { - - def receive = { - case LocalReviveOffers => - launchTask(localScheduler.resourceOffer(freeCores)) - case LocalStatusUpdate(taskId, state, serializeData) => - freeCores += 1 - localScheduler.statusUpdate(taskId, state, serializeData) - launchTask(localScheduler.resourceOffer(freeCores)) - } - - def launchTask(tasks : Seq[TaskDescription]) { - for (task <- tasks) { - freeCores -= 1 - localScheduler.threadPool.submit(new Runnable { - def run() { - localScheduler.runTask(task.taskId, task.serializedTask) - } - }) - } - } -} - -private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) - extends TaskScheduler - with Logging { - - var attemptId = new AtomicInteger(0) - var threadPool = Utils.newDaemonFixedThreadPool(threads) - val env = SparkEnv.get - var listener: TaskSchedulerListener = null - - // Application dependencies (added through SparkContext) that we've fetched so far on this node. - // Each map holds the master's timestamp for the version of that file or JAR we got. - val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() - val currentJars: HashMap[String, Long] = new HashMap[String, Long]() - - val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader) - - var schedulableBuilder: SchedulableBuilder = null - var rootPool: Pool = null - val schedulingMode: SchedulingMode = SchedulingMode.withName( - System.getProperty("spark.cluster.schedulingmode", "FIFO")) - val activeTaskSets = new HashMap[String, TaskSetManager] - val taskIdToTaskSetId = new HashMap[Long, String] - val taskSetTaskIds = new HashMap[String, HashSet[Long]] - - var localActor: ActorRef = null - - override def start() { - // temporarily set rootPool name to empty - rootPool = new Pool("", schedulingMode, 0, 0) - schedulableBuilder = { - schedulingMode match { - case SchedulingMode.FIFO => - new FIFOSchedulableBuilder(rootPool) - case SchedulingMode.FAIR => - new FairSchedulableBuilder(rootPool) - } - } - schedulableBuilder.buildPools() - - localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test") - } - - override def setListener(listener: TaskSchedulerListener) { - this.listener = listener - } - - override def submitTasks(taskSet: TaskSet) { - synchronized { - val manager = new LocalTaskSetManager(this, taskSet) - schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) - activeTaskSets(taskSet.id) = manager - taskSetTaskIds(taskSet.id) = new HashSet[Long]() - localActor ! LocalReviveOffers - } - } - - def resourceOffer(freeCores: Int): Seq[TaskDescription] = { - synchronized { - var freeCpuCores = freeCores - val tasks = new ArrayBuffer[TaskDescription](freeCores) - val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue() - for (manager <- sortedTaskSetQueue) { - logDebug("parentName:%s,name:%s,runningTasks:%s".format( - manager.parent.name, manager.name, manager.runningTasks)) - } - - var launchTask = false - for (manager <- sortedTaskSetQueue) { - do { - launchTask = false - manager.resourceOffer(null, null, freeCpuCores, null) match { - case Some(task) => - tasks += task - taskIdToTaskSetId(task.taskId) = manager.taskSet.id - taskSetTaskIds(manager.taskSet.id) += task.taskId - freeCpuCores -= 1 - launchTask = true - case None => {} - } - } while(launchTask) - } - return tasks - } - } - - def taskSetFinished(manager: TaskSetManager) { - synchronized { - activeTaskSets -= manager.taskSet.id - manager.parent.removeSchedulable(manager) - logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) - taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskSetTaskIds -= manager.taskSet.id - } - } - - def runTask(taskId: Long, bytes: ByteBuffer) { - logInfo("Running " + taskId) - val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) - // Set the Spark execution environment for the worker thread - SparkEnv.set(env) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objectSer = SparkEnv.get.serializer.newInstance() - var attemptedTask: Option[Task[_]] = None - val start = System.currentTimeMillis() - var taskStart: Long = 0 - def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum - val startGCTime = getTotalGCTime - - try { - Accumulators.clear() - Thread.currentThread().setContextClassLoader(classLoader) - - // Serialize and deserialize the task so that accumulators are changed to thread-local ones; - // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) - updateDependencies(taskFiles, taskJars) // Download any files added with addFile - val deserializedTask = ser.deserialize[Task[_]]( - taskBytes, Thread.currentThread.getContextClassLoader) - attemptedTask = Some(deserializedTask) - val deserTime = System.currentTimeMillis() - start - taskStart = System.currentTimeMillis() - - // Run it - val result: Any = deserializedTask.run(taskId) - - // Serialize and deserialize the result to emulate what the Mesos - // executor does. This is useful to catch serialization errors early - // on in development (so when users move their local Spark programs - // to the cluster, they don't get surprised by serialization errors). - val serResult = objectSer.serialize(result) - deserializedTask.metrics.get.resultSize = serResult.limit() - val resultToReturn = objectSer.deserialize[Any](serResult) - val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( - ser.serialize(Accumulators.values)) - val serviceTime = System.currentTimeMillis() - taskStart - logInfo("Finished " + taskId) - deserializedTask.metrics.get.executorRunTime = serviceTime.toInt - deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime - deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt - val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null)) - val serializedResult = ser.serialize(taskResult) - localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult) - } catch { - case t: Throwable => { - val serviceTime = System.currentTimeMillis() - taskStart - val metrics = attemptedTask.flatMap(t => t.metrics) - for (m <- metrics) { - m.executorRunTime = serviceTime.toInt - m.jvmGCTime = getTotalGCTime - startGCTime - } - val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics) - localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure)) - } - } - } - - /** - * Download any missing dependencies if we receive a new set of files and JARs from the - * SparkContext. Also adds any new JARs we fetched to the class loader. - */ - private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - synchronized { - // Fetch missing dependencies - for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentFiles(name) = timestamp - } - - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentJars(name) = timestamp - // Add it to our class loader - val localName = name.split("/").last - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!classLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - classLoader.addURL(url) - } - } - } - } - - def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) { - synchronized { - val taskSetId = taskIdToTaskSetId(taskId) - val taskSetManager = activeTaskSets(taskSetId) - taskSetTaskIds(taskSetId) -= taskId - taskSetManager.statusUpdate(taskId, state, serializedData) - } - } - - override def stop() { - threadPool.shutdownNow() - } - - override def defaultParallelism() = threads -} diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala deleted file mode 100644 index e237f289e3..0000000000 --- a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.local - -import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import spark.{ExceptionFailure, Logging, SparkEnv, Success, TaskState} -import spark.TaskState.TaskState -import spark.scheduler.{Task, TaskResult, TaskSet} -import spark.scheduler.cluster.{Schedulable, TaskDescription, TaskInfo, TaskLocality, TaskSetManager} - - -private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) - extends TaskSetManager with Logging { - - var parent: Schedulable = null - var weight: Int = 1 - var minShare: Int = 0 - var runningTasks: Int = 0 - var priority: Int = taskSet.priority - var stageId: Int = taskSet.stageId - var name: String = "TaskSet_" + taskSet.stageId.toString - - var failCount = new Array[Int](taskSet.tasks.size) - val taskInfos = new HashMap[Long, TaskInfo] - val numTasks = taskSet.tasks.size - var numFinished = 0 - val env = SparkEnv.get - val ser = env.closureSerializer.newInstance() - val copiesRunning = new Array[Int](numTasks) - val finished = new Array[Boolean](numTasks) - val numFailures = new Array[Int](numTasks) - val MAX_TASK_FAILURES = sched.maxFailures - - override def increaseRunningTasks(taskNum: Int): Unit = { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - override def decreaseRunningTasks(taskNum: Int): Unit = { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - override def addSchedulable(schedulable: Schedulable): Unit = { - // nothing - } - - override def removeSchedulable(schedulable: Schedulable): Unit = { - // nothing - } - - override def getSchedulableByName(name: String): Schedulable = { - return null - } - - override def executorLost(executorId: String, host: String): Unit = { - // nothing - } - - override def checkSpeculatableTasks() = true - - override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] - sortedTaskSetQueue += this - return sortedTaskSetQueue - } - - override def hasPendingTasks() = true - - def findTask(): Option[Int] = { - for (i <- 0 to numTasks-1) { - if (copiesRunning(i) == 0 && !finished(i)) { - return Some(i) - } - } - return None - } - - override def resourceOffer( - execId: String, - host: String, - availableCpus: Int, - maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] = - { - SparkEnv.set(sched.env) - logDebug("availableCpus:%d, numFinished:%d, numTasks:%d".format( - availableCpus.toInt, numFinished, numTasks)) - if (availableCpus > 0 && numFinished < numTasks) { - findTask() match { - case Some(index) => - val taskId = sched.attemptId.getAndIncrement() - val task = taskSet.tasks(index) - val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", - TaskLocality.NODE_LOCAL) - taskInfos(taskId) = info - // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here - // we assume the task can be serialized without exceptions. - val bytes = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) - logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes") - val taskName = "task %s:%d".format(taskSet.id, index) - copiesRunning(index) += 1 - increaseRunningTasks(1) - taskStarted(task, info) - return Some(new TaskDescription(taskId, null, taskName, index, bytes)) - case None => {} - } - } - return None - } - - override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - SparkEnv.set(env) - state match { - case TaskState.FINISHED => - taskEnded(tid, state, serializedData) - case TaskState.FAILED => - taskFailed(tid, state, serializedData) - case _ => {} - } - } - - def taskStarted(task: Task[_], info: TaskInfo) { - sched.listener.taskStarted(task, info) - } - - def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - val index = info.index - val task = taskSet.tasks(index) - info.markSuccessful() - val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) - result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics) - numFinished += 1 - decreaseRunningTasks(1) - finished(index) = true - if (numFinished == numTasks) { - sched.taskSetFinished(this) - } - } - - def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) { - val info = taskInfos(tid) - val index = info.index - val task = taskSet.tasks(index) - info.markFailed() - decreaseRunningTasks(1) - val reason: ExceptionFailure = ser.deserialize[ExceptionFailure]( - serializedData, getClass.getClassLoader) - sched.listener.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null)) - if (!finished(index)) { - copiesRunning(index) -= 1 - numFailures(index) += 1 - val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s\n%s".format( - reason.className, reason.description, locs.mkString("\n"))) - if (numFailures(index) > MAX_TASK_FAILURES) { - val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format( - taskSet.id, index, 4, reason.description) - decreaseRunningTasks(runningTasks) - sched.listener.taskSetFailed(taskSet, errorMessage) - // need to delete failed Taskset from schedule queue - sched.taskSetFinished(this) - } - } - } - - override def error(message: String) { - } -} diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala deleted file mode 100644 index eef3ee1425..0000000000 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ /dev/null @@ -1,284 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.mesos - -import com.google.protobuf.ByteString - -import org.apache.mesos.{Scheduler => MScheduler} -import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} - -import spark.{SparkException, Utils, Logging, SparkContext} -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.collection.JavaConversions._ -import java.io.File -import spark.scheduler.cluster._ -import java.util.{ArrayList => JArrayList, List => JList} -import java.util.Collections -import spark.TaskState - -/** - * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds - * onto each Mesos node for the duration of the Spark job instead of relinquishing cores whenever - * a task is done. It launches Spark tasks within the coarse-grained Mesos tasks using the - * StandaloneBackend mechanism. This class is useful for lower and more predictable latency. - * - * Unfortunately this has a bit of duplication from MesosSchedulerBackend, but it seems hard to - * remove this. - */ -private[spark] class CoarseMesosSchedulerBackend( - scheduler: ClusterScheduler, - sc: SparkContext, - master: String, - appName: String) - extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem) - with MScheduler - with Logging { - - val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures - - // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() - - // Driver for talking to Mesos - var driver: SchedulerDriver = null - - // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) - val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt - - // Cores we have acquired with each Mesos task ID - val coresByTaskId = new HashMap[Int, Int] - var totalCoresAcquired = 0 - - val slaveIdsWithExecutors = new HashSet[String] - - val taskIdToSlaveId = new HashMap[Int, String] - val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed - - val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( - "Spark home is not set; set it through the spark.home system " + - "property, the SPARK_HOME environment variable or the SparkContext constructor")) - - val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt - - var nextMesosTaskId = 0 - - def newMesosTaskId(): Int = { - val id = nextMesosTaskId - nextMesosTaskId += 1 - id - } - - override def start() { - super.start() - - synchronized { - new Thread("CoarseMesosSchedulerBackend driver") { - setDaemon(true) - override def run() { - val scheduler = CoarseMesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(appName).build() - driver = new MesosSchedulerDriver(scheduler, fwInfo, master) - try { { - val ret = driver.run() - logInfo("driver.run() returned with code " + ret) - } - } catch { - case e: Exception => logError("driver.run() failed", e) - } - } - }.start() - - waitForRegister() - } - } - - def createCommand(offer: Offer, numCores: Int): CommandInfo = { - val environment = Environment.newBuilder() - sc.executorEnvs.foreach { case (key, value) => - environment.addVariables(Environment.Variable.newBuilder() - .setName(key) - .setValue(value) - .build()) - } - val command = CommandInfo.newBuilder() - .setEnvironment(environment) - val driverUrl = "akka://spark@%s:%s/user/%s".format( - System.getProperty("spark.driver.host"), - System.getProperty("spark.driver.port"), - StandaloneSchedulerBackend.ACTOR_NAME) - val uri = System.getProperty("spark.executor.uri") - if (uri == null) { - val runScript = new File(sparkHome, "spark-class").getCanonicalPath - command.setValue("\"%s\" spark.executor.StandaloneExecutorBackend %s %s %s %d".format( - runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) - } else { - // Grab everything to the first '.'. We'll use that and '*' to - // glob the directory "correctly". - val basename = uri.split('/').last.split('.').head - command.setValue("cd %s*; ./spark-class spark.executor.StandaloneExecutorBackend %s %s %s %d".format( - basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) - command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) - } - return command.build() - } - - override def offerRescinded(d: SchedulerDriver, o: OfferID) {} - - override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { - logInfo("Registered as framework ID " + frameworkId.getValue) - registeredLock.synchronized { - isRegistered = true - registeredLock.notifyAll() - } - } - - def waitForRegister() { - registeredLock.synchronized { - while (!isRegistered) { - registeredLock.wait() - } - } - } - - override def disconnected(d: SchedulerDriver) {} - - override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} - - /** - * Method called by Mesos to offer resources on slaves. We respond by launching an executor, - * unless we've already launched more than we wanted to. - */ - override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { - synchronized { - val filters = Filters.newBuilder().setRefuseSeconds(-1).build() - - for (offer <- offers) { - val slaveId = offer.getSlaveId.toString - val mem = getResource(offer.getResourcesList, "mem") - val cpus = getResource(offer.getResourcesList, "cpus").toInt - if (totalCoresAcquired < maxCores && mem >= executorMemory && cpus >= 1 && - failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && - !slaveIdsWithExecutors.contains(slaveId)) { - // Launch an executor on the slave - val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) - val taskId = newMesosTaskId() - taskIdToSlaveId(taskId) = slaveId - slaveIdsWithExecutors += slaveId - coresByTaskId(taskId) = cpusToUse - val task = MesosTaskInfo.newBuilder() - .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) - .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) - .setName("Task " + taskId) - .addResources(createResource("cpus", cpusToUse)) - .addResources(createResource("mem", executorMemory)) - .build() - d.launchTasks(offer.getId, Collections.singletonList(task), filters) - } else { - // Filter it out - d.launchTasks(offer.getId, Collections.emptyList[MesosTaskInfo](), filters) - } - } - } - } - - /** Helper function to pull out a resource from a Mesos Resources protobuf */ - private def getResource(res: JList[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue - } - // If we reached here, no resource with the required name was present - throw new IllegalArgumentException("No resource called " + name + " in " + res) - } - - /** Build a Mesos resource protobuf object */ - private def createResource(resourceName: String, quantity: Double): Protos.Resource = { - Resource.newBuilder() - .setName(resourceName) - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(quantity).build()) - .build() - } - - /** Check whether a Mesos task state represents a finished task */ - private def isFinished(state: MesosTaskState) = { - state == MesosTaskState.TASK_FINISHED || - state == MesosTaskState.TASK_FAILED || - state == MesosTaskState.TASK_KILLED || - state == MesosTaskState.TASK_LOST - } - - override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { - val taskId = status.getTaskId.getValue.toInt - val state = status.getState - logInfo("Mesos task " + taskId + " is now " + state) - synchronized { - if (isFinished(state)) { - val slaveId = taskIdToSlaveId(taskId) - slaveIdsWithExecutors -= slaveId - taskIdToSlaveId -= taskId - // Remove the cores we have remembered for this task, if it's in the hashmap - for (cores <- coresByTaskId.get(taskId)) { - totalCoresAcquired -= cores - coresByTaskId -= taskId - } - // If it was a failure, mark the slave as failed for blacklisting purposes - if (state == MesosTaskState.TASK_FAILED || state == MesosTaskState.TASK_LOST) { - failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1 - if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) { - logInfo("Blacklisting Mesos slave " + slaveId + " due to too many failures; " + - "is Spark installed on it?") - } - } - driver.reviveOffers() // In case we'd rejected everything before but have now lost a node - } - } - } - - override def error(d: SchedulerDriver, message: String) { - logError("Mesos error: " + message) - scheduler.error(message) - } - - override def stop() { - super.stop() - if (driver != null) { - driver.stop() - } - } - - override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} - - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { - logInfo("Mesos slave lost: " + slaveId.getValue) - synchronized { - if (slaveIdsWithExecutors.contains(slaveId.getValue)) { - // Note that the slave ID corresponds to the executor ID on that slave - slaveIdsWithExecutors -= slaveId.getValue - removeExecutor(slaveId.getValue, "Mesos slave lost") - } - } - } - - override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) { - logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) - slaveLost(d, s) - } -} diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala deleted file mode 100644 index f6069a5775..0000000000 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ /dev/null @@ -1,342 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.mesos - -import com.google.protobuf.ByteString - -import org.apache.mesos.{Scheduler => MScheduler} -import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} - -import spark.{SparkException, Utils, Logging, SparkContext} -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.collection.JavaConversions._ -import java.io.File -import spark.scheduler.cluster._ -import java.util.{ArrayList => JArrayList, List => JList} -import java.util.Collections -import spark.TaskState - -/** - * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a - * separate Mesos task, allowing multiple applications to share cluster nodes both in space (tasks - * from multiple apps can run on different cores) and in time (a core can switch ownership). - */ -private[spark] class MesosSchedulerBackend( - scheduler: ClusterScheduler, - sc: SparkContext, - master: String, - appName: String) - extends SchedulerBackend - with MScheduler - with Logging { - - // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() - - // Driver for talking to Mesos - var driver: SchedulerDriver = null - - // Which slave IDs we have executors on - val slaveIdsWithExecutors = new HashSet[String] - val taskIdToSlaveId = new HashMap[Long, String] - - // An ExecutorInfo for our tasks - var execArgs: Array[Byte] = null - - var classLoader: ClassLoader = null - - override def start() { - synchronized { - classLoader = Thread.currentThread.getContextClassLoader - - new Thread("MesosSchedulerBackend driver") { - setDaemon(true) - override def run() { - val scheduler = MesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(appName).build() - driver = new MesosSchedulerDriver(scheduler, fwInfo, master) - try { - val ret = driver.run() - logInfo("driver.run() returned with code " + ret) - } catch { - case e: Exception => logError("driver.run() failed", e) - } - } - }.start() - - waitForRegister() - } - } - - def createExecutorInfo(execId: String): ExecutorInfo = { - val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( - "Spark home is not set; set it through the spark.home system " + - "property, the SPARK_HOME environment variable or the SparkContext constructor")) - val environment = Environment.newBuilder() - sc.executorEnvs.foreach { case (key, value) => - environment.addVariables(Environment.Variable.newBuilder() - .setName(key) - .setValue(value) - .build()) - } - val command = CommandInfo.newBuilder() - .setEnvironment(environment) - val uri = System.getProperty("spark.executor.uri") - if (uri == null) { - command.setValue(new File(sparkHome, "spark-executor").getCanonicalPath) - } else { - // Grab everything to the first '.'. We'll use that and '*' to - // glob the directory "correctly". - val basename = uri.split('/').last.split('.').head - command.setValue("cd %s*; ./spark-executor".format(basename)) - command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) - } - val memory = Resource.newBuilder() - .setName("mem") - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(executorMemory).build()) - .build() - ExecutorInfo.newBuilder() - .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) - .setCommand(command) - .setData(ByteString.copyFrom(createExecArg())) - .addResources(memory) - .build() - } - - /** - * Create and serialize the executor argument to pass to Mesos. Our executor arg is an array - * containing all the spark.* system properties in the form of (String, String) pairs. - */ - private def createExecArg(): Array[Byte] = { - if (execArgs == null) { - val props = new HashMap[String, String] - val iterator = System.getProperties.entrySet.iterator - while (iterator.hasNext) { - val entry = iterator.next - val (key, value) = (entry.getKey.toString, entry.getValue.toString) - if (key.startsWith("spark.")) { - props(key) = value - } - } - // Serialize the map as an array of (String, String) pairs - execArgs = Utils.serialize(props.toArray) - } - return execArgs - } - - private def setClassLoader(): ClassLoader = { - val oldClassLoader = Thread.currentThread.getContextClassLoader - Thread.currentThread.setContextClassLoader(classLoader) - return oldClassLoader - } - - private def restoreClassLoader(oldClassLoader: ClassLoader) { - Thread.currentThread.setContextClassLoader(oldClassLoader) - } - - override def offerRescinded(d: SchedulerDriver, o: OfferID) {} - - override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { - val oldClassLoader = setClassLoader() - try { - logInfo("Registered as framework ID " + frameworkId.getValue) - registeredLock.synchronized { - isRegistered = true - registeredLock.notifyAll() - } - } finally { - restoreClassLoader(oldClassLoader) - } - } - - def waitForRegister() { - registeredLock.synchronized { - while (!isRegistered) { - registeredLock.wait() - } - } - } - - override def disconnected(d: SchedulerDriver) {} - - override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} - - /** - * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets - * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that - * tasks are balanced across the cluster. - */ - override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { - val oldClassLoader = setClassLoader() - try { - synchronized { - // Build a big list of the offerable workers, and remember their indices so that we can - // figure out which Offer to reply to for each worker - val offerableIndices = new ArrayBuffer[Int] - val offerableWorkers = new ArrayBuffer[WorkerOffer] - - def enoughMemory(o: Offer) = { - val mem = getResource(o.getResourcesList, "mem") - val slaveId = o.getSlaveId.getValue - mem >= executorMemory || slaveIdsWithExecutors.contains(slaveId) - } - - for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) { - offerableIndices += index - offerableWorkers += new WorkerOffer( - offer.getSlaveId.getValue, - offer.getHostname, - getResource(offer.getResourcesList, "cpus").toInt) - } - - // Call into the ClusterScheduler - val taskLists = scheduler.resourceOffers(offerableWorkers) - - // Build a list of Mesos tasks for each slave - val mesosTasks = offers.map(o => Collections.emptyList[MesosTaskInfo]()) - for ((taskList, index) <- taskLists.zipWithIndex) { - if (!taskList.isEmpty) { - val offerNum = offerableIndices(index) - val slaveId = offers(offerNum).getSlaveId.getValue - slaveIdsWithExecutors += slaveId - mesosTasks(offerNum) = new JArrayList[MesosTaskInfo](taskList.size) - for (taskDesc <- taskList) { - taskIdToSlaveId(taskDesc.taskId) = slaveId - mesosTasks(offerNum).add(createMesosTask(taskDesc, slaveId)) - } - } - } - - // Reply to the offers - val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? - for (i <- 0 until offers.size) { - d.launchTasks(offers(i).getId, mesosTasks(i), filters) - } - } - } finally { - restoreClassLoader(oldClassLoader) - } - } - - /** Helper function to pull out a resource from a Mesos Resources protobuf */ - def getResource(res: JList[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue - } - // If we reached here, no resource with the required name was present - throw new IllegalArgumentException("No resource called " + name + " in " + res) - } - - /** Turn a Spark TaskDescription into a Mesos task */ - def createMesosTask(task: TaskDescription, slaveId: String): MesosTaskInfo = { - val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build() - val cpuResource = Resource.newBuilder() - .setName("cpus") - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(1).build()) - .build() - return MesosTaskInfo.newBuilder() - .setTaskId(taskId) - .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) - .setExecutor(createExecutorInfo(slaveId)) - .setName(task.name) - .addResources(cpuResource) - .setData(ByteString.copyFrom(task.serializedTask)) - .build() - } - - /** Check whether a Mesos task state represents a finished task */ - def isFinished(state: MesosTaskState) = { - state == MesosTaskState.TASK_FINISHED || - state == MesosTaskState.TASK_FAILED || - state == MesosTaskState.TASK_KILLED || - state == MesosTaskState.TASK_LOST - } - - override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { - val oldClassLoader = setClassLoader() - try { - val tid = status.getTaskId.getValue.toLong - val state = TaskState.fromMesos(status.getState) - synchronized { - if (status.getState == MesosTaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) { - // We lost the executor on this slave, so remember that it's gone - slaveIdsWithExecutors -= taskIdToSlaveId(tid) - } - if (isFinished(status.getState)) { - taskIdToSlaveId.remove(tid) - } - } - scheduler.statusUpdate(tid, state, status.getData.asReadOnlyByteBuffer) - } finally { - restoreClassLoader(oldClassLoader) - } - } - - override def error(d: SchedulerDriver, message: String) { - val oldClassLoader = setClassLoader() - try { - logError("Mesos error: " + message) - scheduler.error(message) - } finally { - restoreClassLoader(oldClassLoader) - } - } - - override def stop() { - if (driver != null) { - driver.stop() - } - } - - override def reviveOffers() { - driver.reviveOffers() - } - - override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} - - private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) { - val oldClassLoader = setClassLoader() - try { - logInfo("Mesos slave lost: " + slaveId.getValue) - synchronized { - slaveIdsWithExecutors -= slaveId.getValue - } - scheduler.executorLost(slaveId.getValue, reason) - } finally { - restoreClassLoader(oldClassLoader) - } - } - - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { - recordSlaveLost(d, slaveId, SlaveLost()) - } - - override def executorLost(d: SchedulerDriver, executorId: ExecutorID, - slaveId: SlaveID, status: Int) { - logInfo("Executor lost: %s, marking slave %s as lost".format(executorId.getValue, - slaveId.getValue)) - recordSlaveLost(d, slaveId, ExecutorExited(status)) - } - - // TODO: query Mesos for number of cores - override def defaultParallelism() = System.getProperty("spark.default.parallelism", "8").toInt -} diff --git a/core/src/main/scala/spark/serializer/Serializer.scala b/core/src/main/scala/spark/serializer/Serializer.scala deleted file mode 100644 index dc94d42bb6..0000000000 --- a/core/src/main/scala/spark/serializer/Serializer.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.serializer - -import java.io.{EOFException, InputStream, OutputStream} -import java.nio.ByteBuffer - -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream - -import spark.util.ByteBufferInputStream - - -/** - * A serializer. Because some serialization libraries are not thread safe, this class is used to - * create [[spark.serializer.SerializerInstance]] objects that do the actual serialization and are - * guaranteed to only be called from one thread at a time. - */ -trait Serializer { - def newInstance(): SerializerInstance -} - - -/** - * An instance of a serializer, for use by one thread at a time. - */ -trait SerializerInstance { - def serialize[T](t: T): ByteBuffer - - def deserialize[T](bytes: ByteBuffer): T - - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T - - def serializeStream(s: OutputStream): SerializationStream - - def deserializeStream(s: InputStream): DeserializationStream - - def serializeMany[T](iterator: Iterator[T]): ByteBuffer = { - // Default implementation uses serializeStream - val stream = new FastByteArrayOutputStream() - serializeStream(stream).writeAll(iterator) - val buffer = ByteBuffer.allocate(stream.position.toInt) - buffer.put(stream.array, 0, stream.position.toInt) - buffer.flip() - buffer - } - - def deserializeMany(buffer: ByteBuffer): Iterator[Any] = { - // Default implementation uses deserializeStream - buffer.rewind() - deserializeStream(new ByteBufferInputStream(buffer)).asIterator - } -} - - -/** - * A stream for writing serialized objects. - */ -trait SerializationStream { - def writeObject[T](t: T): SerializationStream - def flush(): Unit - def close(): Unit - - def writeAll[T](iter: Iterator[T]): SerializationStream = { - while (iter.hasNext) { - writeObject(iter.next()) - } - this - } -} - - -/** - * A stream for reading serialized objects. - */ -trait DeserializationStream { - def readObject[T](): T - def close(): Unit - - /** - * Read the elements of this stream through an iterator. This can only be called once, as - * reading each element will consume data from the input source. - */ - def asIterator: Iterator[Any] = new spark.util.NextIterator[Any] { - override protected def getNext() = { - try { - readObject[Any]() - } catch { - case eof: EOFException => - finished = true - } - } - - override protected def close() { - DeserializationStream.this.close() - } - } -} diff --git a/core/src/main/scala/spark/serializer/SerializerManager.scala b/core/src/main/scala/spark/serializer/SerializerManager.scala deleted file mode 100644 index b7b24705a2..0000000000 --- a/core/src/main/scala/spark/serializer/SerializerManager.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.serializer - -import java.util.concurrent.ConcurrentHashMap - - -/** - * A service that returns a serializer object given the serializer's class name. If a previous - * instance of the serializer object has been created, the get method returns that instead of - * creating a new one. - */ -private[spark] class SerializerManager { - - private val serializers = new ConcurrentHashMap[String, Serializer] - private var _default: Serializer = _ - - def default = _default - - def setDefault(clsName: String): Serializer = { - _default = get(clsName) - _default - } - - def get(clsName: String): Serializer = { - if (clsName == null) { - default - } else { - var serializer = serializers.get(clsName) - if (serializer != null) { - // If the serializer has been created previously, reuse that. - serializer - } else this.synchronized { - // Otherwise, create a new one. But make sure no other thread has attempted - // to create another new one at the same time. - serializer = serializers.get(clsName) - if (serializer == null) { - val clsLoader = Thread.currentThread.getContextClassLoader - serializer = - Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer] - serializers.put(clsName, serializer) - } - serializer - } - } - } -} diff --git a/core/src/main/scala/spark/storage/BlockException.scala b/core/src/main/scala/spark/storage/BlockException.scala deleted file mode 100644 index 8ebfaf3cbf..0000000000 --- a/core/src/main/scala/spark/storage/BlockException.scala +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -private[spark] -case class BlockException(blockId: String, message: String) extends Exception(message) - diff --git a/core/src/main/scala/spark/storage/BlockFetchTracker.scala b/core/src/main/scala/spark/storage/BlockFetchTracker.scala deleted file mode 100644 index 265e554ad8..0000000000 --- a/core/src/main/scala/spark/storage/BlockFetchTracker.scala +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -private[spark] trait BlockFetchTracker { - def totalBlocks : Int - def numLocalBlocks: Int - def numRemoteBlocks: Int - def remoteFetchTime : Long - def fetchWaitTime: Long - def remoteBytesRead : Long -} diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala deleted file mode 100644 index 568783d893..0000000000 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ /dev/null @@ -1,348 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.nio.ByteBuffer -import java.util.concurrent.LinkedBlockingQueue - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashSet -import scala.collection.mutable.Queue - -import io.netty.buffer.ByteBuf - -import spark.Logging -import spark.Utils -import spark.SparkException -import spark.network.BufferMessage -import spark.network.ConnectionManagerId -import spark.network.netty.ShuffleCopier -import spark.serializer.Serializer - - -/** - * A block fetcher iterator interface. There are two implementations: - * - * BasicBlockFetcherIterator: uses a custom-built NIO communication layer. - * NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer. - * - * Eventually we would like the two to converge and use a single NIO-based communication layer, - * but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores), - * NIO would perform poorly and thus the need for the Netty OIO one. - */ - -private[storage] -trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])] - with Logging with BlockFetchTracker { - def initialize() -} - - -private[storage] -object BlockFetcherIterator { - - // A request to fetch one or more blocks, complete with their sizes - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { - val size = blocks.map(_._2).sum - } - - // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize - // the block (since we want all deserializaton to happen in the calling thread); can also - // represent a fetch failure if size == -1. - class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { - def failed: Boolean = size == -1 - } - - class BasicBlockFetcherIterator( - private val blockManager: BlockManager, - val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], - serializer: Serializer) - extends BlockFetcherIterator { - - import blockManager._ - - private var _remoteBytesRead = 0l - private var _remoteFetchTime = 0l - private var _fetchWaitTime = 0l - - if (blocksByAddress == null) { - throw new IllegalArgumentException("BlocksByAddress is null") - } - - // Total number blocks fetched (local + remote). Also number of FetchResults expected - protected var _numBlocksToFetch = 0 - - protected var startTime = System.currentTimeMillis - - // This represents the number of local blocks, also counting zero-sized blocks - private var numLocal = 0 - // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks - protected val localBlocksToFetch = new ArrayBuffer[String]() - - // This represents the number of remote blocks, also counting zero-sized blocks - private var numRemote = 0 - // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks - protected val remoteBlocksToFetch = new HashSet[String]() - - // A queue to hold our results. - protected val results = new LinkedBlockingQueue[FetchResult] - - // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that - // the number of bytes in flight is limited to maxBytesInFlight - private val fetchRequests = new Queue[FetchRequest] - - // Current bytes in flight from our requests - private var bytesInFlight = 0L - - protected def sendRequest(req: FetchRequest) { - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) - val cmId = new ConnectionManagerId(req.address.host, req.address.port) - val blockMessageArray = new BlockMessageArray(req.blocks.map { - case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId)) - }) - bytesInFlight += req.size - val sizeMap = req.blocks.toMap // so we can look up the size of each blockID - val fetchStart = System.currentTimeMillis() - val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - future.onSuccess { - case Some(message) => { - val fetchDone = System.currentTimeMillis() - _remoteFetchTime += fetchDone - fetchStart - val bufferMessage = message.asInstanceOf[BufferMessage] - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { - if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - throw new SparkException( - "Unexpected message " + blockMessage.getType + " received from " + cmId) - } - val blockId = blockMessage.getId - val networkSize = blockMessage.getData.limit() - results.put(new FetchResult(blockId, sizeMap(blockId), - () => dataDeserialize(blockId, blockMessage.getData, serializer))) - _remoteBytesRead += networkSize - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) - } - } - case None => { - logError("Could not get block(s) from " + cmId) - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } - } - } - } - - protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { - // Split local and remote blocks. Remote blocks are further split into FetchRequests of size - // at most maxBytesInFlight in order to limit the amount of data in flight. - val remoteRequests = new ArrayBuffer[FetchRequest] - for ((address, blockInfos) <- blocksByAddress) { - if (address == blockManagerId) { - numLocal = blockInfos.size - // Filter out zero-sized blocks - localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1) - _numBlocksToFetch += localBlocksToFetch.size - } else { - numRemote += blockInfos.size - // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them - // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 - // nodes, rather than blocking on reading output from one node. - val minRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(String, Long)] - while (iterator.hasNext) { - val (blockId, size) = iterator.next() - // Skip empty blocks - if (size > 0) { - curBlocks += ((blockId, size)) - remoteBlocksToFetch += blockId - _numBlocksToFetch += 1 - curRequestSize += size - } else if (size < 0) { - throw new BlockException(blockId, "Negative block size " + size) - } - if (curRequestSize >= minRequestSize) { - // Add this FetchRequest - remoteRequests += new FetchRequest(address, curBlocks) - curRequestSize = 0 - curBlocks = new ArrayBuffer[(String, Long)] - } - } - // Add in the final request - if (!curBlocks.isEmpty) { - remoteRequests += new FetchRequest(address, curBlocks) - } - } - } - logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " + - totalBlocks + " blocks") - remoteRequests - } - - protected def getLocalBlocks() { - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - for (id <- localBlocksToFetch) { - getLocalFromDisk(id, serializer) match { - case Some(iter) => { - // Pass 0 as size since it's not in flight - results.put(new FetchResult(id, 0, () => iter)) - logDebug("Got local block " + id) - } - case None => { - throw new BlockException(id, "Could not get block " + id + " from local machine") - } - } - } - } - - override def initialize() { - // Split local and remote blocks. - val remoteRequests = splitLocalRemoteBlocks() - // Add the remote requests into our queue in a random order - fetchRequests ++= Utils.randomize(remoteRequests) - - // Send out initial requests for blocks, up to our maxBytesInFlight - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } - - val numGets = remoteRequests.size - fetchRequests.size - logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) - - // Get Local Blocks - startTime = System.currentTimeMillis - getLocalBlocks() - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - } - - //an iterator that will read fetched blocks off the queue as they arrive. - @volatile protected var resultsGotten = 0 - - override def hasNext: Boolean = resultsGotten < _numBlocksToFetch - - override def next(): (String, Option[Iterator[Any]]) = { - resultsGotten += 1 - val startFetchWait = System.currentTimeMillis() - val result = results.take() - val stopFetchWait = System.currentTimeMillis() - _fetchWaitTime += (stopFetchWait - startFetchWait) - if (! result.failed) bytesInFlight -= result.size - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } - (result.blockId, if (result.failed) None else Some(result.deserialize())) - } - - // Implementing BlockFetchTracker trait. - override def totalBlocks: Int = numLocal + numRemote - override def numLocalBlocks: Int = numLocal - override def numRemoteBlocks: Int = numRemote - override def remoteFetchTime: Long = _remoteFetchTime - override def fetchWaitTime: Long = _fetchWaitTime - override def remoteBytesRead: Long = _remoteBytesRead - } - // End of BasicBlockFetcherIterator - - class NettyBlockFetcherIterator( - blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], - serializer: Serializer) - extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) { - - import blockManager._ - - val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] - - private def startCopiers(numCopiers: Int): List[_ <: Thread] = { - (for ( i <- Range(0,numCopiers) ) yield { - val copier = new Thread { - override def run(){ - try { - while(!isInterrupted && !fetchRequestsSync.isEmpty) { - sendRequest(fetchRequestsSync.take()) - } - } catch { - case x: InterruptedException => logInfo("Copier Interrupted") - //case _ => throw new SparkException("Exception Throw in Shuffle Copier") - } - } - } - copier.start - copier - }).toList - } - - // keep this to interrupt the threads when necessary - private def stopCopiers() { - for (copier <- copiers) { - copier.interrupt() - } - } - - override protected def sendRequest(req: FetchRequest) { - - def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) { - val fetchResult = new FetchResult(blockId, blockSize, - () => dataDeserialize(blockId, blockData.nioBuffer, serializer)) - results.put(fetchResult) - } - - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.bytesToString(req.size), req.address.host)) - val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort) - val cpier = new ShuffleCopier - cpier.getBlocks(cmId, req.blocks, putResult) - logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) - } - - private var copiers: List[_ <: Thread] = null - - override def initialize() { - // Split Local Remote Blocks and set numBlocksToFetch - val remoteRequests = splitLocalRemoteBlocks() - // Add the remote requests into our queue in a random order - for (request <- Utils.randomize(remoteRequests)) { - fetchRequestsSync.put(request) - } - - copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt) - logInfo("Started " + fetchRequestsSync.size + " remote gets in " + - Utils.getUsedTimeMs(startTime)) - - // Get Local Blocks - startTime = System.currentTimeMillis - getLocalBlocks() - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - } - - override def next(): (String, Option[Iterator[Any]]) = { - resultsGotten += 1 - val result = results.take() - // If all the results has been retrieved, copiers will exit automatically - (result.blockId, if (result.failed) None else Some(result.deserialize())) - } - } - // End of NettyBlockFetcherIterator -} diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala deleted file mode 100644 index 2a6ec2a55d..0000000000 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ /dev/null @@ -1,1046 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.io.{InputStream, OutputStream} -import java.nio.{ByteBuffer, MappedByteBuffer} - -import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet} - -import akka.actor.{ActorSystem, Cancellable, Props} -import akka.dispatch.{Await, Future} -import akka.util.Duration -import akka.util.duration._ - -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream - -import spark.{Logging, SparkEnv, SparkException, Utils} -import spark.io.CompressionCodec -import spark.network._ -import spark.serializer.Serializer -import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap} - -import sun.nio.ch.DirectBuffer - - -private[spark] class BlockManager( - executorId: String, - actorSystem: ActorSystem, - val master: BlockManagerMaster, - val defaultSerializer: Serializer, - maxMemory: Long) - extends Logging { - - private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { - @volatile var pending: Boolean = true - @volatile var size: Long = -1L - @volatile var initThread: Thread = null - @volatile var failed = false - - setInitThread() - - private def setInitThread() { - // Set current thread as init thread - waitForReady will not block this thread - // (in case there is non trivial initialization which ends up calling waitForReady as part of - // initialization itself) - this.initThread = Thread.currentThread() - } - - /** - * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing). - * Return true if the block is available, false otherwise. - */ - def waitForReady(): Boolean = { - if (initThread != Thread.currentThread() && pending) { - synchronized { - while (pending) this.wait() - } - } - !failed - } - - /** Mark this BlockInfo as ready (i.e. block is finished writing) */ - def markReady(sizeInBytes: Long) { - assert (pending) - size = sizeInBytes - initThread = null - failed = false - initThread = null - pending = false - synchronized { - this.notifyAll() - } - } - - /** Mark this BlockInfo as ready but failed */ - def markFailure() { - assert (pending) - size = 0 - initThread = null - failed = true - initThread = null - pending = false - synchronized { - this.notifyAll() - } - } - } - - val shuffleBlockManager = new ShuffleBlockManager(this) - - private val blockInfo = new TimeStampedHashMap[String, BlockInfo] - - private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) - private[storage] val diskStore: DiskStore = - new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) - - // If we use Netty for shuffle, start a new Netty-based shuffle sender service. - private val nettyPort: Int = { - val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean - val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt - if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0 - } - - val connectionManager = new ConnectionManager(0) - implicit val futureExecContext = connectionManager.futureExecContext - - val blockManagerId = BlockManagerId( - executorId, connectionManager.id.host, connectionManager.id.port, nettyPort) - - // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory - // for receiving shuffle outputs) - val maxBytesInFlight = - System.getProperty("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024 - - // Whether to compress broadcast variables that are stored - val compressBroadcast = System.getProperty("spark.broadcast.compress", "true").toBoolean - // Whether to compress shuffle output that are stored - val compressShuffle = System.getProperty("spark.shuffle.compress", "true").toBoolean - // Whether to compress RDD partitions that are stored serialized - val compressRdds = System.getProperty("spark.rdd.compress", "false").toBoolean - - val heartBeatFrequency = BlockManager.getHeartBeatFrequencyFromSystemProperties - - val hostPort = Utils.localHostPort() - - val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), - name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) - - // Pending reregistration action being executed asynchronously or null if none - // is pending. Accesses should synchronize on asyncReregisterLock. - var asyncReregisterTask: Future[Unit] = null - val asyncReregisterLock = new Object - - private def heartBeat() { - if (!master.sendHeartBeat(blockManagerId)) { - reregister() - } - } - - var heartBeatTask: Cancellable = null - - val metadataCleaner = new MetadataCleaner("BlockManager", this.dropOldBlocks) - initialize() - - // The compression codec to use. Note that the "lazy" val is necessary because we want to delay - // the initialization of the compression codec until it is first used. The reason is that a Spark - // program could be using a user-defined codec in a third party jar, which is loaded in - // Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been - // loaded yet. - private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec() - - /** - * Construct a BlockManager with a memory limit set based on system properties. - */ - def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster, - serializer: Serializer) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties) - } - - /** - * Initialize the BlockManager. Register to the BlockManagerMaster, and start the - * BlockManagerWorker actor. - */ - private def initialize() { - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) - BlockManagerWorker.startBlockManagerWorker(this) - if (!BlockManager.getDisableHeartBeatsForTesting) { - heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) { - heartBeat() - } - } - } - - /** - * Report all blocks to the BlockManager again. This may be necessary if we are dropped - * by the BlockManager and come back or if we become capable of recovering blocks on disk after - * an executor crash. - * - * This function deliberately fails silently if the master returns false (indicating that - * the slave needs to reregister). The error condition will be detected again by the next - * heart beat attempt or new block registration and another try to reregister all blocks - * will be made then. - */ - private def reportAllBlocks() { - logInfo("Reporting " + blockInfo.size + " blocks to the master.") - for ((blockId, info) <- blockInfo) { - if (!tryToReportBlockStatus(blockId, info)) { - logError("Failed to report " + blockId + " to master; giving up.") - return - } - } - } - - /** - * Reregister with the master and report all blocks to it. This will be called by the heart beat - * thread if our heartbeat to the block amnager indicates that we were not registered. - * - * Note that this method must be called without any BlockInfo locks held. - */ - def reregister() { - // TODO: We might need to rate limit reregistering. - logInfo("BlockManager reregistering with master") - master.registerBlockManager(blockManagerId, maxMemory, slaveActor) - reportAllBlocks() - } - - /** - * Reregister with the master sometime soon. - */ - def asyncReregister() { - asyncReregisterLock.synchronized { - if (asyncReregisterTask == null) { - asyncReregisterTask = Future[Unit] { - reregister() - asyncReregisterLock.synchronized { - asyncReregisterTask = null - } - } - } - } - } - - /** - * For testing. Wait for any pending asynchronous reregistration; otherwise, do nothing. - */ - def waitForAsyncReregister() { - val task = asyncReregisterTask - if (task != null) { - Await.ready(task, Duration.Inf) - } - } - - /** - * Get storage level of local block. If no info exists for the block, then returns null. - */ - def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull - - /** - * Tell the master about the current storage status of a block. This will send a block update - * message reflecting the current status, *not* the desired storage level in its block info. - * For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk. - * - * droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid). - * This ensures that update in master will compensate for the increase in memory on slave. - */ - def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) { - val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize) - if (needReregister) { - logInfo("Got told to reregister updating block " + blockId) - // Reregistering will report our new block for free. - asyncReregister() - } - logDebug("Told master about block " + blockId) - } - - /** - * Actually send a UpdateBlockInfo message. Returns the mater's response, - * which will be true if the block was successfully recorded and false if - * the slave needs to re-register. - */ - private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = { - val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized { - info.level match { - case null => - (StorageLevel.NONE, 0L, 0L, false) - case level => - val inMem = level.useMemory && memoryStore.contains(blockId) - val onDisk = level.useDisk && diskStore.contains(blockId) - val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication) - val memSize = if (inMem) memoryStore.getSize(blockId) else droppedMemorySize - val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L - (storageLevel, memSize, diskSize, info.tellMaster) - } - } - - if (tellMaster) { - master.updateBlockInfo(blockManagerId, blockId, curLevel, inMemSize, onDiskSize) - } else { - true - } - } - - /** - * Get locations of an array of blocks. - */ - def getLocationBlockIds(blockIds: Array[String]): Array[Seq[BlockManagerId]] = { - val startTimeMs = System.currentTimeMillis - val locations = master.getLocations(blockIds).toArray - logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) - locations - } - - /** - * A short-circuited method to get blocks directly from disk. This is used for getting - * shuffle blocks. It is safe to do so without a lock on block info since disk store - * never deletes (recent) items. - */ - def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { - diskStore.getValues(blockId, serializer).orElse( - sys.error("Block " + blockId + " not found on disk, though it should be")) - } - - /** - * Get block from local block manager. - */ - def getLocal(blockId: String): Option[Iterator[Any]] = { - logDebug("Getting local block " + blockId) - val info = blockInfo.get(blockId).orNull - if (info != null) { - info.synchronized { - - // In the another thread is writing the block, wait for it to become ready. - if (!info.waitForReady()) { - // If we get here, the block write failed. - logWarning("Block " + blockId + " was marked as failure.") - return None - } - - val level = info.level - logDebug("Level for block " + blockId + " is " + level) - - // Look for the block in memory - if (level.useMemory) { - logDebug("Getting block " + blockId + " from memory") - memoryStore.getValues(blockId) match { - case Some(iterator) => - return Some(iterator) - case None => - logDebug("Block " + blockId + " not found in memory") - } - } - - // Look for block on disk, potentially loading it back into memory if required - if (level.useDisk) { - logDebug("Getting block " + blockId + " from disk") - if (level.useMemory && level.deserialized) { - diskStore.getValues(blockId) match { - case Some(iterator) => - // Put the block back in memory before returning it - // TODO: Consider creating a putValues that also takes in a iterator ? - val elements = new ArrayBuffer[Any] - elements ++= iterator - memoryStore.putValues(blockId, elements, level, true).data match { - case Left(iterator2) => - return Some(iterator2) - case _ => - throw new Exception("Memory store did not return back an iterator") - } - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } else if (level.useMemory && !level.deserialized) { - // Read it as a byte buffer into memory first, then return it - diskStore.getBytes(blockId) match { - case Some(bytes) => - // Put a copy of the block back in memory before returning it. Note that we can't - // put the ByteBuffer returned by the disk store as that's a memory-mapped file. - // The use of rewind assumes this. - assert (0 == bytes.position()) - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) - memoryStore.putBytes(blockId, copyForMemory, level) - bytes.rewind() - return Some(dataDeserialize(blockId, bytes)) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } else { - diskStore.getValues(blockId) match { - case Some(iterator) => - return Some(iterator) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } - } - } - } else { - logDebug("Block " + blockId + " not registered locally") - } - return None - } - - /** - * Get block from the local block manager as serialized bytes. - */ - def getLocalBytes(blockId: String): Option[ByteBuffer] = { - // TODO: This whole thing is very similar to getLocal; we need to refactor it somehow - logDebug("Getting local block " + blockId + " as bytes") - - // As an optimization for map output fetches, if the block is for a shuffle, return it - // without acquiring a lock; the disk store never deletes (recent) items so this should work - if (ShuffleBlockManager.isShuffle(blockId)) { - return diskStore.getBytes(blockId) match { - case Some(bytes) => - Some(bytes) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } - - val info = blockInfo.get(blockId).orNull - if (info != null) { - info.synchronized { - - // In the another thread is writing the block, wait for it to become ready. - if (!info.waitForReady()) { - // If we get here, the block write failed. - logWarning("Block " + blockId + " was marked as failure.") - return None - } - - val level = info.level - logDebug("Level for block " + blockId + " is " + level) - - // Look for the block in memory - if (level.useMemory) { - logDebug("Getting block " + blockId + " from memory") - memoryStore.getBytes(blockId) match { - case Some(bytes) => - return Some(bytes) - case None => - logDebug("Block " + blockId + " not found in memory") - } - } - - // Look for block on disk - if (level.useDisk) { - // Read it as a byte buffer into memory first, then return it - diskStore.getBytes(blockId) match { - case Some(bytes) => - assert (0 == bytes.position()) - if (level.useMemory) { - if (level.deserialized) { - memoryStore.putBytes(blockId, bytes, level) - } else { - // The memory store will hang onto the ByteBuffer, so give it a copy instead of - // the memory-mapped file buffer we got from the disk store - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) - memoryStore.putBytes(blockId, copyForMemory, level) - } - } - bytes.rewind() - return Some(bytes) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } - } - } else { - logDebug("Block " + blockId + " not registered locally") - } - return None - } - - /** - * Get block from remote block managers. - */ - def getRemote(blockId: String): Option[Iterator[Any]] = { - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - logDebug("Getting remote block " + blockId) - // Get locations of block - val locations = master.getLocations(blockId) - - // Get block from remote locations - for (loc <- locations) { - logDebug("Getting remote block " + blockId + " from " + loc) - val data = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) - if (data != null) { - return Some(dataDeserialize(blockId, data)) - } - logDebug("The value of block " + blockId + " is null") - } - logDebug("Block " + blockId + " not found") - return None - } - - /** - * Get a block from the block manager (either local or remote). - */ - def get(blockId: String): Option[Iterator[Any]] = { - getLocal(blockId).orElse(getRemote(blockId)) - } - - /** - * Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns - * an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined - * fashion as they're received. Expects a size in bytes to be provided for each block fetched, - * so that we can control the maxMegabytesInFlight for the fetch. - */ - def getMultiple( - blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer) - : BlockFetcherIterator = { - - val iter = - if (System.getProperty("spark.shuffle.use.netty", "false").toBoolean) { - new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer) - } else { - new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) - } - - iter.initialize() - iter - } - - def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) - : Long = { - val elements = new ArrayBuffer[Any] - elements ++= values - put(blockId, elements, level, tellMaster) - } - - /** - * A short circuited method to get a block writer that can write data directly to disk. - * This is currently used for writing shuffle files out. Callers should handle error - * cases. - */ - def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) - : BlockObjectWriter = { - val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize) - writer.registerCloseEventHandler(() => { - val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false) - blockInfo.put(blockId, myInfo) - myInfo.markReady(writer.size()) - }) - writer - } - - /** - * Put a new block of values to the block manager. Returns its (estimated) size in bytes. - */ - def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, - tellMaster: Boolean = true) : Long = { - - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - if (values == null) { - throw new IllegalArgumentException("Values is null") - } - if (level == null || !level.isValid) { - throw new IllegalArgumentException("Storage level is null or invalid") - } - - // Remember the block's storage level so that we can correctly drop it to disk if it needs - // to be dropped right after it got put into memory. Note, however, that other threads will - // not be able to get() this block until we call markReady on its BlockInfo. - val myInfo = { - val tinfo = new BlockInfo(level, tellMaster) - // Do atomically ! - val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) - - if (oldBlockOpt.isDefined) { - if (oldBlockOpt.get.waitForReady()) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") - return oldBlockOpt.get.size - } - - // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? - oldBlockOpt.get - } else { - tinfo - } - } - - val startTimeMs = System.currentTimeMillis - - // If we need to replicate the data, we'll want access to the values, but because our - // put will read the whole iterator, there will be no values left. For the case where - // the put serializes data, we'll remember the bytes, above; but for the case where it - // doesn't, such as deserialized storage, let's rely on the put returning an Iterator. - var valuesAfterPut: Iterator[Any] = null - - // Ditto for the bytes after the put - var bytesAfterPut: ByteBuffer = null - - // Size of the block in bytes (to return to caller) - var size = 0L - - myInfo.synchronized { - logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) - + " to get into synchronized block") - - var marked = false - try { - if (level.useMemory) { - // Save it just to memory first, even if it also has useDisk set to true; we will later - // drop it to disk if the memory store can't hold it. - val res = memoryStore.putValues(blockId, values, level, true) - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case Left(newIterator) => valuesAfterPut = newIterator - } - } else { - // Save directly to disk. - // Don't get back the bytes unless we replicate them. - val askForBytes = level.replication > 1 - val res = diskStore.putValues(blockId, values, level, askForBytes) - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case _ => - } - } - - // Now that the block is in either the memory or disk store, let other threads read it, - // and tell the master about it. - marked = true - myInfo.markReady(size) - if (tellMaster) { - reportBlockStatus(blockId, myInfo) - } - } finally { - // If we failed at putting the block to memory/disk, notify other possible readers - // that it has failed, and then remove it from the block info map. - if (! marked) { - // Note that the remove must happen before markFailure otherwise another thread - // could've inserted a new BlockInfo before we remove it. - blockInfo.remove(blockId) - myInfo.markFailure() - logWarning("Putting block " + blockId + " failed") - } - } - } - logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) - - // Replicate block if required - if (level.replication > 1) { - val remoteStartTime = System.currentTimeMillis - // Serialize the block if not already done - if (bytesAfterPut == null) { - if (valuesAfterPut == null) { - throw new SparkException( - "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") - } - bytesAfterPut = dataSerialize(blockId, valuesAfterPut) - } - replicate(blockId, bytesAfterPut, level) - logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime)) - } - BlockManager.dispose(bytesAfterPut) - - return size - } - - - /** - * Put a new block of serialized bytes to the block manager. - */ - def putBytes( - blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) { - - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - if (bytes == null) { - throw new IllegalArgumentException("Bytes is null") - } - if (level == null || !level.isValid) { - throw new IllegalArgumentException("Storage level is null or invalid") - } - - // Remember the block's storage level so that we can correctly drop it to disk if it needs - // to be dropped right after it got put into memory. Note, however, that other threads will - // not be able to get() this block until we call markReady on its BlockInfo. - val myInfo = { - val tinfo = new BlockInfo(level, tellMaster) - // Do atomically ! - val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) - - if (oldBlockOpt.isDefined) { - if (oldBlockOpt.get.waitForReady()) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") - return - } - - // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? - oldBlockOpt.get - } else { - tinfo - } - } - - val startTimeMs = System.currentTimeMillis - - // Initiate the replication before storing it locally. This is faster as - // data is already serialized and ready for sending - val replicationFuture = if (level.replication > 1) { - val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper - Future { - replicate(blockId, bufferView, level) - } - } else { - null - } - - myInfo.synchronized { - logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) - + " to get into synchronized block") - - var marked = false - try { - if (level.useMemory) { - // Store it only in memory at first, even if useDisk is also set to true - bytes.rewind() - memoryStore.putBytes(blockId, bytes, level) - } else { - bytes.rewind() - diskStore.putBytes(blockId, bytes, level) - } - - // assert (0 == bytes.position(), "" + bytes) - - // Now that the block is in either the memory or disk store, let other threads read it, - // and tell the master about it. - marked = true - myInfo.markReady(bytes.limit) - if (tellMaster) { - reportBlockStatus(blockId, myInfo) - } - } finally { - // If we failed at putting the block to memory/disk, notify other possible readers - // that it has failed, and then remove it from the block info map. - if (! marked) { - // Note that the remove must happen before markFailure otherwise another thread - // could've inserted a new BlockInfo before we remove it. - blockInfo.remove(blockId) - myInfo.markFailure() - logWarning("Putting block " + blockId + " failed") - } - } - } - - // If replication had started, then wait for it to finish - if (level.replication > 1) { - Await.ready(replicationFuture, Duration.Inf) - } - - if (level.replication > 1) { - logDebug("PutBytes for block " + blockId + " with replication took " + - Utils.getUsedTimeMs(startTimeMs)) - } else { - logDebug("PutBytes for block " + blockId + " without replication took " + - Utils.getUsedTimeMs(startTimeMs)) - } - } - - /** - * Replicate block to another node. - */ - var cachedPeers: Seq[BlockManagerId] = null - private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { - val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) - if (cachedPeers == null) { - cachedPeers = master.getPeers(blockManagerId, level.replication - 1) - } - for (peer: BlockManagerId <- cachedPeers) { - val start = System.nanoTime - data.rewind() - logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " - + data.limit() + " Bytes. To node: " + peer) - if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel), - new ConnectionManagerId(peer.host, peer.port))) { - logError("Failed to call syncPutBlock to " + peer) - } - logDebug("Replicated BlockId " + blockId + " once used " + - (System.nanoTime - start) / 1e6 + " s; The size of the data is " + - data.limit() + " bytes.") - } - } - - /** - * Read a block consisting of a single object. - */ - def getSingle(blockId: String): Option[Any] = { - get(blockId).map(_.next()) - } - - /** - * Write a block consisting of a single object. - */ - def putSingle(blockId: String, value: Any, level: StorageLevel, tellMaster: Boolean = true) { - put(blockId, Iterator(value), level, tellMaster) - } - - /** - * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory - * store reaches its limit and needs to free up space. - */ - def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) { - logInfo("Dropping block " + blockId + " from memory") - val info = blockInfo.get(blockId).orNull - if (info != null) { - info.synchronized { - // required ? As of now, this will be invoked only for blocks which are ready - // But in case this changes in future, adding for consistency sake. - if (! info.waitForReady() ) { - // If we get here, the block write failed. - logWarning("Block " + blockId + " was marked as failure. Nothing to drop") - return - } - - val level = info.level - if (level.useDisk && !diskStore.contains(blockId)) { - logInfo("Writing block " + blockId + " to disk") - data match { - case Left(elements) => - diskStore.putValues(blockId, elements, level, false) - case Right(bytes) => - diskStore.putBytes(blockId, bytes, level) - } - } - val droppedMemorySize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L - val blockWasRemoved = memoryStore.remove(blockId) - if (!blockWasRemoved) { - logWarning("Block " + blockId + " could not be dropped from memory as it does not exist") - } - if (info.tellMaster) { - reportBlockStatus(blockId, info, droppedMemorySize) - } - if (!level.useDisk) { - // The block is completely gone from this node; forget it so we can put() it again later. - blockInfo.remove(blockId) - } - } - } else { - // The block has already been dropped - } - } - - /** - * Remove all blocks belonging to the given RDD. - * @return The number of blocks removed. - */ - def removeRdd(rddId: Int): Int = { - // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps - // from RDD.id to blocks. - logInfo("Removing RDD " + rddId) - val rddPrefix = "rdd_" + rddId + "_" - val blocksToRemove = blockInfo.filter(_._1.startsWith(rddPrefix)).map(_._1) - blocksToRemove.foreach(blockId => removeBlock(blockId, false)) - blocksToRemove.size - } - - /** - * Remove a block from both memory and disk. - */ - def removeBlock(blockId: String, tellMaster: Boolean = true) { - logInfo("Removing block " + blockId) - val info = blockInfo.get(blockId).orNull - if (info != null) info.synchronized { - // Removals are idempotent in disk store and memory store. At worst, we get a warning. - val removedFromMemory = memoryStore.remove(blockId) - val removedFromDisk = diskStore.remove(blockId) - if (!removedFromMemory && !removedFromDisk) { - logWarning("Block " + blockId + " could not be removed as it was not found in either " + - "the disk or memory store") - } - blockInfo.remove(blockId) - if (tellMaster && info.tellMaster) { - reportBlockStatus(blockId, info) - } - } else { - // The block has already been removed; do nothing. - logWarning("Asked to remove block " + blockId + ", which does not exist") - } - } - - def dropOldBlocks(cleanupTime: Long) { - logInfo("Dropping blocks older than " + cleanupTime) - val iterator = blockInfo.internalMap.entrySet().iterator() - while (iterator.hasNext) { - val entry = iterator.next() - val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) - if (time < cleanupTime) { - info.synchronized { - val level = info.level - if (level.useMemory) { - memoryStore.remove(id) - } - if (level.useDisk) { - diskStore.remove(id) - } - iterator.remove() - logInfo("Dropped block " + id) - } - reportBlockStatus(id, info) - } - } - } - - def shouldCompress(blockId: String): Boolean = { - if (ShuffleBlockManager.isShuffle(blockId)) { - compressShuffle - } else if (blockId.startsWith("broadcast_")) { - compressBroadcast - } else if (blockId.startsWith("rdd_")) { - compressRdds - } else { - false // Won't happen in a real cluster, but it can in tests - } - } - - /** - * Wrap an output stream for compression if block compression is enabled for its block type - */ - def wrapForCompression(blockId: String, s: OutputStream): OutputStream = { - if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s - } - - /** - * Wrap an input stream for compression if block compression is enabled for its block type - */ - def wrapForCompression(blockId: String, s: InputStream): InputStream = { - if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s - } - - def dataSerialize( - blockId: String, - values: Iterator[Any], - serializer: Serializer = defaultSerializer): ByteBuffer = { - val byteStream = new FastByteArrayOutputStream(4096) - val ser = serializer.newInstance() - ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() - byteStream.trim() - ByteBuffer.wrap(byteStream.array) - } - - /** - * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of - * the iterator is reached. - */ - def dataDeserialize( - blockId: String, - bytes: ByteBuffer, - serializer: Serializer = defaultSerializer): Iterator[Any] = { - bytes.rewind() - val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) - serializer.newInstance().deserializeStream(stream).asIterator - } - - def stop() { - if (heartBeatTask != null) { - heartBeatTask.cancel() - } - connectionManager.stop() - actorSystem.stop(slaveActor) - blockInfo.clear() - memoryStore.clear() - diskStore.clear() - metadataCleaner.cancel() - logInfo("BlockManager stopped") - } -} - - -private[spark] object BlockManager extends Logging { - - val ID_GENERATOR = new IdGenerator - - def getMaxMemoryFromSystemProperties: Long = { - val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble - (Runtime.getRuntime.maxMemory * memoryFraction).toLong - } - - def getHeartBeatFrequencyFromSystemProperties: Long = - System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong / 4 - - def getDisableHeartBeatsForTesting: Boolean = - System.getProperty("spark.test.disableBlockManagerHeartBeat", "false").toBoolean - - /** - * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that - * might cause errors if one attempts to read from the unmapped buffer, but it's better than - * waiting for the GC to find it because that could lead to huge numbers of open files. There's - * unfortunately no standard API to do this. - */ - def dispose(buffer: ByteBuffer) { - if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { - logTrace("Unmapping " + buffer) - if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) { - buffer.asInstanceOf[DirectBuffer].cleaner().clean() - } - } - } - - def blockIdsToBlockManagers( - blockIds: Array[String], - env: SparkEnv, - blockManagerMaster: BlockManagerMaster = null) - : Map[String, Seq[BlockManagerId]] = - { - // env == null and blockManagerMaster != null is used in tests - assert (env != null || blockManagerMaster != null) - val blockLocations: Seq[Seq[BlockManagerId]] = if (env != null) { - env.blockManager.getLocationBlockIds(blockIds) - } else { - blockManagerMaster.getLocations(blockIds) - } - - val blockManagers = new HashMap[String, Seq[BlockManagerId]] - for (i <- 0 until blockIds.length) { - blockManagers(blockIds(i)) = blockLocations(i) - } - blockManagers.toMap - } - - def blockIdsToExecutorIds( - blockIds: Array[String], - env: SparkEnv, - blockManagerMaster: BlockManagerMaster = null) - : Map[String, Seq[String]] = - { - blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId)) - } - - def blockIdsToHosts( - blockIds: Array[String], - env: SparkEnv, - blockManagerMaster: BlockManagerMaster = null) - : Map[String, Seq[String]] = - { - blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host)) - } -} - diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala deleted file mode 100644 index b36a6176c0..0000000000 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} -import java.util.concurrent.ConcurrentHashMap -import spark.Utils - -/** - * This class represent an unique identifier for a BlockManager. - * The first 2 constructors of this class is made private to ensure that - * BlockManagerId objects can be created only using the apply method in - * the companion object. This allows de-duplication of ID objects. - * Also, constructor parameters are private to ensure that parameters cannot - * be modified from outside this class. - */ -private[spark] class BlockManagerId private ( - private var executorId_ : String, - private var host_ : String, - private var port_ : Int, - private var nettyPort_ : Int - ) extends Externalizable { - - private def this() = this(null, null, 0, 0) // For deserialization only - - def executorId: String = executorId_ - - if (null != host_){ - Utils.checkHost(host_, "Expected hostname") - assert (port_ > 0) - } - - def hostPort: String = { - // DEBUG code - Utils.checkHost(host) - assert (port > 0) - - host + ":" + port - } - - def host: String = host_ - - def port: Int = port_ - - def nettyPort: Int = nettyPort_ - - override def writeExternal(out: ObjectOutput) { - out.writeUTF(executorId_) - out.writeUTF(host_) - out.writeInt(port_) - out.writeInt(nettyPort_) - } - - override def readExternal(in: ObjectInput) { - executorId_ = in.readUTF() - host_ = in.readUTF() - port_ = in.readInt() - nettyPort_ = in.readInt() - } - - @throws(classOf[IOException]) - private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - - override def toString = "BlockManagerId(%s, %s, %d, %d)".format(executorId, host, port, nettyPort) - - override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port + nettyPort - - override def equals(that: Any) = that match { - case id: BlockManagerId => - executorId == id.executorId && port == id.port && host == id.host && nettyPort == id.nettyPort - case _ => - false - } -} - - -private[spark] object BlockManagerId { - - /** - * Returns a [[spark.storage.BlockManagerId]] for the given configuraiton. - * - * @param execId ID of the executor. - * @param host Host name of the block manager. - * @param port Port of the block manager. - * @param nettyPort Optional port for the Netty-based shuffle sender. - * @return A new [[spark.storage.BlockManagerId]]. - */ - def apply(execId: String, host: String, port: Int, nettyPort: Int) = - getCachedBlockManagerId(new BlockManagerId(execId, host, port, nettyPort)) - - def apply(in: ObjectInput) = { - val obj = new BlockManagerId() - obj.readExternal(in) - getCachedBlockManagerId(obj) - } - - val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() - - def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { - blockManagerIdCache.putIfAbsent(id, id) - blockManagerIdCache.get(id) - } -} diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala deleted file mode 100644 index 76128e8cff..0000000000 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ /dev/null @@ -1,178 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import akka.actor.ActorRef -import akka.dispatch.{Await, Future} -import akka.pattern.ask -import akka.util.Duration - -import spark.{Logging, SparkException} -import spark.storage.BlockManagerMessages._ - - -private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging { - - val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt - val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt - - val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" - - val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") - - /** Remove a dead executor from the driver actor. This is only called on the driver side. */ - def removeExecutor(execId: String) { - tell(RemoveExecutor(execId)) - logInfo("Removed " + execId + " successfully in removeExecutor") - } - - /** - * Send the driver actor a heart beat from the slave. Returns true if everything works out, - * false if the driver does not know about the given block manager, which means the block - * manager should re-register. - */ - def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = { - askDriverWithReply[Boolean](HeartBeat(blockManagerId)) - } - - /** Register the BlockManager's id with the driver. */ - def registerBlockManager( - blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { - logInfo("Trying to register BlockManager") - tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) - logInfo("Registered BlockManager") - } - - def updateBlockInfo( - blockManagerId: BlockManagerId, - blockId: String, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long): Boolean = { - val res = askDriverWithReply[Boolean]( - UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) - logInfo("Updated info of block " + blockId) - res - } - - /** Get locations of the blockId from the driver */ - def getLocations(blockId: String): Seq[BlockManagerId] = { - askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId)) - } - - /** Get locations of multiple blockIds from the driver */ - def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { - askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) - } - - /** Get ids of other nodes in the cluster from the driver */ - def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { - val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) - if (result.length != numPeers) { - throw new SparkException( - "Error getting peers, only got " + result.size + " instead of " + numPeers) - } - result - } - - /** - * Remove a block from the slaves that have it. This can only be used to remove - * blocks that the driver knows about. - */ - def removeBlock(blockId: String) { - askDriverWithReply(RemoveBlock(blockId)) - } - - /** - * Remove all blocks belonging to the given RDD. - */ - def removeRdd(rddId: Int, blocking: Boolean) { - val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) - future onFailure { - case e: Throwable => logError("Failed to remove RDD " + rddId, e) - } - if (blocking) { - Await.result(future, timeout) - } - } - - /** - * Return the memory status for each block manager, in the form of a map from - * the block manager's id to two long values. The first value is the maximum - * amount of memory allocated for the block manager, while the second is the - * amount of remaining memory. - */ - def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) - } - - def getStorageStatus: Array[StorageStatus] = { - askDriverWithReply[Array[StorageStatus]](GetStorageStatus) - } - - /** Stop the driver actor, called only on the Spark driver node */ - def stop() { - if (driverActor != null) { - tell(StopBlockManagerMaster) - driverActor = null - logInfo("BlockManagerMaster stopped") - } - } - - /** Send a one-way message to the master actor, to which we expect it to reply with true. */ - private def tell(message: Any) { - if (!askDriverWithReply[Boolean](message)) { - throw new SparkException("BlockManagerMasterActor returned false, expected true.") - } - } - - /** - * Send a message to the driver actor and get its result within a default timeout, or - * throw a SparkException if this fails. - */ - private def askDriverWithReply[T](message: Any): T = { - // TODO: Consider removing multiple attempts - if (driverActor == null) { - throw new SparkException("Error sending message to BlockManager as driverActor is null " + - "[message = " + message + "]") - } - var attempts = 0 - var lastException: Exception = null - while (attempts < AKKA_RETRY_ATTEMPTS) { - attempts += 1 - try { - val future = driverActor.ask(message)(timeout) - val result = Await.result(future, timeout) - if (result == null) { - throw new SparkException("BlockManagerMaster returned null") - } - return result.asInstanceOf[T] - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning("Error sending message to BlockManagerMaster in " + attempts + " attempts", e) - } - Thread.sleep(AKKA_RETRY_INTERVAL_MS) - } - - throw new SparkException( - "Error sending message to BlockManagerMaster [message = " + message + "]", lastException) - } - -} diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala deleted file mode 100644 index b7a981d101..0000000000 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ /dev/null @@ -1,404 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.util.{HashMap => JHashMap} - -import scala.collection.mutable -import scala.collection.JavaConversions._ - -import akka.actor.{Actor, ActorRef, Cancellable} -import akka.dispatch.Future -import akka.pattern.ask -import akka.util.Duration -import akka.util.duration._ - -import spark.{Logging, Utils, SparkException} -import spark.storage.BlockManagerMessages._ - - -/** - * BlockManagerMasterActor is an actor on the master node to track statuses of - * all slaves' block managers. - */ -private[spark] -class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { - - // Mapping from block manager id to the block manager's information. - private val blockManagerInfo = - new mutable.HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo] - - // Mapping from executor ID to block manager ID. - private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId] - - // Mapping from block id to the set of block managers that have the block. - private val blockLocations = new JHashMap[String, mutable.HashSet[BlockManagerId]] - - val akkaTimeout = Duration.create( - System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") - - initLogging() - - val slaveTimeout = System.getProperty("spark.storage.blockManagerSlaveTimeoutMs", - "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong - - val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", - "60000").toLong - - var timeoutCheckingTask: Cancellable = null - - override def preStart() { - if (!BlockManager.getDisableHeartBeatsForTesting) { - timeoutCheckingTask = context.system.scheduler.schedule( - 0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) - } - super.preStart() - } - - def receive = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => - register(blockManagerId, maxMemSize, slaveActor) - sender ! true - - case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => - // TODO: Ideally we want to handle all the message replies in receive instead of in the - // individual private methods. - updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) - - case GetLocations(blockId) => - sender ! getLocations(blockId) - - case GetLocationsMultipleBlockIds(blockIds) => - sender ! getLocationsMultipleBlockIds(blockIds) - - case GetPeers(blockManagerId, size) => - sender ! getPeers(blockManagerId, size) - - case GetMemoryStatus => - sender ! memoryStatus - - case GetStorageStatus => - sender ! storageStatus - - case RemoveRdd(rddId) => - sender ! removeRdd(rddId) - - case RemoveBlock(blockId) => - removeBlockFromWorkers(blockId) - sender ! true - - case RemoveExecutor(execId) => - removeExecutor(execId) - sender ! true - - case StopBlockManagerMaster => - logInfo("Stopping BlockManagerMaster") - sender ! true - if (timeoutCheckingTask != null) { - timeoutCheckingTask.cancel() - } - context.stop(self) - - case ExpireDeadHosts => - expireDeadHosts() - - case HeartBeat(blockManagerId) => - sender ! heartBeat(blockManagerId) - - case other => - logWarning("Got unknown message: " + other) - } - - private def removeRdd(rddId: Int): Future[Seq[Int]] = { - // First remove the metadata for the given RDD, and then asynchronously remove the blocks - // from the slaves. - - val prefix = "rdd_" + rddId + "_" - // Find all blocks for the given RDD, remove the block from both blockLocations and - // the blockManagerInfo that is tracking the blocks. - val blocks = blockLocations.keySet().filter(_.startsWith(prefix)) - blocks.foreach { blockId => - val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) - bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) - blockLocations.remove(blockId) - } - - // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. - // The dispatcher is used as an implicit argument into the Future sequence construction. - import context.dispatcher - val removeMsg = RemoveRdd(rddId) - Future.sequence(blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] - }.toSeq) - } - - private def removeBlockManager(blockManagerId: BlockManagerId) { - val info = blockManagerInfo(blockManagerId) - - // Remove the block manager from blockManagerIdByExecutor. - blockManagerIdByExecutor -= blockManagerId.executorId - - // Remove it from blockManagerInfo and remove all the blocks. - blockManagerInfo.remove(blockManagerId) - val iterator = info.blocks.keySet.iterator - while (iterator.hasNext) { - val blockId = iterator.next - val locations = blockLocations.get(blockId) - locations -= blockManagerId - if (locations.size == 0) { - blockLocations.remove(locations) - } - } - } - - private def expireDeadHosts() { - logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.") - val now = System.currentTimeMillis() - val minSeenTime = now - slaveTimeout - val toRemove = new mutable.HashSet[BlockManagerId] - for (info <- blockManagerInfo.values) { - if (info.lastSeenMs < minSeenTime) { - logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: " + - (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms") - toRemove += info.blockManagerId - } - } - toRemove.foreach(removeBlockManager) - } - - private def removeExecutor(execId: String) { - logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") - blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) - } - - private def heartBeat(blockManagerId: BlockManagerId): Boolean = { - if (!blockManagerInfo.contains(blockManagerId)) { - blockManagerId.executorId == "" && !isLocal - } else { - blockManagerInfo(blockManagerId).updateLastSeenMs() - true - } - } - - // Remove a block from the slaves that have it. This can only be used to remove - // blocks that the master knows about. - private def removeBlockFromWorkers(blockId: String) { - val locations = blockLocations.get(blockId) - if (locations != null) { - locations.foreach { blockManagerId: BlockManagerId => - val blockManager = blockManagerInfo.get(blockManagerId) - if (blockManager.isDefined) { - // Remove the block from the slave's BlockManager. - // Doesn't actually wait for a confirmation and the message might get lost. - // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveActor ! RemoveBlock(blockId) - } - } - } - } - - // Return a map from the block manager id to max memory and remaining memory. - private def memoryStatus: Map[BlockManagerId, (Long, Long)] = { - blockManagerInfo.map { case(blockManagerId, info) => - (blockManagerId, (info.maxMem, info.remainingMem)) - }.toMap - } - - private def storageStatus: Array[StorageStatus] = { - blockManagerInfo.map { case(blockManagerId, info) => - import collection.JavaConverters._ - StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap) - }.toArray - } - - private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { - if (id.executorId == "" && !isLocal) { - // Got a register message from the master node; don't register it - } else if (!blockManagerInfo.contains(id)) { - blockManagerIdByExecutor.get(id.executorId) match { - case Some(manager) => - // A block manager of the same executor already exists. - // This should never happen. Let's just quit. - logError("Got two different block manager registrations on " + id.executorId) - System.exit(1) - case None => - blockManagerIdByExecutor(id.executorId) = id - } - blockManagerInfo(id) = new BlockManagerMasterActor.BlockManagerInfo( - id, System.currentTimeMillis(), maxMemSize, slaveActor) - } - } - - private def updateBlockInfo( - blockManagerId: BlockManagerId, - blockId: String, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long) { - - if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.executorId == "" && !isLocal) { - // We intentionally do not register the master (except in local mode), - // so we should not indicate failure. - sender ! true - } else { - sender ! false - } - return - } - - if (blockId == null) { - blockManagerInfo(blockManagerId).updateLastSeenMs() - sender ! true - return - } - - blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize) - - var locations: mutable.HashSet[BlockManagerId] = null - if (blockLocations.containsKey(blockId)) { - locations = blockLocations.get(blockId) - } else { - locations = new mutable.HashSet[BlockManagerId] - blockLocations.put(blockId, locations) - } - - if (storageLevel.isValid) { - locations.add(blockManagerId) - } else { - locations.remove(blockManagerId) - } - - // Remove the block from master tracking if it has been removed on all slaves. - if (locations.size == 0) { - blockLocations.remove(blockId) - } - sender ! true - } - - private def getLocations(blockId: String): Seq[BlockManagerId] = { - if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty - } - - private def getLocationsMultipleBlockIds(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { - blockIds.map(blockId => getLocations(blockId)) - } - - private def getPeers(blockManagerId: BlockManagerId, size: Int): Seq[BlockManagerId] = { - val peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray - - val selfIndex = peers.indexOf(blockManagerId) - if (selfIndex == -1) { - throw new SparkException("Self index for " + blockManagerId + " not found") - } - - // Note that this logic will select the same node multiple times if there aren't enough peers - Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq - } -} - - -private[spark] -object BlockManagerMasterActor { - - case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) - - class BlockManagerInfo( - val blockManagerId: BlockManagerId, - timeMs: Long, - val maxMem: Long, - val slaveActor: ActorRef) - extends Logging { - - private var _lastSeenMs: Long = timeMs - private var _remainingMem: Long = maxMem - - // Mapping from block id to its status. - private val _blocks = new JHashMap[String, BlockStatus] - - logInfo("Registering block manager %s with %s RAM".format( - blockManagerId.hostPort, Utils.bytesToString(maxMem))) - - def updateLastSeenMs() { - _lastSeenMs = System.currentTimeMillis() - } - - def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, - diskSize: Long) { - - updateLastSeenMs() - - if (_blocks.containsKey(blockId)) { - // The block exists on the slave already. - val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel - - if (originalLevel.useMemory) { - _remainingMem += memSize - } - } - - if (storageLevel.isValid) { - // isValid means it is either stored in-memory or on-disk. - _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) - if (storageLevel.useMemory) { - _remainingMem -= memSize - logInfo("Added %s in memory on %s (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), - Utils.bytesToString(_remainingMem))) - } - if (storageLevel.useDisk) { - logInfo("Added %s on disk on %s (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) - } - } else if (_blocks.containsKey(blockId)) { - // If isValid is not true, drop the block. - val blockStatus: BlockStatus = _blocks.get(blockId) - _blocks.remove(blockId) - if (blockStatus.storageLevel.useMemory) { - _remainingMem += blockStatus.memSize - logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), - Utils.bytesToString(_remainingMem))) - } - if (blockStatus.storageLevel.useDisk) { - logInfo("Removed %s on %s on disk (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) - } - } - } - - def removeBlock(blockId: String) { - if (_blocks.containsKey(blockId)) { - _remainingMem += _blocks.get(blockId).memSize - _blocks.remove(blockId) - } - } - - def remainingMem: Long = _remainingMem - - def lastSeenMs: Long = _lastSeenMs - - def blocks: JHashMap[String, BlockStatus] = _blocks - - override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem - - def clear() { - _blocks.clear() - } - } -} diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala deleted file mode 100644 index 9375a9ca54..0000000000 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.io.{Externalizable, ObjectInput, ObjectOutput} - -import akka.actor.ActorRef - - -private[storage] object BlockManagerMessages { - ////////////////////////////////////////////////////////////////////////////////// - // Messages from the master to slaves. - ////////////////////////////////////////////////////////////////////////////////// - sealed trait ToBlockManagerSlave - - // Remove a block from the slaves that have it. This can only be used to remove - // blocks that the master knows about. - case class RemoveBlock(blockId: String) extends ToBlockManagerSlave - - // Remove all blocks belonging to a specific RDD. - case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave - - - ////////////////////////////////////////////////////////////////////////////////// - // Messages from slaves to the master. - ////////////////////////////////////////////////////////////////////////////////// - sealed trait ToBlockManagerMaster - - case class RegisterBlockManager( - blockManagerId: BlockManagerId, - maxMemSize: Long, - sender: ActorRef) - extends ToBlockManagerMaster - - case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - - class UpdateBlockInfo( - var blockManagerId: BlockManagerId, - var blockId: String, - var storageLevel: StorageLevel, - var memSize: Long, - var diskSize: Long) - extends ToBlockManagerMaster - with Externalizable { - - def this() = this(null, null, null, 0, 0) // For deserialization only - - override def writeExternal(out: ObjectOutput) { - blockManagerId.writeExternal(out) - out.writeUTF(blockId) - storageLevel.writeExternal(out) - out.writeLong(memSize) - out.writeLong(diskSize) - } - - override def readExternal(in: ObjectInput) { - blockManagerId = BlockManagerId(in) - blockId = in.readUTF() - storageLevel = StorageLevel(in) - memSize = in.readLong() - diskSize = in.readLong() - } - } - - object UpdateBlockInfo { - def apply(blockManagerId: BlockManagerId, - blockId: String, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long): UpdateBlockInfo = { - new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize) - } - - // For pattern-matching - def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { - Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) - } - } - - case class GetLocations(blockId: String) extends ToBlockManagerMaster - - case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster - - case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster - - case class RemoveExecutor(execId: String) extends ToBlockManagerMaster - - case object StopBlockManagerMaster extends ToBlockManagerMaster - - case object GetMemoryStatus extends ToBlockManagerMaster - - case object ExpireDeadHosts extends ToBlockManagerMaster - - case object GetStorageStatus extends ToBlockManagerMaster -} diff --git a/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala deleted file mode 100644 index 6e5fb43732..0000000000 --- a/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import akka.actor.Actor - -import spark.storage.BlockManagerMessages._ - - -/** - * An actor to take commands from the master to execute options. For example, - * this is used to remove blocks from the slave's BlockManager. - */ -class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor { - override def receive = { - - case RemoveBlock(blockId) => - blockManager.removeBlock(blockId) - - case RemoveRdd(rddId) => - val numBlocksRemoved = blockManager.removeRdd(rddId) - sender ! numBlocksRemoved - } -} diff --git a/core/src/main/scala/spark/storage/BlockManagerSource.scala b/core/src/main/scala/spark/storage/BlockManagerSource.scala deleted file mode 100644 index 2aecd1ea71..0000000000 --- a/core/src/main/scala/spark/storage/BlockManagerSource.scala +++ /dev/null @@ -1,48 +0,0 @@ -package spark.storage - -import com.codahale.metrics.{Gauge,MetricRegistry} - -import spark.metrics.source.Source - - -private[spark] class BlockManagerSource(val blockManager: BlockManager) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "BlockManager" - - metricRegistry.register(MetricRegistry.name("memory", "maxMem", "MBytes"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val maxMem = storageStatusList.map(_.maxMem).reduce(_ + _) - maxMem / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("memory", "remainingMem", "MBytes"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val remainingMem = storageStatusList.map(_.memRemaining).reduce(_ + _) - remainingMem / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("memory", "memUsed", "MBytes"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val maxMem = storageStatusList.map(_.maxMem).reduce(_ + _) - val remainingMem = storageStatusList.map(_.memRemaining).reduce(_ + _) - (maxMem - remainingMem) / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("disk", "diskSpaceUsed", "MBytes"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val diskSpaceUsed = storageStatusList - .flatMap(_.blocks.values.map(_.diskSize)) - .reduceOption(_ + _) - .getOrElse(0L) - - diskSpaceUsed / 1024 / 1024 - } - }) -} diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala deleted file mode 100644 index 39064bce92..0000000000 --- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.nio.ByteBuffer - -import spark.{Logging, Utils} -import spark.network._ - -/** - * A network interface for BlockManager. Each slave should have one - * BlockManagerWorker. - * - * TODO: Use event model. - */ -private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging { - initLogging() - - blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive) - - def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { - logDebug("Handling message " + msg) - msg match { - case bufferMessage: BufferMessage => { - try { - logDebug("Handling as a buffer message " + bufferMessage) - val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) - logDebug("Parsed as a block message array") - val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) - return Some(new BlockMessageArray(responseMessages).toBufferMessage) - } catch { - case e: Exception => logError("Exception handling buffer message", e) - return None - } - } - case otherMessage: Any => { - logError("Unknown type message received: " + otherMessage) - return None - } - } - } - - def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { - blockMessage.getType match { - case BlockMessage.TYPE_PUT_BLOCK => { - val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - logDebug("Received [" + pB + "]") - putBlock(pB.id, pB.data, pB.level) - return None - } - case BlockMessage.TYPE_GET_BLOCK => { - val gB = new GetBlock(blockMessage.getId) - logDebug("Received [" + gB + "]") - val buffer = getBlock(gB.id) - if (buffer == null) { - return None - } - return Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer))) - } - case _ => return None - } - } - - private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) { - val startTimeMs = System.currentTimeMillis() - logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes) - blockManager.putBytes(id, bytes, level) - logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) - + " with data size: " + bytes.limit) - } - - private def getBlock(id: String): ByteBuffer = { - val startTimeMs = System.currentTimeMillis() - logDebug("GetBlock " + id + " started from " + startTimeMs) - val buffer = blockManager.getLocalBytes(id) match { - case Some(bytes) => bytes - case None => null - } - logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) - + " and got buffer " + buffer) - return buffer - } -} - -private[spark] object BlockManagerWorker extends Logging { - private var blockManagerWorker: BlockManagerWorker = null - - initLogging() - - def startBlockManagerWorker(manager: BlockManager) { - blockManagerWorker = new BlockManagerWorker(manager) - } - - def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = { - val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val blockMessage = BlockMessage.fromPutBlock(msg) - val blockMessageArray = new BlockMessageArray(blockMessage) - val resultMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage) - return (resultMessage != None) - } - - def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { - val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val blockMessage = BlockMessage.fromGetBlock(msg) - val blockMessageArray = new BlockMessageArray(blockMessage) - val responseMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage) - responseMessage match { - case Some(message) => { - val bufferMessage = message.asInstanceOf[BufferMessage] - logDebug("Response message received " + bufferMessage) - BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => { - logDebug("Found " + blockMessage) - return blockMessage.getData - }) - } - case None => logDebug("No response message received"); return null - } - return null - } -} diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala deleted file mode 100644 index bcce26b7c1..0000000000 --- a/core/src/main/scala/spark/storage/BlockMessage.scala +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.nio.ByteBuffer - -import scala.collection.mutable.StringBuilder -import scala.collection.mutable.ArrayBuffer - -import spark.network._ - -private[spark] case class GetBlock(id: String) -private[spark] case class GotBlock(id: String, data: ByteBuffer) -private[spark] case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel) - -private[spark] class BlockMessage() { - // Un-initialized: typ = 0 - // GetBlock: typ = 1 - // GotBlock: typ = 2 - // PutBlock: typ = 3 - private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED - private var id: String = null - private var data: ByteBuffer = null - private var level: StorageLevel = null - - def set(getBlock: GetBlock) { - typ = BlockMessage.TYPE_GET_BLOCK - id = getBlock.id - } - - def set(gotBlock: GotBlock) { - typ = BlockMessage.TYPE_GOT_BLOCK - id = gotBlock.id - data = gotBlock.data - } - - def set(putBlock: PutBlock) { - typ = BlockMessage.TYPE_PUT_BLOCK - id = putBlock.id - data = putBlock.data - level = putBlock.level - } - - def set(buffer: ByteBuffer) { - val startTime = System.currentTimeMillis - /* - println() - println("BlockMessage: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ - typ = buffer.getInt() - val idLength = buffer.getInt() - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buffer.getChar() - } - id = idBuilder.toString() - - if (typ == BlockMessage.TYPE_PUT_BLOCK) { - - val booleanInt = buffer.getInt() - val replication = buffer.getInt() - level = StorageLevel(booleanInt, replication) - - val dataLength = buffer.getInt() - data = ByteBuffer.allocate(dataLength) - if (dataLength != buffer.remaining) { - throw new Exception("Error parsing buffer") - } - data.put(buffer) - data.flip() - } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { - - val dataLength = buffer.getInt() - data = ByteBuffer.allocate(dataLength) - if (dataLength != buffer.remaining) { - throw new Exception("Error parsing buffer") - } - data.put(buffer) - data.flip() - } - - val finishTime = System.currentTimeMillis - } - - def set(bufferMsg: BufferMessage) { - val buffer = bufferMsg.buffers.apply(0) - buffer.clear() - set(buffer) - } - - def getType: Int = { - return typ - } - - def getId: String = { - return id - } - - def getData: ByteBuffer = { - return data - } - - def getLevel: StorageLevel = { - return level - } - - def toBufferMessage: BufferMessage = { - val startTime = System.currentTimeMillis - val buffers = new ArrayBuffer[ByteBuffer]() - var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2) - buffer.putInt(typ).putInt(id.length()) - id.foreach((x: Char) => buffer.putChar(x)) - buffer.flip() - buffers += buffer - - if (typ == BlockMessage.TYPE_PUT_BLOCK) { - buffer = ByteBuffer.allocate(8).putInt(level.toInt).putInt(level.replication) - buffer.flip() - buffers += buffer - - buffer = ByteBuffer.allocate(4).putInt(data.remaining) - buffer.flip() - buffers += buffer - - buffers += data - } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { - buffer = ByteBuffer.allocate(4).putInt(data.remaining) - buffer.flip() - buffers += buffer - - buffers += data - } - - /* - println() - println("BlockMessage: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ - val finishTime = System.currentTimeMillis - return Message.createBufferMessage(buffers) - } - - override def toString: String = { - "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + - ", data = " + (if (data != null) data.remaining.toString else "null") + "]" - } -} - -private[spark] object BlockMessage { - val TYPE_NON_INITIALIZED: Int = 0 - val TYPE_GET_BLOCK: Int = 1 - val TYPE_GOT_BLOCK: Int = 2 - val TYPE_PUT_BLOCK: Int = 3 - - def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(bufferMessage) - newBlockMessage - } - - def fromByteBuffer(buffer: ByteBuffer): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(buffer) - newBlockMessage - } - - def fromGetBlock(getBlock: GetBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(getBlock) - newBlockMessage - } - - def fromGotBlock(gotBlock: GotBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(gotBlock) - newBlockMessage - } - - def fromPutBlock(putBlock: PutBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(putBlock) - newBlockMessage - } - - def main(args: Array[String]) { - val B = new BlockMessage() - B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2)) - val bMsg = B.toBufferMessage - val C = new BlockMessage() - C.set(bMsg) - - println(B.getId + " " + B.getLevel) - println(C.getId + " " + C.getLevel) - } -} diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala deleted file mode 100644 index ee2fc167d5..0000000000 --- a/core/src/main/scala/spark/storage/BlockMessageArray.scala +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - -import spark._ -import spark.network._ - -private[spark] -class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockMessage] with Logging { - - def this(bm: BlockMessage) = this(Array(bm)) - - def this() = this(null.asInstanceOf[Seq[BlockMessage]]) - - def apply(i: Int) = blockMessages(i) - - def iterator = blockMessages.iterator - - def length = blockMessages.length - - initLogging() - - def set(bufferMessage: BufferMessage) { - val startTime = System.currentTimeMillis - val newBlockMessages = new ArrayBuffer[BlockMessage]() - val buffer = bufferMessage.buffers(0) - buffer.clear() - /* - println() - println("BlockMessageArray: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ - while (buffer.remaining() > 0) { - val size = buffer.getInt() - logDebug("Creating block message of size " + size + " bytes") - val newBuffer = buffer.slice() - newBuffer.clear() - newBuffer.limit(size) - logDebug("Trying to convert buffer " + newBuffer + " to block message") - val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer) - logDebug("Created " + newBlockMessage) - newBlockMessages += newBlockMessage - buffer.position(buffer.position() + size) - } - val finishTime = System.currentTimeMillis - logDebug("Converted block message array from buffer message in " + (finishTime - startTime) / 1000.0 + " s") - this.blockMessages = newBlockMessages - } - - def toBufferMessage: BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - - blockMessages.foreach(blockMessage => { - val bufferMessage = blockMessage.toBufferMessage - logDebug("Adding " + blockMessage) - val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size) - sizeBuffer.flip - buffers += sizeBuffer - buffers ++= bufferMessage.buffers - logDebug("Added " + bufferMessage) - }) - - logDebug("Buffer list:") - buffers.foreach((x: ByteBuffer) => logDebug("" + x)) - /* - println() - println("BlockMessageArray: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ - return Message.createBufferMessage(buffers) - } -} - -private[spark] object BlockMessageArray { - - def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { - val newBlockMessageArray = new BlockMessageArray() - newBlockMessageArray.set(bufferMessage) - newBlockMessageArray - } - - def main(args: Array[String]) { - val blockMessages = - (0 until 10).map { i => - if (i % 2 == 0) { - val buffer = ByteBuffer.allocate(100) - buffer.clear - BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY_SER)) - } else { - BlockMessage.fromGetBlock(GetBlock(i.toString)) - } - } - val blockMessageArray = new BlockMessageArray(blockMessages) - println("Block message array created") - - val bufferMessage = blockMessageArray.toBufferMessage - println("Converted to buffer message") - - val totalSize = bufferMessage.size - val newBuffer = ByteBuffer.allocate(totalSize) - newBuffer.clear() - bufferMessage.buffers.foreach(buffer => { - assert (0 == buffer.position()) - newBuffer.put(buffer) - buffer.rewind() - }) - newBuffer.flip - val newBufferMessage = Message.createBufferMessage(newBuffer) - println("Copied to new buffer message, size = " + newBufferMessage.size) - - val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) - println("Converted back to block message array") - newBlockMessageArray.foreach(blockMessage => { - blockMessage.getType match { - case BlockMessage.TYPE_PUT_BLOCK => { - val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - println(pB) - } - case BlockMessage.TYPE_GET_BLOCK => { - val gB = new GetBlock(blockMessage.getId) - println(gB) - } - } - }) - } -} - - diff --git a/core/src/main/scala/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/spark/storage/BlockObjectWriter.scala deleted file mode 100644 index 3812009ca1..0000000000 --- a/core/src/main/scala/spark/storage/BlockObjectWriter.scala +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - - -/** - * An interface for writing JVM objects to some underlying storage. This interface allows - * appending data to an existing block, and can guarantee atomicity in the case of faults - * as it allows the caller to revert partial writes. - * - * This interface does not support concurrent writes. - */ -abstract class BlockObjectWriter(val blockId: String) { - - var closeEventHandler: () => Unit = _ - - def open(): BlockObjectWriter - - def close() { - closeEventHandler() - } - - def isOpen: Boolean - - def registerCloseEventHandler(handler: () => Unit) { - closeEventHandler = handler - } - - /** - * Flush the partial writes and commit them as a single atomic block. Return the - * number of bytes written for this commit. - */ - def commit(): Long - - /** - * Reverts writes that haven't been flushed yet. Callers should invoke this function - * when there are runtime exceptions. - */ - def revertPartialWrites() - - /** - * Writes an object. - */ - def write(value: Any) - - /** - * Size of the valid writes, in bytes. - */ - def size(): Long -} diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala deleted file mode 100644 index c8db0022b0..0000000000 --- a/core/src/main/scala/spark/storage/BlockStore.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer - -import spark.Logging - -/** - * Abstract class to store blocks - */ -private[spark] -abstract class BlockStore(val blockManager: BlockManager) extends Logging { - def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) - - /** - * Put in a block and, possibly, also return its content as either bytes or another Iterator. - * This is used to efficiently write the values to multiple locations (e.g. for replication). - * - * @return a PutResult that contains the size of the data, as well as the values put if - * returnValues is true (if not, the result's data field can be null) - */ - def putValues(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, - returnValues: Boolean) : PutResult - - /** - * Return the size of a block in bytes. - */ - def getSize(blockId: String): Long - - def getBytes(blockId: String): Option[ByteBuffer] - - def getValues(blockId: String): Option[Iterator[Any]] - - /** - * Remove a block, if it exists. - * @param blockId the block to remove. - * @return True if the block was found and removed, False otherwise. - */ - def remove(blockId: String): Boolean - - def contains(blockId: String): Boolean - - def clear() { } -} diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala deleted file mode 100644 index b14497157e..0000000000 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ /dev/null @@ -1,329 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile} -import java.nio.ByteBuffer -import java.nio.channels.FileChannel -import java.nio.channels.FileChannel.MapMode -import java.util.{Random, Date} -import java.text.SimpleDateFormat - -import scala.collection.mutable.ArrayBuffer - -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream - -import spark.Utils -import spark.executor.ExecutorExitCode -import spark.serializer.{Serializer, SerializationStream} -import spark.Logging -import spark.network.netty.ShuffleSender -import spark.network.netty.PathResolver - - -/** - * Stores BlockManager blocks on disk. - */ -private class DiskStore(blockManager: BlockManager, rootDirs: String) - extends BlockStore(blockManager) with Logging { - - class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int) - extends BlockObjectWriter(blockId) { - - private val f: File = createFile(blockId /*, allowAppendExisting */) - - // The file channel, used for repositioning / truncating the file. - private var channel: FileChannel = null - private var bs: OutputStream = null - private var objOut: SerializationStream = null - private var lastValidPosition = 0L - private var initialized = false - - override def open(): DiskBlockObjectWriter = { - val fos = new FileOutputStream(f, true) - channel = fos.getChannel() - bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize)) - objOut = serializer.newInstance().serializeStream(bs) - initialized = true - this - } - - override def close() { - if (initialized) { - objOut.close() - channel = null - bs = null - objOut = null - } - // Invoke the close callback handler. - super.close() - } - - override def isOpen: Boolean = objOut != null - - // Flush the partial writes, and set valid length to be the length of the entire file. - // Return the number of bytes written for this commit. - override def commit(): Long = { - if (initialized) { - // NOTE: Flush the serializer first and then the compressed/buffered output stream - objOut.flush() - bs.flush() - val prevPos = lastValidPosition - lastValidPosition = channel.position() - lastValidPosition - prevPos - } else { - // lastValidPosition is zero if stream is uninitialized - lastValidPosition - } - } - - override def revertPartialWrites() { - if (initialized) { - // Discard current writes. We do this by flushing the outstanding writes and - // truncate the file to the last valid position. - objOut.flush() - bs.flush() - channel.truncate(lastValidPosition) - } - } - - override def write(value: Any) { - if (!initialized) { - open() - } - objOut.writeObject(value) - } - - override def size(): Long = lastValidPosition - } - - private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt - - private var shuffleSender : ShuffleSender = null - // Create one local directory for each path mentioned in spark.local.dir; then, inside this - // directory, create multiple subdirectories that we will hash files into, in order to avoid - // having really large inodes at the top level. - private val localDirs: Array[File] = createLocalDirs() - private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) - - addShutdownHook() - - def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) - : BlockObjectWriter = { - new DiskBlockObjectWriter(blockId, serializer, bufferSize) - } - - override def getSize(blockId: String): Long = { - getFile(blockId).length() - } - - override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { - // So that we do not modify the input offsets ! - // duplicate does not copy buffer, so inexpensive - val bytes = _bytes.duplicate() - logDebug("Attempting to put block " + blockId) - val startTime = System.currentTimeMillis - val file = createFile(blockId) - val channel = new RandomAccessFile(file, "rw").getChannel() - while (bytes.remaining > 0) { - channel.write(bytes) - } - channel.close() - val finishTime = System.currentTimeMillis - logDebug("Block %s stored as %s file on disk in %d ms".format( - blockId, Utils.bytesToString(bytes.limit), (finishTime - startTime))) - } - - private def getFileBytes(file: File): ByteBuffer = { - val length = file.length() - val channel = new RandomAccessFile(file, "r").getChannel() - val buffer = try { - channel.map(MapMode.READ_ONLY, 0, length) - } finally { - channel.close() - } - - buffer - } - - override def putValues( - blockId: String, - values: ArrayBuffer[Any], - level: StorageLevel, - returnValues: Boolean) - : PutResult = { - - logDebug("Attempting to write values for block " + blockId) - val startTime = System.currentTimeMillis - val file = createFile(blockId) - val fileOut = blockManager.wrapForCompression(blockId, - new FastBufferedOutputStream(new FileOutputStream(file))) - val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut) - objOut.writeAll(values.iterator) - objOut.close() - val length = file.length() - - val timeTaken = System.currentTimeMillis - startTime - logDebug("Block %s stored as %s file on disk in %d ms".format( - blockId, Utils.bytesToString(length), timeTaken)) - - if (returnValues) { - // Return a byte buffer for the contents of the file - val buffer = getFileBytes(file) - PutResult(length, Right(buffer)) - } else { - PutResult(length, null) - } - } - - override def getBytes(blockId: String): Option[ByteBuffer] = { - val file = getFile(blockId) - val bytes = getFileBytes(file) - Some(bytes) - } - - override def getValues(blockId: String): Option[Iterator[Any]] = { - getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes)) - } - - /** - * A version of getValues that allows a custom serializer. This is used as part of the - * shuffle short-circuit code. - */ - def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { - getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer)) - } - - override def remove(blockId: String): Boolean = { - val file = getFile(blockId) - if (file.exists()) { - file.delete() - } else { - false - } - } - - override def contains(blockId: String): Boolean = { - getFile(blockId).exists() - } - - private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { - val file = getFile(blockId) - if (!allowAppendExisting && file.exists()) { - // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task - // was rescheduled on the same machine as the old task. - logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") - file.delete() - } - file - } - - private def getFile(blockId: String): File = { - logDebug("Getting file for block " + blockId) - - // Figure out which local directory it hashes to, and which subdirectory in that - val hash = math.abs(blockId.hashCode) - val dirId = hash % localDirs.length - val subDirId = (hash / localDirs.length) % subDirsPerLocalDir - - // Create the subdirectory if it doesn't already exist - var subDir = subDirs(dirId)(subDirId) - if (subDir == null) { - subDir = subDirs(dirId).synchronized { - val old = subDirs(dirId)(subDirId) - if (old != null) { - old - } else { - val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) - newDir.mkdir() - subDirs(dirId)(subDirId) = newDir - newDir - } - } - } - - new File(subDir, blockId) - } - - private def createLocalDirs(): Array[File] = { - logDebug("Creating local directories at root dirs '" + rootDirs + "'") - val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - rootDirs.split(",").map { rootDir => - var foundLocalDir = false - var localDir: File = null - var localDirId: String = null - var tries = 0 - val rand = new Random() - while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) { - tries += 1 - try { - localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - localDir = new File(rootDir, "spark-local-" + localDirId) - if (!localDir.exists) { - foundLocalDir = localDir.mkdirs() - } - } catch { - case e: Exception => - logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e) - } - } - if (!foundLocalDir) { - logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + - " attempts to create local dir in " + rootDir) - System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) - } - logInfo("Created local directory at " + localDir) - localDir - } - } - - private def addShutdownHook() { - localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir)) - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { - override def run() { - logDebug("Shutdown hook called") - localDirs.foreach { localDir => - try { - if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) - } catch { - case t: Throwable => - logError("Exception while deleting local spark dir: " + localDir, t) - } - } - if (shuffleSender != null) { - shuffleSender.stop - } - } - }) - } - - private[storage] def startShuffleBlockSender(port: Int): Int = { - val pResolver = new PathResolver { - override def getAbsolutePath(blockId: String): String = { - if (!blockId.startsWith("shuffle_")) { - return null - } - DiskStore.this.getFile(blockId).getAbsolutePath() - } - } - shuffleSender = new ShuffleSender(port, pResolver) - logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port) - shuffleSender.port - } -} diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala deleted file mode 100644 index 5a51f5cf31..0000000000 --- a/core/src/main/scala/spark/storage/MemoryStore.scala +++ /dev/null @@ -1,257 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.util.LinkedHashMap -import java.util.concurrent.ArrayBlockingQueue -import spark.{SizeEstimator, Utils} -import java.nio.ByteBuffer -import collection.mutable.ArrayBuffer - -/** - * Stores blocks in memory, either as ArrayBuffers of deserialized Java objects or as - * serialized ByteBuffers. - */ -private class MemoryStore(blockManager: BlockManager, maxMemory: Long) - extends BlockStore(blockManager) { - - case class Entry(value: Any, size: Long, deserialized: Boolean, var dropPending: Boolean = false) - - private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true) - private var currentMemory = 0L - // Object used to ensure that only one thread is putting blocks and if necessary, dropping - // blocks from the memory store. - private val putLock = new Object() - - logInfo("MemoryStore started with capacity %s.".format(Utils.bytesToString(maxMemory))) - - def freeMemory: Long = maxMemory - currentMemory - - override def getSize(blockId: String): Long = { - entries.synchronized { - entries.get(blockId).size - } - } - - override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { - // Work on a duplicate - since the original input might be used elsewhere. - val bytes = _bytes.duplicate() - bytes.rewind() - if (level.deserialized) { - val values = blockManager.dataDeserialize(blockId, bytes) - val elements = new ArrayBuffer[Any] - elements ++= values - val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) - tryToPut(blockId, elements, sizeEstimate, true) - } else { - tryToPut(blockId, bytes, bytes.limit, false) - } - } - - override def putValues( - blockId: String, - values: ArrayBuffer[Any], - level: StorageLevel, - returnValues: Boolean) - : PutResult = { - - if (level.deserialized) { - val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef]) - tryToPut(blockId, values, sizeEstimate, true) - PutResult(sizeEstimate, Left(values.iterator)) - } else { - val bytes = blockManager.dataSerialize(blockId, values.iterator) - tryToPut(blockId, bytes, bytes.limit, false) - PutResult(bytes.limit(), Right(bytes.duplicate())) - } - } - - override def getBytes(blockId: String): Option[ByteBuffer] = { - val entry = entries.synchronized { - entries.get(blockId) - } - if (entry == null) { - None - } else if (entry.deserialized) { - Some(blockManager.dataSerialize(blockId, entry.value.asInstanceOf[ArrayBuffer[Any]].iterator)) - } else { - Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data - } - } - - override def getValues(blockId: String): Option[Iterator[Any]] = { - val entry = entries.synchronized { - entries.get(blockId) - } - if (entry == null) { - None - } else if (entry.deserialized) { - Some(entry.value.asInstanceOf[ArrayBuffer[Any]].iterator) - } else { - val buffer = entry.value.asInstanceOf[ByteBuffer].duplicate() // Doesn't actually copy data - Some(blockManager.dataDeserialize(blockId, buffer)) - } - } - - override def remove(blockId: String): Boolean = { - entries.synchronized { - val entry = entries.get(blockId) - if (entry != null) { - entries.remove(blockId) - currentMemory -= entry.size - logInfo("Block %s of size %d dropped from memory (free %d)".format( - blockId, entry.size, freeMemory)) - true - } else { - false - } - } - } - - override def clear() { - entries.synchronized { - entries.clear() - } - logInfo("MemoryStore cleared") - } - - /** - * Return the RDD ID that a given block ID is from, or null if it is not an RDD block. - */ - private def getRddId(blockId: String): String = { - if (blockId.startsWith("rdd_")) { - blockId.split('_')(1) - } else { - null - } - } - - /** - * Try to put in a set of values, if we can free up enough space. The value should either be - * an ArrayBuffer if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) - * size must also be passed by the caller. - * - * Locks on the object putLock to ensure that all the put requests and its associated block - * dropping is done by only on thread at a time. Otherwise while one thread is dropping - * blocks to free memory for one block, another thread may use up the freed space for - * another block. - */ - private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = { - // TODO: Its possible to optimize the locking by locking entries only when selecting blocks - // to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been - // released, it must be ensured that those to-be-dropped blocks are not double counted for - // freeing up more space for another block that needs to be put. Only then the actually dropping - // of blocks (and writing to disk if necessary) can proceed in parallel. - putLock.synchronized { - if (ensureFreeSpace(blockId, size)) { - val entry = new Entry(value, size, deserialized) - entries.synchronized { entries.put(blockId, entry) } - currentMemory += size - if (deserialized) { - logInfo("Block %s stored as values to memory (estimated size %s, free %s)".format( - blockId, Utils.bytesToString(size), Utils.bytesToString(freeMemory))) - } else { - logInfo("Block %s stored as bytes to memory (size %s, free %s)".format( - blockId, Utils.bytesToString(size), Utils.bytesToString(freeMemory))) - } - true - } else { - // Tell the block manager that we couldn't put it in memory so that it can drop it to - // disk if the block allows disk storage. - val data = if (deserialized) { - Left(value.asInstanceOf[ArrayBuffer[Any]]) - } else { - Right(value.asInstanceOf[ByteBuffer].duplicate()) - } - blockManager.dropFromMemory(blockId, data) - false - } - } - } - - /** - * Tries to free up a given amount of space to store a particular block, but can fail and return - * false if either the block is bigger than our memory or it would require replacing another - * block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that - * don't fit into memory that we want to avoid). - * - * Assumes that a lock is held by the caller to ensure only one thread is dropping blocks. - * Otherwise, the freed space may fill up before the caller puts in their new value. - */ - private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = { - - logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( - space, currentMemory, maxMemory)) - - if (space > maxMemory) { - logInfo("Will not store " + blockIdToAdd + " as it is larger than our memory limit") - return false - } - - if (maxMemory - currentMemory < space) { - val rddToAdd = getRddId(blockIdToAdd) - val selectedBlocks = new ArrayBuffer[String]() - var selectedMemory = 0L - - // This is synchronized to ensure that the set of entries is not changed - // (because of getValue or getBytes) while traversing the iterator, as that - // can lead to exceptions. - entries.synchronized { - val iterator = entries.entrySet().iterator() - while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { - val pair = iterator.next() - val blockId = pair.getKey - if (rddToAdd != null && rddToAdd == getRddId(blockId)) { - logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + - "block from the same RDD") - return false - } - selectedBlocks += blockId - selectedMemory += pair.getValue.size - } - } - - if (maxMemory - (currentMemory - selectedMemory) >= space) { - logInfo(selectedBlocks.size + " blocks selected for dropping") - for (blockId <- selectedBlocks) { - val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one thread should be dropping - // blocks and removing entries. However the check is still here for - // future safety. - if (entry != null) { - val data = if (entry.deserialized) { - Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) - } else { - Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) - } - blockManager.dropFromMemory(blockId, data) - } - } - return true - } else { - return false - } - } - return true - } - - override def contains(blockId: String): Boolean = { - entries.synchronized { entries.containsKey(blockId) } - } -} - diff --git a/core/src/main/scala/spark/storage/PutResult.scala b/core/src/main/scala/spark/storage/PutResult.scala deleted file mode 100644 index 3a0974fe15..0000000000 --- a/core/src/main/scala/spark/storage/PutResult.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.nio.ByteBuffer - -/** - * Result of adding a block into a BlockStore. Contains its estimated size, and possibly the - * values put if the caller asked for them to be returned (e.g. for chaining replication) - */ -private[spark] case class PutResult(size: Long, data: Either[Iterator[_], ByteBuffer]) diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala deleted file mode 100644 index 8a7a6f9ed3..0000000000 --- a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import spark.serializer.Serializer - - -private[spark] -class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter]) - - -private[spark] -trait ShuffleBlocks { - def acquireWriters(mapId: Int): ShuffleWriterGroup - def releaseWriters(group: ShuffleWriterGroup) -} - - -private[spark] -class ShuffleBlockManager(blockManager: BlockManager) { - - def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = { - new ShuffleBlocks { - // Get a group of writers for a map task. - override def acquireWriters(mapId: Int): ShuffleWriterGroup = { - val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 - val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) - blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) - } - new ShuffleWriterGroup(mapId, writers) - } - - override def releaseWriters(group: ShuffleWriterGroup) = { - // Nothing really to release here. - } - } - } -} - - -private[spark] -object ShuffleBlockManager { - - // Returns the block id for a given shuffle block. - def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = { - "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId - } - - // Returns true if the block is a shuffle block. - def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_") -} diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala deleted file mode 100644 index f52650988c..0000000000 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} - -/** - * Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, - * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory - * in a serialized format, and whether to replicate the RDD partitions on multiple nodes. - * The [[spark.storage.StorageLevel$]] singleton object contains some static constants for - * commonly useful storage levels. To create your own storage level object, use the factor method - * of the singleton object (`StorageLevel(...)`). - */ -class StorageLevel private( - private var useDisk_ : Boolean, - private var useMemory_ : Boolean, - private var deserialized_ : Boolean, - private var replication_ : Int = 1) - extends Externalizable { - - // TODO: Also add fields for caching priority, dataset ID, and flushing. - private def this(flags: Int, replication: Int) { - this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) - } - - def this() = this(false, true, false) // For deserialization - - def useDisk = useDisk_ - def useMemory = useMemory_ - def deserialized = deserialized_ - def replication = replication_ - - assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") - - override def clone(): StorageLevel = new StorageLevel( - this.useDisk, this.useMemory, this.deserialized, this.replication) - - override def equals(other: Any): Boolean = other match { - case s: StorageLevel => - s.useDisk == useDisk && - s.useMemory == useMemory && - s.deserialized == deserialized && - s.replication == replication - case _ => - false - } - - def isValid = ((useMemory || useDisk) && (replication > 0)) - - def toInt: Int = { - var ret = 0 - if (useDisk_) { - ret |= 4 - } - if (useMemory_) { - ret |= 2 - } - if (deserialized_) { - ret |= 1 - } - return ret - } - - override def writeExternal(out: ObjectOutput) { - out.writeByte(toInt) - out.writeByte(replication_) - } - - override def readExternal(in: ObjectInput) { - val flags = in.readByte() - useDisk_ = (flags & 4) != 0 - useMemory_ = (flags & 2) != 0 - deserialized_ = (flags & 1) != 0 - replication_ = in.readByte() - } - - @throws(classOf[IOException]) - private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this) - - override def toString: String = - "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) - - override def hashCode(): Int = toInt * 41 + replication - def description : String = { - var result = "" - result += (if (useDisk) "Disk " else "") - result += (if (useMemory) "Memory " else "") - result += (if (deserialized) "Deserialized " else "Serialized") - result += "%sx Replicated".format(replication) - result - } -} - - -object StorageLevel { - val NONE = new StorageLevel(false, false, false) - val DISK_ONLY = new StorageLevel(true, false, false) - val DISK_ONLY_2 = new StorageLevel(true, false, false, 2) - val MEMORY_ONLY = new StorageLevel(false, true, true) - val MEMORY_ONLY_2 = new StorageLevel(false, true, true, 2) - val MEMORY_ONLY_SER = new StorageLevel(false, true, false) - val MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, 2) - val MEMORY_AND_DISK = new StorageLevel(true, true, true) - val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2) - val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false) - val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2) - - /** Create a new StorageLevel object */ - def apply(useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, replication: Int = 1) = - getCachedStorageLevel(new StorageLevel(useDisk, useMemory, deserialized, replication)) - - /** Create a new StorageLevel object from its integer representation */ - def apply(flags: Int, replication: Int) = - getCachedStorageLevel(new StorageLevel(flags, replication)) - - /** Read StorageLevel object from ObjectInput stream */ - def apply(in: ObjectInput) = { - val obj = new StorageLevel() - obj.readExternal(in) - getCachedStorageLevel(obj) - } - - private[spark] - val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() - - private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = { - storageLevelCache.putIfAbsent(level, level) - storageLevelCache.get(level) - } -} diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala deleted file mode 100644 index 123b8f6345..0000000000 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import spark.{Utils, SparkContext} -import BlockManagerMasterActor.BlockStatus - -private[spark] -case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, - blocks: Map[String, BlockStatus]) { - - def memUsed(blockPrefix: String = "") = { - blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize). - reduceOption(_+_).getOrElse(0l) - } - - def diskUsed(blockPrefix: String = "") = { - blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize). - reduceOption(_+_).getOrElse(0l) - } - - def memRemaining : Long = maxMem - memUsed() - -} - -case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, - numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) - extends Ordered[RDDInfo] { - override def toString = { - import Utils.bytesToString - "RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id, - storageLevel.toString, numCachedPartitions, numPartitions, bytesToString(memSize), bytesToString(diskSize)) - } - - override def compare(that: RDDInfo) = { - this.id - that.id - } -} - -/* Helper methods for storage-related objects */ -private[spark] -object StorageUtils { - - /* Returns RDD-level information, compiled from a list of StorageStatus objects */ - def rddInfoFromStorageStatus(storageStatusList: Seq[StorageStatus], - sc: SparkContext) : Array[RDDInfo] = { - rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc) - } - - /* Returns a map of blocks to their locations, compiled from a list of StorageStatus objects */ - def blockLocationsFromStorageStatus(storageStatusList: Seq[StorageStatus]) = { - val blockLocationPairs = storageStatusList - .flatMap(s => s.blocks.map(b => (b._1, s.blockManagerId.hostPort))) - blockLocationPairs.groupBy(_._1).map{case (k, v) => (k, v.unzip._2)}.toMap - } - - /* Given a list of BlockStatus objets, returns information for each RDD */ - def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus], - sc: SparkContext) : Array[RDDInfo] = { - - // Group by rddId, ignore the partition name - val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) => - k.substring(0,k.lastIndexOf('_')) - }.mapValues(_.values.toArray) - - // For each RDD, generate an RDDInfo object - val rddInfos = groupedRddBlocks.map { case (rddKey, rddBlocks) => - // Add up memory and disk sizes - val memSize = rddBlocks.map(_.memSize).reduce(_ + _) - val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _) - - // Find the id of the RDD, e.g. rdd_1 => 1 - val rddId = rddKey.split("_").last.toInt - - // Get the friendly name and storage level for the RDD, if available - sc.persistentRdds.get(rddId).map { r => - val rddName = Option(r.name).getOrElse(rddKey) - val rddStorageLevel = r.getStorageLevel - RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize) - } - }.flatten.toArray - - scala.util.Sorting.quickSort(rddInfos) - - rddInfos - } - - /* Removes all BlockStatus object that are not part of a block prefix */ - def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus], - prefix: String) : Array[StorageStatus] = { - - storageStatusList.map { status => - val newBlocks = status.blocks.filterKeys(_.startsWith(prefix)) - //val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _) - StorageStatus(status.blockManagerId, status.maxMem, newBlocks) - } - - } - -} diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala deleted file mode 100644 index b3ab1ff4b4..0000000000 --- a/core/src/main/scala/spark/storage/ThreadingTest.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import akka.actor._ - -import spark.KryoSerializer -import java.util.concurrent.ArrayBlockingQueue -import util.Random - -/** - * This class tests the BlockManager and MemoryStore for thread safety and - * deadlocks. It spawns a number of producer and consumer threads. Producer - * threads continuously pushes blocks into the BlockManager and consumer - * threads continuously retrieves the blocks form the BlockManager and tests - * whether the block is correct or not. - */ -private[spark] object ThreadingTest { - - val numProducers = 5 - val numBlocksPerProducer = 20000 - - private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread { - val queue = new ArrayBlockingQueue[(String, Seq[Int])](100) - - override def run() { - for (i <- 1 to numBlocksPerProducer) { - val blockId = "b-" + id + "-" + i - val blockSize = Random.nextInt(1000) - val block = (1 to blockSize).map(_ => Random.nextInt()) - val level = randomLevel() - val startTime = System.currentTimeMillis() - manager.put(blockId, block.iterator, level, true) - println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms") - queue.add((blockId, block)) - } - println("Producer thread " + id + " terminated") - } - - def randomLevel(): StorageLevel = { - math.abs(Random.nextInt()) % 4 match { - case 0 => StorageLevel.MEMORY_ONLY - case 1 => StorageLevel.MEMORY_ONLY_SER - case 2 => StorageLevel.MEMORY_AND_DISK - case 3 => StorageLevel.MEMORY_AND_DISK_SER - } - } - } - - private[spark] class ConsumerThread( - manager: BlockManager, - queue: ArrayBlockingQueue[(String, Seq[Int])] - ) extends Thread { - var numBlockConsumed = 0 - - override def run() { - println("Consumer thread started") - while(numBlockConsumed < numBlocksPerProducer) { - val (blockId, block) = queue.take() - val startTime = System.currentTimeMillis() - manager.get(blockId) match { - case Some(retrievedBlock) => - assert(retrievedBlock.toList.asInstanceOf[List[Int]] == block.toList, - "Block " + blockId + " did not match") - println("Got block " + blockId + " in " + - (System.currentTimeMillis - startTime) + " ms") - case None => - assert(false, "Block " + blockId + " could not be retrieved") - } - numBlockConsumed += 1 - } - println("Consumer thread terminated") - } - } - - def main(args: Array[String]) { - System.setProperty("spark.kryoserializer.buffer.mb", "1") - val actorSystem = ActorSystem("test") - val serializer = new KryoSerializer - val blockManagerMaster = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))) - val blockManager = new BlockManager( - "", actorSystem, blockManagerMaster, serializer, 1024 * 1024) - val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) - val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) - producers.foreach(_.start) - consumers.foreach(_.start) - producers.foreach(_.join) - consumers.foreach(_.join) - blockManager.stop() - blockManagerMaster.stop() - actorSystem.shutdown() - actorSystem.awaitTermination() - println("Everything stopped.") - println( - "It will take sometime for the JVM to clean all temporary files and shutdown. Sit tight.") - } -} diff --git a/core/src/main/scala/spark/ui/JettyUtils.scala b/core/src/main/scala/spark/ui/JettyUtils.scala deleted file mode 100644 index f66fe39905..0000000000 --- a/core/src/main/scala/spark/ui/JettyUtils.scala +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui - -import javax.servlet.http.{HttpServletResponse, HttpServletRequest} - -import scala.annotation.tailrec -import scala.util.{Try, Success, Failure} -import scala.xml.Node - -import net.liftweb.json.{JValue, pretty, render} - -import org.eclipse.jetty.server.{Server, Request, Handler} -import org.eclipse.jetty.server.handler.{ResourceHandler, HandlerList, ContextHandler, AbstractHandler} -import org.eclipse.jetty.util.thread.QueuedThreadPool - -import spark.Logging - - -/** Utilities for launching a web server using Jetty's HTTP Server class */ -private[spark] object JettyUtils extends Logging { - // Base type for a function that returns something based on an HTTP request. Allows for - // implicit conversion from many types of functions to jetty Handlers. - type Responder[T] = HttpServletRequest => T - - // Conversions from various types of Responder's to jetty Handlers - implicit def jsonResponderToHandler(responder: Responder[JValue]): Handler = - createHandler(responder, "text/json", (in: JValue) => pretty(render(in))) - - implicit def htmlResponderToHandler(responder: Responder[Seq[Node]]): Handler = - createHandler(responder, "text/html", (in: Seq[Node]) => "" + in.toString) - - implicit def textResponderToHandler(responder: Responder[String]): Handler = - createHandler(responder, "text/plain") - - def createHandler[T <% AnyRef](responder: Responder[T], contentType: String, - extractFn: T => String = (in: Any) => in.toString): Handler = { - new AbstractHandler { - def handle(target: String, - baseRequest: Request, - request: HttpServletRequest, - response: HttpServletResponse) { - response.setContentType("%s;charset=utf-8".format(contentType)) - response.setStatus(HttpServletResponse.SC_OK) - baseRequest.setHandled(true) - val result = responder(request) - response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.getWriter().println(extractFn(result)) - } - } - } - - /** Creates a handler that always redirects the user to a given path */ - def createRedirectHandler(newPath: String): Handler = { - new AbstractHandler { - def handle(target: String, - baseRequest: Request, - request: HttpServletRequest, - response: HttpServletResponse) { - response.setStatus(302) - response.setHeader("Location", baseRequest.getRootURL + newPath) - baseRequest.setHandled(true) - } - } - } - - /** Creates a handler for serving files from a static directory */ - def createStaticHandler(resourceBase: String): ResourceHandler = { - val staticHandler = new ResourceHandler - Option(getClass.getClassLoader.getResource(resourceBase)) match { - case Some(res) => - staticHandler.setResourceBase(res.toString) - case None => - throw new Exception("Could not find resource path for Web UI: " + resourceBase) - } - staticHandler - } - - /** - * Attempts to start a Jetty server at the supplied ip:port which uses the supplied handlers. - * - * If the desired port number is contented, continues incrementing ports until a free port is - * found. Returns the chosen port and the jetty Server object. - */ - def startJettyServer(ip: String, port: Int, handlers: Seq[(String, Handler)]): (Server, Int) = { - val handlersToRegister = handlers.map { case(path, handler) => - val contextHandler = new ContextHandler(path) - contextHandler.setHandler(handler) - contextHandler.asInstanceOf[org.eclipse.jetty.server.Handler] - } - - val handlerList = new HandlerList - handlerList.setHandlers(handlersToRegister.toArray) - - @tailrec - def connect(currentPort: Int): (Server, Int) = { - val server = new Server(currentPort) - val pool = new QueuedThreadPool - pool.setDaemon(true) - server.setThreadPool(pool) - server.setHandler(handlerList) - - Try { server.start() } match { - case s: Success[_] => - sys.addShutdownHook(server.stop()) // Be kind, un-bind - (server, server.getConnectors.head.getLocalPort) - case f: Failure[_] => - server.stop() - logInfo("Failed to create UI at port, %s. Trying again.".format(currentPort)) - logInfo("Error was: " + f.toString) - connect((currentPort + 1) % 65536) - } - } - - connect(port) - } -} diff --git a/core/src/main/scala/spark/ui/Page.scala b/core/src/main/scala/spark/ui/Page.scala deleted file mode 100644 index 87376a19d8..0000000000 --- a/core/src/main/scala/spark/ui/Page.scala +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui - -private[spark] object Page extends Enumeration { - val Stages, Storage, Environment, Executors = Value -} diff --git a/core/src/main/scala/spark/ui/SparkUI.scala b/core/src/main/scala/spark/ui/SparkUI.scala deleted file mode 100644 index 23ded44ba3..0000000000 --- a/core/src/main/scala/spark/ui/SparkUI.scala +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui - -import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.{Handler, Server} - -import spark.{Logging, SparkContext, SparkEnv, Utils} -import spark.ui.env.EnvironmentUI -import spark.ui.exec.ExecutorsUI -import spark.ui.storage.BlockManagerUI -import spark.ui.jobs.JobProgressUI -import spark.ui.JettyUtils._ - -/** Top level user interface for Spark */ -private[spark] class SparkUI(sc: SparkContext) extends Logging { - val host = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(Utils.localHostName()) - val port = Option(System.getProperty("spark.ui.port")).getOrElse(SparkUI.DEFAULT_PORT).toInt - var boundPort: Option[Int] = None - var server: Option[Server] = None - - val handlers = Seq[(String, Handler)]( - ("/static", createStaticHandler(SparkUI.STATIC_RESOURCE_DIR)), - ("/", createRedirectHandler("/stages")) - ) - val storage = new BlockManagerUI(sc) - val jobs = new JobProgressUI(sc) - val env = new EnvironmentUI(sc) - val exec = new ExecutorsUI(sc) - - // Add MetricsServlet handlers by default - val metricsServletHandlers = SparkEnv.get.metricsSystem.getServletHandlers - - val allHandlers = storage.getHandlers ++ jobs.getHandlers ++ env.getHandlers ++ - exec.getHandlers ++ metricsServletHandlers ++ handlers - - /** Bind the HTTP server which backs this web interface */ - def bind() { - try { - val (srv, usedPort) = JettyUtils.startJettyServer("0.0.0.0", port, allHandlers) - logInfo("Started Spark Web UI at http://%s:%d".format(host, usedPort)) - server = Some(srv) - boundPort = Some(usedPort) - } catch { - case e: Exception => - logError("Failed to create Spark JettyUtils", e) - System.exit(1) - } - } - - /** Initialize all components of the server */ - def start() { - // NOTE: This is decoupled from bind() because of the following dependency cycle: - // DAGScheduler() requires that the port of this server is known - // This server must register all handlers, including JobProgressUI, before binding - // JobProgressUI registers a listener with SparkContext, which requires sc to initialize - jobs.start() - exec.start() - } - - def stop() { - server.foreach(_.stop()) - } - - private[spark] def appUIAddress = "http://" + host + ":" + boundPort.getOrElse("-1") -} - -private[spark] object SparkUI { - val DEFAULT_PORT = "3030" - val STATIC_RESOURCE_DIR = "spark/ui/static" -} diff --git a/core/src/main/scala/spark/ui/UIUtils.scala b/core/src/main/scala/spark/ui/UIUtils.scala deleted file mode 100644 index 51bb18d888..0000000000 --- a/core/src/main/scala/spark/ui/UIUtils.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui - -import scala.xml.Node - -import spark.SparkContext - -/** Utility functions for generating XML pages with spark content. */ -private[spark] object UIUtils { - import Page._ - - /** Returns a spark page with correctly formatted headers */ - def headerSparkPage(content: => Seq[Node], sc: SparkContext, title: String, page: Page.Value) - : Seq[Node] = { - val jobs = page match { - case Stages =>
  • Stages
  • - case _ =>
  • Stages
  • - } - val storage = page match { - case Storage =>
  • Storage
  • - case _ =>
  • Storage
  • - } - val environment = page match { - case Environment =>
  • Environment
  • - case _ =>
  • Environment
  • - } - val executors = page match { - case Executors =>
  • Executors
  • - case _ =>
  • Executors
  • - } - - - - - - - - {sc.appName} - {title} - - - - -
    -
    -
    -

    - {title} -

    -
    -
    - {content} -
    - - - } - - /** Returns a page with the spark css/js and a simple format. Used for scheduler UI. */ - def basicSparkPage(content: => Seq[Node], title: String): Seq[Node] = { - - - - - - - {title} - - -
    -
    -
    -

    - - {title} -

    -
    -
    - {content} -
    - - - } - - /** Returns an HTML table constructed by generating a row for each object in a sequence. */ - def listingTable[T]( - headers: Seq[String], - makeRow: T => Seq[Node], - rows: Seq[T], - fixedWidth: Boolean = false): Seq[Node] = { - - val colWidth = 100.toDouble / headers.size - val colWidthAttr = if (fixedWidth) colWidth + "%" else "" - var tableClass = "table table-bordered table-striped table-condensed sortable" - if (fixedWidth) { - tableClass += " table-fixed" - } - - - {headers.map(h => )} - - {rows.map(r => makeRow(r))} - -
    {h}
    - } -} diff --git a/core/src/main/scala/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/spark/ui/UIWorkloadGenerator.scala deleted file mode 100644 index 5ff0572f0a..0000000000 --- a/core/src/main/scala/spark/ui/UIWorkloadGenerator.scala +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui - -import scala.util.Random - -import spark.SparkContext -import spark.SparkContext._ -import spark.scheduler.cluster.SchedulingMode - - -/** - * Continuously generates jobs that expose various features of the WebUI (internal testing tool). - * - * Usage: ./run spark.ui.UIWorkloadGenerator [master] - */ -private[spark] object UIWorkloadGenerator { - val NUM_PARTITIONS = 100 - val INTER_JOB_WAIT_MS = 5000 - - def main(args: Array[String]) { - if (args.length < 2) { - println("usage: ./spark-class spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]") - System.exit(1) - } - val master = args(0) - val schedulingMode = SchedulingMode.withName(args(1)) - val appName = "Spark UI Tester" - - if (schedulingMode == SchedulingMode.FAIR) { - System.setProperty("spark.cluster.schedulingmode", "FAIR") - } - val sc = new SparkContext(master, appName) - - def setProperties(s: String) = { - if(schedulingMode == SchedulingMode.FAIR) { - sc.setLocalProperty("spark.scheduler.cluster.fair.pool", s) - } - sc.setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, s) - } - - val baseData = sc.makeRDD(1 to NUM_PARTITIONS * 10, NUM_PARTITIONS) - def nextFloat() = (new Random()).nextFloat() - - val jobs = Seq[(String, () => Long)]( - ("Count", baseData.count), - ("Cache and Count", baseData.map(x => x).cache.count), - ("Single Shuffle", baseData.map(x => (x % 10, x)).reduceByKey(_ + _).count), - ("Entirely failed phase", baseData.map(x => throw new Exception).count), - ("Partially failed phase", { - baseData.map{x => - val probFailure = (4.0 / NUM_PARTITIONS) - if (nextFloat() < probFailure) { - throw new Exception("This is a task failure") - } - 1 - }.count - }), - ("Partially failed phase (longer tasks)", { - baseData.map{x => - val probFailure = (4.0 / NUM_PARTITIONS) - if (nextFloat() < probFailure) { - Thread.sleep(100) - throw new Exception("This is a task failure") - } - 1 - }.count - }), - ("Job with delays", baseData.map(x => Thread.sleep(100)).count) - ) - - while (true) { - for ((desc, job) <- jobs) { - new Thread { - override def run() { - try { - setProperties(desc) - job() - println("Job funished: " + desc) - } catch { - case e: Exception => - println("Job Failed: " + desc) - } - } - }.start - Thread.sleep(INTER_JOB_WAIT_MS) - } - } - } -} diff --git a/core/src/main/scala/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/spark/ui/env/EnvironmentUI.scala deleted file mode 100644 index b1be1a27ef..0000000000 --- a/core/src/main/scala/spark/ui/env/EnvironmentUI.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui.env - -import javax.servlet.http.HttpServletRequest - -import scala.collection.JavaConversions._ -import scala.util.Properties -import scala.xml.Node - -import org.eclipse.jetty.server.Handler - -import spark.ui.JettyUtils._ -import spark.ui.UIUtils -import spark.ui.Page.Environment -import spark.SparkContext - - -private[spark] class EnvironmentUI(sc: SparkContext) { - - def getHandlers = Seq[(String, Handler)]( - ("/environment", (request: HttpServletRequest) => envDetails(request)) - ) - - def envDetails(request: HttpServletRequest): Seq[Node] = { - val jvmInformation = Seq( - ("Java Version", "%s (%s)".format(Properties.javaVersion, Properties.javaVendor)), - ("Java Home", Properties.javaHome), - ("Scala Version", Properties.versionString), - ("Scala Home", Properties.scalaHome) - ).sorted - def jvmRow(kv: (String, String)) = {kv._1}{kv._2} - def jvmTable = - UIUtils.listingTable(Seq("Name", "Value"), jvmRow, jvmInformation, fixedWidth = true) - - val properties = System.getProperties.iterator.toSeq - val classPathProperty = properties.find { case (k, v) => - k.contains("java.class.path") - }.getOrElse(("", "")) - val sparkProperties = properties.filter(_._1.startsWith("spark")).sorted - val otherProperties = properties.diff(sparkProperties :+ classPathProperty).sorted - - val propertyHeaders = Seq("Name", "Value") - def propertyRow(kv: (String, String)) = {kv._1}{kv._2} - val sparkPropertyTable = - UIUtils.listingTable(propertyHeaders, propertyRow, sparkProperties, fixedWidth = true) - val otherPropertyTable = - UIUtils.listingTable(propertyHeaders, propertyRow, otherProperties, fixedWidth = true) - - val classPathEntries = classPathProperty._2 - .split(System.getProperty("path.separator", ":")) - .filterNot(e => e.isEmpty) - .map(e => (e, "System Classpath")) - val addedJars = sc.addedJars.iterator.toSeq.map{case (path, time) => (path, "Added By User")} - val addedFiles = sc.addedFiles.iterator.toSeq.map{case (path, time) => (path, "Added By User")} - val classPath = (addedJars ++ addedFiles ++ classPathEntries).sorted - - val classPathHeaders = Seq("Resource", "Source") - def classPathRow(data: (String, String)) = {data._1}{data._2} - val classPathTable = - UIUtils.listingTable(classPathHeaders, classPathRow, classPath, fixedWidth = true) - - val content = - -

    Runtime Information

    {jvmTable} -

    Spark Properties

    - {sparkPropertyTable} -

    System Properties

    - {otherPropertyTable} -

    Classpath Entries

    - {classPathTable} -
    - - UIUtils.headerSparkPage(content, sc, "Environment", Environment) - } -} diff --git a/core/src/main/scala/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/spark/ui/exec/ExecutorsUI.scala deleted file mode 100644 index 0a7021fbf8..0000000000 --- a/core/src/main/scala/spark/ui/exec/ExecutorsUI.scala +++ /dev/null @@ -1,136 +0,0 @@ -package spark.ui.exec - -import javax.servlet.http.HttpServletRequest - -import scala.collection.mutable.{HashMap, HashSet} -import scala.xml.Node - -import org.eclipse.jetty.server.Handler - -import spark.{ExceptionFailure, Logging, Utils, SparkContext} -import spark.executor.TaskMetrics -import spark.scheduler.cluster.TaskInfo -import spark.scheduler.{SparkListenerTaskStart, SparkListenerTaskEnd, SparkListener} -import spark.ui.JettyUtils._ -import spark.ui.Page.Executors -import spark.ui.UIUtils - - -private[spark] class ExecutorsUI(val sc: SparkContext) { - - private var _listener: Option[ExecutorsListener] = None - def listener = _listener.get - - def start() { - _listener = Some(new ExecutorsListener) - sc.addSparkListener(listener) - } - - def getHandlers = Seq[(String, Handler)]( - ("/executors", (request: HttpServletRequest) => render(request)) - ) - - def render(request: HttpServletRequest): Seq[Node] = { - val storageStatusList = sc.getExecutorStorageStatus - - val maxMem = storageStatusList.map(_.maxMem).fold(0L)(_+_) - val memUsed = storageStatusList.map(_.memUsed()).fold(0L)(_+_) - val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_+_) - - val execHead = Seq("Executor ID", "Address", "RDD blocks", "Memory used", "Disk used", - "Active tasks", "Failed tasks", "Complete tasks", "Total tasks") - - def execRow(kv: Seq[String]) = { - - {kv(0)} - {kv(1)} - {kv(2)} - - {Utils.bytesToString(kv(3).toLong)} / {Utils.bytesToString(kv(4).toLong)} - - - {Utils.bytesToString(kv(5).toLong)} - - {kv(6)} - {kv(7)} - {kv(8)} - {kv(9)} - - } - - val execInfo = for (b <- 0 until storageStatusList.size) yield getExecInfo(b) - val execTable = UIUtils.listingTable(execHead, execRow, execInfo) - - val content = -
    -
    -
      -
    • Memory: - {Utils.bytesToString(memUsed)} Used - ({Utils.bytesToString(maxMem)} Total)
    • -
    • Disk: {Utils.bytesToString(diskSpaceUsed)} Used
    • -
    -
    -
    -
    -
    - {execTable} -
    -
    ; - - UIUtils.headerSparkPage(content, sc, "Executors (" + execInfo.size + ")", Executors) - } - - def getExecInfo(a: Int): Seq[String] = { - val execId = sc.getExecutorStorageStatus(a).blockManagerId.executorId - val hostPort = sc.getExecutorStorageStatus(a).blockManagerId.hostPort - val rddBlocks = sc.getExecutorStorageStatus(a).blocks.size.toString - val memUsed = sc.getExecutorStorageStatus(a).memUsed().toString - val maxMem = sc.getExecutorStorageStatus(a).maxMem.toString - val diskUsed = sc.getExecutorStorageStatus(a).diskUsed().toString - val activeTasks = listener.executorToTasksActive.get(a.toString).map(l => l.size).getOrElse(0) - val failedTasks = listener.executorToTasksFailed.getOrElse(a.toString, 0) - val completedTasks = listener.executorToTasksComplete.getOrElse(a.toString, 0) - val totalTasks = activeTasks + failedTasks + completedTasks - - Seq( - execId, - hostPort, - rddBlocks, - memUsed, - maxMem, - diskUsed, - activeTasks.toString, - failedTasks.toString, - completedTasks.toString, - totalTasks.toString - ) - } - - private[spark] class ExecutorsListener extends SparkListener with Logging { - val executorToTasksActive = HashMap[String, HashSet[TaskInfo]]() - val executorToTasksComplete = HashMap[String, Int]() - val executorToTasksFailed = HashMap[String, Int]() - - override def onTaskStart(taskStart: SparkListenerTaskStart) { - val eid = taskStart.taskInfo.executorId - val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]()) - activeTasks += taskStart.taskInfo - } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - val eid = taskEnd.taskInfo.executorId - val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]()) - activeTasks -= taskEnd.taskInfo - val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = - taskEnd.reason match { - case e: ExceptionFailure => - executorToTasksFailed(eid) = executorToTasksFailed.getOrElse(eid, 0) + 1 - (Some(e), e.metrics) - case _ => - executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1 - (None, Option(taskEnd.taskMetrics)) - } - } - } -} diff --git a/core/src/main/scala/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/spark/ui/jobs/IndexPage.scala deleted file mode 100644 index 8867a6c90c..0000000000 --- a/core/src/main/scala/spark/ui/jobs/IndexPage.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui.jobs - -import javax.servlet.http.HttpServletRequest - -import scala.xml.{NodeSeq, Node} - -import spark.scheduler.cluster.SchedulingMode -import spark.ui.Page._ -import spark.ui.UIUtils._ - - -/** Page showing list of all ongoing and recently finished stages and pools*/ -private[spark] class IndexPage(parent: JobProgressUI) { - def listener = parent.listener - - def render(request: HttpServletRequest): Seq[Node] = { - listener.synchronized { - val activeStages = listener.activeStages.toSeq - val completedStages = listener.completedStages.reverse.toSeq - val failedStages = listener.failedStages.reverse.toSeq - val now = System.currentTimeMillis() - - var activeTime = 0L - for (tasks <- listener.stageToTasksActive.values; t <- tasks) { - activeTime += t.timeRunning(now) - } - - val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent) - val completedStagesTable = new StageTable(completedStages.sortBy(_.submissionTime).reverse, parent) - val failedStagesTable = new StageTable(failedStages.sortBy(_.submissionTime).reverse, parent) - - val pools = listener.sc.getAllPools - val poolTable = new PoolTable(pools, listener) - val summary: NodeSeq = -
    -
      -
    • - Total Duration: - {parent.formatDuration(now - listener.sc.startTime)} -
    • -
    • Scheduling Mode: {parent.sc.getSchedulingMode}
    • -
    • - Active Stages: - {activeStages.size} -
    • -
    • - Completed Stages: - {completedStages.size} -
    • -
    • - Failed Stages: - {failedStages.size} -
    • -
    -
    - - val content = summary ++ - {if (listener.sc.getSchedulingMode == SchedulingMode.FAIR) { -

    {pools.size} Fair Scheduler Pools

    ++ poolTable.toNodeSeq - } else { - Seq() - }} ++ -

    Active Stages ({activeStages.size})

    ++ - activeStagesTable.toNodeSeq++ -

    Completed Stages ({completedStages.size})

    ++ - completedStagesTable.toNodeSeq++ -

    Failed Stages ({failedStages.size})

    ++ - failedStagesTable.toNodeSeq - - headerSparkPage(content, parent.sc, "Spark Stages", Stages) - } - } -} diff --git a/core/src/main/scala/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/spark/ui/jobs/JobProgressListener.scala deleted file mode 100644 index 1d9767a83c..0000000000 --- a/core/src/main/scala/spark/ui/jobs/JobProgressListener.scala +++ /dev/null @@ -1,156 +0,0 @@ -package spark.ui.jobs - -import scala.Seq -import scala.collection.mutable.{ListBuffer, HashMap, HashSet} - -import spark.{ExceptionFailure, SparkContext, Success, Utils} -import spark.scheduler._ -import spark.scheduler.cluster.TaskInfo -import spark.executor.TaskMetrics -import collection.mutable - -/** - * Tracks task-level information to be displayed in the UI. - * - * All access to the data structures in this class must be synchronized on the - * class, since the UI thread and the DAGScheduler event loop may otherwise - * be reading/updating the internal data structures concurrently. - */ -private[spark] class JobProgressListener(val sc: SparkContext) extends SparkListener { - // How many stages to remember - val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt - val DEFAULT_POOL_NAME = "default" - - val stageToPool = new HashMap[Stage, String]() - val stageToDescription = new HashMap[Stage, String]() - val poolToActiveStages = new HashMap[String, HashSet[Stage]]() - - val activeStages = HashSet[Stage]() - val completedStages = ListBuffer[Stage]() - val failedStages = ListBuffer[Stage]() - - // Total metrics reflect metrics only for completed tasks - var totalTime = 0L - var totalShuffleRead = 0L - var totalShuffleWrite = 0L - - val stageToTime = HashMap[Int, Long]() - val stageToShuffleRead = HashMap[Int, Long]() - val stageToShuffleWrite = HashMap[Int, Long]() - val stageToTasksActive = HashMap[Int, HashSet[TaskInfo]]() - val stageToTasksComplete = HashMap[Int, Int]() - val stageToTasksFailed = HashMap[Int, Int]() - val stageToTaskInfos = - HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]() - - override def onJobStart(jobStart: SparkListenerJobStart) {} - - override def onStageCompleted(stageCompleted: StageCompleted) = synchronized { - val stage = stageCompleted.stageInfo.stage - poolToActiveStages(stageToPool(stage)) -= stage - activeStages -= stage - completedStages += stage - trimIfNecessary(completedStages) - } - - /** If stages is too large, remove and garbage collect old stages */ - def trimIfNecessary(stages: ListBuffer[Stage]) = synchronized { - if (stages.size > RETAINED_STAGES) { - val toRemove = RETAINED_STAGES / 10 - stages.takeRight(toRemove).foreach( s => { - stageToTaskInfos.remove(s.id) - stageToTime.remove(s.id) - stageToShuffleRead.remove(s.id) - stageToShuffleWrite.remove(s.id) - stageToTasksActive.remove(s.id) - stageToTasksComplete.remove(s.id) - stageToTasksFailed.remove(s.id) - stageToPool.remove(s) - if (stageToDescription.contains(s)) {stageToDescription.remove(s)} - }) - stages.trimEnd(toRemove) - } - } - - /** For FIFO, all stages are contained by "default" pool but "default" pool here is meaningless */ - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized { - val stage = stageSubmitted.stage - activeStages += stage - - val poolName = Option(stageSubmitted.properties).map { - p => p.getProperty("spark.scheduler.cluster.fair.pool", DEFAULT_POOL_NAME) - }.getOrElse(DEFAULT_POOL_NAME) - stageToPool(stage) = poolName - - val description = Option(stageSubmitted.properties).flatMap { - p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) - } - description.map(d => stageToDescription(stage) = d) - - val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[Stage]()) - stages += stage - } - - override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { - val sid = taskStart.task.stageId - val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) - tasksActive += taskStart.taskInfo - val taskList = stageToTaskInfos.getOrElse( - sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]()) - taskList += ((taskStart.taskInfo, None, None)) - stageToTaskInfos(sid) = taskList - } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { - val sid = taskEnd.task.stageId - val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) - tasksActive -= taskEnd.taskInfo - val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = - taskEnd.reason match { - case e: ExceptionFailure => - stageToTasksFailed(sid) = stageToTasksFailed.getOrElse(sid, 0) + 1 - (Some(e), e.metrics) - case _ => - stageToTasksComplete(sid) = stageToTasksComplete.getOrElse(sid, 0) + 1 - (None, Option(taskEnd.taskMetrics)) - } - - stageToTime.getOrElseUpdate(sid, 0L) - val time = metrics.map(m => m.executorRunTime).getOrElse(0) - stageToTime(sid) += time - totalTime += time - - stageToShuffleRead.getOrElseUpdate(sid, 0L) - val shuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s => - s.remoteBytesRead).getOrElse(0L) - stageToShuffleRead(sid) += shuffleRead - totalShuffleRead += shuffleRead - - stageToShuffleWrite.getOrElseUpdate(sid, 0L) - val shuffleWrite = metrics.flatMap(m => m.shuffleWriteMetrics).map(s => - s.shuffleBytesWritten).getOrElse(0L) - stageToShuffleWrite(sid) += shuffleWrite - totalShuffleWrite += shuffleWrite - - val taskList = stageToTaskInfos.getOrElse( - sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]()) - taskList -= ((taskEnd.taskInfo, None, None)) - taskList += ((taskEnd.taskInfo, metrics, failureInfo)) - stageToTaskInfos(sid) = taskList - } - - override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { - jobEnd match { - case end: SparkListenerJobEnd => - end.jobResult match { - case JobFailed(ex, Some(stage)) => - activeStages -= stage - poolToActiveStages(stageToPool(stage)) -= stage - failedStages += stage - trimIfNecessary(failedStages) - case _ => - } - case _ => - } - } -} diff --git a/core/src/main/scala/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/spark/ui/jobs/JobProgressUI.scala deleted file mode 100644 index c83f102ff3..0000000000 --- a/core/src/main/scala/spark/ui/jobs/JobProgressUI.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui.jobs - -import akka.util.Duration - -import java.text.SimpleDateFormat - -import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.Handler - -import scala.Seq -import scala.collection.mutable.{HashSet, ListBuffer, HashMap, ArrayBuffer} - -import spark.ui.JettyUtils._ -import spark.{ExceptionFailure, SparkContext, Success, Utils} -import spark.scheduler._ -import collection.mutable -import spark.scheduler.cluster.SchedulingMode -import spark.scheduler.cluster.SchedulingMode.SchedulingMode - -/** Web UI showing progress status of all jobs in the given SparkContext. */ -private[spark] class JobProgressUI(val sc: SparkContext) { - private var _listener: Option[JobProgressListener] = None - def listener = _listener.get - val dateFmt = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") - - private val indexPage = new IndexPage(this) - private val stagePage = new StagePage(this) - private val poolPage = new PoolPage(this) - - def start() { - _listener = Some(new JobProgressListener(sc)) - sc.addSparkListener(listener) - } - - def formatDuration(ms: Long) = Utils.msDurationToString(ms) - - def getHandlers = Seq[(String, Handler)]( - ("/stages/stage", (request: HttpServletRequest) => stagePage.render(request)), - ("/stages/pool", (request: HttpServletRequest) => poolPage.render(request)), - ("/stages", (request: HttpServletRequest) => indexPage.render(request)) - ) -} diff --git a/core/src/main/scala/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/spark/ui/jobs/PoolPage.scala deleted file mode 100644 index 7fb74dce40..0000000000 --- a/core/src/main/scala/spark/ui/jobs/PoolPage.scala +++ /dev/null @@ -1,32 +0,0 @@ -package spark.ui.jobs - -import javax.servlet.http.HttpServletRequest - -import scala.xml.{NodeSeq, Node} -import scala.collection.mutable.HashSet - -import spark.scheduler.Stage -import spark.ui.UIUtils._ -import spark.ui.Page._ - -/** Page showing specific pool details */ -private[spark] class PoolPage(parent: JobProgressUI) { - def listener = parent.listener - - def render(request: HttpServletRequest): Seq[Node] = { - listener.synchronized { - val poolName = request.getParameter("poolname") - val poolToActiveStages = listener.poolToActiveStages - val activeStages = poolToActiveStages.get(poolName).toSeq.flatten - val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent) - - val pool = listener.sc.getPoolForName(poolName).get - val poolTable = new PoolTable(Seq(pool), listener) - - val content =

    Summary

    ++ poolTable.toNodeSeq() ++ -

    {activeStages.size} Active Stages

    ++ activeStagesTable.toNodeSeq() - - headerSparkPage(content, parent.sc, "Fair Scheduler Pool: " + poolName, Stages) - } - } -} diff --git a/core/src/main/scala/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/spark/ui/jobs/PoolTable.scala deleted file mode 100644 index 621828f9c3..0000000000 --- a/core/src/main/scala/spark/ui/jobs/PoolTable.scala +++ /dev/null @@ -1,55 +0,0 @@ -package spark.ui.jobs - -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.xml.Node - -import spark.scheduler.Stage -import spark.scheduler.cluster.Schedulable - -/** Table showing list of pools */ -private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressListener) { - - var poolToActiveStages: HashMap[String, HashSet[Stage]] = listener.poolToActiveStages - - def toNodeSeq(): Seq[Node] = { - listener.synchronized { - poolTable(poolRow, pools) - } - } - - private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[Stage]]) => Seq[Node], - rows: Seq[Schedulable] - ): Seq[Node] = { - - - - - - - - - - - {rows.map(r => makeRow(r, poolToActiveStages))} - -
    Pool NameMinimum SharePool WeightActive StagesRunning TasksSchedulingMode
    - } - - private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[Stage]]) - : Seq[Node] = { - val activeStages = poolToActiveStages.get(p.name) match { - case Some(stages) => stages.size - case None => 0 - } - - {p.name} - {p.minShare} - {p.weight} - {activeStages} - {p.runningTasks} - {p.schedulingMode} - - } -} - diff --git a/core/src/main/scala/spark/ui/jobs/StagePage.scala b/core/src/main/scala/spark/ui/jobs/StagePage.scala deleted file mode 100644 index c2341475c7..0000000000 --- a/core/src/main/scala/spark/ui/jobs/StagePage.scala +++ /dev/null @@ -1,183 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui.jobs - -import java.util.Date - -import javax.servlet.http.HttpServletRequest - -import scala.xml.Node - -import spark.ui.UIUtils._ -import spark.ui.Page._ -import spark.util.Distribution -import spark.{ExceptionFailure, Utils} -import spark.scheduler.cluster.TaskInfo -import spark.executor.TaskMetrics - -/** Page showing statistics and task list for a given stage */ -private[spark] class StagePage(parent: JobProgressUI) { - def listener = parent.listener - val dateFmt = parent.dateFmt - - def render(request: HttpServletRequest): Seq[Node] = { - listener.synchronized { - val stageId = request.getParameter("id").toInt - val now = System.currentTimeMillis() - - if (!listener.stageToTaskInfos.contains(stageId)) { - val content = -
    -

    Summary Metrics

    No tasks have started yet -

    Tasks

    No tasks have started yet -
    - return headerSparkPage(content, parent.sc, "Details for Stage %s".format(stageId), Stages) - } - - val tasks = listener.stageToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime) - - val numCompleted = tasks.count(_._1.finished) - val shuffleReadBytes = listener.stageToShuffleRead.getOrElse(stageId, 0L) - val hasShuffleRead = shuffleReadBytes > 0 - val shuffleWriteBytes = listener.stageToShuffleWrite.getOrElse(stageId, 0L) - val hasShuffleWrite = shuffleWriteBytes > 0 - - var activeTime = 0L - listener.stageToTasksActive(stageId).foreach(activeTime += _.timeRunning(now)) - - val summary = -
    -
      -
    • - CPU time: - {parent.formatDuration(listener.stageToTime.getOrElse(stageId, 0L) + activeTime)} -
    • - {if (hasShuffleRead) -
    • - Shuffle read: - {Utils.bytesToString(shuffleReadBytes)} -
    • - } - {if (hasShuffleWrite) -
    • - Shuffle write: - {Utils.bytesToString(shuffleWriteBytes)} -
    • - } -
    -
    - - val taskHeaders: Seq[String] = - Seq("Task ID", "Status", "Locality Level", "Executor", "Launch Time", "Duration") ++ - Seq("GC Time") ++ - {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ - {if (hasShuffleWrite) Seq("Shuffle Write") else Nil} ++ - Seq("Errors") - - val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite), tasks) - - // Excludes tasks which failed and have incomplete metrics - val validTasks = tasks.filter(t => t._1.status == "SUCCESS" && (t._2.isDefined)) - - val summaryTable: Option[Seq[Node]] = - if (validTasks.size == 0) { - None - } - else { - val serviceTimes = validTasks.map{case (info, metrics, exception) => - metrics.get.executorRunTime.toDouble} - val serviceQuantiles = "Duration" +: Distribution(serviceTimes).get.getQuantiles().map( - ms => parent.formatDuration(ms.toLong)) - - def getQuantileCols(data: Seq[Double]) = - Distribution(data).get.getQuantiles().map(d => Utils.bytesToString(d.toLong)) - - val shuffleReadSizes = validTasks.map { - case(info, metrics, exception) => - metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble - } - val shuffleReadQuantiles = "Shuffle Read (Remote)" +: getQuantileCols(shuffleReadSizes) - - val shuffleWriteSizes = validTasks.map { - case(info, metrics, exception) => - metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble - } - val shuffleWriteQuantiles = "Shuffle Write" +: getQuantileCols(shuffleWriteSizes) - - val listings: Seq[Seq[String]] = Seq(serviceQuantiles, - if (hasShuffleRead) shuffleReadQuantiles else Nil, - if (hasShuffleWrite) shuffleWriteQuantiles else Nil) - - val quantileHeaders = Seq("Metric", "Min", "25th percentile", - "Median", "75th percentile", "Max") - def quantileRow(data: Seq[String]): Seq[Node] = {data.map(d => {d})} - Some(listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true)) - } - - val content = - summary ++ -

    Summary Metrics for {numCompleted} Completed Tasks

    ++ -
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++ -

    Tasks

    ++ taskTable; - - headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages) - } - } - - - def taskRow(shuffleRead: Boolean, shuffleWrite: Boolean) - (taskData: (TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])): Seq[Node] = { - def fmtStackTrace(trace: Seq[StackTraceElement]): Seq[Node] = - trace.map(e => {e.toString}) - val (info, metrics, exception) = taskData - - val duration = if (info.status == "RUNNING") info.timeRunning(System.currentTimeMillis()) - else metrics.map(m => m.executorRunTime).getOrElse(1) - val formatDuration = if (info.status == "RUNNING") parent.formatDuration(duration) - else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("") - val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L) - - - {info.taskId} - {info.status} - {info.taskLocality} - {info.host} - {dateFmt.format(new Date(info.launchTime))} - - {formatDuration} - - - {if (gcTime > 0) parent.formatDuration(gcTime) else ""} - - {if (shuffleRead) { - {metrics.flatMap{m => m.shuffleReadMetrics}.map{s => - Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")} - }} - {if (shuffleWrite) { - {metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => - Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")} - }} - {exception.map(e => - - {e.className} ({e.description})
    - {fmtStackTrace(e.stackTrace)} -
    ).getOrElse("")} - - - } -} diff --git a/core/src/main/scala/spark/ui/jobs/StageTable.scala b/core/src/main/scala/spark/ui/jobs/StageTable.scala deleted file mode 100644 index 2b1bc984fc..0000000000 --- a/core/src/main/scala/spark/ui/jobs/StageTable.scala +++ /dev/null @@ -1,107 +0,0 @@ -package spark.ui.jobs - -import java.util.Date - -import scala.xml.Node -import scala.collection.mutable.HashSet - -import spark.Utils -import spark.scheduler.cluster.{SchedulingMode, TaskInfo} -import spark.scheduler.Stage - - -/** Page showing list of all ongoing and recently finished stages */ -private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressUI) { - - val listener = parent.listener - val dateFmt = parent.dateFmt - val isFairScheduler = listener.sc.getSchedulingMode == SchedulingMode.FAIR - - def toNodeSeq(): Seq[Node] = { - listener.synchronized { - stageTable(stageRow, stages) - } - } - - /** Special table which merges two header cells. */ - private def stageTable[T](makeRow: T => Seq[Node], rows: Seq[T]): Seq[Node] = { - - - - {if (isFairScheduler) {} else {}} - - - - - - - - - {rows.map(r => makeRow(r))} - -
    Stage IdPool NameDescriptionSubmittedDurationTasks: Succeeded/TotalShuffle ReadShuffle Write
    - } - - private def makeProgressBar(started: Int, completed: Int, failed: String, total: Int): Seq[Node] = { - val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) - val startWidth = "width: %s%%".format((started.toDouble/total)*100) - -
    - - {completed}/{total} {failed} - -
    -
    -
    - } - - - private def stageRow(s: Stage): Seq[Node] = { - val submissionTime = s.submissionTime match { - case Some(t) => dateFmt.format(new Date(t)) - case None => "Unknown" - } - - val shuffleRead = listener.stageToShuffleRead.getOrElse(s.id, 0L) match { - case 0 => "" - case b => Utils.bytesToString(b) - } - val shuffleWrite = listener.stageToShuffleWrite.getOrElse(s.id, 0L) match { - case 0 => "" - case b => Utils.bytesToString(b) - } - - val startedTasks = listener.stageToTasksActive.getOrElse(s.id, HashSet[TaskInfo]()).size - val completedTasks = listener.stageToTasksComplete.getOrElse(s.id, 0) - val failedTasks = listener.stageToTasksFailed.getOrElse(s.id, 0) match { - case f if f > 0 => "(%s failed)".format(f) - case _ => "" - } - val totalTasks = s.numPartitions - - val poolName = listener.stageToPool.get(s) - - val nameLink = {s.name} - val description = listener.stageToDescription.get(s) - .map(d =>
    {d}
    {nameLink}
    ).getOrElse(nameLink) - val finishTime = s.completionTime.getOrElse(System.currentTimeMillis()) - val duration = s.submissionTime.map(t => finishTime - t) - - - {s.id} - {if (isFairScheduler) { - {poolName.get}} - } - {description} - {submissionTime} - - {duration.map(d => parent.formatDuration(d)).getOrElse("Unknown")} - - - {makeProgressBar(startedTasks, completedTasks, failedTasks, totalTasks)} - - {shuffleRead} - {shuffleWrite} - - } -} diff --git a/core/src/main/scala/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/spark/ui/storage/BlockManagerUI.scala deleted file mode 100644 index 49ed069c75..0000000000 --- a/core/src/main/scala/spark/ui/storage/BlockManagerUI.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui.storage - -import akka.util.Duration - -import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.Handler - -import spark.{Logging, SparkContext} -import spark.ui.JettyUtils._ - -/** Web UI showing storage status of all RDD's in the given SparkContext. */ -private[spark] class BlockManagerUI(val sc: SparkContext) extends Logging { - implicit val timeout = Duration.create( - System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") - - val indexPage = new IndexPage(this) - val rddPage = new RDDPage(this) - - def getHandlers = Seq[(String, Handler)]( - ("/storage/rdd", (request: HttpServletRequest) => rddPage.render(request)), - ("/storage", (request: HttpServletRequest) => indexPage.render(request)) - ) -} diff --git a/core/src/main/scala/spark/ui/storage/IndexPage.scala b/core/src/main/scala/spark/ui/storage/IndexPage.scala deleted file mode 100644 index fc6273c694..0000000000 --- a/core/src/main/scala/spark/ui/storage/IndexPage.scala +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui.storage - -import javax.servlet.http.HttpServletRequest - -import scala.xml.Node - -import spark.storage.{RDDInfo, StorageUtils} -import spark.Utils -import spark.ui.UIUtils._ -import spark.ui.Page._ - -/** Page showing list of RDD's currently stored in the cluster */ -private[spark] class IndexPage(parent: BlockManagerUI) { - val sc = parent.sc - - def render(request: HttpServletRequest): Seq[Node] = { - val storageStatusList = sc.getExecutorStorageStatus - // Calculate macro-level statistics - - val rddHeaders = Seq( - "RDD Name", - "Storage Level", - "Cached Partitions", - "Fraction Cached", - "Size in Memory", - "Size on Disk") - val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) - val content = listingTable(rddHeaders, rddRow, rdds) - - headerSparkPage(content, parent.sc, "Storage ", Storage) - } - - def rddRow(rdd: RDDInfo): Seq[Node] = { - - - - {rdd.name} - - - {rdd.storageLevel.description} - - {rdd.numCachedPartitions} - {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} - {Utils.bytesToString(rdd.memSize)} - {Utils.bytesToString(rdd.diskSize)} - - } -} diff --git a/core/src/main/scala/spark/ui/storage/RDDPage.scala b/core/src/main/scala/spark/ui/storage/RDDPage.scala deleted file mode 100644 index b128a5614d..0000000000 --- a/core/src/main/scala/spark/ui/storage/RDDPage.scala +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui.storage - -import javax.servlet.http.HttpServletRequest - -import scala.xml.Node - -import spark.Utils -import spark.storage.{StorageStatus, StorageUtils} -import spark.storage.BlockManagerMasterActor.BlockStatus -import spark.ui.UIUtils._ -import spark.ui.Page._ - - -/** Page showing storage details for a given RDD */ -private[spark] class RDDPage(parent: BlockManagerUI) { - val sc = parent.sc - - def render(request: HttpServletRequest): Seq[Node] = { - val id = request.getParameter("id") - val prefix = "rdd_" + id.toString - val storageStatusList = sc.getExecutorStorageStatus - val filteredStorageStatusList = StorageUtils. - filterStorageStatusByPrefix(storageStatusList, prefix) - val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head - - val workerHeaders = Seq("Host", "Memory Usage", "Disk Usage") - val workers = filteredStorageStatusList.map((prefix, _)) - val workerTable = listingTable(workerHeaders, workerRow, workers) - - val blockHeaders = Seq("Block Name", "Storage Level", "Size in Memory", "Size on Disk", - "Executors") - - val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1) - val blockLocations = StorageUtils.blockLocationsFromStorageStatus(filteredStorageStatusList) - val blocks = blockStatuses.map { - case(id, status) => (id, status, blockLocations.get(id).getOrElse(Seq("UNKNOWN"))) - } - val blockTable = listingTable(blockHeaders, blockRow, blocks) - - val content = -
    -
    -
      -
    • - Storage Level: - {rddInfo.storageLevel.description} -
    • -
    • - Cached Partitions: - {rddInfo.numCachedPartitions} -
    • -
    • - Total Partitions: - {rddInfo.numPartitions} -
    • -
    • - Memory Size: - {Utils.bytesToString(rddInfo.memSize)} -
    • -
    • - Disk Size: - {Utils.bytesToString(rddInfo.diskSize)} -
    • -
    -
    -
    - -
    -
    -

    Data Distribution on {workers.size} Executors

    - {workerTable} -
    -
    - -
    -
    -

    {blocks.size} Partitions

    - {blockTable} -
    -
    ; - - headerSparkPage(content, parent.sc, "RDD Storage Info for " + rddInfo.name, Storage) - } - - def blockRow(row: (String, BlockStatus, Seq[String])): Seq[Node] = { - val (id, block, locations) = row - - {id} - - {block.storageLevel.description} - - - {Utils.bytesToString(block.memSize)} - - - {Utils.bytesToString(block.diskSize)} - - - {locations.map(l => {l}
    )} - - - } - - def workerRow(worker: (String, StorageStatus)): Seq[Node] = { - val (prefix, status) = worker - - {status.blockManagerId.host + ":" + status.blockManagerId.port} - - {Utils.bytesToString(status.memUsed(prefix))} - ({Utils.bytesToString(status.memRemaining)} Remaining) - - {Utils.bytesToString(status.diskUsed(prefix))} - - } -} diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala deleted file mode 100644 index 9233277bdb..0000000000 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import akka.actor.{ActorSystem, ExtendedActorSystem} -import com.typesafe.config.ConfigFactory -import akka.util.duration._ -import akka.remote.RemoteActorRefProvider - - -/** - * Various utility classes for working with Akka. - */ -private[spark] object AkkaUtils { - - /** - * Creates an ActorSystem ready for remoting, with various Spark features. Returns both the - * ActorSystem itself and its port (which is hard to get from Akka). - * - * Note: the `name` parameter is important, as even if a client sends a message to right - * host + port, if the system name is incorrect, Akka will drop the message. - */ - def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = { - val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt - val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt - val akkaTimeout = System.getProperty("spark.akka.timeout", "60").toInt - val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt - val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off" - // 10 seconds is the default akka timeout, but in a cluster, we need higher by default. - val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt - - val akkaConf = ConfigFactory.parseString(""" - akka.daemonic = on - akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] - akka.stdout-loglevel = "ERROR" - akka.actor.provider = "akka.remote.RemoteActorRefProvider" - akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" - akka.remote.netty.hostname = "%s" - akka.remote.netty.port = %d - akka.remote.netty.connection-timeout = %ds - akka.remote.netty.message-frame-size = %d MiB - akka.remote.netty.execution-pool-size = %d - akka.actor.default-dispatcher.throughput = %d - akka.remote.log-remote-lifecycle-events = %s - akka.remote.netty.write-timeout = %ds - """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize, - lifecycleEvents, akkaWriteTimeout)) - - val actorSystem = ActorSystem(name, akkaConf) - - // Figure out the port number we bound to, in case port was passed as 0. This is a bit of a - // hack because Akka doesn't let you figure out the port through the public API yet. - val provider = actorSystem.asInstanceOf[ExtendedActorSystem].provider - val boundPort = provider.asInstanceOf[RemoteActorRefProvider].transport.address.port.get - return (actorSystem, boundPort) - } -} diff --git a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala deleted file mode 100644 index 0575497f5d..0000000000 --- a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import java.io.Serializable -import java.util.{PriorityQueue => JPriorityQueue} -import scala.collection.generic.Growable -import scala.collection.JavaConverters._ - -/** - * Bounded priority queue. This class wraps the original PriorityQueue - * class and modifies it such that only the top K elements are retained. - * The top K elements are defined by an implicit Ordering[A]. - */ -class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) - extends Iterable[A] with Growable[A] with Serializable { - - private val underlying = new JPriorityQueue[A](maxSize, ord) - - override def iterator: Iterator[A] = underlying.iterator.asScala - - override def ++=(xs: TraversableOnce[A]): this.type = { - xs.foreach { this += _ } - this - } - - override def +=(elem: A): this.type = { - if (size < maxSize) underlying.offer(elem) - else maybeReplaceLowest(elem) - this - } - - override def +=(elem1: A, elem2: A, elems: A*): this.type = { - this += elem1 += elem2 ++= elems - } - - override def clear() { underlying.clear() } - - private def maybeReplaceLowest(a: A): Boolean = { - val head = underlying.peek() - if (head != null && ord.gt(a, head)) { - underlying.poll() - underlying.offer(a) - } else false - } -} - diff --git a/core/src/main/scala/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/spark/util/ByteBufferInputStream.scala deleted file mode 100644 index 47a28e2f76..0000000000 --- a/core/src/main/scala/spark/util/ByteBufferInputStream.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import java.io.InputStream -import java.nio.ByteBuffer -import spark.storage.BlockManager - -/** - * Reads data from a ByteBuffer, and optionally cleans it up using BlockManager.dispose() - * at the end of the stream (e.g. to close a memory-mapped file). - */ -private[spark] -class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = false) - extends InputStream { - - override def read(): Int = { - if (buffer == null || buffer.remaining() == 0) { - cleanUp() - -1 - } else { - buffer.get() & 0xFF - } - } - - override def read(dest: Array[Byte]): Int = { - read(dest, 0, dest.length) - } - - override def read(dest: Array[Byte], offset: Int, length: Int): Int = { - if (buffer == null || buffer.remaining() == 0) { - cleanUp() - -1 - } else { - val amountToGet = math.min(buffer.remaining(), length) - buffer.get(dest, offset, amountToGet) - amountToGet - } - } - - override def skip(bytes: Long): Long = { - if (buffer != null) { - val amountToSkip = math.min(bytes, buffer.remaining).toInt - buffer.position(buffer.position + amountToSkip) - if (buffer.remaining() == 0) { - cleanUp() - } - amountToSkip - } else { - 0L - } - } - - /** - * Clean up the buffer, and potentially dispose of it using BlockManager.dispose(). - */ - private def cleanUp() { - if (buffer != null) { - if (dispose) { - BlockManager.dispose(buffer) - } - buffer = null - } - } -} diff --git a/core/src/main/scala/spark/util/Clock.scala b/core/src/main/scala/spark/util/Clock.scala deleted file mode 100644 index aa71a5b442..0000000000 --- a/core/src/main/scala/spark/util/Clock.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -/** - * An interface to represent clocks, so that they can be mocked out in unit tests. - */ -private[spark] trait Clock { - def getTime(): Long -} - -private[spark] object SystemClock extends Clock { - def getTime(): Long = System.currentTimeMillis() -} diff --git a/core/src/main/scala/spark/util/CompletionIterator.scala b/core/src/main/scala/spark/util/CompletionIterator.scala deleted file mode 100644 index 210450892b..0000000000 --- a/core/src/main/scala/spark/util/CompletionIterator.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -/** - * Wrapper around an iterator which calls a completion method after it successfully iterates through all the elements - */ -abstract class CompletionIterator[+A, +I <: Iterator[A]](sub: I) extends Iterator[A]{ - def next = sub.next - def hasNext = { - val r = sub.hasNext - if (!r) { - completion - } - r - } - - def completion() -} - -object CompletionIterator { - def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A,I] = { - new CompletionIterator[A,I](sub) { - def completion() = completionFunction - } - } -} diff --git a/core/src/main/scala/spark/util/Distribution.scala b/core/src/main/scala/spark/util/Distribution.scala deleted file mode 100644 index 5d4d7a6c50..0000000000 --- a/core/src/main/scala/spark/util/Distribution.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import java.io.PrintStream - -/** - * Util for getting some stats from a small sample of numeric values, with some handy summary functions. - * - * Entirely in memory, not intended as a good way to compute stats over large data sets. - * - * Assumes you are giving it a non-empty set of data - */ -class Distribution(val data: Array[Double], val startIdx: Int, val endIdx: Int) { - require(startIdx < endIdx) - def this(data: Traversable[Double]) = this(data.toArray, 0, data.size) - java.util.Arrays.sort(data, startIdx, endIdx) - val length = endIdx - startIdx - - val defaultProbabilities = Array(0,0.25,0.5,0.75,1.0) - - /** - * Get the value of the distribution at the given probabilities. Probabilities should be - * given from 0 to 1 - * @param probabilities - */ - def getQuantiles(probabilities: Traversable[Double] = defaultProbabilities) = { - probabilities.toIndexedSeq.map{p:Double => data(closestIndex(p))} - } - - private def closestIndex(p: Double) = { - math.min((p * length).toInt + startIdx, endIdx - 1) - } - - def showQuantiles(out: PrintStream = System.out) = { - out.println("min\t25%\t50%\t75%\tmax") - getQuantiles(defaultProbabilities).foreach{q => out.print(q + "\t")} - out.println - } - - def statCounter = StatCounter(data.slice(startIdx, endIdx)) - - /** - * print a summary of this distribution to the given PrintStream. - * @param out - */ - def summary(out: PrintStream = System.out) { - out.println(statCounter) - showQuantiles(out) - } -} - -object Distribution { - - def apply(data: Traversable[Double]): Option[Distribution] = { - if (data.size > 0) - Some(new Distribution(data)) - else - None - } - - def showQuantiles(out: PrintStream = System.out, quantiles: Traversable[Double]) { - out.println("min\t25%\t50%\t75%\tmax") - quantiles.foreach{q => out.print(q + "\t")} - out.println - } -} diff --git a/core/src/main/scala/spark/util/IdGenerator.scala b/core/src/main/scala/spark/util/IdGenerator.scala deleted file mode 100644 index 3422280559..0000000000 --- a/core/src/main/scala/spark/util/IdGenerator.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import java.util.concurrent.atomic.AtomicInteger - -/** - * A util used to get a unique generation ID. This is a wrapper around Java's - * AtomicInteger. An example usage is in BlockManager, where each BlockManager - * instance would start an Akka actor and we use this utility to assign the Akka - * actors unique names. - */ -private[spark] class IdGenerator { - private var id = new AtomicInteger - def next: Int = id.incrementAndGet -} diff --git a/core/src/main/scala/spark/util/IntParam.scala b/core/src/main/scala/spark/util/IntParam.scala deleted file mode 100644 index daf0d58fa2..0000000000 --- a/core/src/main/scala/spark/util/IntParam.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -/** - * An extractor object for parsing strings into integers. - */ -private[spark] object IntParam { - def unapply(str: String): Option[Int] = { - try { - Some(str.toInt) - } catch { - case e: NumberFormatException => None - } - } -} diff --git a/core/src/main/scala/spark/util/MemoryParam.scala b/core/src/main/scala/spark/util/MemoryParam.scala deleted file mode 100644 index 298562323a..0000000000 --- a/core/src/main/scala/spark/util/MemoryParam.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import spark.Utils - -/** - * An extractor object for parsing JVM memory strings, such as "10g", into an Int representing - * the number of megabytes. Supports the same formats as Utils.memoryStringToMb. - */ -private[spark] object MemoryParam { - def unapply(str: String): Option[Int] = { - try { - Some(Utils.memoryStringToMb(str)) - } catch { - case e: NumberFormatException => None - } - } -} diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala deleted file mode 100644 index 92909e0959..0000000000 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} -import java.util.{TimerTask, Timer} -import spark.Logging - - -/** - * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) - */ -class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { - private val delaySeconds = MetadataCleaner.getDelaySeconds - private val periodSeconds = math.max(10, delaySeconds / 10) - private val timer = new Timer(name + " cleanup timer", true) - - private val task = new TimerTask { - override def run() { - try { - cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) - logInfo("Ran metadata cleaner for " + name) - } catch { - case e: Exception => logError("Error running cleanup task for " + name, e) - } - } - } - - if (delaySeconds > 0) { - logDebug( - "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " + - "and period of " + periodSeconds + " secs") - timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) - } - - def cancel() { - timer.cancel() - } -} - - -object MetadataCleaner { - def getDelaySeconds = System.getProperty("spark.cleaner.ttl", "-1").toInt - def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.ttl", delay.toString) } -} - diff --git a/core/src/main/scala/spark/util/MutablePair.scala b/core/src/main/scala/spark/util/MutablePair.scala deleted file mode 100644 index 78d404e66b..0000000000 --- a/core/src/main/scala/spark/util/MutablePair.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - - -/** - * A tuple of 2 elements. This can be used as an alternative to Scala's Tuple2 when we want to - * minimize object allocation. - * - * @param _1 Element 1 of this MutablePair - * @param _2 Element 2 of this MutablePair - */ -case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T1, - @specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2] - (var _1: T1, var _2: T2) - extends Product2[T1, T2] -{ - override def toString = "(" + _1 + "," + _2 + ")" - - override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_,_]] -} diff --git a/core/src/main/scala/spark/util/NextIterator.scala b/core/src/main/scala/spark/util/NextIterator.scala deleted file mode 100644 index 22163ece8d..0000000000 --- a/core/src/main/scala/spark/util/NextIterator.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -/** Provides a basic/boilerplate Iterator implementation. */ -private[spark] abstract class NextIterator[U] extends Iterator[U] { - - private var gotNext = false - private var nextValue: U = _ - private var closed = false - protected var finished = false - - /** - * Method for subclasses to implement to provide the next element. - * - * If no next element is available, the subclass should set `finished` - * to `true` and may return any value (it will be ignored). - * - * This convention is required because `null` may be a valid value, - * and using `Option` seems like it might create unnecessary Some/None - * instances, given some iterators might be called in a tight loop. - * - * @return U, or set 'finished' when done - */ - protected def getNext(): U - - /** - * Method for subclasses to implement when all elements have been successfully - * iterated, and the iteration is done. - * - * Note: `NextIterator` cannot guarantee that `close` will be - * called because it has no control over what happens when an exception - * happens in the user code that is calling hasNext/next. - * - * Ideally you should have another try/catch, as in HadoopRDD, that - * ensures any resources are closed should iteration fail. - */ - protected def close() - - /** - * Calls the subclass-defined close method, but only once. - * - * Usually calling `close` multiple times should be fine, but historically - * there have been issues with some InputFormats throwing exceptions. - */ - def closeIfNeeded() { - if (!closed) { - close() - closed = true - } - } - - override def hasNext: Boolean = { - if (!finished) { - if (!gotNext) { - nextValue = getNext() - if (finished) { - closeIfNeeded() - } - gotNext = true - } - } - !finished - } - - override def next(): U = { - if (!hasNext) { - throw new NoSuchElementException("End of stream") - } - gotNext = false - nextValue - } -} diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala deleted file mode 100644 index 00f782bbe7..0000000000 --- a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import scala.annotation.tailrec - -import java.io.OutputStream -import java.util.concurrent.TimeUnit._ - -class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { - val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) - val CHUNK_SIZE = 8192 - var lastSyncTime = System.nanoTime - var bytesWrittenSinceSync: Long = 0 - - override def write(b: Int) { - waitToWrite(1) - out.write(b) - } - - override def write(bytes: Array[Byte]) { - write(bytes, 0, bytes.length) - } - - @tailrec - override final def write(bytes: Array[Byte], offset: Int, length: Int) { - val writeSize = math.min(length - offset, CHUNK_SIZE) - if (writeSize > 0) { - waitToWrite(writeSize) - out.write(bytes, offset, writeSize) - write(bytes, offset + writeSize, length) - } - } - - override def flush() { - out.flush() - } - - override def close() { - out.close() - } - - @tailrec - private def waitToWrite(numBytes: Int) { - val now = System.nanoTime - val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS) - val rate = bytesWrittenSinceSync.toDouble / elapsedSecs - if (rate < bytesPerSec) { - // It's okay to write; just update some variables and return - bytesWrittenSinceSync += numBytes - if (now > lastSyncTime + SYNC_INTERVAL) { - // Sync interval has passed; let's resync - lastSyncTime = now - bytesWrittenSinceSync = numBytes - } - } else { - // Calculate how much time we should sleep to bring ourselves to the desired rate. - // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala) - val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) - if (sleepTime > 0) Thread.sleep(sleepTime) - waitToWrite(numBytes) - } - } -} diff --git a/core/src/main/scala/spark/util/SerializableBuffer.scala b/core/src/main/scala/spark/util/SerializableBuffer.scala deleted file mode 100644 index 7e6842628a..0000000000 --- a/core/src/main/scala/spark/util/SerializableBuffer.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import java.nio.ByteBuffer -import java.io.{IOException, ObjectOutputStream, EOFException, ObjectInputStream} -import java.nio.channels.Channels - -/** - * A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make - * it easier to pass ByteBuffers in case class messages. - */ -private[spark] -class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable { - def value = buffer - - private def readObject(in: ObjectInputStream) { - val length = in.readInt() - buffer = ByteBuffer.allocate(length) - var amountRead = 0 - val channel = Channels.newChannel(in) - while (amountRead < length) { - val ret = channel.read(buffer) - if (ret == -1) { - throw new EOFException("End of file before fully reading buffer") - } - amountRead += ret - } - buffer.rewind() // Allow us to read it later - } - - private def writeObject(out: ObjectOutputStream) { - out.writeInt(buffer.limit()) - if (Channels.newChannel(out).write(buffer) != buffer.limit()) { - throw new IOException("Could not fully write buffer to output stream") - } - buffer.rewind() // Allow us to write it again later - } -} diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala deleted file mode 100644 index 76358d4151..0000000000 --- a/core/src/main/scala/spark/util/StatCounter.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -/** - * A class for tracking the statistics of a set of numbers (count, mean and variance) in a - * numerically robust way. Includes support for merging two StatCounters. Based on - * [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance Welford and Chan's algorithms for running variance]]. - * - * @constructor Initialize the StatCounter with the given values. - */ -class StatCounter(values: TraversableOnce[Double]) extends Serializable { - private var n: Long = 0 // Running count of our values - private var mu: Double = 0 // Running mean of our values - private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2) - - merge(values) - - /** Initialize the StatCounter with no values. */ - def this() = this(Nil) - - /** Add a value into this StatCounter, updating the internal statistics. */ - def merge(value: Double): StatCounter = { - val delta = value - mu - n += 1 - mu += delta / n - m2 += delta * (value - mu) - this - } - - /** Add multiple values into this StatCounter, updating the internal statistics. */ - def merge(values: TraversableOnce[Double]): StatCounter = { - values.foreach(v => merge(v)) - this - } - - /** Merge another StatCounter into this one, adding up the internal statistics. */ - def merge(other: StatCounter): StatCounter = { - if (other == this) { - merge(other.copy()) // Avoid overwriting fields in a weird order - } else { - if (n == 0) { - mu = other.mu - m2 = other.m2 - n = other.n - } else if (other.n != 0) { - val delta = other.mu - mu - if (other.n * 10 < n) { - mu = mu + (delta * other.n) / (n + other.n) - } else if (n * 10 < other.n) { - mu = other.mu - (delta * n) / (n + other.n) - } else { - mu = (mu * n + other.mu * other.n) / (n + other.n) - } - m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) - n += other.n - } - this - } - } - - /** Clone this StatCounter */ - def copy(): StatCounter = { - val other = new StatCounter - other.n = n - other.mu = mu - other.m2 = m2 - other - } - - def count: Long = n - - def mean: Double = mu - - def sum: Double = n * mu - - /** Return the variance of the values. */ - def variance: Double = { - if (n == 0) - Double.NaN - else - m2 / n - } - - /** - * Return the sample variance, which corrects for bias in estimating the variance by dividing - * by N-1 instead of N. - */ - def sampleVariance: Double = { - if (n <= 1) - Double.NaN - else - m2 / (n - 1) - } - - /** Return the standard deviation of the values. */ - def stdev: Double = math.sqrt(variance) - - /** - * Return the sample standard deviation of the values, which corrects for bias in estimating the - * variance by dividing by N-1 instead of N. - */ - def sampleStdev: Double = math.sqrt(sampleVariance) - - override def toString: String = { - "(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev) - } -} - -object StatCounter { - /** Build a StatCounter from a list of values. */ - def apply(values: TraversableOnce[Double]) = new StatCounter(values) - - /** Build a StatCounter from a list of values passed as variable-length arguments. */ - def apply(values: Double*) = new StatCounter(values) -} diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala deleted file mode 100644 index 07772a0afb..0000000000 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConversions -import scala.collection.mutable.Map -import scala.collection.immutable -import spark.scheduler.MapStatus - -/** - * This is a custom implementation of scala.collection.mutable.Map which stores the insertion - * time stamp along with each key-value pair. Key-value pairs that are older than a particular - * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in - * replacement of scala.collection.mutable.HashMap. - */ -class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { - val internalMap = new ConcurrentHashMap[A, (B, Long)]() - - def get(key: A): Option[B] = { - val value = internalMap.get(key) - if (value != null) Some(value._1) else None - } - - def iterator: Iterator[(A, B)] = { - val jIterator = internalMap.entrySet().iterator() - JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1)) - } - - override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = { - val newMap = new TimeStampedHashMap[A, B1] - newMap.internalMap.putAll(this.internalMap) - newMap.internalMap.put(kv._1, (kv._2, currentTime)) - newMap - } - - override def - (key: A): Map[A, B] = { - val newMap = new TimeStampedHashMap[A, B] - newMap.internalMap.putAll(this.internalMap) - newMap.internalMap.remove(key) - newMap - } - - override def += (kv: (A, B)): this.type = { - internalMap.put(kv._1, (kv._2, currentTime)) - this - } - - // Should we return previous value directly or as Option ? - def putIfAbsent(key: A, value: B): Option[B] = { - val prev = internalMap.putIfAbsent(key, (value, currentTime)) - if (prev != null) Some(prev._1) else None - } - - - override def -= (key: A): this.type = { - internalMap.remove(key) - this - } - - override def update(key: A, value: B) { - this += ((key, value)) - } - - override def apply(key: A): B = { - val value = internalMap.get(key) - if (value == null) throw new NoSuchElementException() - value._1 - } - - override def filter(p: ((A, B)) => Boolean): Map[A, B] = { - JavaConversions.asScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p) - } - - override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() - - override def size: Int = internalMap.size - - override def foreach[U](f: ((A, B)) => U) { - val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - val kv = (entry.getKey, entry.getValue._1) - f(kv) - } - } - - def toMap: immutable.Map[A, B] = iterator.toMap - - /** - * Removes old key-value pairs that have timestamp earlier than `threshTime` - */ - def clearOldValues(threshTime: Long) { - val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - if (entry.getValue._2 < threshTime) { - logDebug("Removing key " + entry.getKey) - iterator.remove() - } - } - } - - private def currentTime: Long = System.currentTimeMillis() - -} diff --git a/core/src/main/scala/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/spark/util/TimeStampedHashSet.scala deleted file mode 100644 index 41e3fd8cba..0000000000 --- a/core/src/main/scala/spark/util/TimeStampedHashSet.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import scala.collection.mutable.Set -import scala.collection.JavaConversions -import java.util.concurrent.ConcurrentHashMap - - -class TimeStampedHashSet[A] extends Set[A] { - val internalMap = new ConcurrentHashMap[A, Long]() - - def contains(key: A): Boolean = { - internalMap.contains(key) - } - - def iterator: Iterator[A] = { - val jIterator = internalMap.entrySet().iterator() - JavaConversions.asScalaIterator(jIterator).map(_.getKey) - } - - override def + (elem: A): Set[A] = { - val newSet = new TimeStampedHashSet[A] - newSet ++= this - newSet += elem - newSet - } - - override def - (elem: A): Set[A] = { - val newSet = new TimeStampedHashSet[A] - newSet ++= this - newSet -= elem - newSet - } - - override def += (key: A): this.type = { - internalMap.put(key, currentTime) - this - } - - override def -= (key: A): this.type = { - internalMap.remove(key) - this - } - - override def empty: Set[A] = new TimeStampedHashSet[A]() - - override def size(): Int = internalMap.size() - - override def foreach[U](f: (A) => U): Unit = { - val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { - f(iterator.next.getKey) - } - } - - /** - * Removes old values that have timestamp earlier than `threshTime` - */ - def clearOldValues(threshTime: Long) { - val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - if (entry.getValue < threshTime) { - iterator.remove() - } - } - } - - private def currentTime: Long = System.currentTimeMillis() -} diff --git a/core/src/main/scala/spark/util/Vector.scala b/core/src/main/scala/spark/util/Vector.scala deleted file mode 100644 index a47cac3b96..0000000000 --- a/core/src/main/scala/spark/util/Vector.scala +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -class Vector(val elements: Array[Double]) extends Serializable { - def length = elements.length - - def apply(index: Int) = elements(index) - - def + (other: Vector): Vector = { - if (length != other.length) - throw new IllegalArgumentException("Vectors of different length") - return Vector(length, i => this(i) + other(i)) - } - - def add(other: Vector) = this + other - - def - (other: Vector): Vector = { - if (length != other.length) - throw new IllegalArgumentException("Vectors of different length") - return Vector(length, i => this(i) - other(i)) - } - - def subtract(other: Vector) = this - other - - def dot(other: Vector): Double = { - if (length != other.length) - throw new IllegalArgumentException("Vectors of different length") - var ans = 0.0 - var i = 0 - while (i < length) { - ans += this(i) * other(i) - i += 1 - } - return ans - } - - /** - * return (this + plus) dot other, but without creating any intermediate storage - * @param plus - * @param other - * @return - */ - def plusDot(plus: Vector, other: Vector): Double = { - if (length != other.length) - throw new IllegalArgumentException("Vectors of different length") - if (length != plus.length) - throw new IllegalArgumentException("Vectors of different length") - var ans = 0.0 - var i = 0 - while (i < length) { - ans += (this(i) + plus(i)) * other(i) - i += 1 - } - return ans - } - - def += (other: Vector): Vector = { - if (length != other.length) - throw new IllegalArgumentException("Vectors of different length") - var i = 0 - while (i < length) { - elements(i) += other(i) - i += 1 - } - this - } - - def addInPlace(other: Vector) = this +=other - - def * (scale: Double): Vector = Vector(length, i => this(i) * scale) - - def multiply (d: Double) = this * d - - def / (d: Double): Vector = this * (1 / d) - - def divide (d: Double) = this / d - - def unary_- = this * -1 - - def sum = elements.reduceLeft(_ + _) - - def squaredDist(other: Vector): Double = { - var ans = 0.0 - var i = 0 - while (i < length) { - ans += (this(i) - other(i)) * (this(i) - other(i)) - i += 1 - } - return ans - } - - def dist(other: Vector): Double = math.sqrt(squaredDist(other)) - - override def toString = elements.mkString("(", ", ", ")") -} - -object Vector { - def apply(elements: Array[Double]) = new Vector(elements) - - def apply(elements: Double*) = new Vector(elements.toArray) - - def apply(length: Int, initializer: Int => Double): Vector = { - val elements: Array[Double] = Array.tabulate(length)(initializer) - return new Vector(elements) - } - - def zeros(length: Int) = new Vector(new Array[Double](length)) - - def ones(length: Int) = Vector(length, _ => 1) - - class Multiplier(num: Double) { - def * (vec: Vector) = vec * num - } - - implicit def doubleToMultiplier(num: Double) = new Multiplier(num) - - implicit object VectorAccumParam extends spark.AccumulatorParam[Vector] { - def addInPlace(t1: Vector, t2: Vector) = t1 + t2 - - def zero(initialValue: Vector) = Vector.zeros(initialValue.length) - } - -} diff --git a/core/src/test/resources/test_metrics_config.properties b/core/src/test/resources/test_metrics_config.properties index 2b31ddf2eb..056a158456 100644 --- a/core/src/test/resources/test_metrics_config.properties +++ b/core/src/test/resources/test_metrics_config.properties @@ -1,6 +1,6 @@ *.sink.console.period = 10 *.sink.console.unit = seconds -*.source.jvm.class = spark.metrics.source.JvmSource +*.source.jvm.class = org.apache.spark.metrics.source.JvmSource master.sink.console.period = 20 master.sink.console.unit = minutes diff --git a/core/src/test/resources/test_metrics_system.properties b/core/src/test/resources/test_metrics_system.properties index d5479f0298..6f5ecea93a 100644 --- a/core/src/test/resources/test_metrics_system.properties +++ b/core/src/test/resources/test_metrics_system.properties @@ -1,7 +1,7 @@ *.sink.console.period = 10 *.sink.console.unit = seconds -test.sink.console.class = spark.metrics.sink.ConsoleSink -test.sink.dummy.class = spark.metrics.sink.DummySink -test.source.dummy.class = spark.metrics.source.DummySource +test.sink.console.class = org.apache.spark.metrics.sink.ConsoleSink +test.sink.dummy.class = org.apache.spark.metrics.sink.DummySink +test.source.dummy.class = org.apache.spark.metrics.source.DummySource test.sink.console.period = 20 test.sink.console.unit = minutes diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala new file mode 100644 index 0000000000..4434f3b87c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import collection.mutable +import java.util.Random +import scala.math.exp +import scala.math.signum +import org.apache.spark.SparkContext._ + +class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext { + + test ("basic accumulation"){ + sc = new SparkContext("local", "test") + val acc : Accumulator[Int] = sc.accumulator(0) + + val d = sc.parallelize(1 to 20) + d.foreach{x => acc += x} + acc.value should be (210) + + + val longAcc = sc.accumulator(0l) + val maxInt = Integer.MAX_VALUE.toLong + d.foreach{x => longAcc += maxInt + x} + longAcc.value should be (210l + maxInt * 20) + } + + test ("value not assignable from tasks") { + sc = new SparkContext("local", "test") + val acc : Accumulator[Int] = sc.accumulator(0) + + val d = sc.parallelize(1 to 20) + evaluating {d.foreach{x => acc.value = x}} should produce [Exception] + } + + test ("add value to collection accumulators") { + import SetAccum._ + val maxI = 1000 + for (nThreads <- List(1, 10)) { //test single & multi-threaded + sc = new SparkContext("local[" + nThreads + "]", "test") + val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) + val d = sc.parallelize(1 to maxI) + d.foreach { + x => acc += x + } + val v = acc.value.asInstanceOf[mutable.Set[Int]] + for (i <- 1 to maxI) { + v should contain(i) + } + resetSparkContext() + } + } + + implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] { + def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = { + t1 ++= t2 + t1 + } + def addAccumulator(t1: mutable.Set[Any], t2: Any) : mutable.Set[Any] = { + t1 += t2 + t1 + } + def zero(t: mutable.Set[Any]) : mutable.Set[Any] = { + new mutable.HashSet[Any]() + } + } + + test ("value not readable in tasks") { + import SetAccum._ + val maxI = 1000 + for (nThreads <- List(1, 10)) { //test single & multi-threaded + sc = new SparkContext("local[" + nThreads + "]", "test") + val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) + val d = sc.parallelize(1 to maxI) + evaluating { + d.foreach { + x => acc.value += x + } + } should produce [SparkException] + resetSparkContext() + } + } + + test ("collection accumulators") { + val maxI = 1000 + for (nThreads <- List(1, 10)) { + // test single & multi-threaded + sc = new SparkContext("local[" + nThreads + "]", "test") + val setAcc = sc.accumulableCollection(mutable.HashSet[Int]()) + val bufferAcc = sc.accumulableCollection(mutable.ArrayBuffer[Int]()) + val mapAcc = sc.accumulableCollection(mutable.HashMap[Int,String]()) + val d = sc.parallelize((1 to maxI) ++ (1 to maxI)) + d.foreach { + x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)} + } + + // Note that this is typed correctly -- no casts necessary + setAcc.value.size should be (maxI) + bufferAcc.value.size should be (2 * maxI) + mapAcc.value.size should be (maxI) + for (i <- 1 to maxI) { + setAcc.value should contain(i) + bufferAcc.value should contain(i) + mapAcc.value should contain (i -> i.toString) + } + resetSparkContext() + } + } + + test ("localValue readable in tasks") { + import SetAccum._ + val maxI = 1000 + for (nThreads <- List(1, 10)) { //test single & multi-threaded + sc = new SparkContext("local[" + nThreads + "]", "test") + val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) + val groupedInts = (1 to (maxI/20)).map {x => (20 * (x - 1) to 20 * x).toSet} + val d = sc.parallelize(groupedInts) + d.foreach { + x => acc.localValue ++= x + } + acc.value should be ( (0 to maxI).toSet) + resetSparkContext() + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala new file mode 100644 index 0000000000..b3a53d928b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +class BroadcastSuite extends FunSuite with LocalSparkContext { + + test("basic broadcast") { + sc = new SparkContext("local", "test") + val list = List(1, 2, 3, 4) + val listBroadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === Set((1, 10), (2, 10))) + } + + test("broadcast variables accessed in multiple threads") { + sc = new SparkContext("local[10]", "test") + val list = List(1, 2, 3, 4) + val listBroadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + } +} diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala new file mode 100644 index 0000000000..23b14f4245 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -0,0 +1,392 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import java.io.File +import org.apache.spark.rdd._ +import org.apache.spark.SparkContext._ +import storage.StorageLevel + +class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { + initLogging() + + var checkpointDir: File = _ + val partitioner = new HashPartitioner(2) + + override def beforeEach() { + super.beforeEach() + checkpointDir = File.createTempFile("temp", "") + checkpointDir.delete() + sc = new SparkContext("local", "test") + sc.setCheckpointDir(checkpointDir.toString) + } + + override def afterEach() { + super.afterEach() + if (checkpointDir != null) { + checkpointDir.delete() + } + } + + test("basic checkpointing") { + val parCollection = sc.makeRDD(1 to 4) + val flatMappedRDD = parCollection.flatMap(x => 1 to x) + flatMappedRDD.checkpoint() + assert(flatMappedRDD.dependencies.head.rdd == parCollection) + val result = flatMappedRDD.collect() + assert(flatMappedRDD.dependencies.head.rdd != parCollection) + assert(flatMappedRDD.collect() === result) + } + + test("RDDs with one-to-one dependencies") { + testCheckpointing(_.map(x => x.toString)) + testCheckpointing(_.flatMap(x => 1 to x)) + testCheckpointing(_.filter(_ % 2 == 0)) + testCheckpointing(_.sample(false, 0.5, 0)) + testCheckpointing(_.glom()) + testCheckpointing(_.mapPartitions(_.map(_.toString))) + testCheckpointing(r => new MapPartitionsWithIndexRDD(r, + (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false )) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) + testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) + testCheckpointing(_.pipe(Seq("cat"))) + } + + test("ParallelCollection") { + val parCollection = sc.makeRDD(1 to 4, 2) + val numPartitions = parCollection.partitions.size + parCollection.checkpoint() + assert(parCollection.dependencies === Nil) + val result = parCollection.collect() + assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) + assert(parCollection.dependencies != Nil) + assert(parCollection.partitions.length === numPartitions) + assert(parCollection.partitions.toList === parCollection.checkpointData.get.getPartitions.toList) + assert(parCollection.collect() === result) + } + + test("BlockRDD") { + val blockId = "id" + val blockManager = SparkEnv.get.blockManager + blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) + val blockRDD = new BlockRDD[String](sc, Array(blockId)) + val numPartitions = blockRDD.partitions.size + blockRDD.checkpoint() + val result = blockRDD.collect() + assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result) + assert(blockRDD.dependencies != Nil) + assert(blockRDD.partitions.length === numPartitions) + assert(blockRDD.partitions.toList === blockRDD.checkpointData.get.getPartitions.toList) + assert(blockRDD.collect() === result) + } + + test("ShuffledRDD") { + testCheckpointing(rdd => { + // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD + new ShuffledRDD[Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner) + }) + } + + test("UnionRDD") { + def otherRDD = sc.makeRDD(1 to 10, 1) + + // Test whether the size of UnionRDDPartitions reduce in size after parent RDD is checkpointed. + // Current implementation of UnionRDD has transient reference to parent RDDs, + // so only the partitions will reduce in serialized size, not the RDD. + testCheckpointing(_.union(otherRDD), false, true) + testParentCheckpointing(_.union(otherRDD), false, true) + } + + test("CartesianRDD") { + def otherRDD = sc.makeRDD(1 to 10, 1) + testCheckpointing(new CartesianRDD(sc, _, otherRDD)) + + // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed + // Current implementation of CoalescedRDDPartition has transient reference to parent RDD, + // so only the RDD will reduce in serialized size, not the partitions. + testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false) + + // Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after + // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions. + // Note that this test is very specific to the current implementation of CartesianRDD. + val ones = sc.makeRDD(1 to 100, 10).map(x => x) + ones.checkpoint() // checkpoint that MappedRDD + val cartesian = new CartesianRDD(sc, ones, ones) + val splitBeforeCheckpoint = + serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition]) + cartesian.count() // do the checkpointing + val splitAfterCheckpoint = + serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition]) + assert( + (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) && + (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2), + "CartesianRDD.parents not updated after parent RDD checkpointed" + ) + } + + test("CoalescedRDD") { + testCheckpointing(_.coalesce(2)) + + // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed + // Current implementation of CoalescedRDDPartition has transient reference to parent RDD, + // so only the RDD will reduce in serialized size, not the partitions. + testParentCheckpointing(_.coalesce(2), true, false) + + // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents) after + // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions. + // Note that this test is very specific to the current implementation of CoalescedRDDPartitions + val ones = sc.makeRDD(1 to 100, 10).map(x => x) + ones.checkpoint() // checkpoint that MappedRDD + val coalesced = new CoalescedRDD(ones, 2) + val splitBeforeCheckpoint = + serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition]) + coalesced.count() // do the checkpointing + val splitAfterCheckpoint = + serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition]) + assert( + splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head, + "CoalescedRDDPartition.parents not updated after parent RDD checkpointed" + ) + } + + test("CoGroupedRDD") { + val longLineageRDD1 = generateLongLineageRDDForCoGroupedRDD() + testCheckpointing(rdd => { + CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner) + }, false, true) + + val longLineageRDD2 = generateLongLineageRDDForCoGroupedRDD() + testParentCheckpointing(rdd => { + CheckpointSuite.cogroup( + longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner) + }, false, true) + } + + test("ZippedRDD") { + testCheckpointing( + rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false) + + // Test whether size of ZippedRDD reduce in size after parent RDD is checkpointed + // Current implementation of ZippedRDDPartitions has transient references to parent RDDs, + // so only the RDD will reduce in serialized size, not the partitions. + testParentCheckpointing( + rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false) + } + + test("CheckpointRDD with zero partitions") { + val rdd = new BlockRDD[Int](sc, Array[String]()) + assert(rdd.partitions.size === 0) + assert(rdd.isCheckpointed === false) + rdd.checkpoint() + assert(rdd.count() === 0) + assert(rdd.isCheckpointed === true) + assert(rdd.partitions.size === 0) + } + + /** + * Test checkpointing of the final RDD generated by the given operation. By default, + * this method tests whether the size of serialized RDD has reduced after checkpointing or not. + * It can also test whether the size of serialized RDD partitions has reduced after checkpointing or + * not, but this is not done by default as usually the partitions do not refer to any RDD and + * therefore never store the lineage. + */ + def testCheckpointing[U: ClassManifest]( + op: (RDD[Int]) => RDD[U], + testRDDSize: Boolean = true, + testRDDPartitionSize: Boolean = false + ) { + // Generate the final RDD using given RDD operation + val baseRDD = generateLongLineageRDD() + val operatedRDD = op(baseRDD) + val parentRDD = operatedRDD.dependencies.headOption.orNull + val rddType = operatedRDD.getClass.getSimpleName + val numPartitions = operatedRDD.partitions.length + + // Find serialized sizes before and after the checkpoint + val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + operatedRDD.checkpoint() + val result = operatedRDD.collect() + val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + + // Test whether the checkpoint file has been created + assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result) + + // Test whether dependencies have been changed from its earlier parent RDD + assert(operatedRDD.dependencies.head.rdd != parentRDD) + + // Test whether the partitions have been changed to the new Hadoop partitions + assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) + + // Test whether the number of partitions is same as before + assert(operatedRDD.partitions.length === numPartitions) + + // Test whether the data in the checkpointed RDD is same as original + assert(operatedRDD.collect() === result) + + // Test whether serialized size of the RDD has reduced. If the RDD + // does not have any dependency to another RDD (e.g., ParallelCollection, + // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing. + if (testRDDSize) { + logInfo("Size of " + rddType + + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") + assert( + rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, + "Size of " + rddType + " did not reduce after checkpointing " + + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" + ) + } + + // Test whether serialized size of the partitions has reduced. If the partitions + // do not have any non-transient reference to another RDD or another RDD's partitions, it + // does not refer to a lineage and therefore may not reduce in size after checkpointing. + // However, if the original partitions before checkpointing do refer to a parent RDD, the partitions + // must be forgotten after checkpointing (to remove all reference to parent RDDs) and + // replaced with the HadooPartitions of the checkpointed RDD. + if (testRDDPartitionSize) { + logInfo("Size of " + rddType + " partitions " + + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]") + assert( + splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, + "Size of " + rddType + " partitions did not reduce after checkpointing " + + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" + ) + } + } + + /** + * Test whether checkpointing of the parent of the generated RDD also + * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent + * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, + * this RDD will remember the partitions and therefore potentially the whole lineage. + */ + def testParentCheckpointing[U: ClassManifest]( + op: (RDD[Int]) => RDD[U], + testRDDSize: Boolean, + testRDDPartitionSize: Boolean + ) { + // Generate the final RDD using given RDD operation + val baseRDD = generateLongLineageRDD() + val operatedRDD = op(baseRDD) + val parentRDD = operatedRDD.dependencies.head.rdd + val rddType = operatedRDD.getClass.getSimpleName + val parentRDDType = parentRDD.getClass.getSimpleName + + // Get the partitions and dependencies of the parent in case they're lazily computed + parentRDD.dependencies + parentRDD.partitions + + // Find serialized sizes before and after the checkpoint + val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) + parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one + val result = operatedRDD.collect() + val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) + + // Test whether the data in the checkpointed RDD is same as original + assert(operatedRDD.collect() === result) + + // Test whether serialized size of the RDD has reduced because of its parent being + // checkpointed. If this RDD or its parent RDD do not have any dependency + // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may + // not reduce in size after checkpointing. + if (testRDDSize) { + assert( + rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, + "Size of " + rddType + " did not reduce after checkpointing parent " + parentRDDType + + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" + ) + } + + // Test whether serialized size of the partitions has reduced because of its parent being + // checkpointed. If the partitions do not have any non-transient reference to another RDD + // or another RDD's partitions, it does not refer to a lineage and therefore may not reduce + // in size after checkpointing. However, if the partitions do refer to the *partitions* of a parent + // RDD, then these partitions must update reference to the parent RDD partitions as the parent RDD's + // partitions must have changed after checkpointing. + if (testRDDPartitionSize) { + assert( + splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, + "Size of " + rddType + " partitions did not reduce after checkpointing parent " + parentRDDType + + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" + ) + } + + } + + /** + * Generate an RDD with a long lineage of one-to-one dependencies. + */ + def generateLongLineageRDD(): RDD[Int] = { + var rdd = sc.makeRDD(1 to 100, 4) + for (i <- 1 to 50) { + rdd = rdd.map(x => x + 1) + } + rdd + } + + /** + * Generate an RDD with a long lineage specifically for CoGroupedRDD. + * A CoGroupedRDD can have a long lineage only one of its parents have a long lineage + * and narrow dependency with this RDD. This method generate such an RDD by a sequence + * of cogroups and mapValues which creates a long lineage of narrow dependencies. + */ + def generateLongLineageRDDForCoGroupedRDD() = { + val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _) + + def ones: RDD[(Int, Int)] = sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) + + var cogrouped: RDD[(Int, (Seq[Int], Seq[Int]))] = ones.cogroup(ones) + for(i <- 1 to 10) { + cogrouped = cogrouped.mapValues(add).cogroup(ones) + } + cogrouped.mapValues(add) + } + + /** + * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks + * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. + */ + def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { + (Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length, + Utils.serialize(rdd.partitions).length) + } + + /** + * Serialize and deserialize an object. This is useful to verify the objects + * contents after deserialization (e.g., the contents of an RDD split after + * it is sent to a slave along with a task) + */ + def serializeDeserialize[T](obj: T): T = { + val bytes = Utils.serialize(obj) + Utils.deserialize[T](bytes) + } +} + + +object CheckpointSuite { + // This is a custom cogroup function that does not use mapValues like + // the PairRDDFunctions.cogroup() + def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { + //println("First = " + first + ", second = " + second) + new CoGroupedRDD[K]( + Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]), + part + ).asInstanceOf[RDD[(K, Seq[Seq[V]])]] + } + +} diff --git a/core/src/test/scala/org/apache/spark/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ClosureCleanerSuite.scala new file mode 100644 index 0000000000..8494899b98 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ClosureCleanerSuite.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.NotSerializableException + +import org.scalatest.FunSuite +import org.apache.spark.LocalSparkContext._ +import SparkContext._ + +class ClosureCleanerSuite extends FunSuite { + test("closures inside an object") { + assert(TestObject.run() === 30) // 6 + 7 + 8 + 9 + } + + test("closures inside a class") { + val obj = new TestClass + assert(obj.run() === 30) // 6 + 7 + 8 + 9 + } + + test("closures inside a class with no default constructor") { + val obj = new TestClassWithoutDefaultConstructor(5) + assert(obj.run() === 30) // 6 + 7 + 8 + 9 + } + + test("closures that don't use fields of the outer class") { + val obj = new TestClassWithoutFieldAccess + assert(obj.run() === 30) // 6 + 7 + 8 + 9 + } + + test("nested closures inside an object") { + assert(TestObjectWithNesting.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1 + } + + test("nested closures inside a class") { + val obj = new TestClassWithNesting(1) + assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1 + } +} + +// A non-serializable class we create in closures to make sure that we aren't +// keeping references to unneeded variables from our outer closures. +class NonSerializable {} + +object TestObject { + def run(): Int = { + var nonSer = new NonSerializable + var x = 5 + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + x).reduce(_ + _) + } + } +} + +class TestClass extends Serializable { + var x = 5 + + def getX = x + + def run(): Int = { + var nonSer = new NonSerializable + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + getX).reduce(_ + _) + } + } +} + +class TestClassWithoutDefaultConstructor(x: Int) extends Serializable { + def getX = x + + def run(): Int = { + var nonSer = new NonSerializable + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + getX).reduce(_ + _) + } + } +} + +// This class is not serializable, but we aren't using any of its fields in our +// closures, so they won't have a $outer pointing to it and should still work. +class TestClassWithoutFieldAccess { + var nonSer = new NonSerializable + + def run(): Int = { + var nonSer2 = new NonSerializable + var x = 5 + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + x).reduce(_ + _) + } + } +} + + +object TestObjectWithNesting { + def run(): Int = { + var nonSer = new NonSerializable + var answer = 0 + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + var y = 1 + for (i <- 1 to 4) { + var nonSer2 = new NonSerializable + var x = i + answer += nums.map(_ + x + y).reduce(_ + _) + } + answer + } + } +} + +class TestClassWithNesting(val y: Int) extends Serializable { + def getY = y + + def run(): Int = { + var nonSer = new NonSerializable + var answer = 0 + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + for (i <- 1 to 4) { + var nonSer2 = new NonSerializable + var x = i + answer += nums.map(_ + x + getY).reduce(_ + _) + } + answer + } + } +} diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala new file mode 100644 index 0000000000..7a856d4081 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -0,0 +1,362 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import network.ConnectionManagerId +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.matchers.ShouldMatchers +import org.scalatest.prop.Checkers +import org.scalatest.time.{Span, Millis} +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen +import org.scalacheck.Prop._ +import org.eclipse.jetty.server.{Server, Request, Handler} + +import com.google.common.io.Files + +import scala.collection.mutable.ArrayBuffer + +import SparkContext._ +import storage.{GetBlock, BlockManagerWorker, StorageLevel} +import ui.JettyUtils + + +class NotSerializableClass +class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} + + +class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter + with LocalSparkContext { + + val clusterUrl = "local-cluster[2,1,512]" + + after { + System.clearProperty("spark.reducer.maxMbInFlight") + System.clearProperty("spark.storage.memoryFraction") + } + + test("task throws not serializable exception") { + // Ensures that executors do not crash when an exn is not serializable. If executors crash, + // this test will hang. Correct behavior is that executors don't crash but fail tasks + // and the scheduler throws a SparkException. + + // numSlaves must be less than numPartitions + val numSlaves = 3 + val numPartitions = 10 + + sc = new SparkContext("local-cluster[%s,1,512]".format(numSlaves), "test") + val data = sc.parallelize(1 to 100, numPartitions). + map(x => throw new NotSerializableExn(new NotSerializableClass)) + intercept[SparkException] { + data.count() + } + resetSparkContext() + } + + test("local-cluster format") { + sc = new SparkContext("local-cluster[2,1,512]", "test") + assert(sc.parallelize(1 to 2, 2).count() == 2) + resetSparkContext() + sc = new SparkContext("local-cluster[2 , 1 , 512]", "test") + assert(sc.parallelize(1 to 2, 2).count() == 2) + resetSparkContext() + sc = new SparkContext("local-cluster[2, 1, 512]", "test") + assert(sc.parallelize(1 to 2, 2).count() == 2) + resetSparkContext() + sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test") + assert(sc.parallelize(1 to 2, 2).count() == 2) + resetSparkContext() + } + + test("simple groupByKey") { + sc = new SparkContext(clusterUrl, "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 5) + val groups = pairs.groupByKey(5).collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey where map output sizes exceed maxMbInFlight") { + System.setProperty("spark.reducer.maxMbInFlight", "1") + sc = new SparkContext(clusterUrl, "test") + // This data should be around 20 MB, so even with 4 mappers and 2 reducers, each map output + // file should be about 2.5 MB + val pairs = sc.parallelize(1 to 2000, 4).map(x => (x % 16, new Array[Byte](10000))) + val groups = pairs.groupByKey(2).map(x => (x._1, x._2.size)).collect() + assert(groups.length === 16) + assert(groups.map(_._2).sum === 2000) + // Note that spark.reducer.maxMbInFlight will be cleared in the test suite's after{} block + } + + test("accumulators") { + sc = new SparkContext(clusterUrl, "test") + val accum = sc.accumulator(0) + sc.parallelize(1 to 10, 10).foreach(x => accum += x) + assert(accum.value === 55) + } + + test("broadcast variables") { + sc = new SparkContext(clusterUrl, "test") + val array = new Array[Int](100) + val bv = sc.broadcast(array) + array(2) = 3 // Change the array -- this should not be seen on workers + val rdd = sc.parallelize(1 to 10, 10) + val sum = rdd.map(x => bv.value.sum).reduce(_ + _) + assert(sum === 0) + } + + test("repeatedly failing task") { + sc = new SparkContext(clusterUrl, "test") + val accum = sc.accumulator(0) + val thrown = intercept[SparkException] { + sc.parallelize(1 to 10, 10).foreach(x => println(x / 0)) + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getMessage.contains("more than 4 times")) + } + + test("caching") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).cache() + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching on disk") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY) + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching in memory, replicated") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_2) + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching in memory, serialized, replicated") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_SER_2) + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching on disk, replicated") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY_2) + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching in memory and disk, replicated") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_2) + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + } + + test("caching in memory and disk, serialized, replicated") { + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_SER_2) + + assert(data.count() === 1000) + assert(data.count() === 1000) + assert(data.count() === 1000) + + // Get all the locations of the first partition and try to fetch the partitions + // from those locations. + val blockIds = data.partitions.indices.map(index => "rdd_%d_%d".format(data.id, index)).toArray + val blockId = blockIds(0) + val blockManager = SparkEnv.get.blockManager + blockManager.master.getLocations(blockId).foreach(id => { + val bytes = BlockManagerWorker.syncGetBlock( + GetBlock(blockId), ConnectionManagerId(id.host, id.port)) + val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList + assert(deserialized === (1 to 100).toList) + }) + } + + test("compute without caching when no partitions fit in memory") { + System.setProperty("spark.storage.memoryFraction", "0.0001") + sc = new SparkContext(clusterUrl, "test") + // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache + // to only 50 KB (0.0001 of 512 MB), so no partitions should fit in memory + val data = sc.parallelize(1 to 4000000, 2).persist(StorageLevel.MEMORY_ONLY_SER) + assert(data.count() === 4000000) + assert(data.count() === 4000000) + assert(data.count() === 4000000) + System.clearProperty("spark.storage.memoryFraction") + } + + test("compute when only some partitions fit in memory") { + System.setProperty("spark.storage.memoryFraction", "0.01") + sc = new SparkContext(clusterUrl, "test") + // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache + // to only 5 MB (0.01 of 512 MB), so not all of it will fit in memory; we use 20 partitions + // to make sure that *some* of them do fit though + val data = sc.parallelize(1 to 4000000, 20).persist(StorageLevel.MEMORY_ONLY_SER) + assert(data.count() === 4000000) + assert(data.count() === 4000000) + assert(data.count() === 4000000) + System.clearProperty("spark.storage.memoryFraction") + } + + test("passing environment variables to cluster") { + sc = new SparkContext(clusterUrl, "test", null, Nil, Map("TEST_VAR" -> "TEST_VALUE")) + val values = sc.parallelize(1 to 2, 2).map(x => System.getenv("TEST_VAR")).collect() + assert(values.toSeq === Seq("TEST_VALUE", "TEST_VALUE")) + } + + test("recover from node failures") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(Seq(true, true), 2) + assert(data.count === 2) // force executors to start + assert(data.map(markNodeIfIdentity).collect.size === 2) + assert(data.map(failOnMarkedIdentity).collect.size === 2) + } + + test("recover from repeated node failures during shuffle-map") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + for (i <- 1 to 3) { + val data = sc.parallelize(Seq(true, false), 2) + assert(data.count === 2) + assert(data.map(markNodeIfIdentity).collect.size === 2) + assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2) + } + } + + test("recover from repeated node failures during shuffle-reduce") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + for (i <- 1 to 3) { + val data = sc.parallelize(Seq(true, true), 2) + assert(data.count === 2) + assert(data.map(markNodeIfIdentity).collect.size === 2) + // This relies on mergeCombiners being used to perform the actual reduce for this + // test to actually be testing what it claims. + val grouped = data.map(x => x -> x).combineByKey( + x => x, + (x: Boolean, y: Boolean) => x, + (x: Boolean, y: Boolean) => failOnMarkedIdentity(x) + ) + assert(grouped.collect.size === 1) + } + } + + test("recover from node failures with replication") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + // Using more than two nodes so we don't have a symmetric communication pattern and might + // cache a partially correct list of peers. + sc = new SparkContext("local-cluster[3,1,512]", "test") + for (i <- 1 to 3) { + val data = sc.parallelize(Seq(true, false, false, false), 4) + data.persist(StorageLevel.MEMORY_ONLY_2) + + assert(data.count === 4) + assert(data.map(markNodeIfIdentity).collect.size === 4) + assert(data.map(failOnMarkedIdentity).collect.size === 4) + + // Create a new replicated RDD to make sure that cached peer information doesn't cause + // problems. + val data2 = sc.parallelize(Seq(true, true), 2).persist(StorageLevel.MEMORY_ONLY_2) + assert(data2.count === 2) + } + } + + test("unpersist RDDs") { + DistributedSuite.amMaster = true + sc = new SparkContext("local-cluster[3,1,512]", "test") + val data = sc.parallelize(Seq(true, false, false, false), 4) + data.persist(StorageLevel.MEMORY_ONLY_2) + data.count + assert(sc.persistentRdds.isEmpty === false) + data.unpersist() + assert(sc.persistentRdds.isEmpty === true) + + failAfter(Span(3000, Millis)) { + try { + while (! sc.getRDDStorageInfo.isEmpty) { + Thread.sleep(200) + } + } catch { + case _ => { Thread.sleep(10) } + // Do nothing. We might see exceptions because block manager + // is racing this thread to remove entries from the driver. + } + } + } + + test("job should fail if TaskResult exceeds Akka frame size") { + // We must use local-cluster mode since results are returned differently + // when running under LocalScheduler: + sc = new SparkContext("local-cluster[1,1,512]", "test") + val akkaFrameSize = + sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt + val rdd = sc.parallelize(Seq(1)).map{x => new Array[Byte](akkaFrameSize)} + val exception = intercept[SparkException] { + rdd.reduce((x, y) => x) + } + exception.getMessage should endWith("result exceeded Akka frame size") + } +} + +object DistributedSuite { + // Indicates whether this JVM is marked for failure. + var mark = false + + // Set by test to remember if we are in the driver program so we can assert + // that we are not. + var amMaster = false + + // Act like an identity function, but if the argument is true, set mark to true. + def markNodeIfIdentity(item: Boolean): Boolean = { + if (item) { + assert(!amMaster) + mark = true + } + item + } + + // Act like an identity function, but if mark was set to true previously, fail, + // crashing the entire JVM. + def failOnMarkedIdentity(item: Boolean): Boolean = { + if (mark) { + System.exit(42) + } + item + } +} diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala new file mode 100644 index 0000000000..b08aad1a6f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.File + +import org.apache.log4j.Logger +import org.apache.log4j.Level + +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts +import org.scalatest.prop.TableDrivenPropertyChecks._ +import org.scalatest.time.SpanSugar._ + +class DriverSuite extends FunSuite with Timeouts { + test("driver should exit after finishing") { + assert(System.getenv("SPARK_HOME") != null) + // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" + val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) + forAll(masters) { (master: String) => + failAfter(30 seconds) { + Utils.execute(Seq("./spark-class", "org.apache.spark.DriverWithoutCleanup", master), + new File(System.getenv("SPARK_HOME"))) + } + } + } +} + +/** + * Program that creates a Spark driver but doesn't call SparkContext.stop() or + * Sys.exit() after finishing. + */ +object DriverWithoutCleanup { + def main(args: Array[String]) { + Logger.getRootLogger().setLevel(Level.WARN) + val sc = new SparkContext(args(0), "DriverWithoutCleanup") + sc.parallelize(1 to 100, 4).count() + } +} diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala new file mode 100644 index 0000000000..ee89a7a387 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +import SparkContext._ + +// Common state shared by FailureSuite-launched tasks. We use a global object +// for this because any local variables used in the task closures will rightfully +// be copied for each task, so there's no other way for them to share state. +object FailureSuiteState { + var tasksRun = 0 + var tasksFailed = 0 + + def clear() { + synchronized { + tasksRun = 0 + tasksFailed = 0 + } + } +} + +class FailureSuite extends FunSuite with LocalSparkContext { + + // Run a 3-task map job in which task 1 deterministically fails once, and check + // whether the job completes successfully and we ran 4 tasks in total. + test("failure in a single-stage job") { + sc = new SparkContext("local[1,1]", "test") + val results = sc.makeRDD(1 to 3, 3).map { x => + FailureSuiteState.synchronized { + FailureSuiteState.tasksRun += 1 + if (x == 1 && FailureSuiteState.tasksFailed == 0) { + FailureSuiteState.tasksFailed += 1 + throw new Exception("Intentional task failure") + } + } + x * x + }.collect() + FailureSuiteState.synchronized { + assert(FailureSuiteState.tasksRun === 4) + } + assert(results.toList === List(1,4,9)) + FailureSuiteState.clear() + } + + // Run a map-reduce job in which a reduce task deterministically fails once. + test("failure in a two-stage job") { + sc = new SparkContext("local[1,1]", "test") + val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map { + case (k, v) => + FailureSuiteState.synchronized { + FailureSuiteState.tasksRun += 1 + if (k == 1 && FailureSuiteState.tasksFailed == 0) { + FailureSuiteState.tasksFailed += 1 + throw new Exception("Intentional task failure") + } + } + (k, v(0) * v(0)) + }.collect() + FailureSuiteState.synchronized { + assert(FailureSuiteState.tasksRun === 4) + } + assert(results.toSet === Set((1, 1), (2, 4), (3, 9))) + FailureSuiteState.clear() + } + + test("failure because task results are not serializable") { + sc = new SparkContext("local[1,1]", "test") + val results = sc.makeRDD(1 to 3).map(x => new NonSerializable) + + val thrown = intercept[SparkException] { + results.collect() + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getMessage.contains("NotSerializableException")) + + FailureSuiteState.clear() + } + + test("failure because task closure is not serializable") { + sc = new SparkContext("local[1,1]", "test") + val a = new NonSerializable + + // Non-serializable closure in the final result stage + val thrown = intercept[SparkException] { + sc.parallelize(1 to 10, 2).map(x => a).count() + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getMessage.contains("NotSerializableException")) + + // Non-serializable closure in an earlier stage + val thrown1 = intercept[SparkException] { + sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count() + } + assert(thrown1.getClass === classOf[SparkException]) + assert(thrown1.getMessage.contains("NotSerializableException")) + + // Non-serializable closure in foreach function + val thrown2 = intercept[SparkException] { + sc.parallelize(1 to 10, 2).foreach(x => println(a)) + } + assert(thrown2.getClass === classOf[SparkException]) + assert(thrown2.getMessage.contains("NotSerializableException")) + + FailureSuiteState.clear() + } + + // TODO: Need to add tests with shuffle fetch failures. +} + + diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala new file mode 100644 index 0000000000..35d1d41af1 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import com.google.common.io.Files +import org.scalatest.FunSuite +import java.io.{File, PrintWriter, FileReader, BufferedReader} +import SparkContext._ + +class FileServerSuite extends FunSuite with LocalSparkContext { + + @transient var tmpFile: File = _ + @transient var testJarFile: File = _ + + override def beforeEach() { + super.beforeEach() + // Create a sample text file + val tmpdir = new File(Files.createTempDir(), "test") + tmpdir.mkdir() + tmpFile = new File(tmpdir, "FileServerSuite.txt") + val pw = new PrintWriter(tmpFile) + pw.println("100") + pw.close() + } + + override def afterEach() { + super.afterEach() + // Clean up downloaded file + if (tmpFile.exists) { + tmpFile.delete() + } + } + + test("Distributing files locally") { + sc = new SparkContext("local[4]", "test") + sc.addFile(tmpFile.toString) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val result = sc.parallelize(testData).reduceByKey { + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) + val fileVal = in.readLine().toInt + in.close() + _ * fileVal + _ * fileVal + }.collect() + assert(result.toSet === Set((1,200), (2,300), (3,500))) + } + + test("Distributing files locally using URL as input") { + // addFile("file:///....") + sc = new SparkContext("local[4]", "test") + sc.addFile(new File(tmpFile.toString).toURI.toString) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val result = sc.parallelize(testData).reduceByKey { + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) + val fileVal = in.readLine().toInt + in.close() + _ * fileVal + _ * fileVal + }.collect() + assert(result.toSet === Set((1,200), (2,300), (3,500))) + } + + test ("Dynamically adding JARS locally") { + sc = new SparkContext("local[4]", "test") + val sampleJarFile = getClass.getClassLoader.getResource("uncommons-maths-1.2.2.jar").getFile() + sc.addJar(sampleJarFile) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0)) + val result = sc.parallelize(testData).reduceByKey { (x,y) => + val fac = Thread.currentThread.getContextClassLoader() + .loadClass("org.uncommons.maths.Maths") + .getDeclaredMethod("factorial", classOf[Int]) + val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt + val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt + a + b + }.collect() + assert(result.toSet === Set((1,2), (2,7), (3,121))) + } + + test("Distributing files on a standalone cluster") { + sc = new SparkContext("local-cluster[1,1,512]", "test") + sc.addFile(tmpFile.toString) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val result = sc.parallelize(testData).reduceByKey { + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) + val fileVal = in.readLine().toInt + in.close() + _ * fileVal + _ * fileVal + }.collect() + assert(result.toSet === Set((1,200), (2,300), (3,500))) + } + + test ("Dynamically adding JARS on a standalone cluster") { + sc = new SparkContext("local-cluster[1,1,512]", "test") + val sampleJarFile = getClass.getClassLoader.getResource("uncommons-maths-1.2.2.jar").getFile() + sc.addJar(sampleJarFile) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0)) + val result = sc.parallelize(testData).reduceByKey { (x,y) => + val fac = Thread.currentThread.getContextClassLoader() + .loadClass("org.uncommons.maths.Maths") + .getDeclaredMethod("factorial", classOf[Int]) + val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt + val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt + a + b + }.collect() + assert(result.toSet === Set((1,2), (2,7), (3,121))) + } +} diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala new file mode 100644 index 0000000000..7b82a4cdd9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.{FileWriter, PrintWriter, File} + +import scala.io.Source + +import com.google.common.io.Files +import org.scalatest.FunSuite +import org.apache.hadoop.io._ +import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodec, GzipCodec} + + +import SparkContext._ + +class FileSuite extends FunSuite with LocalSparkContext { + + test("text files") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 4) + nums.saveAsTextFile(outputDir) + // Read the plain text file and check it's OK + val outputFile = new File(outputDir, "part-00000") + val content = Source.fromFile(outputFile).mkString + assert(content === "1\n2\n3\n4\n") + // Also try reading it in as a text file RDD + assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4")) + } + + test("text files (compressed)") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val normalDir = new File(tempDir, "output_normal").getAbsolutePath + val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath + val codec = new DefaultCodec() + + val data = sc.parallelize("a" * 10000, 1) + data.saveAsTextFile(normalDir) + data.saveAsTextFile(compressedOutputDir, classOf[DefaultCodec]) + + val normalFile = new File(normalDir, "part-00000") + val normalContent = sc.textFile(normalDir).collect + assert(normalContent === Array.fill(10000)("a")) + + val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension) + val compressedContent = sc.textFile(compressedOutputDir).collect + assert(compressedContent === Array.fill(10000)("a")) + + assert(compressedFile.length < normalFile.length) + } + + test("SequenceFiles") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) // (1,a), (2,aa), (3,aaa) + nums.saveAsSequenceFile(outputDir) + // Try reading the output back as a SequenceFile + val output = sc.sequenceFile[IntWritable, Text](outputDir) + assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + } + + test("SequenceFile (compressed)") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val normalDir = new File(tempDir, "output_normal").getAbsolutePath + val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath + val codec = new DefaultCodec() + + val data = sc.parallelize(Seq.fill(100)("abc"), 1).map(x => (x, x)) + data.saveAsSequenceFile(normalDir) + data.saveAsSequenceFile(compressedOutputDir, Some(classOf[DefaultCodec])) + + val normalFile = new File(normalDir, "part-00000") + val normalContent = sc.sequenceFile[String, String](normalDir).collect + assert(normalContent === Array.fill(100)("abc", "abc")) + + val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension) + val compressedContent = sc.sequenceFile[String, String](compressedOutputDir).collect + assert(compressedContent === Array.fill(100)("abc", "abc")) + + assert(compressedFile.length < normalFile.length) + } + + test("SequenceFile with writable key") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), "a" * x)) + nums.saveAsSequenceFile(outputDir) + // Try reading the output back as a SequenceFile + val output = sc.sequenceFile[IntWritable, Text](outputDir) + assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + } + + test("SequenceFile with writable value") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (x, new Text("a" * x))) + nums.saveAsSequenceFile(outputDir) + // Try reading the output back as a SequenceFile + val output = sc.sequenceFile[IntWritable, Text](outputDir) + assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + } + + test("SequenceFile with writable key and value") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) + nums.saveAsSequenceFile(outputDir) + // Try reading the output back as a SequenceFile + val output = sc.sequenceFile[IntWritable, Text](outputDir) + assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + } + + test("implicit conversions in reading SequenceFiles") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) // (1,a), (2,aa), (3,aaa) + nums.saveAsSequenceFile(outputDir) + // Similar to the tests above, we read a SequenceFile, but this time we pass type params + // that are convertable to Writable instead of calling sequenceFile[IntWritable, Text] + val output1 = sc.sequenceFile[Int, String](outputDir) + assert(output1.collect().toList === List((1, "a"), (2, "aa"), (3, "aaa"))) + // Also try having one type be a subclass of Writable and one not + val output2 = sc.sequenceFile[Int, Text](outputDir) + assert(output2.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + val output3 = sc.sequenceFile[IntWritable, String](outputDir) + assert(output3.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + } + + test("object files of ints") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 4) + nums.saveAsObjectFile(outputDir) + // Try reading the output back as an object file + val output = sc.objectFile[Int](outputDir) + assert(output.collect().toList === List(1, 2, 3, 4)) + } + + test("object files of complex types") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) + nums.saveAsObjectFile(outputDir) + // Try reading the output back as an object file + val output = sc.objectFile[(Int, String)](outputDir) + assert(output.collect().toList === List((1, "a"), (2, "aa"), (3, "aaa"))) + } + + test("write SequenceFile using new Hadoop API") { + import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) + nums.saveAsNewAPIHadoopFile[SequenceFileOutputFormat[IntWritable, Text]]( + outputDir) + val output = sc.sequenceFile[IntWritable, Text](outputDir) + assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + } + + test("read SequenceFile using new Hadoop API") { + import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) + nums.saveAsSequenceFile(outputDir) + val output = + sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir) + assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + } + + test("file caching") { + sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val out = new FileWriter(tempDir + "/input") + out.write("Hello world!\n") + out.write("What's up?\n") + out.write("Goodbye\n") + out.close() + val rdd = sc.textFile(tempDir + "/input").cache() + assert(rdd.count() === 3) + assert(rdd.count() === 3) + assert(rdd.count() === 3) + } +} diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java new file mode 100644 index 0000000000..8a869c9005 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -0,0 +1,865 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import java.io.File; +import java.io.IOException; +import java.io.Serializable; +import java.util.*; + +import com.google.common.base.Optional; +import scala.Tuple2; + +import com.google.common.base.Charsets; +import org.apache.hadoop.io.compress.DefaultCodec; +import com.google.common.io.Files; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapred.SequenceFileInputFormat; +import org.apache.hadoop.mapred.SequenceFileOutputFormat; +import org.apache.hadoop.mapreduce.Job; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.*; +import org.apache.spark.partial.BoundedDouble; +import org.apache.spark.partial.PartialResult; +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.util.StatCounter; + + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaAPISuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaAPISuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port"); + } + + static class ReverseIntComparator implements Comparator, Serializable { + + @Override + public int compare(Integer a, Integer b) { + if (a > b) return -1; + else if (a < b) return 1; + else return 0; + } + }; + + @Test + public void sparkContextUnion() { + // Union of non-specialized JavaRDDs + List strings = Arrays.asList("Hello", "World"); + JavaRDD s1 = sc.parallelize(strings); + JavaRDD s2 = sc.parallelize(strings); + // Varargs + JavaRDD sUnion = sc.union(s1, s2); + Assert.assertEquals(4, sUnion.count()); + // List + List> list = new ArrayList>(); + list.add(s2); + sUnion = sc.union(s1, list); + Assert.assertEquals(4, sUnion.count()); + + // Union of JavaDoubleRDDs + List doubles = Arrays.asList(1.0, 2.0); + JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles); + JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles); + JavaDoubleRDD dUnion = sc.union(d1, d2); + Assert.assertEquals(4, dUnion.count()); + + // Union of JavaPairRDDs + List> pairs = new ArrayList>(); + pairs.add(new Tuple2(1, 2)); + pairs.add(new Tuple2(3, 4)); + JavaPairRDD p1 = sc.parallelizePairs(pairs); + JavaPairRDD p2 = sc.parallelizePairs(pairs); + JavaPairRDD pUnion = sc.union(p1, p2); + Assert.assertEquals(4, pUnion.count()); + } + + @Test + public void sortByKey() { + List> pairs = new ArrayList>(); + pairs.add(new Tuple2(0, 4)); + pairs.add(new Tuple2(3, 2)); + pairs.add(new Tuple2(-1, 1)); + + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + // Default comparator + JavaPairRDD sortedRDD = rdd.sortByKey(); + Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + List> sortedPairs = sortedRDD.collect(); + Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + + // Custom comparator + sortedRDD = rdd.sortByKey(new ReverseIntComparator(), false); + Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + sortedPairs = sortedRDD.collect(); + Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + } + + static int foreachCalls = 0; + + @Test + public void foreach() { + foreachCalls = 0; + JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); + rdd.foreach(new VoidFunction() { + @Override + public void call(String s) { + foreachCalls++; + } + }); + Assert.assertEquals(2, foreachCalls); + } + + @Test + public void lookup() { + JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( + new Tuple2("Apples", "Fruit"), + new Tuple2("Oranges", "Fruit"), + new Tuple2("Oranges", "Citrus") + )); + Assert.assertEquals(2, categories.lookup("Oranges").size()); + Assert.assertEquals(2, categories.groupByKey().lookup("Oranges").get(0).size()); + } + + @Test + public void groupBy() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); + Function isOdd = new Function() { + @Override + public Boolean call(Integer x) { + return x % 2 == 0; + } + }; + JavaPairRDD> oddsAndEvens = rdd.groupBy(isOdd); + Assert.assertEquals(2, oddsAndEvens.count()); + Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens + Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds + + oddsAndEvens = rdd.groupBy(isOdd, 1); + Assert.assertEquals(2, oddsAndEvens.count()); + Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens + Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds + } + + @Test + public void cogroup() { + JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( + new Tuple2("Apples", "Fruit"), + new Tuple2("Oranges", "Fruit"), + new Tuple2("Oranges", "Citrus") + )); + JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( + new Tuple2("Oranges", 2), + new Tuple2("Apples", 3) + )); + JavaPairRDD, List>> cogrouped = categories.cogroup(prices); + Assert.assertEquals("[Fruit, Citrus]", cogrouped.lookup("Oranges").get(0)._1().toString()); + Assert.assertEquals("[2]", cogrouped.lookup("Oranges").get(0)._2().toString()); + + cogrouped.collect(); + } + + @Test + public void leftOuterJoin() { + JavaPairRDD rdd1 = sc.parallelizePairs(Arrays.asList( + new Tuple2(1, 1), + new Tuple2(1, 2), + new Tuple2(2, 1), + new Tuple2(3, 1) + )); + JavaPairRDD rdd2 = sc.parallelizePairs(Arrays.asList( + new Tuple2(1, 'x'), + new Tuple2(2, 'y'), + new Tuple2(2, 'z'), + new Tuple2(4, 'w') + )); + List>>> joined = + rdd1.leftOuterJoin(rdd2).collect(); + Assert.assertEquals(5, joined.size()); + Tuple2>> firstUnmatched = + rdd1.leftOuterJoin(rdd2).filter( + new Function>>, Boolean>() { + @Override + public Boolean call(Tuple2>> tup) + throws Exception { + return !tup._2()._2().isPresent(); + } + }).first(); + Assert.assertEquals(3, firstUnmatched._1().intValue()); + } + + @Test + public void foldReduce() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); + Function2 add = new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }; + + int sum = rdd.fold(0, add); + Assert.assertEquals(33, sum); + + sum = rdd.reduce(add); + Assert.assertEquals(33, sum); + } + + @Test + public void foldByKey() { + List> pairs = Arrays.asList( + new Tuple2(2, 1), + new Tuple2(2, 1), + new Tuple2(1, 1), + new Tuple2(3, 2), + new Tuple2(3, 1) + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + JavaPairRDD sums = rdd.foldByKey(0, + new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }); + Assert.assertEquals(1, sums.lookup(1).get(0).intValue()); + Assert.assertEquals(2, sums.lookup(2).get(0).intValue()); + Assert.assertEquals(3, sums.lookup(3).get(0).intValue()); + } + + @Test + public void reduceByKey() { + List> pairs = Arrays.asList( + new Tuple2(2, 1), + new Tuple2(2, 1), + new Tuple2(1, 1), + new Tuple2(3, 2), + new Tuple2(3, 1) + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + JavaPairRDD counts = rdd.reduceByKey( + new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }); + Assert.assertEquals(1, counts.lookup(1).get(0).intValue()); + Assert.assertEquals(2, counts.lookup(2).get(0).intValue()); + Assert.assertEquals(3, counts.lookup(3).get(0).intValue()); + + Map localCounts = counts.collectAsMap(); + Assert.assertEquals(1, localCounts.get(1).intValue()); + Assert.assertEquals(2, localCounts.get(2).intValue()); + Assert.assertEquals(3, localCounts.get(3).intValue()); + + localCounts = rdd.reduceByKeyLocally(new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }); + Assert.assertEquals(1, localCounts.get(1).intValue()); + Assert.assertEquals(2, localCounts.get(2).intValue()); + Assert.assertEquals(3, localCounts.get(3).intValue()); + } + + @Test + public void approximateResults() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); + Map countsByValue = rdd.countByValue(); + Assert.assertEquals(2, countsByValue.get(1).longValue()); + Assert.assertEquals(1, countsByValue.get(13).longValue()); + + PartialResult> approx = rdd.countByValueApprox(1); + Map finalValue = approx.getFinalValue(); + Assert.assertEquals(2.0, finalValue.get(1).mean(), 0.01); + Assert.assertEquals(1.0, finalValue.get(13).mean(), 0.01); + } + + @Test + public void take() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); + Assert.assertEquals(1, rdd.first().intValue()); + List firstTwo = rdd.take(2); + List sample = rdd.takeSample(false, 2, 42); + } + + @Test + public void cartesian() { + JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); + JavaRDD stringRDD = sc.parallelize(Arrays.asList("Hello", "World")); + JavaPairRDD cartesian = stringRDD.cartesian(doubleRDD); + Assert.assertEquals(new Tuple2("Hello", 1.0), cartesian.first()); + } + + @Test + public void javaDoubleRDD() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); + JavaDoubleRDD distinct = rdd.distinct(); + Assert.assertEquals(5, distinct.count()); + JavaDoubleRDD filter = rdd.filter(new Function() { + @Override + public Boolean call(Double x) { + return x > 2.0; + } + }); + Assert.assertEquals(3, filter.count()); + JavaDoubleRDD union = rdd.union(rdd); + Assert.assertEquals(12, union.count()); + union = union.cache(); + Assert.assertEquals(12, union.count()); + + Assert.assertEquals(20, rdd.sum(), 0.01); + StatCounter stats = rdd.stats(); + Assert.assertEquals(20, stats.sum(), 0.01); + Assert.assertEquals(20/6.0, rdd.mean(), 0.01); + Assert.assertEquals(20/6.0, rdd.mean(), 0.01); + Assert.assertEquals(6.22222, rdd.variance(), 0.01); + Assert.assertEquals(7.46667, rdd.sampleVariance(), 0.01); + Assert.assertEquals(2.49444, rdd.stdev(), 0.01); + Assert.assertEquals(2.73252, rdd.sampleStdev(), 0.01); + + Double first = rdd.first(); + List take = rdd.take(5); + } + + @Test + public void map() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + JavaDoubleRDD doubles = rdd.map(new DoubleFunction() { + @Override + public Double call(Integer x) { + return 1.0 * x; + } + }).cache(); + JavaPairRDD pairs = rdd.map(new PairFunction() { + @Override + public Tuple2 call(Integer x) { + return new Tuple2(x, x); + } + }).cache(); + JavaRDD strings = rdd.map(new Function() { + @Override + public String call(Integer x) { + return x.toString(); + } + }).cache(); + } + + @Test + public void flatMap() { + JavaRDD rdd = sc.parallelize(Arrays.asList("Hello World!", + "The quick brown fox jumps over the lazy dog.")); + JavaRDD words = rdd.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Arrays.asList(x.split(" ")); + } + }); + Assert.assertEquals("Hello", words.first()); + Assert.assertEquals(11, words.count()); + + JavaPairRDD pairs = rdd.flatMap( + new PairFlatMapFunction() { + + @Override + public Iterable> call(String s) { + List> pairs = new LinkedList>(); + for (String word : s.split(" ")) pairs.add(new Tuple2(word, word)); + return pairs; + } + } + ); + Assert.assertEquals(new Tuple2("Hello", "Hello"), pairs.first()); + Assert.assertEquals(11, pairs.count()); + + JavaDoubleRDD doubles = rdd.flatMap(new DoubleFlatMapFunction() { + @Override + public Iterable call(String s) { + List lengths = new LinkedList(); + for (String word : s.split(" ")) lengths.add(word.length() * 1.0); + return lengths; + } + }); + Double x = doubles.first(); + Assert.assertEquals(5.0, doubles.first().doubleValue(), 0.01); + Assert.assertEquals(11, pairs.count()); + } + + @Test + public void mapsFromPairsToPairs() { + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD pairRDD = sc.parallelizePairs(pairs); + + // Regression test for SPARK-668: + JavaPairRDD swapped = pairRDD.flatMap( + new PairFlatMapFunction, String, Integer>() { + @Override + public Iterable> call(Tuple2 item) throws Exception { + return Collections.singletonList(item.swap()); + } + }); + swapped.collect(); + + // There was never a bug here, but it's worth testing: + pairRDD.map(new PairFunction, String, Integer>() { + @Override + public Tuple2 call(Tuple2 item) throws Exception { + return item.swap(); + } + }).collect(); + } + + @Test + public void mapPartitions() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); + JavaRDD partitionSums = rdd.mapPartitions( + new FlatMapFunction, Integer>() { + @Override + public Iterable call(Iterator iter) { + int sum = 0; + while (iter.hasNext()) { + sum += iter.next(); + } + return Collections.singletonList(sum); + } + }); + Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); + } + + @Test + public void persist() { + JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); + doubleRDD = doubleRDD.persist(StorageLevel.DISK_ONLY()); + Assert.assertEquals(20, doubleRDD.sum(), 0.1); + + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD pairRDD = sc.parallelizePairs(pairs); + pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY()); + Assert.assertEquals("a", pairRDD.first()._2()); + + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + rdd = rdd.persist(StorageLevel.DISK_ONLY()); + Assert.assertEquals(1, rdd.first().intValue()); + } + + @Test + public void iterator() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); + TaskContext context = new TaskContext(0, 0, 0, null); + Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue()); + } + + @Test + public void glom() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); + Assert.assertEquals("[1, 2]", rdd.glom().first().toString()); + } + + // File input / output tests are largely adapted from FileSuite: + + @Test + public void textFiles() throws IOException { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + rdd.saveAsTextFile(outputDir); + // Read the plain text file and check it's OK + File outputFile = new File(outputDir, "part-00000"); + String content = Files.toString(outputFile, Charsets.UTF_8); + Assert.assertEquals("1\n2\n3\n4\n", content); + // Also try reading it in as a text file RDD + List expected = Arrays.asList("1", "2", "3", "4"); + JavaRDD readRDD = sc.textFile(outputDir); + Assert.assertEquals(expected, readRDD.collect()); + } + + @Test + public void textFilesCompressed() throws IOException { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + rdd.saveAsTextFile(outputDir, DefaultCodec.class); + + // Try reading it in as a text file RDD + List expected = Arrays.asList("1", "2", "3", "4"); + JavaRDD readRDD = sc.textFile(outputDir); + Assert.assertEquals(expected, readRDD.collect()); + } + + @Test + public void sequenceFile() { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + rdd.map(new PairFunction, IntWritable, Text>() { + @Override + public Tuple2 call(Tuple2 pair) { + return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + } + }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + + // Try reading the output back as an object file + JavaPairRDD readRDD = sc.sequenceFile(outputDir, IntWritable.class, + Text.class).map(new PairFunction, Integer, String>() { + @Override + public Tuple2 call(Tuple2 pair) { + return new Tuple2(pair._1().get(), pair._2().toString()); + } + }); + Assert.assertEquals(pairs, readRDD.collect()); + } + + @Test + public void writeWithNewAPIHadoopFile() { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + rdd.map(new PairFunction, IntWritable, Text>() { + @Override + public Tuple2 call(Tuple2 pair) { + return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + } + }).saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class, + org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); + + JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, + Text.class); + Assert.assertEquals(pairs.toString(), output.map(new Function, + String>() { + @Override + public String call(Tuple2 x) { + return x.toString(); + } + }).collect().toString()); + } + + @Test + public void readWithNewAPIHadoopFile() throws IOException { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + rdd.map(new PairFunction, IntWritable, Text>() { + @Override + public Tuple2 call(Tuple2 pair) { + return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + } + }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + + JavaPairRDD output = sc.newAPIHadoopFile(outputDir, + org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, IntWritable.class, + Text.class, new Job().getConfiguration()); + Assert.assertEquals(pairs.toString(), output.map(new Function, + String>() { + @Override + public String call(Tuple2 x) { + return x.toString(); + } + }).collect().toString()); + } + + @Test + public void objectFilesOfInts() { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + rdd.saveAsObjectFile(outputDir); + // Try reading the output back as an object file + List expected = Arrays.asList(1, 2, 3, 4); + JavaRDD readRDD = sc.objectFile(outputDir); + Assert.assertEquals(expected, readRDD.collect()); + } + + @Test + public void objectFilesOfComplexTypes() { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + rdd.saveAsObjectFile(outputDir); + // Try reading the output back as an object file + JavaRDD> readRDD = sc.objectFile(outputDir); + Assert.assertEquals(pairs, readRDD.collect()); + } + + @Test + public void hadoopFile() { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + rdd.map(new PairFunction, IntWritable, Text>() { + @Override + public Tuple2 call(Tuple2 pair) { + return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + } + }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + + JavaPairRDD output = sc.hadoopFile(outputDir, + SequenceFileInputFormat.class, IntWritable.class, Text.class); + Assert.assertEquals(pairs.toString(), output.map(new Function, + String>() { + @Override + public String call(Tuple2 x) { + return x.toString(); + } + }).collect().toString()); + } + + @Test + public void hadoopFileCompressed() { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output_compressed").getAbsolutePath(); + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + rdd.map(new PairFunction, IntWritable, Text>() { + @Override + public Tuple2 call(Tuple2 pair) { + return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + } + }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, + DefaultCodec.class); + + JavaPairRDD output = sc.hadoopFile(outputDir, + SequenceFileInputFormat.class, IntWritable.class, Text.class); + + Assert.assertEquals(pairs.toString(), output.map(new Function, + String>() { + @Override + public String call(Tuple2 x) { + return x.toString(); + } + }).collect().toString()); + } + + @Test + public void zip() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + JavaDoubleRDD doubles = rdd.map(new DoubleFunction() { + @Override + public Double call(Integer x) { + return 1.0 * x; + } + }); + JavaPairRDD zipped = rdd.zip(doubles); + zipped.count(); + } + + @Test + public void zipPartitions() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2); + JavaRDD rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2); + FlatMapFunction2, Iterator, Integer> sizesFn = + new FlatMapFunction2, Iterator, Integer>() { + @Override + public Iterable call(Iterator i, Iterator s) { + int sizeI = 0; + int sizeS = 0; + while (i.hasNext()) { + sizeI += 1; + i.next(); + } + while (s.hasNext()) { + sizeS += 1; + s.next(); + } + return Arrays.asList(sizeI, sizeS); + } + }; + + JavaRDD sizes = rdd1.zipPartitions(rdd2, sizesFn); + Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); + } + + @Test + public void accumulators() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + + final Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + intAccum.add(x); + } + }); + Assert.assertEquals((Integer) 25, intAccum.value()); + + final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + doubleAccum.add((double) x); + } + }); + Assert.assertEquals((Double) 25.0, doubleAccum.value()); + + // Try a custom accumulator type + AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { + public Float addInPlace(Float r, Float t) { + return r + t; + } + + public Float addAccumulator(Float r, Float t) { + return r + t; + } + + public Float zero(Float initialValue) { + return 0.0f; + } + }; + + final Accumulator floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + floatAccum.add((float) x); + } + }); + Assert.assertEquals((Float) 25.0f, floatAccum.value()); + + // Test the setValue method + floatAccum.setValue(5.0f); + Assert.assertEquals((Float) 5.0f, floatAccum.value()); + } + + @Test + public void keyBy() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); + List> s = rdd.keyBy(new Function() { + public String call(Integer t) throws Exception { + return t.toString(); + } + }).collect(); + Assert.assertEquals(new Tuple2("1", 1), s.get(0)); + Assert.assertEquals(new Tuple2("2", 2), s.get(1)); + } + + @Test + public void checkpointAndComputation() { + File tempDir = Files.createTempDir(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + sc.setCheckpointDir(tempDir.getAbsolutePath(), true); + Assert.assertEquals(false, rdd.isCheckpointed()); + rdd.checkpoint(); + rdd.count(); // Forces the DAG to cause a checkpoint + Assert.assertEquals(true, rdd.isCheckpointed()); + Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect()); + } + + @Test + public void checkpointAndRestore() { + File tempDir = Files.createTempDir(); + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + sc.setCheckpointDir(tempDir.getAbsolutePath(), true); + Assert.assertEquals(false, rdd.isCheckpointed()); + rdd.checkpoint(); + rdd.count(); // Forces the DAG to cause a checkpoint + Assert.assertEquals(true, rdd.isCheckpointed()); + + Assert.assertTrue(rdd.getCheckpointFile().isPresent()); + JavaRDD recovered = sc.checkpointFile(rdd.getCheckpointFile().get()); + Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect()); + } + + @Test + public void mapOnPairRDD() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1,2,3,4)); + JavaPairRDD rdd2 = rdd1.map(new PairFunction() { + @Override + public Tuple2 call(Integer i) throws Exception { + return new Tuple2(i, i % 2); + } + }); + JavaPairRDD rdd3 = rdd2.map( + new PairFunction, Integer, Integer>() { + @Override + public Tuple2 call(Tuple2 in) throws Exception { + return new Tuple2(in._2(), in._1()); + } + }); + Assert.assertEquals(Arrays.asList( + new Tuple2(1, 1), + new Tuple2(0, 2), + new Tuple2(1, 3), + new Tuple2(0, 4)), rdd3.collect()); + + } +} diff --git a/core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala new file mode 100644 index 0000000000..d7b23c93fe --- /dev/null +++ b/core/src/test/scala/org/apache/spark/KryoSerializerSuite.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable + +import org.scalatest.FunSuite +import com.esotericsoftware.kryo._ + +import KryoTest._ + +class KryoSerializerSuite extends FunSuite with SharedSparkContext { + test("basic types") { + val ser = (new KryoSerializer).newInstance() + def check[T](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + check(1) + check(1L) + check(1.0f) + check(1.0) + check(1.toByte) + check(1.toShort) + check("") + check("hello") + check(Integer.MAX_VALUE) + check(Integer.MIN_VALUE) + check(java.lang.Long.MAX_VALUE) + check(java.lang.Long.MIN_VALUE) + check[String](null) + check(Array(1, 2, 3)) + check(Array(1L, 2L, 3L)) + check(Array(1.0, 2.0, 3.0)) + check(Array(1.0f, 2.9f, 3.9f)) + check(Array("aaa", "bbb", "ccc")) + check(Array("aaa", "bbb", null)) + check(Array(true, false, true)) + check(Array('a', 'b', 'c')) + check(Array[Int]()) + check(Array(Array("1", "2"), Array("1", "2", "3", "4"))) + } + + test("pairs") { + val ser = (new KryoSerializer).newInstance() + def check[T](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + check((1, 1)) + check((1, 1L)) + check((1L, 1)) + check((1L, 1L)) + check((1.0, 1)) + check((1, 1.0)) + check((1.0, 1.0)) + check((1.0, 1L)) + check((1L, 1.0)) + check((1.0, 1L)) + check(("x", 1)) + check(("x", 1.0)) + check(("x", 1L)) + check((1, "x")) + check((1.0, "x")) + check((1L, "x")) + check(("x", "x")) + } + + test("Scala data structures") { + val ser = (new KryoSerializer).newInstance() + def check[T](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + check(List[Int]()) + check(List[Int](1, 2, 3)) + check(List[String]()) + check(List[String]("x", "y", "z")) + check(None) + check(Some(1)) + check(Some("hi")) + check(mutable.ArrayBuffer(1, 2, 3)) + check(mutable.ArrayBuffer("1", "2", "3")) + check(mutable.Map()) + check(mutable.Map(1 -> "one", 2 -> "two")) + check(mutable.Map("one" -> 1, "two" -> 2)) + check(mutable.HashMap(1 -> "one", 2 -> "two")) + check(mutable.HashMap("one" -> 1, "two" -> 2)) + check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) + check(List(mutable.HashMap("one" -> 1, "two" -> 2),mutable.HashMap(1->"one",2->"two",3->"three"))) + } + + test("custom registrator") { + import KryoTest._ + System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) + + val ser = (new KryoSerializer).newInstance() + def check[T](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + + check(CaseClass(17, "hello")) + + val c1 = new ClassWithNoArgConstructor + c1.x = 32 + check(c1) + + val c2 = new ClassWithoutNoArgConstructor(47) + check(c2) + + val hashMap = new java.util.HashMap[String, String] + hashMap.put("foo", "bar") + check(hashMap) + + System.clearProperty("spark.kryo.registrator") + } + + test("kryo with collect") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)).collect().map(_.x) + assert(control === result.toSeq) + } + + test("kryo with parallelize") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control.map(new ClassWithoutNoArgConstructor(_))).map(_.x).collect() + assert (control === result.toSeq) + } + + test("kryo with parallelize for specialized tuples") { + assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).count === 3) + } + + test("kryo with parallelize for primitive arrays") { + assert (sc.parallelize( Array(1, 2, 3) ).count === 3) + } + + test("kryo with collect for specialized tuples") { + assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).collect().head === (1, 11)) + } + + test("kryo with reduce") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) + .reduce((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x + assert(control.sum === result) + } + + // TODO: this still doesn't work + ignore("kryo with fold") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) + .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x + assert(10 + control.sum === result) + } + + override def beforeAll() { + System.setProperty("spark.serializer", "org.apache.spark.KryoSerializer") + System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) + super.beforeAll() + } + + override def afterAll() { + super.afterAll() + System.clearProperty("spark.kryo.registrator") + System.clearProperty("spark.serializer") + } +} + +object KryoTest { + case class CaseClass(i: Int, s: String) {} + + class ClassWithNoArgConstructor { + var x: Int = 0 + override def equals(other: Any) = other match { + case c: ClassWithNoArgConstructor => x == c.x + case _ => false + } + } + + class ClassWithoutNoArgConstructor(val x: Int) { + override def equals(other: Any) = other match { + case c: ClassWithoutNoArgConstructor => x == c.x + case _ => false + } + } + + class MyRegistrator extends KryoRegistrator { + override def registerClasses(k: Kryo) { + k.register(classOf[CaseClass]) + k.register(classOf[ClassWithNoArgConstructor]) + k.register(classOf[ClassWithoutNoArgConstructor]) + k.register(classOf[java.util.HashMap[_, _]]) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala new file mode 100644 index 0000000000..6ec124da9c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.Suite +import org.scalatest.BeforeAndAfterEach +import org.scalatest.BeforeAndAfterAll + +import org.jboss.netty.logging.InternalLoggerFactory +import org.jboss.netty.logging.Slf4JLoggerFactory + +/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */ +trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite => + + @transient var sc: SparkContext = _ + + override def beforeAll() { + InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()); + super.beforeAll() + } + + override def afterEach() { + resetSparkContext() + super.afterEach() + } + + def resetSparkContext() = { + if (sc != null) { + LocalSparkContext.stop(sc) + sc = null + } + } + +} + +object LocalSparkContext { + def stop(sc: SparkContext) { + sc.stop() + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + } + + /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ + def withSpark[T](sc: SparkContext)(f: SparkContext => T) = { + try { + f(sc) + } finally { + stop(sc) + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala new file mode 100644 index 0000000000..6013320eaa --- /dev/null +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +import akka.actor._ +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.AkkaUtils + +class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { + + test("compressSize") { + assert(MapOutputTracker.compressSize(0L) === 0) + assert(MapOutputTracker.compressSize(1L) === 1) + assert(MapOutputTracker.compressSize(2L) === 8) + assert(MapOutputTracker.compressSize(10L) === 25) + assert((MapOutputTracker.compressSize(1000000L) & 0xFF) === 145) + assert((MapOutputTracker.compressSize(1000000000L) & 0xFF) === 218) + // This last size is bigger than we can encode in a byte, so check that we just return 255 + assert((MapOutputTracker.compressSize(1000000000000000000L) & 0xFF) === 255) + } + + test("decompressSize") { + assert(MapOutputTracker.decompressSize(0) === 0) + for (size <- Seq(2L, 10L, 100L, 50000L, 1000000L, 1000000000L)) { + val size2 = MapOutputTracker.decompressSize(MapOutputTracker.compressSize(size)) + assert(size2 >= 0.99 * size && size2 <= 1.11 * size, + "size " + size + " decompressed to " + size2 + ", which is out of range") + } + } + + test("master start and stop") { + val actorSystem = ActorSystem("test") + val tracker = new MapOutputTracker() + tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker))) + tracker.stop() + } + + test("master register and fetch") { + val actorSystem = ActorSystem("test") + val tracker = new MapOutputTracker() + tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker))) + tracker.registerShuffle(10, 2) + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val compressedSize10000 = MapOutputTracker.compressSize(10000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + val size10000 = MapOutputTracker.decompressSize(compressedSize10000) + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), + Array(compressedSize1000, compressedSize10000))) + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), + Array(compressedSize10000, compressedSize1000))) + val statuses = tracker.getServerStatuses(10, 0) + assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000, 0), size1000), + (BlockManagerId("b", "hostB", 1000, 0), size10000))) + tracker.stop() + } + + test("master register and unregister and fetch") { + val actorSystem = ActorSystem("test") + val tracker = new MapOutputTracker() + tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker))) + tracker.registerShuffle(10, 2) + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val compressedSize10000 = MapOutputTracker.compressSize(10000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + val size10000 = MapOutputTracker.decompressSize(compressedSize10000) + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), + Array(compressedSize1000, compressedSize1000, compressedSize1000))) + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), + Array(compressedSize10000, compressedSize1000, compressedSize1000))) + + // As if we had two simulatenous fetch failures + tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) + + // The remaining reduce task might try to grab the output despite the shuffle failure; + // this should cause it to fail, and the scheduler will ignore the failure due to the + // stage already being aborted. + intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } + } + + test("remote fetch") { + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + val masterTracker = new MapOutputTracker() + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker") + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0) + val slaveTracker = new MapOutputTracker() + slaveTracker.trackerActor = slaveSystem.actorFor( + "akka://spark@localhost:" + boundPort + "/user/MapOutputTracker") + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + + masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + + // failure should be cached + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + } +} diff --git a/core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala new file mode 100644 index 0000000000..f79752b34e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/PairRDDFunctionsSuite.scala @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashSet + +import org.scalatest.FunSuite + +import com.google.common.io.Files +import org.apache.spark.SparkContext._ + + +class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { + test("groupByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with duplicates") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with negative key hash codes") { + val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesForMinus1 = groups.find(_._1 == -1).get._2 + assert(valuesForMinus1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with many output partitions") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) + val groups = pairs.groupByKey(10).collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("reduceByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("reduceByKey with collectAsMap") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_).collectAsMap() + assert(sums.size === 2) + assert(sums(1) === 7) + assert(sums(2) === 1) + } + + test("reduceByKey with many output partitons") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_, 10).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("reduceByKey with partitioner") { + val p = new Partitioner() { + def numPartitions = 2 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p) + val sums = pairs.reduceByKey(_+_) + assert(sums.collect().toSet === Set((1, 4), (0, 1))) + assert(sums.partitioner === Some(p)) + // count the dependencies to make sure there is only 1 ShuffledRDD + val deps = new HashSet[RDD[_]]() + def visit(r: RDD[_]) { + for (dep <- r.dependencies) { + deps += dep.rdd + visit(dep.rdd) + } + } + visit(sums) + assert(deps.size === 2) // ShuffledRDD, ParallelCollection + } + + test("join") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (2, 'x')), + (2, (1, 'y')), + (2, (1, 'z')) + )) + } + + test("join all-to-all") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3))) + val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 6) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (1, 'y')), + (1, (2, 'x')), + (1, (2, 'y')), + (1, (3, 'x')), + (1, (3, 'y')) + )) + } + + test("leftOuterJoin") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.leftOuterJoin(rdd2).collect() + assert(joined.size === 5) + assert(joined.toSet === Set( + (1, (1, Some('x'))), + (1, (2, Some('x'))), + (2, (1, Some('y'))), + (2, (1, Some('z'))), + (3, (1, None)) + )) + } + + test("rightOuterJoin") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.rightOuterJoin(rdd2).collect() + assert(joined.size === 5) + assert(joined.toSet === Set( + (1, (Some(1), 'x')), + (1, (Some(2), 'x')), + (2, (Some(1), 'y')), + (2, (Some(1), 'z')), + (4, (None, 'w')) + )) + } + + test("join with no matches") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 0) + } + + test("join with many output partitions") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.join(rdd2, 10).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (2, 'x')), + (2, (1, 'y')), + (2, (1, 'z')) + )) + } + + test("groupWith") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.groupWith(rdd2).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))), + (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))), + (3, (ArrayBuffer(1), ArrayBuffer())), + (4, (ArrayBuffer(), ArrayBuffer('w'))) + )) + } + + test("zero-partition RDD") { + val emptyDir = Files.createTempDir() + val file = sc.textFile(emptyDir.getAbsolutePath) + assert(file.partitions.size == 0) + assert(file.collect().toList === Nil) + // Test that a shuffle on the file works, because this used to be a bug + assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) + } + + test("keys and values") { + val rdd = sc.parallelize(Array((1, "a"), (2, "b"))) + assert(rdd.keys.collect().toList === List(1, 2)) + assert(rdd.values.collect().toList === List("a", "b")) + } + + test("default partitioner uses partition size") { + // specify 2000 partitions + val a = sc.makeRDD(Array(1, 2, 3, 4), 2000) + // do a map, which loses the partitioner + val b = a.map(a => (a, (a * 2).toString)) + // then a group by, and see we didn't revert to 2 partitions + val c = b.groupByKey() + assert(c.partitions.size === 2000) + } + + test("default partitioner uses largest partitioner") { + val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2) + val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000) + val c = a.join(b) + assert(c.partitions.size === 2000) + } + + test("subtract") { + val a = sc.parallelize(Array(1, 2, 3), 2) + val b = sc.parallelize(Array(2, 3, 4), 4) + val c = a.subtract(b) + assert(c.collect().toSet === Set(1)) + assert(c.partitions.size === a.partitions.size) + } + + test("subtract with narrow dependency") { + // use a deterministic partitioner + val p = new Partitioner() { + def numPartitions = 5 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + // partitionBy so we have a narrow dependency + val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p) + // more partitions/no partitioner so a shuffle dependency + val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) + val c = a.subtract(b) + assert(c.collect().toSet === Set((1, "a"), (3, "c"))) + // Ideally we could keep the original partitioner... + assert(c.partitioner === None) + } + + test("subtractByKey") { + val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2) + val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4) + val c = a.subtractByKey(b) + assert(c.collect().toSet === Set((1, "a"), (1, "a"))) + assert(c.partitions.size === a.partitions.size) + } + + test("subtractByKey with narrow dependency") { + // use a deterministic partitioner + val p = new Partitioner() { + def numPartitions = 5 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + // partitionBy so we have a narrow dependency + val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p) + // more partitions/no partitioner so a shuffle dependency + val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) + val c = a.subtractByKey(b) + assert(c.collect().toSet === Set((1, "a"), (1, "a"))) + assert(c.partitioner.get === p) + } + + test("foldByKey") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.foldByKey(0)(_+_).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("foldByKey with mutable result type") { + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache() + // Fold the values using in-place mutation + val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect() + assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1)))) + // Check that the mutable objects in the original RDD were not changed + assert(bufs.collect().toSet === Set( + (1, ArrayBuffer(1)), + (1, ArrayBuffer(2)), + (1, ArrayBuffer(3)), + (1, ArrayBuffer(1)), + (2, ArrayBuffer(1)))) + } +} diff --git a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala new file mode 100644 index 0000000000..adbe805916 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala @@ -0,0 +1,28 @@ +package org.apache.spark + +import org.scalatest.FunSuite +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.PartitionPruningRDD + + +class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { + + test("Pruned Partitions inherit locality prefs correctly") { + class TestPartition(i: Int) extends Partition { + def index = i + } + val rdd = new RDD[Int](sc, Nil) { + override protected def getPartitions = { + Array[Partition]( + new TestPartition(1), + new TestPartition(2), + new TestPartition(3)) + } + def compute(split: Partition, context: TaskContext) = {Iterator()} + } + val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false}) + val p = prunedRDD.partitions(0) + assert(p.index == 2) + assert(prunedRDD.partitions.length == 1) + } +} diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala new file mode 100644 index 0000000000..7669cf6fb1 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import scala.collection.mutable.ArrayBuffer +import SparkContext._ +import org.apache.spark.util.StatCounter +import scala.math.abs + +class PartitioningSuite extends FunSuite with SharedSparkContext { + + test("HashPartitioner equality") { + val p2 = new HashPartitioner(2) + val p4 = new HashPartitioner(4) + val anotherP4 = new HashPartitioner(4) + assert(p2 === p2) + assert(p4 === p4) + assert(p2 != p4) + assert(p4 != p2) + assert(p4 === anotherP4) + assert(anotherP4 === p4) + } + + test("RangePartitioner equality") { + // Make an RDD where all the elements are the same so that the partition range bounds + // are deterministically all the same. + val rdd = sc.parallelize(Seq(1, 1, 1, 1)).map(x => (x, x)) + + val p2 = new RangePartitioner(2, rdd) + val p4 = new RangePartitioner(4, rdd) + val anotherP4 = new RangePartitioner(4, rdd) + val descendingP2 = new RangePartitioner(2, rdd, false) + val descendingP4 = new RangePartitioner(4, rdd, false) + + assert(p2 === p2) + assert(p4 === p4) + assert(p2 != p4) + assert(p4 != p2) + assert(p4 === anotherP4) + assert(anotherP4 === p4) + assert(descendingP2 === descendingP2) + assert(descendingP4 === descendingP4) + assert(descendingP2 != descendingP4) + assert(descendingP4 != descendingP2) + assert(p2 != descendingP2) + assert(p4 != descendingP4) + assert(descendingP2 != p2) + assert(descendingP4 != p4) + } + + test("HashPartitioner not equal to RangePartitioner") { + val rdd = sc.parallelize(1 to 10).map(x => (x, x)) + val rangeP2 = new RangePartitioner(2, rdd) + val hashP2 = new HashPartitioner(2) + assert(rangeP2 === rangeP2) + assert(hashP2 === hashP2) + assert(hashP2 != rangeP2) + assert(rangeP2 != hashP2) + } + + test("partitioner preservation") { + val rdd = sc.parallelize(1 to 10, 4).map(x => (x, x)) + + val grouped2 = rdd.groupByKey(2) + val grouped4 = rdd.groupByKey(4) + val reduced2 = rdd.reduceByKey(_ + _, 2) + val reduced4 = rdd.reduceByKey(_ + _, 4) + + assert(rdd.partitioner === None) + + assert(grouped2.partitioner === Some(new HashPartitioner(2))) + assert(grouped4.partitioner === Some(new HashPartitioner(4))) + assert(reduced2.partitioner === Some(new HashPartitioner(2))) + assert(reduced4.partitioner === Some(new HashPartitioner(4))) + + assert(grouped2.groupByKey().partitioner === grouped2.partitioner) + assert(grouped2.groupByKey(3).partitioner != grouped2.partitioner) + assert(grouped2.groupByKey(2).partitioner === grouped2.partitioner) + assert(grouped4.groupByKey().partitioner === grouped4.partitioner) + assert(grouped4.groupByKey(3).partitioner != grouped4.partitioner) + assert(grouped4.groupByKey(4).partitioner === grouped4.partitioner) + + assert(grouped2.join(grouped4).partitioner === grouped4.partitioner) + assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped4.partitioner) + assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped4.partitioner) + assert(grouped2.cogroup(grouped4).partitioner === grouped4.partitioner) + + assert(grouped2.join(reduced2).partitioner === grouped2.partitioner) + assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner) + assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner) + assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner) + + assert(grouped2.map(_ => 1).partitioner === None) + assert(grouped2.mapValues(_ => 1).partitioner === grouped2.partitioner) + assert(grouped2.flatMapValues(_ => Seq(1)).partitioner === grouped2.partitioner) + assert(grouped2.filter(_._1 > 4).partitioner === grouped2.partitioner) + } + + test("partitioning Java arrays should fail") { + val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x)) + val arrPairs: RDD[(Array[Int], Int)] = + sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x)) + + assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array")) + // We can't catch all usages of arrays, since they might occur inside other collections: + //assert(fails { arrPairs.distinct() }) + assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.cogroup(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array")) + } + + test("zero-length partitions should be correctly handled") { + // Create RDD with some consecutive empty partitions (including the "first" one) + val rdd: RDD[Double] = sc + .parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8) + .filter(_ >= 0.0) + + // Run the partitions, including the consecutive empty ones, through StatCounter + val stats: StatCounter = rdd.stats(); + assert(abs(6.0 - stats.sum) < 0.01); + assert(abs(6.0/2 - rdd.mean) < 0.01); + assert(abs(1.0 - rdd.variance) < 0.01); + assert(abs(1.0 - rdd.stdev) < 0.01); + + // Add other tests here for classes that should be able to handle empty partitions correctly + } +} diff --git a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala new file mode 100644 index 0000000000..2e851d892d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import SparkContext._ + +class PipedRDDSuite extends FunSuite with SharedSparkContext { + + test("basic pipe") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + + val piped = nums.pipe(Seq("cat")) + + val c = piped.collect() + assert(c.size === 4) + assert(c(0) === "1") + assert(c(1) === "2") + assert(c(2) === "3") + assert(c(3) === "4") + } + + test("advanced pipe") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val bl = sc.broadcast(List("0")) + + val piped = nums.pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, + (i:Int, f: String=> Unit) => f(i + "_")) + + val c = piped.collect() + + assert(c.size === 8) + assert(c(0) === "0") + assert(c(1) === "\u0001") + assert(c(2) === "1_") + assert(c(3) === "2_") + assert(c(4) === "0") + assert(c(5) === "\u0001") + assert(c(6) === "3_") + assert(c(7) === "4_") + + val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) + val d = nums1.groupBy(str=>str.split("\t")(0)). + pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, + (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect() + assert(d.size === 8) + assert(d(0) === "0") + assert(d(1) === "\u0001") + assert(d(2) === "b\t2_") + assert(d(3) === "b\t4_") + assert(d(4) === "0") + assert(d(5) === "\u0001") + assert(d(6) === "a\t1_") + assert(d(7) === "a\t3_") + } + + test("pipe with env variable") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA")) + val c = piped.collect() + assert(c.size === 2) + assert(c(0) === "LALALA") + assert(c(1) === "LALALA") + } + + test("pipe with non-zero exit status") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null")) + intercept[SparkException] { + piped.collect() + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/RDDSuite.scala b/core/src/test/scala/org/apache/spark/RDDSuite.scala new file mode 100644 index 0000000000..342ba8adb2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/RDDSuite.scala @@ -0,0 +1,389 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable.HashMap +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.time.{Span, Millis} +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd._ +import scala.collection.parallel.mutable + +class RDDSuite extends FunSuite with SharedSparkContext { + + test("basic operations") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + assert(nums.collect().toList === List(1, 2, 3, 4)) + val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) + assert(dups.distinct().count() === 4) + assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses? + assert(dups.distinct.collect === dups.distinct().collect) + assert(dups.distinct(2).collect === dups.distinct().collect) + assert(nums.reduce(_ + _) === 10) + assert(nums.fold(0)(_ + _) === 10) + assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4")) + assert(nums.filter(_ > 2).collect().toList === List(3, 4)) + assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) + assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) + assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) + assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4")) + assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) + val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) + assert(partitionSums.collect().toList === List(3, 7)) + + val partitionSumsWithSplit = nums.mapPartitionsWithSplit { + case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) + } + assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7))) + + val partitionSumsWithIndex = nums.mapPartitionsWithIndex { + case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) + } + assert(partitionSumsWithIndex.collect().toList === List((0, 3), (1, 7))) + + intercept[UnsupportedOperationException] { + nums.filter(_ > 5).reduce(_ + _) + } + } + + test("SparkContext.union") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + assert(sc.union(nums).collect().toList === List(1, 2, 3, 4)) + assert(sc.union(nums, nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) + assert(sc.union(Seq(nums)).collect().toList === List(1, 2, 3, 4)) + assert(sc.union(Seq(nums, nums)).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) + } + + test("aggregate") { + val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) + type StringMap = HashMap[String, Int] + val emptyMap = new StringMap { + override def default(key: String): Int = 0 + } + val mergeElement: (StringMap, (String, Int)) => StringMap = (map, pair) => { + map(pair._1) += pair._2 + map + } + val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => { + for ((key, value) <- map2) { + map1(key) += value + } + map1 + } + val result = pairs.aggregate(emptyMap)(mergeElement, mergeMaps) + assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) + } + + test("basic caching") { + val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() + assert(rdd.collect().toList === List(1, 2, 3, 4)) + assert(rdd.collect().toList === List(1, 2, 3, 4)) + assert(rdd.collect().toList === List(1, 2, 3, 4)) + } + + test("caching with failures") { + val onlySplit = new Partition { override def index: Int = 0 } + var shouldFail = true + val rdd = new RDD[Int](sc, Nil) { + override def getPartitions: Array[Partition] = Array(onlySplit) + override val getDependencies = List[Dependency[_]]() + override def compute(split: Partition, context: TaskContext): Iterator[Int] = { + if (shouldFail) { + throw new Exception("injected failure") + } else { + return Array(1, 2, 3, 4).iterator + } + } + }.cache() + val thrown = intercept[Exception]{ + rdd.collect() + } + assert(thrown.getMessage.contains("injected failure")) + shouldFail = false + assert(rdd.collect().toList === List(1, 2, 3, 4)) + } + + test("empty RDD") { + val empty = new EmptyRDD[Int](sc) + assert(empty.count === 0) + assert(empty.collect().size === 0) + + val thrown = intercept[UnsupportedOperationException]{ + empty.reduce(_+_) + } + assert(thrown.getMessage.contains("empty")) + + val emptyKv = new EmptyRDD[(Int, Int)](sc) + val rdd = sc.parallelize(1 to 2, 2).map(x => (x, x)) + assert(rdd.join(emptyKv).collect().size === 0) + assert(rdd.rightOuterJoin(emptyKv).collect().size === 0) + assert(rdd.leftOuterJoin(emptyKv).collect().size === 2) + assert(rdd.cogroup(emptyKv).collect().size === 2) + assert(rdd.union(emptyKv).collect().size === 2) + } + + test("cogrouped RDDs") { + val data = sc.parallelize(1 to 10, 10) + + val coalesced1 = data.coalesce(2) + assert(coalesced1.collect().toList === (1 to 10).toList) + assert(coalesced1.glom().collect().map(_.toList).toList === + List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10))) + + // Check that the narrow dependency is also specified correctly + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === + List(0, 1, 2, 3, 4)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === + List(5, 6, 7, 8, 9)) + + val coalesced2 = data.coalesce(3) + assert(coalesced2.collect().toList === (1 to 10).toList) + assert(coalesced2.glom().collect().map(_.toList).toList === + List(List(1, 2, 3), List(4, 5, 6), List(7, 8, 9, 10))) + + val coalesced3 = data.coalesce(10) + assert(coalesced3.collect().toList === (1 to 10).toList) + assert(coalesced3.glom().collect().map(_.toList).toList === + (1 to 10).map(x => List(x)).toList) + + // If we try to coalesce into more partitions than the original RDD, it should just + // keep the original number of partitions. + val coalesced4 = data.coalesce(20) + assert(coalesced4.collect().toList === (1 to 10).toList) + assert(coalesced4.glom().collect().map(_.toList).toList === + (1 to 10).map(x => List(x)).toList) + + // we can optionally shuffle to keep the upstream parallel + val coalesced5 = data.coalesce(1, shuffle = true) + assert(coalesced5.dependencies.head.rdd.dependencies.head.rdd.asInstanceOf[ShuffledRDD[_, _, _]] != + null) + } + test("cogrouped RDDs with locality") { + val data3 = sc.makeRDD(List((1,List("a","c")), (2,List("a","b","c")), (3,List("b")))) + val coal3 = data3.coalesce(3) + val list3 = coal3.partitions.map(p => p.asInstanceOf[CoalescedRDDPartition].preferredLocation) + assert(list3.sorted === Array("a","b","c"), "Locality preferences are dropped") + + // RDD with locality preferences spread (non-randomly) over 6 machines, m0 through m5 + val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i+2)).map{ j => "m" + (j%6)}))) + val coalesced1 = data.coalesce(3) + assert(coalesced1.collect().toList.sorted === (1 to 9).toList, "Data got *lost* in coalescing") + + val splits = coalesced1.glom().collect().map(_.toList).toList + assert(splits.length === 3, "Supposed to coalesce to 3 but got " + splits.length) + + assert(splits.forall(_.length >= 1) === true, "Some partitions were empty") + + // If we try to coalesce into more partitions than the original RDD, it should just + // keep the original number of partitions. + val coalesced4 = data.coalesce(20) + val listOfLists = coalesced4.glom().collect().map(_.toList).toList + val sortedList = listOfLists.sortWith{ (x, y) => !x.isEmpty && (y.isEmpty || (x(0) < y(0))) } + assert( sortedList === (1 to 9). + map{x => List(x)}.toList, "Tried coalescing 9 partitions to 20 but didn't get 9 back") + } + + test("cogrouped RDDs with locality, large scale (10K partitions)") { + // large scale experiment + import collection.mutable + val rnd = scala.util.Random + val partitions = 10000 + val numMachines = 50 + val machines = mutable.ListBuffer[String]() + (1 to numMachines).foreach(machines += "m"+_) + + val blocks = (1 to partitions).map(i => + { (i, Array.fill(3)(machines(rnd.nextInt(machines.size))).toList) } ) + + val data2 = sc.makeRDD(blocks) + val coalesced2 = data2.coalesce(numMachines*2) + + // test that you get over 90% locality in each group + val minLocality = coalesced2.partitions + .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) + .foldLeft(1.)((perc, loc) => math.min(perc,loc)) + assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.).toInt + "%") + + // test that the groups are load balanced with 100 +/- 20 elements in each + val maxImbalance = coalesced2.partitions + .map(part => part.asInstanceOf[CoalescedRDDPartition].parents.size) + .foldLeft(0)((dev, curr) => math.max(math.abs(100-curr),dev)) + assert(maxImbalance <= 20, "Expected 100 +/- 20 per partition, but got " + maxImbalance) + + val data3 = sc.makeRDD(blocks).map(i => i*2) // derived RDD to test *current* pref locs + val coalesced3 = data3.coalesce(numMachines*2) + val minLocality2 = coalesced3.partitions + .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) + .foldLeft(1.)((perc, loc) => math.min(perc,loc)) + assert(minLocality2 >= 0.90, "Expected 90% locality for derived RDD but got " + + (minLocality2*100.).toInt + "%") + } + + test("zipped RDDs") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val zipped = nums.zip(nums.map(_ + 1.0)) + assert(zipped.glom().map(_.toList).collect().toList === + List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0)))) + + intercept[IllegalArgumentException] { + nums.zip(sc.parallelize(1 to 4, 1)).collect() + } + } + + test("partition pruning") { + val data = sc.parallelize(1 to 10, 10) + // Note that split number starts from 0, so > 8 means only 10th partition left. + val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8) + assert(prunedRdd.partitions.size === 1) + val prunedData = prunedRdd.collect() + assert(prunedData.size === 1) + assert(prunedData(0) === 10) + } + + test("mapWith") { + import java.util.Random + val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) + val randoms = ones.mapWith( + (index: Int) => new Random(index + 42)) + {(t: Int, prng: Random) => prng.nextDouble * t}.collect() + val prn42_3 = { + val prng42 = new Random(42) + prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() + } + val prn43_3 = { + val prng43 = new Random(43) + prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() + } + assert(randoms(2) === prn42_3) + assert(randoms(5) === prn43_3) + } + + test("flatMapWith") { + import java.util.Random + val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) + val randoms = ones.flatMapWith( + (index: Int) => new Random(index + 42)) + {(t: Int, prng: Random) => + val random = prng.nextDouble() + Seq(random * t, random * t * 10)}. + collect() + val prn42_3 = { + val prng42 = new Random(42) + prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() + } + val prn43_3 = { + val prng43 = new Random(43) + prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() + } + assert(randoms(5) === prn42_3 * 10) + assert(randoms(11) === prn43_3 * 10) + } + + test("filterWith") { + import java.util.Random + val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) + val sample = ints.filterWith( + (index: Int) => new Random(index + 42)) + {(t: Int, prng: Random) => prng.nextInt(3) == 0}. + collect() + val checkSample = { + val prng42 = new Random(42) + val prng43 = new Random(43) + Array(1, 2, 3, 4, 5, 6).filter{i => + if (i < 4) 0 == prng42.nextInt(3) + else 0 == prng43.nextInt(3)} + } + assert(sample.size === checkSample.size) + for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) + } + + test("top with predefined ordering") { + val nums = Array.range(1, 100000) + val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2) + val topK = ints.top(5) + assert(topK.size === 5) + assert(topK === nums.reverse.take(5)) + } + + test("top with custom ordering") { + val words = Vector("a", "b", "c", "d") + implicit val ord = implicitly[Ordering[String]].reverse + val rdd = sc.makeRDD(words, 2) + val topK = rdd.top(2) + assert(topK.size === 2) + assert(topK.sorted === Array("b", "a")) + } + + test("takeOrdered with predefined ordering") { + val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + val rdd = sc.makeRDD(nums, 2) + val sortedLowerK = rdd.takeOrdered(5) + assert(sortedLowerK.size === 5) + assert(sortedLowerK === Array(1, 2, 3, 4, 5)) + } + + test("takeOrdered with custom ordering") { + val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + implicit val ord = implicitly[Ordering[Int]].reverse + val rdd = sc.makeRDD(nums, 2) + val sortedTopK = rdd.takeOrdered(5) + assert(sortedTopK.size === 5) + assert(sortedTopK === Array(10, 9, 8, 7, 6)) + assert(sortedTopK === nums.sorted(ord).take(5)) + } + + test("takeSample") { + val data = sc.parallelize(1 to 100, 2) + for (seed <- 1 to 5) { + val sample = data.takeSample(withReplacement=false, 20, seed) + assert(sample.size === 20) // Got exactly 20 elements + assert(sample.toSet.size === 20) // Elements are distinct + assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + } + for (seed <- 1 to 5) { + val sample = data.takeSample(withReplacement=false, 200, seed) + assert(sample.size === 100) // Got only 100 elements + assert(sample.toSet.size === 100) // Elements are distinct + assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + } + for (seed <- 1 to 5) { + val sample = data.takeSample(withReplacement=true, 20, seed) + assert(sample.size === 20) // Got exactly 20 elements + assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + } + for (seed <- 1 to 5) { + val sample = data.takeSample(withReplacement=true, 100, seed) + assert(sample.size === 100) // Got exactly 100 elements + // Chance of getting all distinct elements is astronomically low, so test we got < 100 + assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + } + for (seed <- 1 to 5) { + val sample = data.takeSample(withReplacement=true, 200, seed) + assert(sample.size === 200) // Got exactly 200 elements + // Chance of getting all distinct elements is still quite low, so test we got < 100 + assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + } + } + + test("runJob on an invalid partition") { + intercept[IllegalArgumentException] { + sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala new file mode 100644 index 0000000000..97cbca09bf --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.Suite +import org.scalatest.BeforeAndAfterAll + +/** Shares a local `SparkContext` between all tests in a suite and closes it at the end */ +trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => + + @transient private var _sc: SparkContext = _ + + def sc: SparkContext = _sc + + override def beforeAll() { + _sc = new SparkContext("local", "test") + super.beforeAll() + } + + override def afterAll() { + if (_sc != null) { + LocalSparkContext.stop(_sc) + _sc = null + } + super.afterAll() + } +} diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala new file mode 100644 index 0000000000..e121b162ad --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.BeforeAndAfterAll + + +class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll { + + // This test suite should run all tests in ShuffleSuite with Netty shuffle mode. + + override def beforeAll(configMap: Map[String, Any]) { + System.setProperty("spark.shuffle.use.netty", "true") + } + + override def afterAll(configMap: Map[String, Any]) { + System.setProperty("spark.shuffle.use.netty", "false") + } +} diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala new file mode 100644 index 0000000000..357175e89e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers + +import org.apache.spark.SparkContext._ +import org.apache.spark.ShuffleSuite.NonJavaSerializableClass +import org.apache.spark.rdd.{SubtractedRDD, CoGroupedRDD, OrderedRDDFunctions, ShuffledRDD} +import org.apache.spark.util.MutablePair + + +class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { + test("groupByKey without compression") { + try { + System.setProperty("spark.shuffle.compress", "false") + sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4) + val groups = pairs.groupByKey(4).collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } finally { + System.setProperty("spark.shuffle.compress", "true") + } + } + + test("shuffle non-zero block size") { + sc = new SparkContext("local-cluster[2,1,512]", "test") + val NUM_BLOCKS = 3 + + val a = sc.parallelize(1 to 10, 2) + val b = a.map { x => + (x, new NonJavaSerializableClass(x * 2)) + } + // If the Kryo serializer is not used correctly, the shuffle would fail because the + // default Java serializer cannot handle the non serializable class. + val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( + b, new HashPartitioner(NUM_BLOCKS)).setSerializer(classOf[KryoSerializer].getName) + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + + assert(c.count === 10) + + // All blocks must have non-zero size + (0 until NUM_BLOCKS).foreach { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + assert(statuses.forall(s => s._2 > 0)) + } + } + + test("shuffle serializer") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + val a = sc.parallelize(1 to 10, 2) + val b = a.map { x => + (x, new NonJavaSerializableClass(x * 2)) + } + // If the Kryo serializer is not used correctly, the shuffle would fail because the + // default Java serializer cannot handle the non serializable class. + val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( + b, new HashPartitioner(3)).setSerializer(classOf[KryoSerializer].getName) + assert(c.count === 10) + } + + test("zero sized blocks") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + + // 10 partitions from 4 keys + val NUM_BLOCKS = 10 + val a = sc.parallelize(1 to 4, NUM_BLOCKS) + val b = a.map(x => (x, x*2)) + + // NOTE: The default Java serializer doesn't create zero-sized blocks. + // So, use Kryo + val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) + .setSerializer(classOf[KryoSerializer].getName) + + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + assert(c.count === 4) + + val blockSizes = (0 until NUM_BLOCKS).flatMap { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + statuses.map(x => x._2) + } + val nonEmptyBlocks = blockSizes.filter(x => x > 0) + + // We should have at most 4 non-zero sized partitions + assert(nonEmptyBlocks.size <= 4) + } + + test("zero sized blocks without kryo") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + + // 10 partitions from 4 keys + val NUM_BLOCKS = 10 + val a = sc.parallelize(1 to 4, NUM_BLOCKS) + val b = a.map(x => (x, x*2)) + + // NOTE: The default Java serializer should create zero-sized blocks + val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) + + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + assert(c.count === 4) + + val blockSizes = (0 until NUM_BLOCKS).flatMap { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + statuses.map(x => x._2) + } + val nonEmptyBlocks = blockSizes.filter(x => x > 0) + + // We should have at most 4 non-zero sized partitions + assert(nonEmptyBlocks.size <= 4) + } + + test("shuffle using mutable pairs") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) + val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) + val results = new ShuffledRDD[Int, Int, MutablePair[Int, Int]](pairs, new HashPartitioner(2)) + .collect() + + data.foreach { pair => results should contain (pair) } + } + + test("sorting using mutable pairs") { + // This is not in SortingSuite because of the local cluster setup. + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22)) + val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) + val results = new OrderedRDDFunctions[Int, Int, MutablePair[Int, Int]](pairs) + .sortByKey().collect() + results(0) should be (p(1, 11)) + results(1) should be (p(2, 22)) + results(2) should be (p(3, 33)) + results(3) should be (p(100, 100)) + } + + test("cogroup using mutable pairs") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) + val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3")) + val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) + val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2) + val results = new CoGroupedRDD[Int](Seq(pairs1, pairs2), new HashPartitioner(2)).collectAsMap() + + assert(results(1)(0).length === 3) + assert(results(1)(0).contains(1)) + assert(results(1)(0).contains(2)) + assert(results(1)(0).contains(3)) + assert(results(1)(1).length === 2) + assert(results(1)(1).contains("11")) + assert(results(1)(1).contains("12")) + assert(results(2)(0).length === 1) + assert(results(2)(0).contains(1)) + assert(results(2)(1).length === 1) + assert(results(2)(1).contains("22")) + assert(results(3)(0).length === 0) + assert(results(3)(1).contains("3")) + } + + test("subtract mutable pairs") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) + val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33)) + val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22")) + val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) + val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2) + val results = new SubtractedRDD(pairs1, pairs2, new HashPartitioner(2)).collect() + results should have length (1) + // substracted rdd return results as Tuple2 + results(0) should be ((3, 33)) + } +} + +object ShuffleSuite { + + def mergeCombineException(x: Int, y: Int): Int = { + throw new SparkException("Exception for map-side combine.") + x + y + } + + class NonJavaSerializableClass(val value: Int) +} diff --git a/core/src/test/scala/org/apache/spark/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/SizeEstimatorSuite.scala new file mode 100644 index 0000000000..214ac74898 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SizeEstimatorSuite.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfterAll +import org.scalatest.PrivateMethodTester + +class DummyClass1 {} + +class DummyClass2 { + val x: Int = 0 +} + +class DummyClass3 { + val x: Int = 0 + val y: Double = 0.0 +} + +class DummyClass4(val d: DummyClass3) { + val x: Int = 0 +} + +object DummyString { + def apply(str: String) : DummyString = new DummyString(str.toArray) +} +class DummyString(val arr: Array[Char]) { + override val hashCode: Int = 0 + // JDK-7 has an extra hash32 field http://hg.openjdk.java.net/jdk7u/jdk7u6/jdk/rev/11987e85555f + @transient val hash32: Int = 0 +} + +class SizeEstimatorSuite + extends FunSuite with BeforeAndAfterAll with PrivateMethodTester { + + var oldArch: String = _ + var oldOops: String = _ + + override def beforeAll() { + // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case + oldArch = System.setProperty("os.arch", "amd64") + oldOops = System.setProperty("spark.test.useCompressedOops", "true") + } + + override def afterAll() { + resetOrClear("os.arch", oldArch) + resetOrClear("spark.test.useCompressedOops", oldOops) + } + + test("simple classes") { + assert(SizeEstimator.estimate(new DummyClass1) === 16) + assert(SizeEstimator.estimate(new DummyClass2) === 16) + assert(SizeEstimator.estimate(new DummyClass3) === 24) + assert(SizeEstimator.estimate(new DummyClass4(null)) === 24) + assert(SizeEstimator.estimate(new DummyClass4(new DummyClass3)) === 48) + } + + // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors + // (Sun vs IBM). Use a DummyString class to make tests deterministic. + test("strings") { + assert(SizeEstimator.estimate(DummyString("")) === 40) + assert(SizeEstimator.estimate(DummyString("a")) === 48) + assert(SizeEstimator.estimate(DummyString("ab")) === 48) + assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) + } + + test("primitive arrays") { + assert(SizeEstimator.estimate(new Array[Byte](10)) === 32) + assert(SizeEstimator.estimate(new Array[Char](10)) === 40) + assert(SizeEstimator.estimate(new Array[Short](10)) === 40) + assert(SizeEstimator.estimate(new Array[Int](10)) === 56) + assert(SizeEstimator.estimate(new Array[Long](10)) === 96) + assert(SizeEstimator.estimate(new Array[Float](10)) === 56) + assert(SizeEstimator.estimate(new Array[Double](10)) === 96) + assert(SizeEstimator.estimate(new Array[Int](1000)) === 4016) + assert(SizeEstimator.estimate(new Array[Long](1000)) === 8016) + } + + test("object arrays") { + // Arrays containing nulls should just have one pointer per element + assert(SizeEstimator.estimate(new Array[String](10)) === 56) + assert(SizeEstimator.estimate(new Array[AnyRef](10)) === 56) + + // For object arrays with non-null elements, each object should take one pointer plus + // however many bytes that class takes. (Note that Array.fill calls the code in its + // second parameter separately for each object, so we get distinct objects.) + assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)) === 216) + assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)) === 216) + assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)) === 296) + assert(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)) === 56) + + // Past size 100, our samples 100 elements, but we should still get the right size. + assert(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)) === 28016) + + // If an array contains the *same* element many times, we should only count it once. + val d1 = new DummyClass1 + assert(SizeEstimator.estimate(Array.fill(10)(d1)) === 72) // 10 pointers plus 8-byte object + assert(SizeEstimator.estimate(Array.fill(100)(d1)) === 432) // 100 pointers plus 8-byte object + + // Same thing with huge array containing the same element many times. Note that this won't + // return exactly 4032 because it can't tell that *all* the elements will equal the first + // one it samples, but it should be close to that. + + // TODO: If we sample 100 elements, this should always be 4176 ? + val estimatedSize = SizeEstimator.estimate(Array.fill(1000)(d1)) + assert(estimatedSize >= 4000, "Estimated size " + estimatedSize + " should be more than 4000") + assert(estimatedSize <= 4200, "Estimated size " + estimatedSize + " should be less than 4100") + } + + test("32-bit arch") { + val arch = System.setProperty("os.arch", "x86") + + val initialize = PrivateMethod[Unit]('initialize) + SizeEstimator invokePrivate initialize() + + assert(SizeEstimator.estimate(DummyString("")) === 40) + assert(SizeEstimator.estimate(DummyString("a")) === 48) + assert(SizeEstimator.estimate(DummyString("ab")) === 48) + assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) + + resetOrClear("os.arch", arch) + } + + // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors + // (Sun vs IBM). Use a DummyString class to make tests deterministic. + test("64-bit arch with no compressed oops") { + val arch = System.setProperty("os.arch", "amd64") + val oops = System.setProperty("spark.test.useCompressedOops", "false") + + val initialize = PrivateMethod[Unit]('initialize) + SizeEstimator invokePrivate initialize() + + assert(SizeEstimator.estimate(DummyString("")) === 56) + assert(SizeEstimator.estimate(DummyString("a")) === 64) + assert(SizeEstimator.estimate(DummyString("ab")) === 64) + assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 72) + + resetOrClear("os.arch", arch) + resetOrClear("spark.test.useCompressedOops", oops) + } + + def resetOrClear(prop: String, oldValue: String) { + if (oldValue != null) { + System.setProperty(prop, oldValue) + } else { + System.clearProperty(prop) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/SortingSuite.scala b/core/src/test/scala/org/apache/spark/SortingSuite.scala new file mode 100644 index 0000000000..f4fa9511dd --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SortingSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import org.scalatest.matchers.ShouldMatchers +import SparkContext._ + +class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging { + + test("sortByKey") { + val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2) + assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0))) + } + + test("large array") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + val sorted = pairs.sortByKey() + assert(sorted.partitions.size === 2) + assert(sorted.collect() === pairArr.sortBy(_._1)) + } + + test("large array with one split") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + val sorted = pairs.sortByKey(true, 1) + assert(sorted.partitions.size === 1) + assert(sorted.collect() === pairArr.sortBy(_._1)) + } + + test("large array with many partitions") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + val sorted = pairs.sortByKey(true, 20) + assert(sorted.partitions.size === 20) + assert(sorted.collect() === pairArr.sortBy(_._1)) + } + + test("sort descending") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) + } + + test("sort descending with one split") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 1) + assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) + } + + test("sort descending with many partitions") { + val rand = new scala.util.Random() + val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 2) + assert(pairs.sortByKey(false, 20).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) + } + + test("more partitions than elements") { + val rand = new scala.util.Random() + val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 30) + assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) + } + + test("empty RDD") { + val pairArr = new Array[(Int, Int)](0) + val pairs = sc.parallelize(pairArr, 2) + assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) + } + + test("partition balancing") { + val pairArr = (1 to 1000).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 4).sortByKey() + assert(sorted.collect() === pairArr.sortBy(_._1)) + val partitions = sorted.collectPartitions() + logInfo("Partition lengths: " + partitions.map(_.length).mkString(", ")) + partitions(0).length should be > 180 + partitions(1).length should be > 180 + partitions(2).length should be > 180 + partitions(3).length should be > 180 + partitions(0).last should be < partitions(1).head + partitions(1).last should be < partitions(2).head + partitions(2).last should be < partitions(3).head + } + + test("partition balancing for descending sort") { + val pairArr = (1 to 1000).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 4).sortByKey(false) + assert(sorted.collect() === pairArr.sortBy(_._1).reverse) + val partitions = sorted.collectPartitions() + logInfo("partition lengths: " + partitions.map(_.length).mkString(", ")) + partitions(0).length should be > 180 + partitions(1).length should be > 180 + partitions(2).length should be > 180 + partitions(3).length should be > 180 + partitions(0).last should be > partitions(1).head + partitions(1).last should be > partitions(2).head + partitions(2).last should be > partitions(3).head + } +} + diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala new file mode 100644 index 0000000000..939fe51801 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import org.apache.spark.SparkContext._ + +class SparkContextInfoSuite extends FunSuite with LocalSparkContext { + test("getPersistentRDDs only returns RDDs that are marked as cached") { + sc = new SparkContext("local", "test") + assert(sc.getPersistentRDDs.isEmpty === true) + + val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) + assert(sc.getPersistentRDDs.isEmpty === true) + + rdd.cache() + assert(sc.getPersistentRDDs.size === 1) + assert(sc.getPersistentRDDs.values.head === rdd) + } + + test("getPersistentRDDs returns an immutable map") { + sc = new SparkContext("local", "test") + val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() + + val myRdds = sc.getPersistentRDDs + assert(myRdds.size === 1) + assert(myRdds.values.head === rdd1) + + val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache() + + // getPersistentRDDs should have 2 RDDs, but myRdds should not change + assert(sc.getPersistentRDDs.size === 2) + assert(myRdds.size === 1) + } + + test("getRDDStorageInfo only reports on RDDs that actually persist data") { + sc = new SparkContext("local", "test") + val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() + + assert(sc.getRDDStorageInfo.size === 0) + + rdd.collect() + assert(sc.getRDDStorageInfo.size === 1) + } +} diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala new file mode 100644 index 0000000000..69383ddfb8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import SparkContext._ + +/** + * Holds state shared across task threads in some ThreadingSuite tests. + */ +object ThreadingSuiteState { + val runningThreads = new AtomicInteger + val failed = new AtomicBoolean + + def clear() { + runningThreads.set(0) + failed.set(false) + } +} + +class ThreadingSuite extends FunSuite with LocalSparkContext { + + test("accessing SparkContext form a different thread") { + sc = new SparkContext("local", "test") + val nums = sc.parallelize(1 to 10, 2) + val sem = new Semaphore(0) + @volatile var answer1: Int = 0 + @volatile var answer2: Int = 0 + new Thread { + override def run() { + answer1 = nums.reduce(_ + _) + answer2 = nums.first() // This will run "locally" in the current thread + sem.release() + } + }.start() + sem.acquire() + assert(answer1 === 55) + assert(answer2 === 1) + } + + test("accessing SparkContext form multiple threads") { + sc = new SparkContext("local", "test") + val nums = sc.parallelize(1 to 10, 2) + val sem = new Semaphore(0) + @volatile var ok = true + for (i <- 0 until 10) { + new Thread { + override def run() { + val answer1 = nums.reduce(_ + _) + if (answer1 != 55) { + printf("In thread %d: answer1 was %d\n", i, answer1) + ok = false + } + val answer2 = nums.first() // This will run "locally" in the current thread + if (answer2 != 1) { + printf("In thread %d: answer2 was %d\n", i, answer2) + ok = false + } + sem.release() + } + }.start() + } + sem.acquire(10) + if (!ok) { + fail("One or more threads got the wrong answer from an RDD operation") + } + } + + test("accessing multi-threaded SparkContext form multiple threads") { + sc = new SparkContext("local[4]", "test") + val nums = sc.parallelize(1 to 10, 2) + val sem = new Semaphore(0) + @volatile var ok = true + for (i <- 0 until 10) { + new Thread { + override def run() { + val answer1 = nums.reduce(_ + _) + if (answer1 != 55) { + printf("In thread %d: answer1 was %d\n", i, answer1) + ok = false + } + val answer2 = nums.first() // This will run "locally" in the current thread + if (answer2 != 1) { + printf("In thread %d: answer2 was %d\n", i, answer2) + ok = false + } + sem.release() + } + }.start() + } + sem.acquire(10) + if (!ok) { + fail("One or more threads got the wrong answer from an RDD operation") + } + } + + test("parallel job execution") { + // This test launches two jobs with two threads each on a 4-core local cluster. Each thread + // waits until there are 4 threads running at once, to test that both jobs have been launched. + sc = new SparkContext("local[4]", "test") + val nums = sc.parallelize(1 to 2, 2) + val sem = new Semaphore(0) + ThreadingSuiteState.clear() + for (i <- 0 until 2) { + new Thread { + override def run() { + val ans = nums.map(number => { + val running = ThreadingSuiteState.runningThreads + running.getAndIncrement() + val time = System.currentTimeMillis() + while (running.get() != 4 && System.currentTimeMillis() < time + 1000) { + Thread.sleep(100) + } + if (running.get() != 4) { + println("Waited 1 second without seeing runningThreads = 4 (it was " + + running.get() + "); failing test") + ThreadingSuiteState.failed.set(true) + } + number + }).collect() + assert(ans.toList === List(1, 2)) + sem.release() + } + }.start() + } + sem.acquire(2) + if (ThreadingSuiteState.failed.get()) { + fail("One or more threads didn't see runningThreads = 4") + } + } +} diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala new file mode 100644 index 0000000000..46a2da1724 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.time.{Span, Millis} +import org.apache.spark.SparkContext._ + +class UnpersistSuite extends FunSuite with LocalSparkContext { + test("unpersist RDD") { + sc = new SparkContext("local", "test") + val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() + rdd.count + assert(sc.persistentRdds.isEmpty === false) + rdd.unpersist() + assert(sc.persistentRdds.isEmpty === true) + + failAfter(Span(3000, Millis)) { + try { + while (! sc.getRDDStorageInfo.isEmpty) { + Thread.sleep(200) + } + } catch { + case _ => { Thread.sleep(10) } + // Do nothing. We might see exceptions because block manager + // is racing this thread to remove entries from the driver. + } + } + assert(sc.getRDDStorageInfo.isEmpty === true) + } +} diff --git a/core/src/test/scala/org/apache/spark/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/UtilsSuite.scala new file mode 100644 index 0000000000..3a908720a8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/UtilsSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import com.google.common.base.Charsets +import com.google.common.io.Files +import java.io.{ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream, File} +import org.scalatest.FunSuite +import org.apache.commons.io.FileUtils +import scala.util.Random + +class UtilsSuite extends FunSuite { + + test("bytesToString") { + assert(Utils.bytesToString(10) === "10.0 B") + assert(Utils.bytesToString(1500) === "1500.0 B") + assert(Utils.bytesToString(2000000) === "1953.1 KB") + assert(Utils.bytesToString(2097152) === "2.0 MB") + assert(Utils.bytesToString(2306867) === "2.2 MB") + assert(Utils.bytesToString(5368709120L) === "5.0 GB") + assert(Utils.bytesToString(5L * 1024L * 1024L * 1024L * 1024L) === "5.0 TB") + } + + test("copyStream") { + //input array initialization + val bytes = Array.ofDim[Byte](9000) + Random.nextBytes(bytes) + + val os = new ByteArrayOutputStream() + Utils.copyStream(new ByteArrayInputStream(bytes), os) + + assert(os.toByteArray.toList.equals(bytes.toList)) + } + + test("memoryStringToMb") { + assert(Utils.memoryStringToMb("1") === 0) + assert(Utils.memoryStringToMb("1048575") === 0) + assert(Utils.memoryStringToMb("3145728") === 3) + + assert(Utils.memoryStringToMb("1024k") === 1) + assert(Utils.memoryStringToMb("5000k") === 4) + assert(Utils.memoryStringToMb("4024k") === Utils.memoryStringToMb("4024K")) + + assert(Utils.memoryStringToMb("1024m") === 1024) + assert(Utils.memoryStringToMb("5000m") === 5000) + assert(Utils.memoryStringToMb("4024m") === Utils.memoryStringToMb("4024M")) + + assert(Utils.memoryStringToMb("2g") === 2048) + assert(Utils.memoryStringToMb("3g") === Utils.memoryStringToMb("3G")) + + assert(Utils.memoryStringToMb("2t") === 2097152) + assert(Utils.memoryStringToMb("3t") === Utils.memoryStringToMb("3T")) + } + + test("splitCommandString") { + assert(Utils.splitCommandString("") === Seq()) + assert(Utils.splitCommandString("a") === Seq("a")) + assert(Utils.splitCommandString("aaa") === Seq("aaa")) + assert(Utils.splitCommandString("a b c") === Seq("a", "b", "c")) + assert(Utils.splitCommandString(" a b\t c ") === Seq("a", "b", "c")) + assert(Utils.splitCommandString("a 'b c'") === Seq("a", "b c")) + assert(Utils.splitCommandString("a 'b c' d") === Seq("a", "b c", "d")) + assert(Utils.splitCommandString("'b c'") === Seq("b c")) + assert(Utils.splitCommandString("a \"b c\"") === Seq("a", "b c")) + assert(Utils.splitCommandString("a \"b c\" d") === Seq("a", "b c", "d")) + assert(Utils.splitCommandString("\"b c\"") === Seq("b c")) + assert(Utils.splitCommandString("a 'b\" c' \"d' e\"") === Seq("a", "b\" c", "d' e")) + assert(Utils.splitCommandString("a\t'b\nc'\nd") === Seq("a", "b\nc", "d")) + assert(Utils.splitCommandString("a \"b\\\\c\"") === Seq("a", "b\\c")) + assert(Utils.splitCommandString("a \"b\\\"c\"") === Seq("a", "b\"c")) + assert(Utils.splitCommandString("a 'b\\\"c'") === Seq("a", "b\\\"c")) + assert(Utils.splitCommandString("'a'b") === Seq("ab")) + assert(Utils.splitCommandString("'a''b'") === Seq("ab")) + assert(Utils.splitCommandString("\"a\"b") === Seq("ab")) + assert(Utils.splitCommandString("\"a\"\"b\"") === Seq("ab")) + assert(Utils.splitCommandString("''") === Seq("")) + assert(Utils.splitCommandString("\"\"") === Seq("")) + } + + test("string formatting of time durations") { + val second = 1000 + val minute = second * 60 + val hour = minute * 60 + def str = Utils.msDurationToString(_) + + assert(str(123) === "123 ms") + assert(str(second) === "1.0 s") + assert(str(second + 462) === "1.5 s") + assert(str(hour) === "1.00 h") + assert(str(minute) === "1.0 m") + assert(str(minute + 4 * second + 34) === "1.1 m") + assert(str(10 * hour + minute + 4 * second) === "10.02 h") + assert(str(10 * hour + 59 * minute + 59 * second + 999) === "11.00 h") + } + + test("reading offset bytes of a file") { + val tmpDir2 = Files.createTempDir() + val f1Path = tmpDir2 + "/f1" + val f1 = new FileOutputStream(f1Path) + f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(Charsets.UTF_8)) + f1.close() + + // Read first few bytes + assert(Utils.offsetBytes(f1Path, 0, 5) === "1\n2\n3") + + // Read some middle bytes + assert(Utils.offsetBytes(f1Path, 4, 11) === "3\n4\n5\n6") + + // Read last few bytes + assert(Utils.offsetBytes(f1Path, 12, 18) === "7\n8\n9\n") + + // Read some nonexistent bytes in the beginning + assert(Utils.offsetBytes(f1Path, -5, 5) === "1\n2\n3") + + // Read some nonexistent bytes at the end + assert(Utils.offsetBytes(f1Path, 12, 22) === "7\n8\n9\n") + + // Read some nonexistent bytes on both ends + assert(Utils.offsetBytes(f1Path, -3, 25) === "1\n2\n3\n4\n5\n6\n7\n8\n9\n") + + FileUtils.deleteDirectory(tmpDir2) + } +} + diff --git a/core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala b/core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala new file mode 100644 index 0000000000..618b9c113b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.immutable.NumericRange + +import org.scalatest.FunSuite +import org.scalatest.prop.Checkers +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen +import org.scalacheck.Prop._ + +import SparkContext._ + + +object ZippedPartitionsSuite { + def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = { + Iterator(i.toArray.size, s.toArray.size, d.toArray.size) + } +} + +class ZippedPartitionsSuite extends FunSuite with SharedSparkContext { + test("print sizes") { + val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2) + val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2) + val data3 = sc.makeRDD(Array(1.0, 2.0), 2) + + val zippedRDD = data1.zipPartitions(data2, data3)(ZippedPartitionsSuite.procZippedData) + + val obtainedSizes = zippedRDD.collect() + val expectedSizes = Array(2, 3, 1, 2, 3, 1) + assert(obtainedSizes.size == 6) + assert(obtainedSizes.zip(expectedSizes).forall(x => x._1 == x._2)) + } +} diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala new file mode 100644 index 0000000000..fd6f69041a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.io + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import org.scalatest.FunSuite + + +class CompressionCodecSuite extends FunSuite { + + def testCodec(codec: CompressionCodec) { + // Write 1000 integers to the output stream, compressed. + val outputStream = new ByteArrayOutputStream() + val out = codec.compressedOutputStream(outputStream) + for (i <- 1 until 1000) { + out.write(i % 256) + } + out.close() + + // Read the 1000 integers back. + val inputStream = new ByteArrayInputStream(outputStream.toByteArray) + val in = codec.compressedInputStream(inputStream) + for (i <- 1 until 1000) { + assert(in.read() === i % 256) + } + in.close() + } + + test("default compression codec") { + val codec = CompressionCodec.createCodec() + assert(codec.getClass === classOf[SnappyCompressionCodec]) + testCodec(codec) + } + + test("lzf compression codec") { + val codec = CompressionCodec.createCodec(classOf[LZFCompressionCodec].getName) + assert(codec.getClass === classOf[LZFCompressionCodec]) + testCodec(codec) + } + + test("snappy compression codec") { + val codec = CompressionCodec.createCodec(classOf[SnappyCompressionCodec].getName) + assert(codec.getClass === classOf[SnappyCompressionCodec]) + testCodec(codec) + } +} diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala new file mode 100644 index 0000000000..58c94a162d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics + +import org.scalatest.{BeforeAndAfter, FunSuite} + +class MetricsConfigSuite extends FunSuite with BeforeAndAfter { + var filePath: String = _ + + before { + filePath = getClass.getClassLoader.getResource("test_metrics_config.properties").getFile() + } + + test("MetricsConfig with default properties") { + val conf = new MetricsConfig(Option("dummy-file")) + conf.initialize() + + assert(conf.properties.size() === 5) + assert(conf.properties.getProperty("test-for-dummy") === null) + + val property = conf.getInstance("random") + assert(property.size() === 3) + assert(property.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(property.getProperty("sink.servlet.uri") === "/metrics/json") + assert(property.getProperty("sink.servlet.sample") === "false") + } + + test("MetricsConfig with properties set") { + val conf = new MetricsConfig(Option(filePath)) + conf.initialize() + + val masterProp = conf.getInstance("master") + assert(masterProp.size() === 6) + assert(masterProp.getProperty("sink.console.period") === "20") + assert(masterProp.getProperty("sink.console.unit") === "minutes") + assert(masterProp.getProperty("source.jvm.class") === "org.apache.spark.metrics.source.JvmSource") + assert(masterProp.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(masterProp.getProperty("sink.servlet.uri") === "/metrics/master/json") + assert(masterProp.getProperty("sink.servlet.sample") === "false") + + val workerProp = conf.getInstance("worker") + assert(workerProp.size() === 6) + assert(workerProp.getProperty("sink.console.period") === "10") + assert(workerProp.getProperty("sink.console.unit") === "seconds") + assert(workerProp.getProperty("source.jvm.class") === "org.apache.spark.metrics.source.JvmSource") + assert(workerProp.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet") + assert(workerProp.getProperty("sink.servlet.uri") === "/metrics/json") + assert(workerProp.getProperty("sink.servlet.sample") === "false") + } + + test("MetricsConfig with subProperties") { + val conf = new MetricsConfig(Option(filePath)) + conf.initialize() + + val propCategories = conf.propertyCategories + assert(propCategories.size === 3) + + val masterProp = conf.getInstance("master") + val sourceProps = conf.subProperties(masterProp, MetricsSystem.SOURCE_REGEX) + assert(sourceProps.size === 1) + assert(sourceProps("jvm").getProperty("class") === "org.apache.spark.metrics.source.JvmSource") + + val sinkProps = conf.subProperties(masterProp, MetricsSystem.SINK_REGEX) + assert(sinkProps.size === 2) + assert(sinkProps.contains("console")) + assert(sinkProps.contains("servlet")) + + val consoleProps = sinkProps("console") + assert(consoleProps.size() === 2) + + val servletProps = sinkProps("servlet") + assert(servletProps.size() === 3) + } +} diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala new file mode 100644 index 0000000000..7181333adf --- /dev/null +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics + +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.apache.spark.deploy.master.MasterSource + +class MetricsSystemSuite extends FunSuite with BeforeAndAfter { + var filePath: String = _ + + before { + filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile() + System.setProperty("spark.metrics.conf", filePath) + } + + test("MetricsSystem with default config") { + val metricsSystem = MetricsSystem.createMetricsSystem("default") + val sources = metricsSystem.sources + val sinks = metricsSystem.sinks + + assert(sources.length === 0) + assert(sinks.length === 0) + assert(!metricsSystem.getServletHandlers.isEmpty) + } + + test("MetricsSystem with sources add") { + val metricsSystem = MetricsSystem.createMetricsSystem("test") + val sources = metricsSystem.sources + val sinks = metricsSystem.sinks + + assert(sources.length === 0) + assert(sinks.length === 1) + assert(!metricsSystem.getServletHandlers.isEmpty) + + val source = new MasterSource(null) + metricsSystem.registerSource(source) + assert(sources.length === 1) + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala new file mode 100644 index 0000000000..3d39a31252 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.{ BeforeAndAfter, FunSuite } +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.JdbcRDD +import java.sql._ + +class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { + + before { + Class.forName("org.apache.derby.jdbc.EmbeddedDriver") + val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true") + try { + val create = conn.createStatement + create.execute(""" + CREATE TABLE FOO( + ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), + DATA INTEGER + )""") + create.close + val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)") + (1 to 100).foreach { i => + insert.setInt(1, i * 2) + insert.executeUpdate + } + insert.close + } catch { + case e: SQLException if e.getSQLState == "X0Y32" => + // table exists + } finally { + conn.close + } + } + + test("basic functionality") { + sc = new SparkContext("local", "test") + val rdd = new JdbcRDD( + sc, + () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") }, + "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?", + 1, 100, 3, + (r: ResultSet) => { r.getInt(1) } ).cache + + assert(rdd.count === 100) + assert(rdd.reduce(_+_) === 10100) + } + + after { + try { + DriverManager.getConnection("jdbc:derby:;shutdown=true") + } catch { + case se: SQLException if se.getSQLState == "XJ015" => + // normal shutdown + } + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala new file mode 100644 index 0000000000..a80afdee7e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.collection.immutable.NumericRange + +import org.scalatest.FunSuite +import org.scalatest.prop.Checkers +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen +import org.scalacheck.Prop._ + +class ParallelCollectionSplitSuite extends FunSuite with Checkers { + test("one element per slice") { + val data = Array(1, 2, 3) + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices(0).mkString(",") === "1") + assert(slices(1).mkString(",") === "2") + assert(slices(2).mkString(",") === "3") + } + + test("one slice") { + val data = Array(1, 2, 3) + val slices = ParallelCollectionRDD.slice(data, 1) + assert(slices.size === 1) + assert(slices(0).mkString(",") === "1,2,3") + } + + test("equal slices") { + val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9) + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices(0).mkString(",") === "1,2,3") + assert(slices(1).mkString(",") === "4,5,6") + assert(slices(2).mkString(",") === "7,8,9") + } + + test("non-equal slices") { + val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices(0).mkString(",") === "1,2,3") + assert(slices(1).mkString(",") === "4,5,6") + assert(slices(2).mkString(",") === "7,8,9,10") + } + + test("splitting exclusive range") { + val data = 0 until 100 + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices(0).mkString(",") === (0 to 32).mkString(",")) + assert(slices(1).mkString(",") === (33 to 65).mkString(",")) + assert(slices(2).mkString(",") === (66 to 99).mkString(",")) + } + + test("splitting inclusive range") { + val data = 0 to 100 + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices(0).mkString(",") === (0 to 32).mkString(",")) + assert(slices(1).mkString(",") === (33 to 66).mkString(",")) + assert(slices(2).mkString(",") === (67 to 100).mkString(",")) + } + + test("empty data") { + val data = new Array[Int](0) + val slices = ParallelCollectionRDD.slice(data, 5) + assert(slices.size === 5) + for (slice <- slices) assert(slice.size === 0) + } + + test("zero slices") { + val data = Array(1, 2, 3) + intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, 0) } + } + + test("negative number of slices") { + val data = Array(1, 2, 3) + intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, -5) } + } + + test("exclusive ranges sliced into ranges") { + val data = 1 until 100 + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.forall(_.isInstanceOf[Range])) + } + + test("inclusive ranges sliced into ranges") { + val data = 1 to 100 + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.forall(_.isInstanceOf[Range])) + } + + test("large ranges don't overflow") { + val N = 100 * 1000 * 1000 + val data = 0 until N + val slices = ParallelCollectionRDD.slice(data, 40) + assert(slices.size === 40) + for (i <- 0 until 40) { + assert(slices(i).isInstanceOf[Range]) + val range = slices(i).asInstanceOf[Range] + assert(range.start === i * (N / 40), "slice " + i + " start") + assert(range.end === (i+1) * (N / 40), "slice " + i + " end") + assert(range.step === 1, "slice " + i + " step") + } + } + + test("random array tests") { + val gen = for { + d <- arbitrary[List[Int]] + n <- Gen.choose(1, 100) + } yield (d, n) + val prop = forAll(gen) { + (tuple: (List[Int], Int)) => + val d = tuple._1 + val n = tuple._2 + val slices = ParallelCollectionRDD.slice(d, n) + ("n slices" |: slices.size == n) && + ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && + ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + } + check(prop) + } + + test("random exclusive range tests") { + val gen = for { + a <- Gen.choose(-100, 100) + b <- Gen.choose(-100, 100) + step <- Gen.choose(-5, 5) suchThat (_ != 0) + n <- Gen.choose(1, 100) + } yield (a until b by step, n) + val prop = forAll(gen) { + case (d: Range, n: Int) => + val slices = ParallelCollectionRDD.slice(d, n) + ("n slices" |: slices.size == n) && + ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && + ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && + ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + } + check(prop) + } + + test("random inclusive range tests") { + val gen = for { + a <- Gen.choose(-100, 100) + b <- Gen.choose(-100, 100) + step <- Gen.choose(-5, 5) suchThat (_ != 0) + n <- Gen.choose(1, 100) + } yield (a to b by step, n) + val prop = forAll(gen) { + case (d: Range, n: Int) => + val slices = ParallelCollectionRDD.slice(d, n) + ("n slices" |: slices.size == n) && + ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && + ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && + ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) + } + check(prop) + } + + test("exclusive ranges of longs") { + val data = 1L until 100L + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.forall(_.isInstanceOf[NumericRange[_]])) + } + + test("inclusive ranges of longs") { + val data = 1L to 100L + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.forall(_.isInstanceOf[NumericRange[_]])) + } + + test("exclusive ranges of doubles") { + val data = 1.0 until 100.0 by 1.0 + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.forall(_.isInstanceOf[NumericRange[_]])) + } + + test("inclusive ranges of doubles") { + val data = 1.0 to 100.0 by 1.0 + val slices = ParallelCollectionRDD.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.forall(_.isInstanceOf[NumericRange[_]])) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala new file mode 100644 index 0000000000..94df282b28 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.collection.mutable.{Map, HashMap} + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import org.apache.spark.LocalSparkContext +import org.apache.spark.MapOutputTracker +import org.apache.spark.RDD +import org.apache.spark.SparkContext +import org.apache.spark.Partition +import org.apache.spark.TaskContext +import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency} +import org.apache.spark.{FetchFailed, Success, TaskEndReason} +import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster} + +import org.apache.spark.scheduler.cluster.Pool +import org.apache.spark.scheduler.cluster.SchedulingMode +import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode + +/** + * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler + * rather than spawning an event loop thread as happens in the real code. They use EasyMock + * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are + * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead + * host notifications are sent). In addition, tests may check for side effects on a non-mocked + * MapOutputTracker instance. + * + * Tests primarily consist of running DAGScheduler#processEvent and + * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet) + * and capturing the resulting TaskSets from the mock TaskScheduler. + */ +class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { + + /** Set of TaskSets the DAGScheduler has requested executed. */ + val taskSets = scala.collection.mutable.Buffer[TaskSet]() + val taskScheduler = new TaskScheduler() { + override def rootPool: Pool = null + override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def start() = {} + override def stop() = {} + override def submitTasks(taskSet: TaskSet) = { + // normally done by TaskSetManager + taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) + taskSets += taskSet + } + override def setListener(listener: TaskSchedulerListener) = {} + override def defaultParallelism() = 2 + } + + var mapOutputTracker: MapOutputTracker = null + var scheduler: DAGScheduler = null + + /** + * Set of cache locations to return from our mock BlockManagerMaster. + * Keys are (rdd ID, partition ID). Anything not present will return an empty + * list of cache locations silently. + */ + val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] + // stub out BlockManagerMaster.getLocations to use our cacheLocations + val blockManagerMaster = new BlockManagerMaster(null) { + override def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + blockIds.map { name => + val pieces = name.split("_") + if (pieces(0) == "rdd") { + val key = pieces(1).toInt -> pieces(2).toInt + cacheLocations.getOrElse(key, Seq()) + } else { + Seq() + } + }.toSeq + } + override def removeExecutor(execId: String) { + // don't need to propagate to the driver, which we don't have + } + } + + /** The list of results that DAGScheduler has collected. */ + val results = new HashMap[Int, Any]() + var failure: Exception = _ + val listener = new JobListener() { + override def taskSucceeded(index: Int, result: Any) = results.put(index, result) + override def jobFailed(exception: Exception) = { failure = exception } + } + + before { + sc = new SparkContext("local", "DAGSchedulerSuite") + taskSets.clear() + cacheLocations.clear() + results.clear() + mapOutputTracker = new MapOutputTracker() + scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) { + override def runLocally(job: ActiveJob) { + // don't bother with the thread while unit testing + runLocallyWithinThread(job) + } + } + } + + after { + scheduler.stop() + } + + /** + * Type of RDD we use for testing. Note that we should never call the real RDD compute methods. + * This is a pair RDD type so it can always be used in ShuffleDependencies. + */ + type MyRDD = RDD[(Int, Int)] + + /** + * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and + * preferredLocations (if any) that are passed to them. They are deliberately not executable + * so we can test that DAGScheduler does not try to execute RDDs locally. + */ + private def makeRdd( + numPartitions: Int, + dependencies: List[Dependency[_]], + locations: Seq[Seq[String]] = Nil + ): MyRDD = { + val maxPartition = numPartitions - 1 + return new MyRDD(sc, dependencies) { + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + throw new RuntimeException("should not be reached") + override def getPartitions = (0 to maxPartition).map(i => new Partition { + override def index = i + }).toArray + override def getPreferredLocations(split: Partition): Seq[String] = + if (locations.isDefinedAt(split.index)) + locations(split.index) + else + Nil + override def toString: String = "DAGSchedulerSuiteRDD " + id + } + } + + /** + * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting + * the scheduler not to exit. + * + * After processing the event, submit waiting stages as is done on most iterations of the + * DAGScheduler event loop. + */ + private def runEvent(event: DAGSchedulerEvent) { + assert(!scheduler.processEvent(event)) + scheduler.submitWaitingStages() + } + + /** + * When we submit dummy Jobs, this is the compute function we supply. Except in a local test + * below, we do not expect this function to ever be executed; instead, we will return results + * directly through CompletionEvents. + */ + private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) => + it.next.asInstanceOf[Tuple2[_, _]]._1 + + /** Send the given CompletionEvent messages for the tasks in the TaskSet. */ + private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) { + assert(taskSet.tasks.size >= results.size) + for ((result, i) <- results.zipWithIndex) { + if (i < taskSet.tasks.size) { + runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any](), null, null)) + } + } + } + + /** Sends the rdd to the scheduler for scheduling. */ + private def submit( + rdd: RDD[_], + partitions: Array[Int], + func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, + allowLocal: Boolean = false, + listener: JobListener = listener) { + runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener)) + } + + /** Sends TaskSetFailed to the scheduler. */ + private def failed(taskSet: TaskSet, message: String) { + runEvent(TaskSetFailed(taskSet, message)) + } + + test("zero split job") { + val rdd = makeRdd(0, Nil) + var numResults = 0 + val fakeListener = new JobListener() { + override def taskSucceeded(partition: Int, value: Any) = numResults += 1 + override def jobFailed(exception: Exception) = throw exception + } + submit(rdd, Array(), listener = fakeListener) + assert(numResults === 0) + } + + test("run trivial job") { + val rdd = makeRdd(1, Nil) + submit(rdd, Array(0)) + complete(taskSets(0), List((Success, 42))) + assert(results === Map(0 -> 42)) + } + + test("local job") { + val rdd = new MyRDD(sc, Nil) { + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + Array(42 -> 0).iterator + override def getPartitions = Array( new Partition { override def index = 0 } ) + override def getPreferredLocations(split: Partition) = Nil + override def toString = "DAGSchedulerSuite Local RDD" + } + runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener)) + assert(results === Map(0 -> 42)) + } + + test("run trivial job w/ dependency") { + val baseRdd = makeRdd(1, Nil) + val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) + submit(finalRdd, Array(0)) + complete(taskSets(0), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + } + + test("cache location preferences w/ dependency") { + val baseRdd = makeRdd(1, Nil) + val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) + cacheLocations(baseRdd.id -> 0) = + Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) + submit(finalRdd, Array(0)) + val taskSet = taskSets(0) + assertLocations(taskSet, Seq(Seq("hostA", "hostB"))) + complete(taskSet, Seq((Success, 42))) + assert(results === Map(0 -> 42)) + } + + test("trivial job failure") { + submit(makeRdd(1, Nil), Array(0)) + failed(taskSets(0), "some failure") + assert(failure.getMessage === "Job failed: some failure") + } + + test("run trivial shuffle") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = makeRdd(1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + complete(taskSets(1), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + } + + test("run trivial shuffle with fetch failure") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = makeRdd(2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) + // the 2nd ResultTask failed + complete(taskSets(1), Seq( + (Success, 42), + (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null))) + // this will get called + // blockManagerMaster.removeExecutor("exec-hostA") + // ask the scheduler to try it again + scheduler.resubmitFailedStages() + // have the 2nd attempt pass + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) + // we can see both result blocks now + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) + complete(taskSets(3), Seq((Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + } + + test("ignore late map task completions") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = makeRdd(2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + // pretend we were told hostA went away + val oldEpoch = mapOutputTracker.getEpoch + runEvent(ExecutorLost("exec-hostA")) + val newEpoch = mapOutputTracker.getEpoch + assert(newEpoch > oldEpoch) + val noAccum = Map[Long, Any]() + val taskSet = taskSets(0) + // should be ignored for being too old + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null)) + // should work because it's a non-failed host + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum, null, null)) + // should be ignored for being too old + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null)) + // should work because it's a new epoch + taskSet.tasks(1).epoch = newEpoch + runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null)) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) + complete(taskSets(1), Seq((Success, 42), (Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + } + + test("run trivial shuffle with out-of-band failure and retry") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = makeRdd(1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + // blockManagerMaster.removeExecutor("exec-hostA") + // pretend we were told hostA went away + runEvent(ExecutorLost("exec-hostA")) + // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks + // rather than marking it is as failed and waiting. + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) + // have hostC complete the resubmitted task + complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + complete(taskSets(2), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + } + + test("recursive shuffle failures") { + val shuffleOneRdd = makeRdd(2, Nil) + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val finalRdd = makeRdd(1, List(shuffleDepTwo)) + submit(finalRdd, Array(0)) + // have the first stage complete normally + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + // have the second stage complete normally + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostC", 1)))) + // fail the third stage because hostA went down + complete(taskSets(2), Seq( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) + // TODO assert this: + // blockManagerMaster.removeExecutor("exec-hostA") + // have DAGScheduler try again + scheduler.resubmitFailedStages() + complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 2)))) + complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1)))) + complete(taskSets(5), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + } + + test("cached post-shuffle") { + val shuffleOneRdd = makeRdd(2, Nil) + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val finalRdd = makeRdd(1, List(shuffleDepTwo)) + submit(finalRdd, Array(0)) + cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) + cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) + // complete stage 2 + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + // complete stage 1 + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) + // pretend stage 0 failed because hostA went down + complete(taskSets(2), Seq( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) + // TODO assert this: + // blockManagerMaster.removeExecutor("exec-hostA") + // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. + scheduler.resubmitFailedStages() + assertLocations(taskSets(3), Seq(Seq("hostD"))) + // allow hostD to recover + complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1)))) + complete(taskSets(4), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + } + + /** + * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. + * Note that this checks only the host and not the executor ID. + */ + private def assertLocations(taskSet: TaskSet, hosts: Seq[Seq[String]]) { + assert(hosts.size === taskSet.tasks.size) + for ((taskLocs, expectedLocs) <- taskSet.tasks.map(_.preferredLocations).zip(hosts)) { + assert(taskLocs.map(_.host) === expectedLocs) + } + } + + private def makeMapStatus(host: String, reduces: Int): MapStatus = + new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2)) + + private def makeBlockManagerId(host: String): BlockManagerId = + BlockManagerId("exec-" + host, host, 12345, 0) + +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala new file mode 100644 index 0000000000..f5b3e97222 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.Properties +import java.util.concurrent.LinkedBlockingQueue +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import scala.collection.mutable +import org.apache.spark._ +import org.apache.spark.SparkContext._ + + +class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { + + test("inner method") { + sc = new SparkContext("local", "joblogger") + val joblogger = new JobLogger { + def createLogWriterTest(jobID: Int) = createLogWriter(jobID) + def closeLogWriterTest(jobID: Int) = closeLogWriter(jobID) + def getRddNameTest(rdd: RDD[_]) = getRddName(rdd) + def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage) + } + type MyRDD = RDD[(Int, Int)] + def makeRdd( + numPartitions: Int, + dependencies: List[Dependency[_]] + ): MyRDD = { + val maxPartition = numPartitions - 1 + return new MyRDD(sc, dependencies) { + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + throw new RuntimeException("should not be reached") + override def getPartitions = (0 to maxPartition).map(i => new Partition { + override def index = i + }).toArray + } + } + val jobID = 5 + val parentRdd = makeRdd(4, Nil) + val shuffleDep = new ShuffleDependency(parentRdd, null) + val rootRdd = makeRdd(4, List(shuffleDep)) + val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID, None) + val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID, None) + + joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4, null)) + joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName) + parentRdd.setName("MyRDD") + joblogger.getRddNameTest(parentRdd) should be ("MyRDD") + joblogger.createLogWriterTest(jobID) + joblogger.getJobIDtoPrintWriter.size should be (1) + joblogger.buildJobDepTest(jobID, rootStage) + joblogger.getJobIDToStages.get(jobID).get.size should be (2) + joblogger.getStageIDToJobID.get(0) should be (Some(jobID)) + joblogger.getStageIDToJobID.get(1) should be (Some(jobID)) + joblogger.closeLogWriterTest(jobID) + joblogger.getStageIDToJobID.size should be (0) + joblogger.getJobIDToStages.size should be (0) + joblogger.getJobIDtoPrintWriter.size should be (0) + } + + test("inner variables") { + sc = new SparkContext("local[4]", "joblogger") + val joblogger = new JobLogger { + override protected def closeLogWriter(jobID: Int) = + getJobIDtoPrintWriter.get(jobID).foreach { fileWriter => + fileWriter.close() + } + } + sc.addSparkListener(joblogger) + val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } + rdd.reduceByKey(_+_).collect() + + joblogger.getLogDir should be ("/tmp/spark") + joblogger.getJobIDtoPrintWriter.size should be (1) + joblogger.getStageIDToJobID.size should be (2) + joblogger.getStageIDToJobID.get(0) should be (Some(0)) + joblogger.getStageIDToJobID.get(1) should be (Some(0)) + joblogger.getJobIDToStages.size should be (1) + } + + + test("interface functions") { + sc = new SparkContext("local[4]", "joblogger") + val joblogger = new JobLogger { + var onTaskEndCount = 0 + var onJobEndCount = 0 + var onJobStartCount = 0 + var onStageCompletedCount = 0 + var onStageSubmittedCount = 0 + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1 + override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1 + override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1 + override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1 + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1 + } + sc.addSparkListener(joblogger) + val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } + rdd.reduceByKey(_+_).collect() + + joblogger.onJobStartCount should be (1) + joblogger.onJobEndCount should be (1) + joblogger.onTaskEndCount should be (8) + joblogger.onStageSubmittedCount should be (2) + joblogger.onStageCompletedCount should be (2) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala new file mode 100644 index 0000000000..aac7c207cb --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.scalatest.FunSuite +import org.apache.spark.{SparkContext, LocalSparkContext} +import scala.collection.mutable +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.SparkContext._ + +/** + * + */ + +class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { + + test("local metrics") { + sc = new SparkContext("local[4]", "test") + val listener = new SaveStageInfo + sc.addSparkListener(listener) + sc.addSparkListener(new StatsReportListener) + //just to make sure some of the tasks take a noticeable amount of time + val w = {i:Int => + if (i == 0) + Thread.sleep(100) + i + } + + val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)} + d.count + listener.stageInfos.size should be (1) + + val d2 = d.map{i => w(i) -> i * 2}.setName("shuffle input 1") + + val d3 = d.map{i => w(i) -> (0 to (i % 5))}.setName("shuffle input 2") + + val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => w(k) -> (v1.size, v2.size)} + d4.setName("A Cogroup") + + d4.collectAsMap + + listener.stageInfos.size should be (4) + listener.stageInfos.foreach {stageInfo => + //small test, so some tasks might take less than 1 millisecond, but average should be greater than 1 ms + checkNonZeroAvg(stageInfo.taskInfos.map{_._1.duration}, stageInfo + " duration") + checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorRunTime.toLong}, stageInfo + " executorRunTime") + checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong}, stageInfo + " executorDeserializeTime") + if (stageInfo.stage.rdd.name == d4.name) { + checkNonZeroAvg(stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime}, stageInfo + " fetchWaitTime") + } + + stageInfo.taskInfos.foreach{case (taskInfo, taskMetrics) => + taskMetrics.resultSize should be > (0l) + if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) { + taskMetrics.shuffleWriteMetrics should be ('defined) + taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l) + } + if (stageInfo.stage.rdd.name == d4.name) { + taskMetrics.shuffleReadMetrics should be ('defined) + val sm = taskMetrics.shuffleReadMetrics.get + sm.totalBlocksFetched should be > (0) + sm.localBlocksFetched should be > (0) + sm.remoteBlocksFetched should be (0) + sm.remoteBytesRead should be (0l) + sm.remoteFetchTime should be (0l) + } + } + } + } + + def checkNonZeroAvg(m: Traversable[Long], msg: String) { + assert(m.sum / m.size.toDouble > 0.0, msg) + } + + def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = { + val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name} + !names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty + } + + class SaveStageInfo extends SparkListener { + val stageInfos = mutable.Buffer[StageInfo]() + override def onStageCompleted(stage: StageCompleted) { + stageInfos += stage.stageInfo + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala new file mode 100644 index 0000000000..0347cc02d7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import org.apache.spark.TaskContext +import org.apache.spark.RDD +import org.apache.spark.SparkContext +import org.apache.spark.Partition +import org.apache.spark.LocalSparkContext + +class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { + + test("Calls executeOnCompleteCallbacks after failure") { + var completed = false + sc = new SparkContext("local", "test") + val rdd = new RDD[String](sc, List()) { + override def getPartitions = Array[Partition](StubPartition(0)) + override def compute(split: Partition, context: TaskContext) = { + context.addOnCompleteCallback(() => completed = true) + sys.error("failed") + } + } + val func = (c: TaskContext, i: Iterator[String]) => i.next + val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0) + intercept[RuntimeException] { + task.run(0) + } + assert(completed === true) + } + + case class StubPartition(val index: Int) extends Partition +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala new file mode 100644 index 0000000000..92ad9f09b2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import org.apache.spark._ +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster._ +import scala.collection.mutable.ArrayBuffer + +import java.util.Properties + +class FakeTaskSetManager( + initPriority: Int, + initStageId: Int, + initNumTasks: Int, + clusterScheduler: ClusterScheduler, + taskSet: TaskSet) + extends ClusterTaskSetManager(clusterScheduler, taskSet) { + + parent = null + weight = 1 + minShare = 2 + runningTasks = 0 + priority = initPriority + stageId = initStageId + name = "TaskSet_"+stageId + override val numTasks = initNumTasks + tasksFinished = 0 + + override def increaseRunningTasks(taskNum: Int) { + runningTasks += taskNum + if (parent != null) { + parent.increaseRunningTasks(taskNum) + } + } + + override def decreaseRunningTasks(taskNum: Int) { + runningTasks -= taskNum + if (parent != null) { + parent.decreaseRunningTasks(taskNum) + } + } + + override def addSchedulable(schedulable: Schedulable) { + } + + override def removeSchedulable(schedulable: Schedulable) { + } + + override def getSchedulableByName(name: String): Schedulable = { + return null + } + + override def executorLost(executorId: String, host: String): Unit = { + } + + override def resourceOffer( + execId: String, + host: String, + availableCpus: Int, + maxLocality: TaskLocality.TaskLocality) + : Option[TaskDescription] = + { + if (tasksFinished + runningTasks < numTasks) { + increaseRunningTasks(1) + return Some(new TaskDescription(0, execId, "task 0:0", 0, null)) + } + return None + } + + override def checkSpeculatableTasks(): Boolean = { + return true + } + + def taskFinished() { + decreaseRunningTasks(1) + tasksFinished +=1 + if (tasksFinished == numTasks) { + parent.removeSchedulable(this) + } + } + + def abort() { + decreaseRunningTasks(runningTasks) + parent.removeSchedulable(this) + } +} + +class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging { + + def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): FakeTaskSetManager = { + new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet) + } + + def resourceOffer(rootPool: Pool): Int = { + val taskSetQueue = rootPool.getSortedTaskSetQueue() + /* Just for Test*/ + for (manager <- taskSetQueue) { + logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks)) + } + for (taskSet <- taskSetQueue) { + taskSet.resourceOffer("execId_1", "hostname_1", 1, TaskLocality.ANY) match { + case Some(task) => + return taskSet.stageId + case None => {} + } + } + -1 + } + + def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) { + assert(resourceOffer(rootPool) === expectedTaskSetId) + } + + test("FIFO Scheduler Test") { + sc = new SparkContext("local", "ClusterSchedulerSuite") + val clusterScheduler = new ClusterScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new FakeTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + + val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0) + val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) + schedulableBuilder.buildPools() + + val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, clusterScheduler, taskSet) + val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, clusterScheduler, taskSet) + val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, clusterScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager0, null) + schedulableBuilder.addTaskSetManager(taskSetManager1, null) + schedulableBuilder.addTaskSetManager(taskSetManager2, null) + + checkTaskSetId(rootPool, 0) + resourceOffer(rootPool) + checkTaskSetId(rootPool, 1) + resourceOffer(rootPool) + taskSetManager1.abort() + checkTaskSetId(rootPool, 2) + } + + test("Fair Scheduler Test") { + sc = new SparkContext("local", "ClusterSchedulerSuite") + val clusterScheduler = new ClusterScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new FakeTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + System.setProperty("spark.fairscheduler.allocation.file", xmlPath) + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool) + schedulableBuilder.buildPools() + + assert(rootPool.getSchedulableByName("default") != null) + assert(rootPool.getSchedulableByName("1") != null) + assert(rootPool.getSchedulableByName("2") != null) + assert(rootPool.getSchedulableByName("3") != null) + assert(rootPool.getSchedulableByName("1").minShare === 2) + assert(rootPool.getSchedulableByName("1").weight === 1) + assert(rootPool.getSchedulableByName("2").minShare === 3) + assert(rootPool.getSchedulableByName("2").weight === 1) + assert(rootPool.getSchedulableByName("3").minShare === 2) + assert(rootPool.getSchedulableByName("3").weight === 1) + + val properties1 = new Properties() + properties1.setProperty("spark.scheduler.cluster.fair.pool","1") + val properties2 = new Properties() + properties2.setProperty("spark.scheduler.cluster.fair.pool","2") + + val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, clusterScheduler, taskSet) + val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, clusterScheduler, taskSet) + val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, clusterScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager10, properties1) + schedulableBuilder.addTaskSetManager(taskSetManager11, properties1) + schedulableBuilder.addTaskSetManager(taskSetManager12, properties1) + + val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, clusterScheduler, taskSet) + val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, clusterScheduler, taskSet) + schedulableBuilder.addTaskSetManager(taskSetManager23, properties2) + schedulableBuilder.addTaskSetManager(taskSetManager24, properties2) + + checkTaskSetId(rootPool, 0) + checkTaskSetId(rootPool, 3) + checkTaskSetId(rootPool, 3) + checkTaskSetId(rootPool, 1) + checkTaskSetId(rootPool, 4) + checkTaskSetId(rootPool, 2) + checkTaskSetId(rootPool, 2) + checkTaskSetId(rootPool, 4) + + taskSetManager12.taskFinished() + assert(rootPool.getSchedulableByName("1").runningTasks === 3) + taskSetManager24.abort() + assert(rootPool.getSchedulableByName("2").runningTasks === 2) + } + + test("Nested Pool Test") { + sc = new SparkContext("local", "ClusterSchedulerSuite") + val clusterScheduler = new ClusterScheduler(sc) + var tasks = ArrayBuffer[Task[_]]() + val task = new FakeTask(0) + tasks += task + val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + + val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1) + val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1) + rootPool.addSchedulable(pool0) + rootPool.addSchedulable(pool1) + + val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2) + val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1) + pool0.addSchedulable(pool00) + pool0.addSchedulable(pool01) + + val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2) + val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1) + pool1.addSchedulable(pool10) + pool1.addSchedulable(pool11) + + val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, clusterScheduler, taskSet) + val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, clusterScheduler, taskSet) + pool00.addSchedulable(taskSetManager000) + pool00.addSchedulable(taskSetManager001) + + val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, clusterScheduler, taskSet) + val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, clusterScheduler, taskSet) + pool01.addSchedulable(taskSetManager010) + pool01.addSchedulable(taskSetManager011) + + val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, clusterScheduler, taskSet) + val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, clusterScheduler, taskSet) + pool10.addSchedulable(taskSetManager100) + pool10.addSchedulable(taskSetManager101) + + val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, clusterScheduler, taskSet) + val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, clusterScheduler, taskSet) + pool11.addSchedulable(taskSetManager110) + pool11.addSchedulable(taskSetManager111) + + checkTaskSetId(rootPool, 0) + checkTaskSetId(rootPool, 4) + checkTaskSetId(rootPool, 6) + checkTaskSetId(rootPool, 2) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala new file mode 100644 index 0000000000..a4f63baf3d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable + +import org.scalatest.FunSuite + +import org.apache.spark._ +import org.apache.spark.scheduler._ +import org.apache.spark.executor.TaskMetrics +import java.nio.ByteBuffer +import org.apache.spark.util.FakeClock + +/** + * A mock ClusterScheduler implementation that just remembers information about tasks started and + * feedback received from the TaskSetManagers. Note that it's important to initialize this with + * a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost + * to work, and these are required for locality in ClusterTaskSetManager. + */ +class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */) + extends ClusterScheduler(sc) +{ + val startedTasks = new ArrayBuffer[Long] + val endedTasks = new mutable.HashMap[Long, TaskEndReason] + val finishedManagers = new ArrayBuffer[TaskSetManager] + + val executors = new mutable.HashMap[String, String] ++ liveExecutors + + listener = new TaskSchedulerListener { + def taskStarted(task: Task[_], taskInfo: TaskInfo) { + startedTasks += taskInfo.index + } + + def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: mutable.Map[Long, Any], + taskInfo: TaskInfo, + taskMetrics: TaskMetrics) + { + endedTasks(taskInfo.index) = reason + } + + def executorGained(execId: String, host: String) {} + + def executorLost(execId: String) {} + + def taskSetFailed(taskSet: TaskSet, reason: String) {} + } + + def removeExecutor(execId: String): Unit = executors -= execId + + override def taskSetFinished(manager: TaskSetManager): Unit = finishedManagers += manager + + override def isExecutorAlive(execId: String): Boolean = executors.contains(execId) + + override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host) +} + +class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { + import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL} + + val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong + + test("TaskSet with no preferences") { + sc = new SparkContext("local", "test") + val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) + val taskSet = createTaskSet(1) + val manager = new ClusterTaskSetManager(sched, taskSet) + + // Offer a host with no CPUs + assert(manager.resourceOffer("exec1", "host1", 0, ANY) === None) + + // Offer a host with process-local as the constraint; this should work because the TaskSet + // above won't have any locality preferences + val taskOption = manager.resourceOffer("exec1", "host1", 2, TaskLocality.PROCESS_LOCAL) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === "exec1") + assert(sched.startedTasks.contains(0)) + + // Re-offer the host -- now we should get no more tasks + assert(manager.resourceOffer("exec1", "host1", 2, PROCESS_LOCAL) === None) + + // Tell it the task has finished + manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0)) + assert(sched.endedTasks(0) === Success) + assert(sched.finishedManagers.contains(manager)) + } + + test("multiple offers with no preferences") { + sc = new SparkContext("local", "test") + val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) + val taskSet = createTaskSet(3) + val manager = new ClusterTaskSetManager(sched, taskSet) + + // First three offers should all find tasks + for (i <- 0 until 3) { + val taskOption = manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === "exec1") + } + assert(sched.startedTasks.toSet === Set(0, 1, 2)) + + // Re-offer the host -- now we should get no more tasks + assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) + + // Finish the first two tasks + manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0)) + manager.statusUpdate(1, TaskState.FINISHED, createTaskResult(1)) + assert(sched.endedTasks(0) === Success) + assert(sched.endedTasks(1) === Success) + assert(!sched.finishedManagers.contains(manager)) + + // Finish the last task + manager.statusUpdate(2, TaskState.FINISHED, createTaskResult(2)) + assert(sched.endedTasks(2) === Success) + assert(sched.finishedManagers.contains(manager)) + } + + test("basic delay scheduling") { + sc = new SparkContext("local", "test") + val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = createTaskSet(4, + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host2", "exec2")), + Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")), + Seq() // Last task has no locality prefs + ) + val clock = new FakeClock + val manager = new ClusterTaskSetManager(sched, taskSet, clock) + + // First offer host1, exec1: first task should be chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) + + // Offer host1, exec1 again: the last task, which has no prefs, should be chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 3) + + // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) + + clock.advance(LOCALITY_WAIT) + + // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) + + // Offer host1, exec1 again, at NODE_LOCAL level: we should choose task 2 + assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL).get.index == 2) + + // Offer host1, exec1 again, at NODE_LOCAL level: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL) === None) + + // Offer host1, exec1 again, at ANY level: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + + clock.advance(LOCALITY_WAIT) + + // Offer host1, exec1 again, at ANY level: task 1 should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) + + // Offer host1, exec1 again, at ANY level: nothing should be chosen as we've launched all tasks + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + } + + test("delay scheduling with fallback") { + sc = new SparkContext("local", "test") + val sched = new FakeClusterScheduler(sc, + ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) + val taskSet = createTaskSet(5, + Seq(TaskLocation("host1")), + Seq(TaskLocation("host2")), + Seq(TaskLocation("host2")), + Seq(TaskLocation("host3")), + Seq(TaskLocation("host2")) + ) + val clock = new FakeClock + val manager = new ClusterTaskSetManager(sched, taskSet, clock) + + // First offer host1: first task should be chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) + + // Offer host1 again: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + + clock.advance(LOCALITY_WAIT) + + // Offer host1 again: second task (on host2) should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) + + // Offer host1 again: third task (on host2) should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2) + + // Offer host2: fifth task (also on host2) should get chosen + assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 4) + + // Now that we've launched a local task, we should no longer launch the task for host3 + assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None) + + clock.advance(LOCALITY_WAIT) + + // After another delay, we can go ahead and launch that task non-locally + assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 3) + } + + test("delay scheduling with failed hosts") { + sc = new SparkContext("local", "test") + val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = createTaskSet(3, + Seq(TaskLocation("host1")), + Seq(TaskLocation("host2")), + Seq(TaskLocation("host3")) + ) + val clock = new FakeClock + val manager = new ClusterTaskSetManager(sched, taskSet, clock) + + // First offer host1: first task should be chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) + + // Offer host1 again: third task should be chosen immediately because host3 is not up + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2) + + // After this, nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + + // Now mark host2 as dead + sched.removeExecutor("exec2") + manager.executorLost("exec2", "host2") + + // Task 1 should immediately be launched on host1 because its original host is gone + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) + + // Now that all tasks have launched, nothing new should be launched anywhere else + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None) + } + + /** + * Utility method to create a TaskSet, potentially setting a particular sequence of preferred + * locations for each task (given as varargs) if this sequence is not empty. + */ + def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + if (prefLocs.size != 0 && prefLocs.size != numTasks) { + throw new IllegalArgumentException("Wrong number of task locations") + } + val tasks = Array.tabulate[Task[_]](numTasks) { i => + new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) + } + new TaskSet(tasks, 0, 0, 0, null) + } + + def createTaskResult(id: Int): ByteBuffer = { + ByteBuffer.wrap(Utils.serialize(new TaskResult[Int](id, mutable.Map.empty, new TaskMetrics))) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala new file mode 100644 index 0000000000..2f12aaed18 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.spark.scheduler.{TaskLocation, Task} + +class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId) { + override def run(attemptId: Long): Int = 0 + + override def preferredLocations: Seq[TaskLocation] = prefLocs +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala new file mode 100644 index 0000000000..111340a65c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.local + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter + +import org.apache.spark._ +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster._ +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ConcurrentMap, HashMap} +import java.util.concurrent.Semaphore +import java.util.concurrent.CountDownLatch +import java.util.Properties + +class Lock() { + var finished = false + def jobWait() = { + synchronized { + while(!finished) { + this.wait() + } + } + } + + def jobFinished() = { + synchronized { + finished = true + this.notifyAll() + } + } +} + +object TaskThreadInfo { + val threadToLock = HashMap[Int, Lock]() + val threadToRunning = HashMap[Int, Boolean]() + val threadToStarted = HashMap[Int, CountDownLatch]() +} + +/* + * 1. each thread contains one job. + * 2. each job contains one stage. + * 3. each stage only contains one task. + * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure + * it will get cpu core resource, and will wait to finished after user manually + * release "Lock" and then cluster will contain another free cpu cores. + * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue, + * thus it will be scheduled later when cluster has free cpu cores. + */ +class LocalSchedulerSuite extends FunSuite with LocalSparkContext { + + def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) { + + TaskThreadInfo.threadToRunning(threadIndex) = false + val nums = sc.parallelize(threadIndex to threadIndex, 1) + TaskThreadInfo.threadToLock(threadIndex) = new Lock() + TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1) + new Thread { + if (poolName != null) { + sc.setLocalProperty("spark.scheduler.cluster.fair.pool", poolName) + } + override def run() { + val ans = nums.map(number => { + TaskThreadInfo.threadToRunning(number) = true + TaskThreadInfo.threadToStarted(number).countDown() + TaskThreadInfo.threadToLock(number).jobWait() + TaskThreadInfo.threadToRunning(number) = false + number + }).collect() + assert(ans.toList === List(threadIndex)) + sem.release() + } + }.start() + } + + test("Local FIFO scheduler end-to-end test") { + System.setProperty("spark.cluster.schedulingmode", "FIFO") + sc = new SparkContext("local[4]", "test") + val sem = new Semaphore(0) + + createThread(1,null,sc,sem) + TaskThreadInfo.threadToStarted(1).await() + createThread(2,null,sc,sem) + TaskThreadInfo.threadToStarted(2).await() + createThread(3,null,sc,sem) + TaskThreadInfo.threadToStarted(3).await() + createThread(4,null,sc,sem) + TaskThreadInfo.threadToStarted(4).await() + // thread 5 and 6 (stage pending)must meet following two points + // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager + // queue before executing TaskThreadInfo.threadToLock(1).jobFinished() + // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6 + // So I just use "sleep" 1s here for each thread. + // TODO: any better solution? + createThread(5,null,sc,sem) + Thread.sleep(1000) + createThread(6,null,sc,sem) + Thread.sleep(1000) + + assert(TaskThreadInfo.threadToRunning(1) === true) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === true) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === false) + assert(TaskThreadInfo.threadToRunning(6) === false) + + TaskThreadInfo.threadToLock(1).jobFinished() + TaskThreadInfo.threadToStarted(5).await() + + assert(TaskThreadInfo.threadToRunning(1) === false) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === true) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === true) + assert(TaskThreadInfo.threadToRunning(6) === false) + + TaskThreadInfo.threadToLock(3).jobFinished() + TaskThreadInfo.threadToStarted(6).await() + + assert(TaskThreadInfo.threadToRunning(1) === false) + assert(TaskThreadInfo.threadToRunning(2) === true) + assert(TaskThreadInfo.threadToRunning(3) === false) + assert(TaskThreadInfo.threadToRunning(4) === true) + assert(TaskThreadInfo.threadToRunning(5) === true) + assert(TaskThreadInfo.threadToRunning(6) === true) + + TaskThreadInfo.threadToLock(2).jobFinished() + TaskThreadInfo.threadToLock(4).jobFinished() + TaskThreadInfo.threadToLock(5).jobFinished() + TaskThreadInfo.threadToLock(6).jobFinished() + sem.acquire(6) + } + + test("Local fair scheduler end-to-end test") { + sc = new SparkContext("local[8]", "LocalSchedulerSuite") + val sem = new Semaphore(0) + System.setProperty("spark.cluster.schedulingmode", "FAIR") + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + System.setProperty("spark.fairscheduler.allocation.file", xmlPath) + + createThread(10,"1",sc,sem) + TaskThreadInfo.threadToStarted(10).await() + createThread(20,"2",sc,sem) + TaskThreadInfo.threadToStarted(20).await() + createThread(30,"3",sc,sem) + TaskThreadInfo.threadToStarted(30).await() + + assert(TaskThreadInfo.threadToRunning(10) === true) + assert(TaskThreadInfo.threadToRunning(20) === true) + assert(TaskThreadInfo.threadToRunning(30) === true) + + createThread(11,"1",sc,sem) + TaskThreadInfo.threadToStarted(11).await() + createThread(21,"2",sc,sem) + TaskThreadInfo.threadToStarted(21).await() + createThread(31,"3",sc,sem) + TaskThreadInfo.threadToStarted(31).await() + + assert(TaskThreadInfo.threadToRunning(11) === true) + assert(TaskThreadInfo.threadToRunning(21) === true) + assert(TaskThreadInfo.threadToRunning(31) === true) + + createThread(12,"1",sc,sem) + TaskThreadInfo.threadToStarted(12).await() + createThread(22,"2",sc,sem) + TaskThreadInfo.threadToStarted(22).await() + createThread(32,"3",sc,sem) + + assert(TaskThreadInfo.threadToRunning(12) === true) + assert(TaskThreadInfo.threadToRunning(22) === true) + assert(TaskThreadInfo.threadToRunning(32) === false) + + TaskThreadInfo.threadToLock(10).jobFinished() + TaskThreadInfo.threadToStarted(32).await() + + assert(TaskThreadInfo.threadToRunning(32) === true) + + //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager + // queue so that cluster will assign free cpu core to stage 23 after stage 11 finished. + //2. priority of 23 and 33 will be meaningless as using fair scheduler here. + createThread(23,"2",sc,sem) + createThread(33,"3",sc,sem) + Thread.sleep(1000) + + TaskThreadInfo.threadToLock(11).jobFinished() + TaskThreadInfo.threadToStarted(23).await() + + assert(TaskThreadInfo.threadToRunning(23) === true) + assert(TaskThreadInfo.threadToRunning(33) === false) + + TaskThreadInfo.threadToLock(12).jobFinished() + TaskThreadInfo.threadToStarted(33).await() + + assert(TaskThreadInfo.threadToRunning(33) === true) + + TaskThreadInfo.threadToLock(20).jobFinished() + TaskThreadInfo.threadToLock(21).jobFinished() + TaskThreadInfo.threadToLock(22).jobFinished() + TaskThreadInfo.threadToLock(23).jobFinished() + TaskThreadInfo.threadToLock(30).jobFinished() + TaskThreadInfo.threadToLock(31).jobFinished() + TaskThreadInfo.threadToLock(32).jobFinished() + TaskThreadInfo.threadToLock(33).jobFinished() + + sem.acquire(11) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala new file mode 100644 index 0000000000..88ba10f2f2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -0,0 +1,666 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer + +import akka.actor._ + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import org.scalatest.PrivateMethodTester +import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.matchers.ShouldMatchers._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.JavaSerializer +import org.apache.spark.KryoSerializer +import org.apache.spark.SizeEstimator +import org.apache.spark.Utils +import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.ByteBufferInputStream + + +class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester { + var store: BlockManager = null + var store2: BlockManager = null + var actorSystem: ActorSystem = null + var master: BlockManagerMaster = null + var oldArch: String = null + var oldOops: String = null + var oldHeartBeat: String = null + + // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test + System.setProperty("spark.kryoserializer.buffer.mb", "1") + val serializer = new KryoSerializer + + before { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) + this.actorSystem = actorSystem + System.setProperty("spark.driver.port", boundPort.toString) + System.setProperty("spark.hostPort", "localhost:" + boundPort) + + master = new BlockManagerMaster( + actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))) + + // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case + oldArch = System.setProperty("os.arch", "amd64") + oldOops = System.setProperty("spark.test.useCompressedOops", "true") + oldHeartBeat = System.setProperty("spark.storage.disableBlockManagerHeartBeat", "true") + val initialize = PrivateMethod[Unit]('initialize) + SizeEstimator invokePrivate initialize() + // Set some value ... + System.setProperty("spark.hostPort", Utils.localHostName() + ":" + 1111) + } + + after { + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + + if (store != null) { + store.stop() + store = null + } + if (store2 != null) { + store2.stop() + store2 = null + } + actorSystem.shutdown() + actorSystem.awaitTermination() + actorSystem = null + master = null + + if (oldArch != null) { + System.setProperty("os.arch", oldArch) + } else { + System.clearProperty("os.arch") + } + + if (oldOops != null) { + System.setProperty("spark.test.useCompressedOops", oldOops) + } else { + System.clearProperty("spark.test.useCompressedOops") + } + } + + test("StorageLevel object caching") { + val level1 = StorageLevel(false, false, false, 3) + val level2 = StorageLevel(false, false, false, 3) // this should return the same object as level1 + val level3 = StorageLevel(false, false, false, 2) // this should return a different object + assert(level2 === level1, "level2 is not same as level1") + assert(level2.eq(level1), "level2 is not the same object as level1") + assert(level3 != level1, "level3 is same as level1") + val bytes1 = Utils.serialize(level1) + val level1_ = Utils.deserialize[StorageLevel](bytes1) + val bytes2 = Utils.serialize(level2) + val level2_ = Utils.deserialize[StorageLevel](bytes2) + assert(level1_ === level1, "Deserialized level1 not same as original level1") + assert(level1_.eq(level1), "Deserialized level1 not the same object as original level2") + assert(level2_ === level2, "Deserialized level2 not same as original level2") + assert(level2_.eq(level1), "Deserialized level2 not the same object as original level1") + } + + test("BlockManagerId object caching") { + val id1 = BlockManagerId("e1", "XXX", 1, 0) + val id2 = BlockManagerId("e1", "XXX", 1, 0) // this should return the same object as id1 + val id3 = BlockManagerId("e1", "XXX", 2, 0) // this should return a different object + assert(id2 === id1, "id2 is not same as id1") + assert(id2.eq(id1), "id2 is not the same object as id1") + assert(id3 != id1, "id3 is same as id1") + val bytes1 = Utils.serialize(id1) + val id1_ = Utils.deserialize[BlockManagerId](bytes1) + val bytes2 = Utils.serialize(id2) + val id2_ = Utils.deserialize[BlockManagerId](bytes2) + assert(id1_ === id1, "Deserialized id1 is not same as original id1") + assert(id1_.eq(id1), "Deserialized id1 is not the same object as original id1") + assert(id2_ === id2, "Deserialized id2 is not same as original id2") + assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1") + } + + test("master + 1 manager interaction") { + store = new BlockManager("", actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + + // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) + store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) + + // Checking whether blocks are in memory + assert(store.getSingle("a1") != None, "a1 was not in store") + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") != None, "a3 was not in store") + + // Checking whether master knows about the blocks or not + assert(master.getLocations("a1").size > 0, "master was not told about a1") + assert(master.getLocations("a2").size > 0, "master was not told about a2") + assert(master.getLocations("a3").size === 0, "master was told about a3") + + // Drop a1 and a2 from memory; this should be reported back to the master + store.dropFromMemory("a1", null) + store.dropFromMemory("a2", null) + assert(store.getSingle("a1") === None, "a1 not removed from store") + assert(store.getSingle("a2") === None, "a2 not removed from store") + assert(master.getLocations("a1").size === 0, "master did not remove a1") + assert(master.getLocations("a2").size === 0, "master did not remove a2") + } + + test("master + 2 managers interaction") { + store = new BlockManager("exec1", actorSystem, master, serializer, 2000) + store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer, 2000) + + val peers = master.getPeers(store.blockManagerId, 1) + assert(peers.size === 1, "master did not return the other manager as a peer") + assert(peers.head === store2.blockManagerId, "peer returned by master is not the other manager") + + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_2) + store2.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_2) + assert(master.getLocations("a1").size === 2, "master did not report 2 locations for a1") + assert(master.getLocations("a2").size === 2, "master did not report 2 locations for a2") + } + + test("removing block") { + store = new BlockManager("", actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + + // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 + store.putSingle("a1-to-remove", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("a2-to-remove", a2, StorageLevel.MEMORY_ONLY) + store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) + + // Checking whether blocks are in memory and memory size + val memStatus = master.getMemoryStatus.head._2 + assert(memStatus._1 == 2000L, "total memory " + memStatus._1 + " should equal 2000") + assert(memStatus._2 <= 1200L, "remaining memory " + memStatus._2 + " should <= 1200") + assert(store.getSingle("a1-to-remove") != None, "a1 was not in store") + assert(store.getSingle("a2-to-remove") != None, "a2 was not in store") + assert(store.getSingle("a3-to-remove") != None, "a3 was not in store") + + // Checking whether master knows about the blocks or not + assert(master.getLocations("a1-to-remove").size > 0, "master was not told about a1") + assert(master.getLocations("a2-to-remove").size > 0, "master was not told about a2") + assert(master.getLocations("a3-to-remove").size === 0, "master was told about a3") + + // Remove a1 and a2 and a3. Should be no-op for a3. + master.removeBlock("a1-to-remove") + master.removeBlock("a2-to-remove") + master.removeBlock("a3-to-remove") + + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("a1-to-remove") should be (None) + master.getLocations("a1-to-remove") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("a2-to-remove") should be (None) + master.getLocations("a2-to-remove") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("a3-to-remove") should not be (None) + master.getLocations("a3-to-remove") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + val memStatus = master.getMemoryStatus.head._2 + memStatus._1 should equal (2000L) + memStatus._2 should equal (2000L) + } + } + + test("removing rdd") { + store = new BlockManager("", actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + // Putting a1, a2 and a3 in memory. + store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) + store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY) + master.removeRdd(0, blocking = false) + + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("rdd_0_0") should be (None) + master.getLocations("rdd_0_0") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("rdd_0_1") should be (None) + master.getLocations("rdd_0_1") should have size 0 + } + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + store.getSingle("nonrddblock") should not be (None) + master.getLocations("nonrddblock") should have size (1) + } + + store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) + master.removeRdd(0, blocking = true) + store.getSingle("rdd_0_0") should be (None) + master.getLocations("rdd_0_0") should have size 0 + store.getSingle("rdd_0_1") should be (None) + master.getLocations("rdd_0_1") should have size 0 + } + + test("reregistration on heart beat") { + val heartBeat = PrivateMethod[Unit]('heartBeat) + store = new BlockManager("", actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + + assert(store.getSingle("a1") != None, "a1 was not in store") + assert(master.getLocations("a1").size > 0, "master was not told about a1") + + master.removeExecutor(store.blockManagerId.executorId) + assert(master.getLocations("a1").size == 0, "a1 was not removed from master") + + store invokePrivate heartBeat() + assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") + } + + test("reregistration on block update") { + store = new BlockManager("", actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + assert(master.getLocations("a1").size > 0, "master was not told about a1") + + master.removeExecutor(store.blockManagerId.executorId) + assert(master.getLocations("a1").size == 0, "a1 was not removed from master") + + store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) + store.waitForAsyncReregister() + + assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") + assert(master.getLocations("a2").size > 0, "master was not told about a2") + } + + test("reregistration doesn't dead lock") { + val heartBeat = PrivateMethod[Unit]('heartBeat) + store = new BlockManager("", actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = List(new Array[Byte](400)) + + // try many times to trigger any deadlocks + for (i <- 1 to 100) { + master.removeExecutor(store.blockManagerId.executorId) + val t1 = new Thread { + override def run() { + store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + } + } + val t2 = new Thread { + override def run() { + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + } + } + val t3 = new Thread { + override def run() { + store invokePrivate heartBeat() + } + } + + t1.start() + t2.start() + t3.start() + t1.join() + t2.join() + t3.join() + + store.dropFromMemory("a1", null) + store.dropFromMemory("a2", null) + store.waitForAsyncReregister() + } + } + + test("in-memory LRU storage") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) + store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY) + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") != None, "a3 was not in store") + assert(store.getSingle("a1") === None, "a1 was in store") + assert(store.getSingle("a2") != None, "a2 was not in store") + // At this point a2 was gotten last, so LRU will getSingle rid of a3 + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + assert(store.getSingle("a1") != None, "a1 was not in store") + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") === None, "a3 was in store") + } + + test("in-memory LRU storage with serialization") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) + store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER) + store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY_SER) + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") != None, "a3 was not in store") + assert(store.getSingle("a1") === None, "a1 was in store") + assert(store.getSingle("a2") != None, "a2 was not in store") + // At this point a2 was gotten last, so LRU will getSingle rid of a3 + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) + assert(store.getSingle("a1") != None, "a1 was not in store") + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") === None, "a3 was in store") + } + + test("in-memory LRU for partitions of same RDD") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY) + // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2 + // from the same RDD + assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") + assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store") + assert(store.getSingle("rdd_0_1") != None, "rdd_0_1 was not in store") + // Check that rdd_0_3 doesn't replace them even after further accesses + assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") + assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") + assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") + } + + test("in-memory LRU for partitions of multiple RDDs") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) + // At this point rdd_1_1 should've replaced rdd_0_1 + assert(store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was not in store") + assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store") + assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store") + // Do a get() on rdd_0_2 so that it is the most recently used item + assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store") + // Put in more partitions from RDD 0; they should replace rdd_1_1 + store.putSingle("rdd_0_3", new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle("rdd_0_4", new Array[Byte](400), StorageLevel.MEMORY_ONLY) + // Now rdd_1_1 should be dropped to add rdd_0_3, but then rdd_0_2 should *not* be dropped + // when we try to add rdd_0_4. + assert(!store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was in store") + assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store") + assert(!store.memoryStore.contains("rdd_0_4"), "rdd_0_4 was in store") + assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store") + assert(store.memoryStore.contains("rdd_0_3"), "rdd_0_3 was not in store") + } + + test("on-disk storage") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.DISK_ONLY) + store.putSingle("a2", a2, StorageLevel.DISK_ONLY) + store.putSingle("a3", a3, StorageLevel.DISK_ONLY) + assert(store.getSingle("a2") != None, "a2 was in store") + assert(store.getSingle("a3") != None, "a3 was in store") + assert(store.getSingle("a1") != None, "a1 was in store") + } + + test("disk and memory storage") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) + store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) + store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") != None, "a3 was not in store") + assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") + assert(store.getSingle("a1") != None, "a1 was not in store") + assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store") + } + + test("disk and memory storage with getLocalBytes") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) + store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) + store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) + assert(store.getLocalBytes("a2") != None, "a2 was not in store") + assert(store.getLocalBytes("a3") != None, "a3 was not in store") + assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") + assert(store.getLocalBytes("a1") != None, "a1 was not in store") + assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store") + } + + test("disk and memory storage with serialization") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) + store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) + store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") != None, "a3 was not in store") + assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") + assert(store.getSingle("a1") != None, "a1 was not in store") + assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store") + } + + test("disk and memory storage with serialization and getLocalBytes") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) + store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) + store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) + assert(store.getLocalBytes("a2") != None, "a2 was not in store") + assert(store.getLocalBytes("a3") != None, "a3 was not in store") + assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") + assert(store.getLocalBytes("a1") != None, "a1 was not in store") + assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store") + } + + test("LRU with mixed storage levels") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + val a4 = new Array[Byte](400) + // First store a1 and a2, both in memory, and a3, on disk only + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) + store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER) + store.putSingle("a3", a3, StorageLevel.DISK_ONLY) + // At this point LRU should not kick in because a3 is only on disk + assert(store.getSingle("a1") != None, "a2 was not in store") + assert(store.getSingle("a2") != None, "a3 was not in store") + assert(store.getSingle("a3") != None, "a1 was not in store") + assert(store.getSingle("a1") != None, "a2 was not in store") + assert(store.getSingle("a2") != None, "a3 was not in store") + assert(store.getSingle("a3") != None, "a1 was not in store") + // Now let's add in a4, which uses both disk and memory; a1 should drop out + store.putSingle("a4", a4, StorageLevel.MEMORY_AND_DISK_SER) + assert(store.getSingle("a1") == None, "a1 was in store") + assert(store.getSingle("a2") != None, "a2 was not in store") + assert(store.getSingle("a3") != None, "a3 was not in store") + assert(store.getSingle("a4") != None, "a4 was not in store") + } + + test("in-memory LRU with streams") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val list1 = List(new Array[Byte](200), new Array[Byte](200)) + val list2 = List(new Array[Byte](200), new Array[Byte](200)) + val list3 = List(new Array[Byte](200), new Array[Byte](200)) + store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.put("list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + assert(store.get("list2") != None, "list2 was not in store") + assert(store.get("list2").get.size == 2) + assert(store.get("list3") != None, "list3 was not in store") + assert(store.get("list3").get.size == 2) + assert(store.get("list1") === None, "list1 was in store") + assert(store.get("list2") != None, "list2 was not in store") + assert(store.get("list2").get.size == 2) + // At this point list2 was gotten last, so LRU will getSingle rid of list3 + store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + assert(store.get("list1") != None, "list1 was not in store") + assert(store.get("list1").get.size == 2) + assert(store.get("list2") != None, "list2 was not in store") + assert(store.get("list2").get.size == 2) + assert(store.get("list3") === None, "list1 was in store") + } + + test("LRU with mixed storage levels and streams") { + store = new BlockManager("", actorSystem, master, serializer, 1200) + val list1 = List(new Array[Byte](200), new Array[Byte](200)) + val list2 = List(new Array[Byte](200), new Array[Byte](200)) + val list3 = List(new Array[Byte](200), new Array[Byte](200)) + val list4 = List(new Array[Byte](200), new Array[Byte](200)) + // First store list1 and list2, both in memory, and list3, on disk only + store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) + store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) + store.put("list3", list3.iterator, StorageLevel.DISK_ONLY, tellMaster = true) + // At this point LRU should not kick in because list3 is only on disk + assert(store.get("list1") != None, "list2 was not in store") + assert(store.get("list1").get.size === 2) + assert(store.get("list2") != None, "list3 was not in store") + assert(store.get("list2").get.size === 2) + assert(store.get("list3") != None, "list1 was not in store") + assert(store.get("list3").get.size === 2) + assert(store.get("list1") != None, "list2 was not in store") + assert(store.get("list1").get.size === 2) + assert(store.get("list2") != None, "list3 was not in store") + assert(store.get("list2").get.size === 2) + assert(store.get("list3") != None, "list1 was not in store") + assert(store.get("list3").get.size === 2) + // Now let's add in list4, which uses both disk and memory; list1 should drop out + store.put("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) + assert(store.get("list1") === None, "list1 was in store") + assert(store.get("list2") != None, "list3 was not in store") + assert(store.get("list2").get.size === 2) + assert(store.get("list3") != None, "list1 was not in store") + assert(store.get("list3").get.size === 2) + assert(store.get("list4") != None, "list4 was not in store") + assert(store.get("list4").get.size === 2) + } + + test("negative byte values in ByteBufferInputStream") { + val buffer = ByteBuffer.wrap(Array[Int](254, 255, 0, 1, 2).map(_.toByte).toArray) + val stream = new ByteBufferInputStream(buffer) + val temp = new Array[Byte](10) + assert(stream.read() === 254, "unexpected byte read") + assert(stream.read() === 255, "unexpected byte read") + assert(stream.read() === 0, "unexpected byte read") + assert(stream.read(temp, 0, temp.length) === 2, "unexpected number of bytes read") + assert(stream.read() === -1, "end of stream not signalled") + assert(stream.read(temp, 0, temp.length) === -1, "end of stream not signalled") + } + + test("overly large block") { + store = new BlockManager("", actorSystem, master, serializer, 500) + store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) + assert(store.getSingle("a1") === None, "a1 was in store") + store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) + assert(store.memoryStore.getValues("a2") === None, "a2 was in memory store") + assert(store.getSingle("a2") != None, "a2 was not in store") + } + + test("block compression") { + try { + System.setProperty("spark.shuffle.compress", "true") + store = new BlockManager("exec1", actorSystem, master, serializer, 2000) + store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed") + store.stop() + store = null + + System.setProperty("spark.shuffle.compress", "false") + store = new BlockManager("exec2", actorSystem, master, serializer, 2000) + store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed") + store.stop() + store = null + + System.setProperty("spark.broadcast.compress", "true") + store = new BlockManager("exec3", actorSystem, master, serializer, 2000) + store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed") + store.stop() + store = null + + System.setProperty("spark.broadcast.compress", "false") + store = new BlockManager("exec4", actorSystem, master, serializer, 2000) + store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed") + store.stop() + store = null + + System.setProperty("spark.rdd.compress", "true") + store = new BlockManager("exec5", actorSystem, master, serializer, 2000) + store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed") + store.stop() + store = null + + System.setProperty("spark.rdd.compress", "false") + store = new BlockManager("exec6", actorSystem, master, serializer, 2000) + store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed") + store.stop() + store = null + + // Check that any other block types are also kept uncompressed + store = new BlockManager("exec7", actorSystem, master, serializer, 2000) + store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) + assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") + store.stop() + store = null + } finally { + System.clearProperty("spark.shuffle.compress") + System.clearProperty("spark.broadcast.compress") + System.clearProperty("spark.rdd.compress") + } + } + + test("block store put failure") { + // Use Java serializer so we can create an unserializable error. + store = new BlockManager("", actorSystem, master, new JavaSerializer, 1200) + + // The put should fail since a1 is not serializable. + class UnserializableClass + val a1 = new UnserializableClass + intercept[java.io.NotSerializableException] { + store.putSingle("a1", a1, StorageLevel.DISK_ONLY) + } + + // Make sure get a1 doesn't hang and returns None. + failAfter(1 second) { + assert(store.getSingle("a1") == None, "a1 should not be in store") + } + } +} diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala new file mode 100644 index 0000000000..3321fb5eb7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import scala.util.{Failure, Success, Try} +import java.net.ServerSocket +import org.scalatest.FunSuite +import org.eclipse.jetty.server.Server + +class UISuite extends FunSuite { + test("jetty port increases under contention") { + val startPort = 3030 + val server = new Server(startPort) + server.start() + val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("localhost", startPort, Seq()) + val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("localhost", startPort, Seq()) + + // Allow some wiggle room in case ports on the machine are under contention + assert(boundPort1 > startPort && boundPort1 < startPort + 10) + assert(boundPort2 > boundPort1 && boundPort2 < boundPort1 + 10) + } + + test("jetty binds to port 0 correctly") { + val (jettyServer, boundPort) = JettyUtils.startJettyServer("localhost", 0, Seq()) + assert(jettyServer.getState === "STARTED") + assert(boundPort != 0) + Try {new ServerSocket(boundPort)} match { + case Success(s) => fail("Port %s doesn't seem used by jetty server".format(boundPort)) + case Failure (e) => + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala new file mode 100644 index 0000000000..63642461e4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers + +/** + * + */ + +class DistributionSuite extends FunSuite with ShouldMatchers { + test("summary") { + val d = new Distribution((1 to 100).toArray.map{_.toDouble}) + val stats = d.statCounter + stats.count should be (100) + stats.mean should be (50.5) + stats.sum should be (50 * 101) + + val quantiles = d.getQuantiles() + quantiles(0) should be (1) + quantiles(1) should be (26) + quantiles(2) should be (51) + quantiles(3) should be (76) + quantiles(4) should be (100) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/FakeClock.scala b/core/src/test/scala/org/apache/spark/util/FakeClock.scala new file mode 100644 index 0000000000..0a45917b08 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/FakeClock.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +class FakeClock extends Clock { + private var time = 0L + + def advance(millis: Long): Unit = time += millis + + def getTime(): Long = time +} diff --git a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala new file mode 100644 index 0000000000..45867463a5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import scala.collection.mutable.Buffer +import java.util.NoSuchElementException + +class NextIteratorSuite extends FunSuite with ShouldMatchers { + test("one iteration") { + val i = new StubIterator(Buffer(1)) + i.hasNext should be === true + i.next should be === 1 + i.hasNext should be === false + intercept[NoSuchElementException] { i.next() } + } + + test("two iterations") { + val i = new StubIterator(Buffer(1, 2)) + i.hasNext should be === true + i.next should be === 1 + i.hasNext should be === true + i.next should be === 2 + i.hasNext should be === false + intercept[NoSuchElementException] { i.next() } + } + + test("empty iteration") { + val i = new StubIterator(Buffer()) + i.hasNext should be === false + intercept[NoSuchElementException] { i.next() } + } + + test("close is called once for empty iterations") { + val i = new StubIterator(Buffer()) + i.hasNext should be === false + i.hasNext should be === false + i.closeCalled should be === 1 + } + + test("close is called once for non-empty iterations") { + val i = new StubIterator(Buffer(1, 2)) + i.next should be === 1 + i.next should be === 2 + // close isn't called until we check for the next element + i.closeCalled should be === 0 + i.hasNext should be === false + i.closeCalled should be === 1 + i.hasNext should be === false + i.closeCalled should be === 1 + } + + class StubIterator(ints: Buffer[Int]) extends NextIterator[Int] { + var closeCalled = 0 + + override def getNext() = { + if (ints.size == 0) { + finished = true + 0 + } else { + ints.remove(0) + } + } + + override def close() { + closeCalled += 1 + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/RateLimitedOutputStreamSuite.scala new file mode 100644 index 0000000000..a9dd0b1a5b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/RateLimitedOutputStreamSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.scalatest.FunSuite +import java.io.ByteArrayOutputStream +import java.util.concurrent.TimeUnit._ + +class RateLimitedOutputStreamSuite extends FunSuite { + + private def benchmark[U](f: => U): Long = { + val start = System.nanoTime + f + System.nanoTime - start + } + + test("write") { + val underlying = new ByteArrayOutputStream + val data = "X" * 41000 + val stream = new RateLimitedOutputStream(underlying, 10000) + val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } + assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4) + assert(underlying.toString("UTF-8") == data) + } +} diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala deleted file mode 100644 index 0af175f316..0000000000 --- a/core/src/test/scala/spark/AccumulatorSuite.scala +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers -import collection.mutable -import java.util.Random -import scala.math.exp -import scala.math.signum -import spark.SparkContext._ - -class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext { - - test ("basic accumulation"){ - sc = new SparkContext("local", "test") - val acc : Accumulator[Int] = sc.accumulator(0) - - val d = sc.parallelize(1 to 20) - d.foreach{x => acc += x} - acc.value should be (210) - - - val longAcc = sc.accumulator(0l) - val maxInt = Integer.MAX_VALUE.toLong - d.foreach{x => longAcc += maxInt + x} - longAcc.value should be (210l + maxInt * 20) - } - - test ("value not assignable from tasks") { - sc = new SparkContext("local", "test") - val acc : Accumulator[Int] = sc.accumulator(0) - - val d = sc.parallelize(1 to 20) - evaluating {d.foreach{x => acc.value = x}} should produce [Exception] - } - - test ("add value to collection accumulators") { - import SetAccum._ - val maxI = 1000 - for (nThreads <- List(1, 10)) { //test single & multi-threaded - sc = new SparkContext("local[" + nThreads + "]", "test") - val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) - val d = sc.parallelize(1 to maxI) - d.foreach { - x => acc += x - } - val v = acc.value.asInstanceOf[mutable.Set[Int]] - for (i <- 1 to maxI) { - v should contain(i) - } - resetSparkContext() - } - } - - implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] { - def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = { - t1 ++= t2 - t1 - } - def addAccumulator(t1: mutable.Set[Any], t2: Any) : mutable.Set[Any] = { - t1 += t2 - t1 - } - def zero(t: mutable.Set[Any]) : mutable.Set[Any] = { - new mutable.HashSet[Any]() - } - } - - test ("value not readable in tasks") { - import SetAccum._ - val maxI = 1000 - for (nThreads <- List(1, 10)) { //test single & multi-threaded - sc = new SparkContext("local[" + nThreads + "]", "test") - val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) - val d = sc.parallelize(1 to maxI) - evaluating { - d.foreach { - x => acc.value += x - } - } should produce [SparkException] - resetSparkContext() - } - } - - test ("collection accumulators") { - val maxI = 1000 - for (nThreads <- List(1, 10)) { - // test single & multi-threaded - sc = new SparkContext("local[" + nThreads + "]", "test") - val setAcc = sc.accumulableCollection(mutable.HashSet[Int]()) - val bufferAcc = sc.accumulableCollection(mutable.ArrayBuffer[Int]()) - val mapAcc = sc.accumulableCollection(mutable.HashMap[Int,String]()) - val d = sc.parallelize((1 to maxI) ++ (1 to maxI)) - d.foreach { - x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)} - } - - // Note that this is typed correctly -- no casts necessary - setAcc.value.size should be (maxI) - bufferAcc.value.size should be (2 * maxI) - mapAcc.value.size should be (maxI) - for (i <- 1 to maxI) { - setAcc.value should contain(i) - bufferAcc.value should contain(i) - mapAcc.value should contain (i -> i.toString) - } - resetSparkContext() - } - } - - test ("localValue readable in tasks") { - import SetAccum._ - val maxI = 1000 - for (nThreads <- List(1, 10)) { //test single & multi-threaded - sc = new SparkContext("local[" + nThreads + "]", "test") - val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) - val groupedInts = (1 to (maxI/20)).map {x => (20 * (x - 1) to 20 * x).toSet} - val d = sc.parallelize(groupedInts) - d.foreach { - x => acc.localValue ++= x - } - acc.value should be ( (0 to maxI).toSet) - resetSparkContext() - } - } - -} diff --git a/core/src/test/scala/spark/BroadcastSuite.scala b/core/src/test/scala/spark/BroadcastSuite.scala deleted file mode 100644 index 785721ece8..0000000000 --- a/core/src/test/scala/spark/BroadcastSuite.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite - -class BroadcastSuite extends FunSuite with LocalSparkContext { - - test("basic broadcast") { - sc = new SparkContext("local", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === Set((1, 10), (2, 10))) - } - - test("broadcast variables accessed in multiple threads") { - sc = new SparkContext("local[10]", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) - } -} diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala deleted file mode 100644 index 966dede2be..0000000000 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ /dev/null @@ -1,392 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import java.io.File -import spark.rdd._ -import spark.SparkContext._ -import storage.StorageLevel - -class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { - initLogging() - - var checkpointDir: File = _ - val partitioner = new HashPartitioner(2) - - override def beforeEach() { - super.beforeEach() - checkpointDir = File.createTempFile("temp", "") - checkpointDir.delete() - sc = new SparkContext("local", "test") - sc.setCheckpointDir(checkpointDir.toString) - } - - override def afterEach() { - super.afterEach() - if (checkpointDir != null) { - checkpointDir.delete() - } - } - - test("basic checkpointing") { - val parCollection = sc.makeRDD(1 to 4) - val flatMappedRDD = parCollection.flatMap(x => 1 to x) - flatMappedRDD.checkpoint() - assert(flatMappedRDD.dependencies.head.rdd == parCollection) - val result = flatMappedRDD.collect() - assert(flatMappedRDD.dependencies.head.rdd != parCollection) - assert(flatMappedRDD.collect() === result) - } - - test("RDDs with one-to-one dependencies") { - testCheckpointing(_.map(x => x.toString)) - testCheckpointing(_.flatMap(x => 1 to x)) - testCheckpointing(_.filter(_ % 2 == 0)) - testCheckpointing(_.sample(false, 0.5, 0)) - testCheckpointing(_.glom()) - testCheckpointing(_.mapPartitions(_.map(_.toString))) - testCheckpointing(r => new MapPartitionsWithIndexRDD(r, - (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false )) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) - testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) - testCheckpointing(_.pipe(Seq("cat"))) - } - - test("ParallelCollection") { - val parCollection = sc.makeRDD(1 to 4, 2) - val numPartitions = parCollection.partitions.size - parCollection.checkpoint() - assert(parCollection.dependencies === Nil) - val result = parCollection.collect() - assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) - assert(parCollection.dependencies != Nil) - assert(parCollection.partitions.length === numPartitions) - assert(parCollection.partitions.toList === parCollection.checkpointData.get.getPartitions.toList) - assert(parCollection.collect() === result) - } - - test("BlockRDD") { - val blockId = "id" - val blockManager = SparkEnv.get.blockManager - blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) - val blockRDD = new BlockRDD[String](sc, Array(blockId)) - val numPartitions = blockRDD.partitions.size - blockRDD.checkpoint() - val result = blockRDD.collect() - assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result) - assert(blockRDD.dependencies != Nil) - assert(blockRDD.partitions.length === numPartitions) - assert(blockRDD.partitions.toList === blockRDD.checkpointData.get.getPartitions.toList) - assert(blockRDD.collect() === result) - } - - test("ShuffledRDD") { - testCheckpointing(rdd => { - // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD - new ShuffledRDD[Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner) - }) - } - - test("UnionRDD") { - def otherRDD = sc.makeRDD(1 to 10, 1) - - // Test whether the size of UnionRDDPartitions reduce in size after parent RDD is checkpointed. - // Current implementation of UnionRDD has transient reference to parent RDDs, - // so only the partitions will reduce in serialized size, not the RDD. - testCheckpointing(_.union(otherRDD), false, true) - testParentCheckpointing(_.union(otherRDD), false, true) - } - - test("CartesianRDD") { - def otherRDD = sc.makeRDD(1 to 10, 1) - testCheckpointing(new CartesianRDD(sc, _, otherRDD)) - - // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed - // Current implementation of CoalescedRDDPartition has transient reference to parent RDD, - // so only the RDD will reduce in serialized size, not the partitions. - testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false) - - // Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after - // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions. - // Note that this test is very specific to the current implementation of CartesianRDD. - val ones = sc.makeRDD(1 to 100, 10).map(x => x) - ones.checkpoint() // checkpoint that MappedRDD - val cartesian = new CartesianRDD(sc, ones, ones) - val splitBeforeCheckpoint = - serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition]) - cartesian.count() // do the checkpointing - val splitAfterCheckpoint = - serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition]) - assert( - (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) && - (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2), - "CartesianRDD.parents not updated after parent RDD checkpointed" - ) - } - - test("CoalescedRDD") { - testCheckpointing(_.coalesce(2)) - - // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed - // Current implementation of CoalescedRDDPartition has transient reference to parent RDD, - // so only the RDD will reduce in serialized size, not the partitions. - testParentCheckpointing(_.coalesce(2), true, false) - - // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents) after - // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions. - // Note that this test is very specific to the current implementation of CoalescedRDDPartitions - val ones = sc.makeRDD(1 to 100, 10).map(x => x) - ones.checkpoint() // checkpoint that MappedRDD - val coalesced = new CoalescedRDD(ones, 2) - val splitBeforeCheckpoint = - serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition]) - coalesced.count() // do the checkpointing - val splitAfterCheckpoint = - serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition]) - assert( - splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head, - "CoalescedRDDPartition.parents not updated after parent RDD checkpointed" - ) - } - - test("CoGroupedRDD") { - val longLineageRDD1 = generateLongLineageRDDForCoGroupedRDD() - testCheckpointing(rdd => { - CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner) - }, false, true) - - val longLineageRDD2 = generateLongLineageRDDForCoGroupedRDD() - testParentCheckpointing(rdd => { - CheckpointSuite.cogroup( - longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner) - }, false, true) - } - - test("ZippedRDD") { - testCheckpointing( - rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false) - - // Test whether size of ZippedRDD reduce in size after parent RDD is checkpointed - // Current implementation of ZippedRDDPartitions has transient references to parent RDDs, - // so only the RDD will reduce in serialized size, not the partitions. - testParentCheckpointing( - rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false) - } - - test("CheckpointRDD with zero partitions") { - val rdd = new BlockRDD[Int](sc, Array[String]()) - assert(rdd.partitions.size === 0) - assert(rdd.isCheckpointed === false) - rdd.checkpoint() - assert(rdd.count() === 0) - assert(rdd.isCheckpointed === true) - assert(rdd.partitions.size === 0) - } - - /** - * Test checkpointing of the final RDD generated by the given operation. By default, - * this method tests whether the size of serialized RDD has reduced after checkpointing or not. - * It can also test whether the size of serialized RDD partitions has reduced after checkpointing or - * not, but this is not done by default as usually the partitions do not refer to any RDD and - * therefore never store the lineage. - */ - def testCheckpointing[U: ClassManifest]( - op: (RDD[Int]) => RDD[U], - testRDDSize: Boolean = true, - testRDDPartitionSize: Boolean = false - ) { - // Generate the final RDD using given RDD operation - val baseRDD = generateLongLineageRDD() - val operatedRDD = op(baseRDD) - val parentRDD = operatedRDD.dependencies.headOption.orNull - val rddType = operatedRDD.getClass.getSimpleName - val numPartitions = operatedRDD.partitions.length - - // Find serialized sizes before and after the checkpoint - val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - operatedRDD.checkpoint() - val result = operatedRDD.collect() - val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - - // Test whether the checkpoint file has been created - assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result) - - // Test whether dependencies have been changed from its earlier parent RDD - assert(operatedRDD.dependencies.head.rdd != parentRDD) - - // Test whether the partitions have been changed to the new Hadoop partitions - assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) - - // Test whether the number of partitions is same as before - assert(operatedRDD.partitions.length === numPartitions) - - // Test whether the data in the checkpointed RDD is same as original - assert(operatedRDD.collect() === result) - - // Test whether serialized size of the RDD has reduced. If the RDD - // does not have any dependency to another RDD (e.g., ParallelCollection, - // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing. - if (testRDDSize) { - logInfo("Size of " + rddType + - "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]") - assert( - rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, - "Size of " + rddType + " did not reduce after checkpointing " + - "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" - ) - } - - // Test whether serialized size of the partitions has reduced. If the partitions - // do not have any non-transient reference to another RDD or another RDD's partitions, it - // does not refer to a lineage and therefore may not reduce in size after checkpointing. - // However, if the original partitions before checkpointing do refer to a parent RDD, the partitions - // must be forgotten after checkpointing (to remove all reference to parent RDDs) and - // replaced with the HadooPartitions of the checkpointed RDD. - if (testRDDPartitionSize) { - logInfo("Size of " + rddType + " partitions " - + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]") - assert( - splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, - "Size of " + rddType + " partitions did not reduce after checkpointing " + - "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" - ) - } - } - - /** - * Test whether checkpointing of the parent of the generated RDD also - * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent - * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, - * this RDD will remember the partitions and therefore potentially the whole lineage. - */ - def testParentCheckpointing[U: ClassManifest]( - op: (RDD[Int]) => RDD[U], - testRDDSize: Boolean, - testRDDPartitionSize: Boolean - ) { - // Generate the final RDD using given RDD operation - val baseRDD = generateLongLineageRDD() - val operatedRDD = op(baseRDD) - val parentRDD = operatedRDD.dependencies.head.rdd - val rddType = operatedRDD.getClass.getSimpleName - val parentRDDType = parentRDD.getClass.getSimpleName - - // Get the partitions and dependencies of the parent in case they're lazily computed - parentRDD.dependencies - parentRDD.partitions - - // Find serialized sizes before and after the checkpoint - val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) - parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one - val result = operatedRDD.collect() - val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) - - // Test whether the data in the checkpointed RDD is same as original - assert(operatedRDD.collect() === result) - - // Test whether serialized size of the RDD has reduced because of its parent being - // checkpointed. If this RDD or its parent RDD do not have any dependency - // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may - // not reduce in size after checkpointing. - if (testRDDSize) { - assert( - rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, - "Size of " + rddType + " did not reduce after checkpointing parent " + parentRDDType + - "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" - ) - } - - // Test whether serialized size of the partitions has reduced because of its parent being - // checkpointed. If the partitions do not have any non-transient reference to another RDD - // or another RDD's partitions, it does not refer to a lineage and therefore may not reduce - // in size after checkpointing. However, if the partitions do refer to the *partitions* of a parent - // RDD, then these partitions must update reference to the parent RDD partitions as the parent RDD's - // partitions must have changed after checkpointing. - if (testRDDPartitionSize) { - assert( - splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, - "Size of " + rddType + " partitions did not reduce after checkpointing parent " + parentRDDType + - "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" - ) - } - - } - - /** - * Generate an RDD with a long lineage of one-to-one dependencies. - */ - def generateLongLineageRDD(): RDD[Int] = { - var rdd = sc.makeRDD(1 to 100, 4) - for (i <- 1 to 50) { - rdd = rdd.map(x => x + 1) - } - rdd - } - - /** - * Generate an RDD with a long lineage specifically for CoGroupedRDD. - * A CoGroupedRDD can have a long lineage only one of its parents have a long lineage - * and narrow dependency with this RDD. This method generate such an RDD by a sequence - * of cogroups and mapValues which creates a long lineage of narrow dependencies. - */ - def generateLongLineageRDDForCoGroupedRDD() = { - val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _) - - def ones: RDD[(Int, Int)] = sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _) - - var cogrouped: RDD[(Int, (Seq[Int], Seq[Int]))] = ones.cogroup(ones) - for(i <- 1 to 10) { - cogrouped = cogrouped.mapValues(add).cogroup(ones) - } - cogrouped.mapValues(add) - } - - /** - * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks - * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. - */ - def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { - (Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length, - Utils.serialize(rdd.partitions).length) - } - - /** - * Serialize and deserialize an object. This is useful to verify the objects - * contents after deserialization (e.g., the contents of an RDD split after - * it is sent to a slave along with a task) - */ - def serializeDeserialize[T](obj: T): T = { - val bytes = Utils.serialize(obj) - Utils.deserialize[T](bytes) - } -} - - -object CheckpointSuite { - // This is a custom cogroup function that does not use mapValues like - // the PairRDDFunctions.cogroup() - def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { - //println("First = " + first + ", second = " + second) - new CoGroupedRDD[K]( - Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]), - part - ).asInstanceOf[RDD[(K, Seq[Seq[V]])]] - } - -} diff --git a/core/src/test/scala/spark/ClosureCleanerSuite.scala b/core/src/test/scala/spark/ClosureCleanerSuite.scala deleted file mode 100644 index 7d2831e19c..0000000000 --- a/core/src/test/scala/spark/ClosureCleanerSuite.scala +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io.NotSerializableException - -import org.scalatest.FunSuite -import spark.LocalSparkContext._ -import SparkContext._ - -class ClosureCleanerSuite extends FunSuite { - test("closures inside an object") { - assert(TestObject.run() === 30) // 6 + 7 + 8 + 9 - } - - test("closures inside a class") { - val obj = new TestClass - assert(obj.run() === 30) // 6 + 7 + 8 + 9 - } - - test("closures inside a class with no default constructor") { - val obj = new TestClassWithoutDefaultConstructor(5) - assert(obj.run() === 30) // 6 + 7 + 8 + 9 - } - - test("closures that don't use fields of the outer class") { - val obj = new TestClassWithoutFieldAccess - assert(obj.run() === 30) // 6 + 7 + 8 + 9 - } - - test("nested closures inside an object") { - assert(TestObjectWithNesting.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1 - } - - test("nested closures inside a class") { - val obj = new TestClassWithNesting(1) - assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1 - } -} - -// A non-serializable class we create in closures to make sure that we aren't -// keeping references to unneeded variables from our outer closures. -class NonSerializable {} - -object TestObject { - def run(): Int = { - var nonSer = new NonSerializable - var x = 5 - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - nums.map(_ + x).reduce(_ + _) - } - } -} - -class TestClass extends Serializable { - var x = 5 - - def getX = x - - def run(): Int = { - var nonSer = new NonSerializable - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - nums.map(_ + getX).reduce(_ + _) - } - } -} - -class TestClassWithoutDefaultConstructor(x: Int) extends Serializable { - def getX = x - - def run(): Int = { - var nonSer = new NonSerializable - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - nums.map(_ + getX).reduce(_ + _) - } - } -} - -// This class is not serializable, but we aren't using any of its fields in our -// closures, so they won't have a $outer pointing to it and should still work. -class TestClassWithoutFieldAccess { - var nonSer = new NonSerializable - - def run(): Int = { - var nonSer2 = new NonSerializable - var x = 5 - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - nums.map(_ + x).reduce(_ + _) - } - } -} - - -object TestObjectWithNesting { - def run(): Int = { - var nonSer = new NonSerializable - var answer = 0 - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - var y = 1 - for (i <- 1 to 4) { - var nonSer2 = new NonSerializable - var x = i - answer += nums.map(_ + x + y).reduce(_ + _) - } - answer - } - } -} - -class TestClassWithNesting(val y: Int) extends Serializable { - def getY = y - - def run(): Int = { - var nonSer = new NonSerializable - var answer = 0 - return withSpark(new SparkContext("local", "test")) { sc => - val nums = sc.parallelize(Array(1, 2, 3, 4)) - for (i <- 1 to 4) { - var nonSer2 = new NonSerializable - var x = i - answer += nums.map(_ + x + getY).reduce(_ + _) - } - answer - } - } -} diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala deleted file mode 100644 index e11efe459c..0000000000 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ /dev/null @@ -1,362 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import network.ConnectionManagerId -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.Timeouts._ -import org.scalatest.matchers.ShouldMatchers -import org.scalatest.prop.Checkers -import org.scalatest.time.{Span, Millis} -import org.scalacheck.Arbitrary._ -import org.scalacheck.Gen -import org.scalacheck.Prop._ -import org.eclipse.jetty.server.{Server, Request, Handler} - -import com.google.common.io.Files - -import scala.collection.mutable.ArrayBuffer - -import SparkContext._ -import storage.{GetBlock, BlockManagerWorker, StorageLevel} -import ui.JettyUtils - - -class NotSerializableClass -class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} - - -class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter - with LocalSparkContext { - - val clusterUrl = "local-cluster[2,1,512]" - - after { - System.clearProperty("spark.reducer.maxMbInFlight") - System.clearProperty("spark.storage.memoryFraction") - } - - test("task throws not serializable exception") { - // Ensures that executors do not crash when an exn is not serializable. If executors crash, - // this test will hang. Correct behavior is that executors don't crash but fail tasks - // and the scheduler throws a SparkException. - - // numSlaves must be less than numPartitions - val numSlaves = 3 - val numPartitions = 10 - - sc = new SparkContext("local-cluster[%s,1,512]".format(numSlaves), "test") - val data = sc.parallelize(1 to 100, numPartitions). - map(x => throw new NotSerializableExn(new NotSerializableClass)) - intercept[SparkException] { - data.count() - } - resetSparkContext() - } - - test("local-cluster format") { - sc = new SparkContext("local-cluster[2,1,512]", "test") - assert(sc.parallelize(1 to 2, 2).count() == 2) - resetSparkContext() - sc = new SparkContext("local-cluster[2 , 1 , 512]", "test") - assert(sc.parallelize(1 to 2, 2).count() == 2) - resetSparkContext() - sc = new SparkContext("local-cluster[2, 1, 512]", "test") - assert(sc.parallelize(1 to 2, 2).count() == 2) - resetSparkContext() - sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test") - assert(sc.parallelize(1 to 2, 2).count() == 2) - resetSparkContext() - } - - test("simple groupByKey") { - sc = new SparkContext(clusterUrl, "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 5) - val groups = pairs.groupByKey(5).collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("groupByKey where map output sizes exceed maxMbInFlight") { - System.setProperty("spark.reducer.maxMbInFlight", "1") - sc = new SparkContext(clusterUrl, "test") - // This data should be around 20 MB, so even with 4 mappers and 2 reducers, each map output - // file should be about 2.5 MB - val pairs = sc.parallelize(1 to 2000, 4).map(x => (x % 16, new Array[Byte](10000))) - val groups = pairs.groupByKey(2).map(x => (x._1, x._2.size)).collect() - assert(groups.length === 16) - assert(groups.map(_._2).sum === 2000) - // Note that spark.reducer.maxMbInFlight will be cleared in the test suite's after{} block - } - - test("accumulators") { - sc = new SparkContext(clusterUrl, "test") - val accum = sc.accumulator(0) - sc.parallelize(1 to 10, 10).foreach(x => accum += x) - assert(accum.value === 55) - } - - test("broadcast variables") { - sc = new SparkContext(clusterUrl, "test") - val array = new Array[Int](100) - val bv = sc.broadcast(array) - array(2) = 3 // Change the array -- this should not be seen on workers - val rdd = sc.parallelize(1 to 10, 10) - val sum = rdd.map(x => bv.value.sum).reduce(_ + _) - assert(sum === 0) - } - - test("repeatedly failing task") { - sc = new SparkContext(clusterUrl, "test") - val accum = sc.accumulator(0) - val thrown = intercept[SparkException] { - sc.parallelize(1 to 10, 10).foreach(x => println(x / 0)) - } - assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.contains("more than 4 times")) - } - - test("caching") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).cache() - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching on disk") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching in memory, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_2) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching in memory, serialized, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_SER_2) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching on disk, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY_2) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching in memory and disk, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_2) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching in memory and disk, serialized, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_SER_2) - - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - - // Get all the locations of the first partition and try to fetch the partitions - // from those locations. - val blockIds = data.partitions.indices.map(index => "rdd_%d_%d".format(data.id, index)).toArray - val blockId = blockIds(0) - val blockManager = SparkEnv.get.blockManager - blockManager.master.getLocations(blockId).foreach(id => { - val bytes = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(id.host, id.port)) - val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList - assert(deserialized === (1 to 100).toList) - }) - } - - test("compute without caching when no partitions fit in memory") { - System.setProperty("spark.storage.memoryFraction", "0.0001") - sc = new SparkContext(clusterUrl, "test") - // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache - // to only 50 KB (0.0001 of 512 MB), so no partitions should fit in memory - val data = sc.parallelize(1 to 4000000, 2).persist(StorageLevel.MEMORY_ONLY_SER) - assert(data.count() === 4000000) - assert(data.count() === 4000000) - assert(data.count() === 4000000) - System.clearProperty("spark.storage.memoryFraction") - } - - test("compute when only some partitions fit in memory") { - System.setProperty("spark.storage.memoryFraction", "0.01") - sc = new SparkContext(clusterUrl, "test") - // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache - // to only 5 MB (0.01 of 512 MB), so not all of it will fit in memory; we use 20 partitions - // to make sure that *some* of them do fit though - val data = sc.parallelize(1 to 4000000, 20).persist(StorageLevel.MEMORY_ONLY_SER) - assert(data.count() === 4000000) - assert(data.count() === 4000000) - assert(data.count() === 4000000) - System.clearProperty("spark.storage.memoryFraction") - } - - test("passing environment variables to cluster") { - sc = new SparkContext(clusterUrl, "test", null, Nil, Map("TEST_VAR" -> "TEST_VALUE")) - val values = sc.parallelize(1 to 2, 2).map(x => System.getenv("TEST_VAR")).collect() - assert(values.toSeq === Seq("TEST_VALUE", "TEST_VALUE")) - } - - test("recover from node failures") { - import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} - DistributedSuite.amMaster = true - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(Seq(true, true), 2) - assert(data.count === 2) // force executors to start - assert(data.map(markNodeIfIdentity).collect.size === 2) - assert(data.map(failOnMarkedIdentity).collect.size === 2) - } - - test("recover from repeated node failures during shuffle-map") { - import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} - DistributedSuite.amMaster = true - sc = new SparkContext(clusterUrl, "test") - for (i <- 1 to 3) { - val data = sc.parallelize(Seq(true, false), 2) - assert(data.count === 2) - assert(data.map(markNodeIfIdentity).collect.size === 2) - assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2) - } - } - - test("recover from repeated node failures during shuffle-reduce") { - import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} - DistributedSuite.amMaster = true - sc = new SparkContext(clusterUrl, "test") - for (i <- 1 to 3) { - val data = sc.parallelize(Seq(true, true), 2) - assert(data.count === 2) - assert(data.map(markNodeIfIdentity).collect.size === 2) - // This relies on mergeCombiners being used to perform the actual reduce for this - // test to actually be testing what it claims. - val grouped = data.map(x => x -> x).combineByKey( - x => x, - (x: Boolean, y: Boolean) => x, - (x: Boolean, y: Boolean) => failOnMarkedIdentity(x) - ) - assert(grouped.collect.size === 1) - } - } - - test("recover from node failures with replication") { - import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} - DistributedSuite.amMaster = true - // Using more than two nodes so we don't have a symmetric communication pattern and might - // cache a partially correct list of peers. - sc = new SparkContext("local-cluster[3,1,512]", "test") - for (i <- 1 to 3) { - val data = sc.parallelize(Seq(true, false, false, false), 4) - data.persist(StorageLevel.MEMORY_ONLY_2) - - assert(data.count === 4) - assert(data.map(markNodeIfIdentity).collect.size === 4) - assert(data.map(failOnMarkedIdentity).collect.size === 4) - - // Create a new replicated RDD to make sure that cached peer information doesn't cause - // problems. - val data2 = sc.parallelize(Seq(true, true), 2).persist(StorageLevel.MEMORY_ONLY_2) - assert(data2.count === 2) - } - } - - test("unpersist RDDs") { - DistributedSuite.amMaster = true - sc = new SparkContext("local-cluster[3,1,512]", "test") - val data = sc.parallelize(Seq(true, false, false, false), 4) - data.persist(StorageLevel.MEMORY_ONLY_2) - data.count - assert(sc.persistentRdds.isEmpty === false) - data.unpersist() - assert(sc.persistentRdds.isEmpty === true) - - failAfter(Span(3000, Millis)) { - try { - while (! sc.getRDDStorageInfo.isEmpty) { - Thread.sleep(200) - } - } catch { - case _ => { Thread.sleep(10) } - // Do nothing. We might see exceptions because block manager - // is racing this thread to remove entries from the driver. - } - } - } - - test("job should fail if TaskResult exceeds Akka frame size") { - // We must use local-cluster mode since results are returned differently - // when running under LocalScheduler: - sc = new SparkContext("local-cluster[1,1,512]", "test") - val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt - val rdd = sc.parallelize(Seq(1)).map{x => new Array[Byte](akkaFrameSize)} - val exception = intercept[SparkException] { - rdd.reduce((x, y) => x) - } - exception.getMessage should endWith("result exceeded Akka frame size") - } -} - -object DistributedSuite { - // Indicates whether this JVM is marked for failure. - var mark = false - - // Set by test to remember if we are in the driver program so we can assert - // that we are not. - var amMaster = false - - // Act like an identity function, but if the argument is true, set mark to true. - def markNodeIfIdentity(item: Boolean): Boolean = { - if (item) { - assert(!amMaster) - mark = true - } - item - } - - // Act like an identity function, but if mark was set to true previously, fail, - // crashing the entire JVM. - def failOnMarkedIdentity(item: Boolean): Boolean = { - if (mark) { - System.exit(42) - } - item - } -} diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala deleted file mode 100644 index 553c0309f6..0000000000 --- a/core/src/test/scala/spark/DriverSuite.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io.File - -import org.apache.log4j.Logger -import org.apache.log4j.Level - -import org.scalatest.FunSuite -import org.scalatest.concurrent.Timeouts -import org.scalatest.prop.TableDrivenPropertyChecks._ -import org.scalatest.time.SpanSugar._ - -class DriverSuite extends FunSuite with Timeouts { - test("driver should exit after finishing") { - assert(System.getenv("SPARK_HOME") != null) - // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" - val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) - forAll(masters) { (master: String) => - failAfter(30 seconds) { - Utils.execute(Seq("./spark-class", "spark.DriverWithoutCleanup", master), - new File(System.getenv("SPARK_HOME"))) - } - } - } -} - -/** - * Program that creates a Spark driver but doesn't call SparkContext.stop() or - * Sys.exit() after finishing. - */ -object DriverWithoutCleanup { - def main(args: Array[String]) { - Logger.getRootLogger().setLevel(Level.WARN) - val sc = new SparkContext(args(0), "DriverWithoutCleanup") - sc.parallelize(1 to 100, 4).count() - } -} diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala deleted file mode 100644 index 5b133cdd6e..0000000000 --- a/core/src/test/scala/spark/FailureSuite.scala +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite - -import SparkContext._ - -// Common state shared by FailureSuite-launched tasks. We use a global object -// for this because any local variables used in the task closures will rightfully -// be copied for each task, so there's no other way for them to share state. -object FailureSuiteState { - var tasksRun = 0 - var tasksFailed = 0 - - def clear() { - synchronized { - tasksRun = 0 - tasksFailed = 0 - } - } -} - -class FailureSuite extends FunSuite with LocalSparkContext { - - // Run a 3-task map job in which task 1 deterministically fails once, and check - // whether the job completes successfully and we ran 4 tasks in total. - test("failure in a single-stage job") { - sc = new SparkContext("local[1,1]", "test") - val results = sc.makeRDD(1 to 3, 3).map { x => - FailureSuiteState.synchronized { - FailureSuiteState.tasksRun += 1 - if (x == 1 && FailureSuiteState.tasksFailed == 0) { - FailureSuiteState.tasksFailed += 1 - throw new Exception("Intentional task failure") - } - } - x * x - }.collect() - FailureSuiteState.synchronized { - assert(FailureSuiteState.tasksRun === 4) - } - assert(results.toList === List(1,4,9)) - FailureSuiteState.clear() - } - - // Run a map-reduce job in which a reduce task deterministically fails once. - test("failure in a two-stage job") { - sc = new SparkContext("local[1,1]", "test") - val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map { - case (k, v) => - FailureSuiteState.synchronized { - FailureSuiteState.tasksRun += 1 - if (k == 1 && FailureSuiteState.tasksFailed == 0) { - FailureSuiteState.tasksFailed += 1 - throw new Exception("Intentional task failure") - } - } - (k, v(0) * v(0)) - }.collect() - FailureSuiteState.synchronized { - assert(FailureSuiteState.tasksRun === 4) - } - assert(results.toSet === Set((1, 1), (2, 4), (3, 9))) - FailureSuiteState.clear() - } - - test("failure because task results are not serializable") { - sc = new SparkContext("local[1,1]", "test") - val results = sc.makeRDD(1 to 3).map(x => new NonSerializable) - - val thrown = intercept[SparkException] { - results.collect() - } - assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.contains("NotSerializableException")) - - FailureSuiteState.clear() - } - - test("failure because task closure is not serializable") { - sc = new SparkContext("local[1,1]", "test") - val a = new NonSerializable - - // Non-serializable closure in the final result stage - val thrown = intercept[SparkException] { - sc.parallelize(1 to 10, 2).map(x => a).count() - } - assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.contains("NotSerializableException")) - - // Non-serializable closure in an earlier stage - val thrown1 = intercept[SparkException] { - sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count() - } - assert(thrown1.getClass === classOf[SparkException]) - assert(thrown1.getMessage.contains("NotSerializableException")) - - // Non-serializable closure in foreach function - val thrown2 = intercept[SparkException] { - sc.parallelize(1 to 10, 2).foreach(x => println(a)) - } - assert(thrown2.getClass === classOf[SparkException]) - assert(thrown2.getMessage.contains("NotSerializableException")) - - FailureSuiteState.clear() - } - - // TODO: Need to add tests with shuffle fetch failures. -} - - diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala deleted file mode 100644 index 242ae971f8..0000000000 --- a/core/src/test/scala/spark/FileServerSuite.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import com.google.common.io.Files -import org.scalatest.FunSuite -import java.io.{File, PrintWriter, FileReader, BufferedReader} -import SparkContext._ - -class FileServerSuite extends FunSuite with LocalSparkContext { - - @transient var tmpFile: File = _ - @transient var testJarFile: File = _ - - override def beforeEach() { - super.beforeEach() - // Create a sample text file - val tmpdir = new File(Files.createTempDir(), "test") - tmpdir.mkdir() - tmpFile = new File(tmpdir, "FileServerSuite.txt") - val pw = new PrintWriter(tmpFile) - pw.println("100") - pw.close() - } - - override def afterEach() { - super.afterEach() - // Clean up downloaded file - if (tmpFile.exists) { - tmpFile.delete() - } - } - - test("Distributing files locally") { - sc = new SparkContext("local[4]", "test") - sc.addFile(tmpFile.toString) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) - val result = sc.parallelize(testData).reduceByKey { - val path = SparkFiles.get("FileServerSuite.txt") - val in = new BufferedReader(new FileReader(path)) - val fileVal = in.readLine().toInt - in.close() - _ * fileVal + _ * fileVal - }.collect() - assert(result.toSet === Set((1,200), (2,300), (3,500))) - } - - test("Distributing files locally using URL as input") { - // addFile("file:///....") - sc = new SparkContext("local[4]", "test") - sc.addFile(new File(tmpFile.toString).toURI.toString) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) - val result = sc.parallelize(testData).reduceByKey { - val path = SparkFiles.get("FileServerSuite.txt") - val in = new BufferedReader(new FileReader(path)) - val fileVal = in.readLine().toInt - in.close() - _ * fileVal + _ * fileVal - }.collect() - assert(result.toSet === Set((1,200), (2,300), (3,500))) - } - - test ("Dynamically adding JARS locally") { - sc = new SparkContext("local[4]", "test") - val sampleJarFile = getClass.getClassLoader.getResource("uncommons-maths-1.2.2.jar").getFile() - sc.addJar(sampleJarFile) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0)) - val result = sc.parallelize(testData).reduceByKey { (x,y) => - val fac = Thread.currentThread.getContextClassLoader() - .loadClass("org.uncommons.maths.Maths") - .getDeclaredMethod("factorial", classOf[Int]) - val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt - val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt - a + b - }.collect() - assert(result.toSet === Set((1,2), (2,7), (3,121))) - } - - test("Distributing files on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,512]", "test") - sc.addFile(tmpFile.toString) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) - val result = sc.parallelize(testData).reduceByKey { - val path = SparkFiles.get("FileServerSuite.txt") - val in = new BufferedReader(new FileReader(path)) - val fileVal = in.readLine().toInt - in.close() - _ * fileVal + _ * fileVal - }.collect() - assert(result.toSet === Set((1,200), (2,300), (3,500))) - } - - test ("Dynamically adding JARS on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,512]", "test") - val sampleJarFile = getClass.getClassLoader.getResource("uncommons-maths-1.2.2.jar").getFile() - sc.addJar(sampleJarFile) - val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0)) - val result = sc.parallelize(testData).reduceByKey { (x,y) => - val fac = Thread.currentThread.getContextClassLoader() - .loadClass("org.uncommons.maths.Maths") - .getDeclaredMethod("factorial", classOf[Int]) - val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt - val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt - a + b - }.collect() - assert(result.toSet === Set((1,2), (2,7), (3,121))) - } -} diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala deleted file mode 100644 index 1e2c257c4b..0000000000 --- a/core/src/test/scala/spark/FileSuite.scala +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.io.{FileWriter, PrintWriter, File} - -import scala.io.Source - -import com.google.common.io.Files -import org.scalatest.FunSuite -import org.apache.hadoop.io._ -import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodec, GzipCodec} - - -import SparkContext._ - -class FileSuite extends FunSuite with LocalSparkContext { - - test("text files") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 4) - nums.saveAsTextFile(outputDir) - // Read the plain text file and check it's OK - val outputFile = new File(outputDir, "part-00000") - val content = Source.fromFile(outputFile).mkString - assert(content === "1\n2\n3\n4\n") - // Also try reading it in as a text file RDD - assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4")) - } - - test("text files (compressed)") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val normalDir = new File(tempDir, "output_normal").getAbsolutePath - val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath - val codec = new DefaultCodec() - - val data = sc.parallelize("a" * 10000, 1) - data.saveAsTextFile(normalDir) - data.saveAsTextFile(compressedOutputDir, classOf[DefaultCodec]) - - val normalFile = new File(normalDir, "part-00000") - val normalContent = sc.textFile(normalDir).collect - assert(normalContent === Array.fill(10000)("a")) - - val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension) - val compressedContent = sc.textFile(compressedOutputDir).collect - assert(compressedContent === Array.fill(10000)("a")) - - assert(compressedFile.length < normalFile.length) - } - - test("SequenceFiles") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) // (1,a), (2,aa), (3,aaa) - nums.saveAsSequenceFile(outputDir) - // Try reading the output back as a SequenceFile - val output = sc.sequenceFile[IntWritable, Text](outputDir) - assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - } - - test("SequenceFile (compressed)") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val normalDir = new File(tempDir, "output_normal").getAbsolutePath - val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath - val codec = new DefaultCodec() - - val data = sc.parallelize(Seq.fill(100)("abc"), 1).map(x => (x, x)) - data.saveAsSequenceFile(normalDir) - data.saveAsSequenceFile(compressedOutputDir, Some(classOf[DefaultCodec])) - - val normalFile = new File(normalDir, "part-00000") - val normalContent = sc.sequenceFile[String, String](normalDir).collect - assert(normalContent === Array.fill(100)("abc", "abc")) - - val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension) - val compressedContent = sc.sequenceFile[String, String](compressedOutputDir).collect - assert(compressedContent === Array.fill(100)("abc", "abc")) - - assert(compressedFile.length < normalFile.length) - } - - test("SequenceFile with writable key") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), "a" * x)) - nums.saveAsSequenceFile(outputDir) - // Try reading the output back as a SequenceFile - val output = sc.sequenceFile[IntWritable, Text](outputDir) - assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - } - - test("SequenceFile with writable value") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (x, new Text("a" * x))) - nums.saveAsSequenceFile(outputDir) - // Try reading the output back as a SequenceFile - val output = sc.sequenceFile[IntWritable, Text](outputDir) - assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - } - - test("SequenceFile with writable key and value") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) - nums.saveAsSequenceFile(outputDir) - // Try reading the output back as a SequenceFile - val output = sc.sequenceFile[IntWritable, Text](outputDir) - assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - } - - test("implicit conversions in reading SequenceFiles") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) // (1,a), (2,aa), (3,aaa) - nums.saveAsSequenceFile(outputDir) - // Similar to the tests above, we read a SequenceFile, but this time we pass type params - // that are convertable to Writable instead of calling sequenceFile[IntWritable, Text] - val output1 = sc.sequenceFile[Int, String](outputDir) - assert(output1.collect().toList === List((1, "a"), (2, "aa"), (3, "aaa"))) - // Also try having one type be a subclass of Writable and one not - val output2 = sc.sequenceFile[Int, Text](outputDir) - assert(output2.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - val output3 = sc.sequenceFile[IntWritable, String](outputDir) - assert(output3.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - } - - test("object files of ints") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 4) - nums.saveAsObjectFile(outputDir) - // Try reading the output back as an object file - val output = sc.objectFile[Int](outputDir) - assert(output.collect().toList === List(1, 2, 3, 4)) - } - - test("object files of complex types") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (x, "a" * x)) - nums.saveAsObjectFile(outputDir) - // Try reading the output back as an object file - val output = sc.objectFile[(Int, String)](outputDir) - assert(output.collect().toList === List((1, "a"), (2, "aa"), (3, "aaa"))) - } - - test("write SequenceFile using new Hadoop API") { - import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) - nums.saveAsNewAPIHadoopFile[SequenceFileOutputFormat[IntWritable, Text]]( - outputDir) - val output = sc.sequenceFile[IntWritable, Text](outputDir) - assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - } - - test("read SequenceFile using new Hadoop API") { - import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) - nums.saveAsSequenceFile(outputDir) - val output = - sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir) - assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) - } - - test("file caching") { - sc = new SparkContext("local", "test") - val tempDir = Files.createTempDir() - val out = new FileWriter(tempDir + "/input") - out.write("Hello world!\n") - out.write("What's up?\n") - out.write("Goodbye\n") - out.close() - val rdd = sc.textFile(tempDir + "/input").cache() - assert(rdd.count() === 3) - assert(rdd.count() === 3) - assert(rdd.count() === 3) - } -} diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java deleted file mode 100644 index c337c49268..0000000000 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ /dev/null @@ -1,865 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark; - -import java.io.File; -import java.io.IOException; -import java.io.Serializable; -import java.util.*; - -import com.google.common.base.Optional; -import scala.Tuple2; - -import com.google.common.base.Charsets; -import org.apache.hadoop.io.compress.DefaultCodec; -import com.google.common.io.Files; -import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.mapred.SequenceFileInputFormat; -import org.apache.hadoop.mapred.SequenceFileOutputFormat; -import org.apache.hadoop.mapreduce.Job; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import spark.api.java.JavaDoubleRDD; -import spark.api.java.JavaPairRDD; -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.*; -import spark.partial.BoundedDouble; -import spark.partial.PartialResult; -import spark.storage.StorageLevel; -import spark.util.StatCounter; - - -// The test suite itself is Serializable so that anonymous Function implementations can be -// serialized, as an alternative to converting these anonymous classes to static inner classes; -// see http://stackoverflow.com/questions/758570/. -public class JavaAPISuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaAPISuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port"); - } - - static class ReverseIntComparator implements Comparator, Serializable { - - @Override - public int compare(Integer a, Integer b) { - if (a > b) return -1; - else if (a < b) return 1; - else return 0; - } - }; - - @Test - public void sparkContextUnion() { - // Union of non-specialized JavaRDDs - List strings = Arrays.asList("Hello", "World"); - JavaRDD s1 = sc.parallelize(strings); - JavaRDD s2 = sc.parallelize(strings); - // Varargs - JavaRDD sUnion = sc.union(s1, s2); - Assert.assertEquals(4, sUnion.count()); - // List - List> list = new ArrayList>(); - list.add(s2); - sUnion = sc.union(s1, list); - Assert.assertEquals(4, sUnion.count()); - - // Union of JavaDoubleRDDs - List doubles = Arrays.asList(1.0, 2.0); - JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles); - JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles); - JavaDoubleRDD dUnion = sc.union(d1, d2); - Assert.assertEquals(4, dUnion.count()); - - // Union of JavaPairRDDs - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(1, 2)); - pairs.add(new Tuple2(3, 4)); - JavaPairRDD p1 = sc.parallelizePairs(pairs); - JavaPairRDD p2 = sc.parallelizePairs(pairs); - JavaPairRDD pUnion = sc.union(p1, p2); - Assert.assertEquals(4, pUnion.count()); - } - - @Test - public void sortByKey() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 4)); - pairs.add(new Tuple2(3, 2)); - pairs.add(new Tuple2(-1, 1)); - - JavaPairRDD rdd = sc.parallelizePairs(pairs); - - // Default comparator - JavaPairRDD sortedRDD = rdd.sortByKey(); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); - List> sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); - - // Custom comparator - sortedRDD = rdd.sortByKey(new ReverseIntComparator(), false); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); - sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); - } - - static int foreachCalls = 0; - - @Test - public void foreach() { - foreachCalls = 0; - JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); - rdd.foreach(new VoidFunction() { - @Override - public void call(String s) { - foreachCalls++; - } - }); - Assert.assertEquals(2, foreachCalls); - } - - @Test - public void lookup() { - JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") - )); - Assert.assertEquals(2, categories.lookup("Oranges").size()); - Assert.assertEquals(2, categories.groupByKey().lookup("Oranges").get(0).size()); - } - - @Test - public void groupBy() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); - Function isOdd = new Function() { - @Override - public Boolean call(Integer x) { - return x % 2 == 0; - } - }; - JavaPairRDD> oddsAndEvens = rdd.groupBy(isOdd); - Assert.assertEquals(2, oddsAndEvens.count()); - Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens - Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds - - oddsAndEvens = rdd.groupBy(isOdd, 1); - Assert.assertEquals(2, oddsAndEvens.count()); - Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens - Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds - } - - @Test - public void cogroup() { - JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") - )); - JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) - )); - JavaPairRDD, List>> cogrouped = categories.cogroup(prices); - Assert.assertEquals("[Fruit, Citrus]", cogrouped.lookup("Oranges").get(0)._1().toString()); - Assert.assertEquals("[2]", cogrouped.lookup("Oranges").get(0)._2().toString()); - - cogrouped.collect(); - } - - @Test - public void leftOuterJoin() { - JavaPairRDD rdd1 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 1), - new Tuple2(1, 2), - new Tuple2(2, 1), - new Tuple2(3, 1) - )); - JavaPairRDD rdd2 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 'x'), - new Tuple2(2, 'y'), - new Tuple2(2, 'z'), - new Tuple2(4, 'w') - )); - List>>> joined = - rdd1.leftOuterJoin(rdd2).collect(); - Assert.assertEquals(5, joined.size()); - Tuple2>> firstUnmatched = - rdd1.leftOuterJoin(rdd2).filter( - new Function>>, Boolean>() { - @Override - public Boolean call(Tuple2>> tup) - throws Exception { - return !tup._2()._2().isPresent(); - } - }).first(); - Assert.assertEquals(3, firstUnmatched._1().intValue()); - } - - @Test - public void foldReduce() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); - Function2 add = new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }; - - int sum = rdd.fold(0, add); - Assert.assertEquals(33, sum); - - sum = rdd.reduce(add); - Assert.assertEquals(33, sum); - } - - @Test - public void foldByKey() { - List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - JavaPairRDD sums = rdd.foldByKey(0, - new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }); - Assert.assertEquals(1, sums.lookup(1).get(0).intValue()); - Assert.assertEquals(2, sums.lookup(2).get(0).intValue()); - Assert.assertEquals(3, sums.lookup(3).get(0).intValue()); - } - - @Test - public void reduceByKey() { - List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - JavaPairRDD counts = rdd.reduceByKey( - new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }); - Assert.assertEquals(1, counts.lookup(1).get(0).intValue()); - Assert.assertEquals(2, counts.lookup(2).get(0).intValue()); - Assert.assertEquals(3, counts.lookup(3).get(0).intValue()); - - Map localCounts = counts.collectAsMap(); - Assert.assertEquals(1, localCounts.get(1).intValue()); - Assert.assertEquals(2, localCounts.get(2).intValue()); - Assert.assertEquals(3, localCounts.get(3).intValue()); - - localCounts = rdd.reduceByKeyLocally(new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }); - Assert.assertEquals(1, localCounts.get(1).intValue()); - Assert.assertEquals(2, localCounts.get(2).intValue()); - Assert.assertEquals(3, localCounts.get(3).intValue()); - } - - @Test - public void approximateResults() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); - Map countsByValue = rdd.countByValue(); - Assert.assertEquals(2, countsByValue.get(1).longValue()); - Assert.assertEquals(1, countsByValue.get(13).longValue()); - - PartialResult> approx = rdd.countByValueApprox(1); - Map finalValue = approx.getFinalValue(); - Assert.assertEquals(2.0, finalValue.get(1).mean(), 0.01); - Assert.assertEquals(1.0, finalValue.get(13).mean(), 0.01); - } - - @Test - public void take() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); - Assert.assertEquals(1, rdd.first().intValue()); - List firstTwo = rdd.take(2); - List sample = rdd.takeSample(false, 2, 42); - } - - @Test - public void cartesian() { - JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); - JavaRDD stringRDD = sc.parallelize(Arrays.asList("Hello", "World")); - JavaPairRDD cartesian = stringRDD.cartesian(doubleRDD); - Assert.assertEquals(new Tuple2("Hello", 1.0), cartesian.first()); - } - - @Test - public void javaDoubleRDD() { - JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); - JavaDoubleRDD distinct = rdd.distinct(); - Assert.assertEquals(5, distinct.count()); - JavaDoubleRDD filter = rdd.filter(new Function() { - @Override - public Boolean call(Double x) { - return x > 2.0; - } - }); - Assert.assertEquals(3, filter.count()); - JavaDoubleRDD union = rdd.union(rdd); - Assert.assertEquals(12, union.count()); - union = union.cache(); - Assert.assertEquals(12, union.count()); - - Assert.assertEquals(20, rdd.sum(), 0.01); - StatCounter stats = rdd.stats(); - Assert.assertEquals(20, stats.sum(), 0.01); - Assert.assertEquals(20/6.0, rdd.mean(), 0.01); - Assert.assertEquals(20/6.0, rdd.mean(), 0.01); - Assert.assertEquals(6.22222, rdd.variance(), 0.01); - Assert.assertEquals(7.46667, rdd.sampleVariance(), 0.01); - Assert.assertEquals(2.49444, rdd.stdev(), 0.01); - Assert.assertEquals(2.73252, rdd.sampleStdev(), 0.01); - - Double first = rdd.first(); - List take = rdd.take(5); - } - - @Test - public void map() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - JavaDoubleRDD doubles = rdd.map(new DoubleFunction() { - @Override - public Double call(Integer x) { - return 1.0 * x; - } - }).cache(); - JavaPairRDD pairs = rdd.map(new PairFunction() { - @Override - public Tuple2 call(Integer x) { - return new Tuple2(x, x); - } - }).cache(); - JavaRDD strings = rdd.map(new Function() { - @Override - public String call(Integer x) { - return x.toString(); - } - }).cache(); - } - - @Test - public void flatMap() { - JavaRDD rdd = sc.parallelize(Arrays.asList("Hello World!", - "The quick brown fox jumps over the lazy dog.")); - JavaRDD words = rdd.flatMap(new FlatMapFunction() { - @Override - public Iterable call(String x) { - return Arrays.asList(x.split(" ")); - } - }); - Assert.assertEquals("Hello", words.first()); - Assert.assertEquals(11, words.count()); - - JavaPairRDD pairs = rdd.flatMap( - new PairFlatMapFunction() { - - @Override - public Iterable> call(String s) { - List> pairs = new LinkedList>(); - for (String word : s.split(" ")) pairs.add(new Tuple2(word, word)); - return pairs; - } - } - ); - Assert.assertEquals(new Tuple2("Hello", "Hello"), pairs.first()); - Assert.assertEquals(11, pairs.count()); - - JavaDoubleRDD doubles = rdd.flatMap(new DoubleFlatMapFunction() { - @Override - public Iterable call(String s) { - List lengths = new LinkedList(); - for (String word : s.split(" ")) lengths.add(word.length() * 1.0); - return lengths; - } - }); - Double x = doubles.first(); - Assert.assertEquals(5.0, doubles.first().doubleValue(), 0.01); - Assert.assertEquals(11, pairs.count()); - } - - @Test - public void mapsFromPairsToPairs() { - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD pairRDD = sc.parallelizePairs(pairs); - - // Regression test for SPARK-668: - JavaPairRDD swapped = pairRDD.flatMap( - new PairFlatMapFunction, String, Integer>() { - @Override - public Iterable> call(Tuple2 item) throws Exception { - return Collections.singletonList(item.swap()); - } - }); - swapped.collect(); - - // There was never a bug here, but it's worth testing: - pairRDD.map(new PairFunction, String, Integer>() { - @Override - public Tuple2 call(Tuple2 item) throws Exception { - return item.swap(); - } - }).collect(); - } - - @Test - public void mapPartitions() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); - JavaRDD partitionSums = rdd.mapPartitions( - new FlatMapFunction, Integer>() { - @Override - public Iterable call(Iterator iter) { - int sum = 0; - while (iter.hasNext()) { - sum += iter.next(); - } - return Collections.singletonList(sum); - } - }); - Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); - } - - @Test - public void persist() { - JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); - doubleRDD = doubleRDD.persist(StorageLevel.DISK_ONLY()); - Assert.assertEquals(20, doubleRDD.sum(), 0.1); - - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD pairRDD = sc.parallelizePairs(pairs); - pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY()); - Assert.assertEquals("a", pairRDD.first()._2()); - - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - rdd = rdd.persist(StorageLevel.DISK_ONLY()); - Assert.assertEquals(1, rdd.first().intValue()); - } - - @Test - public void iterator() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0, null); - Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue()); - } - - @Test - public void glom() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); - Assert.assertEquals("[1, 2]", rdd.glom().first().toString()); - } - - // File input / output tests are largely adapted from FileSuite: - - @Test - public void textFiles() throws IOException { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output").getAbsolutePath(); - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); - rdd.saveAsTextFile(outputDir); - // Read the plain text file and check it's OK - File outputFile = new File(outputDir, "part-00000"); - String content = Files.toString(outputFile, Charsets.UTF_8); - Assert.assertEquals("1\n2\n3\n4\n", content); - // Also try reading it in as a text file RDD - List expected = Arrays.asList("1", "2", "3", "4"); - JavaRDD readRDD = sc.textFile(outputDir); - Assert.assertEquals(expected, readRDD.collect()); - } - - @Test - public void textFilesCompressed() throws IOException { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output").getAbsolutePath(); - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); - rdd.saveAsTextFile(outputDir, DefaultCodec.class); - - // Try reading it in as a text file RDD - List expected = Arrays.asList("1", "2", "3", "4"); - JavaRDD readRDD = sc.textFile(outputDir); - Assert.assertEquals(expected, readRDD.collect()); - } - - @Test - public void sequenceFile() { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output").getAbsolutePath(); - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - - rdd.map(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); - - // Try reading the output back as an object file - JavaPairRDD readRDD = sc.sequenceFile(outputDir, IntWritable.class, - Text.class).map(new PairFunction, Integer, String>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2(pair._1().get(), pair._2().toString()); - } - }); - Assert.assertEquals(pairs, readRDD.collect()); - } - - @Test - public void writeWithNewAPIHadoopFile() { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output").getAbsolutePath(); - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - - rdd.map(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class, - org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); - - JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, - Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { - @Override - public String call(Tuple2 x) { - return x.toString(); - } - }).collect().toString()); - } - - @Test - public void readWithNewAPIHadoopFile() throws IOException { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output").getAbsolutePath(); - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - - rdd.map(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); - - JavaPairRDD output = sc.newAPIHadoopFile(outputDir, - org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, IntWritable.class, - Text.class, new Job().getConfiguration()); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { - @Override - public String call(Tuple2 x) { - return x.toString(); - } - }).collect().toString()); - } - - @Test - public void objectFilesOfInts() { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output").getAbsolutePath(); - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4)); - rdd.saveAsObjectFile(outputDir); - // Try reading the output back as an object file - List expected = Arrays.asList(1, 2, 3, 4); - JavaRDD readRDD = sc.objectFile(outputDir); - Assert.assertEquals(expected, readRDD.collect()); - } - - @Test - public void objectFilesOfComplexTypes() { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output").getAbsolutePath(); - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.saveAsObjectFile(outputDir); - // Try reading the output back as an object file - JavaRDD> readRDD = sc.objectFile(outputDir); - Assert.assertEquals(pairs, readRDD.collect()); - } - - @Test - public void hadoopFile() { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output").getAbsolutePath(); - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - - rdd.map(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); - - JavaPairRDD output = sc.hadoopFile(outputDir, - SequenceFileInputFormat.class, IntWritable.class, Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { - @Override - public String call(Tuple2 x) { - return x.toString(); - } - }).collect().toString()); - } - - @Test - public void hadoopFileCompressed() { - File tempDir = Files.createTempDir(); - String outputDir = new File(tempDir, "output_compressed").getAbsolutePath(); - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD rdd = sc.parallelizePairs(pairs); - - rdd.map(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, - DefaultCodec.class); - - JavaPairRDD output = sc.hadoopFile(outputDir, - SequenceFileInputFormat.class, IntWritable.class, Text.class); - - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { - @Override - public String call(Tuple2 x) { - return x.toString(); - } - }).collect().toString()); - } - - @Test - public void zip() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - JavaDoubleRDD doubles = rdd.map(new DoubleFunction() { - @Override - public Double call(Integer x) { - return 1.0 * x; - } - }); - JavaPairRDD zipped = rdd.zip(doubles); - zipped.count(); - } - - @Test - public void zipPartitions() { - JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2); - JavaRDD rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2); - FlatMapFunction2, Iterator, Integer> sizesFn = - new FlatMapFunction2, Iterator, Integer>() { - @Override - public Iterable call(Iterator i, Iterator s) { - int sizeI = 0; - int sizeS = 0; - while (i.hasNext()) { - sizeI += 1; - i.next(); - } - while (s.hasNext()) { - sizeS += 1; - s.next(); - } - return Arrays.asList(sizeI, sizeS); - } - }; - - JavaRDD sizes = rdd1.zipPartitions(rdd2, sizesFn); - Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); - } - - @Test - public void accumulators() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - - final Accumulator intAccum = sc.intAccumulator(10); - rdd.foreach(new VoidFunction() { - public void call(Integer x) { - intAccum.add(x); - } - }); - Assert.assertEquals((Integer) 25, intAccum.value()); - - final Accumulator doubleAccum = sc.doubleAccumulator(10.0); - rdd.foreach(new VoidFunction() { - public void call(Integer x) { - doubleAccum.add((double) x); - } - }); - Assert.assertEquals((Double) 25.0, doubleAccum.value()); - - // Try a custom accumulator type - AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { - public Float addInPlace(Float r, Float t) { - return r + t; - } - - public Float addAccumulator(Float r, Float t) { - return r + t; - } - - public Float zero(Float initialValue) { - return 0.0f; - } - }; - - final Accumulator floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam); - rdd.foreach(new VoidFunction() { - public void call(Integer x) { - floatAccum.add((float) x); - } - }); - Assert.assertEquals((Float) 25.0f, floatAccum.value()); - - // Test the setValue method - floatAccum.setValue(5.0f); - Assert.assertEquals((Float) 5.0f, floatAccum.value()); - } - - @Test - public void keyBy() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); - List> s = rdd.keyBy(new Function() { - public String call(Integer t) throws Exception { - return t.toString(); - } - }).collect(); - Assert.assertEquals(new Tuple2("1", 1), s.get(0)); - Assert.assertEquals(new Tuple2("2", 2), s.get(1)); - } - - @Test - public void checkpointAndComputation() { - File tempDir = Files.createTempDir(); - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - sc.setCheckpointDir(tempDir.getAbsolutePath(), true); - Assert.assertEquals(false, rdd.isCheckpointed()); - rdd.checkpoint(); - rdd.count(); // Forces the DAG to cause a checkpoint - Assert.assertEquals(true, rdd.isCheckpointed()); - Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect()); - } - - @Test - public void checkpointAndRestore() { - File tempDir = Files.createTempDir(); - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - sc.setCheckpointDir(tempDir.getAbsolutePath(), true); - Assert.assertEquals(false, rdd.isCheckpointed()); - rdd.checkpoint(); - rdd.count(); // Forces the DAG to cause a checkpoint - Assert.assertEquals(true, rdd.isCheckpointed()); - - Assert.assertTrue(rdd.getCheckpointFile().isPresent()); - JavaRDD recovered = sc.checkpointFile(rdd.getCheckpointFile().get()); - Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect()); - } - - @Test - public void mapOnPairRDD() { - JavaRDD rdd1 = sc.parallelize(Arrays.asList(1,2,3,4)); - JavaPairRDD rdd2 = rdd1.map(new PairFunction() { - @Override - public Tuple2 call(Integer i) throws Exception { - return new Tuple2(i, i % 2); - } - }); - JavaPairRDD rdd3 = rdd2.map( - new PairFunction, Integer, Integer>() { - @Override - public Tuple2 call(Tuple2 in) throws Exception { - return new Tuple2(in._2(), in._1()); - } - }); - Assert.assertEquals(Arrays.asList( - new Tuple2(1, 1), - new Tuple2(0, 2), - new Tuple2(1, 3), - new Tuple2(0, 4)), rdd3.collect()); - - } -} diff --git a/core/src/test/scala/spark/KryoSerializerSuite.scala b/core/src/test/scala/spark/KryoSerializerSuite.scala deleted file mode 100644 index 7568a0bf65..0000000000 --- a/core/src/test/scala/spark/KryoSerializerSuite.scala +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import scala.collection.mutable - -import org.scalatest.FunSuite -import com.esotericsoftware.kryo._ - -import KryoTest._ - -class KryoSerializerSuite extends FunSuite with SharedSparkContext { - test("basic types") { - val ser = (new KryoSerializer).newInstance() - def check[T](t: T) { - assert(ser.deserialize[T](ser.serialize(t)) === t) - } - check(1) - check(1L) - check(1.0f) - check(1.0) - check(1.toByte) - check(1.toShort) - check("") - check("hello") - check(Integer.MAX_VALUE) - check(Integer.MIN_VALUE) - check(java.lang.Long.MAX_VALUE) - check(java.lang.Long.MIN_VALUE) - check[String](null) - check(Array(1, 2, 3)) - check(Array(1L, 2L, 3L)) - check(Array(1.0, 2.0, 3.0)) - check(Array(1.0f, 2.9f, 3.9f)) - check(Array("aaa", "bbb", "ccc")) - check(Array("aaa", "bbb", null)) - check(Array(true, false, true)) - check(Array('a', 'b', 'c')) - check(Array[Int]()) - check(Array(Array("1", "2"), Array("1", "2", "3", "4"))) - } - - test("pairs") { - val ser = (new KryoSerializer).newInstance() - def check[T](t: T) { - assert(ser.deserialize[T](ser.serialize(t)) === t) - } - check((1, 1)) - check((1, 1L)) - check((1L, 1)) - check((1L, 1L)) - check((1.0, 1)) - check((1, 1.0)) - check((1.0, 1.0)) - check((1.0, 1L)) - check((1L, 1.0)) - check((1.0, 1L)) - check(("x", 1)) - check(("x", 1.0)) - check(("x", 1L)) - check((1, "x")) - check((1.0, "x")) - check((1L, "x")) - check(("x", "x")) - } - - test("Scala data structures") { - val ser = (new KryoSerializer).newInstance() - def check[T](t: T) { - assert(ser.deserialize[T](ser.serialize(t)) === t) - } - check(List[Int]()) - check(List[Int](1, 2, 3)) - check(List[String]()) - check(List[String]("x", "y", "z")) - check(None) - check(Some(1)) - check(Some("hi")) - check(mutable.ArrayBuffer(1, 2, 3)) - check(mutable.ArrayBuffer("1", "2", "3")) - check(mutable.Map()) - check(mutable.Map(1 -> "one", 2 -> "two")) - check(mutable.Map("one" -> 1, "two" -> 2)) - check(mutable.HashMap(1 -> "one", 2 -> "two")) - check(mutable.HashMap("one" -> 1, "two" -> 2)) - check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) - check(List(mutable.HashMap("one" -> 1, "two" -> 2),mutable.HashMap(1->"one",2->"two",3->"three"))) - } - - test("custom registrator") { - import KryoTest._ - System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) - - val ser = (new KryoSerializer).newInstance() - def check[T](t: T) { - assert(ser.deserialize[T](ser.serialize(t)) === t) - } - - check(CaseClass(17, "hello")) - - val c1 = new ClassWithNoArgConstructor - c1.x = 32 - check(c1) - - val c2 = new ClassWithoutNoArgConstructor(47) - check(c2) - - val hashMap = new java.util.HashMap[String, String] - hashMap.put("foo", "bar") - check(hashMap) - - System.clearProperty("spark.kryo.registrator") - } - - test("kryo with collect") { - val control = 1 :: 2 :: Nil - val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)).collect().map(_.x) - assert(control === result.toSeq) - } - - test("kryo with parallelize") { - val control = 1 :: 2 :: Nil - val result = sc.parallelize(control.map(new ClassWithoutNoArgConstructor(_))).map(_.x).collect() - assert (control === result.toSeq) - } - - test("kryo with parallelize for specialized tuples") { - assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).count === 3) - } - - test("kryo with parallelize for primitive arrays") { - assert (sc.parallelize( Array(1, 2, 3) ).count === 3) - } - - test("kryo with collect for specialized tuples") { - assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).collect().head === (1, 11)) - } - - test("kryo with reduce") { - val control = 1 :: 2 :: Nil - val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) - .reduce((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x - assert(control.sum === result) - } - - // TODO: this still doesn't work - ignore("kryo with fold") { - val control = 1 :: 2 :: Nil - val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) - .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x - assert(10 + control.sum === result) - } - - override def beforeAll() { - System.setProperty("spark.serializer", "spark.KryoSerializer") - System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) - super.beforeAll() - } - - override def afterAll() { - super.afterAll() - System.clearProperty("spark.kryo.registrator") - System.clearProperty("spark.serializer") - } -} - -object KryoTest { - case class CaseClass(i: Int, s: String) {} - - class ClassWithNoArgConstructor { - var x: Int = 0 - override def equals(other: Any) = other match { - case c: ClassWithNoArgConstructor => x == c.x - case _ => false - } - } - - class ClassWithoutNoArgConstructor(val x: Int) { - override def equals(other: Any) = other match { - case c: ClassWithoutNoArgConstructor => x == c.x - case _ => false - } - } - - class MyRegistrator extends KryoRegistrator { - override def registerClasses(k: Kryo) { - k.register(classOf[CaseClass]) - k.register(classOf[ClassWithNoArgConstructor]) - k.register(classOf[ClassWithoutNoArgConstructor]) - k.register(classOf[java.util.HashMap[_, _]]) - } - } -} diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala deleted file mode 100644 index ddc212d290..0000000000 --- a/core/src/test/scala/spark/LocalSparkContext.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.Suite -import org.scalatest.BeforeAndAfterEach -import org.scalatest.BeforeAndAfterAll - -import org.jboss.netty.logging.InternalLoggerFactory -import org.jboss.netty.logging.Slf4JLoggerFactory - -/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */ -trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite => - - @transient var sc: SparkContext = _ - - override def beforeAll() { - InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()); - super.beforeAll() - } - - override def afterEach() { - resetSparkContext() - super.afterEach() - } - - def resetSparkContext() = { - if (sc != null) { - LocalSparkContext.stop(sc) - sc = null - } - } - -} - -object LocalSparkContext { - def stop(sc: SparkContext) { - sc.stop() - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - } - - /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ - def withSpark[T](sc: SparkContext)(f: SparkContext => T) = { - try { - f(sc) - } finally { - stop(sc) - } - } - -} diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala deleted file mode 100644 index c21f3331d0..0000000000 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite - -import akka.actor._ -import spark.scheduler.MapStatus -import spark.storage.BlockManagerId -import spark.util.AkkaUtils - -class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { - - test("compressSize") { - assert(MapOutputTracker.compressSize(0L) === 0) - assert(MapOutputTracker.compressSize(1L) === 1) - assert(MapOutputTracker.compressSize(2L) === 8) - assert(MapOutputTracker.compressSize(10L) === 25) - assert((MapOutputTracker.compressSize(1000000L) & 0xFF) === 145) - assert((MapOutputTracker.compressSize(1000000000L) & 0xFF) === 218) - // This last size is bigger than we can encode in a byte, so check that we just return 255 - assert((MapOutputTracker.compressSize(1000000000000000000L) & 0xFF) === 255) - } - - test("decompressSize") { - assert(MapOutputTracker.decompressSize(0) === 0) - for (size <- Seq(2L, 10L, 100L, 50000L, 1000000L, 1000000000L)) { - val size2 = MapOutputTracker.decompressSize(MapOutputTracker.compressSize(size)) - assert(size2 >= 0.99 * size && size2 <= 1.11 * size, - "size " + size + " decompressed to " + size2 + ", which is out of range") - } - } - - test("master start and stop") { - val actorSystem = ActorSystem("test") - val tracker = new MapOutputTracker() - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker))) - tracker.stop() - } - - test("master register and fetch") { - val actorSystem = ActorSystem("test") - val tracker = new MapOutputTracker() - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker))) - tracker.registerShuffle(10, 2) - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val compressedSize10000 = MapOutputTracker.compressSize(10000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), - Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), - Array(compressedSize10000, compressedSize1000))) - val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000, 0), size1000), - (BlockManagerId("b", "hostB", 1000, 0), size10000))) - tracker.stop() - } - - test("master register and unregister and fetch") { - val actorSystem = ActorSystem("test") - val tracker = new MapOutputTracker() - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker))) - tracker.registerShuffle(10, 2) - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val compressedSize10000 = MapOutputTracker.compressSize(10000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), - Array(compressedSize1000, compressedSize1000, compressedSize1000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), - Array(compressedSize10000, compressedSize1000, compressedSize1000))) - - // As if we had two simulatenous fetch failures - tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) - tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) - - // The remaining reduce task might try to grab the output despite the shuffle failure; - // this should cause it to fail, and the scheduler will ignore the failure due to the - // stage already being aborted. - intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } - } - - test("remote fetch") { - val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0) - System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext - System.setProperty("spark.hostPort", hostname + ":" + boundPort) - - val masterTracker = new MapOutputTracker() - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker") - - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0) - val slaveTracker = new MapOutputTracker() - slaveTracker.trackerActor = slaveSystem.actorFor( - "akka://spark@localhost:" + boundPort + "/user/MapOutputTracker") - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) - - masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - - // failure should be cached - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - } -} diff --git a/core/src/test/scala/spark/PairRDDFunctionsSuite.scala b/core/src/test/scala/spark/PairRDDFunctionsSuite.scala deleted file mode 100644 index 328b3b5497..0000000000 --- a/core/src/test/scala/spark/PairRDDFunctionsSuite.scala +++ /dev/null @@ -1,299 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashSet - -import org.scalatest.FunSuite - -import com.google.common.io.Files -import spark.SparkContext._ - - -class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { - test("groupByKey") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) - val groups = pairs.groupByKey().collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("groupByKey with duplicates") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val groups = pairs.groupByKey().collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("groupByKey with negative key hash codes") { - val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1))) - val groups = pairs.groupByKey().collect() - assert(groups.size === 2) - val valuesForMinus1 = groups.find(_._1 == -1).get._2 - assert(valuesForMinus1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("groupByKey with many output partitions") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) - val groups = pairs.groupByKey(10).collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } - - test("reduceByKey") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collect() - assert(sums.toSet === Set((1, 7), (2, 1))) - } - - test("reduceByKey with collectAsMap") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_).collectAsMap() - assert(sums.size === 2) - assert(sums(1) === 7) - assert(sums(2) === 1) - } - - test("reduceByKey with many output partitons") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.reduceByKey(_+_, 10).collect() - assert(sums.toSet === Set((1, 7), (2, 1))) - } - - test("reduceByKey with partitioner") { - val p = new Partitioner() { - def numPartitions = 2 - def getPartition(key: Any) = key.asInstanceOf[Int] - } - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p) - val sums = pairs.reduceByKey(_+_) - assert(sums.collect().toSet === Set((1, 4), (0, 1))) - assert(sums.partitioner === Some(p)) - // count the dependencies to make sure there is only 1 ShuffledRDD - val deps = new HashSet[RDD[_]]() - def visit(r: RDD[_]) { - for (dep <- r.dependencies) { - deps += dep.rdd - visit(dep.rdd) - } - } - visit(sums) - assert(deps.size === 2) // ShuffledRDD, ParallelCollection - } - - test("join") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.join(rdd2).collect() - assert(joined.size === 4) - assert(joined.toSet === Set( - (1, (1, 'x')), - (1, (2, 'x')), - (2, (1, 'y')), - (2, (1, 'z')) - )) - } - - test("join all-to-all") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3))) - val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y'))) - val joined = rdd1.join(rdd2).collect() - assert(joined.size === 6) - assert(joined.toSet === Set( - (1, (1, 'x')), - (1, (1, 'y')), - (1, (2, 'x')), - (1, (2, 'y')), - (1, (3, 'x')), - (1, (3, 'y')) - )) - } - - test("leftOuterJoin") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.leftOuterJoin(rdd2).collect() - assert(joined.size === 5) - assert(joined.toSet === Set( - (1, (1, Some('x'))), - (1, (2, Some('x'))), - (2, (1, Some('y'))), - (2, (1, Some('z'))), - (3, (1, None)) - )) - } - - test("rightOuterJoin") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.rightOuterJoin(rdd2).collect() - assert(joined.size === 5) - assert(joined.toSet === Set( - (1, (Some(1), 'x')), - (1, (Some(2), 'x')), - (2, (Some(1), 'y')), - (2, (Some(1), 'z')), - (4, (None, 'w')) - )) - } - - test("join with no matches") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w'))) - val joined = rdd1.join(rdd2).collect() - assert(joined.size === 0) - } - - test("join with many output partitions") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.join(rdd2, 10).collect() - assert(joined.size === 4) - assert(joined.toSet === Set( - (1, (1, 'x')), - (1, (2, 'x')), - (2, (1, 'y')), - (2, (1, 'z')) - )) - } - - test("groupWith") { - val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) - val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) - val joined = rdd1.groupWith(rdd2).collect() - assert(joined.size === 4) - assert(joined.toSet === Set( - (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))), - (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))), - (3, (ArrayBuffer(1), ArrayBuffer())), - (4, (ArrayBuffer(), ArrayBuffer('w'))) - )) - } - - test("zero-partition RDD") { - val emptyDir = Files.createTempDir() - val file = sc.textFile(emptyDir.getAbsolutePath) - assert(file.partitions.size == 0) - assert(file.collect().toList === Nil) - // Test that a shuffle on the file works, because this used to be a bug - assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) - } - - test("keys and values") { - val rdd = sc.parallelize(Array((1, "a"), (2, "b"))) - assert(rdd.keys.collect().toList === List(1, 2)) - assert(rdd.values.collect().toList === List("a", "b")) - } - - test("default partitioner uses partition size") { - // specify 2000 partitions - val a = sc.makeRDD(Array(1, 2, 3, 4), 2000) - // do a map, which loses the partitioner - val b = a.map(a => (a, (a * 2).toString)) - // then a group by, and see we didn't revert to 2 partitions - val c = b.groupByKey() - assert(c.partitions.size === 2000) - } - - test("default partitioner uses largest partitioner") { - val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2) - val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000) - val c = a.join(b) - assert(c.partitions.size === 2000) - } - - test("subtract") { - val a = sc.parallelize(Array(1, 2, 3), 2) - val b = sc.parallelize(Array(2, 3, 4), 4) - val c = a.subtract(b) - assert(c.collect().toSet === Set(1)) - assert(c.partitions.size === a.partitions.size) - } - - test("subtract with narrow dependency") { - // use a deterministic partitioner - val p = new Partitioner() { - def numPartitions = 5 - def getPartition(key: Any) = key.asInstanceOf[Int] - } - // partitionBy so we have a narrow dependency - val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p) - // more partitions/no partitioner so a shuffle dependency - val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) - val c = a.subtract(b) - assert(c.collect().toSet === Set((1, "a"), (3, "c"))) - // Ideally we could keep the original partitioner... - assert(c.partitioner === None) - } - - test("subtractByKey") { - val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2) - val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4) - val c = a.subtractByKey(b) - assert(c.collect().toSet === Set((1, "a"), (1, "a"))) - assert(c.partitions.size === a.partitions.size) - } - - test("subtractByKey with narrow dependency") { - // use a deterministic partitioner - val p = new Partitioner() { - def numPartitions = 5 - def getPartition(key: Any) = key.asInstanceOf[Int] - } - // partitionBy so we have a narrow dependency - val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p) - // more partitions/no partitioner so a shuffle dependency - val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) - val c = a.subtractByKey(b) - assert(c.collect().toSet === Set((1, "a"), (1, "a"))) - assert(c.partitioner.get === p) - } - - test("foldByKey") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val sums = pairs.foldByKey(0)(_+_).collect() - assert(sums.toSet === Set((1, 7), (2, 1))) - } - - test("foldByKey with mutable result type") { - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) - val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache() - // Fold the values using in-place mutation - val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect() - assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1)))) - // Check that the mutable objects in the original RDD were not changed - assert(bufs.collect().toSet === Set( - (1, ArrayBuffer(1)), - (1, ArrayBuffer(2)), - (1, ArrayBuffer(3)), - (1, ArrayBuffer(1)), - (2, ArrayBuffer(1)))) - } -} diff --git a/core/src/test/scala/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/spark/PartitionPruningRDDSuite.scala deleted file mode 100644 index 88352b639f..0000000000 --- a/core/src/test/scala/spark/PartitionPruningRDDSuite.scala +++ /dev/null @@ -1,28 +0,0 @@ -package spark - -import org.scalatest.FunSuite -import spark.SparkContext._ -import spark.rdd.PartitionPruningRDD - - -class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { - - test("Pruned Partitions inherit locality prefs correctly") { - class TestPartition(i: Int) extends Partition { - def index = i - } - val rdd = new RDD[Int](sc, Nil) { - override protected def getPartitions = { - Array[Partition]( - new TestPartition(1), - new TestPartition(2), - new TestPartition(3)) - } - def compute(split: Partition, context: TaskContext) = {Iterator()} - } - val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false}) - val p = prunedRDD.partitions(0) - assert(p.index == 2) - assert(prunedRDD.partitions.length == 1) - } -} diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala deleted file mode 100644 index b1e0b2b4d0..0000000000 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import scala.collection.mutable.ArrayBuffer -import SparkContext._ -import spark.util.StatCounter -import scala.math.abs - -class PartitioningSuite extends FunSuite with SharedSparkContext { - - test("HashPartitioner equality") { - val p2 = new HashPartitioner(2) - val p4 = new HashPartitioner(4) - val anotherP4 = new HashPartitioner(4) - assert(p2 === p2) - assert(p4 === p4) - assert(p2 != p4) - assert(p4 != p2) - assert(p4 === anotherP4) - assert(anotherP4 === p4) - } - - test("RangePartitioner equality") { - // Make an RDD where all the elements are the same so that the partition range bounds - // are deterministically all the same. - val rdd = sc.parallelize(Seq(1, 1, 1, 1)).map(x => (x, x)) - - val p2 = new RangePartitioner(2, rdd) - val p4 = new RangePartitioner(4, rdd) - val anotherP4 = new RangePartitioner(4, rdd) - val descendingP2 = new RangePartitioner(2, rdd, false) - val descendingP4 = new RangePartitioner(4, rdd, false) - - assert(p2 === p2) - assert(p4 === p4) - assert(p2 != p4) - assert(p4 != p2) - assert(p4 === anotherP4) - assert(anotherP4 === p4) - assert(descendingP2 === descendingP2) - assert(descendingP4 === descendingP4) - assert(descendingP2 != descendingP4) - assert(descendingP4 != descendingP2) - assert(p2 != descendingP2) - assert(p4 != descendingP4) - assert(descendingP2 != p2) - assert(descendingP4 != p4) - } - - test("HashPartitioner not equal to RangePartitioner") { - val rdd = sc.parallelize(1 to 10).map(x => (x, x)) - val rangeP2 = new RangePartitioner(2, rdd) - val hashP2 = new HashPartitioner(2) - assert(rangeP2 === rangeP2) - assert(hashP2 === hashP2) - assert(hashP2 != rangeP2) - assert(rangeP2 != hashP2) - } - - test("partitioner preservation") { - val rdd = sc.parallelize(1 to 10, 4).map(x => (x, x)) - - val grouped2 = rdd.groupByKey(2) - val grouped4 = rdd.groupByKey(4) - val reduced2 = rdd.reduceByKey(_ + _, 2) - val reduced4 = rdd.reduceByKey(_ + _, 4) - - assert(rdd.partitioner === None) - - assert(grouped2.partitioner === Some(new HashPartitioner(2))) - assert(grouped4.partitioner === Some(new HashPartitioner(4))) - assert(reduced2.partitioner === Some(new HashPartitioner(2))) - assert(reduced4.partitioner === Some(new HashPartitioner(4))) - - assert(grouped2.groupByKey().partitioner === grouped2.partitioner) - assert(grouped2.groupByKey(3).partitioner != grouped2.partitioner) - assert(grouped2.groupByKey(2).partitioner === grouped2.partitioner) - assert(grouped4.groupByKey().partitioner === grouped4.partitioner) - assert(grouped4.groupByKey(3).partitioner != grouped4.partitioner) - assert(grouped4.groupByKey(4).partitioner === grouped4.partitioner) - - assert(grouped2.join(grouped4).partitioner === grouped4.partitioner) - assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped4.partitioner) - assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped4.partitioner) - assert(grouped2.cogroup(grouped4).partitioner === grouped4.partitioner) - - assert(grouped2.join(reduced2).partitioner === grouped2.partitioner) - assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner) - assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner) - assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner) - - assert(grouped2.map(_ => 1).partitioner === None) - assert(grouped2.mapValues(_ => 1).partitioner === grouped2.partitioner) - assert(grouped2.flatMapValues(_ => Seq(1)).partitioner === grouped2.partitioner) - assert(grouped2.filter(_._1 > 4).partitioner === grouped2.partitioner) - } - - test("partitioning Java arrays should fail") { - val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x)) - val arrPairs: RDD[(Array[Int], Int)] = - sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x)) - - assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array")) - // We can't catch all usages of arrays, since they might occur inside other collections: - //assert(fails { arrPairs.distinct() }) - assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.cogroup(arrPairs) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array")) - assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array")) - } - - test("zero-length partitions should be correctly handled") { - // Create RDD with some consecutive empty partitions (including the "first" one) - val rdd: RDD[Double] = sc - .parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8) - .filter(_ >= 0.0) - - // Run the partitions, including the consecutive empty ones, through StatCounter - val stats: StatCounter = rdd.stats(); - assert(abs(6.0 - stats.sum) < 0.01); - assert(abs(6.0/2 - rdd.mean) < 0.01); - assert(abs(1.0 - rdd.variance) < 0.01); - assert(abs(1.0 - rdd.stdev) < 0.01); - - // Add other tests here for classes that should be able to handle empty partitions correctly - } -} diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala deleted file mode 100644 index 35c04710a3..0000000000 --- a/core/src/test/scala/spark/PipedRDDSuite.scala +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import SparkContext._ - -class PipedRDDSuite extends FunSuite with SharedSparkContext { - - test("basic pipe") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - - val piped = nums.pipe(Seq("cat")) - - val c = piped.collect() - assert(c.size === 4) - assert(c(0) === "1") - assert(c(1) === "2") - assert(c(2) === "3") - assert(c(3) === "4") - } - - test("advanced pipe") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val bl = sc.broadcast(List("0")) - - val piped = nums.pipe(Seq("cat"), - Map[String, String](), - (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, - (i:Int, f: String=> Unit) => f(i + "_")) - - val c = piped.collect() - - assert(c.size === 8) - assert(c(0) === "0") - assert(c(1) === "\u0001") - assert(c(2) === "1_") - assert(c(3) === "2_") - assert(c(4) === "0") - assert(c(5) === "\u0001") - assert(c(6) === "3_") - assert(c(7) === "4_") - - val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) - val d = nums1.groupBy(str=>str.split("\t")(0)). - pipe(Seq("cat"), - Map[String, String](), - (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, - (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect() - assert(d.size === 8) - assert(d(0) === "0") - assert(d(1) === "\u0001") - assert(d(2) === "b\t2_") - assert(d(3) === "b\t4_") - assert(d(4) === "0") - assert(d(5) === "\u0001") - assert(d(6) === "a\t1_") - assert(d(7) === "a\t3_") - } - - test("pipe with env variable") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA")) - val c = piped.collect() - assert(c.size === 2) - assert(c(0) === "LALALA") - assert(c(1) === "LALALA") - } - - test("pipe with non-zero exit status") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null")) - intercept[SparkException] { - piped.collect() - } - } - -} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala deleted file mode 100644 index e306952bbd..0000000000 --- a/core/src/test/scala/spark/RDDSuite.scala +++ /dev/null @@ -1,389 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import scala.collection.mutable.HashMap -import org.scalatest.FunSuite -import org.scalatest.concurrent.Timeouts._ -import org.scalatest.time.{Span, Millis} -import spark.SparkContext._ -import spark.rdd._ -import scala.collection.parallel.mutable - -class RDDSuite extends FunSuite with SharedSparkContext { - - test("basic operations") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - assert(nums.collect().toList === List(1, 2, 3, 4)) - val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) - assert(dups.distinct().count() === 4) - assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses? - assert(dups.distinct.collect === dups.distinct().collect) - assert(dups.distinct(2).collect === dups.distinct().collect) - assert(nums.reduce(_ + _) === 10) - assert(nums.fold(0)(_ + _) === 10) - assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4")) - assert(nums.filter(_ > 2).collect().toList === List(3, 4)) - assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) - assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) - assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) - assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4")) - assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) - val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) - assert(partitionSums.collect().toList === List(3, 7)) - - val partitionSumsWithSplit = nums.mapPartitionsWithSplit { - case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) - } - assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7))) - - val partitionSumsWithIndex = nums.mapPartitionsWithIndex { - case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) - } - assert(partitionSumsWithIndex.collect().toList === List((0, 3), (1, 7))) - - intercept[UnsupportedOperationException] { - nums.filter(_ > 5).reduce(_ + _) - } - } - - test("SparkContext.union") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - assert(sc.union(nums).collect().toList === List(1, 2, 3, 4)) - assert(sc.union(nums, nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) - assert(sc.union(Seq(nums)).collect().toList === List(1, 2, 3, 4)) - assert(sc.union(Seq(nums, nums)).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) - } - - test("aggregate") { - val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) - type StringMap = HashMap[String, Int] - val emptyMap = new StringMap { - override def default(key: String): Int = 0 - } - val mergeElement: (StringMap, (String, Int)) => StringMap = (map, pair) => { - map(pair._1) += pair._2 - map - } - val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => { - for ((key, value) <- map2) { - map1(key) += value - } - map1 - } - val result = pairs.aggregate(emptyMap)(mergeElement, mergeMaps) - assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) - } - - test("basic caching") { - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() - assert(rdd.collect().toList === List(1, 2, 3, 4)) - assert(rdd.collect().toList === List(1, 2, 3, 4)) - assert(rdd.collect().toList === List(1, 2, 3, 4)) - } - - test("caching with failures") { - val onlySplit = new Partition { override def index: Int = 0 } - var shouldFail = true - val rdd = new RDD[Int](sc, Nil) { - override def getPartitions: Array[Partition] = Array(onlySplit) - override val getDependencies = List[Dependency[_]]() - override def compute(split: Partition, context: TaskContext): Iterator[Int] = { - if (shouldFail) { - throw new Exception("injected failure") - } else { - return Array(1, 2, 3, 4).iterator - } - } - }.cache() - val thrown = intercept[Exception]{ - rdd.collect() - } - assert(thrown.getMessage.contains("injected failure")) - shouldFail = false - assert(rdd.collect().toList === List(1, 2, 3, 4)) - } - - test("empty RDD") { - val empty = new EmptyRDD[Int](sc) - assert(empty.count === 0) - assert(empty.collect().size === 0) - - val thrown = intercept[UnsupportedOperationException]{ - empty.reduce(_+_) - } - assert(thrown.getMessage.contains("empty")) - - val emptyKv = new EmptyRDD[(Int, Int)](sc) - val rdd = sc.parallelize(1 to 2, 2).map(x => (x, x)) - assert(rdd.join(emptyKv).collect().size === 0) - assert(rdd.rightOuterJoin(emptyKv).collect().size === 0) - assert(rdd.leftOuterJoin(emptyKv).collect().size === 2) - assert(rdd.cogroup(emptyKv).collect().size === 2) - assert(rdd.union(emptyKv).collect().size === 2) - } - - test("cogrouped RDDs") { - val data = sc.parallelize(1 to 10, 10) - - val coalesced1 = data.coalesce(2) - assert(coalesced1.collect().toList === (1 to 10).toList) - assert(coalesced1.glom().collect().map(_.toList).toList === - List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10))) - - // Check that the narrow dependency is also specified correctly - assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === - List(0, 1, 2, 3, 4)) - assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === - List(5, 6, 7, 8, 9)) - - val coalesced2 = data.coalesce(3) - assert(coalesced2.collect().toList === (1 to 10).toList) - assert(coalesced2.glom().collect().map(_.toList).toList === - List(List(1, 2, 3), List(4, 5, 6), List(7, 8, 9, 10))) - - val coalesced3 = data.coalesce(10) - assert(coalesced3.collect().toList === (1 to 10).toList) - assert(coalesced3.glom().collect().map(_.toList).toList === - (1 to 10).map(x => List(x)).toList) - - // If we try to coalesce into more partitions than the original RDD, it should just - // keep the original number of partitions. - val coalesced4 = data.coalesce(20) - assert(coalesced4.collect().toList === (1 to 10).toList) - assert(coalesced4.glom().collect().map(_.toList).toList === - (1 to 10).map(x => List(x)).toList) - - // we can optionally shuffle to keep the upstream parallel - val coalesced5 = data.coalesce(1, shuffle = true) - assert(coalesced5.dependencies.head.rdd.dependencies.head.rdd.asInstanceOf[ShuffledRDD[_, _, _]] != - null) - } - test("cogrouped RDDs with locality") { - val data3 = sc.makeRDD(List((1,List("a","c")), (2,List("a","b","c")), (3,List("b")))) - val coal3 = data3.coalesce(3) - val list3 = coal3.partitions.map(p => p.asInstanceOf[CoalescedRDDPartition].preferredLocation) - assert(list3.sorted === Array("a","b","c"), "Locality preferences are dropped") - - // RDD with locality preferences spread (non-randomly) over 6 machines, m0 through m5 - val data = sc.makeRDD((1 to 9).map(i => (i, (i to (i+2)).map{ j => "m" + (j%6)}))) - val coalesced1 = data.coalesce(3) - assert(coalesced1.collect().toList.sorted === (1 to 9).toList, "Data got *lost* in coalescing") - - val splits = coalesced1.glom().collect().map(_.toList).toList - assert(splits.length === 3, "Supposed to coalesce to 3 but got " + splits.length) - - assert(splits.forall(_.length >= 1) === true, "Some partitions were empty") - - // If we try to coalesce into more partitions than the original RDD, it should just - // keep the original number of partitions. - val coalesced4 = data.coalesce(20) - val listOfLists = coalesced4.glom().collect().map(_.toList).toList - val sortedList = listOfLists.sortWith{ (x, y) => !x.isEmpty && (y.isEmpty || (x(0) < y(0))) } - assert( sortedList === (1 to 9). - map{x => List(x)}.toList, "Tried coalescing 9 partitions to 20 but didn't get 9 back") - } - - test("cogrouped RDDs with locality, large scale (10K partitions)") { - // large scale experiment - import collection.mutable - val rnd = scala.util.Random - val partitions = 10000 - val numMachines = 50 - val machines = mutable.ListBuffer[String]() - (1 to numMachines).foreach(machines += "m"+_) - - val blocks = (1 to partitions).map(i => - { (i, Array.fill(3)(machines(rnd.nextInt(machines.size))).toList) } ) - - val data2 = sc.makeRDD(blocks) - val coalesced2 = data2.coalesce(numMachines*2) - - // test that you get over 90% locality in each group - val minLocality = coalesced2.partitions - .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) - .foldLeft(1.)((perc, loc) => math.min(perc,loc)) - assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.).toInt + "%") - - // test that the groups are load balanced with 100 +/- 20 elements in each - val maxImbalance = coalesced2.partitions - .map(part => part.asInstanceOf[CoalescedRDDPartition].parents.size) - .foldLeft(0)((dev, curr) => math.max(math.abs(100-curr),dev)) - assert(maxImbalance <= 20, "Expected 100 +/- 20 per partition, but got " + maxImbalance) - - val data3 = sc.makeRDD(blocks).map(i => i*2) // derived RDD to test *current* pref locs - val coalesced3 = data3.coalesce(numMachines*2) - val minLocality2 = coalesced3.partitions - .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) - .foldLeft(1.)((perc, loc) => math.min(perc,loc)) - assert(minLocality2 >= 0.90, "Expected 90% locality for derived RDD but got " + - (minLocality2*100.).toInt + "%") - } - - test("zipped RDDs") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val zipped = nums.zip(nums.map(_ + 1.0)) - assert(zipped.glom().map(_.toList).collect().toList === - List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0)))) - - intercept[IllegalArgumentException] { - nums.zip(sc.parallelize(1 to 4, 1)).collect() - } - } - - test("partition pruning") { - val data = sc.parallelize(1 to 10, 10) - // Note that split number starts from 0, so > 8 means only 10th partition left. - val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8) - assert(prunedRdd.partitions.size === 1) - val prunedData = prunedRdd.collect() - assert(prunedData.size === 1) - assert(prunedData(0) === 10) - } - - test("mapWith") { - import java.util.Random - val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) - val randoms = ones.mapWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => prng.nextDouble * t}.collect() - val prn42_3 = { - val prng42 = new Random(42) - prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() - } - val prn43_3 = { - val prng43 = new Random(43) - prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() - } - assert(randoms(2) === prn42_3) - assert(randoms(5) === prn43_3) - } - - test("flatMapWith") { - import java.util.Random - val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) - val randoms = ones.flatMapWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => - val random = prng.nextDouble() - Seq(random * t, random * t * 10)}. - collect() - val prn42_3 = { - val prng42 = new Random(42) - prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() - } - val prn43_3 = { - val prng43 = new Random(43) - prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() - } - assert(randoms(5) === prn42_3 * 10) - assert(randoms(11) === prn43_3 * 10) - } - - test("filterWith") { - import java.util.Random - val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) - val sample = ints.filterWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => prng.nextInt(3) == 0}. - collect() - val checkSample = { - val prng42 = new Random(42) - val prng43 = new Random(43) - Array(1, 2, 3, 4, 5, 6).filter{i => - if (i < 4) 0 == prng42.nextInt(3) - else 0 == prng43.nextInt(3)} - } - assert(sample.size === checkSample.size) - for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) - } - - test("top with predefined ordering") { - val nums = Array.range(1, 100000) - val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2) - val topK = ints.top(5) - assert(topK.size === 5) - assert(topK === nums.reverse.take(5)) - } - - test("top with custom ordering") { - val words = Vector("a", "b", "c", "d") - implicit val ord = implicitly[Ordering[String]].reverse - val rdd = sc.makeRDD(words, 2) - val topK = rdd.top(2) - assert(topK.size === 2) - assert(topK.sorted === Array("b", "a")) - } - - test("takeOrdered with predefined ordering") { - val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - val rdd = sc.makeRDD(nums, 2) - val sortedLowerK = rdd.takeOrdered(5) - assert(sortedLowerK.size === 5) - assert(sortedLowerK === Array(1, 2, 3, 4, 5)) - } - - test("takeOrdered with custom ordering") { - val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - implicit val ord = implicitly[Ordering[Int]].reverse - val rdd = sc.makeRDD(nums, 2) - val sortedTopK = rdd.takeOrdered(5) - assert(sortedTopK.size === 5) - assert(sortedTopK === Array(10, 9, 8, 7, 6)) - assert(sortedTopK === nums.sorted(ord).take(5)) - } - - test("takeSample") { - val data = sc.parallelize(1 to 100, 2) - for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=false, 20, seed) - assert(sample.size === 20) // Got exactly 20 elements - assert(sample.toSet.size === 20) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") - } - for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=false, 200, seed) - assert(sample.size === 100) // Got only 100 elements - assert(sample.toSet.size === 100) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") - } - for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 20, seed) - assert(sample.size === 20) // Got exactly 20 elements - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") - } - for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 100, seed) - assert(sample.size === 100) // Got exactly 100 elements - // Chance of getting all distinct elements is astronomically low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") - } - for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 200, seed) - assert(sample.size === 200) // Got exactly 200 elements - // Chance of getting all distinct elements is still quite low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") - } - } - - test("runJob on an invalid partition") { - intercept[IllegalArgumentException] { - sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false) - } - } -} diff --git a/core/src/test/scala/spark/SharedSparkContext.scala b/core/src/test/scala/spark/SharedSparkContext.scala deleted file mode 100644 index 70c24515be..0000000000 --- a/core/src/test/scala/spark/SharedSparkContext.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.Suite -import org.scalatest.BeforeAndAfterAll - -/** Shares a local `SparkContext` between all tests in a suite and closes it at the end */ -trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => - - @transient private var _sc: SparkContext = _ - - def sc: SparkContext = _sc - - override def beforeAll() { - _sc = new SparkContext("local", "test") - super.beforeAll() - } - - override def afterAll() { - if (_sc != null) { - LocalSparkContext.stop(_sc) - _sc = null - } - super.afterAll() - } -} diff --git a/core/src/test/scala/spark/ShuffleNettySuite.scala b/core/src/test/scala/spark/ShuffleNettySuite.scala deleted file mode 100644 index 6bad6c1d13..0000000000 --- a/core/src/test/scala/spark/ShuffleNettySuite.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.BeforeAndAfterAll - - -class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll { - - // This test suite should run all tests in ShuffleSuite with Netty shuffle mode. - - override def beforeAll(configMap: Map[String, Any]) { - System.setProperty("spark.shuffle.use.netty", "true") - } - - override def afterAll(configMap: Map[String, Any]) { - System.setProperty("spark.shuffle.use.netty", "false") - } -} diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala deleted file mode 100644 index 8745689c70..0000000000 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ /dev/null @@ -1,210 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers - -import spark.SparkContext._ -import spark.ShuffleSuite.NonJavaSerializableClass -import spark.rdd.{SubtractedRDD, CoGroupedRDD, OrderedRDDFunctions, ShuffledRDD} -import spark.util.MutablePair - - -class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { - test("groupByKey without compression") { - try { - System.setProperty("spark.shuffle.compress", "false") - sc = new SparkContext("local", "test") - val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4) - val groups = pairs.groupByKey(4).collect() - assert(groups.size === 2) - val valuesFor1 = groups.find(_._1 == 1).get._2 - assert(valuesFor1.toList.sorted === List(1, 2, 3)) - val valuesFor2 = groups.find(_._1 == 2).get._2 - assert(valuesFor2.toList.sorted === List(1)) - } finally { - System.setProperty("spark.shuffle.compress", "true") - } - } - - test("shuffle non-zero block size") { - sc = new SparkContext("local-cluster[2,1,512]", "test") - val NUM_BLOCKS = 3 - - val a = sc.parallelize(1 to 10, 2) - val b = a.map { x => - (x, new NonJavaSerializableClass(x * 2)) - } - // If the Kryo serializer is not used correctly, the shuffle would fail because the - // default Java serializer cannot handle the non serializable class. - val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( - b, new HashPartitioner(NUM_BLOCKS)).setSerializer(classOf[spark.KryoSerializer].getName) - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId - - assert(c.count === 10) - - // All blocks must have non-zero size - (0 until NUM_BLOCKS).foreach { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - assert(statuses.forall(s => s._2 > 0)) - } - } - - test("shuffle serializer") { - // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") - val a = sc.parallelize(1 to 10, 2) - val b = a.map { x => - (x, new NonJavaSerializableClass(x * 2)) - } - // If the Kryo serializer is not used correctly, the shuffle would fail because the - // default Java serializer cannot handle the non serializable class. - val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( - b, new HashPartitioner(3)).setSerializer(classOf[spark.KryoSerializer].getName) - assert(c.count === 10) - } - - test("zero sized blocks") { - // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") - - // 10 partitions from 4 keys - val NUM_BLOCKS = 10 - val a = sc.parallelize(1 to 4, NUM_BLOCKS) - val b = a.map(x => (x, x*2)) - - // NOTE: The default Java serializer doesn't create zero-sized blocks. - // So, use Kryo - val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) - .setSerializer(classOf[spark.KryoSerializer].getName) - - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId - assert(c.count === 4) - - val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - statuses.map(x => x._2) - } - val nonEmptyBlocks = blockSizes.filter(x => x > 0) - - // We should have at most 4 non-zero sized partitions - assert(nonEmptyBlocks.size <= 4) - } - - test("zero sized blocks without kryo") { - // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") - - // 10 partitions from 4 keys - val NUM_BLOCKS = 10 - val a = sc.parallelize(1 to 4, NUM_BLOCKS) - val b = a.map(x => (x, x*2)) - - // NOTE: The default Java serializer should create zero-sized blocks - val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) - - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId - assert(c.count === 4) - - val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - statuses.map(x => x._2) - } - val nonEmptyBlocks = blockSizes.filter(x => x > 0) - - // We should have at most 4 non-zero sized partitions - assert(nonEmptyBlocks.size <= 4) - } - - test("shuffle using mutable pairs") { - // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) - val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) - val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) - val results = new ShuffledRDD[Int, Int, MutablePair[Int, Int]](pairs, new HashPartitioner(2)) - .collect() - - data.foreach { pair => results should contain (pair) } - } - - test("sorting using mutable pairs") { - // This is not in SortingSuite because of the local cluster setup. - // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) - val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22)) - val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) - val results = new OrderedRDDFunctions[Int, Int, MutablePair[Int, Int]](pairs) - .sortByKey().collect() - results(0) should be (p(1, 11)) - results(1) should be (p(2, 22)) - results(2) should be (p(3, 33)) - results(3) should be (p(100, 100)) - } - - test("cogroup using mutable pairs") { - // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) - val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) - val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3")) - val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) - val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2) - val results = new CoGroupedRDD[Int](Seq(pairs1, pairs2), new HashPartitioner(2)).collectAsMap() - - assert(results(1)(0).length === 3) - assert(results(1)(0).contains(1)) - assert(results(1)(0).contains(2)) - assert(results(1)(0).contains(3)) - assert(results(1)(1).length === 2) - assert(results(1)(1).contains("11")) - assert(results(1)(1).contains("12")) - assert(results(2)(0).length === 1) - assert(results(2)(0).contains(1)) - assert(results(2)(1).length === 1) - assert(results(2)(1).contains("22")) - assert(results(3)(0).length === 0) - assert(results(3)(1).contains("3")) - } - - test("subtract mutable pairs") { - // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test") - def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2) - val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33)) - val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22")) - val pairs1: RDD[MutablePair[Int, Int]] = sc.parallelize(data1, 2) - val pairs2: RDD[MutablePair[Int, String]] = sc.parallelize(data2, 2) - val results = new SubtractedRDD(pairs1, pairs2, new HashPartitioner(2)).collect() - results should have length (1) - // substracted rdd return results as Tuple2 - results(0) should be ((3, 33)) - } -} - -object ShuffleSuite { - - def mergeCombineException(x: Int, y: Int): Int = { - throw new SparkException("Exception for map-side combine.") - x + y - } - - class NonJavaSerializableClass(val value: Int) -} diff --git a/core/src/test/scala/spark/SizeEstimatorSuite.scala b/core/src/test/scala/spark/SizeEstimatorSuite.scala deleted file mode 100644 index 1ef812dfbd..0000000000 --- a/core/src/test/scala/spark/SizeEstimatorSuite.scala +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfterAll -import org.scalatest.PrivateMethodTester - -class DummyClass1 {} - -class DummyClass2 { - val x: Int = 0 -} - -class DummyClass3 { - val x: Int = 0 - val y: Double = 0.0 -} - -class DummyClass4(val d: DummyClass3) { - val x: Int = 0 -} - -object DummyString { - def apply(str: String) : DummyString = new DummyString(str.toArray) -} -class DummyString(val arr: Array[Char]) { - override val hashCode: Int = 0 - // JDK-7 has an extra hash32 field http://hg.openjdk.java.net/jdk7u/jdk7u6/jdk/rev/11987e85555f - @transient val hash32: Int = 0 -} - -class SizeEstimatorSuite - extends FunSuite with BeforeAndAfterAll with PrivateMethodTester { - - var oldArch: String = _ - var oldOops: String = _ - - override def beforeAll() { - // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case - oldArch = System.setProperty("os.arch", "amd64") - oldOops = System.setProperty("spark.test.useCompressedOops", "true") - } - - override def afterAll() { - resetOrClear("os.arch", oldArch) - resetOrClear("spark.test.useCompressedOops", oldOops) - } - - test("simple classes") { - assert(SizeEstimator.estimate(new DummyClass1) === 16) - assert(SizeEstimator.estimate(new DummyClass2) === 16) - assert(SizeEstimator.estimate(new DummyClass3) === 24) - assert(SizeEstimator.estimate(new DummyClass4(null)) === 24) - assert(SizeEstimator.estimate(new DummyClass4(new DummyClass3)) === 48) - } - - // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors - // (Sun vs IBM). Use a DummyString class to make tests deterministic. - test("strings") { - assert(SizeEstimator.estimate(DummyString("")) === 40) - assert(SizeEstimator.estimate(DummyString("a")) === 48) - assert(SizeEstimator.estimate(DummyString("ab")) === 48) - assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) - } - - test("primitive arrays") { - assert(SizeEstimator.estimate(new Array[Byte](10)) === 32) - assert(SizeEstimator.estimate(new Array[Char](10)) === 40) - assert(SizeEstimator.estimate(new Array[Short](10)) === 40) - assert(SizeEstimator.estimate(new Array[Int](10)) === 56) - assert(SizeEstimator.estimate(new Array[Long](10)) === 96) - assert(SizeEstimator.estimate(new Array[Float](10)) === 56) - assert(SizeEstimator.estimate(new Array[Double](10)) === 96) - assert(SizeEstimator.estimate(new Array[Int](1000)) === 4016) - assert(SizeEstimator.estimate(new Array[Long](1000)) === 8016) - } - - test("object arrays") { - // Arrays containing nulls should just have one pointer per element - assert(SizeEstimator.estimate(new Array[String](10)) === 56) - assert(SizeEstimator.estimate(new Array[AnyRef](10)) === 56) - - // For object arrays with non-null elements, each object should take one pointer plus - // however many bytes that class takes. (Note that Array.fill calls the code in its - // second parameter separately for each object, so we get distinct objects.) - assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)) === 216) - assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)) === 216) - assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)) === 296) - assert(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)) === 56) - - // Past size 100, our samples 100 elements, but we should still get the right size. - assert(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)) === 28016) - - // If an array contains the *same* element many times, we should only count it once. - val d1 = new DummyClass1 - assert(SizeEstimator.estimate(Array.fill(10)(d1)) === 72) // 10 pointers plus 8-byte object - assert(SizeEstimator.estimate(Array.fill(100)(d1)) === 432) // 100 pointers plus 8-byte object - - // Same thing with huge array containing the same element many times. Note that this won't - // return exactly 4032 because it can't tell that *all* the elements will equal the first - // one it samples, but it should be close to that. - - // TODO: If we sample 100 elements, this should always be 4176 ? - val estimatedSize = SizeEstimator.estimate(Array.fill(1000)(d1)) - assert(estimatedSize >= 4000, "Estimated size " + estimatedSize + " should be more than 4000") - assert(estimatedSize <= 4200, "Estimated size " + estimatedSize + " should be less than 4100") - } - - test("32-bit arch") { - val arch = System.setProperty("os.arch", "x86") - - val initialize = PrivateMethod[Unit]('initialize) - SizeEstimator invokePrivate initialize() - - assert(SizeEstimator.estimate(DummyString("")) === 40) - assert(SizeEstimator.estimate(DummyString("a")) === 48) - assert(SizeEstimator.estimate(DummyString("ab")) === 48) - assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) - - resetOrClear("os.arch", arch) - } - - // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors - // (Sun vs IBM). Use a DummyString class to make tests deterministic. - test("64-bit arch with no compressed oops") { - val arch = System.setProperty("os.arch", "amd64") - val oops = System.setProperty("spark.test.useCompressedOops", "false") - - val initialize = PrivateMethod[Unit]('initialize) - SizeEstimator invokePrivate initialize() - - assert(SizeEstimator.estimate(DummyString("")) === 56) - assert(SizeEstimator.estimate(DummyString("a")) === 64) - assert(SizeEstimator.estimate(DummyString("ab")) === 64) - assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 72) - - resetOrClear("os.arch", arch) - resetOrClear("spark.test.useCompressedOops", oops) - } - - def resetOrClear(prop: String, oldValue: String) { - if (oldValue != null) { - System.setProperty(prop, oldValue) - } else { - System.clearProperty(prop) - } - } -} diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala deleted file mode 100644 index b933c4aab8..0000000000 --- a/core/src/test/scala/spark/SortingSuite.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter -import org.scalatest.matchers.ShouldMatchers -import SparkContext._ - -class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging { - - test("sortByKey") { - val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2) - assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0))) - } - - test("large array") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) - val sorted = pairs.sortByKey() - assert(sorted.partitions.size === 2) - assert(sorted.collect() === pairArr.sortBy(_._1)) - } - - test("large array with one split") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) - val sorted = pairs.sortByKey(true, 1) - assert(sorted.partitions.size === 1) - assert(sorted.collect() === pairArr.sortBy(_._1)) - } - - test("large array with many partitions") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) - val sorted = pairs.sortByKey(true, 20) - assert(sorted.partitions.size === 20) - assert(sorted.collect() === pairArr.sortBy(_._1)) - } - - test("sort descending") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) - assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) - } - - test("sort descending with one split") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 1) - assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) - } - - test("sort descending with many partitions") { - val rand = new scala.util.Random() - val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 2) - assert(pairs.sortByKey(false, 20).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) - } - - test("more partitions than elements") { - val rand = new scala.util.Random() - val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 30) - assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) - } - - test("empty RDD") { - val pairArr = new Array[(Int, Int)](0) - val pairs = sc.parallelize(pairArr, 2) - assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) - } - - test("partition balancing") { - val pairArr = (1 to 1000).map(x => (x, x)).toArray - val sorted = sc.parallelize(pairArr, 4).sortByKey() - assert(sorted.collect() === pairArr.sortBy(_._1)) - val partitions = sorted.collectPartitions() - logInfo("Partition lengths: " + partitions.map(_.length).mkString(", ")) - partitions(0).length should be > 180 - partitions(1).length should be > 180 - partitions(2).length should be > 180 - partitions(3).length should be > 180 - partitions(0).last should be < partitions(1).head - partitions(1).last should be < partitions(2).head - partitions(2).last should be < partitions(3).head - } - - test("partition balancing for descending sort") { - val pairArr = (1 to 1000).map(x => (x, x)).toArray - val sorted = sc.parallelize(pairArr, 4).sortByKey(false) - assert(sorted.collect() === pairArr.sortBy(_._1).reverse) - val partitions = sorted.collectPartitions() - logInfo("partition lengths: " + partitions.map(_.length).mkString(", ")) - partitions(0).length should be > 180 - partitions(1).length should be > 180 - partitions(2).length should be > 180 - partitions(3).length should be > 180 - partitions(0).last should be > partitions(1).head - partitions(1).last should be > partitions(2).head - partitions(2).last should be > partitions(3).head - } -} - diff --git a/core/src/test/scala/spark/SparkContextInfoSuite.scala b/core/src/test/scala/spark/SparkContextInfoSuite.scala deleted file mode 100644 index 6d50bf5e1b..0000000000 --- a/core/src/test/scala/spark/SparkContextInfoSuite.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import spark.SparkContext._ - -class SparkContextInfoSuite extends FunSuite with LocalSparkContext { - test("getPersistentRDDs only returns RDDs that are marked as cached") { - sc = new SparkContext("local", "test") - assert(sc.getPersistentRDDs.isEmpty === true) - - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) - assert(sc.getPersistentRDDs.isEmpty === true) - - rdd.cache() - assert(sc.getPersistentRDDs.size === 1) - assert(sc.getPersistentRDDs.values.head === rdd) - } - - test("getPersistentRDDs returns an immutable map") { - sc = new SparkContext("local", "test") - val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() - - val myRdds = sc.getPersistentRDDs - assert(myRdds.size === 1) - assert(myRdds.values.head === rdd1) - - val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache() - - // getPersistentRDDs should have 2 RDDs, but myRdds should not change - assert(sc.getPersistentRDDs.size === 2) - assert(myRdds.size === 1) - } - - test("getRDDStorageInfo only reports on RDDs that actually persist data") { - sc = new SparkContext("local", "test") - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() - - assert(sc.getRDDStorageInfo.size === 0) - - rdd.collect() - assert(sc.getRDDStorageInfo.size === 1) - } -} \ No newline at end of file diff --git a/core/src/test/scala/spark/ThreadingSuite.scala b/core/src/test/scala/spark/ThreadingSuite.scala deleted file mode 100644 index f2acd0bd3c..0000000000 --- a/core/src/test/scala/spark/ThreadingSuite.scala +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import java.util.concurrent.Semaphore -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.atomic.AtomicInteger - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - -import SparkContext._ - -/** - * Holds state shared across task threads in some ThreadingSuite tests. - */ -object ThreadingSuiteState { - val runningThreads = new AtomicInteger - val failed = new AtomicBoolean - - def clear() { - runningThreads.set(0) - failed.set(false) - } -} - -class ThreadingSuite extends FunSuite with LocalSparkContext { - - test("accessing SparkContext form a different thread") { - sc = new SparkContext("local", "test") - val nums = sc.parallelize(1 to 10, 2) - val sem = new Semaphore(0) - @volatile var answer1: Int = 0 - @volatile var answer2: Int = 0 - new Thread { - override def run() { - answer1 = nums.reduce(_ + _) - answer2 = nums.first() // This will run "locally" in the current thread - sem.release() - } - }.start() - sem.acquire() - assert(answer1 === 55) - assert(answer2 === 1) - } - - test("accessing SparkContext form multiple threads") { - sc = new SparkContext("local", "test") - val nums = sc.parallelize(1 to 10, 2) - val sem = new Semaphore(0) - @volatile var ok = true - for (i <- 0 until 10) { - new Thread { - override def run() { - val answer1 = nums.reduce(_ + _) - if (answer1 != 55) { - printf("In thread %d: answer1 was %d\n", i, answer1) - ok = false - } - val answer2 = nums.first() // This will run "locally" in the current thread - if (answer2 != 1) { - printf("In thread %d: answer2 was %d\n", i, answer2) - ok = false - } - sem.release() - } - }.start() - } - sem.acquire(10) - if (!ok) { - fail("One or more threads got the wrong answer from an RDD operation") - } - } - - test("accessing multi-threaded SparkContext form multiple threads") { - sc = new SparkContext("local[4]", "test") - val nums = sc.parallelize(1 to 10, 2) - val sem = new Semaphore(0) - @volatile var ok = true - for (i <- 0 until 10) { - new Thread { - override def run() { - val answer1 = nums.reduce(_ + _) - if (answer1 != 55) { - printf("In thread %d: answer1 was %d\n", i, answer1) - ok = false - } - val answer2 = nums.first() // This will run "locally" in the current thread - if (answer2 != 1) { - printf("In thread %d: answer2 was %d\n", i, answer2) - ok = false - } - sem.release() - } - }.start() - } - sem.acquire(10) - if (!ok) { - fail("One or more threads got the wrong answer from an RDD operation") - } - } - - test("parallel job execution") { - // This test launches two jobs with two threads each on a 4-core local cluster. Each thread - // waits until there are 4 threads running at once, to test that both jobs have been launched. - sc = new SparkContext("local[4]", "test") - val nums = sc.parallelize(1 to 2, 2) - val sem = new Semaphore(0) - ThreadingSuiteState.clear() - for (i <- 0 until 2) { - new Thread { - override def run() { - val ans = nums.map(number => { - val running = ThreadingSuiteState.runningThreads - running.getAndIncrement() - val time = System.currentTimeMillis() - while (running.get() != 4 && System.currentTimeMillis() < time + 1000) { - Thread.sleep(100) - } - if (running.get() != 4) { - println("Waited 1 second without seeing runningThreads = 4 (it was " + - running.get() + "); failing test") - ThreadingSuiteState.failed.set(true) - } - number - }).collect() - assert(ans.toList === List(1, 2)) - sem.release() - } - }.start() - } - sem.acquire(2) - if (ThreadingSuiteState.failed.get()) { - fail("One or more threads didn't see runningThreads = 4") - } - } -} diff --git a/core/src/test/scala/spark/UnpersistSuite.scala b/core/src/test/scala/spark/UnpersistSuite.scala deleted file mode 100644 index 93977d16f4..0000000000 --- a/core/src/test/scala/spark/UnpersistSuite.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.FunSuite -import org.scalatest.concurrent.Timeouts._ -import org.scalatest.time.{Span, Millis} -import spark.SparkContext._ - -class UnpersistSuite extends FunSuite with LocalSparkContext { - test("unpersist RDD") { - sc = new SparkContext("local", "test") - val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() - rdd.count - assert(sc.persistentRdds.isEmpty === false) - rdd.unpersist() - assert(sc.persistentRdds.isEmpty === true) - - failAfter(Span(3000, Millis)) { - try { - while (! sc.getRDDStorageInfo.isEmpty) { - Thread.sleep(200) - } - } catch { - case _ => { Thread.sleep(10) } - // Do nothing. We might see exceptions because block manager - // is racing this thread to remove entries from the driver. - } - } - assert(sc.getRDDStorageInfo.isEmpty === true) - } -} diff --git a/core/src/test/scala/spark/UtilsSuite.scala b/core/src/test/scala/spark/UtilsSuite.scala deleted file mode 100644 index 98a6c1a1c9..0000000000 --- a/core/src/test/scala/spark/UtilsSuite.scala +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import com.google.common.base.Charsets -import com.google.common.io.Files -import java.io.{ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream, File} -import org.scalatest.FunSuite -import org.apache.commons.io.FileUtils -import scala.util.Random - -class UtilsSuite extends FunSuite { - - test("bytesToString") { - assert(Utils.bytesToString(10) === "10.0 B") - assert(Utils.bytesToString(1500) === "1500.0 B") - assert(Utils.bytesToString(2000000) === "1953.1 KB") - assert(Utils.bytesToString(2097152) === "2.0 MB") - assert(Utils.bytesToString(2306867) === "2.2 MB") - assert(Utils.bytesToString(5368709120L) === "5.0 GB") - assert(Utils.bytesToString(5L * 1024L * 1024L * 1024L * 1024L) === "5.0 TB") - } - - test("copyStream") { - //input array initialization - val bytes = Array.ofDim[Byte](9000) - Random.nextBytes(bytes) - - val os = new ByteArrayOutputStream() - Utils.copyStream(new ByteArrayInputStream(bytes), os) - - assert(os.toByteArray.toList.equals(bytes.toList)) - } - - test("memoryStringToMb") { - assert(Utils.memoryStringToMb("1") === 0) - assert(Utils.memoryStringToMb("1048575") === 0) - assert(Utils.memoryStringToMb("3145728") === 3) - - assert(Utils.memoryStringToMb("1024k") === 1) - assert(Utils.memoryStringToMb("5000k") === 4) - assert(Utils.memoryStringToMb("4024k") === Utils.memoryStringToMb("4024K")) - - assert(Utils.memoryStringToMb("1024m") === 1024) - assert(Utils.memoryStringToMb("5000m") === 5000) - assert(Utils.memoryStringToMb("4024m") === Utils.memoryStringToMb("4024M")) - - assert(Utils.memoryStringToMb("2g") === 2048) - assert(Utils.memoryStringToMb("3g") === Utils.memoryStringToMb("3G")) - - assert(Utils.memoryStringToMb("2t") === 2097152) - assert(Utils.memoryStringToMb("3t") === Utils.memoryStringToMb("3T")) - } - - test("splitCommandString") { - assert(Utils.splitCommandString("") === Seq()) - assert(Utils.splitCommandString("a") === Seq("a")) - assert(Utils.splitCommandString("aaa") === Seq("aaa")) - assert(Utils.splitCommandString("a b c") === Seq("a", "b", "c")) - assert(Utils.splitCommandString(" a b\t c ") === Seq("a", "b", "c")) - assert(Utils.splitCommandString("a 'b c'") === Seq("a", "b c")) - assert(Utils.splitCommandString("a 'b c' d") === Seq("a", "b c", "d")) - assert(Utils.splitCommandString("'b c'") === Seq("b c")) - assert(Utils.splitCommandString("a \"b c\"") === Seq("a", "b c")) - assert(Utils.splitCommandString("a \"b c\" d") === Seq("a", "b c", "d")) - assert(Utils.splitCommandString("\"b c\"") === Seq("b c")) - assert(Utils.splitCommandString("a 'b\" c' \"d' e\"") === Seq("a", "b\" c", "d' e")) - assert(Utils.splitCommandString("a\t'b\nc'\nd") === Seq("a", "b\nc", "d")) - assert(Utils.splitCommandString("a \"b\\\\c\"") === Seq("a", "b\\c")) - assert(Utils.splitCommandString("a \"b\\\"c\"") === Seq("a", "b\"c")) - assert(Utils.splitCommandString("a 'b\\\"c'") === Seq("a", "b\\\"c")) - assert(Utils.splitCommandString("'a'b") === Seq("ab")) - assert(Utils.splitCommandString("'a''b'") === Seq("ab")) - assert(Utils.splitCommandString("\"a\"b") === Seq("ab")) - assert(Utils.splitCommandString("\"a\"\"b\"") === Seq("ab")) - assert(Utils.splitCommandString("''") === Seq("")) - assert(Utils.splitCommandString("\"\"") === Seq("")) - } - - test("string formatting of time durations") { - val second = 1000 - val minute = second * 60 - val hour = minute * 60 - def str = Utils.msDurationToString(_) - - assert(str(123) === "123 ms") - assert(str(second) === "1.0 s") - assert(str(second + 462) === "1.5 s") - assert(str(hour) === "1.00 h") - assert(str(minute) === "1.0 m") - assert(str(minute + 4 * second + 34) === "1.1 m") - assert(str(10 * hour + minute + 4 * second) === "10.02 h") - assert(str(10 * hour + 59 * minute + 59 * second + 999) === "11.00 h") - } - - test("reading offset bytes of a file") { - val tmpDir2 = Files.createTempDir() - val f1Path = tmpDir2 + "/f1" - val f1 = new FileOutputStream(f1Path) - f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(Charsets.UTF_8)) - f1.close() - - // Read first few bytes - assert(Utils.offsetBytes(f1Path, 0, 5) === "1\n2\n3") - - // Read some middle bytes - assert(Utils.offsetBytes(f1Path, 4, 11) === "3\n4\n5\n6") - - // Read last few bytes - assert(Utils.offsetBytes(f1Path, 12, 18) === "7\n8\n9\n") - - // Read some nonexistent bytes in the beginning - assert(Utils.offsetBytes(f1Path, -5, 5) === "1\n2\n3") - - // Read some nonexistent bytes at the end - assert(Utils.offsetBytes(f1Path, 12, 22) === "7\n8\n9\n") - - // Read some nonexistent bytes on both ends - assert(Utils.offsetBytes(f1Path, -3, 25) === "1\n2\n3\n4\n5\n6\n7\n8\n9\n") - - FileUtils.deleteDirectory(tmpDir2) - } -} - diff --git a/core/src/test/scala/spark/ZippedPartitionsSuite.scala b/core/src/test/scala/spark/ZippedPartitionsSuite.scala deleted file mode 100644 index bb5d379273..0000000000 --- a/core/src/test/scala/spark/ZippedPartitionsSuite.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import scala.collection.immutable.NumericRange - -import org.scalatest.FunSuite -import org.scalatest.prop.Checkers -import org.scalacheck.Arbitrary._ -import org.scalacheck.Gen -import org.scalacheck.Prop._ - -import SparkContext._ - - -object ZippedPartitionsSuite { - def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = { - Iterator(i.toArray.size, s.toArray.size, d.toArray.size) - } -} - -class ZippedPartitionsSuite extends FunSuite with SharedSparkContext { - test("print sizes") { - val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2) - val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2) - val data3 = sc.makeRDD(Array(1.0, 2.0), 2) - - val zippedRDD = data1.zipPartitions(data2, data3)(ZippedPartitionsSuite.procZippedData) - - val obtainedSizes = zippedRDD.collect() - val expectedSizes = Array(2, 3, 1, 2, 3, 1) - assert(obtainedSizes.size == 6) - assert(obtainedSizes.zip(expectedSizes).forall(x => x._1 == x._2)) - } -} diff --git a/core/src/test/scala/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/spark/io/CompressionCodecSuite.scala deleted file mode 100644 index 1ba82fe2b9..0000000000 --- a/core/src/test/scala/spark/io/CompressionCodecSuite.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.io - -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} - -import org.scalatest.FunSuite - - -class CompressionCodecSuite extends FunSuite { - - def testCodec(codec: CompressionCodec) { - // Write 1000 integers to the output stream, compressed. - val outputStream = new ByteArrayOutputStream() - val out = codec.compressedOutputStream(outputStream) - for (i <- 1 until 1000) { - out.write(i % 256) - } - out.close() - - // Read the 1000 integers back. - val inputStream = new ByteArrayInputStream(outputStream.toByteArray) - val in = codec.compressedInputStream(inputStream) - for (i <- 1 until 1000) { - assert(in.read() === i % 256) - } - in.close() - } - - test("default compression codec") { - val codec = CompressionCodec.createCodec() - assert(codec.getClass === classOf[SnappyCompressionCodec]) - testCodec(codec) - } - - test("lzf compression codec") { - val codec = CompressionCodec.createCodec(classOf[LZFCompressionCodec].getName) - assert(codec.getClass === classOf[LZFCompressionCodec]) - testCodec(codec) - } - - test("snappy compression codec") { - val codec = CompressionCodec.createCodec(classOf[SnappyCompressionCodec].getName) - assert(codec.getClass === classOf[SnappyCompressionCodec]) - testCodec(codec) - } -} diff --git a/core/src/test/scala/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/spark/metrics/MetricsConfigSuite.scala deleted file mode 100644 index b0213b62d9..0000000000 --- a/core/src/test/scala/spark/metrics/MetricsConfigSuite.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics - -import org.scalatest.{BeforeAndAfter, FunSuite} - -class MetricsConfigSuite extends FunSuite with BeforeAndAfter { - var filePath: String = _ - - before { - filePath = getClass.getClassLoader.getResource("test_metrics_config.properties").getFile() - } - - test("MetricsConfig with default properties") { - val conf = new MetricsConfig(Option("dummy-file")) - conf.initialize() - - assert(conf.properties.size() === 5) - assert(conf.properties.getProperty("test-for-dummy") === null) - - val property = conf.getInstance("random") - assert(property.size() === 3) - assert(property.getProperty("sink.servlet.class") === "spark.metrics.sink.MetricsServlet") - assert(property.getProperty("sink.servlet.uri") === "/metrics/json") - assert(property.getProperty("sink.servlet.sample") === "false") - } - - test("MetricsConfig with properties set") { - val conf = new MetricsConfig(Option(filePath)) - conf.initialize() - - val masterProp = conf.getInstance("master") - assert(masterProp.size() === 6) - assert(masterProp.getProperty("sink.console.period") === "20") - assert(masterProp.getProperty("sink.console.unit") === "minutes") - assert(masterProp.getProperty("source.jvm.class") === "spark.metrics.source.JvmSource") - assert(masterProp.getProperty("sink.servlet.class") === "spark.metrics.sink.MetricsServlet") - assert(masterProp.getProperty("sink.servlet.uri") === "/metrics/master/json") - assert(masterProp.getProperty("sink.servlet.sample") === "false") - - val workerProp = conf.getInstance("worker") - assert(workerProp.size() === 6) - assert(workerProp.getProperty("sink.console.period") === "10") - assert(workerProp.getProperty("sink.console.unit") === "seconds") - assert(workerProp.getProperty("source.jvm.class") === "spark.metrics.source.JvmSource") - assert(workerProp.getProperty("sink.servlet.class") === "spark.metrics.sink.MetricsServlet") - assert(workerProp.getProperty("sink.servlet.uri") === "/metrics/json") - assert(workerProp.getProperty("sink.servlet.sample") === "false") - } - - test("MetricsConfig with subProperties") { - val conf = new MetricsConfig(Option(filePath)) - conf.initialize() - - val propCategories = conf.propertyCategories - assert(propCategories.size === 3) - - val masterProp = conf.getInstance("master") - val sourceProps = conf.subProperties(masterProp, MetricsSystem.SOURCE_REGEX) - assert(sourceProps.size === 1) - assert(sourceProps("jvm").getProperty("class") === "spark.metrics.source.JvmSource") - - val sinkProps = conf.subProperties(masterProp, MetricsSystem.SINK_REGEX) - assert(sinkProps.size === 2) - assert(sinkProps.contains("console")) - assert(sinkProps.contains("servlet")) - - val consoleProps = sinkProps("console") - assert(consoleProps.size() === 2) - - val servletProps = sinkProps("servlet") - assert(servletProps.size() === 3) - } -} diff --git a/core/src/test/scala/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/spark/metrics/MetricsSystemSuite.scala deleted file mode 100644 index dc65ac6994..0000000000 --- a/core/src/test/scala/spark/metrics/MetricsSystemSuite.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.metrics - -import org.scalatest.{BeforeAndAfter, FunSuite} - -class MetricsSystemSuite extends FunSuite with BeforeAndAfter { - var filePath: String = _ - - before { - filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile() - System.setProperty("spark.metrics.conf", filePath) - } - - test("MetricsSystem with default config") { - val metricsSystem = MetricsSystem.createMetricsSystem("default") - val sources = metricsSystem.sources - val sinks = metricsSystem.sinks - - assert(sources.length === 0) - assert(sinks.length === 0) - assert(!metricsSystem.getServletHandlers.isEmpty) - } - - test("MetricsSystem with sources add") { - val metricsSystem = MetricsSystem.createMetricsSystem("test") - val sources = metricsSystem.sources - val sinks = metricsSystem.sinks - - assert(sources.length === 0) - assert(sinks.length === 1) - assert(!metricsSystem.getServletHandlers.isEmpty) - - val source = new spark.deploy.master.MasterSource(null) - metricsSystem.registerSource(source) - assert(sources.length === 1) - } -} diff --git a/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala deleted file mode 100644 index dc8ca941c1..0000000000 --- a/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark - -import org.scalatest.{ BeforeAndAfter, FunSuite } -import spark.SparkContext._ -import spark.rdd.JdbcRDD -import java.sql._ - -class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { - - before { - Class.forName("org.apache.derby.jdbc.EmbeddedDriver") - val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true") - try { - val create = conn.createStatement - create.execute(""" - CREATE TABLE FOO( - ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), - DATA INTEGER - )""") - create.close - val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)") - (1 to 100).foreach { i => - insert.setInt(1, i * 2) - insert.executeUpdate - } - insert.close - } catch { - case e: SQLException if e.getSQLState == "X0Y32" => - // table exists - } finally { - conn.close - } - } - - test("basic functionality") { - sc = new SparkContext("local", "test") - val rdd = new JdbcRDD( - sc, - () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") }, - "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?", - 1, 100, 3, - (r: ResultSet) => { r.getInt(1) } ).cache - - assert(rdd.count === 100) - assert(rdd.reduce(_+_) === 10100) - } - - after { - try { - DriverManager.getConnection("jdbc:derby:;shutdown=true") - } catch { - case se: SQLException if se.getSQLState == "XJ015" => - // normal shutdown - } - } -} diff --git a/core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala deleted file mode 100644 index d1276d541f..0000000000 --- a/core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.rdd - -import scala.collection.immutable.NumericRange - -import org.scalatest.FunSuite -import org.scalatest.prop.Checkers -import org.scalacheck.Arbitrary._ -import org.scalacheck.Gen -import org.scalacheck.Prop._ - -class ParallelCollectionSplitSuite extends FunSuite with Checkers { - test("one element per slice") { - val data = Array(1, 2, 3) - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices(0).mkString(",") === "1") - assert(slices(1).mkString(",") === "2") - assert(slices(2).mkString(",") === "3") - } - - test("one slice") { - val data = Array(1, 2, 3) - val slices = ParallelCollectionRDD.slice(data, 1) - assert(slices.size === 1) - assert(slices(0).mkString(",") === "1,2,3") - } - - test("equal slices") { - val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9) - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices(0).mkString(",") === "1,2,3") - assert(slices(1).mkString(",") === "4,5,6") - assert(slices(2).mkString(",") === "7,8,9") - } - - test("non-equal slices") { - val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices(0).mkString(",") === "1,2,3") - assert(slices(1).mkString(",") === "4,5,6") - assert(slices(2).mkString(",") === "7,8,9,10") - } - - test("splitting exclusive range") { - val data = 0 until 100 - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices(0).mkString(",") === (0 to 32).mkString(",")) - assert(slices(1).mkString(",") === (33 to 65).mkString(",")) - assert(slices(2).mkString(",") === (66 to 99).mkString(",")) - } - - test("splitting inclusive range") { - val data = 0 to 100 - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices(0).mkString(",") === (0 to 32).mkString(",")) - assert(slices(1).mkString(",") === (33 to 66).mkString(",")) - assert(slices(2).mkString(",") === (67 to 100).mkString(",")) - } - - test("empty data") { - val data = new Array[Int](0) - val slices = ParallelCollectionRDD.slice(data, 5) - assert(slices.size === 5) - for (slice <- slices) assert(slice.size === 0) - } - - test("zero slices") { - val data = Array(1, 2, 3) - intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, 0) } - } - - test("negative number of slices") { - val data = Array(1, 2, 3) - intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, -5) } - } - - test("exclusive ranges sliced into ranges") { - val data = 1 until 100 - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) - assert(slices.forall(_.isInstanceOf[Range])) - } - - test("inclusive ranges sliced into ranges") { - val data = 1 to 100 - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) - assert(slices.forall(_.isInstanceOf[Range])) - } - - test("large ranges don't overflow") { - val N = 100 * 1000 * 1000 - val data = 0 until N - val slices = ParallelCollectionRDD.slice(data, 40) - assert(slices.size === 40) - for (i <- 0 until 40) { - assert(slices(i).isInstanceOf[Range]) - val range = slices(i).asInstanceOf[Range] - assert(range.start === i * (N / 40), "slice " + i + " start") - assert(range.end === (i+1) * (N / 40), "slice " + i + " end") - assert(range.step === 1, "slice " + i + " step") - } - } - - test("random array tests") { - val gen = for { - d <- arbitrary[List[Int]] - n <- Gen.choose(1, 100) - } yield (d, n) - val prop = forAll(gen) { - (tuple: (List[Int], Int)) => - val d = tuple._1 - val n = tuple._2 - val slices = ParallelCollectionRDD.slice(d, n) - ("n slices" |: slices.size == n) && - ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) - } - check(prop) - } - - test("random exclusive range tests") { - val gen = for { - a <- Gen.choose(-100, 100) - b <- Gen.choose(-100, 100) - step <- Gen.choose(-5, 5) suchThat (_ != 0) - n <- Gen.choose(1, 100) - } yield (a until b by step, n) - val prop = forAll(gen) { - case (d: Range, n: Int) => - val slices = ParallelCollectionRDD.slice(d, n) - ("n slices" |: slices.size == n) && - ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && - ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) - } - check(prop) - } - - test("random inclusive range tests") { - val gen = for { - a <- Gen.choose(-100, 100) - b <- Gen.choose(-100, 100) - step <- Gen.choose(-5, 5) suchThat (_ != 0) - n <- Gen.choose(1, 100) - } yield (a to b by step, n) - val prop = forAll(gen) { - case (d: Range, n: Int) => - val slices = ParallelCollectionRDD.slice(d, n) - ("n slices" |: slices.size == n) && - ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && - ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && - ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) - } - check(prop) - } - - test("exclusive ranges of longs") { - val data = 1L until 100L - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) - assert(slices.forall(_.isInstanceOf[NumericRange[_]])) - } - - test("inclusive ranges of longs") { - val data = 1L to 100L - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) - assert(slices.forall(_.isInstanceOf[NumericRange[_]])) - } - - test("exclusive ranges of doubles") { - val data = 1.0 until 100.0 by 1.0 - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 99) - assert(slices.forall(_.isInstanceOf[NumericRange[_]])) - } - - test("inclusive ranges of doubles") { - val data = 1.0 to 100.0 by 1.0 - val slices = ParallelCollectionRDD.slice(data, 3) - assert(slices.size === 3) - assert(slices.map(_.size).reduceLeft(_+_) === 100) - assert(slices.forall(_.isInstanceOf[NumericRange[_]])) - } -} diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala deleted file mode 100644 index 3b4a0d52fc..0000000000 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ /dev/null @@ -1,421 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import scala.collection.mutable.{Map, HashMap} - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - -import spark.LocalSparkContext -import spark.MapOutputTracker -import spark.RDD -import spark.SparkContext -import spark.Partition -import spark.TaskContext -import spark.{Dependency, ShuffleDependency, OneToOneDependency} -import spark.{FetchFailed, Success, TaskEndReason} -import spark.storage.{BlockManagerId, BlockManagerMaster} - -import spark.scheduler.cluster.Pool -import spark.scheduler.cluster.SchedulingMode -import spark.scheduler.cluster.SchedulingMode.SchedulingMode - -/** - * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler - * rather than spawning an event loop thread as happens in the real code. They use EasyMock - * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are - * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead - * host notifications are sent). In addition, tests may check for side effects on a non-mocked - * MapOutputTracker instance. - * - * Tests primarily consist of running DAGScheduler#processEvent and - * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet) - * and capturing the resulting TaskSets from the mock TaskScheduler. - */ -class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { - - /** Set of TaskSets the DAGScheduler has requested executed. */ - val taskSets = scala.collection.mutable.Buffer[TaskSet]() - val taskScheduler = new TaskScheduler() { - override def rootPool: Pool = null - override def schedulingMode: SchedulingMode = SchedulingMode.NONE - override def start() = {} - override def stop() = {} - override def submitTasks(taskSet: TaskSet) = { - // normally done by TaskSetManager - taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) - taskSets += taskSet - } - override def setListener(listener: TaskSchedulerListener) = {} - override def defaultParallelism() = 2 - } - - var mapOutputTracker: MapOutputTracker = null - var scheduler: DAGScheduler = null - - /** - * Set of cache locations to return from our mock BlockManagerMaster. - * Keys are (rdd ID, partition ID). Anything not present will return an empty - * list of cache locations silently. - */ - val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] - // stub out BlockManagerMaster.getLocations to use our cacheLocations - val blockManagerMaster = new BlockManagerMaster(null) { - override def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { - blockIds.map { name => - val pieces = name.split("_") - if (pieces(0) == "rdd") { - val key = pieces(1).toInt -> pieces(2).toInt - cacheLocations.getOrElse(key, Seq()) - } else { - Seq() - } - }.toSeq - } - override def removeExecutor(execId: String) { - // don't need to propagate to the driver, which we don't have - } - } - - /** The list of results that DAGScheduler has collected. */ - val results = new HashMap[Int, Any]() - var failure: Exception = _ - val listener = new JobListener() { - override def taskSucceeded(index: Int, result: Any) = results.put(index, result) - override def jobFailed(exception: Exception) = { failure = exception } - } - - before { - sc = new SparkContext("local", "DAGSchedulerSuite") - taskSets.clear() - cacheLocations.clear() - results.clear() - mapOutputTracker = new MapOutputTracker() - scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) { - override def runLocally(job: ActiveJob) { - // don't bother with the thread while unit testing - runLocallyWithinThread(job) - } - } - } - - after { - scheduler.stop() - } - - /** - * Type of RDD we use for testing. Note that we should never call the real RDD compute methods. - * This is a pair RDD type so it can always be used in ShuffleDependencies. - */ - type MyRDD = RDD[(Int, Int)] - - /** - * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and - * preferredLocations (if any) that are passed to them. They are deliberately not executable - * so we can test that DAGScheduler does not try to execute RDDs locally. - */ - private def makeRdd( - numPartitions: Int, - dependencies: List[Dependency[_]], - locations: Seq[Seq[String]] = Nil - ): MyRDD = { - val maxPartition = numPartitions - 1 - return new MyRDD(sc, dependencies) { - override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = - throw new RuntimeException("should not be reached") - override def getPartitions = (0 to maxPartition).map(i => new Partition { - override def index = i - }).toArray - override def getPreferredLocations(split: Partition): Seq[String] = - if (locations.isDefinedAt(split.index)) - locations(split.index) - else - Nil - override def toString: String = "DAGSchedulerSuiteRDD " + id - } - } - - /** - * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting - * the scheduler not to exit. - * - * After processing the event, submit waiting stages as is done on most iterations of the - * DAGScheduler event loop. - */ - private def runEvent(event: DAGSchedulerEvent) { - assert(!scheduler.processEvent(event)) - scheduler.submitWaitingStages() - } - - /** - * When we submit dummy Jobs, this is the compute function we supply. Except in a local test - * below, we do not expect this function to ever be executed; instead, we will return results - * directly through CompletionEvents. - */ - private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) => - it.next.asInstanceOf[Tuple2[_, _]]._1 - - /** Send the given CompletionEvent messages for the tasks in the TaskSet. */ - private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) { - assert(taskSet.tasks.size >= results.size) - for ((result, i) <- results.zipWithIndex) { - if (i < taskSet.tasks.size) { - runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any](), null, null)) - } - } - } - - /** Sends the rdd to the scheduler for scheduling. */ - private def submit( - rdd: RDD[_], - partitions: Array[Int], - func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, - allowLocal: Boolean = false, - listener: JobListener = listener) { - runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener)) - } - - /** Sends TaskSetFailed to the scheduler. */ - private def failed(taskSet: TaskSet, message: String) { - runEvent(TaskSetFailed(taskSet, message)) - } - - test("zero split job") { - val rdd = makeRdd(0, Nil) - var numResults = 0 - val fakeListener = new JobListener() { - override def taskSucceeded(partition: Int, value: Any) = numResults += 1 - override def jobFailed(exception: Exception) = throw exception - } - submit(rdd, Array(), listener = fakeListener) - assert(numResults === 0) - } - - test("run trivial job") { - val rdd = makeRdd(1, Nil) - submit(rdd, Array(0)) - complete(taskSets(0), List((Success, 42))) - assert(results === Map(0 -> 42)) - } - - test("local job") { - val rdd = new MyRDD(sc, Nil) { - override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = - Array(42 -> 0).iterator - override def getPartitions = Array( new Partition { override def index = 0 } ) - override def getPreferredLocations(split: Partition) = Nil - override def toString = "DAGSchedulerSuite Local RDD" - } - runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener)) - assert(results === Map(0 -> 42)) - } - - test("run trivial job w/ dependency") { - val baseRdd = makeRdd(1, Nil) - val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) - submit(finalRdd, Array(0)) - complete(taskSets(0), Seq((Success, 42))) - assert(results === Map(0 -> 42)) - } - - test("cache location preferences w/ dependency") { - val baseRdd = makeRdd(1, Nil) - val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) - cacheLocations(baseRdd.id -> 0) = - Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) - submit(finalRdd, Array(0)) - val taskSet = taskSets(0) - assertLocations(taskSet, Seq(Seq("hostA", "hostB"))) - complete(taskSet, Seq((Success, 42))) - assert(results === Map(0 -> 42)) - } - - test("trivial job failure") { - submit(makeRdd(1, Nil), Array(0)) - failed(taskSets(0), "some failure") - assert(failure.getMessage === "Job failed: some failure") - } - - test("run trivial shuffle") { - val shuffleMapRdd = makeRdd(2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) - val shuffleId = shuffleDep.shuffleId - val reduceRdd = makeRdd(1, List(shuffleDep)) - submit(reduceRdd, Array(0)) - complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) - complete(taskSets(1), Seq((Success, 42))) - assert(results === Map(0 -> 42)) - } - - test("run trivial shuffle with fetch failure") { - val shuffleMapRdd = makeRdd(2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) - val shuffleId = shuffleDep.shuffleId - val reduceRdd = makeRdd(2, List(shuffleDep)) - submit(reduceRdd, Array(0, 1)) - complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) - // the 2nd ResultTask failed - complete(taskSets(1), Seq( - (Success, 42), - (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null))) - // this will get called - // blockManagerMaster.removeExecutor("exec-hostA") - // ask the scheduler to try it again - scheduler.resubmitFailedStages() - // have the 2nd attempt pass - complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) - // we can see both result blocks now - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) - complete(taskSets(3), Seq((Success, 43))) - assert(results === Map(0 -> 42, 1 -> 43)) - } - - test("ignore late map task completions") { - val shuffleMapRdd = makeRdd(2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) - val shuffleId = shuffleDep.shuffleId - val reduceRdd = makeRdd(2, List(shuffleDep)) - submit(reduceRdd, Array(0, 1)) - // pretend we were told hostA went away - val oldEpoch = mapOutputTracker.getEpoch - runEvent(ExecutorLost("exec-hostA")) - val newEpoch = mapOutputTracker.getEpoch - assert(newEpoch > oldEpoch) - val noAccum = Map[Long, Any]() - val taskSet = taskSets(0) - // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null)) - // should work because it's a non-failed host - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum, null, null)) - // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null)) - // should work because it's a new epoch - taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null)) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) - complete(taskSets(1), Seq((Success, 42), (Success, 43))) - assert(results === Map(0 -> 42, 1 -> 43)) - } - - test("run trivial shuffle with out-of-band failure and retry") { - val shuffleMapRdd = makeRdd(2, Nil) - val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) - val shuffleId = shuffleDep.shuffleId - val reduceRdd = makeRdd(1, List(shuffleDep)) - submit(reduceRdd, Array(0)) - // blockManagerMaster.removeExecutor("exec-hostA") - // pretend we were told hostA went away - runEvent(ExecutorLost("exec-hostA")) - // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks - // rather than marking it is as failed and waiting. - complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) - // have hostC complete the resubmitted task - complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) - complete(taskSets(2), Seq((Success, 42))) - assert(results === Map(0 -> 42)) - } - - test("recursive shuffle failures") { - val shuffleOneRdd = makeRdd(2, Nil) - val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) - val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) - val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) - val finalRdd = makeRdd(1, List(shuffleDepTwo)) - submit(finalRdd, Array(0)) - // have the first stage complete normally - complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 2)), - (Success, makeMapStatus("hostB", 2)))) - // have the second stage complete normally - complete(taskSets(1), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostC", 1)))) - // fail the third stage because hostA went down - complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) - // TODO assert this: - // blockManagerMaster.removeExecutor("exec-hostA") - // have DAGScheduler try again - scheduler.resubmitFailedStages() - complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 2)))) - complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1)))) - complete(taskSets(5), Seq((Success, 42))) - assert(results === Map(0 -> 42)) - } - - test("cached post-shuffle") { - val shuffleOneRdd = makeRdd(2, Nil) - val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) - val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) - val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) - val finalRdd = makeRdd(1, List(shuffleDepTwo)) - submit(finalRdd, Array(0)) - cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) - cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) - // complete stage 2 - complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 2)), - (Success, makeMapStatus("hostB", 2)))) - // complete stage 1 - complete(taskSets(1), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) - // pretend stage 0 failed because hostA went down - complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null))) - // TODO assert this: - // blockManagerMaster.removeExecutor("exec-hostA") - // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. - scheduler.resubmitFailedStages() - assertLocations(taskSets(3), Seq(Seq("hostD"))) - // allow hostD to recover - complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1)))) - complete(taskSets(4), Seq((Success, 42))) - assert(results === Map(0 -> 42)) - } - - /** - * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. - * Note that this checks only the host and not the executor ID. - */ - private def assertLocations(taskSet: TaskSet, hosts: Seq[Seq[String]]) { - assert(hosts.size === taskSet.tasks.size) - for ((taskLocs, expectedLocs) <- taskSet.tasks.map(_.preferredLocations).zip(hosts)) { - assert(taskLocs.map(_.host) === expectedLocs) - } - } - - private def makeMapStatus(host: String, reduces: Int): MapStatus = - new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2)) - - private def makeBlockManagerId(host: String): BlockManagerId = - BlockManagerId("exec-" + host, host, 12345, 0) - -} diff --git a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala deleted file mode 100644 index bb9e715f95..0000000000 --- a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import java.util.Properties -import java.util.concurrent.LinkedBlockingQueue -import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers -import scala.collection.mutable -import spark._ -import spark.SparkContext._ - - -class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { - - test("inner method") { - sc = new SparkContext("local", "joblogger") - val joblogger = new JobLogger { - def createLogWriterTest(jobID: Int) = createLogWriter(jobID) - def closeLogWriterTest(jobID: Int) = closeLogWriter(jobID) - def getRddNameTest(rdd: RDD[_]) = getRddName(rdd) - def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage) - } - type MyRDD = RDD[(Int, Int)] - def makeRdd( - numPartitions: Int, - dependencies: List[Dependency[_]] - ): MyRDD = { - val maxPartition = numPartitions - 1 - return new MyRDD(sc, dependencies) { - override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = - throw new RuntimeException("should not be reached") - override def getPartitions = (0 to maxPartition).map(i => new Partition { - override def index = i - }).toArray - } - } - val jobID = 5 - val parentRdd = makeRdd(4, Nil) - val shuffleDep = new ShuffleDependency(parentRdd, null) - val rootRdd = makeRdd(4, List(shuffleDep)) - val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID, None) - val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID, None) - - joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4, null)) - joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName) - parentRdd.setName("MyRDD") - joblogger.getRddNameTest(parentRdd) should be ("MyRDD") - joblogger.createLogWriterTest(jobID) - joblogger.getJobIDtoPrintWriter.size should be (1) - joblogger.buildJobDepTest(jobID, rootStage) - joblogger.getJobIDToStages.get(jobID).get.size should be (2) - joblogger.getStageIDToJobID.get(0) should be (Some(jobID)) - joblogger.getStageIDToJobID.get(1) should be (Some(jobID)) - joblogger.closeLogWriterTest(jobID) - joblogger.getStageIDToJobID.size should be (0) - joblogger.getJobIDToStages.size should be (0) - joblogger.getJobIDtoPrintWriter.size should be (0) - } - - test("inner variables") { - sc = new SparkContext("local[4]", "joblogger") - val joblogger = new JobLogger { - override protected def closeLogWriter(jobID: Int) = - getJobIDtoPrintWriter.get(jobID).foreach { fileWriter => - fileWriter.close() - } - } - sc.addSparkListener(joblogger) - val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } - rdd.reduceByKey(_+_).collect() - - joblogger.getLogDir should be ("/tmp/spark") - joblogger.getJobIDtoPrintWriter.size should be (1) - joblogger.getStageIDToJobID.size should be (2) - joblogger.getStageIDToJobID.get(0) should be (Some(0)) - joblogger.getStageIDToJobID.get(1) should be (Some(0)) - joblogger.getJobIDToStages.size should be (1) - } - - - test("interface functions") { - sc = new SparkContext("local[4]", "joblogger") - val joblogger = new JobLogger { - var onTaskEndCount = 0 - var onJobEndCount = 0 - var onJobStartCount = 0 - var onStageCompletedCount = 0 - var onStageSubmittedCount = 0 - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1 - override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1 - override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1 - override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1 - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1 - } - sc.addSparkListener(joblogger) - val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } - rdd.reduceByKey(_+_).collect() - - joblogger.onJobStartCount should be (1) - joblogger.onJobEndCount should be (1) - joblogger.onTaskEndCount should be (8) - joblogger.onStageSubmittedCount should be (2) - joblogger.onStageCompletedCount should be (2) - } -} diff --git a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala deleted file mode 100644 index 392d67d67b..0000000000 --- a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import org.scalatest.FunSuite -import spark.{SparkContext, LocalSparkContext} -import scala.collection.mutable -import org.scalatest.matchers.ShouldMatchers -import spark.SparkContext._ - -/** - * - */ - -class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { - - test("local metrics") { - sc = new SparkContext("local[4]", "test") - val listener = new SaveStageInfo - sc.addSparkListener(listener) - sc.addSparkListener(new StatsReportListener) - //just to make sure some of the tasks take a noticeable amount of time - val w = {i:Int => - if (i == 0) - Thread.sleep(100) - i - } - - val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)} - d.count - listener.stageInfos.size should be (1) - - val d2 = d.map{i => w(i) -> i * 2}.setName("shuffle input 1") - - val d3 = d.map{i => w(i) -> (0 to (i % 5))}.setName("shuffle input 2") - - val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => w(k) -> (v1.size, v2.size)} - d4.setName("A Cogroup") - - d4.collectAsMap - - listener.stageInfos.size should be (4) - listener.stageInfos.foreach {stageInfo => - //small test, so some tasks might take less than 1 millisecond, but average should be greater than 1 ms - checkNonZeroAvg(stageInfo.taskInfos.map{_._1.duration}, stageInfo + " duration") - checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorRunTime.toLong}, stageInfo + " executorRunTime") - checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong}, stageInfo + " executorDeserializeTime") - if (stageInfo.stage.rdd.name == d4.name) { - checkNonZeroAvg(stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime}, stageInfo + " fetchWaitTime") - } - - stageInfo.taskInfos.foreach{case (taskInfo, taskMetrics) => - taskMetrics.resultSize should be > (0l) - if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) { - taskMetrics.shuffleWriteMetrics should be ('defined) - taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l) - } - if (stageInfo.stage.rdd.name == d4.name) { - taskMetrics.shuffleReadMetrics should be ('defined) - val sm = taskMetrics.shuffleReadMetrics.get - sm.totalBlocksFetched should be > (0) - sm.localBlocksFetched should be > (0) - sm.remoteBlocksFetched should be (0) - sm.remoteBytesRead should be (0l) - sm.remoteFetchTime should be (0l) - } - } - } - } - - def checkNonZeroAvg(m: Traversable[Long], msg: String) { - assert(m.sum / m.size.toDouble > 0.0, msg) - } - - def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = { - val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name} - !names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty - } - - class SaveStageInfo extends SparkListener { - val stageInfos = mutable.Buffer[StageInfo]() - override def onStageCompleted(stage: StageCompleted) { - stageInfos += stage.stageInfo - } - } - -} diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala deleted file mode 100644 index 95a6eee2fc..0000000000 --- a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter -import spark.TaskContext -import spark.RDD -import spark.SparkContext -import spark.Partition -import spark.LocalSparkContext - -class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { - - test("Calls executeOnCompleteCallbacks after failure") { - var completed = false - sc = new SparkContext("local", "test") - val rdd = new RDD[String](sc, List()) { - override def getPartitions = Array[Partition](StubPartition(0)) - override def compute(split: Partition, context: TaskContext) = { - context.addOnCompleteCallback(() => completed = true) - sys.error("failed") - } - } - val func = (c: TaskContext, i: Iterator[String]) => i.next - val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0) - intercept[RuntimeException] { - task.run(0) - } - assert(completed === true) - } - - case class StubPartition(val index: Int) extends Partition -} diff --git a/core/src/test/scala/spark/scheduler/cluster/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/cluster/ClusterSchedulerSuite.scala deleted file mode 100644 index abfdabf5fe..0000000000 --- a/core/src/test/scala/spark/scheduler/cluster/ClusterSchedulerSuite.scala +++ /dev/null @@ -1,266 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - -import spark._ -import spark.scheduler._ -import spark.scheduler.cluster._ -import scala.collection.mutable.ArrayBuffer - -import java.util.Properties - -class FakeTaskSetManager( - initPriority: Int, - initStageId: Int, - initNumTasks: Int, - clusterScheduler: ClusterScheduler, - taskSet: TaskSet) - extends ClusterTaskSetManager(clusterScheduler, taskSet) { - - parent = null - weight = 1 - minShare = 2 - runningTasks = 0 - priority = initPriority - stageId = initStageId - name = "TaskSet_"+stageId - override val numTasks = initNumTasks - tasksFinished = 0 - - override def increaseRunningTasks(taskNum: Int) { - runningTasks += taskNum - if (parent != null) { - parent.increaseRunningTasks(taskNum) - } - } - - override def decreaseRunningTasks(taskNum: Int) { - runningTasks -= taskNum - if (parent != null) { - parent.decreaseRunningTasks(taskNum) - } - } - - override def addSchedulable(schedulable: Schedulable) { - } - - override def removeSchedulable(schedulable: Schedulable) { - } - - override def getSchedulableByName(name: String): Schedulable = { - return null - } - - override def executorLost(executorId: String, host: String): Unit = { - } - - override def resourceOffer( - execId: String, - host: String, - availableCpus: Int, - maxLocality: TaskLocality.TaskLocality) - : Option[TaskDescription] = - { - if (tasksFinished + runningTasks < numTasks) { - increaseRunningTasks(1) - return Some(new TaskDescription(0, execId, "task 0:0", 0, null)) - } - return None - } - - override def checkSpeculatableTasks(): Boolean = { - return true - } - - def taskFinished() { - decreaseRunningTasks(1) - tasksFinished +=1 - if (tasksFinished == numTasks) { - parent.removeSchedulable(this) - } - } - - def abort() { - decreaseRunningTasks(runningTasks) - parent.removeSchedulable(this) - } -} - -class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging { - - def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): FakeTaskSetManager = { - new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet) - } - - def resourceOffer(rootPool: Pool): Int = { - val taskSetQueue = rootPool.getSortedTaskSetQueue() - /* Just for Test*/ - for (manager <- taskSetQueue) { - logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks)) - } - for (taskSet <- taskSetQueue) { - taskSet.resourceOffer("execId_1", "hostname_1", 1, TaskLocality.ANY) match { - case Some(task) => - return taskSet.stageId - case None => {} - } - } - -1 - } - - def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) { - assert(resourceOffer(rootPool) === expectedTaskSetId) - } - - test("FIFO Scheduler Test") { - sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) - - val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0) - val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) - schedulableBuilder.buildPools() - - val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, clusterScheduler, taskSet) - val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, clusterScheduler, taskSet) - val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, clusterScheduler, taskSet) - schedulableBuilder.addTaskSetManager(taskSetManager0, null) - schedulableBuilder.addTaskSetManager(taskSetManager1, null) - schedulableBuilder.addTaskSetManager(taskSetManager2, null) - - checkTaskSetId(rootPool, 0) - resourceOffer(rootPool) - checkTaskSetId(rootPool, 1) - resourceOffer(rootPool) - taskSetManager1.abort() - checkTaskSetId(rootPool, 2) - } - - test("Fair Scheduler Test") { - sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) - - val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() - System.setProperty("spark.fairscheduler.allocation.file", xmlPath) - val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) - val schedulableBuilder = new FairSchedulableBuilder(rootPool) - schedulableBuilder.buildPools() - - assert(rootPool.getSchedulableByName("default") != null) - assert(rootPool.getSchedulableByName("1") != null) - assert(rootPool.getSchedulableByName("2") != null) - assert(rootPool.getSchedulableByName("3") != null) - assert(rootPool.getSchedulableByName("1").minShare === 2) - assert(rootPool.getSchedulableByName("1").weight === 1) - assert(rootPool.getSchedulableByName("2").minShare === 3) - assert(rootPool.getSchedulableByName("2").weight === 1) - assert(rootPool.getSchedulableByName("3").minShare === 2) - assert(rootPool.getSchedulableByName("3").weight === 1) - - val properties1 = new Properties() - properties1.setProperty("spark.scheduler.cluster.fair.pool","1") - val properties2 = new Properties() - properties2.setProperty("spark.scheduler.cluster.fair.pool","2") - - val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, clusterScheduler, taskSet) - val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, clusterScheduler, taskSet) - val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, clusterScheduler, taskSet) - schedulableBuilder.addTaskSetManager(taskSetManager10, properties1) - schedulableBuilder.addTaskSetManager(taskSetManager11, properties1) - schedulableBuilder.addTaskSetManager(taskSetManager12, properties1) - - val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, clusterScheduler, taskSet) - val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, clusterScheduler, taskSet) - schedulableBuilder.addTaskSetManager(taskSetManager23, properties2) - schedulableBuilder.addTaskSetManager(taskSetManager24, properties2) - - checkTaskSetId(rootPool, 0) - checkTaskSetId(rootPool, 3) - checkTaskSetId(rootPool, 3) - checkTaskSetId(rootPool, 1) - checkTaskSetId(rootPool, 4) - checkTaskSetId(rootPool, 2) - checkTaskSetId(rootPool, 2) - checkTaskSetId(rootPool, 4) - - taskSetManager12.taskFinished() - assert(rootPool.getSchedulableByName("1").runningTasks === 3) - taskSetManager24.abort() - assert(rootPool.getSchedulableByName("2").runningTasks === 2) - } - - test("Nested Pool Test") { - sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new ClusterScheduler(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) - - val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) - val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1) - val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1) - rootPool.addSchedulable(pool0) - rootPool.addSchedulable(pool1) - - val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2) - val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1) - pool0.addSchedulable(pool00) - pool0.addSchedulable(pool01) - - val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2) - val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1) - pool1.addSchedulable(pool10) - pool1.addSchedulable(pool11) - - val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, clusterScheduler, taskSet) - val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, clusterScheduler, taskSet) - pool00.addSchedulable(taskSetManager000) - pool00.addSchedulable(taskSetManager001) - - val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, clusterScheduler, taskSet) - val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, clusterScheduler, taskSet) - pool01.addSchedulable(taskSetManager010) - pool01.addSchedulable(taskSetManager011) - - val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, clusterScheduler, taskSet) - val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, clusterScheduler, taskSet) - pool10.addSchedulable(taskSetManager100) - pool10.addSchedulable(taskSetManager101) - - val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, clusterScheduler, taskSet) - val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, clusterScheduler, taskSet) - pool11.addSchedulable(taskSetManager110) - pool11.addSchedulable(taskSetManager111) - - checkTaskSetId(rootPool, 0) - checkTaskSetId(rootPool, 4) - checkTaskSetId(rootPool, 6) - checkTaskSetId(rootPool, 2) - } -} diff --git a/core/src/test/scala/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala deleted file mode 100644 index 5a0b949ef5..0000000000 --- a/core/src/test/scala/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala +++ /dev/null @@ -1,273 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable - -import org.scalatest.FunSuite - -import spark._ -import spark.scheduler._ -import spark.executor.TaskMetrics -import java.nio.ByteBuffer -import spark.util.FakeClock - -/** - * A mock ClusterScheduler implementation that just remembers information about tasks started and - * feedback received from the TaskSetManagers. Note that it's important to initialize this with - * a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost - * to work, and these are required for locality in ClusterTaskSetManager. - */ -class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */) - extends ClusterScheduler(sc) -{ - val startedTasks = new ArrayBuffer[Long] - val endedTasks = new mutable.HashMap[Long, TaskEndReason] - val finishedManagers = new ArrayBuffer[TaskSetManager] - - val executors = new mutable.HashMap[String, String] ++ liveExecutors - - listener = new TaskSchedulerListener { - def taskStarted(task: Task[_], taskInfo: TaskInfo) { - startedTasks += taskInfo.index - } - - def taskEnded( - task: Task[_], - reason: TaskEndReason, - result: Any, - accumUpdates: mutable.Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics) - { - endedTasks(taskInfo.index) = reason - } - - def executorGained(execId: String, host: String) {} - - def executorLost(execId: String) {} - - def taskSetFailed(taskSet: TaskSet, reason: String) {} - } - - def removeExecutor(execId: String): Unit = executors -= execId - - override def taskSetFinished(manager: TaskSetManager): Unit = finishedManagers += manager - - override def isExecutorAlive(execId: String): Boolean = executors.contains(execId) - - override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host) -} - -class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { - import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL} - - val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong - - test("TaskSet with no preferences") { - sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) - val taskSet = createTaskSet(1) - val manager = new ClusterTaskSetManager(sched, taskSet) - - // Offer a host with no CPUs - assert(manager.resourceOffer("exec1", "host1", 0, ANY) === None) - - // Offer a host with process-local as the constraint; this should work because the TaskSet - // above won't have any locality preferences - val taskOption = manager.resourceOffer("exec1", "host1", 2, TaskLocality.PROCESS_LOCAL) - assert(taskOption.isDefined) - val task = taskOption.get - assert(task.executorId === "exec1") - assert(sched.startedTasks.contains(0)) - - // Re-offer the host -- now we should get no more tasks - assert(manager.resourceOffer("exec1", "host1", 2, PROCESS_LOCAL) === None) - - // Tell it the task has finished - manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0)) - assert(sched.endedTasks(0) === Success) - assert(sched.finishedManagers.contains(manager)) - } - - test("multiple offers with no preferences") { - sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) - val taskSet = createTaskSet(3) - val manager = new ClusterTaskSetManager(sched, taskSet) - - // First three offers should all find tasks - for (i <- 0 until 3) { - val taskOption = manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) - assert(taskOption.isDefined) - val task = taskOption.get - assert(task.executorId === "exec1") - } - assert(sched.startedTasks.toSet === Set(0, 1, 2)) - - // Re-offer the host -- now we should get no more tasks - assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) - - // Finish the first two tasks - manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0)) - manager.statusUpdate(1, TaskState.FINISHED, createTaskResult(1)) - assert(sched.endedTasks(0) === Success) - assert(sched.endedTasks(1) === Success) - assert(!sched.finishedManagers.contains(manager)) - - // Finish the last task - manager.statusUpdate(2, TaskState.FINISHED, createTaskResult(2)) - assert(sched.endedTasks(2) === Success) - assert(sched.finishedManagers.contains(manager)) - } - - test("basic delay scheduling") { - sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) - val taskSet = createTaskSet(4, - Seq(TaskLocation("host1", "exec1")), - Seq(TaskLocation("host2", "exec2")), - Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")), - Seq() // Last task has no locality prefs - ) - val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) - - // First offer host1, exec1: first task should be chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) - - // Offer host1, exec1 again: the last task, which has no prefs, should be chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 3) - - // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) - - clock.advance(LOCALITY_WAIT) - - // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) - - // Offer host1, exec1 again, at NODE_LOCAL level: we should choose task 2 - assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL).get.index == 2) - - // Offer host1, exec1 again, at NODE_LOCAL level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL) === None) - - // Offer host1, exec1 again, at ANY level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) - - clock.advance(LOCALITY_WAIT) - - // Offer host1, exec1 again, at ANY level: task 1 should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) - - // Offer host1, exec1 again, at ANY level: nothing should be chosen as we've launched all tasks - assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) - } - - test("delay scheduling with fallback") { - sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, - ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) - val taskSet = createTaskSet(5, - Seq(TaskLocation("host1")), - Seq(TaskLocation("host2")), - Seq(TaskLocation("host2")), - Seq(TaskLocation("host3")), - Seq(TaskLocation("host2")) - ) - val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) - - // First offer host1: first task should be chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) - - // Offer host1 again: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) - - clock.advance(LOCALITY_WAIT) - - // Offer host1 again: second task (on host2) should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) - - // Offer host1 again: third task (on host2) should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2) - - // Offer host2: fifth task (also on host2) should get chosen - assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 4) - - // Now that we've launched a local task, we should no longer launch the task for host3 - assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None) - - clock.advance(LOCALITY_WAIT) - - // After another delay, we can go ahead and launch that task non-locally - assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 3) - } - - test("delay scheduling with failed hosts") { - sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) - val taskSet = createTaskSet(3, - Seq(TaskLocation("host1")), - Seq(TaskLocation("host2")), - Seq(TaskLocation("host3")) - ) - val clock = new FakeClock - val manager = new ClusterTaskSetManager(sched, taskSet, clock) - - // First offer host1: first task should be chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) - - // Offer host1 again: third task should be chosen immediately because host3 is not up - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2) - - // After this, nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) - - // Now mark host2 as dead - sched.removeExecutor("exec2") - manager.executorLost("exec2", "host2") - - // Task 1 should immediately be launched on host1 because its original host is gone - assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) - - // Now that all tasks have launched, nothing new should be launched anywhere else - assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) - assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None) - } - - /** - * Utility method to create a TaskSet, potentially setting a particular sequence of preferred - * locations for each task (given as varargs) if this sequence is not empty. - */ - def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { - if (prefLocs.size != 0 && prefLocs.size != numTasks) { - throw new IllegalArgumentException("Wrong number of task locations") - } - val tasks = Array.tabulate[Task[_]](numTasks) { i => - new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) - } - new TaskSet(tasks, 0, 0, 0, null) - } - - def createTaskResult(id: Int): ByteBuffer = { - ByteBuffer.wrap(Utils.serialize(new TaskResult[Int](id, mutable.Map.empty, new TaskMetrics))) - } -} diff --git a/core/src/test/scala/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/spark/scheduler/cluster/FakeTask.scala deleted file mode 100644 index de9e66be20..0000000000 --- a/core/src/test/scala/spark/scheduler/cluster/FakeTask.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.cluster - -import spark.scheduler.{TaskLocation, Task} - -class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId) { - override def run(attemptId: Long): Int = 0 - - override def preferredLocations: Seq[TaskLocation] = prefLocs -} diff --git a/core/src/test/scala/spark/scheduler/local/LocalSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/local/LocalSchedulerSuite.scala deleted file mode 100644 index d28ee47fa3..0000000000 --- a/core/src/test/scala/spark/scheduler/local/LocalSchedulerSuite.scala +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.scheduler.local - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - -import spark._ -import spark.scheduler._ -import spark.scheduler.cluster._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.{ConcurrentMap, HashMap} -import java.util.concurrent.Semaphore -import java.util.concurrent.CountDownLatch -import java.util.Properties - -class Lock() { - var finished = false - def jobWait() = { - synchronized { - while(!finished) { - this.wait() - } - } - } - - def jobFinished() = { - synchronized { - finished = true - this.notifyAll() - } - } -} - -object TaskThreadInfo { - val threadToLock = HashMap[Int, Lock]() - val threadToRunning = HashMap[Int, Boolean]() - val threadToStarted = HashMap[Int, CountDownLatch]() -} - -/* - * 1. each thread contains one job. - * 2. each job contains one stage. - * 3. each stage only contains one task. - * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure - * it will get cpu core resource, and will wait to finished after user manually - * release "Lock" and then cluster will contain another free cpu cores. - * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue, - * thus it will be scheduled later when cluster has free cpu cores. - */ -class LocalSchedulerSuite extends FunSuite with LocalSparkContext { - - def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) { - - TaskThreadInfo.threadToRunning(threadIndex) = false - val nums = sc.parallelize(threadIndex to threadIndex, 1) - TaskThreadInfo.threadToLock(threadIndex) = new Lock() - TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1) - new Thread { - if (poolName != null) { - sc.setLocalProperty("spark.scheduler.cluster.fair.pool", poolName) - } - override def run() { - val ans = nums.map(number => { - TaskThreadInfo.threadToRunning(number) = true - TaskThreadInfo.threadToStarted(number).countDown() - TaskThreadInfo.threadToLock(number).jobWait() - TaskThreadInfo.threadToRunning(number) = false - number - }).collect() - assert(ans.toList === List(threadIndex)) - sem.release() - } - }.start() - } - - test("Local FIFO scheduler end-to-end test") { - System.setProperty("spark.cluster.schedulingmode", "FIFO") - sc = new SparkContext("local[4]", "test") - val sem = new Semaphore(0) - - createThread(1,null,sc,sem) - TaskThreadInfo.threadToStarted(1).await() - createThread(2,null,sc,sem) - TaskThreadInfo.threadToStarted(2).await() - createThread(3,null,sc,sem) - TaskThreadInfo.threadToStarted(3).await() - createThread(4,null,sc,sem) - TaskThreadInfo.threadToStarted(4).await() - // thread 5 and 6 (stage pending)must meet following two points - // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager - // queue before executing TaskThreadInfo.threadToLock(1).jobFinished() - // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6 - // So I just use "sleep" 1s here for each thread. - // TODO: any better solution? - createThread(5,null,sc,sem) - Thread.sleep(1000) - createThread(6,null,sc,sem) - Thread.sleep(1000) - - assert(TaskThreadInfo.threadToRunning(1) === true) - assert(TaskThreadInfo.threadToRunning(2) === true) - assert(TaskThreadInfo.threadToRunning(3) === true) - assert(TaskThreadInfo.threadToRunning(4) === true) - assert(TaskThreadInfo.threadToRunning(5) === false) - assert(TaskThreadInfo.threadToRunning(6) === false) - - TaskThreadInfo.threadToLock(1).jobFinished() - TaskThreadInfo.threadToStarted(5).await() - - assert(TaskThreadInfo.threadToRunning(1) === false) - assert(TaskThreadInfo.threadToRunning(2) === true) - assert(TaskThreadInfo.threadToRunning(3) === true) - assert(TaskThreadInfo.threadToRunning(4) === true) - assert(TaskThreadInfo.threadToRunning(5) === true) - assert(TaskThreadInfo.threadToRunning(6) === false) - - TaskThreadInfo.threadToLock(3).jobFinished() - TaskThreadInfo.threadToStarted(6).await() - - assert(TaskThreadInfo.threadToRunning(1) === false) - assert(TaskThreadInfo.threadToRunning(2) === true) - assert(TaskThreadInfo.threadToRunning(3) === false) - assert(TaskThreadInfo.threadToRunning(4) === true) - assert(TaskThreadInfo.threadToRunning(5) === true) - assert(TaskThreadInfo.threadToRunning(6) === true) - - TaskThreadInfo.threadToLock(2).jobFinished() - TaskThreadInfo.threadToLock(4).jobFinished() - TaskThreadInfo.threadToLock(5).jobFinished() - TaskThreadInfo.threadToLock(6).jobFinished() - sem.acquire(6) - } - - test("Local fair scheduler end-to-end test") { - sc = new SparkContext("local[8]", "LocalSchedulerSuite") - val sem = new Semaphore(0) - System.setProperty("spark.cluster.schedulingmode", "FAIR") - val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() - System.setProperty("spark.fairscheduler.allocation.file", xmlPath) - - createThread(10,"1",sc,sem) - TaskThreadInfo.threadToStarted(10).await() - createThread(20,"2",sc,sem) - TaskThreadInfo.threadToStarted(20).await() - createThread(30,"3",sc,sem) - TaskThreadInfo.threadToStarted(30).await() - - assert(TaskThreadInfo.threadToRunning(10) === true) - assert(TaskThreadInfo.threadToRunning(20) === true) - assert(TaskThreadInfo.threadToRunning(30) === true) - - createThread(11,"1",sc,sem) - TaskThreadInfo.threadToStarted(11).await() - createThread(21,"2",sc,sem) - TaskThreadInfo.threadToStarted(21).await() - createThread(31,"3",sc,sem) - TaskThreadInfo.threadToStarted(31).await() - - assert(TaskThreadInfo.threadToRunning(11) === true) - assert(TaskThreadInfo.threadToRunning(21) === true) - assert(TaskThreadInfo.threadToRunning(31) === true) - - createThread(12,"1",sc,sem) - TaskThreadInfo.threadToStarted(12).await() - createThread(22,"2",sc,sem) - TaskThreadInfo.threadToStarted(22).await() - createThread(32,"3",sc,sem) - - assert(TaskThreadInfo.threadToRunning(12) === true) - assert(TaskThreadInfo.threadToRunning(22) === true) - assert(TaskThreadInfo.threadToRunning(32) === false) - - TaskThreadInfo.threadToLock(10).jobFinished() - TaskThreadInfo.threadToStarted(32).await() - - assert(TaskThreadInfo.threadToRunning(32) === true) - - //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager - // queue so that cluster will assign free cpu core to stage 23 after stage 11 finished. - //2. priority of 23 and 33 will be meaningless as using fair scheduler here. - createThread(23,"2",sc,sem) - createThread(33,"3",sc,sem) - Thread.sleep(1000) - - TaskThreadInfo.threadToLock(11).jobFinished() - TaskThreadInfo.threadToStarted(23).await() - - assert(TaskThreadInfo.threadToRunning(23) === true) - assert(TaskThreadInfo.threadToRunning(33) === false) - - TaskThreadInfo.threadToLock(12).jobFinished() - TaskThreadInfo.threadToStarted(33).await() - - assert(TaskThreadInfo.threadToRunning(33) === true) - - TaskThreadInfo.threadToLock(20).jobFinished() - TaskThreadInfo.threadToLock(21).jobFinished() - TaskThreadInfo.threadToLock(22).jobFinished() - TaskThreadInfo.threadToLock(23).jobFinished() - TaskThreadInfo.threadToLock(30).jobFinished() - TaskThreadInfo.threadToLock(31).jobFinished() - TaskThreadInfo.threadToLock(32).jobFinished() - TaskThreadInfo.threadToLock(33).jobFinished() - - sem.acquire(11) - } -} diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala deleted file mode 100644 index b719d65342..0000000000 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ /dev/null @@ -1,665 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.storage - -import java.nio.ByteBuffer - -import akka.actor._ - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter -import org.scalatest.PrivateMethodTester -import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.Timeouts._ -import org.scalatest.matchers.ShouldMatchers._ -import org.scalatest.time.SpanSugar._ - -import spark.JavaSerializer -import spark.KryoSerializer -import spark.SizeEstimator -import spark.util.AkkaUtils -import spark.util.ByteBufferInputStream - - -class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester { - var store: BlockManager = null - var store2: BlockManager = null - var actorSystem: ActorSystem = null - var master: BlockManagerMaster = null - var oldArch: String = null - var oldOops: String = null - var oldHeartBeat: String = null - - // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - System.setProperty("spark.kryoserializer.buffer.mb", "1") - val serializer = new KryoSerializer - - before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) - this.actorSystem = actorSystem - System.setProperty("spark.driver.port", boundPort.toString) - System.setProperty("spark.hostPort", "localhost:" + boundPort) - - master = new BlockManagerMaster( - actorSystem.actorOf(Props(new spark.storage.BlockManagerMasterActor(true)))) - - // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case - oldArch = System.setProperty("os.arch", "amd64") - oldOops = System.setProperty("spark.test.useCompressedOops", "true") - oldHeartBeat = System.setProperty("spark.storage.disableBlockManagerHeartBeat", "true") - val initialize = PrivateMethod[Unit]('initialize) - SizeEstimator invokePrivate initialize() - // Set some value ... - System.setProperty("spark.hostPort", spark.Utils.localHostName() + ":" + 1111) - } - - after { - System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - - if (store != null) { - store.stop() - store = null - } - if (store2 != null) { - store2.stop() - store2 = null - } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null - master = null - - if (oldArch != null) { - System.setProperty("os.arch", oldArch) - } else { - System.clearProperty("os.arch") - } - - if (oldOops != null) { - System.setProperty("spark.test.useCompressedOops", oldOops) - } else { - System.clearProperty("spark.test.useCompressedOops") - } - } - - test("StorageLevel object caching") { - val level1 = StorageLevel(false, false, false, 3) - val level2 = StorageLevel(false, false, false, 3) // this should return the same object as level1 - val level3 = StorageLevel(false, false, false, 2) // this should return a different object - assert(level2 === level1, "level2 is not same as level1") - assert(level2.eq(level1), "level2 is not the same object as level1") - assert(level3 != level1, "level3 is same as level1") - val bytes1 = spark.Utils.serialize(level1) - val level1_ = spark.Utils.deserialize[StorageLevel](bytes1) - val bytes2 = spark.Utils.serialize(level2) - val level2_ = spark.Utils.deserialize[StorageLevel](bytes2) - assert(level1_ === level1, "Deserialized level1 not same as original level1") - assert(level1_.eq(level1), "Deserialized level1 not the same object as original level2") - assert(level2_ === level2, "Deserialized level2 not same as original level2") - assert(level2_.eq(level1), "Deserialized level2 not the same object as original level1") - } - - test("BlockManagerId object caching") { - val id1 = BlockManagerId("e1", "XXX", 1, 0) - val id2 = BlockManagerId("e1", "XXX", 1, 0) // this should return the same object as id1 - val id3 = BlockManagerId("e1", "XXX", 2, 0) // this should return a different object - assert(id2 === id1, "id2 is not same as id1") - assert(id2.eq(id1), "id2 is not the same object as id1") - assert(id3 != id1, "id3 is same as id1") - val bytes1 = spark.Utils.serialize(id1) - val id1_ = spark.Utils.deserialize[BlockManagerId](bytes1) - val bytes2 = spark.Utils.serialize(id2) - val id2_ = spark.Utils.deserialize[BlockManagerId](bytes2) - assert(id1_ === id1, "Deserialized id1 is not same as original id1") - assert(id1_.eq(id1), "Deserialized id1 is not the same object as original id1") - assert(id2_ === id2, "Deserialized id2 is not same as original id2") - assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1") - } - - test("master + 1 manager interaction") { - store = new BlockManager("", actorSystem, master, serializer, 2000) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - - // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) - - // Checking whether blocks are in memory - assert(store.getSingle("a1") != None, "a1 was not in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - - // Checking whether master knows about the blocks or not - assert(master.getLocations("a1").size > 0, "master was not told about a1") - assert(master.getLocations("a2").size > 0, "master was not told about a2") - assert(master.getLocations("a3").size === 0, "master was told about a3") - - // Drop a1 and a2 from memory; this should be reported back to the master - store.dropFromMemory("a1", null) - store.dropFromMemory("a2", null) - assert(store.getSingle("a1") === None, "a1 not removed from store") - assert(store.getSingle("a2") === None, "a2 not removed from store") - assert(master.getLocations("a1").size === 0, "master did not remove a1") - assert(master.getLocations("a2").size === 0, "master did not remove a2") - } - - test("master + 2 managers interaction") { - store = new BlockManager("exec1", actorSystem, master, serializer, 2000) - store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer, 2000) - - val peers = master.getPeers(store.blockManagerId, 1) - assert(peers.size === 1, "master did not return the other manager as a peer") - assert(peers.head === store2.blockManagerId, "peer returned by master is not the other manager") - - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_2) - store2.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_2) - assert(master.getLocations("a1").size === 2, "master did not report 2 locations for a1") - assert(master.getLocations("a2").size === 2, "master did not report 2 locations for a2") - } - - test("removing block") { - store = new BlockManager("", actorSystem, master, serializer, 2000) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - - // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 - store.putSingle("a1-to-remove", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("a2-to-remove", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, tellMaster = false) - - // Checking whether blocks are in memory and memory size - val memStatus = master.getMemoryStatus.head._2 - assert(memStatus._1 == 2000L, "total memory " + memStatus._1 + " should equal 2000") - assert(memStatus._2 <= 1200L, "remaining memory " + memStatus._2 + " should <= 1200") - assert(store.getSingle("a1-to-remove") != None, "a1 was not in store") - assert(store.getSingle("a2-to-remove") != None, "a2 was not in store") - assert(store.getSingle("a3-to-remove") != None, "a3 was not in store") - - // Checking whether master knows about the blocks or not - assert(master.getLocations("a1-to-remove").size > 0, "master was not told about a1") - assert(master.getLocations("a2-to-remove").size > 0, "master was not told about a2") - assert(master.getLocations("a3-to-remove").size === 0, "master was told about a3") - - // Remove a1 and a2 and a3. Should be no-op for a3. - master.removeBlock("a1-to-remove") - master.removeBlock("a2-to-remove") - master.removeBlock("a3-to-remove") - - eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("a1-to-remove") should be (None) - master.getLocations("a1-to-remove") should have size 0 - } - eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("a2-to-remove") should be (None) - master.getLocations("a2-to-remove") should have size 0 - } - eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("a3-to-remove") should not be (None) - master.getLocations("a3-to-remove") should have size 0 - } - eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - val memStatus = master.getMemoryStatus.head._2 - memStatus._1 should equal (2000L) - memStatus._2 should equal (2000L) - } - } - - test("removing rdd") { - store = new BlockManager("", actorSystem, master, serializer, 2000) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - // Putting a1, a2 and a3 in memory. - store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY) - master.removeRdd(0, blocking = false) - - eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("rdd_0_0") should be (None) - master.getLocations("rdd_0_0") should have size 0 - } - eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("rdd_0_1") should be (None) - master.getLocations("rdd_0_1") should have size 0 - } - eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("nonrddblock") should not be (None) - master.getLocations("nonrddblock") should have size (1) - } - - store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) - master.removeRdd(0, blocking = true) - store.getSingle("rdd_0_0") should be (None) - master.getLocations("rdd_0_0") should have size 0 - store.getSingle("rdd_0_1") should be (None) - master.getLocations("rdd_0_1") should have size 0 - } - - test("reregistration on heart beat") { - val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000) - val a1 = new Array[Byte](400) - - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - - assert(store.getSingle("a1") != None, "a1 was not in store") - assert(master.getLocations("a1").size > 0, "master was not told about a1") - - master.removeExecutor(store.blockManagerId.executorId) - assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - - store invokePrivate heartBeat() - assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") - } - - test("reregistration on block update") { - store = new BlockManager("", actorSystem, master, serializer, 2000) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - assert(master.getLocations("a1").size > 0, "master was not told about a1") - - master.removeExecutor(store.blockManagerId.executorId) - assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) - store.waitForAsyncReregister() - - assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") - assert(master.getLocations("a2").size > 0, "master was not told about a2") - } - - test("reregistration doesn't dead lock") { - val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000) - val a1 = new Array[Byte](400) - val a2 = List(new Array[Byte](400)) - - // try many times to trigger any deadlocks - for (i <- 1 to 100) { - master.removeExecutor(store.blockManagerId.executorId) - val t1 = new Thread { - override def run() { - store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - } - } - val t2 = new Thread { - override def run() { - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - } - } - val t3 = new Thread { - override def run() { - store invokePrivate heartBeat() - } - } - - t1.start() - t2.start() - t3.start() - t1.join() - t2.join() - t3.join() - - store.dropFromMemory("a1", null) - store.dropFromMemory("a2", null) - store.waitForAsyncReregister() - } - } - - test("in-memory LRU storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - assert(store.getSingle("a1") === None, "a1 was in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - // At this point a2 was gotten last, so LRU will getSingle rid of a3 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a1") != None, "a1 was not in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") === None, "a3 was in store") - } - - test("in-memory LRU storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY_SER) - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - assert(store.getSingle("a1") === None, "a1 was in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - // At this point a2 was gotten last, so LRU will getSingle rid of a3 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) - assert(store.getSingle("a1") != None, "a1 was not in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") === None, "a3 was in store") - } - - test("in-memory LRU for partitions of same RDD") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY) - // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2 - // from the same RDD - assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") - assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store") - assert(store.getSingle("rdd_0_1") != None, "rdd_0_1 was not in store") - // Check that rdd_0_3 doesn't replace them even after further accesses - assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") - assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") - assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") - } - - test("in-memory LRU for partitions of multiple RDDs") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) - // At this point rdd_1_1 should've replaced rdd_0_1 - assert(store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was not in store") - assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store") - assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store") - // Do a get() on rdd_0_2 so that it is the most recently used item - assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store") - // Put in more partitions from RDD 0; they should replace rdd_1_1 - store.putSingle("rdd_0_3", new Array[Byte](400), StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_4", new Array[Byte](400), StorageLevel.MEMORY_ONLY) - // Now rdd_1_1 should be dropped to add rdd_0_3, but then rdd_0_2 should *not* be dropped - // when we try to add rdd_0_4. - assert(!store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was in store") - assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store") - assert(!store.memoryStore.contains("rdd_0_4"), "rdd_0_4 was in store") - assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store") - assert(store.memoryStore.contains("rdd_0_3"), "rdd_0_3 was not in store") - } - - test("on-disk storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.DISK_ONLY) - store.putSingle("a2", a2, StorageLevel.DISK_ONLY) - store.putSingle("a3", a3, StorageLevel.DISK_ONLY) - assert(store.getSingle("a2") != None, "a2 was in store") - assert(store.getSingle("a3") != None, "a3 was in store") - assert(store.getSingle("a1") != None, "a1 was in store") - } - - test("disk and memory storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getSingle("a1") != None, "a1 was not in store") - assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store") - } - - test("disk and memory storage with getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) - assert(store.getLocalBytes("a2") != None, "a2 was not in store") - assert(store.getLocalBytes("a3") != None, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getLocalBytes("a1") != None, "a1 was not in store") - assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store") - } - - test("disk and memory storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getSingle("a1") != None, "a1 was not in store") - assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store") - } - - test("disk and memory storage with serialization and getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) - assert(store.getLocalBytes("a2") != None, "a2 was not in store") - assert(store.getLocalBytes("a3") != None, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getLocalBytes("a1") != None, "a1 was not in store") - assert(store.memoryStore.getValues("a1") != None, "a1 was not in memory store") - } - - test("LRU with mixed storage levels") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - val a4 = new Array[Byte](400) - // First store a1 and a2, both in memory, and a3, on disk only - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a3", a3, StorageLevel.DISK_ONLY) - // At this point LRU should not kick in because a3 is only on disk - assert(store.getSingle("a1") != None, "a2 was not in store") - assert(store.getSingle("a2") != None, "a3 was not in store") - assert(store.getSingle("a3") != None, "a1 was not in store") - assert(store.getSingle("a1") != None, "a2 was not in store") - assert(store.getSingle("a2") != None, "a3 was not in store") - assert(store.getSingle("a3") != None, "a1 was not in store") - // Now let's add in a4, which uses both disk and memory; a1 should drop out - store.putSingle("a4", a4, StorageLevel.MEMORY_AND_DISK_SER) - assert(store.getSingle("a1") == None, "a1 was in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - assert(store.getSingle("a4") != None, "a4 was not in store") - } - - test("in-memory LRU with streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val list1 = List(new Array[Byte](200), new Array[Byte](200)) - val list2 = List(new Array[Byte](200), new Array[Byte](200)) - val list3 = List(new Array[Byte](200), new Array[Byte](200)) - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - store.put("list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - assert(store.get("list2") != None, "list2 was not in store") - assert(store.get("list2").get.size == 2) - assert(store.get("list3") != None, "list3 was not in store") - assert(store.get("list3").get.size == 2) - assert(store.get("list1") === None, "list1 was in store") - assert(store.get("list2") != None, "list2 was not in store") - assert(store.get("list2").get.size == 2) - // At this point list2 was gotten last, so LRU will getSingle rid of list3 - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) - assert(store.get("list1") != None, "list1 was not in store") - assert(store.get("list1").get.size == 2) - assert(store.get("list2") != None, "list2 was not in store") - assert(store.get("list2").get.size == 2) - assert(store.get("list3") === None, "list1 was in store") - } - - test("LRU with mixed storage levels and streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200) - val list1 = List(new Array[Byte](200), new Array[Byte](200)) - val list2 = List(new Array[Byte](200), new Array[Byte](200)) - val list3 = List(new Array[Byte](200), new Array[Byte](200)) - val list4 = List(new Array[Byte](200), new Array[Byte](200)) - // First store list1 and list2, both in memory, and list3, on disk only - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) - store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true) - store.put("list3", list3.iterator, StorageLevel.DISK_ONLY, tellMaster = true) - // At this point LRU should not kick in because list3 is only on disk - assert(store.get("list1") != None, "list2 was not in store") - assert(store.get("list1").get.size === 2) - assert(store.get("list2") != None, "list3 was not in store") - assert(store.get("list2").get.size === 2) - assert(store.get("list3") != None, "list1 was not in store") - assert(store.get("list3").get.size === 2) - assert(store.get("list1") != None, "list2 was not in store") - assert(store.get("list1").get.size === 2) - assert(store.get("list2") != None, "list3 was not in store") - assert(store.get("list2").get.size === 2) - assert(store.get("list3") != None, "list1 was not in store") - assert(store.get("list3").get.size === 2) - // Now let's add in list4, which uses both disk and memory; list1 should drop out - store.put("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) - assert(store.get("list1") === None, "list1 was in store") - assert(store.get("list2") != None, "list3 was not in store") - assert(store.get("list2").get.size === 2) - assert(store.get("list3") != None, "list1 was not in store") - assert(store.get("list3").get.size === 2) - assert(store.get("list4") != None, "list4 was not in store") - assert(store.get("list4").get.size === 2) - } - - test("negative byte values in ByteBufferInputStream") { - val buffer = ByteBuffer.wrap(Array[Int](254, 255, 0, 1, 2).map(_.toByte).toArray) - val stream = new ByteBufferInputStream(buffer) - val temp = new Array[Byte](10) - assert(stream.read() === 254, "unexpected byte read") - assert(stream.read() === 255, "unexpected byte read") - assert(stream.read() === 0, "unexpected byte read") - assert(stream.read(temp, 0, temp.length) === 2, "unexpected number of bytes read") - assert(stream.read() === -1, "end of stream not signalled") - assert(stream.read(temp, 0, temp.length) === -1, "end of stream not signalled") - } - - test("overly large block") { - store = new BlockManager("", actorSystem, master, serializer, 500) - store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a1") === None, "a1 was in store") - store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) - assert(store.memoryStore.getValues("a2") === None, "a2 was in memory store") - assert(store.getSingle("a2") != None, "a2 was not in store") - } - - test("block compression") { - try { - System.setProperty("spark.shuffle.compress", "true") - store = new BlockManager("exec1", actorSystem, master, serializer, 2000) - store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed") - store.stop() - store = null - - System.setProperty("spark.shuffle.compress", "false") - store = new BlockManager("exec2", actorSystem, master, serializer, 2000) - store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed") - store.stop() - store = null - - System.setProperty("spark.broadcast.compress", "true") - store = new BlockManager("exec3", actorSystem, master, serializer, 2000) - store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed") - store.stop() - store = null - - System.setProperty("spark.broadcast.compress", "false") - store = new BlockManager("exec4", actorSystem, master, serializer, 2000) - store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed") - store.stop() - store = null - - System.setProperty("spark.rdd.compress", "true") - store = new BlockManager("exec5", actorSystem, master, serializer, 2000) - store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed") - store.stop() - store = null - - System.setProperty("spark.rdd.compress", "false") - store = new BlockManager("exec6", actorSystem, master, serializer, 2000) - store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed") - store.stop() - store = null - - // Check that any other block types are also kept uncompressed - store = new BlockManager("exec7", actorSystem, master, serializer, 2000) - store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) - assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") - store.stop() - store = null - } finally { - System.clearProperty("spark.shuffle.compress") - System.clearProperty("spark.broadcast.compress") - System.clearProperty("spark.rdd.compress") - } - } - - test("block store put failure") { - // Use Java serializer so we can create an unserializable error. - store = new BlockManager("", actorSystem, master, new JavaSerializer, 1200) - - // The put should fail since a1 is not serializable. - class UnserializableClass - val a1 = new UnserializableClass - intercept[java.io.NotSerializableException] { - store.putSingle("a1", a1, StorageLevel.DISK_ONLY) - } - - // Make sure get a1 doesn't hang and returns None. - failAfter(1 second) { - assert(store.getSingle("a1") == None, "a1 should not be in store") - } - } -} diff --git a/core/src/test/scala/spark/ui/UISuite.scala b/core/src/test/scala/spark/ui/UISuite.scala deleted file mode 100644 index 735a794396..0000000000 --- a/core/src/test/scala/spark/ui/UISuite.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.ui - -import scala.util.{Failure, Success, Try} -import java.net.ServerSocket -import org.scalatest.FunSuite -import org.eclipse.jetty.server.Server - -class UISuite extends FunSuite { - test("jetty port increases under contention") { - val startPort = 3030 - val server = new Server(startPort) - server.start() - val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("localhost", startPort, Seq()) - val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("localhost", startPort, Seq()) - - // Allow some wiggle room in case ports on the machine are under contention - assert(boundPort1 > startPort && boundPort1 < startPort + 10) - assert(boundPort2 > boundPort1 && boundPort2 < boundPort1 + 10) - } - - test("jetty binds to port 0 correctly") { - val (jettyServer, boundPort) = JettyUtils.startJettyServer("localhost", 0, Seq()) - assert(jettyServer.getState === "STARTED") - assert(boundPort != 0) - Try {new ServerSocket(boundPort)} match { - case Success(s) => fail("Port %s doesn't seem used by jetty server".format(boundPort)) - case Failure (e) => - } - } -} diff --git a/core/src/test/scala/spark/util/DistributionSuite.scala b/core/src/test/scala/spark/util/DistributionSuite.scala deleted file mode 100644 index 6578b55e82..0000000000 --- a/core/src/test/scala/spark/util/DistributionSuite.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers - -/** - * - */ - -class DistributionSuite extends FunSuite with ShouldMatchers { - test("summary") { - val d = new Distribution((1 to 100).toArray.map{_.toDouble}) - val stats = d.statCounter - stats.count should be (100) - stats.mean should be (50.5) - stats.sum should be (50 * 101) - - val quantiles = d.getQuantiles() - quantiles(0) should be (1) - quantiles(1) should be (26) - quantiles(2) should be (51) - quantiles(3) should be (76) - quantiles(4) should be (100) - } -} diff --git a/core/src/test/scala/spark/util/FakeClock.scala b/core/src/test/scala/spark/util/FakeClock.scala deleted file mode 100644 index 236706317e..0000000000 --- a/core/src/test/scala/spark/util/FakeClock.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -class FakeClock extends Clock { - private var time = 0L - - def advance(millis: Long): Unit = time += millis - - def getTime(): Long = time -} diff --git a/core/src/test/scala/spark/util/NextIteratorSuite.scala b/core/src/test/scala/spark/util/NextIteratorSuite.scala deleted file mode 100644 index fdbd43d941..0000000000 --- a/core/src/test/scala/spark/util/NextIteratorSuite.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers -import scala.collection.mutable.Buffer -import java.util.NoSuchElementException - -class NextIteratorSuite extends FunSuite with ShouldMatchers { - test("one iteration") { - val i = new StubIterator(Buffer(1)) - i.hasNext should be === true - i.next should be === 1 - i.hasNext should be === false - intercept[NoSuchElementException] { i.next() } - } - - test("two iterations") { - val i = new StubIterator(Buffer(1, 2)) - i.hasNext should be === true - i.next should be === 1 - i.hasNext should be === true - i.next should be === 2 - i.hasNext should be === false - intercept[NoSuchElementException] { i.next() } - } - - test("empty iteration") { - val i = new StubIterator(Buffer()) - i.hasNext should be === false - intercept[NoSuchElementException] { i.next() } - } - - test("close is called once for empty iterations") { - val i = new StubIterator(Buffer()) - i.hasNext should be === false - i.hasNext should be === false - i.closeCalled should be === 1 - } - - test("close is called once for non-empty iterations") { - val i = new StubIterator(Buffer(1, 2)) - i.next should be === 1 - i.next should be === 2 - // close isn't called until we check for the next element - i.closeCalled should be === 0 - i.hasNext should be === false - i.closeCalled should be === 1 - i.hasNext should be === false - i.closeCalled should be === 1 - } - - class StubIterator(ints: Buffer[Int]) extends NextIterator[Int] { - var closeCalled = 0 - - override def getNext() = { - if (ints.size == 0) { - finished = true - 0 - } else { - ints.remove(0) - } - } - - override def close() { - closeCalled += 1 - } - } -} diff --git a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala deleted file mode 100644 index 4c0044202f..0000000000 --- a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.util - -import org.scalatest.FunSuite -import java.io.ByteArrayOutputStream -import java.util.concurrent.TimeUnit._ - -class RateLimitedOutputStreamSuite extends FunSuite { - - private def benchmark[U](f: => U): Long = { - val start = System.nanoTime - f - System.nanoTime - start - } - - test("write") { - val underlying = new ByteArrayOutputStream - val data = "X" * 41000 - val stream = new RateLimitedOutputStream(underlying, 10000) - val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } - assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4) - assert(underlying.toString("UTF-8") == data) - } -} diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 84749fda4e..349eb92a47 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -100,7 +100,7 @@
  • Tuning Guide
  • Hardware Provisioning
  • Building Spark with Maven
  • -
  • Contributing to Spark
  • +
  • Contributing to Spark
  • diff --git a/examples/pom.xml b/examples/pom.xml index 687fbcca8f..13b5531511 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -19,13 +19,13 @@ 4.0.0 - org.spark-project + org.apache.spark spark-parent 0.8.0-SNAPSHOT ../pom.xml - org.spark-project + org.apache.spark spark-examples jar Spark Project Examples @@ -33,25 +33,25 @@ - org.spark-project + org.apache.spark spark-core ${project.version} provided - org.spark-project + org.apache.spark spark-streaming ${project.version} provided - org.spark-project + org.apache.spark spark-mllib ${project.version} provided - org.spark-project + org.apache.spark spark-bagel ${project.version} provided @@ -132,7 +132,7 @@ hadoop2-yarn - org.spark-project + org.apache.spark spark-yarn ${project.version} provided diff --git a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java new file mode 100644 index 0000000000..be0d38589c --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.StringTokenizer; +import java.util.Random; + +/** + * Logistic regression based classification. + */ +public class JavaHdfsLR { + + static int D = 10; // Number of dimensions + static Random rand = new Random(42); + + static class DataPoint implements Serializable { + public DataPoint(double[] x, double y) { + this.x = x; + this.y = y; + } + + double[] x; + double y; + } + + static class ParsePoint extends Function { + public DataPoint call(String line) { + StringTokenizer tok = new StringTokenizer(line, " "); + double y = Double.parseDouble(tok.nextToken()); + double[] x = new double[D]; + int i = 0; + while (i < D) { + x[i] = Double.parseDouble(tok.nextToken()); + i += 1; + } + return new DataPoint(x, y); + } + } + + static class VectorSum extends Function2 { + public double[] call(double[] a, double[] b) { + double[] result = new double[D]; + for (int j = 0; j < D; j++) { + result[j] = a[j] + b[j]; + } + return result; + } + } + + static class ComputeGradient extends Function { + double[] weights; + + public ComputeGradient(double[] weights) { + this.weights = weights; + } + + public double[] call(DataPoint p) { + double[] gradient = new double[D]; + for (int i = 0; i < D; i++) { + double dot = dot(weights, p.x); + gradient[i] = (1 / (1 + Math.exp(-p.y * dot)) - 1) * p.y * p.x[i]; + } + return gradient; + } + } + + public static double dot(double[] a, double[] b) { + double x = 0; + for (int i = 0; i < D; i++) { + x += a[i] * b[i]; + } + return x; + } + + public static void printWeights(double[] a) { + System.out.println(Arrays.toString(a)); + } + + public static void main(String[] args) { + + if (args.length < 3) { + System.err.println("Usage: JavaHdfsLR "); + System.exit(1); + } + + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaHdfsLR", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + JavaRDD lines = sc.textFile(args[1]); + JavaRDD points = lines.map(new ParsePoint()).cache(); + int ITERATIONS = Integer.parseInt(args[2]); + + // Initialize w to a random value + double[] w = new double[D]; + for (int i = 0; i < D; i++) { + w[i] = 2 * rand.nextDouble() - 1; + } + + System.out.print("Initial w: "); + printWeights(w); + + for (int i = 1; i <= ITERATIONS; i++) { + System.out.println("On iteration " + i); + + double[] gradient = points.map( + new ComputeGradient(w) + ).reduce(new VectorSum()); + + for (int j = 0; j < D; j++) { + w[j] -= gradient[j]; + } + + } + + System.out.print("Final w: "); + printWeights(w); + System.exit(0); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java b/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java new file mode 100644 index 0000000000..5a6afe7eae --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples; + +import scala.Tuple2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.util.Vector; + +import java.util.List; +import java.util.Map; + +/** + * K-means clustering using Java API. + */ +public class JavaKMeans { + + /** Parses numbers split by whitespace to a vector */ + static Vector parseVector(String line) { + String[] splits = line.split(" "); + double[] data = new double[splits.length]; + int i = 0; + for (String s : splits) + data[i] = Double.parseDouble(splits[i++]); + return new Vector(data); + } + + /** Computes the vector to which the input vector is closest using squared distance */ + static int closestPoint(Vector p, List centers) { + int bestIndex = 0; + double closest = Double.POSITIVE_INFINITY; + for (int i = 0; i < centers.size(); i++) { + double tempDist = p.squaredDist(centers.get(i)); + if (tempDist < closest) { + closest = tempDist; + bestIndex = i; + } + } + return bestIndex; + } + + /** Computes the mean across all vectors in the input set of vectors */ + static Vector average(List ps) { + int numVectors = ps.size(); + Vector out = new Vector(ps.get(0).elements()); + // start from i = 1 since we already copied index 0 above + for (int i = 1; i < numVectors; i++) { + out.addInPlace(ps.get(i)); + } + return out.divide(numVectors); + } + + public static void main(String[] args) throws Exception { + if (args.length < 4) { + System.err.println("Usage: JavaKMeans "); + System.exit(1); + } + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + String path = args[1]; + int K = Integer.parseInt(args[2]); + double convergeDist = Double.parseDouble(args[3]); + + JavaRDD data = sc.textFile(path).map( + new Function() { + @Override + public Vector call(String line) throws Exception { + return parseVector(line); + } + } + ).cache(); + + final List centroids = data.takeSample(false, K, 42); + + double tempDist; + do { + // allocate each vector to closest centroid + JavaPairRDD closest = data.map( + new PairFunction() { + @Override + public Tuple2 call(Vector vector) throws Exception { + return new Tuple2( + closestPoint(vector, centroids), vector); + } + } + ); + + // group by cluster id and average the vectors within each cluster to compute centroids + JavaPairRDD> pointsGroup = closest.groupByKey(); + Map newCentroids = pointsGroup.mapValues( + new Function, Vector>() { + public Vector call(List ps) throws Exception { + return average(ps); + } + }).collectAsMap(); + tempDist = 0.0; + for (int i = 0; i < K; i++) { + tempDist += centroids.get(i).squaredDist(newCentroids.get(i)); + } + for (Map.Entry t: newCentroids.entrySet()) { + centroids.set(t.getKey(), t.getValue()); + } + System.out.println("Finished iteration (delta = " + tempDist + ")"); + } while (tempDist > convergeDist); + + System.out.println("Final centers:"); + for (Vector c : centroids) + System.out.println(c); + + System.exit(0); + + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java new file mode 100644 index 0000000000..152f029213 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples; + +import com.google.common.collect.Lists; +import scala.Tuple2; +import scala.Tuple3; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; + +import java.io.Serializable; +import java.util.Collections; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Executes a roll up-style query against Apache logs. + */ +public class JavaLogQuery { + + public static List exampleApacheLogs = Lists.newArrayList( + "10.10.10.10 - \"FRED\" [18/Jan/2013:17:56:07 +1100] \"GET http://images.com/2013/Generic.jpg " + + "HTTP/1.1\" 304 315 \"http://referall.com/\" \"Mozilla/4.0 (compatible; MSIE 7.0; " + + "Windows NT 5.1; GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; " + + ".NET CLR 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR " + + "3.5.30729; Release=ARP)\" \"UD-1\" - \"image/jpeg\" \"whatever\" 0.350 \"-\" - \"\" 265 923 934 \"\" " + + "62.24.11.25 images.com 1358492167 - Whatup", + "10.10.10.10 - \"FRED\" [18/Jan/2013:18:02:37 +1100] \"GET http://images.com/2013/Generic.jpg " + + "HTTP/1.1\" 304 306 \"http:/referall.com\" \"Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; " + + "GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR " + + "3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR " + + "3.5.30729; Release=ARP)\" \"UD-1\" - \"image/jpeg\" \"whatever\" 0.352 \"-\" - \"\" 256 977 988 \"\" " + + "0 73.23.2.15 images.com 1358492557 - Whatup"); + + public static Pattern apacheLogRegex = Pattern.compile( + "^([\\d.]+) (\\S+) (\\S+) \\[([\\w\\d:/]+\\s[+\\-]\\d{4})\\] \"(.+?)\" (\\d{3}) ([\\d\\-]+) \"([^\"]+)\" \"([^\"]+)\".*"); + + /** Tracks the total query count and number of aggregate bytes for a particular group. */ + public static class Stats implements Serializable { + + private int count; + private int numBytes; + + public Stats(int count, int numBytes) { + this.count = count; + this.numBytes = numBytes; + } + public Stats merge(Stats other) { + return new Stats(count + other.count, numBytes + other.numBytes); + } + + public String toString() { + return String.format("bytes=%s\tn=%s", numBytes, count); + } + } + + public static Tuple3 extractKey(String line) { + Matcher m = apacheLogRegex.matcher(line); + List key = Collections.emptyList(); + if (m.find()) { + String ip = m.group(1); + String user = m.group(3); + String query = m.group(5); + if (!user.equalsIgnoreCase("-")) { + return new Tuple3(ip, user, query); + } + } + return new Tuple3(null, null, null); + } + + public static Stats extractStats(String line) { + Matcher m = apacheLogRegex.matcher(line); + if (m.find()) { + int bytes = Integer.parseInt(m.group(7)); + return new Stats(1, bytes); + } + else + return new Stats(1, 0); + } + + public static void main(String[] args) throws Exception { + if (args.length == 0) { + System.err.println("Usage: JavaLogQuery [logFile]"); + System.exit(1); + } + + JavaSparkContext jsc = new JavaSparkContext(args[0], "JavaLogQuery", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + + JavaRDD dataSet = (args.length == 2) ? jsc.textFile(args[1]) : jsc.parallelize(exampleApacheLogs); + + JavaPairRDD, Stats> extracted = dataSet.map(new PairFunction, Stats>() { + @Override + public Tuple2, Stats> call(String s) throws Exception { + return new Tuple2, Stats>(extractKey(s), extractStats(s)); + } + }); + + JavaPairRDD, Stats> counts = extracted.reduceByKey(new Function2() { + @Override + public Stats call(Stats stats, Stats stats2) throws Exception { + return stats.merge(stats2); + } + }); + + List, Stats>> output = counts.collect(); + for (Tuple2 t : output) { + System.out.println(t._1 + "\t" + t._2); + } + System.exit(0); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java new file mode 100644 index 0000000000..c5603a639b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples; + +import scala.Tuple2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFlatMapFunction; +import org.apache.spark.api.java.function.PairFunction; + +import java.util.List; +import java.util.ArrayList; + +/** + * Computes the PageRank of URLs from an input file. Input file should + * be in format of: + * URL neighbor URL + * URL neighbor URL + * URL neighbor URL + * ... + * where URL and their neighbors are separated by space(s). + */ +public class JavaPageRank { + private static class Sum extends Function2 { + @Override + public Double call(Double a, Double b) { + return a + b; + } + } + + public static void main(String[] args) throws Exception { + if (args.length < 3) { + System.err.println("Usage: JavaPageRank "); + System.exit(1); + } + + JavaSparkContext ctx = new JavaSparkContext(args[0], "JavaPageRank", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + + // Loads in input file. It should be in format of: + // URL neighbor URL + // URL neighbor URL + // URL neighbor URL + // ... + JavaRDD lines = ctx.textFile(args[1], 1); + + // Loads all URLs from input file and initialize their neighbors. + JavaPairRDD> links = lines.map(new PairFunction() { + @Override + public Tuple2 call(String s) { + String[] parts = s.split("\\s+"); + return new Tuple2(parts[0], parts[1]); + } + }).distinct().groupByKey().cache(); + + // Loads all URLs with other URL(s) link to from input file and initialize ranks of them to one. + JavaPairRDD ranks = links.mapValues(new Function, Double>() { + @Override + public Double call(List rs) throws Exception { + return 1.0; + } + }); + + // Calculates and updates URL ranks continuously using PageRank algorithm. + for (int current = 0; current < Integer.parseInt(args[2]); current++) { + // Calculates URL contributions to the rank of other URLs. + JavaPairRDD contribs = links.join(ranks).values() + .flatMap(new PairFlatMapFunction, Double>, String, Double>() { + @Override + public Iterable> call(Tuple2, Double> s) { + List> results = new ArrayList>(); + for (String n : s._1) { + results.add(new Tuple2(n, s._2 / s._1.size())); + } + return results; + } + }); + + // Re-calculates URL ranks based on neighbor contributions. + ranks = contribs.reduceByKey(new Sum()).mapValues(new Function() { + @Override + public Double call(Double sum) throws Exception { + return 0.15 + sum * 0.85; + } + }); + } + + // Collects all URL ranks and dump them to console. + List> output = ranks.collect(); + for (Tuple2 tuple : output) { + System.out.println(tuple._1 + " has rank: " + tuple._2 + "."); + } + + System.exit(0); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java new file mode 100644 index 0000000000..4a2380caf5 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; + +import java.util.ArrayList; +import java.util.List; + +/** Computes an approximation to pi */ +public class JavaSparkPi { + + + public static void main(String[] args) throws Exception { + if (args.length == 0) { + System.err.println("Usage: JavaLogQuery [slices]"); + System.exit(1); + } + + JavaSparkContext jsc = new JavaSparkContext(args[0], "JavaLogQuery", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + + int slices = (args.length == 2) ? Integer.parseInt(args[1]) : 2; + int n = 100000 * slices; + List l = new ArrayList(n); + for (int i = 0; i < n; i++) + l.add(i); + + JavaRDD dataSet = jsc.parallelize(l, slices); + + int count = dataSet.map(new Function() { + @Override + public Integer call(Integer integer) throws Exception { + double x = Math.random() * 2 - 1; + double y = Math.random() * 2 - 1; + return (x * x + y * y < 1) ? 1 : 0; + } + }).reduce(new Function2() { + @Override + public Integer call(Integer integer, Integer integer2) throws Exception { + return integer + integer2; + } + }); + + System.out.println("Pi is roughly " + 4.0 * count / n); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/JavaTC.java b/examples/src/main/java/org/apache/spark/examples/JavaTC.java new file mode 100644 index 0000000000..17f21f6b77 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/JavaTC.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples; + +import scala.Tuple2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.PairFunction; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Set; + +/** + * Transitive closure on a graph, implemented in Java. + */ +public class JavaTC { + + static int numEdges = 200; + static int numVertices = 100; + static Random rand = new Random(42); + + static List> generateGraph() { + Set> edges = new HashSet>(numEdges); + while (edges.size() < numEdges) { + int from = rand.nextInt(numVertices); + int to = rand.nextInt(numVertices); + Tuple2 e = new Tuple2(from, to); + if (from != to) edges.add(e); + } + return new ArrayList>(edges); + } + + static class ProjectFn extends PairFunction>, + Integer, Integer> { + static ProjectFn INSTANCE = new ProjectFn(); + + public Tuple2 call(Tuple2> triple) { + return new Tuple2(triple._2()._2(), triple._2()._1()); + } + } + + public static void main(String[] args) { + if (args.length == 0) { + System.err.println("Usage: JavaTC []"); + System.exit(1); + } + + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaTC", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + Integer slices = (args.length > 1) ? Integer.parseInt(args[1]): 2; + JavaPairRDD tc = sc.parallelizePairs(generateGraph(), slices).cache(); + + // Linear transitive closure: each round grows paths by one edge, + // by joining the graph's edges with the already-discovered paths. + // e.g. join the path (y, z) from the TC with the edge (x, y) from + // the graph to obtain the path (x, z). + + // Because join() joins on keys, the edges are stored in reversed order. + JavaPairRDD edges = tc.map( + new PairFunction, Integer, Integer>() { + public Tuple2 call(Tuple2 e) { + return new Tuple2(e._2(), e._1()); + } + }); + + long oldCount = 0; + long nextCount = tc.count(); + do { + oldCount = nextCount; + // Perform the join, obtaining an RDD of (y, (z, x)) pairs, + // then project the result to obtain the new (x, z) paths. + tc = tc.union(tc.join(edges).map(ProjectFn.INSTANCE)).distinct().cache(); + nextCount = tc.count(); + } while (nextCount != oldCount); + + System.out.println("TC has " + tc.count() + " edges."); + System.exit(0); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java new file mode 100644 index 0000000000..07d32ad659 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples; + +import scala.Tuple2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; + +import java.util.Arrays; +import java.util.List; + +public class JavaWordCount { + public static void main(String[] args) throws Exception { + if (args.length < 2) { + System.err.println("Usage: JavaWordCount "); + System.exit(1); + } + + JavaSparkContext ctx = new JavaSparkContext(args[0], "JavaWordCount", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + JavaRDD lines = ctx.textFile(args[1], 1); + + JavaRDD words = lines.flatMap(new FlatMapFunction() { + public Iterable call(String s) { + return Arrays.asList(s.split(" ")); + } + }); + + JavaPairRDD ones = words.map(new PairFunction() { + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }); + + JavaPairRDD counts = ones.reduceByKey(new Function2() { + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } + }); + + List> output = counts.collect(); + for (Tuple2 tuple : output) { + System.out.println(tuple._1 + ": " + tuple._2); + } + System.exit(0); + } +} diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java new file mode 100644 index 0000000000..628cb892b6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.examples; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; + +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.mllib.recommendation.Rating; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.StringTokenizer; + +import scala.Tuple2; + +/** + * Example using MLLib ALS from Java. + */ +public class JavaALS { + + static class ParseRating extends Function { + public Rating call(String line) { + StringTokenizer tok = new StringTokenizer(line, ","); + int x = Integer.parseInt(tok.nextToken()); + int y = Integer.parseInt(tok.nextToken()); + double rating = Double.parseDouble(tok.nextToken()); + return new Rating(x, y, rating); + } + } + + static class FeaturesToString extends Function, String> { + public String call(Tuple2 element) { + return element._1().toString() + "," + Arrays.toString(element._2()); + } + } + + public static void main(String[] args) { + + if (args.length != 5 && args.length != 6) { + System.err.println( + "Usage: JavaALS []"); + System.exit(1); + } + + int rank = Integer.parseInt(args[2]); + int iterations = Integer.parseInt(args[3]); + String outputDir = args[4]; + int blocks = -1; + if (args.length == 6) { + blocks = Integer.parseInt(args[5]); + } + + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaALS", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + JavaRDD lines = sc.textFile(args[1]); + + JavaRDD ratings = lines.map(new ParseRating()); + + MatrixFactorizationModel model = ALS.train(ratings.rdd(), rank, iterations, 0.01, blocks); + + model.userFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile( + outputDir + "/userFeatures"); + model.productFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile( + outputDir + "/productFeatures"); + System.out.println("Final user/product features written to " + outputDir); + + System.exit(0); + } +} diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java new file mode 100644 index 0000000000..cd59a139b9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.examples; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; + +import org.apache.spark.mllib.clustering.KMeans; +import org.apache.spark.mllib.clustering.KMeansModel; + +import java.util.Arrays; +import java.util.StringTokenizer; + +/** + * Example using MLLib KMeans from Java. + */ +public class JavaKMeans { + + static class ParsePoint extends Function { + public double[] call(String line) { + StringTokenizer tok = new StringTokenizer(line, " "); + int numTokens = tok.countTokens(); + double[] point = new double[numTokens]; + for (int i = 0; i < numTokens; ++i) { + point[i] = Double.parseDouble(tok.nextToken()); + } + return point; + } + } + + public static void main(String[] args) { + + if (args.length < 4) { + System.err.println( + "Usage: JavaKMeans []"); + System.exit(1); + } + + String inputFile = args[1]; + int k = Integer.parseInt(args[2]); + int iterations = Integer.parseInt(args[3]); + int runs = 1; + + if (args.length >= 5) { + runs = Integer.parseInt(args[4]); + } + + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + JavaRDD lines = sc.textFile(args[1]); + + JavaRDD points = lines.map(new ParsePoint()); + + KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs); + + System.out.println("Cluster centers:"); + for (double[] center : model.clusterCenters()) { + System.out.println(" " + Arrays.toString(center)); + } + double cost = model.computeCost(points.rdd()); + System.out.println("Cost: " + cost); + + System.exit(0); + } +} diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java new file mode 100644 index 0000000000..258061c8e6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.examples; + + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; + +import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.regression.LabeledPoint; + +import java.util.Arrays; +import java.util.StringTokenizer; + +/** + * Logistic regression based classification using ML Lib. + */ +public class JavaLR { + + static class ParsePoint extends Function { + public LabeledPoint call(String line) { + String[] parts = line.split(","); + double y = Double.parseDouble(parts[0]); + StringTokenizer tok = new StringTokenizer(parts[1], " "); + int numTokens = tok.countTokens(); + double[] x = new double[numTokens]; + for (int i = 0; i < numTokens; ++i) { + x[i] = Double.parseDouble(tok.nextToken()); + } + return new LabeledPoint(y, x); + } + } + + public static void printWeights(double[] a) { + System.out.println(Arrays.toString(a)); + } + + public static void main(String[] args) { + if (args.length != 4) { + System.err.println("Usage: JavaLR "); + System.exit(1); + } + + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaLR", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + JavaRDD lines = sc.textFile(args[1]); + JavaRDD points = lines.map(new ParsePoint()).cache(); + double stepSize = Double.parseDouble(args[2]); + int iterations = Integer.parseInt(args[3]); + + // Another way to configure LogisticRegression + // + // LogisticRegressionWithSGD lr = new LogisticRegressionWithSGD(); + // lr.optimizer().setNumIterations(iterations) + // .setStepSize(stepSize) + // .setMiniBatchFraction(1.0); + // lr.setIntercept(true); + // LogisticRegressionModel model = lr.train(points.rdd()); + + LogisticRegressionModel model = LogisticRegressionWithSGD.train(points.rdd(), + iterations, stepSize); + + System.out.print("Final w: "); + printWeights(model.weights()); + + System.exit(0); + } +} diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java new file mode 100644 index 0000000000..261813bf2f --- /dev/null +++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples; + +import org.apache.spark.api.java.function.Function; +import org.apache.spark.streaming.*; +import org.apache.spark.streaming.api.java.*; +import org.apache.spark.streaming.dstream.SparkFlumeEvent; + +/** + * Produces a count of events received from Flume. + * + * This should be used in conjunction with an AvroSink in Flume. It will start + * an Avro server on at the request host:port address and listen for requests. + * Your Flume AvroSink should be pointed to this address. + * + * Usage: JavaFlumeEventCount + * + * is a Spark master URL + * is the host the Flume receiver will be started on - a receiver + * creates a server and listens for flume events. + * is the port the Flume receiver will listen on. + */ +public class JavaFlumeEventCount { + public static void main(String[] args) { + if (args.length != 3) { + System.err.println("Usage: JavaFlumeEventCount "); + System.exit(1); + } + + String master = args[0]; + String host = args[1]; + int port = Integer.parseInt(args[2]); + + Duration batchInterval = new Duration(2000); + + JavaStreamingContext sc = new JavaStreamingContext(master, "FlumeEventCount", batchInterval, + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + + JavaDStream flumeStream = sc.flumeStream("localhost", port); + + flumeStream.count(); + + flumeStream.count().map(new Function() { + @Override + public String call(Long in) { + return "Received " + in + " flume events."; + } + }).print(); + + sc.start(); + } +} diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java new file mode 100644 index 0000000000..def87c199b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples; + +import com.google.common.collect.Lists; +import scala.Tuple2; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + * Usage: NetworkWordCount + * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ ./run spark.streaming.examples.JavaNetworkWordCount local[2] localhost 9999` + */ +public class JavaNetworkWordCount { + public static void main(String[] args) { + if (args.length < 3) { + System.err.println("Usage: NetworkWordCount \n" + + "In local mode, should be 'local[n]' with n > 1"); + System.exit(1); + } + + // Create the context with a 1 second batch size + JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount", + new Duration(1000), System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited test (eg. generated by 'nc') + JavaDStream lines = ssc.socketTextStream(args[1], Integer.parseInt(args[2])); + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(x.split(" ")); + } + }); + JavaPairDStream wordCounts = words.map( + new PairFunction() { + @Override + public Tuple2 call(String s) throws Exception { + return new Tuple2(s, 1); + } + }).reduceByKey(new Function2() { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 + i2; + } + }); + + wordCounts.print(); + ssc.start(); + + } +} diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java new file mode 100644 index 0000000000..c8c7389dd1 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples; + +import com.google.common.collect.Lists; +import scala.Tuple2; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +import java.util.LinkedList; +import java.util.List; +import java.util.Queue; + +public class JavaQueueStream { + public static void main(String[] args) throws InterruptedException { + if (args.length < 1) { + System.err.println("Usage: JavaQueueStream "); + System.exit(1); + } + + // Create the context + JavaStreamingContext ssc = new JavaStreamingContext(args[0], "QueueStream", new Duration(1000), + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + + // Create the queue through which RDDs can be pushed to + // a QueueInputDStream + Queue> rddQueue = new LinkedList>(); + + // Create and push some RDDs into the queue + List list = Lists.newArrayList(); + for (int i = 0; i < 1000; i++) { + list.add(i); + } + + for (int i = 0; i < 30; i++) { + rddQueue.add(ssc.sc().parallelize(list)); + } + + + // Create the QueueInputDStream and use it do some processing + JavaDStream inputStream = ssc.queueStream(rddQueue); + JavaPairDStream mappedStream = inputStream.map( + new PairFunction() { + @Override + public Tuple2 call(Integer i) throws Exception { + return new Tuple2(i % 10, 1); + } + }); + JavaPairDStream reducedStream = mappedStream.reduceByKey( + new Function2() { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 + i2; + } + }); + + reducedStream.print(); + ssc.start(); + } +} diff --git a/examples/src/main/java/spark/examples/JavaHdfsLR.java b/examples/src/main/java/spark/examples/JavaHdfsLR.java deleted file mode 100644 index 9485e0cfa9..0000000000 --- a/examples/src/main/java/spark/examples/JavaHdfsLR.java +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.Function; -import spark.api.java.function.Function2; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.StringTokenizer; -import java.util.Random; - -/** - * Logistic regression based classification. - */ -public class JavaHdfsLR { - - static int D = 10; // Number of dimensions - static Random rand = new Random(42); - - static class DataPoint implements Serializable { - public DataPoint(double[] x, double y) { - this.x = x; - this.y = y; - } - - double[] x; - double y; - } - - static class ParsePoint extends Function { - public DataPoint call(String line) { - StringTokenizer tok = new StringTokenizer(line, " "); - double y = Double.parseDouble(tok.nextToken()); - double[] x = new double[D]; - int i = 0; - while (i < D) { - x[i] = Double.parseDouble(tok.nextToken()); - i += 1; - } - return new DataPoint(x, y); - } - } - - static class VectorSum extends Function2 { - public double[] call(double[] a, double[] b) { - double[] result = new double[D]; - for (int j = 0; j < D; j++) { - result[j] = a[j] + b[j]; - } - return result; - } - } - - static class ComputeGradient extends Function { - double[] weights; - - public ComputeGradient(double[] weights) { - this.weights = weights; - } - - public double[] call(DataPoint p) { - double[] gradient = new double[D]; - for (int i = 0; i < D; i++) { - double dot = dot(weights, p.x); - gradient[i] = (1 / (1 + Math.exp(-p.y * dot)) - 1) * p.y * p.x[i]; - } - return gradient; - } - } - - public static double dot(double[] a, double[] b) { - double x = 0; - for (int i = 0; i < D; i++) { - x += a[i] * b[i]; - } - return x; - } - - public static void printWeights(double[] a) { - System.out.println(Arrays.toString(a)); - } - - public static void main(String[] args) { - - if (args.length < 3) { - System.err.println("Usage: JavaHdfsLR "); - System.exit(1); - } - - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaHdfsLR", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - JavaRDD lines = sc.textFile(args[1]); - JavaRDD points = lines.map(new ParsePoint()).cache(); - int ITERATIONS = Integer.parseInt(args[2]); - - // Initialize w to a random value - double[] w = new double[D]; - for (int i = 0; i < D; i++) { - w[i] = 2 * rand.nextDouble() - 1; - } - - System.out.print("Initial w: "); - printWeights(w); - - for (int i = 1; i <= ITERATIONS; i++) { - System.out.println("On iteration " + i); - - double[] gradient = points.map( - new ComputeGradient(w) - ).reduce(new VectorSum()); - - for (int j = 0; j < D; j++) { - w[j] -= gradient[j]; - } - - } - - System.out.print("Final w: "); - printWeights(w); - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/examples/JavaKMeans.java b/examples/src/main/java/spark/examples/JavaKMeans.java deleted file mode 100644 index 2d34776177..0000000000 --- a/examples/src/main/java/spark/examples/JavaKMeans.java +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples; - -import scala.Tuple2; -import spark.api.java.JavaPairRDD; -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.Function; -import spark.api.java.function.PairFunction; -import spark.util.Vector; - -import java.util.List; -import java.util.Map; - -/** - * K-means clustering using Java API. - */ -public class JavaKMeans { - - /** Parses numbers split by whitespace to a vector */ - static Vector parseVector(String line) { - String[] splits = line.split(" "); - double[] data = new double[splits.length]; - int i = 0; - for (String s : splits) - data[i] = Double.parseDouble(splits[i++]); - return new Vector(data); - } - - /** Computes the vector to which the input vector is closest using squared distance */ - static int closestPoint(Vector p, List centers) { - int bestIndex = 0; - double closest = Double.POSITIVE_INFINITY; - for (int i = 0; i < centers.size(); i++) { - double tempDist = p.squaredDist(centers.get(i)); - if (tempDist < closest) { - closest = tempDist; - bestIndex = i; - } - } - return bestIndex; - } - - /** Computes the mean across all vectors in the input set of vectors */ - static Vector average(List ps) { - int numVectors = ps.size(); - Vector out = new Vector(ps.get(0).elements()); - // start from i = 1 since we already copied index 0 above - for (int i = 1; i < numVectors; i++) { - out.addInPlace(ps.get(i)); - } - return out.divide(numVectors); - } - - public static void main(String[] args) throws Exception { - if (args.length < 4) { - System.err.println("Usage: JavaKMeans "); - System.exit(1); - } - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - String path = args[1]; - int K = Integer.parseInt(args[2]); - double convergeDist = Double.parseDouble(args[3]); - - JavaRDD data = sc.textFile(path).map( - new Function() { - @Override - public Vector call(String line) throws Exception { - return parseVector(line); - } - } - ).cache(); - - final List centroids = data.takeSample(false, K, 42); - - double tempDist; - do { - // allocate each vector to closest centroid - JavaPairRDD closest = data.map( - new PairFunction() { - @Override - public Tuple2 call(Vector vector) throws Exception { - return new Tuple2( - closestPoint(vector, centroids), vector); - } - } - ); - - // group by cluster id and average the vectors within each cluster to compute centroids - JavaPairRDD> pointsGroup = closest.groupByKey(); - Map newCentroids = pointsGroup.mapValues( - new Function, Vector>() { - public Vector call(List ps) throws Exception { - return average(ps); - } - }).collectAsMap(); - tempDist = 0.0; - for (int i = 0; i < K; i++) { - tempDist += centroids.get(i).squaredDist(newCentroids.get(i)); - } - for (Map.Entry t: newCentroids.entrySet()) { - centroids.set(t.getKey(), t.getValue()); - } - System.out.println("Finished iteration (delta = " + tempDist + ")"); - } while (tempDist > convergeDist); - - System.out.println("Final centers:"); - for (Vector c : centroids) - System.out.println(c); - - System.exit(0); - - } -} diff --git a/examples/src/main/java/spark/examples/JavaLogQuery.java b/examples/src/main/java/spark/examples/JavaLogQuery.java deleted file mode 100644 index d22684d980..0000000000 --- a/examples/src/main/java/spark/examples/JavaLogQuery.java +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples; - -import com.google.common.collect.Lists; -import scala.Tuple2; -import scala.Tuple3; -import spark.api.java.JavaPairRDD; -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.Function2; -import spark.api.java.function.PairFunction; - -import java.io.Serializable; -import java.util.Collections; -import java.util.List; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -/** - * Executes a roll up-style query against Apache logs. - */ -public class JavaLogQuery { - - public static List exampleApacheLogs = Lists.newArrayList( - "10.10.10.10 - \"FRED\" [18/Jan/2013:17:56:07 +1100] \"GET http://images.com/2013/Generic.jpg " + - "HTTP/1.1\" 304 315 \"http://referall.com/\" \"Mozilla/4.0 (compatible; MSIE 7.0; " + - "Windows NT 5.1; GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; " + - ".NET CLR 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR " + - "3.5.30729; Release=ARP)\" \"UD-1\" - \"image/jpeg\" \"whatever\" 0.350 \"-\" - \"\" 265 923 934 \"\" " + - "62.24.11.25 images.com 1358492167 - Whatup", - "10.10.10.10 - \"FRED\" [18/Jan/2013:18:02:37 +1100] \"GET http://images.com/2013/Generic.jpg " + - "HTTP/1.1\" 304 306 \"http:/referall.com\" \"Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; " + - "GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR " + - "3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR " + - "3.5.30729; Release=ARP)\" \"UD-1\" - \"image/jpeg\" \"whatever\" 0.352 \"-\" - \"\" 256 977 988 \"\" " + - "0 73.23.2.15 images.com 1358492557 - Whatup"); - - public static Pattern apacheLogRegex = Pattern.compile( - "^([\\d.]+) (\\S+) (\\S+) \\[([\\w\\d:/]+\\s[+\\-]\\d{4})\\] \"(.+?)\" (\\d{3}) ([\\d\\-]+) \"([^\"]+)\" \"([^\"]+)\".*"); - - /** Tracks the total query count and number of aggregate bytes for a particular group. */ - public static class Stats implements Serializable { - - private int count; - private int numBytes; - - public Stats(int count, int numBytes) { - this.count = count; - this.numBytes = numBytes; - } - public Stats merge(Stats other) { - return new Stats(count + other.count, numBytes + other.numBytes); - } - - public String toString() { - return String.format("bytes=%s\tn=%s", numBytes, count); - } - } - - public static Tuple3 extractKey(String line) { - Matcher m = apacheLogRegex.matcher(line); - List key = Collections.emptyList(); - if (m.find()) { - String ip = m.group(1); - String user = m.group(3); - String query = m.group(5); - if (!user.equalsIgnoreCase("-")) { - return new Tuple3(ip, user, query); - } - } - return new Tuple3(null, null, null); - } - - public static Stats extractStats(String line) { - Matcher m = apacheLogRegex.matcher(line); - if (m.find()) { - int bytes = Integer.parseInt(m.group(7)); - return new Stats(1, bytes); - } - else - return new Stats(1, 0); - } - - public static void main(String[] args) throws Exception { - if (args.length == 0) { - System.err.println("Usage: JavaLogQuery [logFile]"); - System.exit(1); - } - - JavaSparkContext jsc = new JavaSparkContext(args[0], "JavaLogQuery", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - - JavaRDD dataSet = (args.length == 2) ? jsc.textFile(args[1]) : jsc.parallelize(exampleApacheLogs); - - JavaPairRDD, Stats> extracted = dataSet.map(new PairFunction, Stats>() { - @Override - public Tuple2, Stats> call(String s) throws Exception { - return new Tuple2, Stats>(extractKey(s), extractStats(s)); - } - }); - - JavaPairRDD, Stats> counts = extracted.reduceByKey(new Function2() { - @Override - public Stats call(Stats stats, Stats stats2) throws Exception { - return stats.merge(stats2); - } - }); - - List, Stats>> output = counts.collect(); - for (Tuple2 t : output) { - System.out.println(t._1 + "\t" + t._2); - } - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/examples/JavaPageRank.java b/examples/src/main/java/spark/examples/JavaPageRank.java deleted file mode 100644 index 75df1af2e3..0000000000 --- a/examples/src/main/java/spark/examples/JavaPageRank.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples; - -import scala.Tuple2; -import spark.api.java.JavaPairRDD; -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.FlatMapFunction; -import spark.api.java.function.Function; -import spark.api.java.function.Function2; -import spark.api.java.function.PairFlatMapFunction; -import spark.api.java.function.PairFunction; - -import java.util.List; -import java.util.ArrayList; - -/** - * Computes the PageRank of URLs from an input file. Input file should - * be in format of: - * URL neighbor URL - * URL neighbor URL - * URL neighbor URL - * ... - * where URL and their neighbors are separated by space(s). - */ -public class JavaPageRank { - private static class Sum extends Function2 { - @Override - public Double call(Double a, Double b) { - return a + b; - } - } - - public static void main(String[] args) throws Exception { - if (args.length < 3) { - System.err.println("Usage: JavaPageRank "); - System.exit(1); - } - - JavaSparkContext ctx = new JavaSparkContext(args[0], "JavaPageRank", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - - // Loads in input file. It should be in format of: - // URL neighbor URL - // URL neighbor URL - // URL neighbor URL - // ... - JavaRDD lines = ctx.textFile(args[1], 1); - - // Loads all URLs from input file and initialize their neighbors. - JavaPairRDD> links = lines.map(new PairFunction() { - @Override - public Tuple2 call(String s) { - String[] parts = s.split("\\s+"); - return new Tuple2(parts[0], parts[1]); - } - }).distinct().groupByKey().cache(); - - // Loads all URLs with other URL(s) link to from input file and initialize ranks of them to one. - JavaPairRDD ranks = links.mapValues(new Function, Double>() { - @Override - public Double call(List rs) throws Exception { - return 1.0; - } - }); - - // Calculates and updates URL ranks continuously using PageRank algorithm. - for (int current = 0; current < Integer.parseInt(args[2]); current++) { - // Calculates URL contributions to the rank of other URLs. - JavaPairRDD contribs = links.join(ranks).values() - .flatMap(new PairFlatMapFunction, Double>, String, Double>() { - @Override - public Iterable> call(Tuple2, Double> s) { - List> results = new ArrayList>(); - for (String n : s._1) { - results.add(new Tuple2(n, s._2 / s._1.size())); - } - return results; - } - }); - - // Re-calculates URL ranks based on neighbor contributions. - ranks = contribs.reduceByKey(new Sum()).mapValues(new Function() { - @Override - public Double call(Double sum) throws Exception { - return 0.15 + sum * 0.85; - } - }); - } - - // Collects all URL ranks and dump them to console. - List> output = ranks.collect(); - for (Tuple2 tuple : output) { - System.out.println(tuple._1 + " has rank: " + tuple._2 + "."); - } - - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/examples/JavaSparkPi.java b/examples/src/main/java/spark/examples/JavaSparkPi.java deleted file mode 100644 index d5f42fbb38..0000000000 --- a/examples/src/main/java/spark/examples/JavaSparkPi.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.Function; -import spark.api.java.function.Function2; - -import java.util.ArrayList; -import java.util.List; - -/** Computes an approximation to pi */ -public class JavaSparkPi { - - - public static void main(String[] args) throws Exception { - if (args.length == 0) { - System.err.println("Usage: JavaLogQuery [slices]"); - System.exit(1); - } - - JavaSparkContext jsc = new JavaSparkContext(args[0], "JavaLogQuery", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - - int slices = (args.length == 2) ? Integer.parseInt(args[1]) : 2; - int n = 100000 * slices; - List l = new ArrayList(n); - for (int i = 0; i < n; i++) - l.add(i); - - JavaRDD dataSet = jsc.parallelize(l, slices); - - int count = dataSet.map(new Function() { - @Override - public Integer call(Integer integer) throws Exception { - double x = Math.random() * 2 - 1; - double y = Math.random() * 2 - 1; - return (x * x + y * y < 1) ? 1 : 0; - } - }).reduce(new Function2() { - @Override - public Integer call(Integer integer, Integer integer2) throws Exception { - return integer + integer2; - } - }); - - System.out.println("Pi is roughly " + 4.0 * count / n); - } -} diff --git a/examples/src/main/java/spark/examples/JavaTC.java b/examples/src/main/java/spark/examples/JavaTC.java deleted file mode 100644 index 559d7f9e53..0000000000 --- a/examples/src/main/java/spark/examples/JavaTC.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples; - -import scala.Tuple2; -import spark.api.java.JavaPairRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.PairFunction; - -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Random; -import java.util.Set; - -/** - * Transitive closure on a graph, implemented in Java. - */ -public class JavaTC { - - static int numEdges = 200; - static int numVertices = 100; - static Random rand = new Random(42); - - static List> generateGraph() { - Set> edges = new HashSet>(numEdges); - while (edges.size() < numEdges) { - int from = rand.nextInt(numVertices); - int to = rand.nextInt(numVertices); - Tuple2 e = new Tuple2(from, to); - if (from != to) edges.add(e); - } - return new ArrayList>(edges); - } - - static class ProjectFn extends PairFunction>, - Integer, Integer> { - static ProjectFn INSTANCE = new ProjectFn(); - - public Tuple2 call(Tuple2> triple) { - return new Tuple2(triple._2()._2(), triple._2()._1()); - } - } - - public static void main(String[] args) { - if (args.length == 0) { - System.err.println("Usage: JavaTC []"); - System.exit(1); - } - - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaTC", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - Integer slices = (args.length > 1) ? Integer.parseInt(args[1]): 2; - JavaPairRDD tc = sc.parallelizePairs(generateGraph(), slices).cache(); - - // Linear transitive closure: each round grows paths by one edge, - // by joining the graph's edges with the already-discovered paths. - // e.g. join the path (y, z) from the TC with the edge (x, y) from - // the graph to obtain the path (x, z). - - // Because join() joins on keys, the edges are stored in reversed order. - JavaPairRDD edges = tc.map( - new PairFunction, Integer, Integer>() { - public Tuple2 call(Tuple2 e) { - return new Tuple2(e._2(), e._1()); - } - }); - - long oldCount = 0; - long nextCount = tc.count(); - do { - oldCount = nextCount; - // Perform the join, obtaining an RDD of (y, (z, x)) pairs, - // then project the result to obtain the new (x, z) paths. - tc = tc.union(tc.join(edges).map(ProjectFn.INSTANCE)).distinct().cache(); - nextCount = tc.count(); - } while (nextCount != oldCount); - - System.out.println("TC has " + tc.count() + " edges."); - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/examples/JavaWordCount.java b/examples/src/main/java/spark/examples/JavaWordCount.java deleted file mode 100644 index 1af370c1c3..0000000000 --- a/examples/src/main/java/spark/examples/JavaWordCount.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples; - -import scala.Tuple2; -import spark.api.java.JavaPairRDD; -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.FlatMapFunction; -import spark.api.java.function.Function2; -import spark.api.java.function.PairFunction; - -import java.util.Arrays; -import java.util.List; - -public class JavaWordCount { - public static void main(String[] args) throws Exception { - if (args.length < 2) { - System.err.println("Usage: JavaWordCount "); - System.exit(1); - } - - JavaSparkContext ctx = new JavaSparkContext(args[0], "JavaWordCount", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - JavaRDD lines = ctx.textFile(args[1], 1); - - JavaRDD words = lines.flatMap(new FlatMapFunction() { - public Iterable call(String s) { - return Arrays.asList(s.split(" ")); - } - }); - - JavaPairRDD ones = words.map(new PairFunction() { - public Tuple2 call(String s) { - return new Tuple2(s, 1); - } - }); - - JavaPairRDD counts = ones.reduceByKey(new Function2() { - public Integer call(Integer i1, Integer i2) { - return i1 + i2; - } - }); - - List> output = counts.collect(); - for (Tuple2 tuple : output) { - System.out.println(tuple._1 + ": " + tuple._2); - } - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/mllib/examples/JavaALS.java b/examples/src/main/java/spark/mllib/examples/JavaALS.java deleted file mode 100644 index b48f459cb7..0000000000 --- a/examples/src/main/java/spark/mllib/examples/JavaALS.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.examples; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.Function; - -import spark.mllib.recommendation.ALS; -import spark.mllib.recommendation.MatrixFactorizationModel; -import spark.mllib.recommendation.Rating; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.StringTokenizer; - -import scala.Tuple2; - -/** - * Example using MLLib ALS from Java. - */ -public class JavaALS { - - static class ParseRating extends Function { - public Rating call(String line) { - StringTokenizer tok = new StringTokenizer(line, ","); - int x = Integer.parseInt(tok.nextToken()); - int y = Integer.parseInt(tok.nextToken()); - double rating = Double.parseDouble(tok.nextToken()); - return new Rating(x, y, rating); - } - } - - static class FeaturesToString extends Function, String> { - public String call(Tuple2 element) { - return element._1().toString() + "," + Arrays.toString(element._2()); - } - } - - public static void main(String[] args) { - - if (args.length != 5 && args.length != 6) { - System.err.println( - "Usage: JavaALS []"); - System.exit(1); - } - - int rank = Integer.parseInt(args[2]); - int iterations = Integer.parseInt(args[3]); - String outputDir = args[4]; - int blocks = -1; - if (args.length == 6) { - blocks = Integer.parseInt(args[5]); - } - - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaALS", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - JavaRDD lines = sc.textFile(args[1]); - - JavaRDD ratings = lines.map(new ParseRating()); - - MatrixFactorizationModel model = ALS.train(ratings.rdd(), rank, iterations, 0.01, blocks); - - model.userFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile( - outputDir + "/userFeatures"); - model.productFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile( - outputDir + "/productFeatures"); - System.out.println("Final user/product features written to " + outputDir); - - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/mllib/examples/JavaKMeans.java b/examples/src/main/java/spark/mllib/examples/JavaKMeans.java deleted file mode 100644 index 02f40438b8..0000000000 --- a/examples/src/main/java/spark/mllib/examples/JavaKMeans.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.examples; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.Function; - -import spark.mllib.clustering.KMeans; -import spark.mllib.clustering.KMeansModel; - -import java.util.Arrays; -import java.util.StringTokenizer; - -/** - * Example using MLLib KMeans from Java. - */ -public class JavaKMeans { - - static class ParsePoint extends Function { - public double[] call(String line) { - StringTokenizer tok = new StringTokenizer(line, " "); - int numTokens = tok.countTokens(); - double[] point = new double[numTokens]; - for (int i = 0; i < numTokens; ++i) { - point[i] = Double.parseDouble(tok.nextToken()); - } - return point; - } - } - - public static void main(String[] args) { - - if (args.length < 4) { - System.err.println( - "Usage: JavaKMeans []"); - System.exit(1); - } - - String inputFile = args[1]; - int k = Integer.parseInt(args[2]); - int iterations = Integer.parseInt(args[3]); - int runs = 1; - - if (args.length >= 5) { - runs = Integer.parseInt(args[4]); - } - - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - JavaRDD lines = sc.textFile(args[1]); - - JavaRDD points = lines.map(new ParsePoint()); - - KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs); - - System.out.println("Cluster centers:"); - for (double[] center : model.clusterCenters()) { - System.out.println(" " + Arrays.toString(center)); - } - double cost = model.computeCost(points.rdd()); - System.out.println("Cost: " + cost); - - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/mllib/examples/JavaLR.java b/examples/src/main/java/spark/mllib/examples/JavaLR.java deleted file mode 100644 index bf4aeaf40f..0000000000 --- a/examples/src/main/java/spark/mllib/examples/JavaLR.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.examples; - - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.Function; - -import spark.mllib.classification.LogisticRegressionWithSGD; -import spark.mllib.classification.LogisticRegressionModel; -import spark.mllib.regression.LabeledPoint; - -import java.util.Arrays; -import java.util.StringTokenizer; - -/** - * Logistic regression based classification using ML Lib. - */ -public class JavaLR { - - static class ParsePoint extends Function { - public LabeledPoint call(String line) { - String[] parts = line.split(","); - double y = Double.parseDouble(parts[0]); - StringTokenizer tok = new StringTokenizer(parts[1], " "); - int numTokens = tok.countTokens(); - double[] x = new double[numTokens]; - for (int i = 0; i < numTokens; ++i) { - x[i] = Double.parseDouble(tok.nextToken()); - } - return new LabeledPoint(y, x); - } - } - - public static void printWeights(double[] a) { - System.out.println(Arrays.toString(a)); - } - - public static void main(String[] args) { - if (args.length != 4) { - System.err.println("Usage: JavaLR "); - System.exit(1); - } - - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaLR", - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - JavaRDD lines = sc.textFile(args[1]); - JavaRDD points = lines.map(new ParsePoint()).cache(); - double stepSize = Double.parseDouble(args[2]); - int iterations = Integer.parseInt(args[3]); - - // Another way to configure LogisticRegression - // - // LogisticRegressionWithSGD lr = new LogisticRegressionWithSGD(); - // lr.optimizer().setNumIterations(iterations) - // .setStepSize(stepSize) - // .setMiniBatchFraction(1.0); - // lr.setIntercept(true); - // LogisticRegressionModel model = lr.train(points.rdd()); - - LogisticRegressionModel model = LogisticRegressionWithSGD.train(points.rdd(), - iterations, stepSize); - - System.out.print("Final w: "); - printWeights(model.weights()); - - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/streaming/examples/JavaFlumeEventCount.java b/examples/src/main/java/spark/streaming/examples/JavaFlumeEventCount.java deleted file mode 100644 index 096a9ae219..0000000000 --- a/examples/src/main/java/spark/streaming/examples/JavaFlumeEventCount.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples; - -import spark.api.java.function.Function; -import spark.streaming.*; -import spark.streaming.api.java.*; -import spark.streaming.dstream.SparkFlumeEvent; - -/** - * Produces a count of events received from Flume. - * - * This should be used in conjunction with an AvroSink in Flume. It will start - * an Avro server on at the request host:port address and listen for requests. - * Your Flume AvroSink should be pointed to this address. - * - * Usage: JavaFlumeEventCount - * - * is a Spark master URL - * is the host the Flume receiver will be started on - a receiver - * creates a server and listens for flume events. - * is the port the Flume receiver will listen on. - */ -public class JavaFlumeEventCount { - public static void main(String[] args) { - if (args.length != 3) { - System.err.println("Usage: JavaFlumeEventCount "); - System.exit(1); - } - - String master = args[0]; - String host = args[1]; - int port = Integer.parseInt(args[2]); - - Duration batchInterval = new Duration(2000); - - JavaStreamingContext sc = new JavaStreamingContext(master, "FlumeEventCount", batchInterval, - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - - JavaDStream flumeStream = sc.flumeStream("localhost", port); - - flumeStream.count(); - - flumeStream.count().map(new Function() { - @Override - public String call(Long in) { - return "Received " + in + " flume events."; - } - }).print(); - - sc.start(); - } -} diff --git a/examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java b/examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java deleted file mode 100644 index c54d3f3d59..0000000000 --- a/examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples; - -import com.google.common.collect.Lists; -import scala.Tuple2; -import spark.api.java.function.FlatMapFunction; -import spark.api.java.function.Function2; -import spark.api.java.function.PairFunction; -import spark.streaming.Duration; -import spark.streaming.api.java.JavaDStream; -import spark.streaming.api.java.JavaPairDStream; -import spark.streaming.api.java.JavaStreamingContext; - -/** - * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. - * Usage: NetworkWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. - * and describe the TCP server that Spark Streaming would connect to receive data. - * - * To run this on your local machine, you need to first run a Netcat server - * `$ nc -lk 9999` - * and then run the example - * `$ ./run spark.streaming.examples.JavaNetworkWordCount local[2] localhost 9999` - */ -public class JavaNetworkWordCount { - public static void main(String[] args) { - if (args.length < 3) { - System.err.println("Usage: NetworkWordCount \n" + - "In local mode, should be 'local[n]' with n > 1"); - System.exit(1); - } - - // Create the context with a 1 second batch size - JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount", - new Duration(1000), System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - - // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') - JavaDStream lines = ssc.socketTextStream(args[1], Integer.parseInt(args[2])); - JavaDStream words = lines.flatMap(new FlatMapFunction() { - @Override - public Iterable call(String x) { - return Lists.newArrayList(x.split(" ")); - } - }); - JavaPairDStream wordCounts = words.map( - new PairFunction() { - @Override - public Tuple2 call(String s) throws Exception { - return new Tuple2(s, 1); - } - }).reduceByKey(new Function2() { - @Override - public Integer call(Integer i1, Integer i2) throws Exception { - return i1 + i2; - } - }); - - wordCounts.print(); - ssc.start(); - - } -} diff --git a/examples/src/main/java/spark/streaming/examples/JavaQueueStream.java b/examples/src/main/java/spark/streaming/examples/JavaQueueStream.java deleted file mode 100644 index 1f4a991542..0000000000 --- a/examples/src/main/java/spark/streaming/examples/JavaQueueStream.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples; - -import com.google.common.collect.Lists; -import scala.Tuple2; -import spark.api.java.JavaRDD; -import spark.api.java.function.Function2; -import spark.api.java.function.PairFunction; -import spark.streaming.Duration; -import spark.streaming.api.java.JavaDStream; -import spark.streaming.api.java.JavaPairDStream; -import spark.streaming.api.java.JavaStreamingContext; - -import java.util.LinkedList; -import java.util.List; -import java.util.Queue; - -public class JavaQueueStream { - public static void main(String[] args) throws InterruptedException { - if (args.length < 1) { - System.err.println("Usage: JavaQueueStream "); - System.exit(1); - } - - // Create the context - JavaStreamingContext ssc = new JavaStreamingContext(args[0], "QueueStream", new Duration(1000), - System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); - - // Create the queue through which RDDs can be pushed to - // a QueueInputDStream - Queue> rddQueue = new LinkedList>(); - - // Create and push some RDDs into the queue - List list = Lists.newArrayList(); - for (int i = 0; i < 1000; i++) { - list.add(i); - } - - for (int i = 0; i < 30; i++) { - rddQueue.add(ssc.sc().parallelize(list)); - } - - - // Create the QueueInputDStream and use it do some processing - JavaDStream inputStream = ssc.queueStream(rddQueue); - JavaPairDStream mappedStream = inputStream.map( - new PairFunction() { - @Override - public Tuple2 call(Integer i) throws Exception { - return new Tuple2(i % 10, 1); - } - }); - JavaPairDStream reducedStream = mappedStream.reduceByKey( - new Function2() { - @Override - public Integer call(Integer i1, Integer i2) throws Exception { - return i1 + i2; - } - }); - - reducedStream.print(); - ssc.start(); - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala new file mode 100644 index 0000000000..868ff81f67 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark.SparkContext + +object BroadcastTest { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: BroadcastTest [] [numElem]") + System.exit(1) + } + + val sc = new SparkContext(args(0), "Broadcast Test", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val slices = if (args.length > 1) args(1).toInt else 2 + val num = if (args.length > 2) args(2).toInt else 1000000 + + var arr1 = new Array[Int](num) + for (i <- 0 until arr1.length) { + arr1(i) = i + } + + for (i <- 0 until 2) { + println("Iteration " + i) + println("===========") + val barr1 = sc.broadcast(arr1) + sc.parallelize(1 to 10, slices).foreach { + i => println(barr1.value.size) + } + } + + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala new file mode 100644 index 0000000000..33bf7151a7 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.hadoop.mapreduce.Job +import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat +import org.apache.cassandra.hadoop.ConfigHelper +import org.apache.cassandra.hadoop.ColumnFamilyInputFormat +import org.apache.cassandra.thrift._ +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import java.nio.ByteBuffer +import java.util.SortedMap +import org.apache.cassandra.db.IColumn +import org.apache.cassandra.utils.ByteBufferUtil +import scala.collection.JavaConversions._ + + +/* + * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra + * support for Hadoop. + * + * To run this example, run this file with the following command params - + * + * + * So if you want to run this on localhost this will be, + * local[3] localhost 9160 + * + * The example makes some assumptions: + * 1. You have already created a keyspace called casDemo and it has a column family named Words + * 2. There are column family has a column named "para" which has test content. + * + * You can create the content by running the following script at the bottom of this file with + * cassandra-cli. + * + */ +object CassandraTest { + + def main(args: Array[String]) { + + // Get a SparkContext + val sc = new SparkContext(args(0), "casDemo") + + // Build the job configuration with ConfigHelper provided by Cassandra + val job = new Job() + job.setInputFormatClass(classOf[ColumnFamilyInputFormat]) + + val host: String = args(1) + val port: String = args(2) + + ConfigHelper.setInputInitialAddress(job.getConfiguration(), host) + ConfigHelper.setInputRpcPort(job.getConfiguration(), port) + ConfigHelper.setOutputInitialAddress(job.getConfiguration(), host) + ConfigHelper.setOutputRpcPort(job.getConfiguration(), port) + ConfigHelper.setInputColumnFamily(job.getConfiguration(), "casDemo", "Words") + ConfigHelper.setOutputColumnFamily(job.getConfiguration(), "casDemo", "WordCount") + + val predicate = new SlicePredicate() + val sliceRange = new SliceRange() + sliceRange.setStart(Array.empty[Byte]) + sliceRange.setFinish(Array.empty[Byte]) + predicate.setSlice_range(sliceRange) + ConfigHelper.setInputSlicePredicate(job.getConfiguration(), predicate) + + ConfigHelper.setInputPartitioner(job.getConfiguration(), "Murmur3Partitioner") + ConfigHelper.setOutputPartitioner(job.getConfiguration(), "Murmur3Partitioner") + + // Make a new Hadoop RDD + val casRdd = sc.newAPIHadoopRDD( + job.getConfiguration(), + classOf[ColumnFamilyInputFormat], + classOf[ByteBuffer], + classOf[SortedMap[ByteBuffer, IColumn]]) + + // Let us first get all the paragraphs from the retrieved rows + val paraRdd = casRdd.map { + case (key, value) => { + ByteBufferUtil.string(value.get(ByteBufferUtil.bytes("para")).value()) + } + } + + // Lets get the word count in paras + val counts = paraRdd.flatMap(p => p.split(" ")).map(word => (word, 1)).reduceByKey(_ + _) + + counts.collect().foreach { + case (word, count) => println(word + ":" + count) + } + + counts.map { + case (word, count) => { + val colWord = new org.apache.cassandra.thrift.Column() + colWord.setName(ByteBufferUtil.bytes("word")) + colWord.setValue(ByteBufferUtil.bytes(word)) + colWord.setTimestamp(System.currentTimeMillis) + + val colCount = new org.apache.cassandra.thrift.Column() + colCount.setName(ByteBufferUtil.bytes("wcount")) + colCount.setValue(ByteBufferUtil.bytes(count.toLong)) + colCount.setTimestamp(System.currentTimeMillis) + + val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis) + + val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil + mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn()) + mutations.get(0).column_or_supercolumn.setColumn(colWord) + mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) + mutations.get(1).column_or_supercolumn.setColumn(colCount) + (outputkey, mutations) + } + }.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]], + classOf[ColumnFamilyOutputFormat], job.getConfiguration) + } +} + +/* +create keyspace casDemo; +use casDemo; + +create column family WordCount with comparator = UTF8Type; +update column family WordCount with column_metadata = + [{column_name: word, validation_class: UTF8Type}, + {column_name: wcount, validation_class: LongType}]; + +create column family Words with comparator = UTF8Type; +update column family Words with column_metadata = + [{column_name: book, validation_class: UTF8Type}, + {column_name: para, validation_class: UTF8Type}]; + +assume Words keys as utf8; + +set Words['3musk001']['book'] = 'The Three Musketeers'; +set Words['3musk001']['para'] = 'On the first Monday of the month of April, 1625, the market + town of Meung, in which the author of ROMANCE OF THE ROSE was born, appeared to + be in as perfect a state of revolution as if the Huguenots had just made + a second La Rochelle of it. Many citizens, seeing the women flying + toward the High Street, leaving their children crying at the open doors, + hastened to don the cuirass, and supporting their somewhat uncertain + courage with a musket or a partisan, directed their steps toward the + hostelry of the Jolly Miller, before which was gathered, increasing + every minute, a compact group, vociferous and full of curiosity.'; + +set Words['3musk002']['book'] = 'The Three Musketeers'; +set Words['3musk002']['para'] = 'In those times panics were common, and few days passed without + some city or other registering in its archives an event of this kind. There were + nobles, who made war against each other; there was the king, who made + war against the cardinal; there was Spain, which made war against the + king. Then, in addition to these concealed or public, secret or open + wars, there were robbers, mendicants, Huguenots, wolves, and scoundrels, + who made war upon everybody. The citizens always took up arms readily + against thieves, wolves or scoundrels, often against nobles or + Huguenots, sometimes against the king, but never against cardinal or + Spain. It resulted, then, from this habit that on the said first Monday + of April, 1625, the citizens, on hearing the clamor, and seeing neither + the red-and-yellow standard nor the livery of the Duc de Richelieu, + rushed toward the hostel of the Jolly Miller. When arrived there, the + cause of the hubbub was apparent to all'; + +set Words['3musk003']['book'] = 'The Three Musketeers'; +set Words['3musk003']['para'] = 'You ought, I say, then, to husband the means you have, however + large the sum may be; but you ought also to endeavor to perfect yourself in + the exercises becoming a gentleman. I will write a letter today to the + Director of the Royal Academy, and tomorrow he will admit you without + any expense to yourself. Do not refuse this little service. Our + best-born and richest gentlemen sometimes solicit it without being able + to obtain it. You will learn horsemanship, swordsmanship in all its + branches, and dancing. You will make some desirable acquaintances; and + from time to time you can call upon me, just to tell me how you are + getting on, and to say whether I can be of further service to you.'; + + +set Words['thelostworld001']['book'] = 'The Lost World'; +set Words['thelostworld001']['para'] = 'She sat with that proud, delicate profile of hers outlined + against the red curtain. How beautiful she was! And yet how aloof! We had been + friends, quite good friends; but never could I get beyond the same + comradeship which I might have established with one of my + fellow-reporters upon the Gazette,--perfectly frank, perfectly kindly, + and perfectly unsexual. My instincts are all against a woman being too + frank and at her ease with me. It is no compliment to a man. Where + the real sex feeling begins, timidity and distrust are its companions, + heritage from old wicked days when love and violence went often hand in + hand. The bent head, the averted eye, the faltering voice, the wincing + figure--these, and not the unshrinking gaze and frank reply, are the + true signals of passion. Even in my short life I had learned as much + as that--or had inherited it in that race memory which we call instinct.'; + +set Words['thelostworld002']['book'] = 'The Lost World'; +set Words['thelostworld002']['para'] = 'I always liked McArdle, the crabbed, old, round-backed, + red-headed news editor, and I rather hoped that he liked me. Of course, Beaumont was + the real boss; but he lived in the rarefied atmosphere of some Olympian + height from which he could distinguish nothing smaller than an + international crisis or a split in the Cabinet. Sometimes we saw him + passing in lonely majesty to his inner sanctum, with his eyes staring + vaguely and his mind hovering over the Balkans or the Persian Gulf. He + was above and beyond us. But McArdle was his first lieutenant, and it + was he that we knew. The old man nodded as I entered the room, and he + pushed his spectacles far up on his bald forehead.'; + +*/ diff --git a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala new file mode 100644 index 0000000000..92eb96bd8e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark.SparkContext + +object ExceptionHandlingTest { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: ExceptionHandlingTest ") + System.exit(1) + } + + val sc = new SparkContext(args(0), "ExceptionHandlingTest", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + sc.parallelize(0 until sc.defaultParallelism).foreach { i => + if (math.random > 0.75) + throw new Exception("Testing exception handling") + } + + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala new file mode 100644 index 0000000000..42c2e0e8e1 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import java.util.Random + +object GroupByTest { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]") + System.exit(1) + } + + var numMappers = if (args.length > 1) args(1).toInt else 2 + var numKVPairs = if (args.length > 2) args(2).toInt else 1000 + var valSize = if (args.length > 3) args(3).toInt else 1000 + var numReducers = if (args.length > 4) args(4).toInt else numMappers + + val sc = new SparkContext(args(0), "GroupBy Test", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => + val ranGen = new Random + var arr1 = new Array[(Int, Array[Byte])](numKVPairs) + for (i <- 0 until numKVPairs) { + val byteArr = new Array[Byte](valSize) + ranGen.nextBytes(byteArr) + arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr) + } + arr1 + }.cache + // Enforce that everything has been calculated and in cache + pairs1.count + + println(pairs1.groupByKey(numReducers).count) + + System.exit(0) + } +} + diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala new file mode 100644 index 0000000000..efe2e93b0d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark._ +import org.apache.spark.rdd.NewHadoopRDD +import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor} +import org.apache.hadoop.hbase.client.HBaseAdmin +import org.apache.hadoop.hbase.mapreduce.TableInputFormat + +object HBaseTest { + def main(args: Array[String]) { + val sc = new SparkContext(args(0), "HBaseTest", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val conf = HBaseConfiguration.create() + + // Other options for configuring scan behavior are available. More information available at + // http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html + conf.set(TableInputFormat.INPUT_TABLE, args(1)) + + // Initialize hBase table if necessary + val admin = new HBaseAdmin(conf) + if(!admin.isTableAvailable(args(1))) { + val tableDesc = new HTableDescriptor(args(1)) + admin.createTable(tableDesc) + } + + val hBaseRDD = sc.newAPIHadoopRDD(conf, classOf[TableInputFormat], + classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable], + classOf[org.apache.hadoop.hbase.client.Result]) + + hBaseRDD.count() + + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala new file mode 100644 index 0000000000..d6a88d3032 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark._ + +object HdfsTest { + def main(args: Array[String]) { + val sc = new SparkContext(args(0), "HdfsTest", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val file = sc.textFile(args(1)) + val mapped = file.map(s => s.length).cache() + for (iter <- 1 to 10) { + val start = System.currentTimeMillis() + for (x <- mapped) { x + 2 } + // println("Processing: " + x) + val end = System.currentTimeMillis() + println("Iteration " + iter + " took " + (end-start) + " ms") + } + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala new file mode 100644 index 0000000000..4af45b2b4a --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import scala.math.sqrt +import cern.jet.math._ +import cern.colt.matrix._ +import cern.colt.matrix.linalg._ + +/** + * Alternating least squares matrix factorization. + */ +object LocalALS { + // Parameters set through command line arguments + var M = 0 // Number of movies + var U = 0 // Number of users + var F = 0 // Number of features + var ITERATIONS = 0 + + val LAMBDA = 0.01 // Regularization coefficient + + // Some COLT objects + val factory2D = DoubleFactory2D.dense + val factory1D = DoubleFactory1D.dense + val algebra = Algebra.DEFAULT + val blas = SeqBlas.seqBlas + + def generateR(): DoubleMatrix2D = { + val mh = factory2D.random(M, F) + val uh = factory2D.random(U, F) + return algebra.mult(mh, algebra.transpose(uh)) + } + + def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D], + us: Array[DoubleMatrix1D]): Double = + { + val r = factory2D.make(M, U) + for (i <- 0 until M; j <- 0 until U) { + r.set(i, j, blas.ddot(ms(i), us(j))) + } + //println("R: " + r) + blas.daxpy(-1, targetR, r) + val sumSqs = r.aggregate(Functions.plus, Functions.square) + return sqrt(sumSqs / (M * U)) + } + + def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], + R: DoubleMatrix2D) : DoubleMatrix1D = + { + val XtX = factory2D.make(F, F) + val Xty = factory1D.make(F) + // For each user that rated the movie + for (j <- 0 until U) { + val u = us(j) + // Add u * u^t to XtX + blas.dger(1, u, u, XtX) + // Add u * rating to Xty + blas.daxpy(R.get(i, j), u, Xty) + } + // Add regularization coefs to diagonal terms + for (d <- 0 until F) { + XtX.set(d, d, XtX.get(d, d) + LAMBDA * U) + } + // Solve it with Cholesky + val ch = new CholeskyDecomposition(XtX) + val Xty2D = factory2D.make(Xty.toArray, F) + val solved2D = ch.solve(Xty2D) + return solved2D.viewColumn(0) + } + + def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D], + R: DoubleMatrix2D) : DoubleMatrix1D = + { + val XtX = factory2D.make(F, F) + val Xty = factory1D.make(F) + // For each movie that the user rated + for (i <- 0 until M) { + val m = ms(i) + // Add m * m^t to XtX + blas.dger(1, m, m, XtX) + // Add m * rating to Xty + blas.daxpy(R.get(i, j), m, Xty) + } + // Add regularization coefs to diagonal terms + for (d <- 0 until F) { + XtX.set(d, d, XtX.get(d, d) + LAMBDA * M) + } + // Solve it with Cholesky + val ch = new CholeskyDecomposition(XtX) + val Xty2D = factory2D.make(Xty.toArray, F) + val solved2D = ch.solve(Xty2D) + return solved2D.viewColumn(0) + } + + def main(args: Array[String]) { + args match { + case Array(m, u, f, iters) => { + M = m.toInt + U = u.toInt + F = f.toInt + ITERATIONS = iters.toInt + } + case _ => { + System.err.println("Usage: LocalALS ") + System.exit(1) + } + } + printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS); + + val R = generateR() + + // Initialize m and u randomly + var ms = Array.fill(M)(factory1D.random(F)) + var us = Array.fill(U)(factory1D.random(F)) + + // Iteratively update movies then users + for (iter <- 1 to ITERATIONS) { + println("Iteration " + iter + ":") + ms = (0 until M).map(i => updateMovie(i, ms(i), us, R)).toArray + us = (0 until U).map(j => updateUser(j, us(j), ms, R)).toArray + println("RMSE = " + rmse(R, ms, us)) + println() + } + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala new file mode 100644 index 0000000000..fb130ea198 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.util.Random +import org.apache.spark.util.Vector + +object LocalFileLR { + val D = 10 // Numer of dimensions + val rand = new Random(42) + + case class DataPoint(x: Vector, y: Double) + + def parsePoint(line: String): DataPoint = { + val nums = line.split(' ').map(_.toDouble) + return DataPoint(new Vector(nums.slice(1, D+1)), nums(0)) + } + + def main(args: Array[String]) { + val lines = scala.io.Source.fromFile(args(0)).getLines().toArray + val points = lines.map(parsePoint _) + val ITERATIONS = args(1).toInt + + // Initialize w to a random value + var w = Vector(D, _ => 2 * rand.nextDouble - 1) + println("Initial w: " + w) + + for (i <- 1 to ITERATIONS) { + println("On iteration " + i) + var gradient = Vector.zeros(D) + for (p <- points) { + val scale = (1 / (1 + math.exp(-p.y * (w dot p.x))) - 1) * p.y + gradient += scale * p.x + } + w -= gradient + } + + println("Final w: " + w) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala new file mode 100644 index 0000000000..f90ea35cd4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.util.Random +import org.apache.spark.util.Vector +import org.apache.spark.SparkContext._ +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet + +/** + * K-means clustering. + */ +object LocalKMeans { + val N = 1000 + val R = 1000 // Scaling factor + val D = 10 + val K = 10 + val convergeDist = 0.001 + val rand = new Random(42) + + def generateData = { + def generatePoint(i: Int) = { + Vector(D, _ => rand.nextDouble * R) + } + Array.tabulate(N)(generatePoint) + } + + def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = { + var index = 0 + var bestIndex = 0 + var closest = Double.PositiveInfinity + + for (i <- 1 to centers.size) { + val vCurr = centers.get(i).get + val tempDist = p.squaredDist(vCurr) + if (tempDist < closest) { + closest = tempDist + bestIndex = i + } + } + + return bestIndex + } + + def main(args: Array[String]) { + val data = generateData + var points = new HashSet[Vector] + var kPoints = new HashMap[Int, Vector] + var tempDist = 1.0 + + while (points.size < K) { + points.add(data(rand.nextInt(N))) + } + + val iter = points.iterator + for (i <- 1 to points.size) { + kPoints.put(i, iter.next()) + } + + println("Initial centers: " + kPoints) + + while(tempDist > convergeDist) { + var closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) + + var mappings = closest.groupBy[Int] (x => x._1) + + var pointStats = mappings.map(pair => pair._2.reduceLeft [(Int, (Vector, Int))] {case ((id1, (x1, y1)), (id2, (x2, y2))) => (id1, (x1 + x2, y1+y2))}) + + var newPoints = pointStats.map {mapping => (mapping._1, mapping._2._1/mapping._2._2)} + + tempDist = 0.0 + for (mapping <- newPoints) { + tempDist += kPoints.get(mapping._1).get.squaredDist(mapping._2) + } + + for (newP <- newPoints) { + kPoints.put(newP._1, newP._2) + } + } + + println("Final centers: " + kPoints) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala new file mode 100644 index 0000000000..cd4e9f1af0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.util.Random +import org.apache.spark.util.Vector + +/** + * Logistic regression based classification. + */ +object LocalLR { + val N = 10000 // Number of data points + val D = 10 // Number of dimensions + val R = 0.7 // Scaling factor + val ITERATIONS = 5 + val rand = new Random(42) + + case class DataPoint(x: Vector, y: Double) + + def generateData = { + def generatePoint(i: Int) = { + val y = if(i % 2 == 0) -1 else 1 + val x = Vector(D, _ => rand.nextGaussian + y * R) + DataPoint(x, y) + } + Array.tabulate(N)(generatePoint) + } + + def main(args: Array[String]) { + val data = generateData + + // Initialize w to a random value + var w = Vector(D, _ => 2 * rand.nextDouble - 1) + println("Initial w: " + w) + + for (i <- 1 to ITERATIONS) { + println("On iteration " + i) + var gradient = Vector.zeros(D) + for (p <- data) { + val scale = (1 / (1 + math.exp(-p.y * (w dot p.x))) - 1) * p.y + gradient += scale * p.x + } + w -= gradient + } + + println("Final w: " + w) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala new file mode 100644 index 0000000000..bb7f22ec8d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import scala.math.random +import org.apache.spark._ +import SparkContext._ + +object LocalPi { + def main(args: Array[String]) { + var count = 0 + for (i <- 1 to 100000) { + val x = random * 2 - 1 + val y = random * 2 - 1 + if (x*x + y*y < 1) count += 1 + } + println("Pi is roughly " + 4 * count / 100000.0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala new file mode 100644 index 0000000000..17ff3ce764 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +/** + * Executes a roll up-style query against Apache logs. + */ +object LogQuery { + val exampleApacheLogs = List( + """10.10.10.10 - "FRED" [18/Jan/2013:17:56:07 +1100] "GET http://images.com/2013/Generic.jpg + | HTTP/1.1" 304 315 "http://referall.com/" "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; + | GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR + | 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR + | 3.5.30729; Release=ARP)" "UD-1" - "image/jpeg" "whatever" 0.350 "-" - "" 265 923 934 "" + | 62.24.11.25 images.com 1358492167 - Whatup""".stripMargin.replace("\n", ""), + """10.10.10.10 - "FRED" [18/Jan/2013:18:02:37 +1100] "GET http://images.com/2013/Generic.jpg + | HTTP/1.1" 304 306 "http:/referall.com" "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; + | GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR + | 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR + | 3.5.30729; Release=ARP)" "UD-1" - "image/jpeg" "whatever" 0.352 "-" - "" 256 977 988 "" + | 0 73.23.2.15 images.com 1358492557 - Whatup""".stripMargin.replace("\n", "") + ) + + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: LogQuery [logFile]") + System.exit(1) + } + + val sc = new SparkContext(args(0), "Log Query", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val dataSet = + if (args.length == 2) sc.textFile(args(1)) + else sc.parallelize(exampleApacheLogs) + + val apacheLogRegex = + """^([\d.]+) (\S+) (\S+) \[([\w\d:/]+\s[+\-]\d{4})\] "(.+?)" (\d{3}) ([\d\-]+) "([^"]+)" "([^"]+)".*""".r + + /** Tracks the total query count and number of aggregate bytes for a particular group. */ + class Stats(val count: Int, val numBytes: Int) extends Serializable { + def merge(other: Stats) = new Stats(count + other.count, numBytes + other.numBytes) + override def toString = "bytes=%s\tn=%s".format(numBytes, count) + } + + def extractKey(line: String): (String, String, String) = { + apacheLogRegex.findFirstIn(line) match { + case Some(apacheLogRegex(ip, _, user, dateTime, query, status, bytes, referer, ua)) => + if (user != "\"-\"") (ip, user, query) + else (null, null, null) + case _ => (null, null, null) + } + } + + def extractStats(line: String): Stats = { + apacheLogRegex.findFirstIn(line) match { + case Some(apacheLogRegex(ip, _, user, dateTime, query, status, bytes, referer, ua)) => + new Stats(1, bytes.toInt) + case _ => new Stats(1, 0) + } + } + + dataSet.map(line => (extractKey(line), extractStats(line))) + .reduceByKey((a, b) => a.merge(b)) + .collect().foreach{ + case (user, query) => println("%s\t%s".format(user, query))} + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala new file mode 100644 index 0000000000..f79f0142b8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark.SparkContext + +object MultiBroadcastTest { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: BroadcastTest [] [numElem]") + System.exit(1) + } + + val sc = new SparkContext(args(0), "Broadcast Test", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val slices = if (args.length > 1) args(1).toInt else 2 + val num = if (args.length > 2) args(2).toInt else 1000000 + + var arr1 = new Array[Int](num) + for (i <- 0 until arr1.length) { + arr1(i) = i + } + + var arr2 = new Array[Int](num) + for (i <- 0 until arr2.length) { + arr2(i) = i + } + + val barr1 = sc.broadcast(arr1) + val barr2 = sc.broadcast(arr2) + sc.parallelize(1 to 10, slices).foreach { + i => println(barr1.value.size + barr2.value.size) + } + + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala new file mode 100644 index 0000000000..37ddfb5db7 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import java.util.Random + +object SimpleSkewedGroupByTest { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: SimpleSkewedGroupByTest " + + "[numMappers] [numKVPairs] [valSize] [numReducers] [ratio]") + System.exit(1) + } + + var numMappers = if (args.length > 1) args(1).toInt else 2 + var numKVPairs = if (args.length > 2) args(2).toInt else 1000 + var valSize = if (args.length > 3) args(3).toInt else 1000 + var numReducers = if (args.length > 4) args(4).toInt else numMappers + var ratio = if (args.length > 5) args(5).toInt else 5.0 + + val sc = new SparkContext(args(0), "GroupBy Test", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => + val ranGen = new Random + var result = new Array[(Int, Array[Byte])](numKVPairs) + for (i <- 0 until numKVPairs) { + val byteArr = new Array[Byte](valSize) + ranGen.nextBytes(byteArr) + val offset = ranGen.nextInt(1000) * numReducers + if (ranGen.nextDouble < ratio / (numReducers + ratio - 1)) { + // give ratio times higher chance of generating key 0 (for reducer 0) + result(i) = (offset, byteArr) + } else { + // generate a key for one of the other reducers + val key = 1 + ranGen.nextInt(numReducers-1) + offset + result(i) = (key, byteArr) + } + } + result + }.cache + // Enforce that everything has been calculated and in cache + pairs1.count + + println("RESULT: " + pairs1.groupByKey(numReducers).count) + // Print how many keys each reducer got (for debugging) + //println("RESULT: " + pairs1.groupByKey(numReducers) + // .map{case (k,v) => (k, v.size)} + // .collectAsMap) + + System.exit(0) + } +} + diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala new file mode 100644 index 0000000000..9c954b2b5b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import java.util.Random + +object SkewedGroupByTest { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]") + System.exit(1) + } + + var numMappers = if (args.length > 1) args(1).toInt else 2 + var numKVPairs = if (args.length > 2) args(2).toInt else 1000 + var valSize = if (args.length > 3) args(3).toInt else 1000 + var numReducers = if (args.length > 4) args(4).toInt else numMappers + + val sc = new SparkContext(args(0), "GroupBy Test", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => + val ranGen = new Random + + // map output sizes lineraly increase from the 1st to the last + numKVPairs = (1.0 * (p + 1) / numMappers * numKVPairs).toInt + + var arr1 = new Array[(Int, Array[Byte])](numKVPairs) + for (i <- 0 until numKVPairs) { + val byteArr = new Array[Byte](valSize) + ranGen.nextBytes(byteArr) + arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr) + } + arr1 + }.cache() + // Enforce that everything has been calculated and in cache + pairs1.count() + + println(pairs1.groupByKey(numReducers).count()) + + System.exit(0) + } +} + diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala new file mode 100644 index 0000000000..814944ba1c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import scala.math.sqrt +import cern.jet.math._ +import cern.colt.matrix._ +import cern.colt.matrix.linalg._ +import org.apache.spark._ + +/** + * Alternating least squares matrix factorization. + */ +object SparkALS { + // Parameters set through command line arguments + var M = 0 // Number of movies + var U = 0 // Number of users + var F = 0 // Number of features + var ITERATIONS = 0 + + val LAMBDA = 0.01 // Regularization coefficient + + // Some COLT objects + val factory2D = DoubleFactory2D.dense + val factory1D = DoubleFactory1D.dense + val algebra = Algebra.DEFAULT + val blas = SeqBlas.seqBlas + + def generateR(): DoubleMatrix2D = { + val mh = factory2D.random(M, F) + val uh = factory2D.random(U, F) + return algebra.mult(mh, algebra.transpose(uh)) + } + + def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D], + us: Array[DoubleMatrix1D]): Double = + { + val r = factory2D.make(M, U) + for (i <- 0 until M; j <- 0 until U) { + r.set(i, j, blas.ddot(ms(i), us(j))) + } + //println("R: " + r) + blas.daxpy(-1, targetR, r) + val sumSqs = r.aggregate(Functions.plus, Functions.square) + return sqrt(sumSqs / (M * U)) + } + + def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], + R: DoubleMatrix2D) : DoubleMatrix1D = + { + val U = us.size + val F = us(0).size + val XtX = factory2D.make(F, F) + val Xty = factory1D.make(F) + // For each user that rated the movie + for (j <- 0 until U) { + val u = us(j) + // Add u * u^t to XtX + blas.dger(1, u, u, XtX) + // Add u * rating to Xty + blas.daxpy(R.get(i, j), u, Xty) + } + // Add regularization coefs to diagonal terms + for (d <- 0 until F) { + XtX.set(d, d, XtX.get(d, d) + LAMBDA * U) + } + // Solve it with Cholesky + val ch = new CholeskyDecomposition(XtX) + val Xty2D = factory2D.make(Xty.toArray, F) + val solved2D = ch.solve(Xty2D) + return solved2D.viewColumn(0) + } + + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: SparkALS [ ]") + System.exit(1) + } + + var host = "" + var slices = 0 + + val options = (0 to 5).map(i => if (i < args.length) Some(args(i)) else None) + + options.toArray match { + case Array(host_, m, u, f, iters, slices_) => + host = host_.get + M = m.getOrElse("100").toInt + U = u.getOrElse("500").toInt + F = f.getOrElse("10").toInt + ITERATIONS = iters.getOrElse("5").toInt + slices = slices_.getOrElse("2").toInt + case _ => + System.err.println("Usage: SparkALS [ ]") + System.exit(1) + } + printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) + + val sc = new SparkContext(host, "SparkALS", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + val R = generateR() + + // Initialize m and u randomly + var ms = Array.fill(M)(factory1D.random(F)) + var us = Array.fill(U)(factory1D.random(F)) + + // Iteratively update movies then users + val Rc = sc.broadcast(R) + var msb = sc.broadcast(ms) + var usb = sc.broadcast(us) + for (iter <- 1 to ITERATIONS) { + println("Iteration " + iter + ":") + ms = sc.parallelize(0 until M, slices) + .map(i => update(i, msb.value(i), usb.value, Rc.value)) + .toArray + msb = sc.broadcast(ms) // Re-broadcast ms because it was updated + us = sc.parallelize(0 until U, slices) + .map(i => update(i, usb.value(i), msb.value, algebra.transpose(Rc.value))) + .toArray + usb = sc.broadcast(us) // Re-broadcast us because it was updated + println("RMSE = " + rmse(R, ms, us)) + println() + } + + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala new file mode 100644 index 0000000000..646682878f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.util.Random +import scala.math.exp +import org.apache.spark.util.Vector +import org.apache.spark._ +import org.apache.spark.scheduler.InputFormatInfo + +/** + * Logistic regression based classification. + */ +object SparkHdfsLR { + val D = 10 // Numer of dimensions + val rand = new Random(42) + + case class DataPoint(x: Vector, y: Double) + + def parsePoint(line: String): DataPoint = { + //val nums = line.split(' ').map(_.toDouble) + //return DataPoint(new Vector(nums.slice(1, D+1)), nums(0)) + val tok = new java.util.StringTokenizer(line, " ") + var y = tok.nextToken.toDouble + var x = new Array[Double](D) + var i = 0 + while (i < D) { + x(i) = tok.nextToken.toDouble; i += 1 + } + return DataPoint(new Vector(x), y) + } + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: SparkHdfsLR ") + System.exit(1) + } + val inputPath = args(1) + val conf = SparkEnv.get.hadoop.newConfiguration() + val sc = new SparkContext(args(0), "SparkHdfsLR", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")), Map(), + InputFormatInfo.computePreferredLocations( + Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)))) + val lines = sc.textFile(inputPath) + val points = lines.map(parsePoint _).cache() + val ITERATIONS = args(2).toInt + + // Initialize w to a random value + var w = Vector(D, _ => 2 * rand.nextDouble - 1) + println("Initial w: " + w) + + for (i <- 1 to ITERATIONS) { + println("On iteration " + i) + val gradient = points.map { p => + (1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x + }.reduce(_ + _) + w -= gradient + } + + println("Final w: " + w) + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala new file mode 100644 index 0000000000..f7bf75b4e5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.util.Random +import org.apache.spark.SparkContext +import org.apache.spark.util.Vector +import org.apache.spark.SparkContext._ +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet + +/** + * K-means clustering. + */ +object SparkKMeans { + val R = 1000 // Scaling factor + val rand = new Random(42) + + def parseVector(line: String): Vector = { + return new Vector(line.split(' ').map(_.toDouble)) + } + + def closestPoint(p: Vector, centers: Array[Vector]): Int = { + var index = 0 + var bestIndex = 0 + var closest = Double.PositiveInfinity + + for (i <- 0 until centers.length) { + val tempDist = p.squaredDist(centers(i)) + if (tempDist < closest) { + closest = tempDist + bestIndex = i + } + } + + return bestIndex + } + + def main(args: Array[String]) { + if (args.length < 4) { + System.err.println("Usage: SparkLocalKMeans ") + System.exit(1) + } + val sc = new SparkContext(args(0), "SparkLocalKMeans", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val lines = sc.textFile(args(1)) + val data = lines.map(parseVector _).cache() + val K = args(2).toInt + val convergeDist = args(3).toDouble + + var kPoints = data.takeSample(false, K, 42).toArray + var tempDist = 1.0 + + while(tempDist > convergeDist) { + var closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) + + var pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)} + + var newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collectAsMap() + + tempDist = 0.0 + for (i <- 0 until K) { + tempDist += kPoints(i).squaredDist(newPoints(i)) + } + + for (newP <- newPoints) { + kPoints(newP._1) = newP._2 + } + println("Finished iteration (delta = " + tempDist + ")") + } + + println("Final centers:") + kPoints.foreach(println) + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala new file mode 100644 index 0000000000..9ed9fe4d76 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.util.Random +import scala.math.exp +import org.apache.spark.util.Vector +import org.apache.spark._ + +/** + * Logistic regression based classification. + */ +object SparkLR { + val N = 10000 // Number of data points + val D = 10 // Numer of dimensions + val R = 0.7 // Scaling factor + val ITERATIONS = 5 + val rand = new Random(42) + + case class DataPoint(x: Vector, y: Double) + + def generateData = { + def generatePoint(i: Int) = { + val y = if(i % 2 == 0) -1 else 1 + val x = Vector(D, _ => rand.nextGaussian + y * R) + DataPoint(x, y) + } + Array.tabulate(N)(generatePoint) + } + + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: SparkLR []") + System.exit(1) + } + val sc = new SparkContext(args(0), "SparkLR", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val numSlices = if (args.length > 1) args(1).toInt else 2 + val points = sc.parallelize(generateData, numSlices).cache() + + // Initialize w to a random value + var w = Vector(D, _ => 2 * rand.nextDouble - 1) + println("Initial w: " + w) + + for (i <- 1 to ITERATIONS) { + println("On iteration " + i) + val gradient = points.map { p => + (1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x + }.reduce(_ + _) + w -= gradient + } + + println("Final w: " + w) + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala new file mode 100644 index 0000000000..2721caf08b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -0,0 +1,46 @@ +package org.apache.spark.examples + +import org.apache.spark.SparkContext._ +import org.apache.spark.SparkContext + + +/** + * Computes the PageRank of URLs from an input file. Input file should + * be in format of: + * URL neighbor URL + * URL neighbor URL + * URL neighbor URL + * ... + * where URL and their neighbors are separated by space(s). + */ +object SparkPageRank { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: PageRank ") + System.exit(1) + } + var iters = args(2).toInt + val ctx = new SparkContext(args(0), "PageRank", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val lines = ctx.textFile(args(1), 1) + val links = lines.map{ s => + val parts = s.split("\\s+") + (parts(0), parts(1)) + }.distinct().groupByKey().cache() + var ranks = links.mapValues(v => 1.0) + + for (i <- 1 to iters) { + val contribs = links.join(ranks).values.flatMap{ case (urls, rank) => + val size = urls.size + urls.map(url => (url, rank / size)) + } + ranks = contribs.reduceByKey(_ + _).mapValues(0.15 + 0.85 * _) + } + + val output = ranks.collect() + output.foreach(tup => println(tup._1 + " has rank: " + tup._2 + ".")) + + System.exit(0) + } +} + diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala new file mode 100644 index 0000000000..5a2bc9b0d0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import scala.math.random +import org.apache.spark._ +import SparkContext._ + +/** Computes an approximation to pi */ +object SparkPi { + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: SparkPi []") + System.exit(1) + } + val spark = new SparkContext(args(0), "SparkPi", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val slices = if (args.length > 1) args(1).toInt else 2 + val n = 100000 * slices + val count = spark.parallelize(1 to n, slices).map { i => + val x = random * 2 - 1 + val y = random * 2 - 1 + if (x*x + y*y < 1) 1 else 0 + }.reduce(_ + _) + println("Pi is roughly " + 4.0 * count / n) + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala new file mode 100644 index 0000000000..5a7a9d1bd8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import org.apache.spark._ +import SparkContext._ +import scala.util.Random +import scala.collection.mutable + +/** + * Transitive closure on a graph. + */ +object SparkTC { + val numEdges = 200 + val numVertices = 100 + val rand = new Random(42) + + def generateGraph = { + val edges: mutable.Set[(Int, Int)] = mutable.Set.empty + while (edges.size < numEdges) { + val from = rand.nextInt(numVertices) + val to = rand.nextInt(numVertices) + if (from != to) edges.+=((from, to)) + } + edges.toSeq + } + + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: SparkTC []") + System.exit(1) + } + val spark = new SparkContext(args(0), "SparkTC", + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val slices = if (args.length > 1) args(1).toInt else 2 + var tc = spark.parallelize(generateGraph, slices).cache() + + // Linear transitive closure: each round grows paths by one edge, + // by joining the graph's edges with the already-discovered paths. + // e.g. join the path (y, z) from the TC with the edge (x, y) from + // the graph to obtain the path (x, z). + + // Because join() joins on keys, the edges are stored in reversed order. + val edges = tc.map(x => (x._2, x._1)) + + // This join is iterated until a fixed point is reached. + var oldCount = 0L + var nextCount = tc.count() + do { + oldCount = nextCount + // Perform the join, obtaining an RDD of (y, (z, x)) pairs, + // then project the result to obtain the new (x, z) paths. + tc = tc.union(tc.join(edges).map(x => (x._2._2, x._2._1))).distinct().cache(); + nextCount = tc.count() + } while (nextCount != oldCount) + + println("TC has " + tc.count() + " edges.") + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala new file mode 100644 index 0000000000..b190e83c4d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.bagel + +import org.apache.spark._ +import org.apache.spark.SparkContext._ + +import org.apache.spark.bagel._ +import org.apache.spark.bagel.Bagel._ + +import scala.collection.mutable.ArrayBuffer + +import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} + +import com.esotericsoftware.kryo._ + +class PageRankUtils extends Serializable { + def computeWithCombiner(numVertices: Long, epsilon: Double)( + self: PRVertex, messageSum: Option[Double], superstep: Int + ): (PRVertex, Array[PRMessage]) = { + val newValue = messageSum match { + case Some(msgSum) if msgSum != 0 => + 0.15 / numVertices + 0.85 * msgSum + case _ => self.value + } + + val terminate = superstep >= 10 + + val outbox: Array[PRMessage] = + if (!terminate) + self.outEdges.map(targetId => + new PRMessage(targetId, newValue / self.outEdges.size)) + else + Array[PRMessage]() + + (new PRVertex(newValue, self.outEdges, !terminate), outbox) + } + + def computeNoCombiner(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int): (PRVertex, Array[PRMessage]) = + computeWithCombiner(numVertices, epsilon)(self, messages match { + case Some(msgs) => Some(msgs.map(_.value).sum) + case None => None + }, superstep) +} + +class PRCombiner extends Combiner[PRMessage, Double] with Serializable { + def createCombiner(msg: PRMessage): Double = + msg.value + def mergeMsg(combiner: Double, msg: PRMessage): Double = + combiner + msg.value + def mergeCombiners(a: Double, b: Double): Double = + a + b +} + +class PRVertex() extends Vertex with Serializable { + var value: Double = _ + var outEdges: Array[String] = _ + var active: Boolean = _ + + def this(value: Double, outEdges: Array[String], active: Boolean = true) { + this() + this.value = value + this.outEdges = outEdges + this.active = active + } + + override def toString(): String = { + "PRVertex(value=%f, outEdges.length=%d, active=%s)".format(value, outEdges.length, active.toString) + } +} + +class PRMessage() extends Message[String] with Serializable { + var targetId: String = _ + var value: Double = _ + + def this(targetId: String, value: Double) { + this() + this.targetId = targetId + this.value = value + } +} + +class PRKryoRegistrator extends KryoRegistrator { + def registerClasses(kryo: Kryo) { + kryo.register(classOf[PRVertex]) + kryo.register(classOf[PRMessage]) + } +} + +class CustomPartitioner(partitions: Int) extends Partitioner { + def numPartitions = partitions + + def getPartition(key: Any): Int = { + val hash = key match { + case k: Long => (k & 0x00000000FFFFFFFFL).toInt + case _ => key.hashCode + } + + val mod = key.hashCode % partitions + if (mod < 0) mod + partitions else mod + } + + override def equals(other: Any): Boolean = other match { + case c: CustomPartitioner => + c.numPartitions == numPartitions + case _ => false + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala new file mode 100644 index 0000000000..b1f606e48e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.bagel + +import org.apache.spark._ +import org.apache.spark.SparkContext._ + +import org.apache.spark.bagel._ +import org.apache.spark.bagel.Bagel._ + +import scala.xml.{XML,NodeSeq} + +/** + * Run PageRank on XML Wikipedia dumps from http://wiki.freebase.com/wiki/WEX. Uses the "articles" + * files from there, which contains one line per wiki article in a tab-separated format + * (http://wiki.freebase.com/wiki/WEX/Documentation#articles). + */ +object WikipediaPageRank { + def main(args: Array[String]) { + if (args.length < 5) { + System.err.println("Usage: WikipediaPageRank ") + System.exit(-1) + } + + System.setProperty("spark.serializer", "org.apache.spark.KryoSerializer") + System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) + + val inputFile = args(0) + val threshold = args(1).toDouble + val numPartitions = args(2).toInt + val host = args(3) + val usePartitioner = args(4).toBoolean + val sc = new SparkContext(host, "WikipediaPageRank") + + // Parse the Wikipedia page data into a graph + val input = sc.textFile(inputFile) + + println("Counting vertices...") + val numVertices = input.count() + println("Done counting vertices.") + + println("Parsing input file...") + var vertices = input.map(line => { + val fields = line.split("\t") + val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) + val links = + if (body == "\\N") + NodeSeq.Empty + else + try { + XML.loadString(body) \\ "link" \ "target" + } catch { + case e: org.xml.sax.SAXParseException => + System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) + NodeSeq.Empty + } + val outEdges = links.map(link => new String(link.text)).toArray + val id = new String(title) + (id, new PRVertex(1.0 / numVertices, outEdges)) + }) + if (usePartitioner) + vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache + else + vertices = vertices.cache + println("Done parsing input file.") + + // Do the computation + val epsilon = 0.01 / numVertices + val messages = sc.parallelize(Array[(String, PRMessage)]()) + val utils = new PageRankUtils + val result = + Bagel.run( + sc, vertices, messages, combiner = new PRCombiner(), + numPartitions = numPartitions)( + utils.computeWithCombiner(numVertices, epsilon)) + + // Print the result + System.err.println("Articles with PageRank >= "+threshold+":") + val top = + (result + .filter { case (id, vertex) => vertex.value >= threshold } + .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) } + .collect.mkString) + println(top) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala new file mode 100644 index 0000000000..3bfa48eaf3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.bagel + +import org.apache.spark._ +import serializer.{DeserializationStream, SerializationStream, SerializerInstance} +import org.apache.spark.SparkContext._ + +import org.apache.spark.bagel._ +import org.apache.spark.bagel.Bagel._ + +import scala.xml.{XML,NodeSeq} + +import scala.collection.mutable.ArrayBuffer + +import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} +import java.nio.ByteBuffer + +object WikipediaPageRankStandalone { + def main(args: Array[String]) { + if (args.length < 5) { + System.err.println("Usage: WikipediaPageRankStandalone ") + System.exit(-1) + } + + System.setProperty("spark.serializer", "spark.bagel.examples.WPRSerializer") + + val inputFile = args(0) + val threshold = args(1).toDouble + val numIterations = args(2).toInt + val host = args(3) + val usePartitioner = args(4).toBoolean + val sc = new SparkContext(host, "WikipediaPageRankStandalone") + + val input = sc.textFile(inputFile) + val partitioner = new HashPartitioner(sc.defaultParallelism) + val links = + if (usePartitioner) + input.map(parseArticle _).partitionBy(partitioner).cache() + else + input.map(parseArticle _).cache() + val n = links.count() + val defaultRank = 1.0 / n + val a = 0.15 + + // Do the computation + val startTime = System.currentTimeMillis + val ranks = + pageRank(links, numIterations, defaultRank, a, n, partitioner, usePartitioner, sc.defaultParallelism) + + // Print the result + System.err.println("Articles with PageRank >= "+threshold+":") + val top = + (ranks + .filter { case (id, rank) => rank >= threshold } + .map { case (id, rank) => "%s\t%s\n".format(id, rank) } + .collect().mkString) + println(top) + + val time = (System.currentTimeMillis - startTime) / 1000.0 + println("Completed %d iterations in %f seconds: %f seconds per iteration" + .format(numIterations, time, time / numIterations)) + System.exit(0) + } + + def parseArticle(line: String): (String, Array[String]) = { + val fields = line.split("\t") + val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) + val id = new String(title) + val links = + if (body == "\\N") + NodeSeq.Empty + else + try { + XML.loadString(body) \\ "link" \ "target" + } catch { + case e: org.xml.sax.SAXParseException => + System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) + NodeSeq.Empty + } + val outEdges = links.map(link => new String(link.text)).toArray + (id, outEdges) + } + + def pageRank( + links: RDD[(String, Array[String])], + numIterations: Int, + defaultRank: Double, + a: Double, + n: Long, + partitioner: Partitioner, + usePartitioner: Boolean, + numPartitions: Int + ): RDD[(String, Double)] = { + var ranks = links.mapValues { edges => defaultRank } + for (i <- 1 to numIterations) { + val contribs = links.groupWith(ranks).flatMap { + case (id, (linksWrapper, rankWrapper)) => + if (linksWrapper.length > 0) { + if (rankWrapper.length > 0) { + linksWrapper(0).map(dest => (dest, rankWrapper(0) / linksWrapper(0).size)) + } else { + linksWrapper(0).map(dest => (dest, defaultRank / linksWrapper(0).size)) + } + } else { + Array[(String, Double)]() + } + } + ranks = (contribs.combineByKey((x: Double) => x, + (x: Double, y: Double) => x + y, + (x: Double, y: Double) => x + y, + partitioner) + .mapValues(sum => a/n + (1-a)*sum)) + } + ranks + } +} + +class WPRSerializer extends org.apache.spark.serializer.Serializer { + def newInstance(): SerializerInstance = new WPRSerializerInstance() +} + +class WPRSerializerInstance extends SerializerInstance { + def serialize[T](t: T): ByteBuffer = { + throw new UnsupportedOperationException() + } + + def deserialize[T](bytes: ByteBuffer): T = { + throw new UnsupportedOperationException() + } + + def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + throw new UnsupportedOperationException() + } + + def serializeStream(s: OutputStream): SerializationStream = { + new WPRSerializationStream(s) + } + + def deserializeStream(s: InputStream): DeserializationStream = { + new WPRDeserializationStream(s) + } +} + +class WPRSerializationStream(os: OutputStream) extends SerializationStream { + val dos = new DataOutputStream(os) + + def writeObject[T](t: T): SerializationStream = t match { + case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { + case links: Array[String] => { + dos.writeInt(0) // links + dos.writeUTF(id) + dos.writeInt(links.length) + for (link <- links) { + dos.writeUTF(link) + } + this + } + case rank: Double => { + dos.writeInt(1) // rank + dos.writeUTF(id) + dos.writeDouble(rank) + this + } + } + case (id: String, rank: Double) => { + dos.writeInt(2) // rank without wrapper + dos.writeUTF(id) + dos.writeDouble(rank) + this + } + } + + def flush() { dos.flush() } + def close() { dos.close() } +} + +class WPRDeserializationStream(is: InputStream) extends DeserializationStream { + val dis = new DataInputStream(is) + + def readObject[T](): T = { + val typeId = dis.readInt() + typeId match { + case 0 => { + val id = dis.readUTF() + val numLinks = dis.readInt() + val links = new Array[String](numLinks) + for (i <- 0 until numLinks) { + val link = dis.readUTF() + links(i) = link + } + (id, ArrayBuffer(links)).asInstanceOf[T] + } + case 1 => { + val id = dis.readUTF() + val rank = dis.readDouble() + (id, ArrayBuffer(rank)).asInstanceOf[T] + } + case 2 => { + val id = dis.readUTF() + val rank = dis.readDouble() + (id, rank).asInstanceOf[T] + } + } + } + + def close() { dis.close() } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala new file mode 100644 index 0000000000..cd3423a07b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import scala.collection.mutable.LinkedList +import scala.util.Random + +import akka.actor.Actor +import akka.actor.ActorRef +import akka.actor.Props +import akka.actor.actorRef2Scala + +import org.apache.spark.streaming.Seconds +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions +import org.apache.spark.streaming.receivers.Receiver +import org.apache.spark.util.AkkaUtils + +case class SubscribeReceiver(receiverActor: ActorRef) +case class UnsubscribeReceiver(receiverActor: ActorRef) + +/** + * Sends the random content to every receiver subscribed with 1/2 + * second delay. + */ +class FeederActor extends Actor { + + val rand = new Random() + var receivers: LinkedList[ActorRef] = new LinkedList[ActorRef]() + + val strings: Array[String] = Array("words ", "may ", "count ") + + def makeMessage(): String = { + val x = rand.nextInt(3) + strings(x) + strings(2 - x) + } + + /* + * A thread to generate random messages + */ + new Thread() { + override def run() { + while (true) { + Thread.sleep(500) + receivers.foreach(_ ! makeMessage) + } + } + }.start() + + def receive: Receive = { + + case SubscribeReceiver(receiverActor: ActorRef) => + println("received subscribe from %s".format(receiverActor.toString)) + receivers = LinkedList(receiverActor) ++ receivers + + case UnsubscribeReceiver(receiverActor: ActorRef) => + println("received unsubscribe from %s".format(receiverActor.toString)) + receivers = receivers.dropWhile(x => x eq receiverActor) + + } +} + +/** + * A sample actor as receiver, is also simplest. This receiver actor + * goes and subscribe to a typical publisher/feeder actor and receives + * data. + * + * @see [[org.apache.spark.streaming.examples.FeederActor]] + */ +class SampleActorReceiver[T: ClassManifest](urlOfPublisher: String) +extends Actor with Receiver { + + lazy private val remotePublisher = context.actorFor(urlOfPublisher) + + override def preStart = remotePublisher ! SubscribeReceiver(context.self) + + def receive = { + case msg ⇒ context.parent ! pushBlock(msg.asInstanceOf[T]) + } + + override def postStop() = remotePublisher ! UnsubscribeReceiver(context.self) + +} + +/** + * A sample feeder actor + * + * Usage: FeederActor + * and describe the AkkaSystem that Spark Sample feeder would start on. + */ +object FeederActor { + + def main(args: Array[String]) { + if(args.length < 2){ + System.err.println( + "Usage: FeederActor \n" + ) + System.exit(1) + } + val Seq(host, port) = args.toSeq + + + val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt)._1 + val feeder = actorSystem.actorOf(Props[FeederActor], "FeederActor") + + println("Feeder started as:" + feeder) + + actorSystem.awaitTermination(); + } +} + +/** + * A sample word count program demonstrating the use of plugging in + * Actor as Receiver + * Usage: ActorWordCount + * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * and describe the AkkaSystem that Spark Sample feeder is running on. + * + * To run this example locally, you may run Feeder Actor as + * `$ ./run-example spark.streaming.examples.FeederActor 127.0.1.1 9999` + * and then run the example + * `$ ./run-example spark.streaming.examples.ActorWordCount local[2] 127.0.1.1 9999` + */ +object ActorWordCount { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println( + "Usage: ActorWordCount " + + "In local mode, should be 'local[n]' with n > 1") + System.exit(1) + } + + val Seq(master, host, port) = args.toSeq + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "ActorWordCount", Seconds(2), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + /* + * Following is the use of actorStream to plug in custom actor as receiver + * + * An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e type of data received and InputDstream + * should be same. + * + * For example: Both actorStream and SampleActorReceiver are parameterized + * to same type to ensure type safety. + */ + + val lines = ssc.actorStream[String]( + Props(new SampleActorReceiver[String]("akka://test@%s:%s/user/FeederActor".format( + host, port.toInt))), "SampleReceiver") + + //compute wordcount + lines.flatMap(_.split("\\s+")).map(x => (x, 1)).reduceByKey(_ + _).print() + + ssc.start() + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/FlumeEventCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/FlumeEventCount.scala new file mode 100644 index 0000000000..9f6e163454 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/FlumeEventCount.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.util.IntParam +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ + +/** + * Produces a count of events received from Flume. + * + * This should be used in conjunction with an AvroSink in Flume. It will start + * an Avro server on at the request host:port address and listen for requests. + * Your Flume AvroSink should be pointed to this address. + * + * Usage: FlumeEventCount + * + * is a Spark master URL + * is the host the Flume receiver will be started on - a receiver + * creates a server and listens for flume events. + * is the port the Flume receiver will listen on. + */ +object FlumeEventCount { + def main(args: Array[String]) { + if (args.length != 3) { + System.err.println( + "Usage: FlumeEventCount ") + System.exit(1) + } + + val Array(master, host, IntParam(port)) = args + + val batchInterval = Milliseconds(2000) + // Create the context and set the batch size + val ssc = new StreamingContext(master, "FlumeEventCount", batchInterval, + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + // Create a flume stream + val stream = ssc.flumeStream(host,port,StorageLevel.MEMORY_ONLY) + + // Print out the count of events received from this server in each batch + stream.count().map(cnt => "Received " + cnt + " flume events." ).print() + + ssc.start() + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala new file mode 100644 index 0000000000..bc8564b3ba --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.streaming.StreamingContext._ + + +/** + * Counts words in new text files created in the given directory + * Usage: HdfsWordCount + * is the Spark master URL. + * is the directory that Spark Streaming will use to find and read new text files. + * + * To run this on your local machine on directory `localdir`, run this example + * `$ ./run-example spark.streaming.examples.HdfsWordCount local[2] localdir` + * Then create a text file in `localdir` and the words in the file will get counted. + */ +object HdfsWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: HdfsWordCount ") + System.exit(1) + } + + // Create the context + val ssc = new StreamingContext(args(0), "HdfsWordCount", Seconds(2), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + // Create the FileInputDStream on the directory and use the + // stream to count words in new files created + val lines = ssc.textFileStream(args(1)) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } +} + diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala new file mode 100644 index 0000000000..12f939d5a7 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import java.util.Properties +import kafka.message.Message +import kafka.producer.SyncProducerConfig +import kafka.producer._ +import org.apache.spark.SparkContext +import org.apache.spark.streaming._ +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.util.RawTextHelper._ + +/** + * Consumes messages from one or more topics in Kafka and does wordcount. + * Usage: KafkaWordCount + * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * is a list of one or more zookeeper servers that make quorum + * is the name of kafka consumer group + * is a list of one or more kafka topics to consume from + * is the number of threads the kafka consumer should use + * + * Example: + * `./run-example spark.streaming.examples.KafkaWordCount local[2] zoo01,zoo02,zoo03 my-consumer-group topic1,topic2 1` + */ +object KafkaWordCount { + def main(args: Array[String]) { + + if (args.length < 5) { + System.err.println("Usage: KafkaWordCount ") + System.exit(1) + } + + val Array(master, zkQuorum, group, topics, numThreads) = args + + val ssc = new StreamingContext(master, "KafkaWordCount", Seconds(2), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + ssc.checkpoint("checkpoint") + + val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap + val lines = ssc.kafkaStream(zkQuorum, group, topicpMap) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2) + wordCounts.print() + + ssc.start() + } +} + +// Produces some random words between 1 and 100. +object KafkaWordCountProducer { + + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: KafkaWordCountProducer ") + System.exit(1) + } + + val Array(zkQuorum, topic, messagesPerSec, wordsPerMessage) = args + + // Zookeper connection properties + val props = new Properties() + props.put("zk.connect", zkQuorum) + props.put("serializer.class", "kafka.serializer.StringEncoder") + + val config = new ProducerConfig(props) + val producer = new Producer[String, String](config) + + // Send some messages + while(true) { + val messages = (1 to messagesPerSec.toInt).map { messageNum => + (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString).mkString(" ") + }.toArray + println(messages.mkString(",")) + val data = new ProducerData[String, String](topic, messages) + producer.send(data) + Thread.sleep(100) + } + } + +} + diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala new file mode 100644 index 0000000000..e2487dca5f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.streaming.StreamingContext._ + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + * Usage: NetworkWordCount + * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ ./run-example spark.streaming.examples.NetworkWordCount local[2] localhost 9999` + */ +object NetworkWordCount { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: NetworkWordCount \n" + + "In local mode, should be 'local[n]' with n > 1") + System.exit(1) + } + + // Create the context with a 1 second batch size + val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited test (eg. generated by 'nc') + val lines = ssc.socketTextStream(args(1), args(2).toInt) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala new file mode 100644 index 0000000000..822da8c9b5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.RDD +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.streaming.StreamingContext._ + +import scala.collection.mutable.SynchronizedQueue + +object QueueStream { + + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: QueueStream ") + System.exit(1) + } + + // Create the context + val ssc = new StreamingContext(args(0), "QueueStream", Seconds(1), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + // Create the queue through which RDDs can be pushed to + // a QueueInputDStream + val rddQueue = new SynchronizedQueue[RDD[Int]]() + + // Create the QueueInputDStream and use it do some processing + val inputStream = ssc.queueStream(rddQueue) + val mappedStream = inputStream.map(x => (x % 10, 1)) + val reducedStream = mappedStream.reduceByKey(_ + _) + reducedStream.print() + ssc.start() + + // Create and push some RDDs into + for (i <- 1 to 30) { + rddQueue += ssc.sparkContext.makeRDD(1 to 1000, 10) + Thread.sleep(1000) + } + ssc.stop() + System.exit(0) + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala new file mode 100644 index 0000000000..2e3d9ccf00 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.util.IntParam +import org.apache.spark.storage.StorageLevel + +import org.apache.spark.streaming._ +import org.apache.spark.streaming.util.RawTextHelper + +/** + * Receives text from multiple rawNetworkStreams and counts how many '\n' delimited + * lines have the word 'the' in them. This is useful for benchmarking purposes. This + * will only work with spark.streaming.util.RawTextSender running on all worker nodes + * and with Spark using Kryo serialization (set Java property "spark.serializer" to + * "org.apache.spark.KryoSerializer"). + * Usage: RawNetworkGrep + * is the Spark master URL + * is the number rawNetworkStreams, which should be same as number + * of work nodes in the cluster + * is "localhost". + * is the port on which RawTextSender is running in the worker nodes. + * is the Spark Streaming batch duration in milliseconds. + */ + +object RawNetworkGrep { + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println("Usage: RawNetworkGrep ") + System.exit(1) + } + + val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args + + // Create the context + val ssc = new StreamingContext(master, "RawNetworkGrep", Milliseconds(batchMillis), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + // Warm up the JVMs on master and slave for JIT compilation to kick in + RawTextHelper.warmUp(ssc.sparkContext) + + val rawStreams = (1 to numStreams).map(_ => + ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray + val union = ssc.union(rawStreams) + union.filter(_.contains("the")).count().foreach(r => + println("Grep count: " + r.collect().mkString)) + ssc.start() + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala new file mode 100644 index 0000000000..cb30c4edb3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.streaming._ +import org.apache.spark.streaming.StreamingContext._ + +/** + * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every second. + * Usage: StatefulNetworkWordCount + * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. + * and describe the TCP server that Spark Streaming would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ ./run-example spark.streaming.examples.StatefulNetworkWordCount local[2] localhost 9999` + */ +object StatefulNetworkWordCount { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: StatefulNetworkWordCount \n" + + "In local mode, should be 'local[n]' with n > 1") + System.exit(1) + } + + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + val currentCount = values.foldLeft(0)(_ + _) + + val previousCount = state.getOrElse(0) + + Some(currentCount + previousCount) + } + + // Create the context with a 1 second batch size + val ssc = new StreamingContext(args(0), "NetworkWordCumulativeCountUpdateStateByKey", Seconds(1), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + ssc.checkpoint(".") + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited test (eg. generated by 'nc') + val lines = ssc.socketTextStream(args(1), args(2).toInt) + val words = lines.flatMap(_.split(" ")) + val wordDstream = words.map(x => (x, 1)) + + // Update the cumulative count using updateStateByKey + // This will give a Dstream made of state (which is the cumulative count of the words) + val stateDstream = wordDstream.updateStateByKey[Int](updateFunc) + stateDstream.print() + ssc.start() + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala new file mode 100644 index 0000000000..35b6329ab3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.storage.StorageLevel +import com.twitter.algebird._ +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.SparkContext._ + +/** + * Illustrates the use of the Count-Min Sketch, from Twitter's Algebird library, to compute + * windowed and global Top-K estimates of user IDs occurring in a Twitter stream. + *
    + * Note that since Algebird's implementation currently only supports Long inputs, + * the example operates on Long IDs. Once the implementation supports other inputs (such as String), + * the same approach could be used for computing popular topics for example. + *

    + *

    + * + * This blog post has a good overview of the Count-Min Sketch (CMS). The CMS is a datastructure + * for approximate frequency estimation in data streams (e.g. Top-K elements, frequency of any given element, etc), + * that uses space sub-linear in the number of elements in the stream. Once elements are added to the CMS, the + * estimated count of an element can be computed, as well as "heavy-hitters" that occur more than a threshold + * percentage of the overall total count. + *

    + * Algebird's implementation is a monoid, so we can succinctly merge two CMS instances in the reduce operation. + */ +object TwitterAlgebirdCMS { + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: TwitterAlgebirdCMS " + + " [filter1] [filter2] ... [filter n]") + System.exit(1) + } + + // CMS parameters + val DELTA = 1E-3 + val EPS = 0.01 + val SEED = 1 + val PERC = 0.001 + // K highest frequency elements to take + val TOPK = 10 + + val (master, filters) = (args.head, args.tail) + + val ssc = new StreamingContext(master, "TwitterAlgebirdCMS", Seconds(10), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val stream = ssc.twitterStream(None, filters, StorageLevel.MEMORY_ONLY_SER) + + val users = stream.map(status => status.getUser.getId) + + val cms = new CountMinSketchMonoid(EPS, DELTA, SEED, PERC) + var globalCMS = cms.zero + val mm = new MapMonoid[Long, Int]() + var globalExact = Map[Long, Int]() + + val approxTopUsers = users.mapPartitions(ids => { + ids.map(id => cms.create(id)) + }).reduce(_ ++ _) + + val exactTopUsers = users.map(id => (id, 1)) + .reduceByKey((a, b) => a + b) + + approxTopUsers.foreach(rdd => { + if (rdd.count() != 0) { + val partial = rdd.first() + val partialTopK = partial.heavyHitters.map(id => + (id, partial.frequency(id).estimate)).toSeq.sortBy(_._2).reverse.slice(0, TOPK) + globalCMS ++= partial + val globalTopK = globalCMS.heavyHitters.map(id => + (id, globalCMS.frequency(id).estimate)).toSeq.sortBy(_._2).reverse.slice(0, TOPK) + println("Approx heavy hitters at %2.2f%% threshold this batch: %s".format(PERC, + partialTopK.mkString("[", ",", "]"))) + println("Approx heavy hitters at %2.2f%% threshold overall: %s".format(PERC, + globalTopK.mkString("[", ",", "]"))) + } + }) + + exactTopUsers.foreach(rdd => { + if (rdd.count() != 0) { + val partialMap = rdd.collect().toMap + val partialTopK = rdd.map( + {case (id, count) => (count, id)}) + .sortByKey(ascending = false).take(TOPK) + globalExact = mm.plus(globalExact.toMap, partialMap) + val globalTopK = globalExact.toSeq.sortBy(_._2).reverse.slice(0, TOPK) + println("Exact heavy hitters this batch: %s".format(partialTopK.mkString("[", ",", "]"))) + println("Exact heavy hitters overall: %s".format(globalTopK.mkString("[", ",", "]"))) + } + }) + + ssc.start() + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala new file mode 100644 index 0000000000..8bfde2a829 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.storage.StorageLevel +import com.twitter.algebird.HyperLogLog._ +import com.twitter.algebird.HyperLogLogMonoid +import org.apache.spark.streaming.dstream.TwitterInputDStream + +/** + * Illustrates the use of the HyperLogLog algorithm, from Twitter's Algebird library, to compute + * a windowed and global estimate of the unique user IDs occurring in a Twitter stream. + *

    + *

    + * This + * blog post and this + * blog post + * have good overviews of HyperLogLog (HLL). HLL is a memory-efficient datastructure for estimating + * the cardinality of a data stream, i.e. the number of unique elements. + *

    + * Algebird's implementation is a monoid, so we can succinctly merge two HLL instances in the reduce operation. + */ +object TwitterAlgebirdHLL { + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: TwitterAlgebirdHLL " + + " [filter1] [filter2] ... [filter n]") + System.exit(1) + } + + /** Bit size parameter for HyperLogLog, trades off accuracy vs size */ + val BIT_SIZE = 12 + val (master, filters) = (args.head, args.tail) + + val ssc = new StreamingContext(master, "TwitterAlgebirdHLL", Seconds(5), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val stream = ssc.twitterStream(None, filters, StorageLevel.MEMORY_ONLY_SER) + + val users = stream.map(status => status.getUser.getId) + + val hll = new HyperLogLogMonoid(BIT_SIZE) + var globalHll = hll.zero + var userSet: Set[Long] = Set() + + val approxUsers = users.mapPartitions(ids => { + ids.map(id => hll(id)) + }).reduce(_ + _) + + val exactUsers = users.map(id => Set(id)).reduce(_ ++ _) + + approxUsers.foreach(rdd => { + if (rdd.count() != 0) { + val partial = rdd.first() + globalHll += partial + println("Approx distinct users this batch: %d".format(partial.estimatedSize.toInt)) + println("Approx distinct users overall: %d".format(globalHll.estimatedSize.toInt)) + } + }) + + exactUsers.foreach(rdd => { + if (rdd.count() != 0) { + val partial = rdd.first() + userSet ++= partial + println("Exact distinct users this batch: %d".format(partial.size)) + println("Exact distinct users overall: %d".format(userSet.size)) + println("Error rate: %2.5f%%".format(((globalHll.estimatedSize / userSet.size.toDouble) - 1) * 100)) + } + }) + + ssc.start() + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala new file mode 100644 index 0000000000..27aa6b14bf --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.streaming.{Seconds, StreamingContext} +import StreamingContext._ +import org.apache.spark.SparkContext._ + +/** + * Calculates popular hashtags (topics) over sliding 10 and 60 second windows from a Twitter + * stream. The stream is instantiated with credentials and optionally filters supplied by the + * command line arguments. + * + */ +object TwitterPopularTags { + def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: TwitterPopularTags " + + " [filter1] [filter2] ... [filter n]") + System.exit(1) + } + + val (master, filters) = (args.head, args.tail) + + val ssc = new StreamingContext(master, "TwitterPopularTags", Seconds(2), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val stream = ssc.twitterStream(None, filters) + + val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#"))) + + val topCounts60 = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(60)) + .map{case (topic, count) => (count, topic)} + .transform(_.sortByKey(false)) + + val topCounts10 = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(10)) + .map{case (topic, count) => (count, topic)} + .transform(_.sortByKey(false)) + + + // Print popular hashtags + topCounts60.foreach(rdd => { + val topList = rdd.take(5) + println("\nPopular topics in last 60 seconds (%s total):".format(rdd.count())) + topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} + }) + + topCounts10.foreach(rdd => { + val topList = rdd.take(5) + println("\nPopular topics in last 10 seconds (%s total):".format(rdd.count())) + topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} + }) + + ssc.start() + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala new file mode 100644 index 0000000000..c8743b9e25 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import akka.actor.ActorSystem +import akka.actor.actorRef2Scala +import akka.zeromq._ +import org.apache.spark.streaming.{ Seconds, StreamingContext } +import org.apache.spark.streaming.StreamingContext._ +import akka.zeromq.Subscribe + +/** + * A simple publisher for demonstration purposes, repeatedly publishes random Messages + * every one second. + */ +object SimpleZeroMQPublisher { + + def main(args: Array[String]) = { + if (args.length < 2) { + System.err.println("Usage: SimpleZeroMQPublisher ") + System.exit(1) + } + + val Seq(url, topic) = args.toSeq + val acs: ActorSystem = ActorSystem() + + val pubSocket = ZeroMQExtension(acs).newSocket(SocketType.Pub, Bind(url)) + val messages: Array[String] = Array("words ", "may ", "count ") + while (true) { + Thread.sleep(1000) + pubSocket ! ZMQMessage(Frame(topic) :: messages.map(x => Frame(x.getBytes)).toList) + } + acs.awaitTermination() + } +} + +/** + * A sample wordcount with ZeroMQStream stream + * + * To work with zeroMQ, some native libraries have to be installed. + * Install zeroMQ (release 2.1) core libraries. [ZeroMQ Install guide](http://www.zeromq.org/intro:get-the-software) + * + * Usage: ZeroMQWordCount + * In local mode, should be 'local[n]' with n > 1 + * and describe where zeroMq publisher is running. + * + * To run this example locally, you may run publisher as + * `$ ./run-example spark.streaming.examples.SimpleZeroMQPublisher tcp://127.0.1.1:1234 foo.bar` + * and run the example as + * `$ ./run-example spark.streaming.examples.ZeroMQWordCount local[2] tcp://127.0.1.1:1234 foo` + */ +object ZeroMQWordCount { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println( + "Usage: ZeroMQWordCount " + + "In local mode, should be 'local[n]' with n > 1") + System.exit(1) + } + val Seq(master, url, topic) = args.toSeq + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "ZeroMQWordCount", Seconds(2), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + def bytesToStringIterator(x: Seq[Seq[Byte]]) = (x.map(x => new String(x.toArray))).iterator + + //For this stream, a zeroMQ publisher should be running. + val lines = ssc.zeroMQStream(url, Subscribe(topic), bytesToStringIterator) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } + +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala new file mode 100644 index 0000000000..884d6d6f34 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples.clickstream + +import java.net.{InetAddress,ServerSocket,Socket,SocketException} +import java.io.{InputStreamReader, BufferedReader, PrintWriter} +import util.Random + +/** Represents a page view on a website with associated dimension data.*/ +class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int) { + override def toString() : String = { + "%s\t%s\t%s\t%s\n".format(url, status, zipCode, userID) + } +} +object PageView { + def fromString(in : String) : PageView = { + val parts = in.split("\t") + new PageView(parts(0), parts(1).toInt, parts(2).toInt, parts(3).toInt) + } +} + +/** Generates streaming events to simulate page views on a website. + * + * This should be used in tandem with PageViewStream.scala. Example: + * $ ./run-example spark.streaming.examples.clickstream.PageViewGenerator 44444 10 + * $ ./run-example spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444 + * */ +object PageViewGenerator { + val pages = Map("http://foo.com/" -> .7, + "http://foo.com/news" -> 0.2, + "http://foo.com/contact" -> .1) + val httpStatus = Map(200 -> .95, + 404 -> .05) + val userZipCode = Map(94709 -> .5, + 94117 -> .5) + val userID = Map((1 to 100).map(_ -> .01):_*) + + + def pickFromDistribution[T](inputMap : Map[T, Double]) : T = { + val rand = new Random().nextDouble() + var total = 0.0 + for ((item, prob) <- inputMap) { + total = total + prob + if (total > rand) { + return item + } + } + return inputMap.take(1).head._1 // Shouldn't get here if probabilities add up to 1.0 + } + + def getNextClickEvent() : String = { + val id = pickFromDistribution(userID) + val page = pickFromDistribution(pages) + val status = pickFromDistribution(httpStatus) + val zipCode = pickFromDistribution(userZipCode) + new PageView(page, status, zipCode, id).toString() + } + + def main(args : Array[String]) { + if (args.length != 2) { + System.err.println("Usage: PageViewGenerator ") + System.exit(1) + } + val port = args(0).toInt + val viewsPerSecond = args(1).toFloat + val sleepDelayMs = (1000.0 / viewsPerSecond).toInt + val listener = new ServerSocket(port) + println("Listening on port: " + port) + + while (true) { + val socket = listener.accept() + new Thread() { + override def run = { + println("Got client connected from: " + socket.getInetAddress) + val out = new PrintWriter(socket.getOutputStream(), true) + + while (true) { + Thread.sleep(sleepDelayMs) + out.write(getNextClickEvent()) + out.flush() + } + socket.close() + } + }.start() + } + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala new file mode 100644 index 0000000000..8282cc9269 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples.clickstream + +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.SparkContext._ + +/** Analyses a streaming dataset of web page views. This class demonstrates several types of + * operators available in Spark streaming. + * + * This should be used in tandem with PageViewStream.scala. Example: + * $ ./run-example spark.streaming.examples.clickstream.PageViewGenerator 44444 10 + * $ ./run-example spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444 + */ +object PageViewStream { + def main(args: Array[String]) { + if (args.length != 3) { + System.err.println("Usage: PageViewStream ") + System.err.println(" must be one of pageCounts, slidingPageCounts," + + " errorRatePerZipCode, activeUserCount, popularUsersSeen") + System.exit(1) + } + val metric = args(0) + val host = args(1) + val port = args(2).toInt + + // Create the context + val ssc = new StreamingContext("local[2]", "PageViewStream", Seconds(1), + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + + // Create a NetworkInputDStream on target host:port and convert each line to a PageView + val pageViews = ssc.socketTextStream(host, port) + .flatMap(_.split("\n")) + .map(PageView.fromString(_)) + + // Return a count of views per URL seen in each batch + val pageCounts = pageViews.map(view => view.url).countByValue() + + // Return a sliding window of page views per URL in the last ten seconds + val slidingPageCounts = pageViews.map(view => view.url) + .countByValueAndWindow(Seconds(10), Seconds(2)) + + + // Return the rate of error pages (a non 200 status) in each zip code over the last 30 seconds + val statusesPerZipCode = pageViews.window(Seconds(30), Seconds(2)) + .map(view => ((view.zipCode, view.status))) + .groupByKey() + val errorRatePerZipCode = statusesPerZipCode.map{ + case(zip, statuses) => + val normalCount = statuses.filter(_ == 200).size + val errorCount = statuses.size - normalCount + val errorRatio = errorCount.toFloat / statuses.size + if (errorRatio > 0.05) {"%s: **%s**".format(zip, errorRatio)} + else {"%s: %s".format(zip, errorRatio)} + } + + // Return the number unique users in last 15 seconds + val activeUserCount = pageViews.window(Seconds(15), Seconds(2)) + .map(view => (view.userID, 1)) + .groupByKey() + .count() + .map("Unique active users: " + _) + + // An external dataset we want to join to this stream + val userList = ssc.sparkContext.parallelize( + Map(1 -> "Patrick Wendell", 2->"Reynold Xin", 3->"Matei Zaharia").toSeq) + + metric match { + case "pageCounts" => pageCounts.print() + case "slidingPageCounts" => slidingPageCounts.print() + case "errorRatePerZipCode" => errorRatePerZipCode.print() + case "activeUserCount" => activeUserCount.print() + case "popularUsersSeen" => + // Look for users in our existing dataset and print it out if we have a match + pageViews.map(view => (view.userID, 1)) + .foreach((rdd, time) => rdd.join(userList) + .map(_._2._2) + .take(10) + .foreach(u => println("Saw user %s at time %s".format(u, time)))) + case _ => println("Invalid metric entered: " + metric) + } + + ssc.start() + } +} diff --git a/examples/src/main/scala/spark/examples/BroadcastTest.scala b/examples/src/main/scala/spark/examples/BroadcastTest.scala deleted file mode 100644 index 911490cb6c..0000000000 --- a/examples/src/main/scala/spark/examples/BroadcastTest.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark.SparkContext - -object BroadcastTest { - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: BroadcastTest [] [numElem]") - System.exit(1) - } - - val sc = new SparkContext(args(0), "Broadcast Test", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val slices = if (args.length > 1) args(1).toInt else 2 - val num = if (args.length > 2) args(2).toInt else 1000000 - - var arr1 = new Array[Int](num) - for (i <- 0 until arr1.length) { - arr1(i) = i - } - - for (i <- 0 until 2) { - println("Iteration " + i) - println("===========") - val barr1 = sc.broadcast(arr1) - sc.parallelize(1 to 10, slices).foreach { - i => println(barr1.value.size) - } - } - - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/CassandraTest.scala b/examples/src/main/scala/spark/examples/CassandraTest.scala deleted file mode 100644 index 104bfd5204..0000000000 --- a/examples/src/main/scala/spark/examples/CassandraTest.scala +++ /dev/null @@ -1,213 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import org.apache.hadoop.mapreduce.Job -import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat -import org.apache.cassandra.hadoop.ConfigHelper -import org.apache.cassandra.hadoop.ColumnFamilyInputFormat -import org.apache.cassandra.thrift._ -import spark.SparkContext -import spark.SparkContext._ -import java.nio.ByteBuffer -import java.util.SortedMap -import org.apache.cassandra.db.IColumn -import org.apache.cassandra.utils.ByteBufferUtil -import scala.collection.JavaConversions._ - - -/* - * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra - * support for Hadoop. - * - * To run this example, run this file with the following command params - - * - * - * So if you want to run this on localhost this will be, - * local[3] localhost 9160 - * - * The example makes some assumptions: - * 1. You have already created a keyspace called casDemo and it has a column family named Words - * 2. There are column family has a column named "para" which has test content. - * - * You can create the content by running the following script at the bottom of this file with - * cassandra-cli. - * - */ -object CassandraTest { - - def main(args: Array[String]) { - - // Get a SparkContext - val sc = new SparkContext(args(0), "casDemo") - - // Build the job configuration with ConfigHelper provided by Cassandra - val job = new Job() - job.setInputFormatClass(classOf[ColumnFamilyInputFormat]) - - val host: String = args(1) - val port: String = args(2) - - ConfigHelper.setInputInitialAddress(job.getConfiguration(), host) - ConfigHelper.setInputRpcPort(job.getConfiguration(), port) - ConfigHelper.setOutputInitialAddress(job.getConfiguration(), host) - ConfigHelper.setOutputRpcPort(job.getConfiguration(), port) - ConfigHelper.setInputColumnFamily(job.getConfiguration(), "casDemo", "Words") - ConfigHelper.setOutputColumnFamily(job.getConfiguration(), "casDemo", "WordCount") - - val predicate = new SlicePredicate() - val sliceRange = new SliceRange() - sliceRange.setStart(Array.empty[Byte]) - sliceRange.setFinish(Array.empty[Byte]) - predicate.setSlice_range(sliceRange) - ConfigHelper.setInputSlicePredicate(job.getConfiguration(), predicate) - - ConfigHelper.setInputPartitioner(job.getConfiguration(), "Murmur3Partitioner") - ConfigHelper.setOutputPartitioner(job.getConfiguration(), "Murmur3Partitioner") - - // Make a new Hadoop RDD - val casRdd = sc.newAPIHadoopRDD( - job.getConfiguration(), - classOf[ColumnFamilyInputFormat], - classOf[ByteBuffer], - classOf[SortedMap[ByteBuffer, IColumn]]) - - // Let us first get all the paragraphs from the retrieved rows - val paraRdd = casRdd.map { - case (key, value) => { - ByteBufferUtil.string(value.get(ByteBufferUtil.bytes("para")).value()) - } - } - - // Lets get the word count in paras - val counts = paraRdd.flatMap(p => p.split(" ")).map(word => (word, 1)).reduceByKey(_ + _) - - counts.collect().foreach { - case (word, count) => println(word + ":" + count) - } - - counts.map { - case (word, count) => { - val colWord = new org.apache.cassandra.thrift.Column() - colWord.setName(ByteBufferUtil.bytes("word")) - colWord.setValue(ByteBufferUtil.bytes(word)) - colWord.setTimestamp(System.currentTimeMillis) - - val colCount = new org.apache.cassandra.thrift.Column() - colCount.setName(ByteBufferUtil.bytes("wcount")) - colCount.setValue(ByteBufferUtil.bytes(count.toLong)) - colCount.setTimestamp(System.currentTimeMillis) - - val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis) - - val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil - mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn()) - mutations.get(0).column_or_supercolumn.setColumn(colWord) - mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) - mutations.get(1).column_or_supercolumn.setColumn(colCount) - (outputkey, mutations) - } - }.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]], - classOf[ColumnFamilyOutputFormat], job.getConfiguration) - } -} - -/* -create keyspace casDemo; -use casDemo; - -create column family WordCount with comparator = UTF8Type; -update column family WordCount with column_metadata = - [{column_name: word, validation_class: UTF8Type}, - {column_name: wcount, validation_class: LongType}]; - -create column family Words with comparator = UTF8Type; -update column family Words with column_metadata = - [{column_name: book, validation_class: UTF8Type}, - {column_name: para, validation_class: UTF8Type}]; - -assume Words keys as utf8; - -set Words['3musk001']['book'] = 'The Three Musketeers'; -set Words['3musk001']['para'] = 'On the first Monday of the month of April, 1625, the market - town of Meung, in which the author of ROMANCE OF THE ROSE was born, appeared to - be in as perfect a state of revolution as if the Huguenots had just made - a second La Rochelle of it. Many citizens, seeing the women flying - toward the High Street, leaving their children crying at the open doors, - hastened to don the cuirass, and supporting their somewhat uncertain - courage with a musket or a partisan, directed their steps toward the - hostelry of the Jolly Miller, before which was gathered, increasing - every minute, a compact group, vociferous and full of curiosity.'; - -set Words['3musk002']['book'] = 'The Three Musketeers'; -set Words['3musk002']['para'] = 'In those times panics were common, and few days passed without - some city or other registering in its archives an event of this kind. There were - nobles, who made war against each other; there was the king, who made - war against the cardinal; there was Spain, which made war against the - king. Then, in addition to these concealed or public, secret or open - wars, there were robbers, mendicants, Huguenots, wolves, and scoundrels, - who made war upon everybody. The citizens always took up arms readily - against thieves, wolves or scoundrels, often against nobles or - Huguenots, sometimes against the king, but never against cardinal or - Spain. It resulted, then, from this habit that on the said first Monday - of April, 1625, the citizens, on hearing the clamor, and seeing neither - the red-and-yellow standard nor the livery of the Duc de Richelieu, - rushed toward the hostel of the Jolly Miller. When arrived there, the - cause of the hubbub was apparent to all'; - -set Words['3musk003']['book'] = 'The Three Musketeers'; -set Words['3musk003']['para'] = 'You ought, I say, then, to husband the means you have, however - large the sum may be; but you ought also to endeavor to perfect yourself in - the exercises becoming a gentleman. I will write a letter today to the - Director of the Royal Academy, and tomorrow he will admit you without - any expense to yourself. Do not refuse this little service. Our - best-born and richest gentlemen sometimes solicit it without being able - to obtain it. You will learn horsemanship, swordsmanship in all its - branches, and dancing. You will make some desirable acquaintances; and - from time to time you can call upon me, just to tell me how you are - getting on, and to say whether I can be of further service to you.'; - - -set Words['thelostworld001']['book'] = 'The Lost World'; -set Words['thelostworld001']['para'] = 'She sat with that proud, delicate profile of hers outlined - against the red curtain. How beautiful she was! And yet how aloof! We had been - friends, quite good friends; but never could I get beyond the same - comradeship which I might have established with one of my - fellow-reporters upon the Gazette,--perfectly frank, perfectly kindly, - and perfectly unsexual. My instincts are all against a woman being too - frank and at her ease with me. It is no compliment to a man. Where - the real sex feeling begins, timidity and distrust are its companions, - heritage from old wicked days when love and violence went often hand in - hand. The bent head, the averted eye, the faltering voice, the wincing - figure--these, and not the unshrinking gaze and frank reply, are the - true signals of passion. Even in my short life I had learned as much - as that--or had inherited it in that race memory which we call instinct.'; - -set Words['thelostworld002']['book'] = 'The Lost World'; -set Words['thelostworld002']['para'] = 'I always liked McArdle, the crabbed, old, round-backed, - red-headed news editor, and I rather hoped that he liked me. Of course, Beaumont was - the real boss; but he lived in the rarefied atmosphere of some Olympian - height from which he could distinguish nothing smaller than an - international crisis or a split in the Cabinet. Sometimes we saw him - passing in lonely majesty to his inner sanctum, with his eyes staring - vaguely and his mind hovering over the Balkans or the Persian Gulf. He - was above and beyond us. But McArdle was his first lieutenant, and it - was he that we knew. The old man nodded as I entered the room, and he - pushed his spectacles far up on his bald forehead.'; - -*/ diff --git a/examples/src/main/scala/spark/examples/ExceptionHandlingTest.scala b/examples/src/main/scala/spark/examples/ExceptionHandlingTest.scala deleted file mode 100644 index 67ddaec8d2..0000000000 --- a/examples/src/main/scala/spark/examples/ExceptionHandlingTest.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark.SparkContext - -object ExceptionHandlingTest { - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: ExceptionHandlingTest ") - System.exit(1) - } - - val sc = new SparkContext(args(0), "ExceptionHandlingTest", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - sc.parallelize(0 until sc.defaultParallelism).foreach { i => - if (math.random > 0.75) - throw new Exception("Testing exception handling") - } - - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/GroupByTest.scala b/examples/src/main/scala/spark/examples/GroupByTest.scala deleted file mode 100644 index 5cee413615..0000000000 --- a/examples/src/main/scala/spark/examples/GroupByTest.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark.SparkContext -import spark.SparkContext._ -import java.util.Random - -object GroupByTest { - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]") - System.exit(1) - } - - var numMappers = if (args.length > 1) args(1).toInt else 2 - var numKVPairs = if (args.length > 2) args(2).toInt else 1000 - var valSize = if (args.length > 3) args(3).toInt else 1000 - var numReducers = if (args.length > 4) args(4).toInt else numMappers - - val sc = new SparkContext(args(0), "GroupBy Test", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => - val ranGen = new Random - var arr1 = new Array[(Int, Array[Byte])](numKVPairs) - for (i <- 0 until numKVPairs) { - val byteArr = new Array[Byte](valSize) - ranGen.nextBytes(byteArr) - arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr) - } - arr1 - }.cache - // Enforce that everything has been calculated and in cache - pairs1.count - - println(pairs1.groupByKey(numReducers).count) - - System.exit(0) - } -} - diff --git a/examples/src/main/scala/spark/examples/HBaseTest.scala b/examples/src/main/scala/spark/examples/HBaseTest.scala deleted file mode 100644 index 4dd6c243ac..0000000000 --- a/examples/src/main/scala/spark/examples/HBaseTest.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark._ -import spark.rdd.NewHadoopRDD -import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor} -import org.apache.hadoop.hbase.client.HBaseAdmin -import org.apache.hadoop.hbase.mapreduce.TableInputFormat - -object HBaseTest { - def main(args: Array[String]) { - val sc = new SparkContext(args(0), "HBaseTest", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - val conf = HBaseConfiguration.create() - - // Other options for configuring scan behavior are available. More information available at - // http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html - conf.set(TableInputFormat.INPUT_TABLE, args(1)) - - // Initialize hBase table if necessary - val admin = new HBaseAdmin(conf) - if(!admin.isTableAvailable(args(1))) { - val tableDesc = new HTableDescriptor(args(1)) - admin.createTable(tableDesc) - } - - val hBaseRDD = sc.newAPIHadoopRDD(conf, classOf[TableInputFormat], - classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable], - classOf[org.apache.hadoop.hbase.client.Result]) - - hBaseRDD.count() - - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/HdfsTest.scala b/examples/src/main/scala/spark/examples/HdfsTest.scala deleted file mode 100644 index 23258336e2..0000000000 --- a/examples/src/main/scala/spark/examples/HdfsTest.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark._ - -object HdfsTest { - def main(args: Array[String]) { - val sc = new SparkContext(args(0), "HdfsTest", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val file = sc.textFile(args(1)) - val mapped = file.map(s => s.length).cache() - for (iter <- 1 to 10) { - val start = System.currentTimeMillis() - for (x <- mapped) { x + 2 } - // println("Processing: " + x) - val end = System.currentTimeMillis() - println("Iteration " + iter + " took " + (end-start) + " ms") - } - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/LocalALS.scala b/examples/src/main/scala/spark/examples/LocalALS.scala deleted file mode 100644 index 7a449a9d72..0000000000 --- a/examples/src/main/scala/spark/examples/LocalALS.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import scala.math.sqrt -import cern.jet.math._ -import cern.colt.matrix._ -import cern.colt.matrix.linalg._ - -/** - * Alternating least squares matrix factorization. - */ -object LocalALS { - // Parameters set through command line arguments - var M = 0 // Number of movies - var U = 0 // Number of users - var F = 0 // Number of features - var ITERATIONS = 0 - - val LAMBDA = 0.01 // Regularization coefficient - - // Some COLT objects - val factory2D = DoubleFactory2D.dense - val factory1D = DoubleFactory1D.dense - val algebra = Algebra.DEFAULT - val blas = SeqBlas.seqBlas - - def generateR(): DoubleMatrix2D = { - val mh = factory2D.random(M, F) - val uh = factory2D.random(U, F) - return algebra.mult(mh, algebra.transpose(uh)) - } - - def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D], - us: Array[DoubleMatrix1D]): Double = - { - val r = factory2D.make(M, U) - for (i <- 0 until M; j <- 0 until U) { - r.set(i, j, blas.ddot(ms(i), us(j))) - } - //println("R: " + r) - blas.daxpy(-1, targetR, r) - val sumSqs = r.aggregate(Functions.plus, Functions.square) - return sqrt(sumSqs / (M * U)) - } - - def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], - R: DoubleMatrix2D) : DoubleMatrix1D = - { - val XtX = factory2D.make(F, F) - val Xty = factory1D.make(F) - // For each user that rated the movie - for (j <- 0 until U) { - val u = us(j) - // Add u * u^t to XtX - blas.dger(1, u, u, XtX) - // Add u * rating to Xty - blas.daxpy(R.get(i, j), u, Xty) - } - // Add regularization coefs to diagonal terms - for (d <- 0 until F) { - XtX.set(d, d, XtX.get(d, d) + LAMBDA * U) - } - // Solve it with Cholesky - val ch = new CholeskyDecomposition(XtX) - val Xty2D = factory2D.make(Xty.toArray, F) - val solved2D = ch.solve(Xty2D) - return solved2D.viewColumn(0) - } - - def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D], - R: DoubleMatrix2D) : DoubleMatrix1D = - { - val XtX = factory2D.make(F, F) - val Xty = factory1D.make(F) - // For each movie that the user rated - for (i <- 0 until M) { - val m = ms(i) - // Add m * m^t to XtX - blas.dger(1, m, m, XtX) - // Add m * rating to Xty - blas.daxpy(R.get(i, j), m, Xty) - } - // Add regularization coefs to diagonal terms - for (d <- 0 until F) { - XtX.set(d, d, XtX.get(d, d) + LAMBDA * M) - } - // Solve it with Cholesky - val ch = new CholeskyDecomposition(XtX) - val Xty2D = factory2D.make(Xty.toArray, F) - val solved2D = ch.solve(Xty2D) - return solved2D.viewColumn(0) - } - - def main(args: Array[String]) { - args match { - case Array(m, u, f, iters) => { - M = m.toInt - U = u.toInt - F = f.toInt - ITERATIONS = iters.toInt - } - case _ => { - System.err.println("Usage: LocalALS ") - System.exit(1) - } - } - printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS); - - val R = generateR() - - // Initialize m and u randomly - var ms = Array.fill(M)(factory1D.random(F)) - var us = Array.fill(U)(factory1D.random(F)) - - // Iteratively update movies then users - for (iter <- 1 to ITERATIONS) { - println("Iteration " + iter + ":") - ms = (0 until M).map(i => updateMovie(i, ms(i), us, R)).toArray - us = (0 until U).map(j => updateUser(j, us(j), ms, R)).toArray - println("RMSE = " + rmse(R, ms, us)) - println() - } - } -} diff --git a/examples/src/main/scala/spark/examples/LocalFileLR.scala b/examples/src/main/scala/spark/examples/LocalFileLR.scala deleted file mode 100644 index c1f8d32aa8..0000000000 --- a/examples/src/main/scala/spark/examples/LocalFileLR.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import java.util.Random -import spark.util.Vector - -object LocalFileLR { - val D = 10 // Numer of dimensions - val rand = new Random(42) - - case class DataPoint(x: Vector, y: Double) - - def parsePoint(line: String): DataPoint = { - val nums = line.split(' ').map(_.toDouble) - return DataPoint(new Vector(nums.slice(1, D+1)), nums(0)) - } - - def main(args: Array[String]) { - val lines = scala.io.Source.fromFile(args(0)).getLines().toArray - val points = lines.map(parsePoint _) - val ITERATIONS = args(1).toInt - - // Initialize w to a random value - var w = Vector(D, _ => 2 * rand.nextDouble - 1) - println("Initial w: " + w) - - for (i <- 1 to ITERATIONS) { - println("On iteration " + i) - var gradient = Vector.zeros(D) - for (p <- points) { - val scale = (1 / (1 + math.exp(-p.y * (w dot p.x))) - 1) * p.y - gradient += scale * p.x - } - w -= gradient - } - - println("Final w: " + w) - } -} diff --git a/examples/src/main/scala/spark/examples/LocalKMeans.scala b/examples/src/main/scala/spark/examples/LocalKMeans.scala deleted file mode 100644 index 0a0bc6f476..0000000000 --- a/examples/src/main/scala/spark/examples/LocalKMeans.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import java.util.Random -import spark.util.Vector -import spark.SparkContext._ -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -/** - * K-means clustering. - */ -object LocalKMeans { - val N = 1000 - val R = 1000 // Scaling factor - val D = 10 - val K = 10 - val convergeDist = 0.001 - val rand = new Random(42) - - def generateData = { - def generatePoint(i: Int) = { - Vector(D, _ => rand.nextDouble * R) - } - Array.tabulate(N)(generatePoint) - } - - def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = { - var index = 0 - var bestIndex = 0 - var closest = Double.PositiveInfinity - - for (i <- 1 to centers.size) { - val vCurr = centers.get(i).get - val tempDist = p.squaredDist(vCurr) - if (tempDist < closest) { - closest = tempDist - bestIndex = i - } - } - - return bestIndex - } - - def main(args: Array[String]) { - val data = generateData - var points = new HashSet[Vector] - var kPoints = new HashMap[Int, Vector] - var tempDist = 1.0 - - while (points.size < K) { - points.add(data(rand.nextInt(N))) - } - - val iter = points.iterator - for (i <- 1 to points.size) { - kPoints.put(i, iter.next()) - } - - println("Initial centers: " + kPoints) - - while(tempDist > convergeDist) { - var closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) - - var mappings = closest.groupBy[Int] (x => x._1) - - var pointStats = mappings.map(pair => pair._2.reduceLeft [(Int, (Vector, Int))] {case ((id1, (x1, y1)), (id2, (x2, y2))) => (id1, (x1 + x2, y1+y2))}) - - var newPoints = pointStats.map {mapping => (mapping._1, mapping._2._1/mapping._2._2)} - - tempDist = 0.0 - for (mapping <- newPoints) { - tempDist += kPoints.get(mapping._1).get.squaredDist(mapping._2) - } - - for (newP <- newPoints) { - kPoints.put(newP._1, newP._2) - } - } - - println("Final centers: " + kPoints) - } -} diff --git a/examples/src/main/scala/spark/examples/LocalLR.scala b/examples/src/main/scala/spark/examples/LocalLR.scala deleted file mode 100644 index ab99bf1fbe..0000000000 --- a/examples/src/main/scala/spark/examples/LocalLR.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import java.util.Random -import spark.util.Vector - -/** - * Logistic regression based classification. - */ -object LocalLR { - val N = 10000 // Number of data points - val D = 10 // Number of dimensions - val R = 0.7 // Scaling factor - val ITERATIONS = 5 - val rand = new Random(42) - - case class DataPoint(x: Vector, y: Double) - - def generateData = { - def generatePoint(i: Int) = { - val y = if(i % 2 == 0) -1 else 1 - val x = Vector(D, _ => rand.nextGaussian + y * R) - DataPoint(x, y) - } - Array.tabulate(N)(generatePoint) - } - - def main(args: Array[String]) { - val data = generateData - - // Initialize w to a random value - var w = Vector(D, _ => 2 * rand.nextDouble - 1) - println("Initial w: " + w) - - for (i <- 1 to ITERATIONS) { - println("On iteration " + i) - var gradient = Vector.zeros(D) - for (p <- data) { - val scale = (1 / (1 + math.exp(-p.y * (w dot p.x))) - 1) * p.y - gradient += scale * p.x - } - w -= gradient - } - - println("Final w: " + w) - } -} diff --git a/examples/src/main/scala/spark/examples/LocalPi.scala b/examples/src/main/scala/spark/examples/LocalPi.scala deleted file mode 100644 index ccd69695df..0000000000 --- a/examples/src/main/scala/spark/examples/LocalPi.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import scala.math.random -import spark._ -import SparkContext._ - -object LocalPi { - def main(args: Array[String]) { - var count = 0 - for (i <- 1 to 100000) { - val x = random * 2 - 1 - val y = random * 2 - 1 - if (x*x + y*y < 1) count += 1 - } - println("Pi is roughly " + 4 * count / 100000.0) - } -} diff --git a/examples/src/main/scala/spark/examples/LogQuery.scala b/examples/src/main/scala/spark/examples/LogQuery.scala deleted file mode 100644 index e815ececf7..0000000000 --- a/examples/src/main/scala/spark/examples/LogQuery.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark.SparkContext -import spark.SparkContext._ -/** - * Executes a roll up-style query against Apache logs. - */ -object LogQuery { - val exampleApacheLogs = List( - """10.10.10.10 - "FRED" [18/Jan/2013:17:56:07 +1100] "GET http://images.com/2013/Generic.jpg - | HTTP/1.1" 304 315 "http://referall.com/" "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; - | GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR - | 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR - | 3.5.30729; Release=ARP)" "UD-1" - "image/jpeg" "whatever" 0.350 "-" - "" 265 923 934 "" - | 62.24.11.25 images.com 1358492167 - Whatup""".stripMargin.replace("\n", ""), - """10.10.10.10 - "FRED" [18/Jan/2013:18:02:37 +1100] "GET http://images.com/2013/Generic.jpg - | HTTP/1.1" 304 306 "http:/referall.com" "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; - | GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR - | 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR - | 3.5.30729; Release=ARP)" "UD-1" - "image/jpeg" "whatever" 0.352 "-" - "" 256 977 988 "" - | 0 73.23.2.15 images.com 1358492557 - Whatup""".stripMargin.replace("\n", "") - ) - - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: LogQuery [logFile]") - System.exit(1) - } - - val sc = new SparkContext(args(0), "Log Query", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - val dataSet = - if (args.length == 2) sc.textFile(args(1)) - else sc.parallelize(exampleApacheLogs) - - val apacheLogRegex = - """^([\d.]+) (\S+) (\S+) \[([\w\d:/]+\s[+\-]\d{4})\] "(.+?)" (\d{3}) ([\d\-]+) "([^"]+)" "([^"]+)".*""".r - - /** Tracks the total query count and number of aggregate bytes for a particular group. */ - class Stats(val count: Int, val numBytes: Int) extends Serializable { - def merge(other: Stats) = new Stats(count + other.count, numBytes + other.numBytes) - override def toString = "bytes=%s\tn=%s".format(numBytes, count) - } - - def extractKey(line: String): (String, String, String) = { - apacheLogRegex.findFirstIn(line) match { - case Some(apacheLogRegex(ip, _, user, dateTime, query, status, bytes, referer, ua)) => - if (user != "\"-\"") (ip, user, query) - else (null, null, null) - case _ => (null, null, null) - } - } - - def extractStats(line: String): Stats = { - apacheLogRegex.findFirstIn(line) match { - case Some(apacheLogRegex(ip, _, user, dateTime, query, status, bytes, referer, ua)) => - new Stats(1, bytes.toInt) - case _ => new Stats(1, 0) - } - } - - dataSet.map(line => (extractKey(line), extractStats(line))) - .reduceByKey((a, b) => a.merge(b)) - .collect().foreach{ - case (user, query) => println("%s\t%s".format(user, query))} - } -} diff --git a/examples/src/main/scala/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/spark/examples/MultiBroadcastTest.scala deleted file mode 100644 index d0b1cf06e5..0000000000 --- a/examples/src/main/scala/spark/examples/MultiBroadcastTest.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark.SparkContext - -object MultiBroadcastTest { - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: BroadcastTest [] [numElem]") - System.exit(1) - } - - val sc = new SparkContext(args(0), "Broadcast Test", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - val slices = if (args.length > 1) args(1).toInt else 2 - val num = if (args.length > 2) args(2).toInt else 1000000 - - var arr1 = new Array[Int](num) - for (i <- 0 until arr1.length) { - arr1(i) = i - } - - var arr2 = new Array[Int](num) - for (i <- 0 until arr2.length) { - arr2(i) = i - } - - val barr1 = sc.broadcast(arr1) - val barr2 = sc.broadcast(arr2) - sc.parallelize(1 to 10, slices).foreach { - i => println(barr1.value.size + barr2.value.size) - } - - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/spark/examples/SimpleSkewedGroupByTest.scala deleted file mode 100644 index d197bbaf7c..0000000000 --- a/examples/src/main/scala/spark/examples/SimpleSkewedGroupByTest.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark.SparkContext -import spark.SparkContext._ -import java.util.Random - -object SimpleSkewedGroupByTest { - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SimpleSkewedGroupByTest " + - "[numMappers] [numKVPairs] [valSize] [numReducers] [ratio]") - System.exit(1) - } - - var numMappers = if (args.length > 1) args(1).toInt else 2 - var numKVPairs = if (args.length > 2) args(2).toInt else 1000 - var valSize = if (args.length > 3) args(3).toInt else 1000 - var numReducers = if (args.length > 4) args(4).toInt else numMappers - var ratio = if (args.length > 5) args(5).toInt else 5.0 - - val sc = new SparkContext(args(0), "GroupBy Test", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => - val ranGen = new Random - var result = new Array[(Int, Array[Byte])](numKVPairs) - for (i <- 0 until numKVPairs) { - val byteArr = new Array[Byte](valSize) - ranGen.nextBytes(byteArr) - val offset = ranGen.nextInt(1000) * numReducers - if (ranGen.nextDouble < ratio / (numReducers + ratio - 1)) { - // give ratio times higher chance of generating key 0 (for reducer 0) - result(i) = (offset, byteArr) - } else { - // generate a key for one of the other reducers - val key = 1 + ranGen.nextInt(numReducers-1) + offset - result(i) = (key, byteArr) - } - } - result - }.cache - // Enforce that everything has been calculated and in cache - pairs1.count - - println("RESULT: " + pairs1.groupByKey(numReducers).count) - // Print how many keys each reducer got (for debugging) - //println("RESULT: " + pairs1.groupByKey(numReducers) - // .map{case (k,v) => (k, v.size)} - // .collectAsMap) - - System.exit(0) - } -} - diff --git a/examples/src/main/scala/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/spark/examples/SkewedGroupByTest.scala deleted file mode 100644 index 4641b82444..0000000000 --- a/examples/src/main/scala/spark/examples/SkewedGroupByTest.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark.SparkContext -import spark.SparkContext._ -import java.util.Random - -object SkewedGroupByTest { - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]") - System.exit(1) - } - - var numMappers = if (args.length > 1) args(1).toInt else 2 - var numKVPairs = if (args.length > 2) args(2).toInt else 1000 - var valSize = if (args.length > 3) args(3).toInt else 1000 - var numReducers = if (args.length > 4) args(4).toInt else numMappers - - val sc = new SparkContext(args(0), "GroupBy Test", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => - val ranGen = new Random - - // map output sizes lineraly increase from the 1st to the last - numKVPairs = (1.0 * (p + 1) / numMappers * numKVPairs).toInt - - var arr1 = new Array[(Int, Array[Byte])](numKVPairs) - for (i <- 0 until numKVPairs) { - val byteArr = new Array[Byte](valSize) - ranGen.nextBytes(byteArr) - arr1(i) = (ranGen.nextInt(Int.MaxValue), byteArr) - } - arr1 - }.cache() - // Enforce that everything has been calculated and in cache - pairs1.count() - - println(pairs1.groupByKey(numReducers).count()) - - System.exit(0) - } -} - diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala deleted file mode 100644 index ba0dfd8f9b..0000000000 --- a/examples/src/main/scala/spark/examples/SparkALS.scala +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import scala.math.sqrt -import cern.jet.math._ -import cern.colt.matrix._ -import cern.colt.matrix.linalg._ -import spark._ - -/** - * Alternating least squares matrix factorization. - */ -object SparkALS { - // Parameters set through command line arguments - var M = 0 // Number of movies - var U = 0 // Number of users - var F = 0 // Number of features - var ITERATIONS = 0 - - val LAMBDA = 0.01 // Regularization coefficient - - // Some COLT objects - val factory2D = DoubleFactory2D.dense - val factory1D = DoubleFactory1D.dense - val algebra = Algebra.DEFAULT - val blas = SeqBlas.seqBlas - - def generateR(): DoubleMatrix2D = { - val mh = factory2D.random(M, F) - val uh = factory2D.random(U, F) - return algebra.mult(mh, algebra.transpose(uh)) - } - - def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D], - us: Array[DoubleMatrix1D]): Double = - { - val r = factory2D.make(M, U) - for (i <- 0 until M; j <- 0 until U) { - r.set(i, j, blas.ddot(ms(i), us(j))) - } - //println("R: " + r) - blas.daxpy(-1, targetR, r) - val sumSqs = r.aggregate(Functions.plus, Functions.square) - return sqrt(sumSqs / (M * U)) - } - - def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], - R: DoubleMatrix2D) : DoubleMatrix1D = - { - val U = us.size - val F = us(0).size - val XtX = factory2D.make(F, F) - val Xty = factory1D.make(F) - // For each user that rated the movie - for (j <- 0 until U) { - val u = us(j) - // Add u * u^t to XtX - blas.dger(1, u, u, XtX) - // Add u * rating to Xty - blas.daxpy(R.get(i, j), u, Xty) - } - // Add regularization coefs to diagonal terms - for (d <- 0 until F) { - XtX.set(d, d, XtX.get(d, d) + LAMBDA * U) - } - // Solve it with Cholesky - val ch = new CholeskyDecomposition(XtX) - val Xty2D = factory2D.make(Xty.toArray, F) - val solved2D = ch.solve(Xty2D) - return solved2D.viewColumn(0) - } - - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkALS [ ]") - System.exit(1) - } - - var host = "" - var slices = 0 - - val options = (0 to 5).map(i => if (i < args.length) Some(args(i)) else None) - - options.toArray match { - case Array(host_, m, u, f, iters, slices_) => - host = host_.get - M = m.getOrElse("100").toInt - U = u.getOrElse("500").toInt - F = f.getOrElse("10").toInt - ITERATIONS = iters.getOrElse("5").toInt - slices = slices_.getOrElse("2").toInt - case _ => - System.err.println("Usage: SparkALS [ ]") - System.exit(1) - } - printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) - - val sc = new SparkContext(host, "SparkALS", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - val R = generateR() - - // Initialize m and u randomly - var ms = Array.fill(M)(factory1D.random(F)) - var us = Array.fill(U)(factory1D.random(F)) - - // Iteratively update movies then users - val Rc = sc.broadcast(R) - var msb = sc.broadcast(ms) - var usb = sc.broadcast(us) - for (iter <- 1 to ITERATIONS) { - println("Iteration " + iter + ":") - ms = sc.parallelize(0 until M, slices) - .map(i => update(i, msb.value(i), usb.value, Rc.value)) - .toArray - msb = sc.broadcast(ms) // Re-broadcast ms because it was updated - us = sc.parallelize(0 until U, slices) - .map(i => update(i, usb.value(i), msb.value, algebra.transpose(Rc.value))) - .toArray - usb = sc.broadcast(us) // Re-broadcast us because it was updated - println("RMSE = " + rmse(R, ms, us)) - println() - } - - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/spark/examples/SparkHdfsLR.scala deleted file mode 100644 index 43c9115664..0000000000 --- a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import java.util.Random -import scala.math.exp -import spark.util.Vector -import spark._ -import spark.scheduler.InputFormatInfo - -/** - * Logistic regression based classification. - */ -object SparkHdfsLR { - val D = 10 // Numer of dimensions - val rand = new Random(42) - - case class DataPoint(x: Vector, y: Double) - - def parsePoint(line: String): DataPoint = { - //val nums = line.split(' ').map(_.toDouble) - //return DataPoint(new Vector(nums.slice(1, D+1)), nums(0)) - val tok = new java.util.StringTokenizer(line, " ") - var y = tok.nextToken.toDouble - var x = new Array[Double](D) - var i = 0 - while (i < D) { - x(i) = tok.nextToken.toDouble; i += 1 - } - return DataPoint(new Vector(x), y) - } - - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: SparkHdfsLR ") - System.exit(1) - } - val inputPath = args(1) - val conf = SparkEnv.get.hadoop.newConfiguration() - val sc = new SparkContext(args(0), "SparkHdfsLR", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")), Map(), - InputFormatInfo.computePreferredLocations( - Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)))) - val lines = sc.textFile(inputPath) - val points = lines.map(parsePoint _).cache() - val ITERATIONS = args(2).toInt - - // Initialize w to a random value - var w = Vector(D, _ => 2 * rand.nextDouble - 1) - println("Initial w: " + w) - - for (i <- 1 to ITERATIONS) { - println("On iteration " + i) - val gradient = points.map { p => - (1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x - }.reduce(_ + _) - w -= gradient - } - - println("Final w: " + w) - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/SparkKMeans.scala b/examples/src/main/scala/spark/examples/SparkKMeans.scala deleted file mode 100644 index 38ed3b149a..0000000000 --- a/examples/src/main/scala/spark/examples/SparkKMeans.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import java.util.Random -import spark.SparkContext -import spark.util.Vector -import spark.SparkContext._ -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -/** - * K-means clustering. - */ -object SparkKMeans { - val R = 1000 // Scaling factor - val rand = new Random(42) - - def parseVector(line: String): Vector = { - return new Vector(line.split(' ').map(_.toDouble)) - } - - def closestPoint(p: Vector, centers: Array[Vector]): Int = { - var index = 0 - var bestIndex = 0 - var closest = Double.PositiveInfinity - - for (i <- 0 until centers.length) { - val tempDist = p.squaredDist(centers(i)) - if (tempDist < closest) { - closest = tempDist - bestIndex = i - } - } - - return bestIndex - } - - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: SparkLocalKMeans ") - System.exit(1) - } - val sc = new SparkContext(args(0), "SparkLocalKMeans", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val lines = sc.textFile(args(1)) - val data = lines.map(parseVector _).cache() - val K = args(2).toInt - val convergeDist = args(3).toDouble - - var kPoints = data.takeSample(false, K, 42).toArray - var tempDist = 1.0 - - while(tempDist > convergeDist) { - var closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) - - var pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)} - - var newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collectAsMap() - - tempDist = 0.0 - for (i <- 0 until K) { - tempDist += kPoints(i).squaredDist(newPoints(i)) - } - - for (newP <- newPoints) { - kPoints(newP._1) = newP._2 - } - println("Finished iteration (delta = " + tempDist + ")") - } - - println("Final centers:") - kPoints.foreach(println) - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/SparkLR.scala b/examples/src/main/scala/spark/examples/SparkLR.scala deleted file mode 100644 index 52a0d69744..0000000000 --- a/examples/src/main/scala/spark/examples/SparkLR.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import java.util.Random -import scala.math.exp -import spark.util.Vector -import spark._ - -/** - * Logistic regression based classification. - */ -object SparkLR { - val N = 10000 // Number of data points - val D = 10 // Numer of dimensions - val R = 0.7 // Scaling factor - val ITERATIONS = 5 - val rand = new Random(42) - - case class DataPoint(x: Vector, y: Double) - - def generateData = { - def generatePoint(i: Int) = { - val y = if(i % 2 == 0) -1 else 1 - val x = Vector(D, _ => rand.nextGaussian + y * R) - DataPoint(x, y) - } - Array.tabulate(N)(generatePoint) - } - - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkLR []") - System.exit(1) - } - val sc = new SparkContext(args(0), "SparkLR", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val numSlices = if (args.length > 1) args(1).toInt else 2 - val points = sc.parallelize(generateData, numSlices).cache() - - // Initialize w to a random value - var w = Vector(D, _ => 2 * rand.nextDouble - 1) - println("Initial w: " + w) - - for (i <- 1 to ITERATIONS) { - println("On iteration " + i) - val gradient = points.map { p => - (1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x - }.reduce(_ + _) - w -= gradient - } - - println("Final w: " + w) - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/SparkPageRank.scala b/examples/src/main/scala/spark/examples/SparkPageRank.scala deleted file mode 100644 index dedbbd01a3..0000000000 --- a/examples/src/main/scala/spark/examples/SparkPageRank.scala +++ /dev/null @@ -1,46 +0,0 @@ -package spark.examples - -import spark.SparkContext._ -import spark.SparkContext - - -/** - * Computes the PageRank of URLs from an input file. Input file should - * be in format of: - * URL neighbor URL - * URL neighbor URL - * URL neighbor URL - * ... - * where URL and their neighbors are separated by space(s). - */ -object SparkPageRank { - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: PageRank ") - System.exit(1) - } - var iters = args(2).toInt - val ctx = new SparkContext(args(0), "PageRank", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val lines = ctx.textFile(args(1), 1) - val links = lines.map{ s => - val parts = s.split("\\s+") - (parts(0), parts(1)) - }.distinct().groupByKey().cache() - var ranks = links.mapValues(v => 1.0) - - for (i <- 1 to iters) { - val contribs = links.join(ranks).values.flatMap{ case (urls, rank) => - val size = urls.size - urls.map(url => (url, rank / size)) - } - ranks = contribs.reduceByKey(_ + _).mapValues(0.15 + 0.85 * _) - } - - val output = ranks.collect() - output.foreach(tup => println(tup._1 + " has rank: " + tup._2 + ".")) - - System.exit(0) - } -} - diff --git a/examples/src/main/scala/spark/examples/SparkPi.scala b/examples/src/main/scala/spark/examples/SparkPi.scala deleted file mode 100644 index 00560ac9d1..0000000000 --- a/examples/src/main/scala/spark/examples/SparkPi.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import scala.math.random -import spark._ -import SparkContext._ - -/** Computes an approximation to pi */ -object SparkPi { - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkPi []") - System.exit(1) - } - val spark = new SparkContext(args(0), "SparkPi", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val slices = if (args.length > 1) args(1).toInt else 2 - val n = 100000 * slices - val count = spark.parallelize(1 to n, slices).map { i => - val x = random * 2 - 1 - val y = random * 2 - 1 - if (x*x + y*y < 1) 1 else 0 - }.reduce(_ + _) - println("Pi is roughly " + 4.0 * count / n) - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/SparkTC.scala b/examples/src/main/scala/spark/examples/SparkTC.scala deleted file mode 100644 index bf988a953b..0000000000 --- a/examples/src/main/scala/spark/examples/SparkTC.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples - -import spark._ -import SparkContext._ -import scala.util.Random -import scala.collection.mutable - -/** - * Transitive closure on a graph. - */ -object SparkTC { - val numEdges = 200 - val numVertices = 100 - val rand = new Random(42) - - def generateGraph = { - val edges: mutable.Set[(Int, Int)] = mutable.Set.empty - while (edges.size < numEdges) { - val from = rand.nextInt(numVertices) - val to = rand.nextInt(numVertices) - if (from != to) edges.+=((from, to)) - } - edges.toSeq - } - - def main(args: Array[String]) { - if (args.length == 0) { - System.err.println("Usage: SparkTC []") - System.exit(1) - } - val spark = new SparkContext(args(0), "SparkTC", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val slices = if (args.length > 1) args(1).toInt else 2 - var tc = spark.parallelize(generateGraph, slices).cache() - - // Linear transitive closure: each round grows paths by one edge, - // by joining the graph's edges with the already-discovered paths. - // e.g. join the path (y, z) from the TC with the edge (x, y) from - // the graph to obtain the path (x, z). - - // Because join() joins on keys, the edges are stored in reversed order. - val edges = tc.map(x => (x._2, x._1)) - - // This join is iterated until a fixed point is reached. - var oldCount = 0L - var nextCount = tc.count() - do { - oldCount = nextCount - // Perform the join, obtaining an RDD of (y, (z, x)) pairs, - // then project the result to obtain the new (x, z) paths. - tc = tc.union(tc.join(edges).map(x => (x._2._2, x._2._1))).distinct().cache(); - nextCount = tc.count() - } while (nextCount != oldCount) - - println("TC has " + tc.count() + " edges.") - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala deleted file mode 100644 index c23ee9895f..0000000000 --- a/examples/src/main/scala/spark/examples/bagel/PageRankUtils.scala +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples.bagel - -import spark._ -import spark.SparkContext._ - -import spark.bagel._ -import spark.bagel.Bagel._ - -import scala.collection.mutable.ArrayBuffer - -import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} - -import com.esotericsoftware.kryo._ - -class PageRankUtils extends Serializable { - def computeWithCombiner(numVertices: Long, epsilon: Double)( - self: PRVertex, messageSum: Option[Double], superstep: Int - ): (PRVertex, Array[PRMessage]) = { - val newValue = messageSum match { - case Some(msgSum) if msgSum != 0 => - 0.15 / numVertices + 0.85 * msgSum - case _ => self.value - } - - val terminate = superstep >= 10 - - val outbox: Array[PRMessage] = - if (!terminate) - self.outEdges.map(targetId => - new PRMessage(targetId, newValue / self.outEdges.size)) - else - Array[PRMessage]() - - (new PRVertex(newValue, self.outEdges, !terminate), outbox) - } - - def computeNoCombiner(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int): (PRVertex, Array[PRMessage]) = - computeWithCombiner(numVertices, epsilon)(self, messages match { - case Some(msgs) => Some(msgs.map(_.value).sum) - case None => None - }, superstep) -} - -class PRCombiner extends Combiner[PRMessage, Double] with Serializable { - def createCombiner(msg: PRMessage): Double = - msg.value - def mergeMsg(combiner: Double, msg: PRMessage): Double = - combiner + msg.value - def mergeCombiners(a: Double, b: Double): Double = - a + b -} - -class PRVertex() extends Vertex with Serializable { - var value: Double = _ - var outEdges: Array[String] = _ - var active: Boolean = _ - - def this(value: Double, outEdges: Array[String], active: Boolean = true) { - this() - this.value = value - this.outEdges = outEdges - this.active = active - } - - override def toString(): String = { - "PRVertex(value=%f, outEdges.length=%d, active=%s)".format(value, outEdges.length, active.toString) - } -} - -class PRMessage() extends Message[String] with Serializable { - var targetId: String = _ - var value: Double = _ - - def this(targetId: String, value: Double) { - this() - this.targetId = targetId - this.value = value - } -} - -class PRKryoRegistrator extends KryoRegistrator { - def registerClasses(kryo: Kryo) { - kryo.register(classOf[PRVertex]) - kryo.register(classOf[PRMessage]) - } -} - -class CustomPartitioner(partitions: Int) extends Partitioner { - def numPartitions = partitions - - def getPartition(key: Any): Int = { - val hash = key match { - case k: Long => (k & 0x00000000FFFFFFFFL).toInt - case _ => key.hashCode - } - - val mod = key.hashCode % partitions - if (mod < 0) mod + partitions else mod - } - - override def equals(other: Any): Boolean = other match { - case c: CustomPartitioner => - c.numPartitions == numPartitions - case _ => false - } -} diff --git a/examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala deleted file mode 100644 index 00635a7ffa..0000000000 --- a/examples/src/main/scala/spark/examples/bagel/WikipediaPageRank.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples.bagel - -import spark._ -import spark.SparkContext._ - -import spark.bagel._ -import spark.bagel.Bagel._ - -import scala.xml.{XML,NodeSeq} - -/** - * Run PageRank on XML Wikipedia dumps from http://wiki.freebase.com/wiki/WEX. Uses the "articles" - * files from there, which contains one line per wiki article in a tab-separated format - * (http://wiki.freebase.com/wiki/WEX/Documentation#articles). - */ -object WikipediaPageRank { - def main(args: Array[String]) { - if (args.length < 5) { - System.err.println("Usage: WikipediaPageRank ") - System.exit(-1) - } - - System.setProperty("spark.serializer", "spark.KryoSerializer") - System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) - - val inputFile = args(0) - val threshold = args(1).toDouble - val numPartitions = args(2).toInt - val host = args(3) - val usePartitioner = args(4).toBoolean - val sc = new SparkContext(host, "WikipediaPageRank") - - // Parse the Wikipedia page data into a graph - val input = sc.textFile(inputFile) - - println("Counting vertices...") - val numVertices = input.count() - println("Done counting vertices.") - - println("Parsing input file...") - var vertices = input.map(line => { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val links = - if (body == "\\N") - NodeSeq.Empty - else - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) - NodeSeq.Empty - } - val outEdges = links.map(link => new String(link.text)).toArray - val id = new String(title) - (id, new PRVertex(1.0 / numVertices, outEdges)) - }) - if (usePartitioner) - vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache - else - vertices = vertices.cache - println("Done parsing input file.") - - // Do the computation - val epsilon = 0.01 / numVertices - val messages = sc.parallelize(Array[(String, PRMessage)]()) - val utils = new PageRankUtils - val result = - Bagel.run( - sc, vertices, messages, combiner = new PRCombiner(), - numPartitions = numPartitions)( - utils.computeWithCombiner(numVertices, epsilon)) - - // Print the result - System.err.println("Articles with PageRank >= "+threshold+":") - val top = - (result - .filter { case (id, vertex) => vertex.value >= threshold } - .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) } - .collect.mkString) - println(top) - } -} diff --git a/examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala deleted file mode 100644 index c416ddbc58..0000000000 --- a/examples/src/main/scala/spark/examples/bagel/WikipediaPageRankStandalone.scala +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.examples.bagel - -import spark._ -import serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import spark.SparkContext._ - -import spark.bagel._ -import spark.bagel.Bagel._ - -import scala.xml.{XML,NodeSeq} - -import scala.collection.mutable.ArrayBuffer - -import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} -import java.nio.ByteBuffer - -object WikipediaPageRankStandalone { - def main(args: Array[String]) { - if (args.length < 5) { - System.err.println("Usage: WikipediaPageRankStandalone ") - System.exit(-1) - } - - System.setProperty("spark.serializer", "spark.bagel.examples.WPRSerializer") - - val inputFile = args(0) - val threshold = args(1).toDouble - val numIterations = args(2).toInt - val host = args(3) - val usePartitioner = args(4).toBoolean - val sc = new SparkContext(host, "WikipediaPageRankStandalone") - - val input = sc.textFile(inputFile) - val partitioner = new HashPartitioner(sc.defaultParallelism) - val links = - if (usePartitioner) - input.map(parseArticle _).partitionBy(partitioner).cache() - else - input.map(parseArticle _).cache() - val n = links.count() - val defaultRank = 1.0 / n - val a = 0.15 - - // Do the computation - val startTime = System.currentTimeMillis - val ranks = - pageRank(links, numIterations, defaultRank, a, n, partitioner, usePartitioner, sc.defaultParallelism) - - // Print the result - System.err.println("Articles with PageRank >= "+threshold+":") - val top = - (ranks - .filter { case (id, rank) => rank >= threshold } - .map { case (id, rank) => "%s\t%s\n".format(id, rank) } - .collect().mkString) - println(top) - - val time = (System.currentTimeMillis - startTime) / 1000.0 - println("Completed %d iterations in %f seconds: %f seconds per iteration" - .format(numIterations, time, time / numIterations)) - System.exit(0) - } - - def parseArticle(line: String): (String, Array[String]) = { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val id = new String(title) - val links = - if (body == "\\N") - NodeSeq.Empty - else - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) - NodeSeq.Empty - } - val outEdges = links.map(link => new String(link.text)).toArray - (id, outEdges) - } - - def pageRank( - links: RDD[(String, Array[String])], - numIterations: Int, - defaultRank: Double, - a: Double, - n: Long, - partitioner: Partitioner, - usePartitioner: Boolean, - numPartitions: Int - ): RDD[(String, Double)] = { - var ranks = links.mapValues { edges => defaultRank } - for (i <- 1 to numIterations) { - val contribs = links.groupWith(ranks).flatMap { - case (id, (linksWrapper, rankWrapper)) => - if (linksWrapper.length > 0) { - if (rankWrapper.length > 0) { - linksWrapper(0).map(dest => (dest, rankWrapper(0) / linksWrapper(0).size)) - } else { - linksWrapper(0).map(dest => (dest, defaultRank / linksWrapper(0).size)) - } - } else { - Array[(String, Double)]() - } - } - ranks = (contribs.combineByKey((x: Double) => x, - (x: Double, y: Double) => x + y, - (x: Double, y: Double) => x + y, - partitioner) - .mapValues(sum => a/n + (1-a)*sum)) - } - ranks - } -} - -class WPRSerializer extends spark.serializer.Serializer { - def newInstance(): SerializerInstance = new WPRSerializerInstance() -} - -class WPRSerializerInstance extends SerializerInstance { - def serialize[T](t: T): ByteBuffer = { - throw new UnsupportedOperationException() - } - - def deserialize[T](bytes: ByteBuffer): T = { - throw new UnsupportedOperationException() - } - - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { - throw new UnsupportedOperationException() - } - - def serializeStream(s: OutputStream): SerializationStream = { - new WPRSerializationStream(s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new WPRDeserializationStream(s) - } -} - -class WPRSerializationStream(os: OutputStream) extends SerializationStream { - val dos = new DataOutputStream(os) - - def writeObject[T](t: T): SerializationStream = t match { - case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { - case links: Array[String] => { - dos.writeInt(0) // links - dos.writeUTF(id) - dos.writeInt(links.length) - for (link <- links) { - dos.writeUTF(link) - } - this - } - case rank: Double => { - dos.writeInt(1) // rank - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - case (id: String, rank: Double) => { - dos.writeInt(2) // rank without wrapper - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - - def flush() { dos.flush() } - def close() { dos.close() } -} - -class WPRDeserializationStream(is: InputStream) extends DeserializationStream { - val dis = new DataInputStream(is) - - def readObject[T](): T = { - val typeId = dis.readInt() - typeId match { - case 0 => { - val id = dis.readUTF() - val numLinks = dis.readInt() - val links = new Array[String](numLinks) - for (i <- 0 until numLinks) { - val link = dis.readUTF() - links(i) = link - } - (id, ArrayBuffer(links)).asInstanceOf[T] - } - case 1 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, ArrayBuffer(rank)).asInstanceOf[T] - } - case 2 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, rank).asInstanceOf[T] - } - } - } - - def close() { dis.close() } -} diff --git a/examples/src/main/scala/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/spark/streaming/examples/ActorWordCount.scala deleted file mode 100644 index 05d3176478..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/ActorWordCount.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import scala.collection.mutable.LinkedList -import scala.util.Random - -import akka.actor.Actor -import akka.actor.ActorRef -import akka.actor.Props -import akka.actor.actorRef2Scala - -import spark.streaming.Seconds -import spark.streaming.StreamingContext -import spark.streaming.StreamingContext.toPairDStreamFunctions -import spark.streaming.receivers.Receiver -import spark.util.AkkaUtils - -case class SubscribeReceiver(receiverActor: ActorRef) -case class UnsubscribeReceiver(receiverActor: ActorRef) - -/** - * Sends the random content to every receiver subscribed with 1/2 - * second delay. - */ -class FeederActor extends Actor { - - val rand = new Random() - var receivers: LinkedList[ActorRef] = new LinkedList[ActorRef]() - - val strings: Array[String] = Array("words ", "may ", "count ") - - def makeMessage(): String = { - val x = rand.nextInt(3) - strings(x) + strings(2 - x) - } - - /* - * A thread to generate random messages - */ - new Thread() { - override def run() { - while (true) { - Thread.sleep(500) - receivers.foreach(_ ! makeMessage) - } - } - }.start() - - def receive: Receive = { - - case SubscribeReceiver(receiverActor: ActorRef) => - println("received subscribe from %s".format(receiverActor.toString)) - receivers = LinkedList(receiverActor) ++ receivers - - case UnsubscribeReceiver(receiverActor: ActorRef) => - println("received unsubscribe from %s".format(receiverActor.toString)) - receivers = receivers.dropWhile(x => x eq receiverActor) - - } -} - -/** - * A sample actor as receiver, is also simplest. This receiver actor - * goes and subscribe to a typical publisher/feeder actor and receives - * data. - * - * @see [[spark.streaming.examples.FeederActor]] - */ -class SampleActorReceiver[T: ClassManifest](urlOfPublisher: String) -extends Actor with Receiver { - - lazy private val remotePublisher = context.actorFor(urlOfPublisher) - - override def preStart = remotePublisher ! SubscribeReceiver(context.self) - - def receive = { - case msg ⇒ context.parent ! pushBlock(msg.asInstanceOf[T]) - } - - override def postStop() = remotePublisher ! UnsubscribeReceiver(context.self) - -} - -/** - * A sample feeder actor - * - * Usage: FeederActor - * and describe the AkkaSystem that Spark Sample feeder would start on. - */ -object FeederActor { - - def main(args: Array[String]) { - if(args.length < 2){ - System.err.println( - "Usage: FeederActor \n" - ) - System.exit(1) - } - val Seq(host, port) = args.toSeq - - - val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt)._1 - val feeder = actorSystem.actorOf(Props[FeederActor], "FeederActor") - - println("Feeder started as:" + feeder) - - actorSystem.awaitTermination(); - } -} - -/** - * A sample word count program demonstrating the use of plugging in - * Actor as Receiver - * Usage: ActorWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. - * and describe the AkkaSystem that Spark Sample feeder is running on. - * - * To run this example locally, you may run Feeder Actor as - * `$ ./run-example spark.streaming.examples.FeederActor 127.0.1.1 9999` - * and then run the example - * `$ ./run-example spark.streaming.examples.ActorWordCount local[2] 127.0.1.1 9999` - */ -object ActorWordCount { - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println( - "Usage: ActorWordCount " + - "In local mode, should be 'local[n]' with n > 1") - System.exit(1) - } - - val Seq(master, host, port) = args.toSeq - - // Create the context and set the batch size - val ssc = new StreamingContext(master, "ActorWordCount", Seconds(2), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - /* - * Following is the use of actorStream to plug in custom actor as receiver - * - * An important point to note: - * Since Actor may exist outside the spark framework, It is thus user's responsibility - * to ensure the type safety, i.e type of data received and InputDstream - * should be same. - * - * For example: Both actorStream and SampleActorReceiver are parameterized - * to same type to ensure type safety. - */ - - val lines = ssc.actorStream[String]( - Props(new SampleActorReceiver[String]("akka://test@%s:%s/user/FeederActor".format( - host, port.toInt))), "SampleReceiver") - - //compute wordcount - lines.flatMap(_.split("\\s+")).map(x => (x, 1)).reduceByKey(_ + _).print() - - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala b/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala deleted file mode 100644 index 3ab4fc2c37..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.util.IntParam -import spark.storage.StorageLevel -import spark.streaming._ - -/** - * Produces a count of events received from Flume. - * - * This should be used in conjunction with an AvroSink in Flume. It will start - * an Avro server on at the request host:port address and listen for requests. - * Your Flume AvroSink should be pointed to this address. - * - * Usage: FlumeEventCount - * - * is a Spark master URL - * is the host the Flume receiver will be started on - a receiver - * creates a server and listens for flume events. - * is the port the Flume receiver will listen on. - */ -object FlumeEventCount { - def main(args: Array[String]) { - if (args.length != 3) { - System.err.println( - "Usage: FlumeEventCount ") - System.exit(1) - } - - val Array(master, host, IntParam(port)) = args - - val batchInterval = Milliseconds(2000) - // Create the context and set the batch size - val ssc = new StreamingContext(master, "FlumeEventCount", batchInterval, - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - // Create a flume stream - val stream = ssc.flumeStream(host,port,StorageLevel.MEMORY_ONLY) - - // Print out the count of events received from this server in each batch - stream.count().map(cnt => "Received " + cnt + " flume events." ).print() - - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala b/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala deleted file mode 100644 index 30af01a26f..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ - - -/** - * Counts words in new text files created in the given directory - * Usage: HdfsWordCount - * is the Spark master URL. - * is the directory that Spark Streaming will use to find and read new text files. - * - * To run this on your local machine on directory `localdir`, run this example - * `$ ./run-example spark.streaming.examples.HdfsWordCount local[2] localdir` - * Then create a text file in `localdir` and the words in the file will get counted. - */ -object HdfsWordCount { - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: HdfsWordCount ") - System.exit(1) - } - - // Create the context - val ssc = new StreamingContext(args(0), "HdfsWordCount", Seconds(2), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - // Create the FileInputDStream on the directory and use the - // stream to count words in new files created - val lines = ssc.textFileStream(args(1)) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - } -} - diff --git a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala deleted file mode 100644 index d9c76d1a33..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import java.util.Properties -import kafka.message.Message -import kafka.producer.SyncProducerConfig -import kafka.producer._ -import spark.SparkContext -import spark.streaming._ -import spark.streaming.StreamingContext._ -import spark.storage.StorageLevel -import spark.streaming.util.RawTextHelper._ - -/** - * Consumes messages from one or more topics in Kafka and does wordcount. - * Usage: KafkaWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. - * is a list of one or more zookeeper servers that make quorum - * is the name of kafka consumer group - * is a list of one or more kafka topics to consume from - * is the number of threads the kafka consumer should use - * - * Example: - * `./run-example spark.streaming.examples.KafkaWordCount local[2] zoo01,zoo02,zoo03 my-consumer-group topic1,topic2 1` - */ -object KafkaWordCount { - def main(args: Array[String]) { - - if (args.length < 5) { - System.err.println("Usage: KafkaWordCount ") - System.exit(1) - } - - val Array(master, zkQuorum, group, topics, numThreads) = args - - val ssc = new StreamingContext(master, "KafkaWordCount", Seconds(2), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - ssc.checkpoint("checkpoint") - - val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap - val lines = ssc.kafkaStream(zkQuorum, group, topicpMap) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2) - wordCounts.print() - - ssc.start() - } -} - -// Produces some random words between 1 and 100. -object KafkaWordCountProducer { - - def main(args: Array[String]) { - if (args.length < 2) { - System.err.println("Usage: KafkaWordCountProducer ") - System.exit(1) - } - - val Array(zkQuorum, topic, messagesPerSec, wordsPerMessage) = args - - // Zookeper connection properties - val props = new Properties() - props.put("zk.connect", zkQuorum) - props.put("serializer.class", "kafka.serializer.StringEncoder") - - val config = new ProducerConfig(props) - val producer = new Producer[String, String](config) - - // Send some messages - while(true) { - val messages = (1 to messagesPerSec.toInt).map { messageNum => - (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString).mkString(" ") - }.toArray - println(messages.mkString(",")) - val data = new ProducerData[String, String](topic, messages) - producer.send(data) - Thread.sleep(100) - } - } - -} - diff --git a/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala deleted file mode 100644 index b29d79aac5..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ - -/** - * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. - * Usage: NetworkWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. - * and describe the TCP server that Spark Streaming would connect to receive data. - * - * To run this on your local machine, you need to first run a Netcat server - * `$ nc -lk 9999` - * and then run the example - * `$ ./run-example spark.streaming.examples.NetworkWordCount local[2] localhost 9999` - */ -object NetworkWordCount { - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: NetworkWordCount \n" + - "In local mode, should be 'local[n]' with n > 1") - System.exit(1) - } - - // Create the context with a 1 second batch size - val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') - val lines = ssc.socketTextStream(args(1), args(2).toInt) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/QueueStream.scala b/examples/src/main/scala/spark/streaming/examples/QueueStream.scala deleted file mode 100644 index da36c8c23c..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/QueueStream.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.RDD -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ - -import scala.collection.mutable.SynchronizedQueue - -object QueueStream { - - def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: QueueStream ") - System.exit(1) - } - - // Create the context - val ssc = new StreamingContext(args(0), "QueueStream", Seconds(1), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - // Create the queue through which RDDs can be pushed to - // a QueueInputDStream - val rddQueue = new SynchronizedQueue[RDD[Int]]() - - // Create the QueueInputDStream and use it do some processing - val inputStream = ssc.queueStream(rddQueue) - val mappedStream = inputStream.map(x => (x % 10, 1)) - val reducedStream = mappedStream.reduceByKey(_ + _) - reducedStream.print() - ssc.start() - - // Create and push some RDDs into - for (i <- 1 to 30) { - rddQueue += ssc.sparkContext.makeRDD(1 to 1000, 10) - Thread.sleep(1000) - } - ssc.stop() - System.exit(0) - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala b/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala deleted file mode 100644 index 7fb680bcc3..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.util.IntParam -import spark.storage.StorageLevel - -import spark.streaming._ -import spark.streaming.util.RawTextHelper - -/** - * Receives text from multiple rawNetworkStreams and counts how many '\n' delimited - * lines have the word 'the' in them. This is useful for benchmarking purposes. This - * will only work with spark.streaming.util.RawTextSender running on all worker nodes - * and with Spark using Kryo serialization (set Java property "spark.serializer" to - * "spark.KryoSerializer"). - * Usage: RawNetworkGrep - * is the Spark master URL - * is the number rawNetworkStreams, which should be same as number - * of work nodes in the cluster - * is "localhost". - * is the port on which RawTextSender is running in the worker nodes. - * is the Spark Streaming batch duration in milliseconds. - */ - -object RawNetworkGrep { - def main(args: Array[String]) { - if (args.length != 5) { - System.err.println("Usage: RawNetworkGrep ") - System.exit(1) - } - - val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args - - // Create the context - val ssc = new StreamingContext(master, "RawNetworkGrep", Milliseconds(batchMillis), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - // Warm up the JVMs on master and slave for JIT compilation to kick in - RawTextHelper.warmUp(ssc.sparkContext) - - val rawStreams = (1 to numStreams).map(_ => - ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray - val union = ssc.union(rawStreams) - union.filter(_.contains("the")).count().foreach(r => - println("Grep count: " + r.collect().mkString)) - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala b/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala deleted file mode 100644 index b709fc3c87..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.streaming._ -import spark.streaming.StreamingContext._ - -/** - * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every second. - * Usage: StatefulNetworkWordCount - * is the Spark master URL. In local mode, should be 'local[n]' with n > 1. - * and describe the TCP server that Spark Streaming would connect to receive data. - * - * To run this on your local machine, you need to first run a Netcat server - * `$ nc -lk 9999` - * and then run the example - * `$ ./run-example spark.streaming.examples.StatefulNetworkWordCount local[2] localhost 9999` - */ -object StatefulNetworkWordCount { - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: StatefulNetworkWordCount \n" + - "In local mode, should be 'local[n]' with n > 1") - System.exit(1) - } - - val updateFunc = (values: Seq[Int], state: Option[Int]) => { - val currentCount = values.foldLeft(0)(_ + _) - - val previousCount = state.getOrElse(0) - - Some(currentCount + previousCount) - } - - // Create the context with a 1 second batch size - val ssc = new StreamingContext(args(0), "NetworkWordCumulativeCountUpdateStateByKey", Seconds(1), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - ssc.checkpoint(".") - - // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') - val lines = ssc.socketTextStream(args(1), args(2).toInt) - val words = lines.flatMap(_.split(" ")) - val wordDstream = words.map(x => (x, 1)) - - // Update the cumulative count using updateStateByKey - // This will give a Dstream made of state (which is the cumulative count of the words) - val stateDstream = wordDstream.updateStateByKey[Int](updateFunc) - stateDstream.print() - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala deleted file mode 100644 index 8770abd57e..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import spark.storage.StorageLevel -import com.twitter.algebird._ -import spark.streaming.StreamingContext._ -import spark.SparkContext._ - -/** - * Illustrates the use of the Count-Min Sketch, from Twitter's Algebird library, to compute - * windowed and global Top-K estimates of user IDs occurring in a Twitter stream. - *
    - * Note that since Algebird's implementation currently only supports Long inputs, - * the example operates on Long IDs. Once the implementation supports other inputs (such as String), - * the same approach could be used for computing popular topics for example. - *

    - *

    - * - * This blog post has a good overview of the Count-Min Sketch (CMS). The CMS is a datastructure - * for approximate frequency estimation in data streams (e.g. Top-K elements, frequency of any given element, etc), - * that uses space sub-linear in the number of elements in the stream. Once elements are added to the CMS, the - * estimated count of an element can be computed, as well as "heavy-hitters" that occur more than a threshold - * percentage of the overall total count. - *

    - * Algebird's implementation is a monoid, so we can succinctly merge two CMS instances in the reduce operation. - */ -object TwitterAlgebirdCMS { - def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: TwitterAlgebirdCMS " + - " [filter1] [filter2] ... [filter n]") - System.exit(1) - } - - // CMS parameters - val DELTA = 1E-3 - val EPS = 0.01 - val SEED = 1 - val PERC = 0.001 - // K highest frequency elements to take - val TOPK = 10 - - val (master, filters) = (args.head, args.tail) - - val ssc = new StreamingContext(master, "TwitterAlgebirdCMS", Seconds(10), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val stream = ssc.twitterStream(None, filters, StorageLevel.MEMORY_ONLY_SER) - - val users = stream.map(status => status.getUser.getId) - - val cms = new CountMinSketchMonoid(EPS, DELTA, SEED, PERC) - var globalCMS = cms.zero - val mm = new MapMonoid[Long, Int]() - var globalExact = Map[Long, Int]() - - val approxTopUsers = users.mapPartitions(ids => { - ids.map(id => cms.create(id)) - }).reduce(_ ++ _) - - val exactTopUsers = users.map(id => (id, 1)) - .reduceByKey((a, b) => a + b) - - approxTopUsers.foreach(rdd => { - if (rdd.count() != 0) { - val partial = rdd.first() - val partialTopK = partial.heavyHitters.map(id => - (id, partial.frequency(id).estimate)).toSeq.sortBy(_._2).reverse.slice(0, TOPK) - globalCMS ++= partial - val globalTopK = globalCMS.heavyHitters.map(id => - (id, globalCMS.frequency(id).estimate)).toSeq.sortBy(_._2).reverse.slice(0, TOPK) - println("Approx heavy hitters at %2.2f%% threshold this batch: %s".format(PERC, - partialTopK.mkString("[", ",", "]"))) - println("Approx heavy hitters at %2.2f%% threshold overall: %s".format(PERC, - globalTopK.mkString("[", ",", "]"))) - } - }) - - exactTopUsers.foreach(rdd => { - if (rdd.count() != 0) { - val partialMap = rdd.collect().toMap - val partialTopK = rdd.map( - {case (id, count) => (count, id)}) - .sortByKey(ascending = false).take(TOPK) - globalExact = mm.plus(globalExact.toMap, partialMap) - val globalTopK = globalExact.toSeq.sortBy(_._2).reverse.slice(0, TOPK) - println("Exact heavy hitters this batch: %s".format(partialTopK.mkString("[", ",", "]"))) - println("Exact heavy hitters overall: %s".format(globalTopK.mkString("[", ",", "]"))) - } - }) - - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala deleted file mode 100644 index cba5c986be..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import spark.storage.StorageLevel -import com.twitter.algebird.HyperLogLog._ -import com.twitter.algebird.HyperLogLogMonoid -import spark.streaming.dstream.TwitterInputDStream - -/** - * Illustrates the use of the HyperLogLog algorithm, from Twitter's Algebird library, to compute - * a windowed and global estimate of the unique user IDs occurring in a Twitter stream. - *

    - *

    - * This - * blog post and this - * blog post - * have good overviews of HyperLogLog (HLL). HLL is a memory-efficient datastructure for estimating - * the cardinality of a data stream, i.e. the number of unique elements. - *

    - * Algebird's implementation is a monoid, so we can succinctly merge two HLL instances in the reduce operation. - */ -object TwitterAlgebirdHLL { - def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: TwitterAlgebirdHLL " + - " [filter1] [filter2] ... [filter n]") - System.exit(1) - } - - /** Bit size parameter for HyperLogLog, trades off accuracy vs size */ - val BIT_SIZE = 12 - val (master, filters) = (args.head, args.tail) - - val ssc = new StreamingContext(master, "TwitterAlgebirdHLL", Seconds(5), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val stream = ssc.twitterStream(None, filters, StorageLevel.MEMORY_ONLY_SER) - - val users = stream.map(status => status.getUser.getId) - - val hll = new HyperLogLogMonoid(BIT_SIZE) - var globalHll = hll.zero - var userSet: Set[Long] = Set() - - val approxUsers = users.mapPartitions(ids => { - ids.map(id => hll(id)) - }).reduce(_ + _) - - val exactUsers = users.map(id => Set(id)).reduce(_ ++ _) - - approxUsers.foreach(rdd => { - if (rdd.count() != 0) { - val partial = rdd.first() - globalHll += partial - println("Approx distinct users this batch: %d".format(partial.estimatedSize.toInt)) - println("Approx distinct users overall: %d".format(globalHll.estimatedSize.toInt)) - } - }) - - exactUsers.foreach(rdd => { - if (rdd.count() != 0) { - val partial = rdd.first() - userSet ++= partial - println("Exact distinct users this batch: %d".format(partial.size)) - println("Exact distinct users overall: %d".format(userSet.size)) - println("Error rate: %2.5f%%".format(((globalHll.estimatedSize / userSet.size.toDouble) - 1) * 100)) - } - }) - - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala deleted file mode 100644 index 682b99f75e..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import spark.streaming.{Seconds, StreamingContext} -import StreamingContext._ -import spark.SparkContext._ - -/** - * Calculates popular hashtags (topics) over sliding 10 and 60 second windows from a Twitter - * stream. The stream is instantiated with credentials and optionally filters supplied by the - * command line arguments. - * - */ -object TwitterPopularTags { - def main(args: Array[String]) { - if (args.length < 1) { - System.err.println("Usage: TwitterPopularTags " + - " [filter1] [filter2] ... [filter n]") - System.exit(1) - } - - val (master, filters) = (args.head, args.tail) - - val ssc = new StreamingContext(master, "TwitterPopularTags", Seconds(2), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val stream = ssc.twitterStream(None, filters) - - val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#"))) - - val topCounts60 = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(60)) - .map{case (topic, count) => (count, topic)} - .transform(_.sortByKey(false)) - - val topCounts10 = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(10)) - .map{case (topic, count) => (count, topic)} - .transform(_.sortByKey(false)) - - - // Print popular hashtags - topCounts60.foreach(rdd => { - val topList = rdd.take(5) - println("\nPopular topics in last 60 seconds (%s total):".format(rdd.count())) - topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} - }) - - topCounts10.foreach(rdd => { - val topList = rdd.take(5) - println("\nPopular topics in last 10 seconds (%s total):".format(rdd.count())) - topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} - }) - - ssc.start() - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/ZeroMQWordCount.scala b/examples/src/main/scala/spark/streaming/examples/ZeroMQWordCount.scala deleted file mode 100644 index a0cae06c30..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/ZeroMQWordCount.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples - -import akka.actor.ActorSystem -import akka.actor.actorRef2Scala -import akka.zeromq._ -import spark.streaming.{ Seconds, StreamingContext } -import spark.streaming.StreamingContext._ -import akka.zeromq.Subscribe - -/** - * A simple publisher for demonstration purposes, repeatedly publishes random Messages - * every one second. - */ -object SimpleZeroMQPublisher { - - def main(args: Array[String]) = { - if (args.length < 2) { - System.err.println("Usage: SimpleZeroMQPublisher ") - System.exit(1) - } - - val Seq(url, topic) = args.toSeq - val acs: ActorSystem = ActorSystem() - - val pubSocket = ZeroMQExtension(acs).newSocket(SocketType.Pub, Bind(url)) - val messages: Array[String] = Array("words ", "may ", "count ") - while (true) { - Thread.sleep(1000) - pubSocket ! ZMQMessage(Frame(topic) :: messages.map(x => Frame(x.getBytes)).toList) - } - acs.awaitTermination() - } -} - -/** - * A sample wordcount with ZeroMQStream stream - * - * To work with zeroMQ, some native libraries have to be installed. - * Install zeroMQ (release 2.1) core libraries. [ZeroMQ Install guide](http://www.zeromq.org/intro:get-the-software) - * - * Usage: ZeroMQWordCount - * In local mode, should be 'local[n]' with n > 1 - * and describe where zeroMq publisher is running. - * - * To run this example locally, you may run publisher as - * `$ ./run-example spark.streaming.examples.SimpleZeroMQPublisher tcp://127.0.1.1:1234 foo.bar` - * and run the example as - * `$ ./run-example spark.streaming.examples.ZeroMQWordCount local[2] tcp://127.0.1.1:1234 foo` - */ -object ZeroMQWordCount { - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println( - "Usage: ZeroMQWordCount " + - "In local mode, should be 'local[n]' with n > 1") - System.exit(1) - } - val Seq(master, url, topic) = args.toSeq - - // Create the context and set the batch size - val ssc = new StreamingContext(master, "ZeroMQWordCount", Seconds(2), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - def bytesToStringIterator(x: Seq[Seq[Byte]]) = (x.map(x => new String(x.toArray))).iterator - - //For this stream, a zeroMQ publisher should be running. - val lines = ssc.zeroMQStream(url, Subscribe(topic), bytesToStringIterator) - val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.print() - ssc.start() - } - -} diff --git a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala deleted file mode 100644 index dd36bbbf32..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples.clickstream - -import java.net.{InetAddress,ServerSocket,Socket,SocketException} -import java.io.{InputStreamReader, BufferedReader, PrintWriter} -import util.Random - -/** Represents a page view on a website with associated dimension data.*/ -class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int) { - override def toString() : String = { - "%s\t%s\t%s\t%s\n".format(url, status, zipCode, userID) - } -} -object PageView { - def fromString(in : String) : PageView = { - val parts = in.split("\t") - new PageView(parts(0), parts(1).toInt, parts(2).toInt, parts(3).toInt) - } -} - -/** Generates streaming events to simulate page views on a website. - * - * This should be used in tandem with PageViewStream.scala. Example: - * $ ./run-example spark.streaming.examples.clickstream.PageViewGenerator 44444 10 - * $ ./run-example spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444 - * */ -object PageViewGenerator { - val pages = Map("http://foo.com/" -> .7, - "http://foo.com/news" -> 0.2, - "http://foo.com/contact" -> .1) - val httpStatus = Map(200 -> .95, - 404 -> .05) - val userZipCode = Map(94709 -> .5, - 94117 -> .5) - val userID = Map((1 to 100).map(_ -> .01):_*) - - - def pickFromDistribution[T](inputMap : Map[T, Double]) : T = { - val rand = new Random().nextDouble() - var total = 0.0 - for ((item, prob) <- inputMap) { - total = total + prob - if (total > rand) { - return item - } - } - return inputMap.take(1).head._1 // Shouldn't get here if probabilities add up to 1.0 - } - - def getNextClickEvent() : String = { - val id = pickFromDistribution(userID) - val page = pickFromDistribution(pages) - val status = pickFromDistribution(httpStatus) - val zipCode = pickFromDistribution(userZipCode) - new PageView(page, status, zipCode, id).toString() - } - - def main(args : Array[String]) { - if (args.length != 2) { - System.err.println("Usage: PageViewGenerator ") - System.exit(1) - } - val port = args(0).toInt - val viewsPerSecond = args(1).toFloat - val sleepDelayMs = (1000.0 / viewsPerSecond).toInt - val listener = new ServerSocket(port) - println("Listening on port: " + port) - - while (true) { - val socket = listener.accept() - new Thread() { - override def run = { - println("Got client connected from: " + socket.getInetAddress) - val out = new PrintWriter(socket.getOutputStream(), true) - - while (true) { - Thread.sleep(sleepDelayMs) - out.write(getNextClickEvent()) - out.flush() - } - socket.close() - } - }.start() - } - } -} diff --git a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala deleted file mode 100644 index 152da23489..0000000000 --- a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.streaming.examples.clickstream - -import spark.streaming.{Seconds, StreamingContext} -import spark.streaming.StreamingContext._ -import spark.SparkContext._ - -/** Analyses a streaming dataset of web page views. This class demonstrates several types of - * operators available in Spark streaming. - * - * This should be used in tandem with PageViewStream.scala. Example: - * $ ./run-example spark.streaming.examples.clickstream.PageViewGenerator 44444 10 - * $ ./run-example spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444 - */ -object PageViewStream { - def main(args: Array[String]) { - if (args.length != 3) { - System.err.println("Usage: PageViewStream ") - System.err.println(" must be one of pageCounts, slidingPageCounts," + - " errorRatePerZipCode, activeUserCount, popularUsersSeen") - System.exit(1) - } - val metric = args(0) - val host = args(1) - val port = args(2).toInt - - // Create the context - val ssc = new StreamingContext("local[2]", "PageViewStream", Seconds(1), - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - - // Create a NetworkInputDStream on target host:port and convert each line to a PageView - val pageViews = ssc.socketTextStream(host, port) - .flatMap(_.split("\n")) - .map(PageView.fromString(_)) - - // Return a count of views per URL seen in each batch - val pageCounts = pageViews.map(view => view.url).countByValue() - - // Return a sliding window of page views per URL in the last ten seconds - val slidingPageCounts = pageViews.map(view => view.url) - .countByValueAndWindow(Seconds(10), Seconds(2)) - - - // Return the rate of error pages (a non 200 status) in each zip code over the last 30 seconds - val statusesPerZipCode = pageViews.window(Seconds(30), Seconds(2)) - .map(view => ((view.zipCode, view.status))) - .groupByKey() - val errorRatePerZipCode = statusesPerZipCode.map{ - case(zip, statuses) => - val normalCount = statuses.filter(_ == 200).size - val errorCount = statuses.size - normalCount - val errorRatio = errorCount.toFloat / statuses.size - if (errorRatio > 0.05) {"%s: **%s**".format(zip, errorRatio)} - else {"%s: %s".format(zip, errorRatio)} - } - - // Return the number unique users in last 15 seconds - val activeUserCount = pageViews.window(Seconds(15), Seconds(2)) - .map(view => (view.userID, 1)) - .groupByKey() - .count() - .map("Unique active users: " + _) - - // An external dataset we want to join to this stream - val userList = ssc.sparkContext.parallelize( - Map(1 -> "Patrick Wendell", 2->"Reynold Xin", 3->"Matei Zaharia").toSeq) - - metric match { - case "pageCounts" => pageCounts.print() - case "slidingPageCounts" => slidingPageCounts.print() - case "errorRatePerZipCode" => errorRatePerZipCode.print() - case "activeUserCount" => activeUserCount.print() - case "popularUsersSeen" => - // Look for users in our existing dataset and print it out if we have a match - pageViews.map(view => (view.userID, 1)) - .foreach((rdd, time) => rdd.join(userList) - .map(_._2._2) - .take(10) - .foreach(u => println("Saw user %s at time %s".format(u, time)))) - case _ => println("Invalid metric entered: " + metric) - } - - ssc.start() - } -} diff --git a/mllib/pom.xml b/mllib/pom.xml index ab31d5734e..2d5d3c00d1 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -19,13 +19,13 @@ 4.0.0 - org.spark-project + org.apache.spark spark-parent 0.8.0-SNAPSHOT ../pom.xml - org.spark-project + org.apache.spark spark-mllib jar Spark Project ML Library @@ -33,7 +33,7 @@ - org.spark-project + org.apache.spark spark-core ${project.version} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala new file mode 100644 index 0000000000..4f4a7f5296 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -0,0 +1,21 @@ +package org.apache.spark.mllib.classification + +import org.apache.spark.RDD + +trait ClassificationModel extends Serializable { + /** + * Predict values for the given data set using the model trained. + * + * @param testData RDD representing data points to be predicted + * @return RDD[Int] where each entry contains the corresponding prediction + */ + def predict(testData: RDD[Array[Double]]): RDD[Double] + + /** + * Predict values for a single data point using the model trained. + * + * @param testData array representing a single data point + * @return Int prediction from the trained model + */ + def predict(testData: Array[Double]): Double +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala new file mode 100644 index 0000000000..91bb50c829 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import scala.math.round + +import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.DataValidators + +import org.jblas.DoubleMatrix + +/** + * Classification model trained using Logistic Regression. + * + * @param weights Weights computed for every feature. + * @param intercept Intercept computed for this model. + */ +class LogisticRegressionModel( + override val weights: Array[Double], + override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) + with ClassificationModel with Serializable { + + override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + intercept: Double) = { + val margin = dataMatrix.mmul(weightMatrix).get(0) + intercept + round(1.0/ (1.0 + math.exp(margin * -1))) + } +} + +/** + * Train a classification model for Logistic Regression using Stochastic Gradient Descent. + * NOTE: Labels used in Logistic Regression should be {0, 1} + */ +class LogisticRegressionWithSGD private ( + var stepSize: Double, + var numIterations: Int, + var regParam: Double, + var miniBatchFraction: Double) + extends GeneralizedLinearAlgorithm[LogisticRegressionModel] + with Serializable { + + val gradient = new LogisticGradient() + val updater = new SimpleUpdater() + override val optimizer = new GradientDescent(gradient, updater) + .setStepSize(stepSize) + .setNumIterations(numIterations) + .setRegParam(regParam) + .setMiniBatchFraction(miniBatchFraction) + override val validators = List(DataValidators.classificationLabels) + + /** + * Construct a LogisticRegression object with default parameters + */ + def this() = this(1.0, 100, 0.0, 1.0) + + def createModel(weights: Array[Double], intercept: Double) = { + new LogisticRegressionModel(weights, intercept) + } +} + +/** + * Top-level methods for calling Logistic Regression. + * NOTE: Labels used in Logistic Regression should be {0, 1} + */ +object LogisticRegressionWithSGD { + // NOTE(shivaram): We use multiple train methods instead of default arguments to support + // Java programs. + + /** + * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed + * number of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in + * gradient descent are initialized using the initial weights provided. + * NOTE: Labels used in Logistic Regression should be {0, 1} + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param miniBatchFraction Fraction of data to be used per iteration. + * @param initialWeights Initial set of weights to be used. Array should be equal in size to + * the number of features in the data. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + miniBatchFraction: Double, + initialWeights: Array[Double]) + : LogisticRegressionModel = + { + new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run( + input, initialWeights) + } + + /** + * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed + * number of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. + * NOTE: Labels used in Logistic Regression should be {0, 1} + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + + * @param miniBatchFraction Fraction of data to be used per iteration. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + miniBatchFraction: Double) + : LogisticRegressionModel = + { + new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run( + input) + } + + /** + * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed + * number of iterations of gradient descent using the specified step size. We use the entire data + * set to update the gradient in each iteration. + * NOTE: Labels used in Logistic Regression should be {0, 1} + * + * @param input RDD of (label, array of features) pairs. + * @param stepSize Step size to be used for each iteration of Gradient Descent. + + * @param numIterations Number of iterations of gradient descent to run. + * @return a LogisticRegressionModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double) + : LogisticRegressionModel = + { + train(input, numIterations, stepSize, 1.0) + } + + /** + * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed + * number of iterations of gradient descent using a step size of 1.0. We use the entire data set + * to update the gradient in each iteration. + * NOTE: Labels used in Logistic Regression should be {0, 1} + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @return a LogisticRegressionModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int) + : LogisticRegressionModel = + { + train(input, numIterations, 1.0, 1.0) + } + + def main(args: Array[String]) { + if (args.length != 4) { + println("Usage: LogisticRegression " + + "") + System.exit(1) + } + val sc = new SparkContext(args(0), "LogisticRegression") + val data = MLUtils.loadLabeledData(sc, args(1)) + val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala new file mode 100644 index 0000000000..c92c7cc3f3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import scala.math.signum + +import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.DataValidators + +import org.jblas.DoubleMatrix + +/** + * Model built using SVM. + * + * @param weights Weights computed for every feature. + * @param intercept Intercept computed for this model. + */ +class SVMModel( + override val weights: Array[Double], + override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) + with ClassificationModel with Serializable { + + override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + intercept: Double) = { + val margin = dataMatrix.dot(weightMatrix) + intercept + if (margin < 0) 0.0 else 1.0 + } +} + +/** + * Train an SVM using Stochastic Gradient Descent. + * NOTE: Labels used in SVM should be {0, 1} + */ +class SVMWithSGD private ( + var stepSize: Double, + var numIterations: Int, + var regParam: Double, + var miniBatchFraction: Double) + extends GeneralizedLinearAlgorithm[SVMModel] with Serializable { + + val gradient = new HingeGradient() + val updater = new SquaredL2Updater() + override val optimizer = new GradientDescent(gradient, updater) + .setStepSize(stepSize) + .setNumIterations(numIterations) + .setRegParam(regParam) + .setMiniBatchFraction(miniBatchFraction) + + override val validators = List(DataValidators.classificationLabels) + + /** + * Construct a SVM object with default parameters + */ + def this() = this(1.0, 100, 1.0, 1.0) + + def createModel(weights: Array[Double], intercept: Double) = { + new SVMModel(weights, intercept) + } +} + +/** + * Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1} + */ +object SVMWithSGD { + + /** + * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in + * gradient descent are initialized using the initial weights provided. + * NOTE: Labels used in SVM should be {0, 1} + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param regParam Regularization parameter. + * @param miniBatchFraction Fraction of data to be used per iteration. + * @param initialWeights Initial set of weights to be used. Array should be equal in size to + * the number of features in the data. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double, + initialWeights: Array[Double]) + : SVMModel = + { + new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input, + initialWeights) + } + + /** + * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. + * NOTE: Labels used in SVM should be {0, 1} + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param regParam Regularization parameter. + * @param miniBatchFraction Fraction of data to be used per iteration. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double) + : SVMModel = + { + new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) + } + + /** + * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. We use the entire data set to + * update the gradient in each iteration. + * NOTE: Labels used in SVM should be {0, 1} + * + * @param input RDD of (label, array of features) pairs. + * @param stepSize Step size to be used for each iteration of Gradient Descent. + * @param regParam Regularization parameter. + * @param numIterations Number of iterations of gradient descent to run. + * @return a SVMModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double) + : SVMModel = + { + train(input, numIterations, stepSize, regParam, 1.0) + } + + /** + * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using a step size of 1.0. We use the entire data set to + * update the gradient in each iteration. + * NOTE: Labels used in SVM should be {0, 1} + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @return a SVMModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int) + : SVMModel = + { + train(input, numIterations, 1.0, 1.0, 1.0) + } + + def main(args: Array[String]) { + if (args.length != 5) { + println("Usage: SVM ") + System.exit(1) + } + val sc = new SparkContext(args(0), "SVM") + val data = MLUtils.loadLabeledData(sc, args(1)) + val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala new file mode 100644 index 0000000000..2c3db099fa --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import org.apache.spark.{SparkContext, RDD} +import org.apache.spark.SparkContext._ +import org.apache.spark.Logging +import org.apache.spark.mllib.util.MLUtils + +import org.jblas.DoubleMatrix + + +/** + * K-means clustering with support for multiple parallel runs and a k-means++ like initialization + * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, + * they are executed together with joint passes over the data for efficiency. + * + * This is an iterative algorithm that will make multiple passes over the data, so any RDDs given + * to it should be cached by the user. + */ +class KMeans private ( + var k: Int, + var maxIterations: Int, + var runs: Int, + var initializationMode: String, + var initializationSteps: Int, + var epsilon: Double) + extends Serializable with Logging +{ + private type ClusterCenters = Array[Array[Double]] + + def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4) + + /** Set the number of clusters to create (k). Default: 2. */ + def setK(k: Int): KMeans = { + this.k = k + this + } + + /** Set maximum number of iterations to run. Default: 20. */ + def setMaxIterations(maxIterations: Int): KMeans = { + this.maxIterations = maxIterations + this + } + + /** + * Set the initialization algorithm. This can be either "random" to choose random points as + * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ + * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. + */ + def setInitializationMode(initializationMode: String): KMeans = { + if (initializationMode != KMeans.RANDOM && initializationMode != KMeans.K_MEANS_PARALLEL) { + throw new IllegalArgumentException("Invalid initialization mode: " + initializationMode) + } + this.initializationMode = initializationMode + this + } + + /** + * Set the number of runs of the algorithm to execute in parallel. We initialize the algorithm + * this many times with random starting conditions (configured by the initialization mode), then + * return the best clustering found over any run. Default: 1. + */ + def setRuns(runs: Int): KMeans = { + if (runs <= 0) { + throw new IllegalArgumentException("Number of runs must be positive") + } + this.runs = runs + this + } + + /** + * Set the number of steps for the k-means|| initialization mode. This is an advanced + * setting -- the default of 5 is almost always enough. Default: 5. + */ + def setInitializationSteps(initializationSteps: Int): KMeans = { + if (initializationSteps <= 0) { + throw new IllegalArgumentException("Number of initialization steps must be positive") + } + this.initializationSteps = initializationSteps + this + } + + /** + * Set the distance threshold within which we've consider centers to have converged. + * If all centers move less than this Euclidean distance, we stop iterating one run. + */ + def setEpsilon(epsilon: Double): KMeans = { + this.epsilon = epsilon + this + } + + /** + * Train a K-means model on the given set of points; `data` should be cached for high + * performance, because this is an iterative algorithm. + */ + def run(data: RDD[Array[Double]]): KMeansModel = { + // TODO: check whether data is persistent; this needs RDD.storageLevel to be publicly readable + + val sc = data.sparkContext + + val centers = if (initializationMode == KMeans.RANDOM) { + initRandom(data) + } else { + initKMeansParallel(data) + } + + val active = Array.fill(runs)(true) + val costs = Array.fill(runs)(0.0) + + var activeRuns = new ArrayBuffer[Int] ++ (0 until runs) + var iteration = 0 + + // Execute iterations of Lloyd's algorithm until all runs have converged + while (iteration < maxIterations && !activeRuns.isEmpty) { + type WeightedPoint = (DoubleMatrix, Long) + def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = { + (p1._1.addi(p2._1), p1._2 + p2._2) + } + + val activeCenters = activeRuns.map(r => centers(r)).toArray + val costAccums = activeRuns.map(_ => sc.accumulator(0.0)) + + // Find the sum and count of points mapping to each center + val totalContribs = data.mapPartitions { points => + val runs = activeCenters.length + val k = activeCenters(0).length + val dims = activeCenters(0)(0).length + + val sums = Array.fill(runs, k)(new DoubleMatrix(dims)) + val counts = Array.fill(runs, k)(0L) + + for (point <- points; (centers, runIndex) <- activeCenters.zipWithIndex) { + val (bestCenter, cost) = KMeans.findClosest(centers, point) + costAccums(runIndex) += cost + sums(runIndex)(bestCenter).addi(new DoubleMatrix(point)) + counts(runIndex)(bestCenter) += 1 + } + + val contribs = for (i <- 0 until runs; j <- 0 until k) yield { + ((i, j), (sums(i)(j), counts(i)(j))) + } + contribs.iterator + }.reduceByKey(mergeContribs).collectAsMap() + + // Update the cluster centers and costs for each active run + for ((run, i) <- activeRuns.zipWithIndex) { + var changed = false + for (j <- 0 until k) { + val (sum, count) = totalContribs((i, j)) + if (count != 0) { + val newCenter = sum.divi(count).data + if (MLUtils.squaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) { + changed = true + } + centers(run)(j) = newCenter + } + } + if (!changed) { + active(run) = false + logInfo("Run " + run + " finished in " + (iteration + 1) + " iterations") + } + costs(run) = costAccums(i).value + } + + activeRuns = activeRuns.filter(active(_)) + iteration += 1 + } + + val bestRun = costs.zipWithIndex.min._2 + new KMeansModel(centers(bestRun)) + } + + /** + * Initialize `runs` sets of cluster centers at random. + */ + private def initRandom(data: RDD[Array[Double]]): Array[ClusterCenters] = { + // Sample all the cluster centers in one pass to avoid repeated scans + val sample = data.takeSample(true, runs * k, new Random().nextInt()).toSeq + Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).toArray) + } + + /** + * Initialize `runs` sets of cluster centers using the k-means|| algorithm by Bahmani et al. + * (Bahmani et al., Scalable K-Means++, VLDB 2012). This is a variant of k-means++ that tries + * to find with dissimilar cluster centers by starting with a random center and then doing + * passes where more centers are chosen with probability proportional to their squared distance + * to the current cluster set. It results in a provable approximation to an optimal clustering. + * + * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf. + */ + private def initKMeansParallel(data: RDD[Array[Double]]): Array[ClusterCenters] = { + // Initialize each run's center to a random point + val seed = new Random().nextInt() + val sample = data.takeSample(true, runs, seed).toSeq + val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r))) + + // On each step, sample 2 * k points on average for each run with probability proportional + // to their squared distance from that run's current centers + for (step <- 0 until initializationSteps) { + val centerArrays = centers.map(_.toArray) + val sumCosts = data.flatMap { point => + for (r <- 0 until runs) yield (r, KMeans.pointCost(centerArrays(r), point)) + }.reduceByKey(_ + _).collectAsMap() + val chosen = data.mapPartitionsWithIndex { (index, points) => + val rand = new Random(seed ^ (step << 16) ^ index) + for { + p <- points + r <- 0 until runs + if rand.nextDouble() < KMeans.pointCost(centerArrays(r), p) * 2 * k / sumCosts(r) + } yield (r, p) + }.collect() + for ((r, p) <- chosen) { + centers(r) += p + } + } + + // Finally, we might have a set of more than k candidate centers for each run; weigh each + // candidate by the number of points in the dataset mapping to it and run a local k-means++ + // on the weighted centers to pick just k of them + val centerArrays = centers.map(_.toArray) + val weightMap = data.flatMap { p => + for (r <- 0 until runs) yield ((r, KMeans.findClosest(centerArrays(r), p)._1), 1.0) + }.reduceByKey(_ + _).collectAsMap() + val finalCenters = (0 until runs).map { r => + val myCenters = centers(r).toArray + val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray + LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30) + } + + finalCenters.toArray + } +} + + +/** + * Top-level methods for calling K-means clustering. + */ +object KMeans { + // Initialization mode names + val RANDOM = "random" + val K_MEANS_PARALLEL = "k-means||" + + def train( + data: RDD[Array[Double]], + k: Int, + maxIterations: Int, + runs: Int, + initializationMode: String) + : KMeansModel = + { + new KMeans().setK(k) + .setMaxIterations(maxIterations) + .setRuns(runs) + .setInitializationMode(initializationMode) + .run(data) + } + + def train(data: RDD[Array[Double]], k: Int, maxIterations: Int, runs: Int): KMeansModel = { + train(data, k, maxIterations, runs, K_MEANS_PARALLEL) + } + + def train(data: RDD[Array[Double]], k: Int, maxIterations: Int): KMeansModel = { + train(data, k, maxIterations, 1, K_MEANS_PARALLEL) + } + + /** + * Return the index of the closest point in `centers` to `point`, as well as its distance. + */ + private[mllib] def findClosest(centers: Array[Array[Double]], point: Array[Double]) + : (Int, Double) = + { + var bestDistance = Double.PositiveInfinity + var bestIndex = 0 + for (i <- 0 until centers.length) { + val distance = MLUtils.squaredDistance(point, centers(i)) + if (distance < bestDistance) { + bestDistance = distance + bestIndex = i + } + } + (bestIndex, bestDistance) + } + + /** + * Return the K-means cost of a given point against the given cluster centers. + */ + private[mllib] def pointCost(centers: Array[Array[Double]], point: Array[Double]): Double = { + var bestDistance = Double.PositiveInfinity + for (i <- 0 until centers.length) { + val distance = MLUtils.squaredDistance(point, centers(i)) + if (distance < bestDistance) { + bestDistance = distance + } + } + bestDistance + } + + def main(args: Array[String]) { + if (args.length < 4) { + println("Usage: KMeans []") + System.exit(1) + } + val (master, inputFile, k, iters) = (args(0), args(1), args(2).toInt, args(3).toInt) + val runs = if (args.length >= 5) args(4).toInt else 1 + val sc = new SparkContext(master, "KMeans") + val data = sc.textFile(inputFile).map(line => line.split(' ').map(_.toDouble)).cache() + val model = KMeans.train(data, k, iters, runs) + val cost = model.computeCost(data) + println("Cluster centers:") + for (c <- model.clusterCenters) { + println(" " + c.mkString(" ")) + } + println("Cost: " + cost) + System.exit(0) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala new file mode 100644 index 0000000000..d1fe5d138d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.RDD +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.util.MLUtils + + +/** + * A clustering model for K-means. Each point belongs to the cluster with the closest center. + */ +class KMeansModel(val clusterCenters: Array[Array[Double]]) extends Serializable { + /** Total number of clusters. */ + def k: Int = clusterCenters.length + + /** Return the cluster index that a given point belongs to. */ + def predict(point: Array[Double]): Int = { + KMeans.findClosest(clusterCenters, point)._1 + } + + /** + * Return the K-means cost (sum of squared distances of points to their nearest center) for this + * model on the given data. + */ + def computeCost(data: RDD[Array[Double]]): Double = { + data.map(p => KMeans.pointCost(clusterCenters, p)).sum + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala new file mode 100644 index 0000000000..baf8251d8f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.util.Random + +import org.jblas.{DoubleMatrix, SimpleBlas} + +/** + * An utility object to run K-means locally. This is private to the ML package because it's used + * in the initialization of KMeans but not meant to be publicly exposed. + */ +private[mllib] object LocalKMeans { + /** + * Run K-means++ on the weighted point set `points`. This first does the K-means++ + * initialization procedure and then roudns of Lloyd's algorithm. + */ + def kMeansPlusPlus( + seed: Int, + points: Array[Array[Double]], + weights: Array[Double], + k: Int, + maxIterations: Int) + : Array[Array[Double]] = + { + val rand = new Random(seed) + val dimensions = points(0).length + val centers = new Array[Array[Double]](k) + + // Initialize centers by sampling using the k-means++ procedure + centers(0) = pickWeighted(rand, points, weights) + for (i <- 1 until k) { + // Pick the next center with a probability proportional to cost under current centers + val curCenters = centers.slice(0, i) + val sum = points.zip(weights).map { case (p, w) => + w * KMeans.pointCost(curCenters, p) + }.sum + val r = rand.nextDouble() * sum + var cumulativeScore = 0.0 + var j = 0 + while (j < points.length && cumulativeScore < r) { + cumulativeScore += weights(j) * KMeans.pointCost(curCenters, points(j)) + j += 1 + } + centers(i) = points(j-1) + } + + // Run up to maxIterations iterations of Lloyd's algorithm + val oldClosest = Array.fill(points.length)(-1) + var iteration = 0 + var moved = true + while (moved && iteration < maxIterations) { + moved = false + val sums = Array.fill(k)(new DoubleMatrix(dimensions)) + val counts = Array.fill(k)(0.0) + for ((p, i) <- points.zipWithIndex) { + val index = KMeans.findClosest(centers, p)._1 + SimpleBlas.axpy(weights(i), new DoubleMatrix(p), sums(index)) + counts(index) += weights(i) + if (index != oldClosest(i)) { + moved = true + oldClosest(i) = index + } + } + // Update centers + for (i <- 0 until k) { + if (counts(i) == 0.0) { + // Assign center to a random point + centers(i) = points(rand.nextInt(points.length)) + } else { + centers(i) = sums(i).divi(counts(i)).data + } + } + iteration += 1 + } + + centers + } + + private def pickWeighted[T](rand: Random, data: Array[T], weights: Array[Double]): T = { + val r = rand.nextDouble() * weights.sum + var i = 0 + var curWeight = 0.0 + while (i < data.length && curWeight < r) { + curWeight += weights(i) + i += 1 + } + data(i - 1) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala new file mode 100644 index 0000000000..749e7364f4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.optimization + +import org.jblas.DoubleMatrix + +/** + * Class used to compute the gradient for a loss function, given a single data point. + */ +abstract class Gradient extends Serializable { + /** + * Compute the gradient and loss given features of a single data point. + * + * @param data - Feature values for one data point. Column matrix of size nx1 + * where n is the number of features. + * @param label - Label for this data item. + * @param weights - Column matrix containing weights for every feature. + * + * @return A tuple of 2 elements. The first element is a column matrix containing the computed + * gradient and the second element is the loss computed at this data point. + * + */ + def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): + (DoubleMatrix, Double) +} + +/** + * Compute gradient and loss for a logistic loss function. + */ +class LogisticGradient extends Gradient { + override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): + (DoubleMatrix, Double) = { + val margin: Double = -1.0 * data.dot(weights) + val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label + + val gradient = data.mul(gradientMultiplier) + val loss = + if (margin > 0) { + math.log(1 + math.exp(0 - margin)) + } else { + math.log(1 + math.exp(margin)) - margin + } + + (gradient, loss) + } +} + +/** + * Compute gradient and loss for a Least-squared loss function. + */ +class SquaredGradient extends Gradient { + override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): + (DoubleMatrix, Double) = { + val diff: Double = data.dot(weights) - label + + val loss = 0.5 * diff * diff + val gradient = data.mul(diff) + + (gradient, loss) + } +} + +/** + * Compute gradient and loss for a Hinge loss function. + * NOTE: This assumes that the labels are {0,1} + */ +class HingeGradient extends Gradient { + override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): + (DoubleMatrix, Double) = { + + val dotProduct = data.dot(weights) + + // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) + // Therefore the gradient is -(2y - 1)*x + val labelScaled = 2 * label - 1.0 + + if (1.0 > labelScaled * dotProduct) { + (data.mul(-labelScaled), 1.0 - labelScaled * dotProduct) + } else { + (DoubleMatrix.zeros(1, weights.length), 0.0) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala new file mode 100644 index 0000000000..b62c9b3340 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.optimization + +import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.SparkContext._ + +import org.jblas.DoubleMatrix + +import scala.collection.mutable.ArrayBuffer + +/** + * Class used to solve an optimization problem using Gradient Descent. + * @param gradient Gradient function to be used. + * @param updater Updater to be used to update weights after every iteration. + */ +class GradientDescent(var gradient: Gradient, var updater: Updater) extends Optimizer { + + private var stepSize: Double = 1.0 + private var numIterations: Int = 100 + private var regParam: Double = 0.0 + private var miniBatchFraction: Double = 1.0 + + /** + * Set the step size per-iteration of SGD. Default 1.0. + */ + def setStepSize(step: Double): this.type = { + this.stepSize = step + this + } + + /** + * Set fraction of data to be used for each SGD iteration. Default 1.0. + */ + def setMiniBatchFraction(fraction: Double): this.type = { + this.miniBatchFraction = fraction + this + } + + /** + * Set the number of iterations for SGD. Default 100. + */ + def setNumIterations(iters: Int): this.type = { + this.numIterations = iters + this + } + + /** + * Set the regularization parameter used for SGD. Default 0.0. + */ + def setRegParam(regParam: Double): this.type = { + this.regParam = regParam + this + } + + /** + * Set the gradient function to be used for SGD. + */ + def setGradient(gradient: Gradient): this.type = { + this.gradient = gradient + this + } + + + /** + * Set the updater function to be used for SGD. + */ + def setUpdater(updater: Updater): this.type = { + this.updater = updater + this + } + + def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double]) + : Array[Double] = { + + val (weights, stochasticLossHistory) = GradientDescent.runMiniBatchSGD( + data, + gradient, + updater, + stepSize, + numIterations, + regParam, + miniBatchFraction, + initialWeights) + weights + } + +} + +// Top-level method to run gradient descent. +object GradientDescent extends Logging { + /** + * Run gradient descent in parallel using mini batches. + * + * @param data - Input data for SGD. RDD of form (label, [feature values]). + * @param gradient - Gradient object that will be used to compute the gradient. + * @param updater - Updater object that will be used to update the model. + * @param stepSize - stepSize to be used during update. + * @param numIterations - number of iterations that SGD should be run. + * @param regParam - regularization parameter + * @param miniBatchFraction - fraction of the input data set that should be used for + * one iteration of SGD. Default value 1.0. + * + * @return A tuple containing two elements. The first element is a column matrix containing + * weights for every feature, and the second element is an array containing the stochastic + * loss computed for every iteration. + */ + def runMiniBatchSGD( + data: RDD[(Double, Array[Double])], + gradient: Gradient, + updater: Updater, + stepSize: Double, + numIterations: Int, + regParam: Double, + miniBatchFraction: Double, + initialWeights: Array[Double]) : (Array[Double], Array[Double]) = { + + val stochasticLossHistory = new ArrayBuffer[Double](numIterations) + + val nexamples: Long = data.count() + val miniBatchSize = nexamples * miniBatchFraction + + // Initialize weights as a column vector + var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*) + var regVal = 0.0 + + for (i <- 1 to numIterations) { + val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42+i).map { + case (y, features) => + val featuresCol = new DoubleMatrix(features.length, 1, features:_*) + val (grad, loss) = gradient.compute(featuresCol, y, weights) + (grad, loss) + }.reduce((a, b) => (a._1.addi(b._1), a._2 + b._2)) + + /** + * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration + * and regVal is the regularization value computed in the previous iteration as well. + */ + stochasticLossHistory.append(lossSum / miniBatchSize + regVal) + val update = updater.compute( + weights, gradientSum.div(miniBatchSize), stepSize, i, regParam) + weights = update._1 + regVal = update._2 + } + + logInfo("GradientDescent finished. Last 10 stochastic losses %s".format( + stochasticLossHistory.takeRight(10).mkString(", "))) + + (weights.toArray, stochasticLossHistory.toArray) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala new file mode 100644 index 0000000000..50059d385d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.optimization + +import org.apache.spark.RDD + +trait Optimizer { + + /** + * Solve the provided convex optimization problem. + */ + def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double]): Array[Double] + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala new file mode 100644 index 0000000000..4c51f4f881 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.optimization + +import scala.math._ +import org.jblas.DoubleMatrix + +/** + * Class used to update weights used in Gradient Descent. + */ +abstract class Updater extends Serializable { + /** + * Compute an updated value for weights given the gradient, stepSize, iteration number and + * regularization parameter. Also returns the regularization value computed using the + * *updated* weights. + * + * @param weightsOld - Column matrix of size nx1 where n is the number of features. + * @param gradient - Column matrix of size nx1 where n is the number of features. + * @param stepSize - step size across iterations + * @param iter - Iteration number + * @param regParam - Regularization parameter + * + * @return A tuple of 2 elements. The first element is a column matrix containing updated weights, + * and the second element is the regularization value computed using updated weights. + */ + def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int, + regParam: Double): (DoubleMatrix, Double) +} + +/** + * A simple updater that adaptively adjusts the learning rate the + * square root of the number of iterations. Does not perform any regularization. + */ +class SimpleUpdater extends Updater { + override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, + stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { + val thisIterStepSize = stepSize / math.sqrt(iter) + val normGradient = gradient.mul(thisIterStepSize) + (weightsOld.sub(normGradient), 0) + } +} + +/** + * Updater that adjusts learning rate and performs L1 regularization. + * + * The corresponding proximal operator used is the soft-thresholding function. + * That is, each weight component is shrunk towards 0 by shrinkageVal. + * + * If w > shrinkageVal, set weight component to w-shrinkageVal. + * If w < -shrinkageVal, set weight component to w+shrinkageVal. + * If -shrinkageVal < w < shrinkageVal, set weight component to 0. + * + * Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal) + */ +class L1Updater extends Updater { + override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, + stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { + val thisIterStepSize = stepSize / math.sqrt(iter) + val normGradient = gradient.mul(thisIterStepSize) + // Take gradient step + val newWeights = weightsOld.sub(normGradient) + // Soft thresholding + val shrinkageVal = regParam * thisIterStepSize + (0 until newWeights.length).foreach { i => + val wi = newWeights.get(i) + newWeights.put(i, signum(wi) * max(0.0, abs(wi) - shrinkageVal)) + } + (newWeights, newWeights.norm1 * regParam) + } +} + +/** + * Updater that adjusts the learning rate and performs L2 regularization + */ +class SquaredL2Updater extends Updater { + override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, + stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { + val thisIterStepSize = stepSize / math.sqrt(iter) + val normGradient = gradient.mul(thisIterStepSize) + val newWeights = weightsOld.sub(normGradient).div(2.0 * thisIterStepSize * regParam + 1.0) + (newWeights, pow(newWeights.norm2, 2.0) * regParam) + } +} + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala new file mode 100644 index 0000000000..218217acfe --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -0,0 +1,453 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.recommendation + +import scala.collection.mutable.{ArrayBuffer, BitSet} +import scala.util.Random +import scala.util.Sorting + +import org.apache.spark.{HashPartitioner, Partitioner, SparkContext, RDD} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.KryoRegistrator +import org.apache.spark.SparkContext._ + +import com.esotericsoftware.kryo.Kryo +import org.jblas.{DoubleMatrix, SimpleBlas, Solve} + + +/** + * Out-link information for a user or product block. This includes the original user/product IDs + * of the elements within this block, and the list of destination blocks that each user or + * product will need to send its feature vector to. + */ +private[recommendation] case class OutLinkBlock(elementIds: Array[Int], shouldSend: Array[BitSet]) + + +/** + * In-link information for a user (or product) block. This includes the original user/product IDs + * of the elements within this block, as well as an array of indices and ratings that specify + * which user in the block will be rated by which products from each product block (or vice-versa). + * Specifically, if this InLinkBlock is for users, ratingsForBlock(b)(i) will contain two arrays, + * indices and ratings, for the i'th product that will be sent to us by product block b (call this + * P). These arrays represent the users that product P had ratings for (by their index in this + * block), as well as the corresponding rating for each one. We can thus use this information when + * we get product block b's message to update the corresponding users. + */ +private[recommendation] case class InLinkBlock( + elementIds: Array[Int], ratingsForBlock: Array[Array[(Array[Int], Array[Double])]]) + + +/** + * A more compact class to represent a rating than Tuple3[Int, Int, Double]. + */ +case class Rating(val user: Int, val product: Int, val rating: Double) + +/** + * Alternating Least Squares matrix factorization. + * + * This is a blocked implementation of the ALS factorization algorithm that groups the two sets + * of factors (referred to as "users" and "products") into blocks and reduces communication by only + * sending one copy of each user vector to each product block on each iteration, and only for the + * product blocks that need that user's feature vector. This is achieved by precomputing some + * information about the ratings matrix to determine the "out-links" of each user (which blocks of + * products it will contribute to) and "in-link" information for each product (which of the feature + * vectors it receives from each user block it will depend on). This allows us to send only an + * array of feature vectors between each user block and product block, and have the product block + * find the users' ratings and update the products based on these messages. + */ +class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var lambda: Double) + extends Serializable +{ + def this() = this(-1, 10, 10, 0.01) + + /** + * Set the number of blocks to parallelize the computation into; pass -1 for an auto-configured + * number of blocks. Default: -1. + */ + def setBlocks(numBlocks: Int): ALS = { + this.numBlocks = numBlocks + this + } + + /** Set the rank of the feature matrices computed (number of features). Default: 10. */ + def setRank(rank: Int): ALS = { + this.rank = rank + this + } + + /** Set the number of iterations to run. Default: 10. */ + def setIterations(iterations: Int): ALS = { + this.iterations = iterations + this + } + + /** Set the regularization parameter, lambda. Default: 0.01. */ + def setLambda(lambda: Double): ALS = { + this.lambda = lambda + this + } + + /** + * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. + * Returns a MatrixFactorizationModel with feature vectors for each user and product. + */ + def run(ratings: RDD[Rating]): MatrixFactorizationModel = { + val numBlocks = if (this.numBlocks == -1) { + math.max(ratings.context.defaultParallelism, ratings.partitions.size / 2) + } else { + this.numBlocks + } + + val partitioner = new HashPartitioner(numBlocks) + + val ratingsByUserBlock = ratings.map{ rating => (rating.user % numBlocks, rating) } + val ratingsByProductBlock = ratings.map{ rating => + (rating.product % numBlocks, Rating(rating.product, rating.user, rating.rating)) + } + + val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock) + val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, ratingsByProductBlock) + + // Initialize user and product factors randomly, but use a deterministic seed for each partition + // so that fault recovery works + val seedGen = new Random() + val seed1 = seedGen.nextInt() + val seed2 = seedGen.nextInt() + // Hash an integer to propagate random bits at all positions, similar to java.util.HashTable + def hash(x: Int): Int = { + val r = x ^ (x >>> 20) ^ (x >>> 12) + r ^ (r >>> 7) ^ (r >>> 4) + } + var users = userOutLinks.mapPartitionsWithIndex { (index, itr) => + val rand = new Random(hash(seed1 ^ index)) + itr.map { case (x, y) => + (x, y.elementIds.map(_ => randomFactor(rank, rand))) + } + } + var products = productOutLinks.mapPartitionsWithIndex { (index, itr) => + val rand = new Random(hash(seed2 ^ index)) + itr.map { case (x, y) => + (x, y.elementIds.map(_ => randomFactor(rank, rand))) + } + } + + for (iter <- 0 until iterations) { + // perform ALS update + products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda) + users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda) + } + + // Flatten and cache the two final RDDs to un-block them + val usersOut = users.join(userOutLinks).flatMap { case (b, (factors, outLinkBlock)) => + for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i)) + } + val productsOut = products.join(productOutLinks).flatMap { case (b, (factors, outLinkBlock)) => + for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i)) + } + + usersOut.persist() + productsOut.persist() + + new MatrixFactorizationModel(rank, usersOut, productsOut) + } + + /** + * Make the out-links table for a block of the users (or products) dataset given the list of + * (user, product, rating) values for the users in that block (or the opposite for products). + */ + private def makeOutLinkBlock(numBlocks: Int, ratings: Array[Rating]): OutLinkBlock = { + val userIds = ratings.map(_.user).distinct.sorted + val numUsers = userIds.length + val userIdToPos = userIds.zipWithIndex.toMap + val shouldSend = Array.fill(numUsers)(new BitSet(numBlocks)) + for (r <- ratings) { + shouldSend(userIdToPos(r.user))(r.product % numBlocks) = true + } + OutLinkBlock(userIds, shouldSend) + } + + /** + * Make the in-links table for a block of the users (or products) dataset given a list of + * (user, product, rating) values for the users in that block (or the opposite for products). + */ + private def makeInLinkBlock(numBlocks: Int, ratings: Array[Rating]): InLinkBlock = { + val userIds = ratings.map(_.user).distinct.sorted + val numUsers = userIds.length + val userIdToPos = userIds.zipWithIndex.toMap + // Split out our ratings by product block + val blockRatings = Array.fill(numBlocks)(new ArrayBuffer[Rating]) + for (r <- ratings) { + blockRatings(r.product % numBlocks) += r + } + val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numBlocks) + for (productBlock <- 0 until numBlocks) { + // Create an array of (product, Seq(Rating)) ratings + val groupedRatings = blockRatings(productBlock).groupBy(_.product).toArray + // Sort them by product ID + val ordering = new Ordering[(Int, ArrayBuffer[Rating])] { + def compare(a: (Int, ArrayBuffer[Rating]), b: (Int, ArrayBuffer[Rating])): Int = a._1 - b._1 + } + Sorting.quickSort(groupedRatings)(ordering) + // Translate the user IDs to indices based on userIdToPos + ratingsForBlock(productBlock) = groupedRatings.map { case (p, rs) => + (rs.view.map(r => userIdToPos(r.user)).toArray, rs.view.map(_.rating).toArray) + } + } + InLinkBlock(userIds, ratingsForBlock) + } + + /** + * Make RDDs of InLinkBlocks and OutLinkBlocks given an RDD of (blockId, (u, p, r)) values for + * the users (or (blockId, (p, u, r)) for the products). We create these simultaneously to avoid + * having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it. + */ + private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, Rating)]) + : (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) = + { + val grouped = ratings.partitionBy(new HashPartitioner(numBlocks)) + val links = grouped.mapPartitionsWithIndex((blockId, elements) => { + val ratings = elements.map{_._2}.toArray + val inLinkBlock = makeInLinkBlock(numBlocks, ratings) + val outLinkBlock = makeOutLinkBlock(numBlocks, ratings) + Iterator.single((blockId, (inLinkBlock, outLinkBlock))) + }, true) + links.persist(StorageLevel.MEMORY_AND_DISK) + (links.mapValues(_._1), links.mapValues(_._2)) + } + + /** + * Make a random factor vector with the given random. + */ + private def randomFactor(rank: Int, rand: Random): Array[Double] = { + Array.fill(rank)(rand.nextDouble) + } + + /** + * Compute the user feature vectors given the current products (or vice-versa). This first joins + * the products with their out-links to generate a set of messages to each destination block + * (specifically, the features for the products that user block cares about), then groups these + * by destination and joins them with the in-link info to figure out how to update each user. + * It returns an RDD of new feature vectors for each user block. + */ + private def updateFeatures( + products: RDD[(Int, Array[Array[Double]])], + productOutLinks: RDD[(Int, OutLinkBlock)], + userInLinks: RDD[(Int, InLinkBlock)], + partitioner: Partitioner, + rank: Int, + lambda: Double) + : RDD[(Int, Array[Array[Double]])] = + { + val numBlocks = products.partitions.size + productOutLinks.join(products).flatMap { case (bid, (outLinkBlock, factors)) => + val toSend = Array.fill(numBlocks)(new ArrayBuffer[Array[Double]]) + for (p <- 0 until outLinkBlock.elementIds.length; userBlock <- 0 until numBlocks) { + if (outLinkBlock.shouldSend(p)(userBlock)) { + toSend(userBlock) += factors(p) + } + } + toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) } + }.groupByKey(partitioner) + .join(userInLinks) + .mapValues{ case (messages, inLinkBlock) => updateBlock(messages, inLinkBlock, rank, lambda) } + } + + /** + * Compute the new feature vectors for a block of the users matrix given the list of factors + * it received from each product and its InLinkBlock. + */ + def updateBlock(messages: Seq[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock, + rank: Int, lambda: Double) + : Array[Array[Double]] = + { + // Sort the incoming block factor messages by block ID and make them an array + val blockFactors = messages.sortBy(_._1).map(_._2).toArray // Array[Array[Double]] + val numBlocks = blockFactors.length + val numUsers = inLinkBlock.elementIds.length + + // We'll sum up the XtXes using vectors that represent only the lower-triangular part, since + // the matrices are symmetric + val triangleSize = rank * (rank + 1) / 2 + val userXtX = Array.fill(numUsers)(DoubleMatrix.zeros(triangleSize)) + val userXy = Array.fill(numUsers)(DoubleMatrix.zeros(rank)) + + // Some temp variables to avoid memory allocation + val tempXtX = DoubleMatrix.zeros(triangleSize) + val fullXtX = DoubleMatrix.zeros(rank, rank) + + // Compute the XtX and Xy values for each user by adding products it rated in each product block + for (productBlock <- 0 until numBlocks) { + for (p <- 0 until blockFactors(productBlock).length) { + val x = new DoubleMatrix(blockFactors(productBlock)(p)) + fillXtX(x, tempXtX) + val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p) + for (i <- 0 until us.length) { + userXtX(us(i)).addi(tempXtX) + SimpleBlas.axpy(rs(i), x, userXy(us(i))) + } + } + } + + // Solve the least-squares problem for each user and return the new feature vectors + userXtX.zipWithIndex.map{ case (triangularXtX, index) => + // Compute the full XtX matrix from the lower-triangular part we got above + fillFullMatrix(triangularXtX, fullXtX) + // Add regularization + (0 until rank).foreach(i => fullXtX.data(i*rank + i) += lambda) + // Solve the resulting matrix, which is symmetric and positive-definite + Solve.solvePositive(fullXtX, userXy(index)).data + } + } + + /** + * Set xtxDest to the lower-triangular part of x transpose * x. For efficiency in summing + * these matrices, we store xtxDest as only rank * (rank+1) / 2 values, namely the values + * at (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), etc in that order. + */ + private def fillXtX(x: DoubleMatrix, xtxDest: DoubleMatrix) { + var i = 0 + var pos = 0 + while (i < x.length) { + var j = 0 + while (j <= i) { + xtxDest.data(pos) = x.data(i) * x.data(j) + pos += 1 + j += 1 + } + i += 1 + } + } + + /** + * Given a triangular matrix in the order of fillXtX above, compute the full symmetric square + * matrix that it represents, storing it into destMatrix. + */ + private def fillFullMatrix(triangularMatrix: DoubleMatrix, destMatrix: DoubleMatrix) { + val rank = destMatrix.rows + var i = 0 + var pos = 0 + while (i < rank) { + var j = 0 + while (j <= i) { + destMatrix.data(i*rank + j) = triangularMatrix.data(pos) + destMatrix.data(j*rank + i) = triangularMatrix.data(pos) + pos += 1 + j += 1 + } + i += 1 + } + } +} + + +/** + * Top-level methods for calling Alternating Least Squares (ALS) matrix factorizaton. + */ +object ALS { + /** + * Train a matrix factorization model given an RDD of ratings given by users to some products, + * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the + * product of two lower-rank matrices of a given rank (number of features). To solve for these + * features, we run a given number of iterations of ALS. This is done using a level of + * parallelism given by `blocks`. + * + * @param ratings RDD of (userID, productID, rating) pairs + * @param rank number of features to use + * @param iterations number of iterations of ALS (recommended: 10-20) + * @param lambda regularization factor (recommended: 0.01) + * @param blocks level of parallelism to split computation into + */ + def train( + ratings: RDD[Rating], + rank: Int, + iterations: Int, + lambda: Double, + blocks: Int) + : MatrixFactorizationModel = + { + new ALS(blocks, rank, iterations, lambda).run(ratings) + } + + /** + * Train a matrix factorization model given an RDD of ratings given by users to some products, + * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the + * product of two lower-rank matrices of a given rank (number of features). To solve for these + * features, we run a given number of iterations of ALS. The level of parallelism is determined + * automatically based on the number of partitions in `ratings`. + * + * @param ratings RDD of (userID, productID, rating) pairs + * @param rank number of features to use + * @param iterations number of iterations of ALS (recommended: 10-20) + * @param lambda regularization factor (recommended: 0.01) + */ + def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double) + : MatrixFactorizationModel = + { + train(ratings, rank, iterations, lambda, -1) + } + + /** + * Train a matrix factorization model given an RDD of ratings given by users to some products, + * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the + * product of two lower-rank matrices of a given rank (number of features). To solve for these + * features, we run a given number of iterations of ALS. The level of parallelism is determined + * automatically based on the number of partitions in `ratings`. + * + * @param ratings RDD of (userID, productID, rating) pairs + * @param rank number of features to use + * @param iterations number of iterations of ALS (recommended: 10-20) + */ + def train(ratings: RDD[Rating], rank: Int, iterations: Int) + : MatrixFactorizationModel = + { + train(ratings, rank, iterations, 0.01, -1) + } + + private class ALSRegistrator extends KryoRegistrator { + override def registerClasses(kryo: Kryo) { + kryo.register(classOf[Rating]) + } + } + + def main(args: Array[String]) { + if (args.length != 5 && args.length != 6) { + println("Usage: ALS []") + System.exit(1) + } + val (master, ratingsFile, rank, iters, outputDir) = + (args(0), args(1), args(2).toInt, args(3).toInt, args(4)) + val blocks = if (args.length == 6) args(5).toInt else -1 + System.setProperty("spark.serializer", "org.apache.spark.KryoSerializer") + System.setProperty("spark.kryo.registrator", classOf[ALSRegistrator].getName) + System.setProperty("spark.kryo.referenceTracking", "false") + System.setProperty("spark.kryoserializer.buffer.mb", "8") + System.setProperty("spark.locality.wait", "10000") + val sc = new SparkContext(master, "ALS") + val ratings = sc.textFile(ratingsFile).map { line => + val fields = line.split(',') + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble) + } + val model = ALS.train(ratings, rank, iters, 0.01, blocks) + model.userFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") } + .saveAsTextFile(outputDir + "/userFeatures") + model.productFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") } + .saveAsTextFile(outputDir + "/productFeatures") + println("Final user/product features written to " + outputDir) + System.exit(0) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala new file mode 100644 index 0000000000..ae9fe48aec --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.recommendation + +import org.apache.spark.RDD +import org.apache.spark.SparkContext._ + +import org.jblas._ + +/** + * Model representing the result of matrix factorization. + * + * @param rank Rank for the features in this model. + * @param userFeatures RDD of tuples where each tuple represents the userId and + * the features computed for this user. + * @param productFeatures RDD of tuples where each tuple represents the productId + * and the features computed for this product. + */ +class MatrixFactorizationModel( + val rank: Int, + val userFeatures: RDD[(Int, Array[Double])], + val productFeatures: RDD[(Int, Array[Double])]) + extends Serializable +{ + /** Predict the rating of one user for one product. */ + def predict(user: Int, product: Int): Double = { + val userVector = new DoubleMatrix(userFeatures.lookup(user).head) + val productVector = new DoubleMatrix(productFeatures.lookup(product).head) + userVector.dot(productVector) + } + + // TODO: Figure out what good bulk prediction methods would look like. + // Probably want a way to get the top users for a product or vice-versa. +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala new file mode 100644 index 0000000000..06015110ac --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.apache.spark.{Logging, RDD, SparkException} +import org.apache.spark.mllib.optimization._ + +import org.jblas.DoubleMatrix + +/** + * GeneralizedLinearModel (GLM) represents a model trained using + * GeneralizedLinearAlgorithm. GLMs consist of a weight vector and + * an intercept. + * + * @param weights Weights computed for every feature. + * @param intercept Intercept computed for this model. + */ +abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: Double) + extends Serializable { + + // Create a column vector that can be used for predictions + private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*) + + /** + * Predict the result given a data point and the weights learned. + * + * @param dataMatrix Row vector containing the features for this data point + * @param weightMatrix Column vector containing the weights of the model + * @param intercept Intercept of the model. + */ + def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + intercept: Double): Double + + /** + * Predict values for the given data set using the model trained. + * + * @param testData RDD representing data points to be predicted + * @return RDD[Double] where each entry contains the corresponding prediction + */ + def predict(testData: RDD[Array[Double]]): RDD[Double] = { + // A small optimization to avoid serializing the entire model. Only the weightsMatrix + // and intercept is needed. + val localWeights = weightsMatrix + val localIntercept = intercept + + testData.map { x => + val dataMatrix = new DoubleMatrix(1, x.length, x:_*) + predictPoint(dataMatrix, localWeights, localIntercept) + } + } + + /** + * Predict values for a single data point using the model trained. + * + * @param testData array representing a single data point + * @return Double prediction from the trained model + */ + def predict(testData: Array[Double]): Double = { + val dataMat = new DoubleMatrix(1, testData.length, testData:_*) + predictPoint(dataMat, weightsMatrix, intercept) + } +} + +/** + * GeneralizedLinearAlgorithm implements methods to train a Genearalized Linear Model (GLM). + * This class should be extended with an Optimizer to create a new GLM. + */ +abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] + extends Logging with Serializable { + + protected val validators: Seq[RDD[LabeledPoint] => Boolean] = List() + + val optimizer: Optimizer + + protected var addIntercept: Boolean = true + + protected var validateData: Boolean = true + + /** + * Create a model given the weights and intercept + */ + protected def createModel(weights: Array[Double], intercept: Double): M + + /** + * Set if the algorithm should add an intercept. Default true. + */ + def setIntercept(addIntercept: Boolean): this.type = { + this.addIntercept = addIntercept + this + } + + /** + * Set if the algorithm should validate data before training. Default true. + */ + def setValidateData(validateData: Boolean): this.type = { + this.validateData = validateData + this + } + + /** + * Run the algorithm with the configured parameters on an input + * RDD of LabeledPoint entries. + */ + def run(input: RDD[LabeledPoint]) : M = { + val nfeatures: Int = input.first().features.length + val initialWeights = Array.fill(nfeatures)(1.0) + run(input, initialWeights) + } + + /** + * Run the algorithm with the configured parameters on an input RDD + * of LabeledPoint entries starting from the initial weights provided. + */ + def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = { + + // Check the data properties before running the optimizer + if (validateData && !validators.forall(func => func(input))) { + throw new SparkException("Input validation failed.") + } + + // Add a extra variable consisting of all 1.0's for the intercept. + val data = if (addIntercept) { + input.map(labeledPoint => (labeledPoint.label, Array(1.0, labeledPoint.features:_*))) + } else { + input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) + } + + val initialWeightsWithIntercept = if (addIntercept) { + Array(1.0, initialWeights:_*) + } else { + initialWeights + } + + val weights = optimizer.optimize(data, initialWeightsWithIntercept) + val intercept = weights(0) + val weightsScaled = weights.tail + + val model = createModel(weightsScaled, intercept) + + logInfo("Final model weights " + model.weights.mkString(",")) + logInfo("Final model intercept " + model.intercept) + model + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala new file mode 100644 index 0000000000..63240e24dc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +/** + * Class that represents the features and labels of a data point. + * + * @param label Label for this data point. + * @param features List of features for this data point. + */ +case class LabeledPoint(val label: Double, val features: Array[Double]) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala new file mode 100644 index 0000000000..df3beb1959 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.util.MLUtils + +import org.jblas.DoubleMatrix + +/** + * Regression model trained using Lasso. + * + * @param weights Weights computed for every feature. + * @param intercept Intercept computed for this model. + */ +class LassoModel( + override val weights: Array[Double], + override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) + with RegressionModel with Serializable { + + override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + intercept: Double) = { + dataMatrix.dot(weightMatrix) + intercept + } +} + +/** + * Train a regression model with L1-regularization using Stochastic Gradient Descent. + */ +class LassoWithSGD private ( + var stepSize: Double, + var numIterations: Int, + var regParam: Double, + var miniBatchFraction: Double) + extends GeneralizedLinearAlgorithm[LassoModel] + with Serializable { + + val gradient = new SquaredGradient() + val updater = new L1Updater() + @transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize) + .setNumIterations(numIterations) + .setRegParam(regParam) + .setMiniBatchFraction(miniBatchFraction) + + // We don't want to penalize the intercept, so set this to false. + setIntercept(false) + + var yMean = 0.0 + var xColMean: DoubleMatrix = _ + var xColSd: DoubleMatrix = _ + + /** + * Construct a Lasso object with default parameters + */ + def this() = this(1.0, 100, 1.0, 1.0) + + def createModel(weights: Array[Double], intercept: Double) = { + val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) + val weightsScaled = weightsMat.div(xColSd) + val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)) + + new LassoModel(weightsScaled.data, interceptScaled) + } + + override def run( + input: RDD[LabeledPoint], + initialWeights: Array[Double]) + : LassoModel = + { + val nfeatures: Int = input.first.features.length + val nexamples: Long = input.count() + + // To avoid penalizing the intercept, we center and scale the data. + val stats = MLUtils.computeStats(input, nfeatures, nexamples) + yMean = stats._1 + xColMean = stats._2 + xColSd = stats._3 + + val normalizedData = input.map { point => + val yNormalized = point.label - yMean + val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*) + val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd) + LabeledPoint(yNormalized, featuresNormalized.toArray) + } + + super.run(normalizedData, initialWeights) + } +} + +/** + * Top-level methods for calling Lasso. + */ +object LassoWithSGD { + + /** + * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in + * gradient descent are initialized using the initial weights provided. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param regParam Regularization parameter. + * @param miniBatchFraction Fraction of data to be used per iteration. + * @param initialWeights Initial set of weights to be used. Array should be equal in size to + * the number of features in the data. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double, + initialWeights: Array[Double]) + : LassoModel = + { + new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input, + initialWeights) + } + + /** + * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param regParam Regularization parameter. + * @param miniBatchFraction Fraction of data to be used per iteration. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double) + : LassoModel = + { + new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) + } + + /** + * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. We use the entire data set to + * update the gradient in each iteration. + * + * @param input RDD of (label, array of features) pairs. + * @param stepSize Step size to be used for each iteration of Gradient Descent. + * @param regParam Regularization parameter. + * @param numIterations Number of iterations of gradient descent to run. + * @return a LassoModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double) + : LassoModel = + { + train(input, numIterations, stepSize, regParam, 1.0) + } + + /** + * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using a step size of 1.0. We use the entire data set to + * update the gradient in each iteration. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @return a LassoModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int) + : LassoModel = + { + train(input, numIterations, 1.0, 1.0, 1.0) + } + + def main(args: Array[String]) { + if (args.length != 5) { + println("Usage: Lasso ") + System.exit(1) + } + val sc = new SparkContext(args(0), "Lasso") + val data = MLUtils.loadLabeledData(sc, args(1)) + val model = LassoWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala new file mode 100644 index 0000000000..71f968471c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.util.MLUtils + +import org.jblas.DoubleMatrix + +/** + * Regression model trained using LinearRegression. + * + * @param weights Weights computed for every feature. + * @param intercept Intercept computed for this model. + */ +class LinearRegressionModel( + override val weights: Array[Double], + override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) + with RegressionModel with Serializable { + + override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + intercept: Double) = { + dataMatrix.dot(weightMatrix) + intercept + } +} + +/** + * Train a regression model with no regularization using Stochastic Gradient Descent. + */ +class LinearRegressionWithSGD private ( + var stepSize: Double, + var numIterations: Int, + var miniBatchFraction: Double) + extends GeneralizedLinearAlgorithm[LinearRegressionModel] + with Serializable { + + val gradient = new SquaredGradient() + val updater = new SimpleUpdater() + val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize) + .setNumIterations(numIterations) + .setMiniBatchFraction(miniBatchFraction) + + /** + * Construct a LinearRegression object with default parameters + */ + def this() = this(1.0, 100, 1.0) + + def createModel(weights: Array[Double], intercept: Double) = { + new LinearRegressionModel(weights, intercept) + } +} + +/** + * Top-level methods for calling LinearRegression. + */ +object LinearRegressionWithSGD { + + /** + * Train a Linear Regression model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in + * gradient descent are initialized using the initial weights provided. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param miniBatchFraction Fraction of data to be used per iteration. + * @param initialWeights Initial set of weights to be used. Array should be equal in size to + * the number of features in the data. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + miniBatchFraction: Double, + initialWeights: Array[Double]) + : LinearRegressionModel = + { + new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input, + initialWeights) + } + + /** + * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param miniBatchFraction Fraction of data to be used per iteration. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + miniBatchFraction: Double) + : LinearRegressionModel = + { + new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input) + } + + /** + * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. We use the entire data set to + * update the gradient in each iteration. + * + * @param input RDD of (label, array of features) pairs. + * @param stepSize Step size to be used for each iteration of Gradient Descent. + * @param numIterations Number of iterations of gradient descent to run. + * @return a LinearRegressionModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double) + : LinearRegressionModel = + { + train(input, numIterations, stepSize, 1.0) + } + + /** + * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using a step size of 1.0. We use the entire data set to + * update the gradient in each iteration. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @return a LinearRegressionModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int) + : LinearRegressionModel = + { + train(input, numIterations, 1.0, 1.0) + } + + def main(args: Array[String]) { + if (args.length != 5) { + println("Usage: LinearRegression ") + System.exit(1) + } + val sc = new SparkContext(args(0), "LinearRegression") + val data = MLUtils.loadLabeledData(sc, args(1)) + val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala new file mode 100644 index 0000000000..8dd325efc0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.apache.spark.RDD + +trait RegressionModel extends Serializable { + /** + * Predict values for the given data set using the model trained. + * + * @param testData RDD representing data points to be predicted + * @return RDD[Double] where each entry contains the corresponding prediction + */ + def predict(testData: RDD[Array[Double]]): RDD[Double] + + /** + * Predict values for a single data point using the model trained. + * + * @param testData array representing a single data point + * @return Double prediction from the trained model + */ + def predict(testData: Array[Double]): Double +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala new file mode 100644 index 0000000000..228ab9e4e8 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.apache.spark.{Logging, RDD, SparkContext} +import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.util.MLUtils + +import org.jblas.DoubleMatrix + +/** + * Regression model trained using RidgeRegression. + * + * @param weights Weights computed for every feature. + * @param intercept Intercept computed for this model. + */ +class RidgeRegressionModel( + override val weights: Array[Double], + override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) + with RegressionModel with Serializable { + + override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + intercept: Double) = { + dataMatrix.dot(weightMatrix) + intercept + } +} + +/** + * Train a regression model with L2-regularization using Stochastic Gradient Descent. + */ +class RidgeRegressionWithSGD private ( + var stepSize: Double, + var numIterations: Int, + var regParam: Double, + var miniBatchFraction: Double) + extends GeneralizedLinearAlgorithm[RidgeRegressionModel] + with Serializable { + + val gradient = new SquaredGradient() + val updater = new SquaredL2Updater() + + @transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize) + .setNumIterations(numIterations) + .setRegParam(regParam) + .setMiniBatchFraction(miniBatchFraction) + + // We don't want to penalize the intercept in RidgeRegression, so set this to false. + setIntercept(false) + + var yMean = 0.0 + var xColMean: DoubleMatrix = _ + var xColSd: DoubleMatrix = _ + + /** + * Construct a RidgeRegression object with default parameters + */ + def this() = this(1.0, 100, 1.0, 1.0) + + def createModel(weights: Array[Double], intercept: Double) = { + val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) + val weightsScaled = weightsMat.div(xColSd) + val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)) + + new RidgeRegressionModel(weightsScaled.data, interceptScaled) + } + + override def run( + input: RDD[LabeledPoint], + initialWeights: Array[Double]) + : RidgeRegressionModel = + { + val nfeatures: Int = input.first.features.length + val nexamples: Long = input.count() + + // To avoid penalizing the intercept, we center and scale the data. + val stats = MLUtils.computeStats(input, nfeatures, nexamples) + yMean = stats._1 + xColMean = stats._2 + xColSd = stats._3 + + val normalizedData = input.map { point => + val yNormalized = point.label - yMean + val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*) + val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd) + LabeledPoint(yNormalized, featuresNormalized.toArray) + } + + super.run(normalizedData, initialWeights) + } +} + +/** + * Top-level methods for calling RidgeRegression. + */ +object RidgeRegressionWithSGD { + + /** + * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in + * gradient descent are initialized using the initial weights provided. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param regParam Regularization parameter. + * @param miniBatchFraction Fraction of data to be used per iteration. + * @param initialWeights Initial set of weights to be used. Array should be equal in size to + * the number of features in the data. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double, + initialWeights: Array[Double]) + : RidgeRegressionModel = + { + new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run( + input, initialWeights) + } + + /** + * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. Each iteration uses + * `miniBatchFraction` fraction of the data to calculate the gradient. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @param stepSize Step size to be used for each iteration of gradient descent. + * @param regParam Regularization parameter. + * @param miniBatchFraction Fraction of data to be used per iteration. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double) + : RidgeRegressionModel = + { + new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) + } + + /** + * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using the specified step size. We use the entire data set to + * update the gradient in each iteration. + * + * @param input RDD of (label, array of features) pairs. + * @param stepSize Step size to be used for each iteration of Gradient Descent. + * @param regParam Regularization parameter. + * @param numIterations Number of iterations of gradient descent to run. + * @return a RidgeRegressionModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + regParam: Double) + : RidgeRegressionModel = + { + train(input, numIterations, stepSize, regParam, 1.0) + } + + /** + * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number + * of iterations of gradient descent using a step size of 1.0. We use the entire data set to + * update the gradient in each iteration. + * + * @param input RDD of (label, array of features) pairs. + * @param numIterations Number of iterations of gradient descent to run. + * @return a RidgeRegressionModel which has the weights and offset from training. + */ + def train( + input: RDD[LabeledPoint], + numIterations: Int) + : RidgeRegressionModel = + { + train(input, numIterations, 1.0, 1.0, 1.0) + } + + def main(args: Array[String]) { + if (args.length != 5) { + println("Usage: RidgeRegression " + + " ") + System.exit(1) + } + val sc = new SparkContext(args(0), "RidgeRegression") + val data = MLUtils.loadLabeledData(sc, args(1)) + val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble, + args(3).toDouble) + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala new file mode 100644 index 0000000000..7fd4623071 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import org.apache.spark.{RDD, Logging} +import org.apache.spark.mllib.regression.LabeledPoint + +/** + * A collection of methods used to validate data before applying ML algorithms. + */ +object DataValidators extends Logging { + + /** + * Function to check if labels used for classification are either zero or one. + * + * @param data - input data set that needs to be checked + * + * @return True if labels are all zero or one, false otherwise. + */ + val classificationLabels: RDD[LabeledPoint] => Boolean = { data => + val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count() + if (numInvalid != 0) { + logError("Classification labels should be 0 or 1. Found " + numInvalid + " invalid labels") + } + numInvalid == 0 + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala new file mode 100644 index 0000000000..6500d47183 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import scala.util.Random + +import org.apache.spark.{RDD, SparkContext} + +/** + * Generate test data for KMeans. This class first chooses k cluster centers + * from a d-dimensional Gaussian distribution scaled by factor r and then creates a Gaussian + * cluster with scale 1 around each center. + */ + +object KMeansDataGenerator { + + /** + * Generate an RDD containing test data for KMeans. + * + * @param sc SparkContext to use for creating the RDD + * @param numPoints Number of points that will be contained in the RDD + * @param k Number of clusters + * @param d Number of dimensions + * @param r Scaling factor for the distribution of the initial centers + * @param numPartitions Number of partitions of the generated RDD; default 2 + */ + def generateKMeansRDD( + sc: SparkContext, + numPoints: Int, + k: Int, + d: Int, + r: Double, + numPartitions: Int = 2) + : RDD[Array[Double]] = + { + // First, generate some centers + val rand = new Random(42) + val centers = Array.fill(k)(Array.fill(d)(rand.nextGaussian() * r)) + // Then generate points around each center + sc.parallelize(0 until numPoints, numPartitions).map { idx => + val center = centers(idx % k) + val rand2 = new Random(42 + idx) + Array.tabulate(d)(i => center(i) + rand2.nextGaussian()) + } + } + + def main(args: Array[String]) { + if (args.length < 6) { + println("Usage: KMeansGenerator " + + " []") + System.exit(1) + } + + val sparkMaster = args(0) + val outputPath = args(1) + val numPoints = args(2).toInt + val k = args(3).toInt + val d = args(4).toInt + val r = args(5).toDouble + val parts = if (args.length >= 7) args(6).toInt else 2 + + val sc = new SparkContext(sparkMaster, "KMeansDataGenerator") + val data = generateKMeansRDD(sc, numPoints, k, d, r, parts) + data.map(_.mkString(" ")).saveAsTextFile(outputPath) + + System.exit(0) + } +} + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala new file mode 100644 index 0000000000..4c49d484b4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import scala.collection.JavaConversions._ +import scala.util.Random + +import org.jblas.DoubleMatrix + +import org.apache.spark.{RDD, SparkContext} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.LabeledPoint + +/** + * Generate sample data used for Linear Data. This class generates + * uniformly random values for every feature and adds Gaussian noise with mean `eps` to the + * response variable `Y`. + */ +object LinearDataGenerator { + + /** + * Return a Java List of synthetic data randomly generated according to a multi + * collinear model. + * @param intercept Data intercept + * @param weights Weights to be applied. + * @param nPoints Number of points in sample. + * @param seed Random seed + * @return Java List of input. + */ + def generateLinearInputAsList( + intercept: Double, + weights: Array[Double], + nPoints: Int, + seed: Int, + eps: Double): java.util.List[LabeledPoint] = { + seqAsJavaList(generateLinearInput(intercept, weights, nPoints, seed, eps)) + } + + /** + * + * @param intercept Data intercept + * @param weights Weights to be applied. + * @param nPoints Number of points in sample. + * @param seed Random seed + * @param eps Epsilon scaling factor. + * @return + */ + def generateLinearInput( + intercept: Double, + weights: Array[Double], + nPoints: Int, + seed: Int, + eps: Double = 0.1): Seq[LabeledPoint] = { + + val rnd = new Random(seed) + val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) + val x = Array.fill[Array[Double]](nPoints)( + Array.fill[Double](weights.length)(2 * rnd.nextDouble - 1.0)) + val y = x.map { xi => + (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + intercept + eps * rnd.nextGaussian() + } + y.zip(x).map(p => LabeledPoint(p._1, p._2)) + } + + /** + * Generate an RDD containing sample data for Linear Regression models - including Ridge, Lasso, + * and uregularized variants. + * + * @param sc SparkContext to be used for generating the RDD. + * @param nexamples Number of examples that will be contained in the RDD. + * @param nfeatures Number of features to generate for each example. + * @param eps Epsilon factor by which examples are scaled. + * @param weights Weights associated with the first weights.length features. + * @param nparts Number of partitions in the RDD. Default value is 2. + * + * @return RDD of LabeledPoint containing sample data. + */ + def generateLinearRDD( + sc: SparkContext, + nexamples: Int, + nfeatures: Int, + eps: Double, + nparts: Int = 2, + intercept: Double = 0.0) : RDD[LabeledPoint] = { + org.jblas.util.Random.seed(42) + // Random values distributed uniformly in [-0.5, 0.5] + val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5) + + val data: RDD[LabeledPoint] = sc.parallelize(0 until nparts, nparts).flatMap { p => + val seed = 42 + p + val examplesInPartition = nexamples / nparts + generateLinearInput(intercept, w.toArray, examplesInPartition, seed, eps) + } + data + } + + def main(args: Array[String]) { + if (args.length < 2) { + println("Usage: LinearDataGenerator " + + " [num_examples] [num_features] [num_partitions]") + System.exit(1) + } + + val sparkMaster: String = args(0) + val outputPath: String = args(1) + val nexamples: Int = if (args.length > 2) args(2).toInt else 1000 + val nfeatures: Int = if (args.length > 3) args(3).toInt else 100 + val parts: Int = if (args.length > 4) args(4).toInt else 2 + val eps = 10 + + val sc = new SparkContext(sparkMaster, "LinearDataGenerator") + val data = generateLinearRDD(sc, nexamples, nfeatures, eps, nparts = parts) + + MLUtils.saveLabeledData(data, outputPath) + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala new file mode 100644 index 0000000000..f553298fc5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import scala.util.Random + +import org.apache.spark.{RDD, SparkContext} +import org.apache.spark.mllib.regression.LabeledPoint + +/** + * Generate test data for LogisticRegression. This class chooses positive labels + * with probability `probOne` and scales features for positive examples by `eps`. + */ + +object LogisticRegressionDataGenerator { + + /** + * Generate an RDD containing test data for LogisticRegression. + * + * @param sc SparkContext to use for creating the RDD. + * @param nexamples Number of examples that will be contained in the RDD. + * @param nfeatures Number of features to generate for each example. + * @param eps Epsilon factor by which positive examples are scaled. + * @param nparts Number of partitions of the generated RDD. Default value is 2. + * @param probOne Probability that a label is 1 (and not 0). Default value is 0.5. + */ + def generateLogisticRDD( + sc: SparkContext, + nexamples: Int, + nfeatures: Int, + eps: Double, + nparts: Int = 2, + probOne: Double = 0.5): RDD[LabeledPoint] = { + val data = sc.parallelize(0 until nexamples, nparts).map { idx => + val rnd = new Random(42 + idx) + + val y = if (idx % 2 == 0) 0.0 else 1.0 + val x = Array.fill[Double](nfeatures) { + rnd.nextGaussian() + (y * eps) + } + LabeledPoint(y, x) + } + data + } + + def main(args: Array[String]) { + if (args.length != 5) { + println("Usage: LogisticRegressionGenerator " + + " ") + System.exit(1) + } + + val sparkMaster: String = args(0) + val outputPath: String = args(1) + val nexamples: Int = if (args.length > 2) args(2).toInt else 1000 + val nfeatures: Int = if (args.length > 3) args(3).toInt else 2 + val parts: Int = if (args.length > 4) args(4).toInt else 2 + val eps = 3 + + val sc = new SparkContext(sparkMaster, "LogisticRegressionDataGenerator") + val data = generateLogisticRDD(sc, nexamples, nfeatures, eps, parts) + + MLUtils.saveLabeledData(data, outputPath) + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala new file mode 100644 index 0000000000..7eb69ae81c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.recommendation + +import scala.util.Random + +import org.jblas.DoubleMatrix + +import org.apache.spark.{RDD, SparkContext} +import org.apache.spark.mllib.util.MLUtils + +/** +* Generate RDD(s) containing data for Matrix Factorization. +* +* This method samples training entries according to the oversampling factor +* 'trainSampFact', which is a multiplicative factor of the number of +* degrees of freedom of the matrix: rank*(m+n-rank). +* +* It optionally samples entries for a testing matrix using +* 'testSampFact', the percentage of the number of training entries +* to use for testing. +* +* This method takes the following inputs: +* sparkMaster (String) The master URL. +* outputPath (String) Directory to save output. +* m (Int) Number of rows in data matrix. +* n (Int) Number of columns in data matrix. +* rank (Int) Underlying rank of data matrix. +* trainSampFact (Double) Oversampling factor. +* noise (Boolean) Whether to add gaussian noise to training data. +* sigma (Double) Standard deviation of added gaussian noise. +* test (Boolean) Whether to create testing RDD. +* testSampFact (Double) Percentage of training data to use as test data. +*/ + +object MFDataGenerator{ + + def main(args: Array[String]) { + if (args.length < 2) { + println("Usage: MFDataGenerator " + + " [m] [n] [rank] [trainSampFact] [noise] [sigma] [test] [testSampFact]") + System.exit(1) + } + + val sparkMaster: String = args(0) + val outputPath: String = args(1) + val m: Int = if (args.length > 2) args(2).toInt else 100 + val n: Int = if (args.length > 3) args(3).toInt else 100 + val rank: Int = if (args.length > 4) args(4).toInt else 10 + val trainSampFact: Double = if (args.length > 5) args(5).toDouble else 1.0 + val noise: Boolean = if (args.length > 6) args(6).toBoolean else false + val sigma: Double = if (args.length > 7) args(7).toDouble else 0.1 + val test: Boolean = if (args.length > 8) args(8).toBoolean else false + val testSampFact: Double = if (args.length > 9) args(9).toDouble else 0.1 + + val sc = new SparkContext(sparkMaster, "MFDataGenerator") + + val A = DoubleMatrix.randn(m, rank) + val B = DoubleMatrix.randn(rank, n) + val z = 1 / (scala.math.sqrt(scala.math.sqrt(rank))) + A.mmuli(z) + B.mmuli(z) + val fullData = A.mmul(B) + + val df = rank * (m + n - rank) + val sampSize = scala.math.min(scala.math.round(trainSampFact * df), + scala.math.round(.99 * m * n)).toInt + val rand = new Random() + val mn = m * n + val shuffled = rand.shuffle(1 to mn toIterable) + + val omega = shuffled.slice(0, sampSize) + val ordered = omega.sortWith(_ < _).toArray + val trainData: RDD[(Int, Int, Double)] = sc.parallelize(ordered) + .map(x => (fullData.indexRows(x - 1), fullData.indexColumns(x - 1), fullData.get(x - 1))) + + // optionally add gaussian noise + if (noise) { + trainData.map(x => (x._1, x._2, x._3 + rand.nextGaussian * sigma)) + } + + trainData.map(x => x._1 + "," + x._2 + "," + x._3).saveAsTextFile(outputPath) + + // optionally generate testing data + if (test) { + val testSampSize = scala.math + .min(scala.math.round(sampSize * testSampFact),scala.math.round(mn - sampSize)).toInt + val testOmega = shuffled.slice(sampSize, sampSize + testSampSize) + val testOrdered = testOmega.sortWith(_ < _).toArray + val testData: RDD[(Int, Int, Double)] = sc.parallelize(testOrdered) + .map(x => (fullData.indexRows(x - 1), fullData.indexColumns(x - 1), fullData.get(x - 1))) + testData.map(x => x._1 + "," + x._2 + "," + x._3).saveAsTextFile(outputPath) + } + + sc.stop() + + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala new file mode 100644 index 0000000000..0aeafbe23c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import org.apache.spark.{RDD, SparkContext} +import org.apache.spark.SparkContext._ + +import org.jblas.DoubleMatrix +import org.apache.spark.mllib.regression.LabeledPoint + +/** + * Helper methods to load, save and pre-process data used in ML Lib. + */ +object MLUtils { + + /** + * Load labeled data from a file. The data format used here is + * , ... + * where , are feature values in Double and is the corresponding label as Double. + * + * @param sc SparkContext + * @param dir Directory to the input data files. + * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is + * the label, and the second element represents the feature values (an array of Double). + */ + def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { + sc.textFile(dir).map { line => + val parts = line.split(',') + val label = parts(0).toDouble + val features = parts(1).trim().split(' ').map(_.toDouble) + LabeledPoint(label, features) + } + } + + /** + * Save labeled data to a file. The data format used here is + * , ... + * where , are feature values in Double and is the corresponding label as Double. + * + * @param data An RDD of LabeledPoints containing data to be saved. + * @param dir Directory to save the data. + */ + def saveLabeledData(data: RDD[LabeledPoint], dir: String) { + val dataStr = data.map(x => x.label + "," + x.features.mkString(" ")) + dataStr.saveAsTextFile(dir) + } + + /** + * Utility function to compute mean and standard deviation on a given dataset. + * + * @param data - input data set whose statistics are computed + * @param nfeatures - number of features + * @param nexamples - number of examples in input dataset + * + * @return (yMean, xColMean, xColSd) - Tuple consisting of + * yMean - mean of the labels + * xColMean - Row vector with mean for every column (or feature) of the input data + * xColSd - Row vector standard deviation for every column (or feature) of the input data. + */ + def computeStats(data: RDD[LabeledPoint], nfeatures: Int, nexamples: Long): + (Double, DoubleMatrix, DoubleMatrix) = { + val yMean: Double = data.map { labeledPoint => labeledPoint.label }.reduce(_ + _) / nexamples + + // NOTE: We shuffle X by column here to compute column sum and sum of squares. + val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { labeledPoint => + val nCols = labeledPoint.features.length + // Traverse over every column and emit (col, value, value^2) + Iterator.tabulate(nCols) { i => + (i, (labeledPoint.features(i), labeledPoint.features(i)*labeledPoint.features(i))) + } + }.reduceByKey { case(x1, x2) => + (x1._1 + x2._1, x1._2 + x2._2) + } + val xColSumsMap = xColSumSq.collectAsMap() + + val xColMean = DoubleMatrix.zeros(nfeatures, 1) + val xColSd = DoubleMatrix.zeros(nfeatures, 1) + + // Compute mean and unbiased variance using column sums + var col = 0 + while (col < nfeatures) { + xColMean.put(col, xColSumsMap(col)._1 / nexamples) + val variance = + (xColSumsMap(col)._2 - (math.pow(xColSumsMap(col)._1, 2) / nexamples)) / (nexamples) + xColSd.put(col, math.sqrt(variance)) + col += 1 + } + + (yMean, xColMean, xColSd) + } + + /** + * Return the squared Euclidean distance between two vectors. + */ + def squaredDistance(v1: Array[Double], v2: Array[Double]): Double = { + if (v1.length != v2.length) { + throw new IllegalArgumentException("Vector sizes don't match") + } + var i = 0 + var sum = 0.0 + while (i < v1.length) { + sum += (v1(i) - v2(i)) * (v1(i) - v2(i)) + i += 1 + } + sum + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala new file mode 100644 index 0000000000..d3f191b05b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala @@ -0,0 +1,50 @@ +package org.apache.spark.mllib.util + +import scala.util.Random + +import org.apache.spark.{RDD, SparkContext} + +import org.jblas.DoubleMatrix +import org.apache.spark.mllib.regression.LabeledPoint + +/** + * Generate sample data used for SVM. This class generates uniform random values + * for the features and adds Gaussian noise with weight 0.1 to generate labels. + */ +object SVMDataGenerator { + + def main(args: Array[String]) { + if (args.length < 2) { + println("Usage: SVMGenerator " + + " [num_examples] [num_features] [num_partitions]") + System.exit(1) + } + + val sparkMaster: String = args(0) + val outputPath: String = args(1) + val nexamples: Int = if (args.length > 2) args(2).toInt else 1000 + val nfeatures: Int = if (args.length > 3) args(3).toInt else 2 + val parts: Int = if (args.length > 4) args(4).toInt else 2 + + val sc = new SparkContext(sparkMaster, "SVMGenerator") + + val globalRnd = new Random(94720) + val trueWeights = new DoubleMatrix(1, nfeatures + 1, + Array.fill[Double](nfeatures + 1)(globalRnd.nextGaussian()):_*) + + val data: RDD[LabeledPoint] = sc.parallelize(0 until nexamples, parts).map { idx => + val rnd = new Random(42 + idx) + + val x = Array.fill[Double](nfeatures) { + rnd.nextDouble() * 2.0 - 1.0 + } + val yD = (new DoubleMatrix(1, x.length, x:_*)).dot(trueWeights) + rnd.nextGaussian() * 0.1 + val y = if (yD < 0) 0.0 else 1.0 + LabeledPoint(y, x) + } + + MLUtils.saveLabeledData(data, outputPath) + + sc.stop() + } +} diff --git a/mllib/src/main/scala/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/spark/mllib/classification/ClassificationModel.scala deleted file mode 100644 index 70fae8c15a..0000000000 --- a/mllib/src/main/scala/spark/mllib/classification/ClassificationModel.scala +++ /dev/null @@ -1,21 +0,0 @@ -package spark.mllib.classification - -import spark.RDD - -trait ClassificationModel extends Serializable { - /** - * Predict values for the given data set using the model trained. - * - * @param testData RDD representing data points to be predicted - * @return RDD[Int] where each entry contains the corresponding prediction - */ - def predict(testData: RDD[Array[Double]]): RDD[Double] - - /** - * Predict values for a single data point using the model trained. - * - * @param testData array representing a single data point - * @return Int prediction from the trained model - */ - def predict(testData: Array[Double]): Double -} diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala deleted file mode 100644 index 482e4a6745..0000000000 --- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala +++ /dev/null @@ -1,188 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.classification - -import scala.math.round - -import spark.{Logging, RDD, SparkContext} -import spark.mllib.optimization._ -import spark.mllib.regression._ -import spark.mllib.util.MLUtils -import spark.mllib.util.DataValidators - -import org.jblas.DoubleMatrix - -/** - * Classification model trained using Logistic Regression. - * - * @param weights Weights computed for every feature. - * @param intercept Intercept computed for this model. - */ -class LogisticRegressionModel( - override val weights: Array[Double], - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with ClassificationModel with Serializable { - - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { - val margin = dataMatrix.mmul(weightMatrix).get(0) + intercept - round(1.0/ (1.0 + math.exp(margin * -1))) - } -} - -/** - * Train a classification model for Logistic Regression using Stochastic Gradient Descent. - * NOTE: Labels used in Logistic Regression should be {0, 1} - */ -class LogisticRegressionWithSGD private ( - var stepSize: Double, - var numIterations: Int, - var regParam: Double, - var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[LogisticRegressionModel] - with Serializable { - - val gradient = new LogisticGradient() - val updater = new SimpleUpdater() - override val optimizer = new GradientDescent(gradient, updater) - .setStepSize(stepSize) - .setNumIterations(numIterations) - .setRegParam(regParam) - .setMiniBatchFraction(miniBatchFraction) - override val validators = List(DataValidators.classificationLabels) - - /** - * Construct a LogisticRegression object with default parameters - */ - def this() = this(1.0, 100, 0.0, 1.0) - - def createModel(weights: Array[Double], intercept: Double) = { - new LogisticRegressionModel(weights, intercept) - } -} - -/** - * Top-level methods for calling Logistic Regression. - * NOTE: Labels used in Logistic Regression should be {0, 1} - */ -object LogisticRegressionWithSGD { - // NOTE(shivaram): We use multiple train methods instead of default arguments to support - // Java programs. - - /** - * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed - * number of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in - * gradient descent are initialized using the initial weights provided. - * NOTE: Labels used in Logistic Regression should be {0, 1} - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param miniBatchFraction Fraction of data to be used per iteration. - * @param initialWeights Initial set of weights to be used. Array should be equal in size to - * the number of features in the data. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - miniBatchFraction: Double, - initialWeights: Array[Double]) - : LogisticRegressionModel = - { - new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run( - input, initialWeights) - } - - /** - * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed - * number of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. - * NOTE: Labels used in Logistic Regression should be {0, 1} - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - - * @param miniBatchFraction Fraction of data to be used per iteration. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - miniBatchFraction: Double) - : LogisticRegressionModel = - { - new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run( - input) - } - - /** - * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed - * number of iterations of gradient descent using the specified step size. We use the entire data - * set to update the gradient in each iteration. - * NOTE: Labels used in Logistic Regression should be {0, 1} - * - * @param input RDD of (label, array of features) pairs. - * @param stepSize Step size to be used for each iteration of Gradient Descent. - - * @param numIterations Number of iterations of gradient descent to run. - * @return a LogisticRegressionModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double) - : LogisticRegressionModel = - { - train(input, numIterations, stepSize, 1.0) - } - - /** - * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed - * number of iterations of gradient descent using a step size of 1.0. We use the entire data set - * to update the gradient in each iteration. - * NOTE: Labels used in Logistic Regression should be {0, 1} - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @return a LogisticRegressionModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int) - : LogisticRegressionModel = - { - train(input, numIterations, 1.0, 1.0) - } - - def main(args: Array[String]) { - if (args.length != 4) { - println("Usage: LogisticRegression " + - "") - System.exit(1) - } - val sc = new SparkContext(args(0), "LogisticRegression") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) - - sc.stop() - } -} diff --git a/mllib/src/main/scala/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/spark/mllib/classification/SVM.scala deleted file mode 100644 index 69393cd7b0..0000000000 --- a/mllib/src/main/scala/spark/mllib/classification/SVM.scala +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.classification - -import scala.math.signum - -import spark.{Logging, RDD, SparkContext} -import spark.mllib.optimization._ -import spark.mllib.regression._ -import spark.mllib.util.MLUtils -import spark.mllib.util.DataValidators - -import org.jblas.DoubleMatrix - -/** - * Model built using SVM. - * - * @param weights Weights computed for every feature. - * @param intercept Intercept computed for this model. - */ -class SVMModel( - override val weights: Array[Double], - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with ClassificationModel with Serializable { - - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { - val margin = dataMatrix.dot(weightMatrix) + intercept - if (margin < 0) 0.0 else 1.0 - } -} - -/** - * Train an SVM using Stochastic Gradient Descent. - * NOTE: Labels used in SVM should be {0, 1} - */ -class SVMWithSGD private ( - var stepSize: Double, - var numIterations: Int, - var regParam: Double, - var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[SVMModel] with Serializable { - - val gradient = new HingeGradient() - val updater = new SquaredL2Updater() - override val optimizer = new GradientDescent(gradient, updater) - .setStepSize(stepSize) - .setNumIterations(numIterations) - .setRegParam(regParam) - .setMiniBatchFraction(miniBatchFraction) - - override val validators = List(DataValidators.classificationLabels) - - /** - * Construct a SVM object with default parameters - */ - def this() = this(1.0, 100, 1.0, 1.0) - - def createModel(weights: Array[Double], intercept: Double) = { - new SVMModel(weights, intercept) - } -} - -/** - * Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1} - */ -object SVMWithSGD { - - /** - * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in - * gradient descent are initialized using the initial weights provided. - * NOTE: Labels used in SVM should be {0, 1} - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param regParam Regularization parameter. - * @param miniBatchFraction Fraction of data to be used per iteration. - * @param initialWeights Initial set of weights to be used. Array should be equal in size to - * the number of features in the data. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double, - miniBatchFraction: Double, - initialWeights: Array[Double]) - : SVMModel = - { - new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input, - initialWeights) - } - - /** - * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. - * NOTE: Labels used in SVM should be {0, 1} - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param regParam Regularization parameter. - * @param miniBatchFraction Fraction of data to be used per iteration. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double, - miniBatchFraction: Double) - : SVMModel = - { - new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) - } - - /** - * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. We use the entire data set to - * update the gradient in each iteration. - * NOTE: Labels used in SVM should be {0, 1} - * - * @param input RDD of (label, array of features) pairs. - * @param stepSize Step size to be used for each iteration of Gradient Descent. - * @param regParam Regularization parameter. - * @param numIterations Number of iterations of gradient descent to run. - * @return a SVMModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double) - : SVMModel = - { - train(input, numIterations, stepSize, regParam, 1.0) - } - - /** - * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using a step size of 1.0. We use the entire data set to - * update the gradient in each iteration. - * NOTE: Labels used in SVM should be {0, 1} - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @return a SVMModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int) - : SVMModel = - { - train(input, numIterations, 1.0, 1.0, 1.0) - } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: SVM ") - System.exit(1) - } - val sc = new SparkContext(args(0), "SVM") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) - - sc.stop() - } -} diff --git a/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala deleted file mode 100644 index 97e3d110ae..0000000000 --- a/mllib/src/main/scala/spark/mllib/clustering/KMeans.scala +++ /dev/null @@ -1,335 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.clustering - -import scala.collection.mutable.ArrayBuffer -import scala.util.Random - -import spark.{SparkContext, RDD} -import spark.SparkContext._ -import spark.Logging -import spark.mllib.util.MLUtils - -import org.jblas.DoubleMatrix - - -/** - * K-means clustering with support for multiple parallel runs and a k-means++ like initialization - * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, - * they are executed together with joint passes over the data for efficiency. - * - * This is an iterative algorithm that will make multiple passes over the data, so any RDDs given - * to it should be cached by the user. - */ -class KMeans private ( - var k: Int, - var maxIterations: Int, - var runs: Int, - var initializationMode: String, - var initializationSteps: Int, - var epsilon: Double) - extends Serializable with Logging -{ - private type ClusterCenters = Array[Array[Double]] - - def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4) - - /** Set the number of clusters to create (k). Default: 2. */ - def setK(k: Int): KMeans = { - this.k = k - this - } - - /** Set maximum number of iterations to run. Default: 20. */ - def setMaxIterations(maxIterations: Int): KMeans = { - this.maxIterations = maxIterations - this - } - - /** - * Set the initialization algorithm. This can be either "random" to choose random points as - * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ - * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. - */ - def setInitializationMode(initializationMode: String): KMeans = { - if (initializationMode != KMeans.RANDOM && initializationMode != KMeans.K_MEANS_PARALLEL) { - throw new IllegalArgumentException("Invalid initialization mode: " + initializationMode) - } - this.initializationMode = initializationMode - this - } - - /** - * Set the number of runs of the algorithm to execute in parallel. We initialize the algorithm - * this many times with random starting conditions (configured by the initialization mode), then - * return the best clustering found over any run. Default: 1. - */ - def setRuns(runs: Int): KMeans = { - if (runs <= 0) { - throw new IllegalArgumentException("Number of runs must be positive") - } - this.runs = runs - this - } - - /** - * Set the number of steps for the k-means|| initialization mode. This is an advanced - * setting -- the default of 5 is almost always enough. Default: 5. - */ - def setInitializationSteps(initializationSteps: Int): KMeans = { - if (initializationSteps <= 0) { - throw new IllegalArgumentException("Number of initialization steps must be positive") - } - this.initializationSteps = initializationSteps - this - } - - /** - * Set the distance threshold within which we've consider centers to have converged. - * If all centers move less than this Euclidean distance, we stop iterating one run. - */ - def setEpsilon(epsilon: Double): KMeans = { - this.epsilon = epsilon - this - } - - /** - * Train a K-means model on the given set of points; `data` should be cached for high - * performance, because this is an iterative algorithm. - */ - def run(data: RDD[Array[Double]]): KMeansModel = { - // TODO: check whether data is persistent; this needs RDD.storageLevel to be publicly readable - - val sc = data.sparkContext - - val centers = if (initializationMode == KMeans.RANDOM) { - initRandom(data) - } else { - initKMeansParallel(data) - } - - val active = Array.fill(runs)(true) - val costs = Array.fill(runs)(0.0) - - var activeRuns = new ArrayBuffer[Int] ++ (0 until runs) - var iteration = 0 - - // Execute iterations of Lloyd's algorithm until all runs have converged - while (iteration < maxIterations && !activeRuns.isEmpty) { - type WeightedPoint = (DoubleMatrix, Long) - def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = { - (p1._1.addi(p2._1), p1._2 + p2._2) - } - - val activeCenters = activeRuns.map(r => centers(r)).toArray - val costAccums = activeRuns.map(_ => sc.accumulator(0.0)) - - // Find the sum and count of points mapping to each center - val totalContribs = data.mapPartitions { points => - val runs = activeCenters.length - val k = activeCenters(0).length - val dims = activeCenters(0)(0).length - - val sums = Array.fill(runs, k)(new DoubleMatrix(dims)) - val counts = Array.fill(runs, k)(0L) - - for (point <- points; (centers, runIndex) <- activeCenters.zipWithIndex) { - val (bestCenter, cost) = KMeans.findClosest(centers, point) - costAccums(runIndex) += cost - sums(runIndex)(bestCenter).addi(new DoubleMatrix(point)) - counts(runIndex)(bestCenter) += 1 - } - - val contribs = for (i <- 0 until runs; j <- 0 until k) yield { - ((i, j), (sums(i)(j), counts(i)(j))) - } - contribs.iterator - }.reduceByKey(mergeContribs).collectAsMap() - - // Update the cluster centers and costs for each active run - for ((run, i) <- activeRuns.zipWithIndex) { - var changed = false - for (j <- 0 until k) { - val (sum, count) = totalContribs((i, j)) - if (count != 0) { - val newCenter = sum.divi(count).data - if (MLUtils.squaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) { - changed = true - } - centers(run)(j) = newCenter - } - } - if (!changed) { - active(run) = false - logInfo("Run " + run + " finished in " + (iteration + 1) + " iterations") - } - costs(run) = costAccums(i).value - } - - activeRuns = activeRuns.filter(active(_)) - iteration += 1 - } - - val bestRun = costs.zipWithIndex.min._2 - new KMeansModel(centers(bestRun)) - } - - /** - * Initialize `runs` sets of cluster centers at random. - */ - private def initRandom(data: RDD[Array[Double]]): Array[ClusterCenters] = { - // Sample all the cluster centers in one pass to avoid repeated scans - val sample = data.takeSample(true, runs * k, new Random().nextInt()).toSeq - Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).toArray) - } - - /** - * Initialize `runs` sets of cluster centers using the k-means|| algorithm by Bahmani et al. - * (Bahmani et al., Scalable K-Means++, VLDB 2012). This is a variant of k-means++ that tries - * to find with dissimilar cluster centers by starting with a random center and then doing - * passes where more centers are chosen with probability proportional to their squared distance - * to the current cluster set. It results in a provable approximation to an optimal clustering. - * - * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf. - */ - private def initKMeansParallel(data: RDD[Array[Double]]): Array[ClusterCenters] = { - // Initialize each run's center to a random point - val seed = new Random().nextInt() - val sample = data.takeSample(true, runs, seed).toSeq - val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r))) - - // On each step, sample 2 * k points on average for each run with probability proportional - // to their squared distance from that run's current centers - for (step <- 0 until initializationSteps) { - val centerArrays = centers.map(_.toArray) - val sumCosts = data.flatMap { point => - for (r <- 0 until runs) yield (r, KMeans.pointCost(centerArrays(r), point)) - }.reduceByKey(_ + _).collectAsMap() - val chosen = data.mapPartitionsWithIndex { (index, points) => - val rand = new Random(seed ^ (step << 16) ^ index) - for { - p <- points - r <- 0 until runs - if rand.nextDouble() < KMeans.pointCost(centerArrays(r), p) * 2 * k / sumCosts(r) - } yield (r, p) - }.collect() - for ((r, p) <- chosen) { - centers(r) += p - } - } - - // Finally, we might have a set of more than k candidate centers for each run; weigh each - // candidate by the number of points in the dataset mapping to it and run a local k-means++ - // on the weighted centers to pick just k of them - val centerArrays = centers.map(_.toArray) - val weightMap = data.flatMap { p => - for (r <- 0 until runs) yield ((r, KMeans.findClosest(centerArrays(r), p)._1), 1.0) - }.reduceByKey(_ + _).collectAsMap() - val finalCenters = (0 until runs).map { r => - val myCenters = centers(r).toArray - val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray - LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30) - } - - finalCenters.toArray - } -} - - -/** - * Top-level methods for calling K-means clustering. - */ -object KMeans { - // Initialization mode names - val RANDOM = "random" - val K_MEANS_PARALLEL = "k-means||" - - def train( - data: RDD[Array[Double]], - k: Int, - maxIterations: Int, - runs: Int, - initializationMode: String) - : KMeansModel = - { - new KMeans().setK(k) - .setMaxIterations(maxIterations) - .setRuns(runs) - .setInitializationMode(initializationMode) - .run(data) - } - - def train(data: RDD[Array[Double]], k: Int, maxIterations: Int, runs: Int): KMeansModel = { - train(data, k, maxIterations, runs, K_MEANS_PARALLEL) - } - - def train(data: RDD[Array[Double]], k: Int, maxIterations: Int): KMeansModel = { - train(data, k, maxIterations, 1, K_MEANS_PARALLEL) - } - - /** - * Return the index of the closest point in `centers` to `point`, as well as its distance. - */ - private[mllib] def findClosest(centers: Array[Array[Double]], point: Array[Double]) - : (Int, Double) = - { - var bestDistance = Double.PositiveInfinity - var bestIndex = 0 - for (i <- 0 until centers.length) { - val distance = MLUtils.squaredDistance(point, centers(i)) - if (distance < bestDistance) { - bestDistance = distance - bestIndex = i - } - } - (bestIndex, bestDistance) - } - - /** - * Return the K-means cost of a given point against the given cluster centers. - */ - private[mllib] def pointCost(centers: Array[Array[Double]], point: Array[Double]): Double = { - var bestDistance = Double.PositiveInfinity - for (i <- 0 until centers.length) { - val distance = MLUtils.squaredDistance(point, centers(i)) - if (distance < bestDistance) { - bestDistance = distance - } - } - bestDistance - } - - def main(args: Array[String]) { - if (args.length < 4) { - println("Usage: KMeans []") - System.exit(1) - } - val (master, inputFile, k, iters) = (args(0), args(1), args(2).toInt, args(3).toInt) - val runs = if (args.length >= 5) args(4).toInt else 1 - val sc = new SparkContext(master, "KMeans") - val data = sc.textFile(inputFile).map(line => line.split(' ').map(_.toDouble)).cache() - val model = KMeans.train(data, k, iters, runs) - val cost = model.computeCost(data) - println("Cluster centers:") - for (c <- model.clusterCenters) { - println(" " + c.mkString(" ")) - } - println("Cost: " + cost) - System.exit(0) - } -} diff --git a/mllib/src/main/scala/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/spark/mllib/clustering/KMeansModel.scala deleted file mode 100644 index b8f80e80cd..0000000000 --- a/mllib/src/main/scala/spark/mllib/clustering/KMeansModel.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.clustering - -import spark.RDD -import spark.SparkContext._ -import spark.mllib.util.MLUtils - - -/** - * A clustering model for K-means. Each point belongs to the cluster with the closest center. - */ -class KMeansModel(val clusterCenters: Array[Array[Double]]) extends Serializable { - /** Total number of clusters. */ - def k: Int = clusterCenters.length - - /** Return the cluster index that a given point belongs to. */ - def predict(point: Array[Double]): Int = { - KMeans.findClosest(clusterCenters, point)._1 - } - - /** - * Return the K-means cost (sum of squared distances of points to their nearest center) for this - * model on the given data. - */ - def computeCost(data: RDD[Array[Double]]): Double = { - data.map(p => KMeans.pointCost(clusterCenters, p)).sum - } -} diff --git a/mllib/src/main/scala/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/spark/mllib/clustering/LocalKMeans.scala deleted file mode 100644 index 89fe7d7e85..0000000000 --- a/mllib/src/main/scala/spark/mllib/clustering/LocalKMeans.scala +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.clustering - -import scala.util.Random - -import org.jblas.{DoubleMatrix, SimpleBlas} - -/** - * An utility object to run K-means locally. This is private to the ML package because it's used - * in the initialization of KMeans but not meant to be publicly exposed. - */ -private[mllib] object LocalKMeans { - /** - * Run K-means++ on the weighted point set `points`. This first does the K-means++ - * initialization procedure and then roudns of Lloyd's algorithm. - */ - def kMeansPlusPlus( - seed: Int, - points: Array[Array[Double]], - weights: Array[Double], - k: Int, - maxIterations: Int) - : Array[Array[Double]] = - { - val rand = new Random(seed) - val dimensions = points(0).length - val centers = new Array[Array[Double]](k) - - // Initialize centers by sampling using the k-means++ procedure - centers(0) = pickWeighted(rand, points, weights) - for (i <- 1 until k) { - // Pick the next center with a probability proportional to cost under current centers - val curCenters = centers.slice(0, i) - val sum = points.zip(weights).map { case (p, w) => - w * KMeans.pointCost(curCenters, p) - }.sum - val r = rand.nextDouble() * sum - var cumulativeScore = 0.0 - var j = 0 - while (j < points.length && cumulativeScore < r) { - cumulativeScore += weights(j) * KMeans.pointCost(curCenters, points(j)) - j += 1 - } - centers(i) = points(j-1) - } - - // Run up to maxIterations iterations of Lloyd's algorithm - val oldClosest = Array.fill(points.length)(-1) - var iteration = 0 - var moved = true - while (moved && iteration < maxIterations) { - moved = false - val sums = Array.fill(k)(new DoubleMatrix(dimensions)) - val counts = Array.fill(k)(0.0) - for ((p, i) <- points.zipWithIndex) { - val index = KMeans.findClosest(centers, p)._1 - SimpleBlas.axpy(weights(i), new DoubleMatrix(p), sums(index)) - counts(index) += weights(i) - if (index != oldClosest(i)) { - moved = true - oldClosest(i) = index - } - } - // Update centers - for (i <- 0 until k) { - if (counts(i) == 0.0) { - // Assign center to a random point - centers(i) = points(rand.nextInt(points.length)) - } else { - centers(i) = sums(i).divi(counts(i)).data - } - } - iteration += 1 - } - - centers - } - - private def pickWeighted[T](rand: Random, data: Array[T], weights: Array[Double]): T = { - val r = rand.nextDouble() * weights.sum - var i = 0 - var curWeight = 0.0 - while (i < data.length && curWeight < r) { - curWeight += weights(i) - i += 1 - } - data(i - 1) - } -} diff --git a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala deleted file mode 100644 index 05568f55af..0000000000 --- a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.optimization - -import org.jblas.DoubleMatrix - -/** - * Class used to compute the gradient for a loss function, given a single data point. - */ -abstract class Gradient extends Serializable { - /** - * Compute the gradient and loss given features of a single data point. - * - * @param data - Feature values for one data point. Column matrix of size nx1 - * where n is the number of features. - * @param label - Label for this data item. - * @param weights - Column matrix containing weights for every feature. - * - * @return A tuple of 2 elements. The first element is a column matrix containing the computed - * gradient and the second element is the loss computed at this data point. - * - */ - def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) -} - -/** - * Compute gradient and loss for a logistic loss function. - */ -class LogisticGradient extends Gradient { - override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) = { - val margin: Double = -1.0 * data.dot(weights) - val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label - - val gradient = data.mul(gradientMultiplier) - val loss = - if (margin > 0) { - math.log(1 + math.exp(0 - margin)) - } else { - math.log(1 + math.exp(margin)) - margin - } - - (gradient, loss) - } -} - -/** - * Compute gradient and loss for a Least-squared loss function. - */ -class SquaredGradient extends Gradient { - override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) = { - val diff: Double = data.dot(weights) - label - - val loss = 0.5 * diff * diff - val gradient = data.mul(diff) - - (gradient, loss) - } -} - -/** - * Compute gradient and loss for a Hinge loss function. - * NOTE: This assumes that the labels are {0,1} - */ -class HingeGradient extends Gradient { - override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) = { - - val dotProduct = data.dot(weights) - - // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) - // Therefore the gradient is -(2y - 1)*x - val labelScaled = 2 * label - 1.0 - - if (1.0 > labelScaled * dotProduct) { - (data.mul(-labelScaled), 1.0 - labelScaled * dotProduct) - } else { - (DoubleMatrix.zeros(1, weights.length), 0.0) - } - } -} diff --git a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala deleted file mode 100644 index 31917df7e8..0000000000 --- a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.optimization - -import spark.{Logging, RDD, SparkContext} -import spark.SparkContext._ - -import org.jblas.DoubleMatrix - -import scala.collection.mutable.ArrayBuffer - -/** - * Class used to solve an optimization problem using Gradient Descent. - * @param gradient Gradient function to be used. - * @param updater Updater to be used to update weights after every iteration. - */ -class GradientDescent(var gradient: Gradient, var updater: Updater) extends Optimizer { - - private var stepSize: Double = 1.0 - private var numIterations: Int = 100 - private var regParam: Double = 0.0 - private var miniBatchFraction: Double = 1.0 - - /** - * Set the step size per-iteration of SGD. Default 1.0. - */ - def setStepSize(step: Double): this.type = { - this.stepSize = step - this - } - - /** - * Set fraction of data to be used for each SGD iteration. Default 1.0. - */ - def setMiniBatchFraction(fraction: Double): this.type = { - this.miniBatchFraction = fraction - this - } - - /** - * Set the number of iterations for SGD. Default 100. - */ - def setNumIterations(iters: Int): this.type = { - this.numIterations = iters - this - } - - /** - * Set the regularization parameter used for SGD. Default 0.0. - */ - def setRegParam(regParam: Double): this.type = { - this.regParam = regParam - this - } - - /** - * Set the gradient function to be used for SGD. - */ - def setGradient(gradient: Gradient): this.type = { - this.gradient = gradient - this - } - - - /** - * Set the updater function to be used for SGD. - */ - def setUpdater(updater: Updater): this.type = { - this.updater = updater - this - } - - def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double]) - : Array[Double] = { - - val (weights, stochasticLossHistory) = GradientDescent.runMiniBatchSGD( - data, - gradient, - updater, - stepSize, - numIterations, - regParam, - miniBatchFraction, - initialWeights) - weights - } - -} - -// Top-level method to run gradient descent. -object GradientDescent extends Logging { - /** - * Run gradient descent in parallel using mini batches. - * - * @param data - Input data for SGD. RDD of form (label, [feature values]). - * @param gradient - Gradient object that will be used to compute the gradient. - * @param updater - Updater object that will be used to update the model. - * @param stepSize - stepSize to be used during update. - * @param numIterations - number of iterations that SGD should be run. - * @param regParam - regularization parameter - * @param miniBatchFraction - fraction of the input data set that should be used for - * one iteration of SGD. Default value 1.0. - * - * @return A tuple containing two elements. The first element is a column matrix containing - * weights for every feature, and the second element is an array containing the stochastic - * loss computed for every iteration. - */ - def runMiniBatchSGD( - data: RDD[(Double, Array[Double])], - gradient: Gradient, - updater: Updater, - stepSize: Double, - numIterations: Int, - regParam: Double, - miniBatchFraction: Double, - initialWeights: Array[Double]) : (Array[Double], Array[Double]) = { - - val stochasticLossHistory = new ArrayBuffer[Double](numIterations) - - val nexamples: Long = data.count() - val miniBatchSize = nexamples * miniBatchFraction - - // Initialize weights as a column vector - var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*) - var regVal = 0.0 - - for (i <- 1 to numIterations) { - val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42+i).map { - case (y, features) => - val featuresCol = new DoubleMatrix(features.length, 1, features:_*) - val (grad, loss) = gradient.compute(featuresCol, y, weights) - (grad, loss) - }.reduce((a, b) => (a._1.addi(b._1), a._2 + b._2)) - - /** - * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration - * and regVal is the regularization value computed in the previous iteration as well. - */ - stochasticLossHistory.append(lossSum / miniBatchSize + regVal) - val update = updater.compute( - weights, gradientSum.div(miniBatchSize), stepSize, i, regParam) - weights = update._1 - regVal = update._2 - } - - logInfo("GradientDescent finished. Last 10 stochastic losses %s".format( - stochasticLossHistory.takeRight(10).mkString(", "))) - - (weights.toArray, stochasticLossHistory.toArray) - } -} diff --git a/mllib/src/main/scala/spark/mllib/optimization/Optimizer.scala b/mllib/src/main/scala/spark/mllib/optimization/Optimizer.scala deleted file mode 100644 index 76a519c338..0000000000 --- a/mllib/src/main/scala/spark/mllib/optimization/Optimizer.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.optimization - -import spark.RDD - -trait Optimizer { - - /** - * Solve the provided convex optimization problem. - */ - def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double]): Array[Double] - -} diff --git a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala deleted file mode 100644 index db67d6b0bc..0000000000 --- a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.optimization - -import scala.math._ -import org.jblas.DoubleMatrix - -/** - * Class used to update weights used in Gradient Descent. - */ -abstract class Updater extends Serializable { - /** - * Compute an updated value for weights given the gradient, stepSize, iteration number and - * regularization parameter. Also returns the regularization value computed using the - * *updated* weights. - * - * @param weightsOld - Column matrix of size nx1 where n is the number of features. - * @param gradient - Column matrix of size nx1 where n is the number of features. - * @param stepSize - step size across iterations - * @param iter - Iteration number - * @param regParam - Regularization parameter - * - * @return A tuple of 2 elements. The first element is a column matrix containing updated weights, - * and the second element is the regularization value computed using updated weights. - */ - def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int, - regParam: Double): (DoubleMatrix, Double) -} - -/** - * A simple updater that adaptively adjusts the learning rate the - * square root of the number of iterations. Does not perform any regularization. - */ -class SimpleUpdater extends Updater { - override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, - stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { - val thisIterStepSize = stepSize / math.sqrt(iter) - val normGradient = gradient.mul(thisIterStepSize) - (weightsOld.sub(normGradient), 0) - } -} - -/** - * Updater that adjusts learning rate and performs L1 regularization. - * - * The corresponding proximal operator used is the soft-thresholding function. - * That is, each weight component is shrunk towards 0 by shrinkageVal. - * - * If w > shrinkageVal, set weight component to w-shrinkageVal. - * If w < -shrinkageVal, set weight component to w+shrinkageVal. - * If -shrinkageVal < w < shrinkageVal, set weight component to 0. - * - * Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal) - */ -class L1Updater extends Updater { - override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, - stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { - val thisIterStepSize = stepSize / math.sqrt(iter) - val normGradient = gradient.mul(thisIterStepSize) - // Take gradient step - val newWeights = weightsOld.sub(normGradient) - // Soft thresholding - val shrinkageVal = regParam * thisIterStepSize - (0 until newWeights.length).foreach { i => - val wi = newWeights.get(i) - newWeights.put(i, signum(wi) * max(0.0, abs(wi) - shrinkageVal)) - } - (newWeights, newWeights.norm1 * regParam) - } -} - -/** - * Updater that adjusts the learning rate and performs L2 regularization - */ -class SquaredL2Updater extends Updater { - override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, - stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { - val thisIterStepSize = stepSize / math.sqrt(iter) - val normGradient = gradient.mul(thisIterStepSize) - val newWeights = weightsOld.sub(normGradient).div(2.0 * thisIterStepSize * regParam + 1.0) - (newWeights, pow(newWeights.norm2, 2.0) * regParam) - } -} - diff --git a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala deleted file mode 100644 index dbfbf59975..0000000000 --- a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala +++ /dev/null @@ -1,453 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.recommendation - -import scala.collection.mutable.{ArrayBuffer, BitSet} -import scala.util.Random -import scala.util.Sorting - -import spark.{HashPartitioner, Partitioner, SparkContext, RDD} -import spark.storage.StorageLevel -import spark.KryoRegistrator -import spark.SparkContext._ - -import com.esotericsoftware.kryo.Kryo -import org.jblas.{DoubleMatrix, SimpleBlas, Solve} - - -/** - * Out-link information for a user or product block. This includes the original user/product IDs - * of the elements within this block, and the list of destination blocks that each user or - * product will need to send its feature vector to. - */ -private[recommendation] case class OutLinkBlock(elementIds: Array[Int], shouldSend: Array[BitSet]) - - -/** - * In-link information for a user (or product) block. This includes the original user/product IDs - * of the elements within this block, as well as an array of indices and ratings that specify - * which user in the block will be rated by which products from each product block (or vice-versa). - * Specifically, if this InLinkBlock is for users, ratingsForBlock(b)(i) will contain two arrays, - * indices and ratings, for the i'th product that will be sent to us by product block b (call this - * P). These arrays represent the users that product P had ratings for (by their index in this - * block), as well as the corresponding rating for each one. We can thus use this information when - * we get product block b's message to update the corresponding users. - */ -private[recommendation] case class InLinkBlock( - elementIds: Array[Int], ratingsForBlock: Array[Array[(Array[Int], Array[Double])]]) - - -/** - * A more compact class to represent a rating than Tuple3[Int, Int, Double]. - */ -case class Rating(val user: Int, val product: Int, val rating: Double) - -/** - * Alternating Least Squares matrix factorization. - * - * This is a blocked implementation of the ALS factorization algorithm that groups the two sets - * of factors (referred to as "users" and "products") into blocks and reduces communication by only - * sending one copy of each user vector to each product block on each iteration, and only for the - * product blocks that need that user's feature vector. This is achieved by precomputing some - * information about the ratings matrix to determine the "out-links" of each user (which blocks of - * products it will contribute to) and "in-link" information for each product (which of the feature - * vectors it receives from each user block it will depend on). This allows us to send only an - * array of feature vectors between each user block and product block, and have the product block - * find the users' ratings and update the products based on these messages. - */ -class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var lambda: Double) - extends Serializable -{ - def this() = this(-1, 10, 10, 0.01) - - /** - * Set the number of blocks to parallelize the computation into; pass -1 for an auto-configured - * number of blocks. Default: -1. - */ - def setBlocks(numBlocks: Int): ALS = { - this.numBlocks = numBlocks - this - } - - /** Set the rank of the feature matrices computed (number of features). Default: 10. */ - def setRank(rank: Int): ALS = { - this.rank = rank - this - } - - /** Set the number of iterations to run. Default: 10. */ - def setIterations(iterations: Int): ALS = { - this.iterations = iterations - this - } - - /** Set the regularization parameter, lambda. Default: 0.01. */ - def setLambda(lambda: Double): ALS = { - this.lambda = lambda - this - } - - /** - * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. - * Returns a MatrixFactorizationModel with feature vectors for each user and product. - */ - def run(ratings: RDD[Rating]): MatrixFactorizationModel = { - val numBlocks = if (this.numBlocks == -1) { - math.max(ratings.context.defaultParallelism, ratings.partitions.size / 2) - } else { - this.numBlocks - } - - val partitioner = new HashPartitioner(numBlocks) - - val ratingsByUserBlock = ratings.map{ rating => (rating.user % numBlocks, rating) } - val ratingsByProductBlock = ratings.map{ rating => - (rating.product % numBlocks, Rating(rating.product, rating.user, rating.rating)) - } - - val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock) - val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, ratingsByProductBlock) - - // Initialize user and product factors randomly, but use a deterministic seed for each partition - // so that fault recovery works - val seedGen = new Random() - val seed1 = seedGen.nextInt() - val seed2 = seedGen.nextInt() - // Hash an integer to propagate random bits at all positions, similar to java.util.HashTable - def hash(x: Int): Int = { - val r = x ^ (x >>> 20) ^ (x >>> 12) - r ^ (r >>> 7) ^ (r >>> 4) - } - var users = userOutLinks.mapPartitionsWithIndex { (index, itr) => - val rand = new Random(hash(seed1 ^ index)) - itr.map { case (x, y) => - (x, y.elementIds.map(_ => randomFactor(rank, rand))) - } - } - var products = productOutLinks.mapPartitionsWithIndex { (index, itr) => - val rand = new Random(hash(seed2 ^ index)) - itr.map { case (x, y) => - (x, y.elementIds.map(_ => randomFactor(rank, rand))) - } - } - - for (iter <- 0 until iterations) { - // perform ALS update - products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda) - users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda) - } - - // Flatten and cache the two final RDDs to un-block them - val usersOut = users.join(userOutLinks).flatMap { case (b, (factors, outLinkBlock)) => - for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i)) - } - val productsOut = products.join(productOutLinks).flatMap { case (b, (factors, outLinkBlock)) => - for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i)) - } - - usersOut.persist() - productsOut.persist() - - new MatrixFactorizationModel(rank, usersOut, productsOut) - } - - /** - * Make the out-links table for a block of the users (or products) dataset given the list of - * (user, product, rating) values for the users in that block (or the opposite for products). - */ - private def makeOutLinkBlock(numBlocks: Int, ratings: Array[Rating]): OutLinkBlock = { - val userIds = ratings.map(_.user).distinct.sorted - val numUsers = userIds.length - val userIdToPos = userIds.zipWithIndex.toMap - val shouldSend = Array.fill(numUsers)(new BitSet(numBlocks)) - for (r <- ratings) { - shouldSend(userIdToPos(r.user))(r.product % numBlocks) = true - } - OutLinkBlock(userIds, shouldSend) - } - - /** - * Make the in-links table for a block of the users (or products) dataset given a list of - * (user, product, rating) values for the users in that block (or the opposite for products). - */ - private def makeInLinkBlock(numBlocks: Int, ratings: Array[Rating]): InLinkBlock = { - val userIds = ratings.map(_.user).distinct.sorted - val numUsers = userIds.length - val userIdToPos = userIds.zipWithIndex.toMap - // Split out our ratings by product block - val blockRatings = Array.fill(numBlocks)(new ArrayBuffer[Rating]) - for (r <- ratings) { - blockRatings(r.product % numBlocks) += r - } - val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numBlocks) - for (productBlock <- 0 until numBlocks) { - // Create an array of (product, Seq(Rating)) ratings - val groupedRatings = blockRatings(productBlock).groupBy(_.product).toArray - // Sort them by product ID - val ordering = new Ordering[(Int, ArrayBuffer[Rating])] { - def compare(a: (Int, ArrayBuffer[Rating]), b: (Int, ArrayBuffer[Rating])): Int = a._1 - b._1 - } - Sorting.quickSort(groupedRatings)(ordering) - // Translate the user IDs to indices based on userIdToPos - ratingsForBlock(productBlock) = groupedRatings.map { case (p, rs) => - (rs.view.map(r => userIdToPos(r.user)).toArray, rs.view.map(_.rating).toArray) - } - } - InLinkBlock(userIds, ratingsForBlock) - } - - /** - * Make RDDs of InLinkBlocks and OutLinkBlocks given an RDD of (blockId, (u, p, r)) values for - * the users (or (blockId, (p, u, r)) for the products). We create these simultaneously to avoid - * having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it. - */ - private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, Rating)]) - : (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) = - { - val grouped = ratings.partitionBy(new HashPartitioner(numBlocks)) - val links = grouped.mapPartitionsWithIndex((blockId, elements) => { - val ratings = elements.map{_._2}.toArray - val inLinkBlock = makeInLinkBlock(numBlocks, ratings) - val outLinkBlock = makeOutLinkBlock(numBlocks, ratings) - Iterator.single((blockId, (inLinkBlock, outLinkBlock))) - }, true) - links.persist(StorageLevel.MEMORY_AND_DISK) - (links.mapValues(_._1), links.mapValues(_._2)) - } - - /** - * Make a random factor vector with the given random. - */ - private def randomFactor(rank: Int, rand: Random): Array[Double] = { - Array.fill(rank)(rand.nextDouble) - } - - /** - * Compute the user feature vectors given the current products (or vice-versa). This first joins - * the products with their out-links to generate a set of messages to each destination block - * (specifically, the features for the products that user block cares about), then groups these - * by destination and joins them with the in-link info to figure out how to update each user. - * It returns an RDD of new feature vectors for each user block. - */ - private def updateFeatures( - products: RDD[(Int, Array[Array[Double]])], - productOutLinks: RDD[(Int, OutLinkBlock)], - userInLinks: RDD[(Int, InLinkBlock)], - partitioner: Partitioner, - rank: Int, - lambda: Double) - : RDD[(Int, Array[Array[Double]])] = - { - val numBlocks = products.partitions.size - productOutLinks.join(products).flatMap { case (bid, (outLinkBlock, factors)) => - val toSend = Array.fill(numBlocks)(new ArrayBuffer[Array[Double]]) - for (p <- 0 until outLinkBlock.elementIds.length; userBlock <- 0 until numBlocks) { - if (outLinkBlock.shouldSend(p)(userBlock)) { - toSend(userBlock) += factors(p) - } - } - toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) } - }.groupByKey(partitioner) - .join(userInLinks) - .mapValues{ case (messages, inLinkBlock) => updateBlock(messages, inLinkBlock, rank, lambda) } - } - - /** - * Compute the new feature vectors for a block of the users matrix given the list of factors - * it received from each product and its InLinkBlock. - */ - def updateBlock(messages: Seq[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock, - rank: Int, lambda: Double) - : Array[Array[Double]] = - { - // Sort the incoming block factor messages by block ID and make them an array - val blockFactors = messages.sortBy(_._1).map(_._2).toArray // Array[Array[Double]] - val numBlocks = blockFactors.length - val numUsers = inLinkBlock.elementIds.length - - // We'll sum up the XtXes using vectors that represent only the lower-triangular part, since - // the matrices are symmetric - val triangleSize = rank * (rank + 1) / 2 - val userXtX = Array.fill(numUsers)(DoubleMatrix.zeros(triangleSize)) - val userXy = Array.fill(numUsers)(DoubleMatrix.zeros(rank)) - - // Some temp variables to avoid memory allocation - val tempXtX = DoubleMatrix.zeros(triangleSize) - val fullXtX = DoubleMatrix.zeros(rank, rank) - - // Compute the XtX and Xy values for each user by adding products it rated in each product block - for (productBlock <- 0 until numBlocks) { - for (p <- 0 until blockFactors(productBlock).length) { - val x = new DoubleMatrix(blockFactors(productBlock)(p)) - fillXtX(x, tempXtX) - val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p) - for (i <- 0 until us.length) { - userXtX(us(i)).addi(tempXtX) - SimpleBlas.axpy(rs(i), x, userXy(us(i))) - } - } - } - - // Solve the least-squares problem for each user and return the new feature vectors - userXtX.zipWithIndex.map{ case (triangularXtX, index) => - // Compute the full XtX matrix from the lower-triangular part we got above - fillFullMatrix(triangularXtX, fullXtX) - // Add regularization - (0 until rank).foreach(i => fullXtX.data(i*rank + i) += lambda) - // Solve the resulting matrix, which is symmetric and positive-definite - Solve.solvePositive(fullXtX, userXy(index)).data - } - } - - /** - * Set xtxDest to the lower-triangular part of x transpose * x. For efficiency in summing - * these matrices, we store xtxDest as only rank * (rank+1) / 2 values, namely the values - * at (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), etc in that order. - */ - private def fillXtX(x: DoubleMatrix, xtxDest: DoubleMatrix) { - var i = 0 - var pos = 0 - while (i < x.length) { - var j = 0 - while (j <= i) { - xtxDest.data(pos) = x.data(i) * x.data(j) - pos += 1 - j += 1 - } - i += 1 - } - } - - /** - * Given a triangular matrix in the order of fillXtX above, compute the full symmetric square - * matrix that it represents, storing it into destMatrix. - */ - private def fillFullMatrix(triangularMatrix: DoubleMatrix, destMatrix: DoubleMatrix) { - val rank = destMatrix.rows - var i = 0 - var pos = 0 - while (i < rank) { - var j = 0 - while (j <= i) { - destMatrix.data(i*rank + j) = triangularMatrix.data(pos) - destMatrix.data(j*rank + i) = triangularMatrix.data(pos) - pos += 1 - j += 1 - } - i += 1 - } - } -} - - -/** - * Top-level methods for calling Alternating Least Squares (ALS) matrix factorizaton. - */ -object ALS { - /** - * Train a matrix factorization model given an RDD of ratings given by users to some products, - * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the - * product of two lower-rank matrices of a given rank (number of features). To solve for these - * features, we run a given number of iterations of ALS. This is done using a level of - * parallelism given by `blocks`. - * - * @param ratings RDD of (userID, productID, rating) pairs - * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - * @param lambda regularization factor (recommended: 0.01) - * @param blocks level of parallelism to split computation into - */ - def train( - ratings: RDD[Rating], - rank: Int, - iterations: Int, - lambda: Double, - blocks: Int) - : MatrixFactorizationModel = - { - new ALS(blocks, rank, iterations, lambda).run(ratings) - } - - /** - * Train a matrix factorization model given an RDD of ratings given by users to some products, - * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the - * product of two lower-rank matrices of a given rank (number of features). To solve for these - * features, we run a given number of iterations of ALS. The level of parallelism is determined - * automatically based on the number of partitions in `ratings`. - * - * @param ratings RDD of (userID, productID, rating) pairs - * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - * @param lambda regularization factor (recommended: 0.01) - */ - def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double) - : MatrixFactorizationModel = - { - train(ratings, rank, iterations, lambda, -1) - } - - /** - * Train a matrix factorization model given an RDD of ratings given by users to some products, - * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the - * product of two lower-rank matrices of a given rank (number of features). To solve for these - * features, we run a given number of iterations of ALS. The level of parallelism is determined - * automatically based on the number of partitions in `ratings`. - * - * @param ratings RDD of (userID, productID, rating) pairs - * @param rank number of features to use - * @param iterations number of iterations of ALS (recommended: 10-20) - */ - def train(ratings: RDD[Rating], rank: Int, iterations: Int) - : MatrixFactorizationModel = - { - train(ratings, rank, iterations, 0.01, -1) - } - - private class ALSRegistrator extends KryoRegistrator { - override def registerClasses(kryo: Kryo) { - kryo.register(classOf[Rating]) - } - } - - def main(args: Array[String]) { - if (args.length != 5 && args.length != 6) { - println("Usage: ALS []") - System.exit(1) - } - val (master, ratingsFile, rank, iters, outputDir) = - (args(0), args(1), args(2).toInt, args(3).toInt, args(4)) - val blocks = if (args.length == 6) args(5).toInt else -1 - System.setProperty("spark.serializer", "spark.KryoSerializer") - System.setProperty("spark.kryo.registrator", classOf[ALSRegistrator].getName) - System.setProperty("spark.kryo.referenceTracking", "false") - System.setProperty("spark.kryoserializer.buffer.mb", "8") - System.setProperty("spark.locality.wait", "10000") - val sc = new SparkContext(master, "ALS") - val ratings = sc.textFile(ratingsFile).map { line => - val fields = line.split(',') - Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble) - } - val model = ALS.train(ratings, rank, iters, 0.01, blocks) - model.userFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") } - .saveAsTextFile(outputDir + "/userFeatures") - model.productFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") } - .saveAsTextFile(outputDir + "/productFeatures") - println("Final user/product features written to " + outputDir) - System.exit(0) - } -} diff --git a/mllib/src/main/scala/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/spark/mllib/recommendation/MatrixFactorizationModel.scala deleted file mode 100644 index 5e21717da5..0000000000 --- a/mllib/src/main/scala/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.recommendation - -import spark.RDD -import spark.SparkContext._ - -import org.jblas._ - -/** - * Model representing the result of matrix factorization. - * - * @param rank Rank for the features in this model. - * @param userFeatures RDD of tuples where each tuple represents the userId and - * the features computed for this user. - * @param productFeatures RDD of tuples where each tuple represents the productId - * and the features computed for this product. - */ -class MatrixFactorizationModel( - val rank: Int, - val userFeatures: RDD[(Int, Array[Double])], - val productFeatures: RDD[(Int, Array[Double])]) - extends Serializable -{ - /** Predict the rating of one user for one product. */ - def predict(user: Int, product: Int): Double = { - val userVector = new DoubleMatrix(userFeatures.lookup(user).head) - val productVector = new DoubleMatrix(productFeatures.lookup(product).head) - userVector.dot(productVector) - } - - // TODO: Figure out what good bulk prediction methods would look like. - // Probably want a way to get the top users for a product or vice-versa. -} diff --git a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala deleted file mode 100644 index d164d415d6..0000000000 --- a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -import spark.{Logging, RDD, SparkException} -import spark.mllib.optimization._ - -import org.jblas.DoubleMatrix - -/** - * GeneralizedLinearModel (GLM) represents a model trained using - * GeneralizedLinearAlgorithm. GLMs consist of a weight vector and - * an intercept. - * - * @param weights Weights computed for every feature. - * @param intercept Intercept computed for this model. - */ -abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: Double) - extends Serializable { - - // Create a column vector that can be used for predictions - private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*) - - /** - * Predict the result given a data point and the weights learned. - * - * @param dataMatrix Row vector containing the features for this data point - * @param weightMatrix Column vector containing the weights of the model - * @param intercept Intercept of the model. - */ - def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double): Double - - /** - * Predict values for the given data set using the model trained. - * - * @param testData RDD representing data points to be predicted - * @return RDD[Double] where each entry contains the corresponding prediction - */ - def predict(testData: spark.RDD[Array[Double]]): RDD[Double] = { - // A small optimization to avoid serializing the entire model. Only the weightsMatrix - // and intercept is needed. - val localWeights = weightsMatrix - val localIntercept = intercept - - testData.map { x => - val dataMatrix = new DoubleMatrix(1, x.length, x:_*) - predictPoint(dataMatrix, localWeights, localIntercept) - } - } - - /** - * Predict values for a single data point using the model trained. - * - * @param testData array representing a single data point - * @return Double prediction from the trained model - */ - def predict(testData: Array[Double]): Double = { - val dataMat = new DoubleMatrix(1, testData.length, testData:_*) - predictPoint(dataMat, weightsMatrix, intercept) - } -} - -/** - * GeneralizedLinearAlgorithm implements methods to train a Genearalized Linear Model (GLM). - * This class should be extended with an Optimizer to create a new GLM. - */ -abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] - extends Logging with Serializable { - - protected val validators: Seq[RDD[LabeledPoint] => Boolean] = List() - - val optimizer: Optimizer - - protected var addIntercept: Boolean = true - - protected var validateData: Boolean = true - - /** - * Create a model given the weights and intercept - */ - protected def createModel(weights: Array[Double], intercept: Double): M - - /** - * Set if the algorithm should add an intercept. Default true. - */ - def setIntercept(addIntercept: Boolean): this.type = { - this.addIntercept = addIntercept - this - } - - /** - * Set if the algorithm should validate data before training. Default true. - */ - def setValidateData(validateData: Boolean): this.type = { - this.validateData = validateData - this - } - - /** - * Run the algorithm with the configured parameters on an input - * RDD of LabeledPoint entries. - */ - def run(input: RDD[LabeledPoint]) : M = { - val nfeatures: Int = input.first().features.length - val initialWeights = Array.fill(nfeatures)(1.0) - run(input, initialWeights) - } - - /** - * Run the algorithm with the configured parameters on an input RDD - * of LabeledPoint entries starting from the initial weights provided. - */ - def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = { - - // Check the data properties before running the optimizer - if (validateData && !validators.forall(func => func(input))) { - throw new SparkException("Input validation failed.") - } - - // Add a extra variable consisting of all 1.0's for the intercept. - val data = if (addIntercept) { - input.map(labeledPoint => (labeledPoint.label, Array(1.0, labeledPoint.features:_*))) - } else { - input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) - } - - val initialWeightsWithIntercept = if (addIntercept) { - Array(1.0, initialWeights:_*) - } else { - initialWeights - } - - val weights = optimizer.optimize(data, initialWeightsWithIntercept) - val intercept = weights(0) - val weightsScaled = weights.tail - - val model = createModel(weightsScaled, intercept) - - logInfo("Final model weights " + model.weights.mkString(",")) - logInfo("Final model intercept " + model.intercept) - model - } -} diff --git a/mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala deleted file mode 100644 index 3de60482c5..0000000000 --- a/mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -/** - * Class that represents the features and labels of a data point. - * - * @param label Label for this data point. - * @param features List of features for this data point. - */ -case class LabeledPoint(val label: Double, val features: Array[Double]) diff --git a/mllib/src/main/scala/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/spark/mllib/regression/Lasso.scala deleted file mode 100644 index 0f33456ef4..0000000000 --- a/mllib/src/main/scala/spark/mllib/regression/Lasso.scala +++ /dev/null @@ -1,210 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -import spark.{Logging, RDD, SparkContext} -import spark.mllib.optimization._ -import spark.mllib.util.MLUtils - -import org.jblas.DoubleMatrix - -/** - * Regression model trained using Lasso. - * - * @param weights Weights computed for every feature. - * @param intercept Intercept computed for this model. - */ -class LassoModel( - override val weights: Array[Double], - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { - - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { - dataMatrix.dot(weightMatrix) + intercept - } -} - -/** - * Train a regression model with L1-regularization using Stochastic Gradient Descent. - */ -class LassoWithSGD private ( - var stepSize: Double, - var numIterations: Int, - var regParam: Double, - var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[LassoModel] - with Serializable { - - val gradient = new SquaredGradient() - val updater = new L1Updater() - @transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize) - .setNumIterations(numIterations) - .setRegParam(regParam) - .setMiniBatchFraction(miniBatchFraction) - - // We don't want to penalize the intercept, so set this to false. - setIntercept(false) - - var yMean = 0.0 - var xColMean: DoubleMatrix = _ - var xColSd: DoubleMatrix = _ - - /** - * Construct a Lasso object with default parameters - */ - def this() = this(1.0, 100, 1.0, 1.0) - - def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) - val weightsScaled = weightsMat.div(xColSd) - val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)) - - new LassoModel(weightsScaled.data, interceptScaled) - } - - override def run( - input: RDD[LabeledPoint], - initialWeights: Array[Double]) - : LassoModel = - { - val nfeatures: Int = input.first.features.length - val nexamples: Long = input.count() - - // To avoid penalizing the intercept, we center and scale the data. - val stats = MLUtils.computeStats(input, nfeatures, nexamples) - yMean = stats._1 - xColMean = stats._2 - xColSd = stats._3 - - val normalizedData = input.map { point => - val yNormalized = point.label - yMean - val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*) - val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd) - LabeledPoint(yNormalized, featuresNormalized.toArray) - } - - super.run(normalizedData, initialWeights) - } -} - -/** - * Top-level methods for calling Lasso. - */ -object LassoWithSGD { - - /** - * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in - * gradient descent are initialized using the initial weights provided. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param regParam Regularization parameter. - * @param miniBatchFraction Fraction of data to be used per iteration. - * @param initialWeights Initial set of weights to be used. Array should be equal in size to - * the number of features in the data. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double, - miniBatchFraction: Double, - initialWeights: Array[Double]) - : LassoModel = - { - new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input, - initialWeights) - } - - /** - * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param regParam Regularization parameter. - * @param miniBatchFraction Fraction of data to be used per iteration. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double, - miniBatchFraction: Double) - : LassoModel = - { - new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) - } - - /** - * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. We use the entire data set to - * update the gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. - * @param stepSize Step size to be used for each iteration of Gradient Descent. - * @param regParam Regularization parameter. - * @param numIterations Number of iterations of gradient descent to run. - * @return a LassoModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double) - : LassoModel = - { - train(input, numIterations, stepSize, regParam, 1.0) - } - - /** - * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using a step size of 1.0. We use the entire data set to - * update the gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @return a LassoModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int) - : LassoModel = - { - train(input, numIterations, 1.0, 1.0, 1.0) - } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: Lasso ") - System.exit(1) - } - val sc = new SparkContext(args(0), "Lasso") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = LassoWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) - - sc.stop() - } -} diff --git a/mllib/src/main/scala/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/spark/mllib/regression/LinearRegression.scala deleted file mode 100644 index 885ff5a30d..0000000000 --- a/mllib/src/main/scala/spark/mllib/regression/LinearRegression.scala +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -import spark.{Logging, RDD, SparkContext} -import spark.mllib.optimization._ -import spark.mllib.util.MLUtils - -import org.jblas.DoubleMatrix - -/** - * Regression model trained using LinearRegression. - * - * @param weights Weights computed for every feature. - * @param intercept Intercept computed for this model. - */ -class LinearRegressionModel( - override val weights: Array[Double], - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { - - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { - dataMatrix.dot(weightMatrix) + intercept - } -} - -/** - * Train a regression model with no regularization using Stochastic Gradient Descent. - */ -class LinearRegressionWithSGD private ( - var stepSize: Double, - var numIterations: Int, - var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[LinearRegressionModel] - with Serializable { - - val gradient = new SquaredGradient() - val updater = new SimpleUpdater() - val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize) - .setNumIterations(numIterations) - .setMiniBatchFraction(miniBatchFraction) - - /** - * Construct a LinearRegression object with default parameters - */ - def this() = this(1.0, 100, 1.0) - - def createModel(weights: Array[Double], intercept: Double) = { - new LinearRegressionModel(weights, intercept) - } -} - -/** - * Top-level methods for calling LinearRegression. - */ -object LinearRegressionWithSGD { - - /** - * Train a Linear Regression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in - * gradient descent are initialized using the initial weights provided. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param miniBatchFraction Fraction of data to be used per iteration. - * @param initialWeights Initial set of weights to be used. Array should be equal in size to - * the number of features in the data. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - miniBatchFraction: Double, - initialWeights: Array[Double]) - : LinearRegressionModel = - { - new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input, - initialWeights) - } - - /** - * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param miniBatchFraction Fraction of data to be used per iteration. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - miniBatchFraction: Double) - : LinearRegressionModel = - { - new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input) - } - - /** - * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. We use the entire data set to - * update the gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. - * @param stepSize Step size to be used for each iteration of Gradient Descent. - * @param numIterations Number of iterations of gradient descent to run. - * @return a LinearRegressionModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double) - : LinearRegressionModel = - { - train(input, numIterations, stepSize, 1.0) - } - - /** - * Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using a step size of 1.0. We use the entire data set to - * update the gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @return a LinearRegressionModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int) - : LinearRegressionModel = - { - train(input, numIterations, 1.0, 1.0) - } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: LinearRegression ") - System.exit(1) - } - val sc = new SparkContext(args(0), "LinearRegression") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) - - sc.stop() - } -} diff --git a/mllib/src/main/scala/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/spark/mllib/regression/RegressionModel.scala deleted file mode 100644 index b845ba1a89..0000000000 --- a/mllib/src/main/scala/spark/mllib/regression/RegressionModel.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -import spark.RDD - -trait RegressionModel extends Serializable { - /** - * Predict values for the given data set using the model trained. - * - * @param testData RDD representing data points to be predicted - * @return RDD[Double] where each entry contains the corresponding prediction - */ - def predict(testData: RDD[Array[Double]]): RDD[Double] - - /** - * Predict values for a single data point using the model trained. - * - * @param testData array representing a single data point - * @return Double prediction from the trained model - */ - def predict(testData: Array[Double]): Double -} diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala deleted file mode 100644 index cb1303dd99..0000000000 --- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala +++ /dev/null @@ -1,213 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -import spark.{Logging, RDD, SparkContext} -import spark.mllib.optimization._ -import spark.mllib.util.MLUtils - -import org.jblas.DoubleMatrix - -/** - * Regression model trained using RidgeRegression. - * - * @param weights Weights computed for every feature. - * @param intercept Intercept computed for this model. - */ -class RidgeRegressionModel( - override val weights: Array[Double], - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { - - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { - dataMatrix.dot(weightMatrix) + intercept - } -} - -/** - * Train a regression model with L2-regularization using Stochastic Gradient Descent. - */ -class RidgeRegressionWithSGD private ( - var stepSize: Double, - var numIterations: Int, - var regParam: Double, - var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[RidgeRegressionModel] - with Serializable { - - val gradient = new SquaredGradient() - val updater = new SquaredL2Updater() - - @transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize) - .setNumIterations(numIterations) - .setRegParam(regParam) - .setMiniBatchFraction(miniBatchFraction) - - // We don't want to penalize the intercept in RidgeRegression, so set this to false. - setIntercept(false) - - var yMean = 0.0 - var xColMean: DoubleMatrix = _ - var xColSd: DoubleMatrix = _ - - /** - * Construct a RidgeRegression object with default parameters - */ - def this() = this(1.0, 100, 1.0, 1.0) - - def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) - val weightsScaled = weightsMat.div(xColSd) - val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)) - - new RidgeRegressionModel(weightsScaled.data, interceptScaled) - } - - override def run( - input: RDD[LabeledPoint], - initialWeights: Array[Double]) - : RidgeRegressionModel = - { - val nfeatures: Int = input.first.features.length - val nexamples: Long = input.count() - - // To avoid penalizing the intercept, we center and scale the data. - val stats = MLUtils.computeStats(input, nfeatures, nexamples) - yMean = stats._1 - xColMean = stats._2 - xColSd = stats._3 - - val normalizedData = input.map { point => - val yNormalized = point.label - yMean - val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*) - val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd) - LabeledPoint(yNormalized, featuresNormalized.toArray) - } - - super.run(normalizedData, initialWeights) - } -} - -/** - * Top-level methods for calling RidgeRegression. - */ -object RidgeRegressionWithSGD { - - /** - * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in - * gradient descent are initialized using the initial weights provided. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param regParam Regularization parameter. - * @param miniBatchFraction Fraction of data to be used per iteration. - * @param initialWeights Initial set of weights to be used. Array should be equal in size to - * the number of features in the data. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double, - miniBatchFraction: Double, - initialWeights: Array[Double]) - : RidgeRegressionModel = - { - new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run( - input, initialWeights) - } - - /** - * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. Each iteration uses - * `miniBatchFraction` fraction of the data to calculate the gradient. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @param stepSize Step size to be used for each iteration of gradient descent. - * @param regParam Regularization parameter. - * @param miniBatchFraction Fraction of data to be used per iteration. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double, - miniBatchFraction: Double) - : RidgeRegressionModel = - { - new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) - } - - /** - * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using the specified step size. We use the entire data set to - * update the gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. - * @param stepSize Step size to be used for each iteration of Gradient Descent. - * @param regParam Regularization parameter. - * @param numIterations Number of iterations of gradient descent to run. - * @return a RidgeRegressionModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int, - stepSize: Double, - regParam: Double) - : RidgeRegressionModel = - { - train(input, numIterations, stepSize, regParam, 1.0) - } - - /** - * Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number - * of iterations of gradient descent using a step size of 1.0. We use the entire data set to - * update the gradient in each iteration. - * - * @param input RDD of (label, array of features) pairs. - * @param numIterations Number of iterations of gradient descent to run. - * @return a RidgeRegressionModel which has the weights and offset from training. - */ - def train( - input: RDD[LabeledPoint], - numIterations: Int) - : RidgeRegressionModel = - { - train(input, numIterations, 1.0, 1.0, 1.0) - } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: RidgeRegression " + - " ") - System.exit(1) - } - val sc = new SparkContext(args(0), "RidgeRegression") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble, - args(3).toDouble) - - sc.stop() - } -} diff --git a/mllib/src/main/scala/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/spark/mllib/util/DataValidators.scala deleted file mode 100644 index 57553accf1..0000000000 --- a/mllib/src/main/scala/spark/mllib/util/DataValidators.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.util - -import spark.{RDD, Logging} -import spark.mllib.regression.LabeledPoint - -/** - * A collection of methods used to validate data before applying ML algorithms. - */ -object DataValidators extends Logging { - - /** - * Function to check if labels used for classification are either zero or one. - * - * @param data - input data set that needs to be checked - * - * @return True if labels are all zero or one, false otherwise. - */ - val classificationLabels: RDD[LabeledPoint] => Boolean = { data => - val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count() - if (numInvalid != 0) { - logError("Classification labels should be 0 or 1. Found " + numInvalid + " invalid labels") - } - numInvalid == 0 - } -} diff --git a/mllib/src/main/scala/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/KMeansDataGenerator.scala deleted file mode 100644 index 672b63f65a..0000000000 --- a/mllib/src/main/scala/spark/mllib/util/KMeansDataGenerator.scala +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.util - -import scala.util.Random - -import spark.{RDD, SparkContext} - -/** - * Generate test data for KMeans. This class first chooses k cluster centers - * from a d-dimensional Gaussian distribution scaled by factor r and then creates a Gaussian - * cluster with scale 1 around each center. - */ - -object KMeansDataGenerator { - - /** - * Generate an RDD containing test data for KMeans. - * - * @param sc SparkContext to use for creating the RDD - * @param numPoints Number of points that will be contained in the RDD - * @param k Number of clusters - * @param d Number of dimensions - * @param r Scaling factor for the distribution of the initial centers - * @param numPartitions Number of partitions of the generated RDD; default 2 - */ - def generateKMeansRDD( - sc: SparkContext, - numPoints: Int, - k: Int, - d: Int, - r: Double, - numPartitions: Int = 2) - : RDD[Array[Double]] = - { - // First, generate some centers - val rand = new Random(42) - val centers = Array.fill(k)(Array.fill(d)(rand.nextGaussian() * r)) - // Then generate points around each center - sc.parallelize(0 until numPoints, numPartitions).map { idx => - val center = centers(idx % k) - val rand2 = new Random(42 + idx) - Array.tabulate(d)(i => center(i) + rand2.nextGaussian()) - } - } - - def main(args: Array[String]) { - if (args.length < 6) { - println("Usage: KMeansGenerator " + - " []") - System.exit(1) - } - - val sparkMaster = args(0) - val outputPath = args(1) - val numPoints = args(2).toInt - val k = args(3).toInt - val d = args(4).toInt - val r = args(5).toDouble - val parts = if (args.length >= 7) args(6).toInt else 2 - - val sc = new SparkContext(sparkMaster, "KMeansDataGenerator") - val data = generateKMeansRDD(sc, numPoints, k, d, r, parts) - data.map(_.mkString(" ")).saveAsTextFile(outputPath) - - System.exit(0) - } -} - diff --git a/mllib/src/main/scala/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/LinearDataGenerator.scala deleted file mode 100644 index 9f48477f84..0000000000 --- a/mllib/src/main/scala/spark/mllib/util/LinearDataGenerator.scala +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.util - -import scala.collection.JavaConversions._ -import scala.util.Random - -import org.jblas.DoubleMatrix - -import spark.{RDD, SparkContext} -import spark.mllib.regression.LabeledPoint -import spark.mllib.regression.LabeledPoint - -/** - * Generate sample data used for Linear Data. This class generates - * uniformly random values for every feature and adds Gaussian noise with mean `eps` to the - * response variable `Y`. - */ -object LinearDataGenerator { - - /** - * Return a Java List of synthetic data randomly generated according to a multi - * collinear model. - * @param intercept Data intercept - * @param weights Weights to be applied. - * @param nPoints Number of points in sample. - * @param seed Random seed - * @return Java List of input. - */ - def generateLinearInputAsList( - intercept: Double, - weights: Array[Double], - nPoints: Int, - seed: Int, - eps: Double): java.util.List[LabeledPoint] = { - seqAsJavaList(generateLinearInput(intercept, weights, nPoints, seed, eps)) - } - - /** - * - * @param intercept Data intercept - * @param weights Weights to be applied. - * @param nPoints Number of points in sample. - * @param seed Random seed - * @param eps Epsilon scaling factor. - * @return - */ - def generateLinearInput( - intercept: Double, - weights: Array[Double], - nPoints: Int, - seed: Int, - eps: Double = 0.1): Seq[LabeledPoint] = { - - val rnd = new Random(seed) - val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) - val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(2 * rnd.nextDouble - 1.0)) - val y = x.map { xi => - (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + intercept + eps * rnd.nextGaussian() - } - y.zip(x).map(p => LabeledPoint(p._1, p._2)) - } - - /** - * Generate an RDD containing sample data for Linear Regression models - including Ridge, Lasso, - * and uregularized variants. - * - * @param sc SparkContext to be used for generating the RDD. - * @param nexamples Number of examples that will be contained in the RDD. - * @param nfeatures Number of features to generate for each example. - * @param eps Epsilon factor by which examples are scaled. - * @param weights Weights associated with the first weights.length features. - * @param nparts Number of partitions in the RDD. Default value is 2. - * - * @return RDD of LabeledPoint containing sample data. - */ - def generateLinearRDD( - sc: SparkContext, - nexamples: Int, - nfeatures: Int, - eps: Double, - nparts: Int = 2, - intercept: Double = 0.0) : RDD[LabeledPoint] = { - org.jblas.util.Random.seed(42) - // Random values distributed uniformly in [-0.5, 0.5] - val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5) - - val data: RDD[LabeledPoint] = sc.parallelize(0 until nparts, nparts).flatMap { p => - val seed = 42 + p - val examplesInPartition = nexamples / nparts - generateLinearInput(intercept, w.toArray, examplesInPartition, seed, eps) - } - data - } - - def main(args: Array[String]) { - if (args.length < 2) { - println("Usage: LinearDataGenerator " + - " [num_examples] [num_features] [num_partitions]") - System.exit(1) - } - - val sparkMaster: String = args(0) - val outputPath: String = args(1) - val nexamples: Int = if (args.length > 2) args(2).toInt else 1000 - val nfeatures: Int = if (args.length > 3) args(3).toInt else 100 - val parts: Int = if (args.length > 4) args(4).toInt else 2 - val eps = 10 - - val sc = new SparkContext(sparkMaster, "LinearDataGenerator") - val data = generateLinearRDD(sc, nexamples, nfeatures, eps, nparts = parts) - - MLUtils.saveLabeledData(data, outputPath) - sc.stop() - } -} diff --git a/mllib/src/main/scala/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/LogisticRegressionDataGenerator.scala deleted file mode 100644 index d6402f23e2..0000000000 --- a/mllib/src/main/scala/spark/mllib/util/LogisticRegressionDataGenerator.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.util - -import scala.util.Random - -import spark.{RDD, SparkContext} -import spark.mllib.regression.LabeledPoint - -/** - * Generate test data for LogisticRegression. This class chooses positive labels - * with probability `probOne` and scales features for positive examples by `eps`. - */ - -object LogisticRegressionDataGenerator { - - /** - * Generate an RDD containing test data for LogisticRegression. - * - * @param sc SparkContext to use for creating the RDD. - * @param nexamples Number of examples that will be contained in the RDD. - * @param nfeatures Number of features to generate for each example. - * @param eps Epsilon factor by which positive examples are scaled. - * @param nparts Number of partitions of the generated RDD. Default value is 2. - * @param probOne Probability that a label is 1 (and not 0). Default value is 0.5. - */ - def generateLogisticRDD( - sc: SparkContext, - nexamples: Int, - nfeatures: Int, - eps: Double, - nparts: Int = 2, - probOne: Double = 0.5): RDD[LabeledPoint] = { - val data = sc.parallelize(0 until nexamples, nparts).map { idx => - val rnd = new Random(42 + idx) - - val y = if (idx % 2 == 0) 0.0 else 1.0 - val x = Array.fill[Double](nfeatures) { - rnd.nextGaussian() + (y * eps) - } - LabeledPoint(y, x) - } - data - } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: LogisticRegressionGenerator " + - " ") - System.exit(1) - } - - val sparkMaster: String = args(0) - val outputPath: String = args(1) - val nexamples: Int = if (args.length > 2) args(2).toInt else 1000 - val nfeatures: Int = if (args.length > 3) args(3).toInt else 2 - val parts: Int = if (args.length > 4) args(4).toInt else 2 - val eps = 3 - - val sc = new SparkContext(sparkMaster, "LogisticRegressionDataGenerator") - val data = generateLogisticRDD(sc, nexamples, nfeatures, eps, parts) - - MLUtils.saveLabeledData(data, outputPath) - sc.stop() - } -} diff --git a/mllib/src/main/scala/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/MFDataGenerator.scala deleted file mode 100644 index 88992cde0c..0000000000 --- a/mllib/src/main/scala/spark/mllib/util/MFDataGenerator.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.recommendation - -import scala.util.Random - -import org.jblas.DoubleMatrix - -import spark.{RDD, SparkContext} -import spark.mllib.util.MLUtils - -/** -* Generate RDD(s) containing data for Matrix Factorization. -* -* This method samples training entries according to the oversampling factor -* 'trainSampFact', which is a multiplicative factor of the number of -* degrees of freedom of the matrix: rank*(m+n-rank). -* -* It optionally samples entries for a testing matrix using -* 'testSampFact', the percentage of the number of training entries -* to use for testing. -* -* This method takes the following inputs: -* sparkMaster (String) The master URL. -* outputPath (String) Directory to save output. -* m (Int) Number of rows in data matrix. -* n (Int) Number of columns in data matrix. -* rank (Int) Underlying rank of data matrix. -* trainSampFact (Double) Oversampling factor. -* noise (Boolean) Whether to add gaussian noise to training data. -* sigma (Double) Standard deviation of added gaussian noise. -* test (Boolean) Whether to create testing RDD. -* testSampFact (Double) Percentage of training data to use as test data. -*/ - -object MFDataGenerator{ - - def main(args: Array[String]) { - if (args.length < 2) { - println("Usage: MFDataGenerator " + - " [m] [n] [rank] [trainSampFact] [noise] [sigma] [test] [testSampFact]") - System.exit(1) - } - - val sparkMaster: String = args(0) - val outputPath: String = args(1) - val m: Int = if (args.length > 2) args(2).toInt else 100 - val n: Int = if (args.length > 3) args(3).toInt else 100 - val rank: Int = if (args.length > 4) args(4).toInt else 10 - val trainSampFact: Double = if (args.length > 5) args(5).toDouble else 1.0 - val noise: Boolean = if (args.length > 6) args(6).toBoolean else false - val sigma: Double = if (args.length > 7) args(7).toDouble else 0.1 - val test: Boolean = if (args.length > 8) args(8).toBoolean else false - val testSampFact: Double = if (args.length > 9) args(9).toDouble else 0.1 - - val sc = new SparkContext(sparkMaster, "MFDataGenerator") - - val A = DoubleMatrix.randn(m, rank) - val B = DoubleMatrix.randn(rank, n) - val z = 1 / (scala.math.sqrt(scala.math.sqrt(rank))) - A.mmuli(z) - B.mmuli(z) - val fullData = A.mmul(B) - - val df = rank * (m + n - rank) - val sampSize = scala.math.min(scala.math.round(trainSampFact * df), - scala.math.round(.99 * m * n)).toInt - val rand = new Random() - val mn = m * n - val shuffled = rand.shuffle(1 to mn toIterable) - - val omega = shuffled.slice(0, sampSize) - val ordered = omega.sortWith(_ < _).toArray - val trainData: RDD[(Int, Int, Double)] = sc.parallelize(ordered) - .map(x => (fullData.indexRows(x - 1), fullData.indexColumns(x - 1), fullData.get(x - 1))) - - // optionally add gaussian noise - if (noise) { - trainData.map(x => (x._1, x._2, x._3 + rand.nextGaussian * sigma)) - } - - trainData.map(x => x._1 + "," + x._2 + "," + x._3).saveAsTextFile(outputPath) - - // optionally generate testing data - if (test) { - val testSampSize = scala.math - .min(scala.math.round(sampSize * testSampFact),scala.math.round(mn - sampSize)).toInt - val testOmega = shuffled.slice(sampSize, sampSize + testSampSize) - val testOrdered = testOmega.sortWith(_ < _).toArray - val testData: RDD[(Int, Int, Double)] = sc.parallelize(testOrdered) - .map(x => (fullData.indexRows(x - 1), fullData.indexColumns(x - 1), fullData.get(x - 1))) - testData.map(x => x._1 + "," + x._2 + "," + x._3).saveAsTextFile(outputPath) - } - - sc.stop() - - } -} \ No newline at end of file diff --git a/mllib/src/main/scala/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/spark/mllib/util/MLUtils.scala deleted file mode 100644 index a8e6ae9953..0000000000 --- a/mllib/src/main/scala/spark/mllib/util/MLUtils.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.util - -import spark.{RDD, SparkContext} -import spark.SparkContext._ - -import org.jblas.DoubleMatrix -import spark.mllib.regression.LabeledPoint - -/** - * Helper methods to load, save and pre-process data used in ML Lib. - */ -object MLUtils { - - /** - * Load labeled data from a file. The data format used here is - * , ... - * where , are feature values in Double and is the corresponding label as Double. - * - * @param sc SparkContext - * @param dir Directory to the input data files. - * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is - * the label, and the second element represents the feature values (an array of Double). - */ - def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { - sc.textFile(dir).map { line => - val parts = line.split(',') - val label = parts(0).toDouble - val features = parts(1).trim().split(' ').map(_.toDouble) - LabeledPoint(label, features) - } - } - - /** - * Save labeled data to a file. The data format used here is - * , ... - * where , are feature values in Double and is the corresponding label as Double. - * - * @param data An RDD of LabeledPoints containing data to be saved. - * @param dir Directory to save the data. - */ - def saveLabeledData(data: RDD[LabeledPoint], dir: String) { - val dataStr = data.map(x => x.label + "," + x.features.mkString(" ")) - dataStr.saveAsTextFile(dir) - } - - /** - * Utility function to compute mean and standard deviation on a given dataset. - * - * @param data - input data set whose statistics are computed - * @param nfeatures - number of features - * @param nexamples - number of examples in input dataset - * - * @return (yMean, xColMean, xColSd) - Tuple consisting of - * yMean - mean of the labels - * xColMean - Row vector with mean for every column (or feature) of the input data - * xColSd - Row vector standard deviation for every column (or feature) of the input data. - */ - def computeStats(data: RDD[LabeledPoint], nfeatures: Int, nexamples: Long): - (Double, DoubleMatrix, DoubleMatrix) = { - val yMean: Double = data.map { labeledPoint => labeledPoint.label }.reduce(_ + _) / nexamples - - // NOTE: We shuffle X by column here to compute column sum and sum of squares. - val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { labeledPoint => - val nCols = labeledPoint.features.length - // Traverse over every column and emit (col, value, value^2) - Iterator.tabulate(nCols) { i => - (i, (labeledPoint.features(i), labeledPoint.features(i)*labeledPoint.features(i))) - } - }.reduceByKey { case(x1, x2) => - (x1._1 + x2._1, x1._2 + x2._2) - } - val xColSumsMap = xColSumSq.collectAsMap() - - val xColMean = DoubleMatrix.zeros(nfeatures, 1) - val xColSd = DoubleMatrix.zeros(nfeatures, 1) - - // Compute mean and unbiased variance using column sums - var col = 0 - while (col < nfeatures) { - xColMean.put(col, xColSumsMap(col)._1 / nexamples) - val variance = - (xColSumsMap(col)._2 - (math.pow(xColSumsMap(col)._1, 2) / nexamples)) / (nexamples) - xColSd.put(col, math.sqrt(variance)) - col += 1 - } - - (yMean, xColMean, xColSd) - } - - /** - * Return the squared Euclidean distance between two vectors. - */ - def squaredDistance(v1: Array[Double], v2: Array[Double]): Double = { - if (v1.length != v2.length) { - throw new IllegalArgumentException("Vector sizes don't match") - } - var i = 0 - var sum = 0.0 - while (i < v1.length) { - sum += (v1(i) - v2(i)) * (v1(i) - v2(i)) - i += 1 - } - sum - } -} diff --git a/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala deleted file mode 100644 index eff456cad6..0000000000 --- a/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala +++ /dev/null @@ -1,50 +0,0 @@ -package spark.mllib.util - -import scala.util.Random - -import spark.{RDD, SparkContext} - -import org.jblas.DoubleMatrix -import spark.mllib.regression.LabeledPoint - -/** - * Generate sample data used for SVM. This class generates uniform random values - * for the features and adds Gaussian noise with weight 0.1 to generate labels. - */ -object SVMDataGenerator { - - def main(args: Array[String]) { - if (args.length < 2) { - println("Usage: SVMGenerator " + - " [num_examples] [num_features] [num_partitions]") - System.exit(1) - } - - val sparkMaster: String = args(0) - val outputPath: String = args(1) - val nexamples: Int = if (args.length > 2) args(2).toInt else 1000 - val nfeatures: Int = if (args.length > 3) args(3).toInt else 2 - val parts: Int = if (args.length > 4) args(4).toInt else 2 - - val sc = new SparkContext(sparkMaster, "SVMGenerator") - - val globalRnd = new Random(94720) - val trueWeights = new DoubleMatrix(1, nfeatures + 1, - Array.fill[Double](nfeatures + 1)(globalRnd.nextGaussian()):_*) - - val data: RDD[LabeledPoint] = sc.parallelize(0 until nexamples, parts).map { idx => - val rnd = new Random(42 + idx) - - val x = Array.fill[Double](nfeatures) { - rnd.nextDouble() * 2.0 - 1.0 - } - val yD = (new DoubleMatrix(1, x.length, x:_*)).dot(trueWeights) + rnd.nextGaussian() * 0.1 - val y = if (yD < 0) 0.0 else 1.0 - LabeledPoint(y, x) - } - - MLUtils.saveLabeledData(data, outputPath) - - sc.stop() - } -} diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java new file mode 100644 index 0000000000..e18e3bc6a8 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +import org.apache.spark.mllib.regression.LabeledPoint; + +public class JavaLogisticRegressionSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + int validatePrediction(List validationData, LogisticRegressionModel model) { + int numAccurate = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + if (prediction == point.label()) { + numAccurate++; + } + } + return numAccurate; + } + + @Test + public void runLRUsingConstructor() { + int nPoints = 10000; + double A = 2.0; + double B = -1.5; + + JavaRDD testRDD = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + List validationData = + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); + + LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD(); + lrImpl.optimizer().setStepSize(1.0) + .setRegParam(1.0) + .setNumIterations(100); + LogisticRegressionModel model = lrImpl.run(testRDD.rdd()); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + + @Test + public void runLRUsingStaticMethods() { + int nPoints = 10000; + double A = 2.0; + double B = -1.5; + + JavaRDD testRDD = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + List validationData = + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); + + LogisticRegressionModel model = LogisticRegressionWithSGD.train( + testRDD.rdd(), 100, 1.0, 1.0); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java new file mode 100644 index 0000000000..117e5eaa8b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification; + + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +import org.apache.spark.mllib.regression.LabeledPoint; + +public class JavaSVMSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaSVMSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + int validatePrediction(List validationData, SVMModel model) { + int numAccurate = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + if (prediction == point.label()) { + numAccurate++; + } + } + return numAccurate; + } + + @Test + public void runSVMUsingConstructor() { + int nPoints = 10000; + double A = 2.0; + double[] weights = {-1.5, 1.0}; + + JavaRDD testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A, + weights, nPoints, 42), 2).cache(); + List validationData = + SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); + + SVMWithSGD svmSGDImpl = new SVMWithSGD(); + svmSGDImpl.optimizer().setStepSize(1.0) + .setRegParam(1.0) + .setNumIterations(100); + SVMModel model = svmSGDImpl.run(testRDD.rdd()); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + + @Test + public void runSVMUsingStaticMethods() { + int nPoints = 10000; + double A = 2.0; + double[] weights = {-1.5, 1.0}; + + JavaRDD testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A, + weights, nPoints, 42), 2).cache(); + List validationData = + SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); + + SVMModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java new file mode 100644 index 0000000000..32d3934ac1 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +public class JavaKMeansSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaKMeans"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + // L1 distance between two points + double distance1(double[] v1, double[] v2) { + double distance = 0.0; + for (int i = 0; i < v1.length; ++i) { + distance = Math.max(distance, Math.abs(v1[i] - v2[i])); + } + return distance; + } + + // Assert that two sets of points are equal, within EPSILON tolerance + void assertSetsEqual(double[][] v1, double[][] v2) { + double EPSILON = 1e-4; + Assert.assertTrue(v1.length == v2.length); + for (int i = 0; i < v1.length; ++i) { + double minDistance = Double.MAX_VALUE; + for (int j = 0; j < v2.length; ++j) { + minDistance = Math.min(minDistance, distance1(v1[i], v2[j])); + } + Assert.assertTrue(minDistance <= EPSILON); + } + + for (int i = 0; i < v2.length; ++i) { + double minDistance = Double.MAX_VALUE; + for (int j = 0; j < v1.length; ++j) { + minDistance = Math.min(minDistance, distance1(v2[i], v1[j])); + } + Assert.assertTrue(minDistance <= EPSILON); + } + } + + + @Test + public void runKMeansUsingStaticMethods() { + List points = new ArrayList(); + points.add(new double[]{1.0, 2.0, 6.0}); + points.add(new double[]{1.0, 3.0, 0.0}); + points.add(new double[]{1.0, 4.0, 6.0}); + + double[][] expectedCenter = { {1.0, 3.0, 4.0} }; + + JavaRDD data = sc.parallelize(points, 2); + KMeansModel model = KMeans.train(data.rdd(), 1, 1); + assertSetsEqual(model.clusterCenters(), expectedCenter); + + model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.RANDOM()); + assertSetsEqual(model.clusterCenters(), expectedCenter); + } + + @Test + public void runKMeansUsingConstructor() { + List points = new ArrayList(); + points.add(new double[]{1.0, 2.0, 6.0}); + points.add(new double[]{1.0, 3.0, 0.0}); + points.add(new double[]{1.0, 4.0, 6.0}); + + double[][] expectedCenter = { {1.0, 3.0, 4.0} }; + + JavaRDD data = sc.parallelize(points, 2); + KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd()); + assertSetsEqual(model.clusterCenters(), expectedCenter); + + model = new KMeans().setK(1) + .setMaxIterations(1) + .setRuns(1) + .setInitializationMode(KMeans.RANDOM()) + .run(data.rdd()); + assertSetsEqual(model.clusterCenters(), expectedCenter); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java new file mode 100644 index 0000000000..3323f6cee2 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.recommendation; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +import org.jblas.DoubleMatrix; + +public class JavaALSSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaALS"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + void validatePrediction(MatrixFactorizationModel model, int users, int products, int features, + DoubleMatrix trueRatings, double matchThreshold) { + DoubleMatrix predictedU = new DoubleMatrix(users, features); + List> userFeatures = model.userFeatures().toJavaRDD().collect(); + for (int i = 0; i < features; ++i) { + for (scala.Tuple2 userFeature : userFeatures) { + predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]); + } + } + DoubleMatrix predictedP = new DoubleMatrix(products, features); + + List> productFeatures = + model.productFeatures().toJavaRDD().collect(); + for (int i = 0; i < features; ++i) { + for (scala.Tuple2 productFeature : productFeatures) { + predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]); + } + } + + DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose()); + + for (int u = 0; u < users; ++u) { + for (int p = 0; p < products; ++p) { + double prediction = predictedRatings.get(u, p); + double correct = trueRatings.get(u, p); + Assert.assertTrue(Math.abs(prediction - correct) < matchThreshold); + } + } + } + + @Test + public void runALSUsingStaticMethods() { + int features = 1; + int iterations = 15; + int users = 10; + int products = 10; + scala.Tuple2, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7); + + JavaRDD data = sc.parallelize(testData._1()); + MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); + validatePrediction(model, users, products, features, testData._2(), 0.3); + } + + @Test + public void runALSUsingConstructor() { + int features = 2; + int iterations = 15; + int users = 20; + int products = 30; + scala.Tuple2, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7); + + JavaRDD data = sc.parallelize(testData._1()); + + MatrixFactorizationModel model = new ALS().setRank(features) + .setIterations(iterations) + .run(data.rdd()); + validatePrediction(model, users, products, features, testData._2(), 0.3); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java new file mode 100644 index 0000000000..f44b25cd44 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.util.LinearDataGenerator; + +public class JavaLassoSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaLassoSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + int validatePrediction(List validationData, LassoModel model) { + int numAccurate = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + // A prediction is off if the prediction is more than 0.5 away from expected value. + if (Math.abs(prediction - point.label()) <= 0.5) { + numAccurate++; + } + } + return numAccurate; + } + + @Test + public void runLassoUsingConstructor() { + int nPoints = 10000; + double A = 2.0; + double[] weights = {-1.5, 1.0e-2}; + + JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, + weights, nPoints, 42, 0.1), 2).cache(); + List validationData = + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + + LassoWithSGD lassoSGDImpl = new LassoWithSGD(); + lassoSGDImpl.optimizer().setStepSize(1.0) + .setRegParam(0.01) + .setNumIterations(20); + LassoModel model = lassoSGDImpl.run(testRDD.rdd()); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + + @Test + public void runLassoUsingStaticMethods() { + int nPoints = 10000; + double A = 2.0; + double[] weights = {-1.5, 1.0e-2}; + + JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, + weights, nPoints, 42, 0.1), 2).cache(); + List validationData = + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + + LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java new file mode 100644 index 0000000000..5a4410a632 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.util.LinearDataGenerator; + +public class JavaLinearRegressionSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + int validatePrediction(List validationData, LinearRegressionModel model) { + int numAccurate = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + // A prediction is off if the prediction is more than 0.5 away from expected value. + if (Math.abs(prediction - point.label()) <= 0.5) { + numAccurate++; + } + } + return numAccurate; + } + + @Test + public void runLinearRegressionUsingConstructor() { + int nPoints = 100; + double A = 3.0; + double[] weights = {10, 10}; + + JavaRDD testRDD = sc.parallelize( + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); + List validationData = + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + + LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); + LinearRegressionModel model = linSGDImpl.run(testRDD.rdd()); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + + @Test + public void runLinearRegressionUsingStaticMethods() { + int nPoints = 100; + double A = 3.0; + double[] weights = {10, 10}; + + JavaRDD testRDD = sc.parallelize( + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); + List validationData = + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + + LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100); + + int numAccurate = validatePrediction(validationData, model); + Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); + } + +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java new file mode 100644 index 0000000000..2fdd5fc8fd --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.jblas.DoubleMatrix; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.util.LinearDataGenerator; + +public class JavaRidgeRegressionSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + System.clearProperty("spark.driver.port"); + } + + double predictionError(List validationData, RidgeRegressionModel model) { + double errorSum = 0; + for (LabeledPoint point: validationData) { + Double prediction = model.predict(point.features()); + errorSum += (prediction - point.label()) * (prediction - point.label()); + } + return errorSum / validationData.size(); + } + + List generateRidgeData(int numPoints, int nfeatures, double eps) { + org.jblas.util.Random.seed(42); + // Pick weights as random values distributed uniformly in [-0.5, 0.5] + DoubleMatrix w = DoubleMatrix.rand(nfeatures, 1).subi(0.5); + // Set first two weights to eps + w.put(0, 0, eps); + w.put(1, 0, eps); + return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, eps); + } + + @Test + public void runRidgeRegressionUsingConstructor() { + int nexamples = 200; + int nfeatures = 20; + double eps = 10.0; + List data = generateRidgeData(2*nexamples, nfeatures, eps); + + JavaRDD testRDD = sc.parallelize(data.subList(0, nexamples)); + List validationData = data.subList(nexamples, 2*nexamples); + + RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD(); + ridgeSGDImpl.optimizer().setStepSize(1.0) + .setRegParam(0.0) + .setNumIterations(200); + RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd()); + double unRegularizedErr = predictionError(validationData, model); + + ridgeSGDImpl.optimizer().setRegParam(0.1); + model = ridgeSGDImpl.run(testRDD.rdd()); + double regularizedErr = predictionError(validationData, model); + + Assert.assertTrue(regularizedErr < unRegularizedErr); + } + + @Test + public void runRidgeRegressionUsingStaticMethods() { + int nexamples = 200; + int nfeatures = 20; + double eps = 10.0; + List data = generateRidgeData(2*nexamples, nfeatures, eps); + + JavaRDD testRDD = sc.parallelize(data.subList(0, nexamples)); + List validationData = data.subList(nexamples, 2*nexamples); + + RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0); + double unRegularizedErr = predictionError(validationData, model); + + model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.1); + double regularizedErr = predictionError(validationData, model); + + Assert.assertTrue(regularizedErr < unRegularizedErr); + } +} diff --git a/mllib/src/test/java/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/spark/mllib/classification/JavaLogisticRegressionSuite.java deleted file mode 100644 index e0ebd45cd8..0000000000 --- a/mllib/src/test/java/spark/mllib/classification/JavaLogisticRegressionSuite.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.classification; - -import java.io.Serializable; -import java.util.List; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; - -import spark.mllib.regression.LabeledPoint; - -public class JavaLogisticRegressionSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - } - - int validatePrediction(List validationData, LogisticRegressionModel model) { - int numAccurate = 0; - for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); - if (prediction == point.label()) { - numAccurate++; - } - } - return numAccurate; - } - - @Test - public void runLRUsingConstructor() { - int nPoints = 10000; - double A = 2.0; - double B = -1.5; - - JavaRDD testRDD = sc.parallelize( - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); - List validationData = - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); - - LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD(); - lrImpl.optimizer().setStepSize(1.0) - .setRegParam(1.0) - .setNumIterations(100); - LogisticRegressionModel model = lrImpl.run(testRDD.rdd()); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - - @Test - public void runLRUsingStaticMethods() { - int nPoints = 10000; - double A = 2.0; - double B = -1.5; - - JavaRDD testRDD = sc.parallelize( - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); - List validationData = - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); - - LogisticRegressionModel model = LogisticRegressionWithSGD.train( - testRDD.rdd(), 100, 1.0, 1.0); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - -} diff --git a/mllib/src/test/java/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/spark/mllib/classification/JavaSVMSuite.java deleted file mode 100644 index 7881b3c38f..0000000000 --- a/mllib/src/test/java/spark/mllib/classification/JavaSVMSuite.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.classification; - - -import java.io.Serializable; -import java.util.List; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; - -import spark.mllib.regression.LabeledPoint; - -public class JavaSVMSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaSVMSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - } - - int validatePrediction(List validationData, SVMModel model) { - int numAccurate = 0; - for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); - if (prediction == point.label()) { - numAccurate++; - } - } - return numAccurate; - } - - @Test - public void runSVMUsingConstructor() { - int nPoints = 10000; - double A = 2.0; - double[] weights = {-1.5, 1.0}; - - JavaRDD testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A, - weights, nPoints, 42), 2).cache(); - List validationData = - SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); - - SVMWithSGD svmSGDImpl = new SVMWithSGD(); - svmSGDImpl.optimizer().setStepSize(1.0) - .setRegParam(1.0) - .setNumIterations(100); - SVMModel model = svmSGDImpl.run(testRDD.rdd()); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - - @Test - public void runSVMUsingStaticMethods() { - int nPoints = 10000; - double A = 2.0; - double[] weights = {-1.5, 1.0}; - - JavaRDD testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A, - weights, nPoints, 42), 2).cache(); - List validationData = - SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); - - SVMModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - -} diff --git a/mllib/src/test/java/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/spark/mllib/clustering/JavaKMeansSuite.java deleted file mode 100644 index 3f2d82bfb4..0000000000 --- a/mllib/src/test/java/spark/mllib/clustering/JavaKMeansSuite.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.clustering; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; - -public class JavaKMeansSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaKMeans"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - } - - // L1 distance between two points - double distance1(double[] v1, double[] v2) { - double distance = 0.0; - for (int i = 0; i < v1.length; ++i) { - distance = Math.max(distance, Math.abs(v1[i] - v2[i])); - } - return distance; - } - - // Assert that two sets of points are equal, within EPSILON tolerance - void assertSetsEqual(double[][] v1, double[][] v2) { - double EPSILON = 1e-4; - Assert.assertTrue(v1.length == v2.length); - for (int i = 0; i < v1.length; ++i) { - double minDistance = Double.MAX_VALUE; - for (int j = 0; j < v2.length; ++j) { - minDistance = Math.min(minDistance, distance1(v1[i], v2[j])); - } - Assert.assertTrue(minDistance <= EPSILON); - } - - for (int i = 0; i < v2.length; ++i) { - double minDistance = Double.MAX_VALUE; - for (int j = 0; j < v1.length; ++j) { - minDistance = Math.min(minDistance, distance1(v2[i], v1[j])); - } - Assert.assertTrue(minDistance <= EPSILON); - } - } - - - @Test - public void runKMeansUsingStaticMethods() { - List points = new ArrayList(); - points.add(new double[]{1.0, 2.0, 6.0}); - points.add(new double[]{1.0, 3.0, 0.0}); - points.add(new double[]{1.0, 4.0, 6.0}); - - double[][] expectedCenter = { {1.0, 3.0, 4.0} }; - - JavaRDD data = sc.parallelize(points, 2); - KMeansModel model = KMeans.train(data.rdd(), 1, 1); - assertSetsEqual(model.clusterCenters(), expectedCenter); - - model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.RANDOM()); - assertSetsEqual(model.clusterCenters(), expectedCenter); - } - - @Test - public void runKMeansUsingConstructor() { - List points = new ArrayList(); - points.add(new double[]{1.0, 2.0, 6.0}); - points.add(new double[]{1.0, 3.0, 0.0}); - points.add(new double[]{1.0, 4.0, 6.0}); - - double[][] expectedCenter = { {1.0, 3.0, 4.0} }; - - JavaRDD data = sc.parallelize(points, 2); - KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd()); - assertSetsEqual(model.clusterCenters(), expectedCenter); - - model = new KMeans().setK(1) - .setMaxIterations(1) - .setRuns(1) - .setInitializationMode(KMeans.RANDOM()) - .run(data.rdd()); - assertSetsEqual(model.clusterCenters(), expectedCenter); - } -} diff --git a/mllib/src/test/java/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/spark/mllib/recommendation/JavaALSSuite.java deleted file mode 100644 index 7993629a6d..0000000000 --- a/mllib/src/test/java/spark/mllib/recommendation/JavaALSSuite.java +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.recommendation; - -import java.io.Serializable; -import java.util.List; - -import scala.Tuple2; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; - -import org.jblas.DoubleMatrix; - -public class JavaALSSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaALS"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - } - - void validatePrediction(MatrixFactorizationModel model, int users, int products, int features, - DoubleMatrix trueRatings, double matchThreshold) { - DoubleMatrix predictedU = new DoubleMatrix(users, features); - List> userFeatures = model.userFeatures().toJavaRDD().collect(); - for (int i = 0; i < features; ++i) { - for (scala.Tuple2 userFeature : userFeatures) { - predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]); - } - } - DoubleMatrix predictedP = new DoubleMatrix(products, features); - - List> productFeatures = - model.productFeatures().toJavaRDD().collect(); - for (int i = 0; i < features; ++i) { - for (scala.Tuple2 productFeature : productFeatures) { - predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]); - } - } - - DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose()); - - for (int u = 0; u < users; ++u) { - for (int p = 0; p < products; ++p) { - double prediction = predictedRatings.get(u, p); - double correct = trueRatings.get(u, p); - Assert.assertTrue(Math.abs(prediction - correct) < matchThreshold); - } - } - } - - @Test - public void runALSUsingStaticMethods() { - int features = 1; - int iterations = 15; - int users = 10; - int products = 10; - scala.Tuple2, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7); - - JavaRDD data = sc.parallelize(testData._1()); - MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); - validatePrediction(model, users, products, features, testData._2(), 0.3); - } - - @Test - public void runALSUsingConstructor() { - int features = 2; - int iterations = 15; - int users = 20; - int products = 30; - scala.Tuple2, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7); - - JavaRDD data = sc.parallelize(testData._1()); - - MatrixFactorizationModel model = new ALS().setRank(features) - .setIterations(iterations) - .run(data.rdd()); - validatePrediction(model, users, products, features, testData._2(), 0.3); - } -} diff --git a/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java deleted file mode 100644 index 5863140baf..0000000000 --- a/mllib/src/test/java/spark/mllib/regression/JavaLassoSuite.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression; - -import java.io.Serializable; -import java.util.List; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.mllib.util.LinearDataGenerator; - -public class JavaLassoSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaLassoSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - } - - int validatePrediction(List validationData, LassoModel model) { - int numAccurate = 0; - for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); - // A prediction is off if the prediction is more than 0.5 away from expected value. - if (Math.abs(prediction - point.label()) <= 0.5) { - numAccurate++; - } - } - return numAccurate; - } - - @Test - public void runLassoUsingConstructor() { - int nPoints = 10000; - double A = 2.0; - double[] weights = {-1.5, 1.0e-2}; - - JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, - weights, nPoints, 42, 0.1), 2).cache(); - List validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); - - LassoWithSGD lassoSGDImpl = new LassoWithSGD(); - lassoSGDImpl.optimizer().setStepSize(1.0) - .setRegParam(0.01) - .setNumIterations(20); - LassoModel model = lassoSGDImpl.run(testRDD.rdd()); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - - @Test - public void runLassoUsingStaticMethods() { - int nPoints = 10000; - double A = 2.0; - double[] weights = {-1.5, 1.0e-2}; - - JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, - weights, nPoints, 42, 0.1), 2).cache(); - List validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); - - LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - -} diff --git a/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java deleted file mode 100644 index 50716c7861..0000000000 --- a/mllib/src/test/java/spark/mllib/regression/JavaLinearRegressionSuite.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression; - -import java.io.Serializable; -import java.util.List; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.mllib.util.LinearDataGenerator; - -public class JavaLinearRegressionSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - } - - int validatePrediction(List validationData, LinearRegressionModel model) { - int numAccurate = 0; - for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); - // A prediction is off if the prediction is more than 0.5 away from expected value. - if (Math.abs(prediction - point.label()) <= 0.5) { - numAccurate++; - } - } - return numAccurate; - } - - @Test - public void runLinearRegressionUsingConstructor() { - int nPoints = 100; - double A = 3.0; - double[] weights = {10, 10}; - - JavaRDD testRDD = sc.parallelize( - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); - List validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); - - LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); - LinearRegressionModel model = linSGDImpl.run(testRDD.rdd()); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - - @Test - public void runLinearRegressionUsingStaticMethods() { - int nPoints = 100; - double A = 3.0; - double[] weights = {10, 10}; - - JavaRDD testRDD = sc.parallelize( - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); - List validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); - - LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100); - - int numAccurate = validatePrediction(validationData, model); - Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); - } - -} diff --git a/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java deleted file mode 100644 index 2c0aabad30..0000000000 --- a/mllib/src/test/java/spark/mllib/regression/JavaRidgeRegressionSuite.java +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression; - -import java.io.Serializable; -import java.util.List; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import org.jblas.DoubleMatrix; - -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.mllib.util.LinearDataGenerator; - -public class JavaRidgeRegressionSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - System.clearProperty("spark.driver.port"); - } - - double predictionError(List validationData, RidgeRegressionModel model) { - double errorSum = 0; - for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); - errorSum += (prediction - point.label()) * (prediction - point.label()); - } - return errorSum / validationData.size(); - } - - List generateRidgeData(int numPoints, int nfeatures, double eps) { - org.jblas.util.Random.seed(42); - // Pick weights as random values distributed uniformly in [-0.5, 0.5] - DoubleMatrix w = DoubleMatrix.rand(nfeatures, 1).subi(0.5); - // Set first two weights to eps - w.put(0, 0, eps); - w.put(1, 0, eps); - return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, eps); - } - - @Test - public void runRidgeRegressionUsingConstructor() { - int nexamples = 200; - int nfeatures = 20; - double eps = 10.0; - List data = generateRidgeData(2*nexamples, nfeatures, eps); - - JavaRDD testRDD = sc.parallelize(data.subList(0, nexamples)); - List validationData = data.subList(nexamples, 2*nexamples); - - RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD(); - ridgeSGDImpl.optimizer().setStepSize(1.0) - .setRegParam(0.0) - .setNumIterations(200); - RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd()); - double unRegularizedErr = predictionError(validationData, model); - - ridgeSGDImpl.optimizer().setRegParam(0.1); - model = ridgeSGDImpl.run(testRDD.rdd()); - double regularizedErr = predictionError(validationData, model); - - Assert.assertTrue(regularizedErr < unRegularizedErr); - } - - @Test - public void runRidgeRegressionUsingStaticMethods() { - int nexamples = 200; - int nfeatures = 20; - double eps = 10.0; - List data = generateRidgeData(2*nexamples, nfeatures, eps); - - JavaRDD testRDD = sc.parallelize(data.subList(0, nexamples)); - List validationData = data.subList(nexamples, 2*nexamples); - - RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0); - double unRegularizedErr = predictionError(validationData, model); - - model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.1); - double regularizedErr = predictionError(validationData, model); - - Assert.assertTrue(regularizedErr < unRegularizedErr); - } -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala new file mode 100644 index 0000000000..34c67294e9 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import scala.util.Random +import scala.collection.JavaConversions._ + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.regression._ + +object LogisticRegressionSuite { + + def generateLogisticInputAsList( + offset: Double, + scale: Double, + nPoints: Int, + seed: Int): java.util.List[LabeledPoint] = { + seqAsJavaList(generateLogisticInput(offset, scale, nPoints, seed)) + } + + // Generate input of the form Y = logistic(offset + scale*X) + def generateLogisticInput( + offset: Double, + scale: Double, + nPoints: Int, + seed: Int): Seq[LabeledPoint] = { + val rnd = new Random(seed) + val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) + + // NOTE: if U is uniform[0, 1] then ln(u) - ln(1-u) is Logistic(0,1) + val unifRand = new scala.util.Random(45) + val rLogis = (0 until nPoints).map { i => + val u = unifRand.nextDouble() + math.log(u) - math.log(1.0-u) + } + + // y <- A + B*x + rLogis() + // y <- as.numeric(y > 0) + val y: Seq[Int] = (0 until nPoints).map { i => + val yVal = offset + scale * x1(i) + rLogis(i) + if (yVal > 0) 1 else 0 + } + + val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Array(x1(i)))) + testData + } + +} + +class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) => + (prediction != expected.label) + }.size + // At least 83% of the predictions should be on. + ((input.length - numOffPredictions).toDouble / input.length) should be > 0.83 + } + + // Test if we can correctly learn A, B where Y = logistic(A + B*X) + test("logistic regression") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + val lr = new LogisticRegressionWithSGD() + lr.optimizer.setStepSize(10.0).setNumIterations(20) + + val model = lr.run(testRDD) + + // Test the weights + val weight0 = model.weights(0) + assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") + assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + + val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } + + test("logistic regression with initial weights") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val initialB = -1.0 + val initialWeights = Array(initialB) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + // Use half as many iterations as the previous test. + val lr = new LogisticRegressionWithSGD() + lr.optimizer.setStepSize(10.0).setNumIterations(10) + + val model = lr.run(testRDD, initialWeights) + + val weight0 = model.weights(0) + assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") + assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + + val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala new file mode 100644 index 0000000000..6a957e3ddc --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import scala.util.Random +import scala.math.signum +import scala.collection.JavaConversions._ + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.jblas.DoubleMatrix + +import org.apache.spark.{SparkException, SparkContext} +import org.apache.spark.mllib.regression._ + +object SVMSuite { + + def generateSVMInputAsList( + intercept: Double, + weights: Array[Double], + nPoints: Int, + seed: Int): java.util.List[LabeledPoint] = { + seqAsJavaList(generateSVMInput(intercept, weights, nPoints, seed)) + } + + // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) + def generateSVMInput( + intercept: Double, + weights: Array[Double], + nPoints: Int, + seed: Int): Seq[LabeledPoint] = { + val rnd = new Random(seed) + val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) + val x = Array.fill[Array[Double]](nPoints)( + Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0)) + val y = x.map { xi => + val yD = (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + + intercept + 0.01 * rnd.nextGaussian() + if (yD < 0) 0.0 else 1.0 + } + y.zip(x).map(p => LabeledPoint(p._1, p._2)) + } + +} + +class SVMSuite extends FunSuite with BeforeAndAfterAll { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) => + (prediction != expected.label) + }.size + // At least 80% of the predictions should be on. + assert(numOffPredictions < input.length / 5) + } + + + test("SVM using local random SGD") { + val nPoints = 10000 + + // NOTE: Intercept should be small for generating equal 0s and 1s + val A = 0.01 + val B = -1.5 + val C = 1.0 + + val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val svm = new SVMWithSGD() + svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100) + + val model = svm.run(testRDD) + + val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } + + test("SVM local random SGD with initial weights") { + val nPoints = 10000 + + // NOTE: Intercept should be small for generating equal 0s and 1s + val A = 0.01 + val B = -1.5 + val C = 1.0 + + val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + + val initialB = -1.0 + val initialC = -1.0 + val initialWeights = Array(initialB,initialC) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val svm = new SVMWithSGD() + svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100) + + val model = svm.run(testRDD, initialWeights) + + val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) + val validationRDD = sc.parallelize(validationData,2) + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } + + test("SVM with invalid labels") { + val nPoints = 10000 + + // NOTE: Intercept should be small for generating equal 0s and 1s + val A = 0.01 + val B = -1.5 + val C = 1.0 + + val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testRDD = sc.parallelize(testData, 2) + + val testRDDInvalid = testRDD.map { lp => + if (lp.label == 0.0) { + LabeledPoint(-1.0, lp.features) + } else { + lp + } + } + + intercept[SparkException] { + val model = SVMWithSGD.train(testRDDInvalid, 100) + } + + // Turning off data validation should not throw an exception + val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala new file mode 100644 index 0000000000..94245f6027 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ + +import org.jblas._ + +class KMeansSuite extends FunSuite with BeforeAndAfterAll { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + val EPSILON = 1e-4 + + import KMeans.{RANDOM, K_MEANS_PARALLEL} + + def prettyPrint(point: Array[Double]): String = point.mkString("(", ", ", ")") + + def prettyPrint(points: Array[Array[Double]]): String = { + points.map(prettyPrint).mkString("(", "; ", ")") + } + + // L1 distance between two points + def distance1(v1: Array[Double], v2: Array[Double]): Double = { + v1.zip(v2).map{ case (a, b) => math.abs(a-b) }.max + } + + // Assert that two vectors are equal within tolerance EPSILON + def assertEqual(v1: Array[Double], v2: Array[Double]) { + def errorMessage = prettyPrint(v1) + " did not equal " + prettyPrint(v2) + assert(v1.length == v2.length, errorMessage) + assert(distance1(v1, v2) <= EPSILON, errorMessage) + } + + // Assert that two sets of points are equal, within EPSILON tolerance + def assertSetsEqual(set1: Array[Array[Double]], set2: Array[Array[Double]]) { + def errorMessage = prettyPrint(set1) + " did not equal " + prettyPrint(set2) + assert(set1.length == set2.length, errorMessage) + for (v <- set1) { + val closestDistance = set2.map(w => distance1(v, w)).min + if (closestDistance > EPSILON) { + fail(errorMessage) + } + } + for (v <- set2) { + val closestDistance = set1.map(w => distance1(v, w)).min + if (closestDistance > EPSILON) { + fail(errorMessage) + } + } + } + + test("single cluster") { + val data = sc.parallelize(Array( + Array(1.0, 2.0, 6.0), + Array(1.0, 3.0, 0.0), + Array(1.0, 4.0, 6.0) + )) + + // No matter how many runs or iterations we use, we should get one cluster, + // centered at the mean of the points + + var model = KMeans.train(data, k=1, maxIterations=1) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=2) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=5) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=1, runs=5) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=1, runs=5) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train( + data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + } + + test("single cluster with big dataset") { + val smallData = Array( + Array(1.0, 2.0, 6.0), + Array(1.0, 3.0, 0.0), + Array(1.0, 4.0, 6.0) + ) + val data = sc.parallelize((1 to 100).flatMap(_ => smallData), 4) + + // No matter how many runs or iterations we use, we should get one cluster, + // centered at the mean of the points + + var model = KMeans.train(data, k=1, maxIterations=1) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=2) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=5) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=1, runs=5) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=1, runs=5) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + + model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL) + assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) + } + + test("k-means|| initialization") { + val points = Array( + Array(1.0, 2.0, 6.0), + Array(1.0, 3.0, 0.0), + Array(1.0, 4.0, 6.0), + Array(1.0, 0.0, 1.0), + Array(1.0, 1.0, 1.0) + ) + val rdd = sc.parallelize(points) + + // K-means|| initialization should place all clusters into distinct centers because + // it will make at least five passes, and it will give non-zero probability to each + // unselected point as long as it hasn't yet selected all of them + + var model = KMeans.train(rdd, k=5, maxIterations=1) + assertSetsEqual(model.clusterCenters, points) + + // Iterations of Lloyd's should not change the answer either + model = KMeans.train(rdd, k=5, maxIterations=10) + assertSetsEqual(model.clusterCenters, points) + + // Neither should more runs + model = KMeans.train(rdd, k=5, maxIterations=10, runs=5) + assertSetsEqual(model.clusterCenters, points) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala new file mode 100644 index 0000000000..347ef238f4 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.recommendation + +import scala.collection.JavaConversions._ +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ + +import org.jblas._ + +object ALSSuite { + + def generateRatingsAsJavaList( + users: Int, + products: Int, + features: Int, + samplingRate: Double): (java.util.List[Rating], DoubleMatrix) = { + val (sampledRatings, trueRatings) = generateRatings(users, products, features, samplingRate) + (seqAsJavaList(sampledRatings), trueRatings) + } + + def generateRatings( + users: Int, + products: Int, + features: Int, + samplingRate: Double): (Seq[Rating], DoubleMatrix) = { + val rand = new Random(42) + + // Create a random matrix with uniform values from -1 to 1 + def randomMatrix(m: Int, n: Int) = + new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble() * 2 - 1): _*) + + val userMatrix = randomMatrix(users, features) + val productMatrix = randomMatrix(features, products) + val trueRatings = userMatrix.mmul(productMatrix) + + val sampledRatings = { + for (u <- 0 until users; p <- 0 until products if rand.nextDouble() < samplingRate) + yield Rating(u, p, trueRatings.get(u, p)) + } + + (sampledRatings, trueRatings) + } + +} + + +class ALSSuite extends FunSuite with BeforeAndAfterAll { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + test("rank-1 matrices") { + testALS(10, 20, 1, 15, 0.7, 0.3) + } + + test("rank-2 matrices") { + testALS(20, 30, 2, 15, 0.7, 0.3) + } + + /** + * Test if we can correctly factorize R = U * P where U and P are of known rank. + * + * @param users number of users + * @param products number of products + * @param features number of features (rank of problem) + * @param iterations number of iterations to run + * @param samplingRate what fraction of the user-product pairs are known + * @param matchThreshold max difference allowed to consider a predicted rating correct + */ + def testALS(users: Int, products: Int, features: Int, iterations: Int, + samplingRate: Double, matchThreshold: Double) + { + val (sampledRatings, trueRatings) = ALSSuite.generateRatings(users, products, + features, samplingRate) + val model = ALS.train(sc.parallelize(sampledRatings), features, iterations) + + val predictedU = new DoubleMatrix(users, features) + for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) { + predictedU.put(u, i, vec(i)) + } + val predictedP = new DoubleMatrix(products, features) + for ((p, vec) <- model.productFeatures.collect(); i <- 0 until features) { + predictedP.put(p, i, vec(i)) + } + val predictedRatings = predictedU.mmul(predictedP.transpose) + + for (u <- 0 until users; p <- 0 until products) { + val prediction = predictedRatings.get(u, p) + val correct = trueRatings.get(u, p) + if (math.abs(prediction - correct) > matchThreshold) { + fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format( + u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP)) + } + } + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala new file mode 100644 index 0000000000..db980c7bae --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import scala.collection.JavaConversions._ +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.util.LinearDataGenerator + + +class LassoSuite extends FunSuite with BeforeAndAfterAll { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) => + // A prediction is off if the prediction is more than 0.5 away from expected value. + math.abs(prediction - expected.label) > 0.5 + }.size + // At least 80% of the predictions should be on. + assert(numOffPredictions < input.length / 5) + } + + test("Lasso local random SGD") { + val nPoints = 10000 + + val A = 2.0 + val B = -1.5 + val C = 1.0e-2 + + val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val ls = new LassoWithSGD() + ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20) + + val model = ls.run(testRDD) + + val weight0 = model.weights(0) + val weight1 = model.weights(1) + assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") + assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]") + + val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } + + test("Lasso local random SGD with initial weights") { + val nPoints = 10000 + + val A = 2.0 + val B = -1.5 + val C = 1.0e-2 + + val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42) + + val initialB = -1.0 + val initialC = -1.0 + val initialWeights = Array(initialB,initialC) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val ls = new LassoWithSGD() + ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20) + + val model = ls.run(testRDD, initialWeights) + + val weight0 = model.weights(0) + val weight1 = model.weights(1) + assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") + assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]") + + val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + val validationRDD = sc.parallelize(validationData,2) + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala new file mode 100644 index 0000000000..ef500c704c --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.util.LinearDataGenerator + +class LinearRegressionSuite extends FunSuite with BeforeAndAfterAll { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) => + // A prediction is off if the prediction is more than 0.5 away from expected value. + math.abs(prediction - expected.label) > 0.5 + }.size + // At least 80% of the predictions should be on. + assert(numOffPredictions < input.length / 5) + } + + // Test if we can correctly learn Y = 3 + 10*X1 + 10*X2 + test("linear regression") { + val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput( + 3.0, Array(10.0, 10.0), 100, 42), 2).cache() + val linReg = new LinearRegressionWithSGD() + linReg.optimizer.setNumIterations(1000).setStepSize(1.0) + + val model = linReg.run(testRDD) + + assert(model.intercept >= 2.5 && model.intercept <= 3.5) + assert(model.weights.length === 2) + assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0) + assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0) + + val validationData = LinearDataGenerator.generateLinearInput( + 3.0, Array(10.0, 10.0), 100, 17) + val validationRDD = sc.parallelize(validationData, 2).cache() + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala new file mode 100644 index 0000000000..c18092d804 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import scala.collection.JavaConversions._ +import scala.util.Random + +import org.jblas.DoubleMatrix +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.util.LinearDataGenerator + +class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = { + predictions.zip(input).map { case (prediction, expected) => + (prediction - expected.label) * (prediction - expected.label) + }.reduceLeft(_ + _) / predictions.size + } + + test("regularization with skewed weights") { + val nexamples = 200 + val nfeatures = 20 + val eps = 10 + + org.jblas.util.Random.seed(42) + // Pick weights as random values distributed uniformly in [-0.5, 0.5] + val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5) + // Set first two weights to eps + w.put(0, 0, eps) + w.put(1, 0, eps) + + // Use half of data for training and other half for validation + val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2*nexamples, 42, eps) + val testData = data.take(nexamples) + val validationData = data.takeRight(nexamples) + + val testRDD = sc.parallelize(testData, 2).cache() + val validationRDD = sc.parallelize(validationData, 2).cache() + + // First run without regularization. + val linearReg = new LinearRegressionWithSGD() + linearReg.optimizer.setNumIterations(200) + .setStepSize(1.0) + + val linearModel = linearReg.run(testRDD) + val linearErr = predictionError( + linearModel.predict(validationRDD.map(_.features)).collect(), validationData) + + val ridgeReg = new RidgeRegressionWithSGD() + ridgeReg.optimizer.setNumIterations(200) + .setRegParam(0.1) + .setStepSize(1.0) + val ridgeModel = ridgeReg.run(testRDD) + val ridgeErr = predictionError( + ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData) + + // Ridge CV-error should be lower than linear regression + assert(ridgeErr < linearErr, + "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")") + } +} diff --git a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala deleted file mode 100644 index bd87c528c3..0000000000 --- a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.classification - -import scala.util.Random -import scala.collection.JavaConversions._ - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite -import org.scalatest.matchers.ShouldMatchers - -import spark.SparkContext -import spark.mllib.regression._ - -object LogisticRegressionSuite { - - def generateLogisticInputAsList( - offset: Double, - scale: Double, - nPoints: Int, - seed: Int): java.util.List[LabeledPoint] = { - seqAsJavaList(generateLogisticInput(offset, scale, nPoints, seed)) - } - - // Generate input of the form Y = logistic(offset + scale*X) - def generateLogisticInput( - offset: Double, - scale: Double, - nPoints: Int, - seed: Int): Seq[LabeledPoint] = { - val rnd = new Random(seed) - val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) - - // NOTE: if U is uniform[0, 1] then ln(u) - ln(1-u) is Logistic(0,1) - val unifRand = new scala.util.Random(45) - val rLogis = (0 until nPoints).map { i => - val u = unifRand.nextDouble() - math.log(u) - math.log(1.0-u) - } - - // y <- A + B*x + rLogis() - // y <- as.numeric(y > 0) - val y: Seq[Int] = (0 until nPoints).map { i => - val yVal = offset + scale * x1(i) + rLogis(i) - if (yVal > 0) 1 else 0 - } - - val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Array(x1(i)))) - testData - } - -} - -class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } - - def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { - val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) => - (prediction != expected.label) - }.size - // At least 83% of the predictions should be on. - ((input.length - numOffPredictions).toDouble / input.length) should be > 0.83 - } - - // Test if we can correctly learn A, B where Y = logistic(A + B*X) - test("logistic regression") { - val nPoints = 10000 - val A = 2.0 - val B = -1.5 - - val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) - - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() - val lr = new LogisticRegressionWithSGD() - lr.optimizer.setStepSize(10.0).setNumIterations(20) - - val model = lr.run(testRDD) - - // Test the weights - val weight0 = model.weights(0) - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") - - val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) - // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) - - // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) - } - - test("logistic regression with initial weights") { - val nPoints = 10000 - val A = 2.0 - val B = -1.5 - - val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) - - val initialB = -1.0 - val initialWeights = Array(initialB) - - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() - - // Use half as many iterations as the previous test. - val lr = new LogisticRegressionWithSGD() - lr.optimizer.setStepSize(10.0).setNumIterations(10) - - val model = lr.run(testRDD, initialWeights) - - val weight0 = model.weights(0) - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") - - val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) - // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) - - // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) - } -} diff --git a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala deleted file mode 100644 index 894ae458ad..0000000000 --- a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.classification - -import scala.util.Random -import scala.math.signum -import scala.collection.JavaConversions._ - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite - -import spark.SparkContext -import spark.mllib.regression._ - -import org.jblas.DoubleMatrix - -object SVMSuite { - - def generateSVMInputAsList( - intercept: Double, - weights: Array[Double], - nPoints: Int, - seed: Int): java.util.List[LabeledPoint] = { - seqAsJavaList(generateSVMInput(intercept, weights, nPoints, seed)) - } - - // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) - def generateSVMInput( - intercept: Double, - weights: Array[Double], - nPoints: Int, - seed: Int): Seq[LabeledPoint] = { - val rnd = new Random(seed) - val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) - val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0)) - val y = x.map { xi => - val yD = (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + - intercept + 0.01 * rnd.nextGaussian() - if (yD < 0) 0.0 else 1.0 - } - y.zip(x).map(p => LabeledPoint(p._1, p._2)) - } - -} - -class SVMSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } - - def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { - val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) => - (prediction != expected.label) - }.size - // At least 80% of the predictions should be on. - assert(numOffPredictions < input.length / 5) - } - - - test("SVM using local random SGD") { - val nPoints = 10000 - - // NOTE: Intercept should be small for generating equal 0s and 1s - val A = 0.01 - val B = -1.5 - val C = 1.0 - - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) - - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() - - val svm = new SVMWithSGD() - svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100) - - val model = svm.run(testRDD) - - val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) - - // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) - - // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) - } - - test("SVM local random SGD with initial weights") { - val nPoints = 10000 - - // NOTE: Intercept should be small for generating equal 0s and 1s - val A = 0.01 - val B = -1.5 - val C = 1.0 - - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) - - val initialB = -1.0 - val initialC = -1.0 - val initialWeights = Array(initialB,initialC) - - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() - - val svm = new SVMWithSGD() - svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100) - - val model = svm.run(testRDD, initialWeights) - - val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData,2) - - // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) - - // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) - } - - test("SVM with invalid labels") { - val nPoints = 10000 - - // NOTE: Intercept should be small for generating equal 0s and 1s - val A = 0.01 - val B = -1.5 - val C = 1.0 - - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) - val testRDD = sc.parallelize(testData, 2) - - val testRDDInvalid = testRDD.map { lp => - if (lp.label == 0.0) { - LabeledPoint(-1.0, lp.features) - } else { - lp - } - } - - intercept[spark.SparkException] { - val model = SVMWithSGD.train(testRDDInvalid, 100) - } - - // Turning off data validation should not throw an exception - val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid) - } -} diff --git a/mllib/src/test/scala/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/spark/mllib/clustering/KMeansSuite.scala deleted file mode 100644 index d5d95c8639..0000000000 --- a/mllib/src/test/scala/spark/mllib/clustering/KMeansSuite.scala +++ /dev/null @@ -1,173 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.clustering - -import scala.util.Random - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite - -import spark.SparkContext -import spark.SparkContext._ - -import org.jblas._ - -class KMeansSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } - - val EPSILON = 1e-4 - - import KMeans.{RANDOM, K_MEANS_PARALLEL} - - def prettyPrint(point: Array[Double]): String = point.mkString("(", ", ", ")") - - def prettyPrint(points: Array[Array[Double]]): String = { - points.map(prettyPrint).mkString("(", "; ", ")") - } - - // L1 distance between two points - def distance1(v1: Array[Double], v2: Array[Double]): Double = { - v1.zip(v2).map{ case (a, b) => math.abs(a-b) }.max - } - - // Assert that two vectors are equal within tolerance EPSILON - def assertEqual(v1: Array[Double], v2: Array[Double]) { - def errorMessage = prettyPrint(v1) + " did not equal " + prettyPrint(v2) - assert(v1.length == v2.length, errorMessage) - assert(distance1(v1, v2) <= EPSILON, errorMessage) - } - - // Assert that two sets of points are equal, within EPSILON tolerance - def assertSetsEqual(set1: Array[Array[Double]], set2: Array[Array[Double]]) { - def errorMessage = prettyPrint(set1) + " did not equal " + prettyPrint(set2) - assert(set1.length == set2.length, errorMessage) - for (v <- set1) { - val closestDistance = set2.map(w => distance1(v, w)).min - if (closestDistance > EPSILON) { - fail(errorMessage) - } - } - for (v <- set2) { - val closestDistance = set1.map(w => distance1(v, w)).min - if (closestDistance > EPSILON) { - fail(errorMessage) - } - } - } - - test("single cluster") { - val data = sc.parallelize(Array( - Array(1.0, 2.0, 6.0), - Array(1.0, 3.0, 0.0), - Array(1.0, 4.0, 6.0) - )) - - // No matter how many runs or iterations we use, we should get one cluster, - // centered at the mean of the points - - var model = KMeans.train(data, k=1, maxIterations=1) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=2) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=5) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=1, runs=5) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=1, runs=5) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train( - data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - } - - test("single cluster with big dataset") { - val smallData = Array( - Array(1.0, 2.0, 6.0), - Array(1.0, 3.0, 0.0), - Array(1.0, 4.0, 6.0) - ) - val data = sc.parallelize((1 to 100).flatMap(_ => smallData), 4) - - // No matter how many runs or iterations we use, we should get one cluster, - // centered at the mean of the points - - var model = KMeans.train(data, k=1, maxIterations=1) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=2) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=5) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=1, runs=5) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=1, runs=5) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - - model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL) - assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0))) - } - - test("k-means|| initialization") { - val points = Array( - Array(1.0, 2.0, 6.0), - Array(1.0, 3.0, 0.0), - Array(1.0, 4.0, 6.0), - Array(1.0, 0.0, 1.0), - Array(1.0, 1.0, 1.0) - ) - val rdd = sc.parallelize(points) - - // K-means|| initialization should place all clusters into distinct centers because - // it will make at least five passes, and it will give non-zero probability to each - // unselected point as long as it hasn't yet selected all of them - - var model = KMeans.train(rdd, k=5, maxIterations=1) - assertSetsEqual(model.clusterCenters, points) - - // Iterations of Lloyd's should not change the answer either - model = KMeans.train(rdd, k=5, maxIterations=10) - assertSetsEqual(model.clusterCenters, points) - - // Neither should more runs - model = KMeans.train(rdd, k=5, maxIterations=10, runs=5) - assertSetsEqual(model.clusterCenters, points) - } -} diff --git a/mllib/src/test/scala/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/spark/mllib/recommendation/ALSSuite.scala deleted file mode 100644 index 15a60efda6..0000000000 --- a/mllib/src/test/scala/spark/mllib/recommendation/ALSSuite.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.recommendation - -import scala.collection.JavaConversions._ -import scala.util.Random - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite - -import spark.SparkContext -import spark.SparkContext._ - -import org.jblas._ - -object ALSSuite { - - def generateRatingsAsJavaList( - users: Int, - products: Int, - features: Int, - samplingRate: Double): (java.util.List[Rating], DoubleMatrix) = { - val (sampledRatings, trueRatings) = generateRatings(users, products, features, samplingRate) - (seqAsJavaList(sampledRatings), trueRatings) - } - - def generateRatings( - users: Int, - products: Int, - features: Int, - samplingRate: Double): (Seq[Rating], DoubleMatrix) = { - val rand = new Random(42) - - // Create a random matrix with uniform values from -1 to 1 - def randomMatrix(m: Int, n: Int) = - new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble() * 2 - 1): _*) - - val userMatrix = randomMatrix(users, features) - val productMatrix = randomMatrix(features, products) - val trueRatings = userMatrix.mmul(productMatrix) - - val sampledRatings = { - for (u <- 0 until users; p <- 0 until products if rand.nextDouble() < samplingRate) - yield Rating(u, p, trueRatings.get(u, p)) - } - - (sampledRatings, trueRatings) - } - -} - - -class ALSSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } - - test("rank-1 matrices") { - testALS(10, 20, 1, 15, 0.7, 0.3) - } - - test("rank-2 matrices") { - testALS(20, 30, 2, 15, 0.7, 0.3) - } - - /** - * Test if we can correctly factorize R = U * P where U and P are of known rank. - * - * @param users number of users - * @param products number of products - * @param features number of features (rank of problem) - * @param iterations number of iterations to run - * @param samplingRate what fraction of the user-product pairs are known - * @param matchThreshold max difference allowed to consider a predicted rating correct - */ - def testALS(users: Int, products: Int, features: Int, iterations: Int, - samplingRate: Double, matchThreshold: Double) - { - val (sampledRatings, trueRatings) = ALSSuite.generateRatings(users, products, - features, samplingRate) - val model = ALS.train(sc.parallelize(sampledRatings), features, iterations) - - val predictedU = new DoubleMatrix(users, features) - for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) { - predictedU.put(u, i, vec(i)) - } - val predictedP = new DoubleMatrix(products, features) - for ((p, vec) <- model.productFeatures.collect(); i <- 0 until features) { - predictedP.put(p, i, vec(i)) - } - val predictedRatings = predictedU.mmul(predictedP.transpose) - - for (u <- 0 until users; p <- 0 until products) { - val prediction = predictedRatings.get(u, p) - val correct = trueRatings.get(u, p) - if (math.abs(prediction - correct) > matchThreshold) { - fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format( - u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP)) - } - } - } -} - diff --git a/mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala deleted file mode 100644 index 622dbbab7f..0000000000 --- a/mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -import scala.collection.JavaConversions._ -import scala.util.Random - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite - -import spark.SparkContext -import spark.mllib.util.LinearDataGenerator - - -class LassoSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } - - def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { - val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) => - // A prediction is off if the prediction is more than 0.5 away from expected value. - math.abs(prediction - expected.label) > 0.5 - }.size - // At least 80% of the predictions should be on. - assert(numOffPredictions < input.length / 5) - } - - test("Lasso local random SGD") { - val nPoints = 10000 - - val A = 2.0 - val B = -1.5 - val C = 1.0e-2 - - val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42) - - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() - - val ls = new LassoWithSGD() - ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20) - - val model = ls.run(testRDD) - - val weight0 = model.weights(0) - val weight1 = model.weights(1) - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]") - - val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) - - // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) - - // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) - } - - test("Lasso local random SGD with initial weights") { - val nPoints = 10000 - - val A = 2.0 - val B = -1.5 - val C = 1.0e-2 - - val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42) - - val initialB = -1.0 - val initialC = -1.0 - val initialWeights = Array(initialB,initialC) - - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() - - val ls = new LassoWithSGD() - ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20) - - val model = ls.run(testRDD, initialWeights) - - val weight0 = model.weights(0) - val weight1 = model.weights(1) - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]") - - val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData,2) - - // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) - - // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) - } -} diff --git a/mllib/src/test/scala/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LinearRegressionSuite.scala deleted file mode 100644 index acc48a3283..0000000000 --- a/mllib/src/test/scala/spark/mllib/regression/LinearRegressionSuite.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite - -import spark.SparkContext -import spark.SparkContext._ -import spark.mllib.util.LinearDataGenerator - -class LinearRegressionSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } - - def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { - val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) => - // A prediction is off if the prediction is more than 0.5 away from expected value. - math.abs(prediction - expected.label) > 0.5 - }.size - // At least 80% of the predictions should be on. - assert(numOffPredictions < input.length / 5) - } - - // Test if we can correctly learn Y = 3 + 10*X1 + 10*X2 - test("linear regression") { - val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput( - 3.0, Array(10.0, 10.0), 100, 42), 2).cache() - val linReg = new LinearRegressionWithSGD() - linReg.optimizer.setNumIterations(1000).setStepSize(1.0) - - val model = linReg.run(testRDD) - - assert(model.intercept >= 2.5 && model.intercept <= 3.5) - assert(model.weights.length === 2) - assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0) - assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0) - - val validationData = LinearDataGenerator.generateLinearInput( - 3.0, Array(10.0, 10.0), 100, 17) - val validationRDD = sc.parallelize(validationData, 2).cache() - - // Test prediction on RDD. - validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) - - // Test prediction on Array. - validatePrediction(validationData.map(row => model.predict(row.features)), validationData) - } -} diff --git a/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala deleted file mode 100644 index c482035706..0000000000 --- a/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package spark.mllib.regression - -import scala.collection.JavaConversions._ -import scala.util.Random - -import org.jblas.DoubleMatrix -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite - -import spark.SparkContext -import spark.SparkContext._ -import spark.mllib.util.LinearDataGenerator - -class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } - - def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = { - predictions.zip(input).map { case (prediction, expected) => - (prediction - expected.label) * (prediction - expected.label) - }.reduceLeft(_ + _) / predictions.size - } - - test("regularization with skewed weights") { - val nexamples = 200 - val nfeatures = 20 - val eps = 10 - - org.jblas.util.Random.seed(42) - // Pick weights as random values distributed uniformly in [-0.5, 0.5] - val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5) - // Set first two weights to eps - w.put(0, 0, eps) - w.put(1, 0, eps) - - // Use half of data for training and other half for validation - val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2*nexamples, 42, eps) - val testData = data.take(nexamples) - val validationData = data.takeRight(nexamples) - - val testRDD = sc.parallelize(testData, 2).cache() - val validationRDD = sc.parallelize(validationData, 2).cache() - - // First run without regularization. - val linearReg = new LinearRegressionWithSGD() - linearReg.optimizer.setNumIterations(200) - .setStepSize(1.0) - - val linearModel = linearReg.run(testRDD) - val linearErr = predictionError( - linearModel.predict(validationRDD.map(_.features)).collect(), validationData) - - val ridgeReg = new RidgeRegressionWithSGD() - ridgeReg.optimizer.setNumIterations(200) - .setRegParam(0.1) - .setStepSize(1.0) - val ridgeModel = ridgeReg.run(testRDD) - val ridgeErr = predictionError( - ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData) - - // Ridge CV-error should be lower than linear regression - assert(ridgeErr < linearErr, - "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")") - } -} diff --git a/pom.xml b/pom.xml index e2fd54a966..9230611eae 100644 --- a/pom.xml +++ b/pom.xml @@ -18,22 +18,22 @@ 4.0.0 - org.spark-project + org.apache.spark spark-parent 0.8.0-SNAPSHOT pom Spark Project Parent POM - http://spark-project.org/ + http://spark.incubator.apache.org/ - BSD License - https://github.com/mesos/spark/blob/master/LICENSE + Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0.html repo - scm:git:git@github.com:mesos/spark.git - scm:git:git@github.com:mesos/spark.git + scm:git:git@github.com:apache/incubator-spark.git + scm:git:git@github.com:apache/incubator-spark.git @@ -46,7 +46,7 @@ - github + JIRA https://spark-project.atlassian.net/browse/SPARK diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 2e26812671..18e86d2cae 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -74,7 +74,7 @@ object SparkBuild extends Build { core, repl, examples, bagel, streaming, mllib, tools, assemblyProj) ++ maybeYarnRef def sharedSettings = Defaults.defaultSettings ++ Seq( - organization := "org.spark-project", + organization := "org.apache.spark", version := "0.8.0-SNAPSHOT", scalaVersion := "2.9.3", scalacOptions := Seq("-unchecked", "-optimize", "-deprecation"), @@ -103,7 +103,7 @@ object SparkBuild extends Build { //useGpg in Global := true, pomExtra := ( - http://spark-project.org/ + http://spark.incubator.apache.org/ Apache 2.0 License @@ -112,8 +112,8 @@ object SparkBuild extends Build { - scm:git:git@github.com:mesos/spark.git - scm:git:git@github.com:mesos/spark.git + scm:git:git@github.com:apache/incubator-spark.git + scm:git:git@github.com:apache/incubator-spark.git @@ -125,6 +125,10 @@ object SparkBuild extends Build { http://www.cs.berkeley.edu/ + + JIRA + https://spark-project.atlassian.net/browse/SPARK + ), /* diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 2803ce90f3..906e9221a1 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -114,9 +114,9 @@ class SparkContext(object): self.addPyFile(path) # Create a temporary directory inside spark.local.dir: - local_dir = self._jvm.spark.Utils.getLocalDir() + local_dir = self._jvm.org.apache.spark.Utils.getLocalDir() self._temp_dir = \ - self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath() + self._jvm.org.apache.spark.Utils.createTempDir(local_dir).getAbsolutePath() @property def defaultParallelism(self): diff --git a/python/pyspark/files.py b/python/pyspark/files.py index 89bcbcfe06..57ee14eeb7 100644 --- a/python/pyspark/files.py +++ b/python/pyspark/files.py @@ -52,4 +52,4 @@ class SparkFiles(object): return cls._root_directory else: # This will have to change if we support multiple SparkContexts: - return cls._sc._jvm.spark.SparkFiles.getRootDirectory() + return cls._sc._jvm.org.apache.spark.SparkFiles.getRootDirectory() diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3ccf062c86..26fbe0f080 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -53,7 +53,7 @@ def launch_gateway(): # Connect to the gateway gateway = JavaGateway(GatewayClient(port=port), auto_convert=False) # Import the classes used by PySpark - java_import(gateway.jvm, "spark.api.java.*") - java_import(gateway.jvm, "spark.api.python.*") + java_import(gateway.jvm, "org.apache.spark.api.java.*") + java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "scala.Tuple2") return gateway diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index 919e35f240..6a1b09e8df 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -19,13 +19,13 @@ 4.0.0 - org.spark-project + org.apache.spark spark-parent 0.8.0-SNAPSHOT ../pom.xml - org.spark-project + org.apache.spark spark-repl-bin pom Spark Project REPL binary packaging @@ -39,18 +39,18 @@ - org.spark-project + org.apache.spark spark-core ${project.version} - org.spark-project + org.apache.spark spark-bagel ${project.version} runtime - org.spark-project + org.apache.spark spark-repl ${project.version} runtime @@ -109,7 +109,7 @@ hadoop2-yarn - org.spark-project + org.apache.spark spark-yarn ${project.version} diff --git a/repl/pom.xml b/repl/pom.xml index f800664cff..f6276f1895 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -19,13 +19,13 @@ 4.0.0 - org.spark-project + org.apache.spark spark-parent 0.8.0-SNAPSHOT ../pom.xml - org.spark-project + org.apache.spark spark-repl jar Spark Project REPL @@ -38,18 +38,18 @@ - org.spark-project + org.apache.spark spark-core ${project.version} - org.spark-project + org.apache.spark spark-bagel ${project.version} runtime - org.spark-project + org.apache.spark spark-mllib ${project.version} runtime @@ -136,7 +136,7 @@ hadoop2-yarn - org.spark-project + org.apache.spark spark-yarn ${project.version} diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala new file mode 100644 index 0000000000..3e171849e3 --- /dev/null +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import java.io.{ByteArrayOutputStream, InputStream} +import java.net.{URI, URL, URLClassLoader, URLEncoder} +import java.util.concurrent.{Executors, ExecutorService} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.objectweb.asm._ +import org.objectweb.asm.Opcodes._ + + +/** + * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, + * used to load classes defined by the interpreter when the REPL is used + */ +class ExecutorClassLoader(classUri: String, parent: ClassLoader) +extends ClassLoader(parent) { + val uri = new URI(classUri) + val directory = uri.getPath + + // Hadoop FileSystem object for our URI, if it isn't using HTTP + var fileSystem: FileSystem = { + if (uri.getScheme() == "http") + null + else + FileSystem.get(uri, new Configuration()) + } + + override def findClass(name: String): Class[_] = { + try { + val pathInDirectory = name.replace('.', '/') + ".class" + val inputStream = { + if (fileSystem != null) + fileSystem.open(new Path(directory, pathInDirectory)) + else + new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream() + } + val bytes = readAndTransformClass(name, inputStream) + inputStream.close() + return defineClass(name, bytes, 0, bytes.length) + } catch { + case e: Exception => throw new ClassNotFoundException(name, e) + } + } + + def readAndTransformClass(name: String, in: InputStream): Array[Byte] = { + if (name.startsWith("line") && name.endsWith("$iw$")) { + // Class seems to be an interpreter "wrapper" object storing a val or var. + // Replace its constructor with a dummy one that does not run the + // initialization code placed there by the REPL. The val or var will + // be initialized later through reflection when it is used in a task. + val cr = new ClassReader(in) + val cw = new ClassWriter( + ClassWriter.COMPUTE_FRAMES + ClassWriter.COMPUTE_MAXS) + val cleaner = new ConstructorCleaner(name, cw) + cr.accept(cleaner, 0) + return cw.toByteArray + } else { + // Pass the class through unmodified + val bos = new ByteArrayOutputStream + val bytes = new Array[Byte](4096) + var done = false + while (!done) { + val num = in.read(bytes) + if (num >= 0) + bos.write(bytes, 0, num) + else + done = true + } + return bos.toByteArray + } + } + + /** + * URL-encode a string, preserving only slashes + */ + def urlEncode(str: String): String = { + str.split('/').map(part => URLEncoder.encode(part, "UTF-8")).mkString("/") + } +} + +class ConstructorCleaner(className: String, cv: ClassVisitor) +extends ClassVisitor(ASM4, cv) { + override def visitMethod(access: Int, name: String, desc: String, + sig: String, exceptions: Array[String]): MethodVisitor = { + val mv = cv.visitMethod(access, name, desc, sig, exceptions) + if (name == "" && (access & ACC_STATIC) == 0) { + // This is the constructor, time to clean it; just output some new + // instructions to mv that create the object and set the static MODULE$ + // field in the class to point to it, but do nothing otherwise. + mv.visitCode() + mv.visitVarInsn(ALOAD, 0) // load this + mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V") + mv.visitVarInsn(ALOAD, 0) // load this + //val classType = className.replace('.', '/') + //mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";") + mv.visitInsn(RETURN) + mv.visitMaxs(-1, -1) // stack size and local vars will be auto-computed + mv.visitEnd() + return null + } else { + return mv + } + } +} diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/src/main/scala/org/apache/spark/repl/Main.scala new file mode 100644 index 0000000000..17e149f8ab --- /dev/null +++ b/repl/src/main/scala/org/apache/spark/repl/Main.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import scala.collection.mutable.Set + +object Main { + private var _interp: SparkILoop = null + + def interp = _interp + + def interp_=(i: SparkILoop) { _interp = i } + + def main(args: Array[String]) { + _interp = new SparkILoop + _interp.process(args) + } +} diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala b/repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala new file mode 100644 index 0000000000..d8fb7191b4 --- /dev/null +++ b/repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala @@ -0,0 +1,5 @@ +package scala.tools.nsc + +object SparkHelper { + def explicitParentLoader(settings: Settings) = settings.explicitParentLoader +} diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala new file mode 100644 index 0000000000..193ccb48ee --- /dev/null +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -0,0 +1,1008 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2011 LAMP/EPFL + * @author Alexander Spoon + */ + +package org.apache.spark.repl + +import scala.tools.nsc._ +import scala.tools.nsc.interpreter._ + +import Predef.{ println => _, _ } +import java.io.{ BufferedReader, FileReader, PrintWriter } +import scala.sys.process.Process +import session._ +import scala.tools.nsc.interpreter.{ Results => IR } +import scala.tools.util.{ SignalManager, Signallable, Javap } +import scala.annotation.tailrec +import scala.util.control.Exception.{ ignoring } +import scala.collection.mutable.ListBuffer +import scala.concurrent.ops +import util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream } +import interpreter._ +import io.{ File, Sources } + +import org.apache.spark.Logging +import org.apache.spark.SparkContext + +/** The Scala interactive shell. It provides a read-eval-print loop + * around the Interpreter class. + * After instantiation, clients should call the main() method. + * + * If no in0 is specified, then input will come from the console, and + * the class will attempt to provide input editing feature such as + * input history. + * + * @author Moez A. Abdel-Gawad + * @author Lex Spoon + * @version 1.2 + */ +class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: Option[String]) + extends AnyRef + with LoopCommands + with Logging +{ + def this(in0: BufferedReader, out: PrintWriter, master: String) = this(Some(in0), out, Some(master)) + def this(in0: BufferedReader, out: PrintWriter) = this(Some(in0), out, None) + def this() = this(None, new PrintWriter(Console.out, true), None) + + var in: InteractiveReader = _ // the input stream from which commands come + var settings: Settings = _ + var intp: SparkIMain = _ + + /* + lazy val power = { + val g = intp.global + Power[g.type](this, g) + } + */ + + // TODO + // object opt extends AestheticSettings + // + @deprecated("Use `intp` instead.", "2.9.0") + def interpreter = intp + + @deprecated("Use `intp` instead.", "2.9.0") + def interpreter_= (i: SparkIMain): Unit = intp = i + + def history = in.history + + /** The context class loader at the time this object was created */ + protected val originalClassLoader = Thread.currentThread.getContextClassLoader + + // Install a signal handler so we can be prodded. + private val signallable = + /*if (isReplDebug) Signallable("Dump repl state.")(dumpCommand()) + else*/ null + + // classpath entries added via :cp + var addedClasspath: String = "" + + /** A reverse list of commands to replay if the user requests a :replay */ + var replayCommandStack: List[String] = Nil + + /** A list of commands to replay if the user requests a :replay */ + def replayCommands = replayCommandStack.reverse + + /** Record a command for replay should the user request a :replay */ + def addReplay(cmd: String) = replayCommandStack ::= cmd + + /** Try to install sigint handler: ignore failure. Signal handler + * will interrupt current line execution if any is in progress. + * + * Attempting to protect the repl from accidental exit, we only honor + * a single ctrl-C if the current buffer is empty: otherwise we look + * for a second one within a short time. + */ + private def installSigIntHandler() { + def onExit() { + Console.println("") // avoiding "shell prompt in middle of line" syndrome + sys.exit(1) + } + ignoring(classOf[Exception]) { + SignalManager("INT") = { + if (intp == null) + onExit() + else if (intp.lineManager.running) + intp.lineManager.cancel() + else if (in.currentLine != "") { + // non-empty buffer, so make them hit ctrl-C a second time + SignalManager("INT") = onExit() + io.timer(5)(installSigIntHandler()) // and restore original handler if they don't + } + else onExit() + } + } + } + + /** Close the interpreter and set the var to null. */ + def closeInterpreter() { + if (intp ne null) { + intp.close + intp = null + Thread.currentThread.setContextClassLoader(originalClassLoader) + } + } + + class SparkILoopInterpreter extends SparkIMain(settings, out) { + override lazy val formatting = new Formatting { + def prompt = SparkILoop.this.prompt + } + override protected def createLineManager() = new Line.Manager { + override def onRunaway(line: Line[_]): Unit = { + val template = """ + |// She's gone rogue, captain! Have to take her out! + |// Calling Thread.stop on runaway %s with offending code: + |// scala> %s""".stripMargin + + echo(template.format(line.thread, line.code)) + // XXX no way to suppress the deprecation warning + line.thread.stop() + in.redrawLine() + } + } + override protected def parentClassLoader = { + SparkHelper.explicitParentLoader(settings).getOrElse( classOf[SparkILoop].getClassLoader ) + } + } + + /** Create a new interpreter. */ + def createInterpreter() { + if (addedClasspath != "") + settings.classpath append addedClasspath + + intp = new SparkILoopInterpreter + intp.setContextClassLoader() + installSigIntHandler() + } + + /** print a friendly help message */ + def helpCommand(line: String): Result = { + if (line == "") helpSummary() + else uniqueCommand(line) match { + case Some(lc) => echo("\n" + lc.longHelp) + case _ => ambiguousError(line) + } + } + private def helpSummary() = { + val usageWidth = commands map (_.usageMsg.length) max + val formatStr = "%-" + usageWidth + "s %s %s" + + echo("All commands can be abbreviated, e.g. :he instead of :help.") + echo("Those marked with a * have more detailed help, e.g. :help imports.\n") + + commands foreach { cmd => + val star = if (cmd.hasLongHelp) "*" else " " + echo(formatStr.format(cmd.usageMsg, star, cmd.help)) + } + } + private def ambiguousError(cmd: String): Result = { + matchingCommands(cmd) match { + case Nil => echo(cmd + ": no such command. Type :help for help.") + case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?") + } + Result(true, None) + } + private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd) + private def uniqueCommand(cmd: String): Option[LoopCommand] = { + // this lets us add commands willy-nilly and only requires enough command to disambiguate + matchingCommands(cmd) match { + case List(x) => Some(x) + // exact match OK even if otherwise appears ambiguous + case xs => xs find (_.name == cmd) + } + } + + /** Print a welcome message */ + def printWelcome() { + echo("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version 0.8.0 + /_/ +""") + import Properties._ + val welcomeMsg = "Using Scala %s (%s, Java %s)".format( + versionString, javaVmName, javaVersion) + echo(welcomeMsg) + } + + /** Show the history */ + lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { + override def usage = "[num]" + def defaultLines = 20 + + def apply(line: String): Result = { + if (history eq NoHistory) + return "No history available." + + val xs = words(line) + val current = history.index + val count = try xs.head.toInt catch { case _: Exception => defaultLines } + val lines = history.asStrings takeRight count + val offset = current - lines.size + 1 + + for ((line, index) <- lines.zipWithIndex) + echo("%3d %s".format(index + offset, line)) + } + } + + private def echo(msg: String) = { + out println msg + out.flush() + } + private def echoNoNL(msg: String) = { + out print msg + out.flush() + } + + /** Search the history */ + def searchHistory(_cmdline: String) { + val cmdline = _cmdline.toLowerCase + val offset = history.index - history.size + 1 + + for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline) + echo("%d %s".format(index + offset, line)) + } + + private var currentPrompt = Properties.shellPromptString + def setPrompt(prompt: String) = currentPrompt = prompt + /** Prompt to print when awaiting input */ + def prompt = currentPrompt + + import LoopCommand.{ cmd, nullary } + + /** Standard commands **/ + lazy val standardCommands = List( + cmd("cp", "", "add a jar or directory to the classpath", addClasspath), + cmd("help", "[command]", "print this summary or command-specific help", helpCommand), + historyCommand, + cmd("h?", "", "search the history", searchHistory), + cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand), + cmd("implicits", "[-v]", "show the implicits in scope", implicitsCommand), + cmd("javap", "", "disassemble a file or class name", javapCommand), + nullary("keybindings", "show how ctrl-[A-Z] and other keys are bound", keybindingsCommand), + cmd("load", "", "load and interpret a Scala file", loadCommand), + nullary("paste", "enter paste mode: all input up to ctrl-D compiled together", pasteCommand), + //nullary("power", "enable power user mode", powerCmd), + nullary("quit", "exit the interpreter", () => Result(false, None)), + nullary("replay", "reset execution and replay all previous commands", replay), + shCommand, + nullary("silent", "disable/enable automatic printing of results", verbosity), + cmd("type", "", "display the type of an expression without evaluating it", typeCommand) + ) + + /** Power user commands */ + lazy val powerCommands: List[LoopCommand] = List( + //nullary("dump", "displays a view of the interpreter's internal state", dumpCommand), + //cmd("phase", "", "set the implicit phase for power commands", phaseCommand), + cmd("wrap", "", "name of method to wrap around each repl line", wrapCommand) withLongHelp (""" + |:wrap + |:wrap clear + |:wrap + | + |Installs a wrapper around each line entered into the repl. + |Currently it must be the simple name of an existing method + |with the specific signature shown in the following example. + | + |def timed[T](body: => T): T = { + | val start = System.nanoTime + | try body + | finally println((System.nanoTime - start) + " nanos elapsed.") + |} + |:wrap timed + | + |If given no argument, :wrap names the wrapper installed. + |An argument of clear will remove the wrapper if any is active. + |Note that wrappers do not compose (a new one replaces the old + |one) and also that the :phase command uses the same machinery, + |so setting :wrap will clear any :phase setting. + """.stripMargin.trim) + ) + + /* + private def dumpCommand(): Result = { + echo("" + power) + history.asStrings takeRight 30 foreach echo + in.redrawLine() + } + */ + + private val typeTransforms = List( + "scala.collection.immutable." -> "immutable.", + "scala.collection.mutable." -> "mutable.", + "scala.collection.generic." -> "generic.", + "java.lang." -> "jl.", + "scala.runtime." -> "runtime." + ) + + private def importsCommand(line: String): Result = { + val tokens = words(line) + val handlers = intp.languageWildcardHandlers ++ intp.importHandlers + val isVerbose = tokens contains "-v" + + handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach { + case (handler, idx) => + val (types, terms) = handler.importedSymbols partition (_.name.isTypeName) + val imps = handler.implicitSymbols + val found = tokens filter (handler importsSymbolNamed _) + val typeMsg = if (types.isEmpty) "" else types.size + " types" + val termMsg = if (terms.isEmpty) "" else terms.size + " terms" + val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit" + val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "") + val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")") + + intp.reporter.printMessage("%2d) %-30s %s%s".format( + idx + 1, + handler.importString, + statsMsg, + foundMsg + )) + } + } + + private def implicitsCommand(line: String): Result = { + val intp = SparkILoop.this.intp + import intp._ + import global.Symbol + + def p(x: Any) = intp.reporter.printMessage("" + x) + + // If an argument is given, only show a source with that + // in its name somewhere. + val args = line split "\\s+" + val filtered = intp.implicitSymbolsBySource filter { + case (source, syms) => + (args contains "-v") || { + if (line == "") (source.fullName.toString != "scala.Predef") + else (args exists (source.name.toString contains _)) + } + } + + if (filtered.isEmpty) + return "No implicits have been imported other than those in Predef." + + filtered foreach { + case (source, syms) => + p("/* " + syms.size + " implicit members imported from " + source.fullName + " */") + + // This groups the members by where the symbol is defined + val byOwner = syms groupBy (_.owner) + val sortedOwners = byOwner.toList sortBy { case (owner, _) => intp.afterTyper(source.info.baseClasses indexOf owner) } + + sortedOwners foreach { + case (owner, members) => + // Within each owner, we cluster results based on the final result type + // if there are more than a couple, and sort each cluster based on name. + // This is really just trying to make the 100 or so implicits imported + // by default into something readable. + val memberGroups: List[List[Symbol]] = { + val groups = members groupBy (_.tpe.finalResultType) toList + val (big, small) = groups partition (_._2.size > 3) + val xss = ( + (big sortBy (_._1.toString) map (_._2)) :+ + (small flatMap (_._2)) + ) + + xss map (xs => xs sortBy (_.name.toString)) + } + + val ownerMessage = if (owner == source) " defined in " else " inherited from " + p(" /* " + members.size + ownerMessage + owner.fullName + " */") + + memberGroups foreach { group => + group foreach (s => p(" " + intp.symbolDefString(s))) + p("") + } + } + p("") + } + } + + protected def newJavap() = new Javap(intp.classLoader, new SparkIMain.ReplStrippingWriter(intp)) { + override def tryClass(path: String): Array[Byte] = { + // Look for Foo first, then Foo$, but if Foo$ is given explicitly, + // we have to drop the $ to find object Foo, then tack it back onto + // the end of the flattened name. + def className = intp flatName path + def moduleName = (intp flatName path.stripSuffix("$")) + "$" + + val bytes = super.tryClass(className) + if (bytes.nonEmpty) bytes + else super.tryClass(moduleName) + } + } + private lazy val javap = + try newJavap() + catch { case _: Exception => null } + + private def typeCommand(line: String): Result = { + intp.typeOfExpression(line) match { + case Some(tp) => tp.toString + case _ => "Failed to determine type." + } + } + + private def javapCommand(line: String): Result = { + if (javap == null) + return ":javap unavailable on this platform." + if (line == "") + return ":javap [-lcsvp] [path1 path2 ...]" + + javap(words(line)) foreach { res => + if (res.isError) return "Failed: " + res.value + else res.show() + } + } + private def keybindingsCommand(): Result = { + if (in.keyBindings.isEmpty) "Key bindings unavailable." + else { + echo("Reading jline properties for default key bindings.") + echo("Accuracy not guaranteed: treat this as a guideline only.\n") + in.keyBindings foreach (x => echo ("" + x)) + } + } + private def wrapCommand(line: String): Result = { + def failMsg = "Argument to :wrap must be the name of a method with signature [T](=> T): T" + val intp = SparkILoop.this.intp + val g: intp.global.type = intp.global + import g._ + + words(line) match { + case Nil => + intp.executionWrapper match { + case "" => "No execution wrapper is set." + case s => "Current execution wrapper: " + s + } + case "clear" :: Nil => + intp.executionWrapper match { + case "" => "No execution wrapper is set." + case s => intp.clearExecutionWrapper() ; "Cleared execution wrapper." + } + case wrapper :: Nil => + intp.typeOfExpression(wrapper) match { + case Some(PolyType(List(targ), MethodType(List(arg), restpe))) => + intp setExecutionWrapper intp.pathToTerm(wrapper) + "Set wrapper to '" + wrapper + "'" + case Some(x) => + failMsg + "\nFound: " + x + case _ => + failMsg + "\nFound: " + } + case _ => failMsg + } + } + + private def pathToPhaseWrapper = intp.pathToTerm("$r") + ".phased.atCurrent" + /* + private def phaseCommand(name: String): Result = { + // This line crashes us in TreeGen: + // + // if (intp.power.phased set name) "..." + // + // Exception in thread "main" java.lang.AssertionError: assertion failed: ._7.type + // at scala.Predef$.assert(Predef.scala:99) + // at scala.tools.nsc.ast.TreeGen.mkAttributedQualifier(TreeGen.scala:69) + // at scala.tools.nsc.ast.TreeGen.mkAttributedQualifier(TreeGen.scala:44) + // at scala.tools.nsc.ast.TreeGen.mkAttributedRef(TreeGen.scala:101) + // at scala.tools.nsc.ast.TreeGen.mkAttributedStableRef(TreeGen.scala:143) + // + // But it works like so, type annotated. + val phased: Phased = power.phased + import phased.NoPhaseName + + if (name == "clear") { + phased.set(NoPhaseName) + intp.clearExecutionWrapper() + "Cleared active phase." + } + else if (name == "") phased.get match { + case NoPhaseName => "Usage: :phase (e.g. typer, erasure.next, erasure+3)" + case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get) + } + else { + val what = phased.parse(name) + if (what.isEmpty || !phased.set(what)) + "'" + name + "' does not appear to represent a valid phase." + else { + intp.setExecutionWrapper(pathToPhaseWrapper) + val activeMessage = + if (what.toString.length == name.length) "" + what + else "%s (%s)".format(what, name) + + "Active phase is now: " + activeMessage + } + } + } + */ + + /** Available commands */ + def commands: List[LoopCommand] = standardCommands /* ++ ( + if (isReplPower) powerCommands else Nil + )*/ + + val replayQuestionMessage = + """|The repl compiler has crashed spectacularly. Shall I replay your + |session? I can re-run all lines except the last one. + |[y/n] + """.trim.stripMargin + + private val crashRecovery: PartialFunction[Throwable, Unit] = { + case ex: Throwable => + if (settings.YrichExes.value) { + val sources = implicitly[Sources] + echo("\n" + ex.getMessage) + echo( + if (isReplDebug) "[searching " + sources.path + " for exception contexts...]" + else "[searching for exception contexts...]" + ) + echo(Exceptional(ex).force().context()) + } + else { + echo(util.stackTraceString(ex)) + } + ex match { + case _: NoSuchMethodError | _: NoClassDefFoundError => + echo("Unrecoverable error.") + throw ex + case _ => + def fn(): Boolean = in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) + if (fn()) replay() + else echo("\nAbandoning crashed session.") + } + } + + /** The main read-eval-print loop for the repl. It calls + * command() for each line of input, and stops when + * command() returns false. + */ + def loop() { + def readOneLine() = { + out.flush() + in readLine prompt + } + // return false if repl should exit + def processLine(line: String): Boolean = + if (line eq null) false // assume null means EOF + else command(line) match { + case Result(false, _) => false + case Result(_, Some(finalLine)) => addReplay(finalLine) ; true + case _ => true + } + + while (true) { + try if (!processLine(readOneLine)) return + catch crashRecovery + } + } + + /** interpret all lines from a specified file */ + def interpretAllFrom(file: File) { + val oldIn = in + val oldReplay = replayCommandStack + + try file applyReader { reader => + in = SimpleReader(reader, out, false) + echo("Loading " + file + "...") + loop() + } + finally { + in = oldIn + replayCommandStack = oldReplay + } + } + + /** create a new interpreter and replay all commands so far */ + def replay() { + closeInterpreter() + createInterpreter() + for (cmd <- replayCommands) { + echo("Replaying: " + cmd) // flush because maybe cmd will have its own output + command(cmd) + echo("") + } + } + + /** fork a shell and run a command */ + lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { + override def usage = "" + def apply(line: String): Result = line match { + case "" => showUsage() + case _ => + val toRun = classOf[ProcessResult].getName + "(" + string2codeQuoted(line) + ")" + intp interpret toRun + () + } + } + + def withFile(filename: String)(action: File => Unit) { + val f = File(filename) + + if (f.exists) action(f) + else echo("That file does not exist") + } + + def loadCommand(arg: String) = { + var shouldReplay: Option[String] = None + withFile(arg)(f => { + interpretAllFrom(f) + shouldReplay = Some(":load " + arg) + }) + Result(true, shouldReplay) + } + + def addClasspath(arg: String): Unit = { + val f = File(arg).normalize + if (f.exists) { + addedClasspath = ClassPath.join(addedClasspath, f.path) + val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath) + echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, totalClasspath)) + replay() + } + else echo("The path '" + f + "' doesn't seem to exist.") + } + + def powerCmd(): Result = { + if (isReplPower) "Already in power mode." + else enablePowerMode() + } + def enablePowerMode() = { + //replProps.power setValue true + //power.unleash() + //echo(power.banner) + } + + def verbosity() = { + val old = intp.printResults + intp.printResults = !old + echo("Switched " + (if (old) "off" else "on") + " result printing.") + } + + /** Run one command submitted by the user. Two values are returned: + * (1) whether to keep running, (2) the line to record for replay, + * if any. */ + def command(line: String): Result = { + if (line startsWith ":") { + val cmd = line.tail takeWhile (x => !x.isWhitespace) + uniqueCommand(cmd) match { + case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace)) + case _ => ambiguousError(cmd) + } + } + else if (intp.global == null) Result(false, None) // Notice failure to create compiler + else Result(true, interpretStartingWith(line)) + } + + private def readWhile(cond: String => Boolean) = { + Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) + } + + def pasteCommand(): Result = { + echo("// Entering paste mode (ctrl-D to finish)\n") + val code = readWhile(_ => true) mkString "\n" + echo("\n// Exiting paste mode, now interpreting.\n") + intp interpret code + () + } + + private object paste extends Pasted { + val ContinueString = " | " + val PromptString = "scala> " + + def interpret(line: String): Unit = { + echo(line.trim) + intp interpret line + echo("") + } + + def transcript(start: String) = { + // Printing this message doesn't work very well because it's buried in the + // transcript they just pasted. Todo: a short timer goes off when + // lines stop coming which tells them to hit ctrl-D. + // + // echo("// Detected repl transcript paste: ctrl-D to finish.") + apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim)) + } + } + import paste.{ ContinueString, PromptString } + + /** Interpret expressions starting with the first line. + * Read lines until a complete compilation unit is available + * or until a syntax error has been seen. If a full unit is + * read, go ahead and interpret it. Return the full string + * to be recorded for replay, if any. + */ + def interpretStartingWith(code: String): Option[String] = { + // signal completion non-completion input has been received + in.completion.resetVerbosity() + + def reallyInterpret = { + val reallyResult = intp.interpret(code) + (reallyResult, reallyResult match { + case IR.Error => None + case IR.Success => Some(code) + case IR.Incomplete => + if (in.interactive && code.endsWith("\n\n")) { + echo("You typed two blank lines. Starting a new command.") + None + } + else in.readLine(ContinueString) match { + case null => + // we know compilation is going to fail since we're at EOF and the + // parser thinks the input is still incomplete, but since this is + // a file being read non-interactively we want to fail. So we send + // it straight to the compiler for the nice error message. + intp.compileString(code) + None + + case line => interpretStartingWith(code + "\n" + line) + } + }) + } + + /** Here we place ourselves between the user and the interpreter and examine + * the input they are ostensibly submitting. We intervene in several cases: + * + * 1) If the line starts with "scala> " it is assumed to be an interpreter paste. + * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation + * on the previous result. + * 3) If the Completion object's execute returns Some(_), we inject that value + * and avoid the interpreter, as it's likely not valid scala code. + */ + if (code == "") None + else if (!paste.running && code.trim.startsWith(PromptString)) { + paste.transcript(code) + None + } + else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") { + interpretStartingWith(intp.mostRecentVar + code) + } + else { + def runCompletion = in.completion execute code map (intp bindValue _) + /** Due to my accidentally letting file completion execution sneak ahead + * of actual parsing this now operates in such a way that the scala + * interpretation always wins. However to avoid losing useful file + * completion I let it fail and then check the others. So if you + * type /tmp it will echo a failure and then give you a Directory object. + * It's not pretty: maybe I'll implement the silence bits I need to avoid + * echoing the failure. + */ + if (intp isParseable code) { + val (code, result) = reallyInterpret + //if (power != null && code == IR.Error) + // runCompletion + + result + } + else runCompletion match { + case Some(_) => None // completion hit: avoid the latent error + case _ => reallyInterpret._2 // trigger the latent error + } + } + } + + // runs :load `file` on any files passed via -i + def loadFiles(settings: Settings) = settings match { + case settings: GenericRunnerSettings => + for (filename <- settings.loadfiles.value) { + val cmd = ":load " + filename + command(cmd) + addReplay(cmd) + echo("") + } + case _ => + } + + /** Tries to create a JLineReader, falling back to SimpleReader: + * unless settings or properties are such that it should start + * with SimpleReader. + */ + def chooseReader(settings: Settings): InteractiveReader = { + if (settings.Xnojline.value || Properties.isEmacsShell) + SimpleReader() + else try SparkJLineReader( + if (settings.noCompletion.value) NoCompletion + else new SparkJLineCompletion(intp) + ) + catch { + case ex @ (_: Exception | _: NoClassDefFoundError) => + echo("Failed to created SparkJLineReader: " + ex + "\nFalling back to SimpleReader.") + SimpleReader() + } + } + + def initializeSpark() { + intp.beQuietDuring { + command(""" + org.apache.spark.repl.Main.interp.out.println("Creating SparkContext..."); + org.apache.spark.repl.Main.interp.out.flush(); + @transient val sc = org.apache.spark.repl.Main.interp.createSparkContext(); + org.apache.spark.repl.Main.interp.out.println("Spark context available as sc."); + org.apache.spark.repl.Main.interp.out.flush(); + """) + command("import org.apache.spark.SparkContext._") + } + echo("Type in expressions to have them evaluated.") + echo("Type :help for more information.") + } + + var sparkContext: SparkContext = null + + def createSparkContext(): SparkContext = { + val uri = System.getenv("SPARK_EXECUTOR_URI") + if (uri != null) { + System.setProperty("spark.executor.uri", uri) + } + val master = this.master match { + case Some(m) => m + case None => { + val prop = System.getenv("MASTER") + if (prop != null) prop else "local" + } + } + val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')) + .getOrElse(new Array[String](0)) + .map(new java.io.File(_).getAbsolutePath) + sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars) + sparkContext + } + + def process(settings: Settings): Boolean = { + // Ensure logging is initialized before any Spark threads try to use logs + // (because SLF4J initialization is not thread safe) + initLogging() + + printWelcome() + echo("Initializing interpreter...") + + // Add JARS specified in Spark's ADD_JARS variable to classpath + val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')).getOrElse(new Array[String](0)) + jars.foreach(settings.classpath.append(_)) + + this.settings = settings + createInterpreter() + + // sets in to some kind of reader depending on environmental cues + in = in0 match { + case Some(reader) => SimpleReader(reader, out, true) + case None => chooseReader(settings) + } + + loadFiles(settings) + // it is broken on startup; go ahead and exit + if (intp.reporter.hasErrors) + return false + + try { + // this is about the illusion of snappiness. We call initialize() + // which spins off a separate thread, then print the prompt and try + // our best to look ready. Ideally the user will spend a + // couple seconds saying "wow, it starts so fast!" and by the time + // they type a command the compiler is ready to roll. + intp.initialize() + initializeSpark() + if (isReplPower) { + echo("Starting in power mode, one moment...\n") + enablePowerMode() + } + loop() + } + finally closeInterpreter() + true + } + + /** process command-line arguments and do as they request */ + def process(args: Array[String]): Boolean = { + val command = new CommandLine(args.toList, msg => echo("scala: " + msg)) + def neededHelp(): String = + (if (command.settings.help.value) command.usageMsg + "\n" else "") + + (if (command.settings.Xhelp.value) command.xusageMsg + "\n" else "") + + // if they asked for no help and command is valid, we call the real main + neededHelp() match { + case "" => command.ok && process(command.settings) + case help => echoNoNL(help) ; true + } + } + + @deprecated("Use `process` instead", "2.9.0") + def main(args: Array[String]): Unit = { + if (isReplDebug) + System.out.println(new java.util.Date) + + process(args) + } + @deprecated("Use `process` instead", "2.9.0") + def main(settings: Settings): Unit = process(settings) +} + +object SparkILoop { + implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp + private def echo(msg: String) = Console println msg + + // Designed primarily for use by test code: take a String with a + // bunch of code, and prints out a transcript of what it would look + // like if you'd just typed it into the repl. + def runForTranscript(code: String, settings: Settings): String = { + import java.io.{ BufferedReader, StringReader, OutputStreamWriter } + + stringFromStream { ostream => + Console.withOut(ostream) { + val output = new PrintWriter(new OutputStreamWriter(ostream), true) { + override def write(str: String) = { + // completely skip continuation lines + if (str forall (ch => ch.isWhitespace || ch == '|')) () + // print a newline on empty scala prompts + else if ((str contains '\n') && (str.trim == "scala> ")) super.write("\n") + else super.write(str) + } + } + val input = new BufferedReader(new StringReader(code)) { + override def readLine(): String = { + val s = super.readLine() + // helping out by printing the line being interpreted. + if (s != null) + output.println(s) + s + } + } + val repl = new SparkILoop(input, output) + if (settings.classpath.isDefault) + settings.classpath.value = sys.props("java.class.path") + + repl process settings + } + } + } + + /** Creates an interpreter loop with default settings and feeds + * the given code to it as input. + */ + def run(code: String, sets: Settings = new Settings): String = { + import java.io.{ BufferedReader, StringReader, OutputStreamWriter } + + stringFromStream { ostream => + Console.withOut(ostream) { + val input = new BufferedReader(new StringReader(code)) + val output = new PrintWriter(new OutputStreamWriter(ostream), true) + val repl = new SparkILoop(input, output) + + if (sets.classpath.isDefault) + sets.classpath.value = sys.props("java.class.path") + + repl process sets + } + } + } + def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) + + // provide the enclosing type T + // in order to set up the interpreter's classpath and parent class loader properly + def breakIf[T: Manifest](assertion: => Boolean, args: NamedParam*): Unit = + if (assertion) break[T](args.toList) + + // start a repl, binding supplied args + def break[T: Manifest](args: List[NamedParam]): Unit = { + val msg = if (args.isEmpty) "" else " Binding " + args.size + " value%s.".format( + if (args.size == 1) "" else "s" + ) + echo("Debug repl starting." + msg) + val repl = new SparkILoop { + override def prompt = "\ndebug> " + } + repl.settings = new Settings(echo) + repl.settings.embeddedDefaults[T] + repl.createInterpreter() + repl.in = SparkJLineReader(repl) + + // rebind exit so people don't accidentally call sys.exit by way of predef + repl.quietRun("""def exit = println("Type :quit to resume program execution.")""") + args foreach (p => repl.bind(p.name, p.tpe, p.value)) + repl.loop() + + echo("\nDebug repl exiting.") + repl.closeInterpreter() + } +} diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala new file mode 100644 index 0000000000..7e244e48a2 --- /dev/null +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -0,0 +1,1160 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2011 LAMP/EPFL + * @author Martin Odersky + */ + +package org.apache.spark.repl + +import scala.tools.nsc._ +import scala.tools.nsc.interpreter._ + +import Predef.{ println => _, _ } +import java.io.{ PrintWriter } +import java.lang.reflect +import java.net.URL +import util.{ Set => _, _ } +import io.{ AbstractFile, PlainFile, VirtualDirectory } +import reporters.{ ConsoleReporter, Reporter } +import symtab.{ Flags, Names } +import scala.tools.nsc.interpreter.{ Results => IR } +import scala.tools.util.PathResolver +import scala.tools.nsc.util.{ ScalaClassLoader, Exceptional } +import ScalaClassLoader.URLClassLoader +import Exceptional.unwrap +import scala.collection.{ mutable, immutable } +import scala.PartialFunction.{ cond, condOpt } +import scala.util.control.Exception.{ ultimately } +import scala.reflect.NameTransformer +import SparkIMain._ + +import org.apache.spark.HttpServer +import org.apache.spark.Utils +import org.apache.spark.SparkEnv + +/** An interpreter for Scala code. + * + * The main public entry points are compile(), interpret(), and bind(). + * The compile() method loads a complete Scala file. The interpret() method + * executes one line of Scala code at the request of the user. The bind() + * method binds an object to a variable that can then be used by later + * interpreted code. + * + * The overall approach is based on compiling the requested code and then + * using a Java classloader and Java reflection to run the code + * and access its results. + * + * In more detail, a single compiler instance is used + * to accumulate all successfully compiled or interpreted Scala code. To + * "interpret" a line of code, the compiler generates a fresh object that + * includes the line of code and which has public member(s) to export + * all variables defined by that code. To extract the result of an + * interpreted line to show the user, a second "result object" is created + * which imports the variables exported by the above object and then + * exports a single member named "$export". To accomodate user expressions + * that read from variables or methods defined in previous statements, "import" + * statements are used. + * + * This interpreter shares the strengths and weaknesses of using the + * full compiler-to-Java. The main strength is that interpreted code + * behaves exactly as does compiled code, including running at full speed. + * The main weakness is that redefining classes and methods is not handled + * properly, because rebinding at the Java level is technically difficult. + * + * @author Moez A. Abdel-Gawad + * @author Lex Spoon + */ +class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends SparkImports { + imain => + + /** construct an interpreter that reports to Console */ + def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) + def this() = this(new Settings()) + + /** whether to print out result lines */ + var printResults: Boolean = true + + /** whether to print errors */ + var totalSilence: Boolean = false + + private val RESULT_OBJECT_PREFIX = "RequestResult$" + + lazy val formatting: Formatting = new Formatting { + val prompt = Properties.shellPromptString + } + import formatting._ + + val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") + + /** Local directory to save .class files too */ + val outputDir = { + val tmp = System.getProperty("java.io.tmpdir") + val rootDir = System.getProperty("spark.repl.classdir", tmp) + Utils.createTempDir(rootDir) + } + if (SPARK_DEBUG_REPL) { + echo("Output directory: " + outputDir) + } + + /** Scala compiler virtual directory for outputDir */ + val virtualDirectory = new PlainFile(outputDir) + + /** Jetty server that will serve our classes to worker nodes */ + val classServer = new HttpServer(outputDir) + + // Start the classServer and store its URI in a spark system property + // (which will be passed to executors so that they can connect to it) + classServer.start() + System.setProperty("spark.repl.class.uri", classServer.uri) + if (SPARK_DEBUG_REPL) { + echo("Class server started, URI = " + classServer.uri) + } + + /* + // directory to save .class files to + val virtualDirectory = new VirtualDirectory("(memory)", None) { + private def pp(root: io.AbstractFile, indentLevel: Int) { + val spaces = " " * indentLevel + out.println(spaces + root.name) + if (root.isDirectory) + root.toList sortBy (_.name) foreach (x => pp(x, indentLevel + 1)) + } + // print the contents hierarchically + def show() = pp(this, 0) + } + */ + + /** reporter */ + lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this) + import reporter.{ printMessage, withoutTruncating } + + // not sure if we have some motivation to print directly to console + private def echo(msg: String) { Console println msg } + + // protected def defaultImports: List[String] = List("_root_.scala.sys.exit") + + /** We're going to go to some trouble to initialize the compiler asynchronously. + * It's critical that nothing call into it until it's been initialized or we will + * run into unrecoverable issues, but the perceived repl startup time goes + * through the roof if we wait for it. So we initialize it with a future and + * use a lazy val to ensure that any attempt to use the compiler object waits + * on the future. + */ + private val _compiler: Global = newCompiler(settings, reporter) + private var _initializeComplete = false + def isInitializeComplete = _initializeComplete + + private def _initialize(): Boolean = { + val source = """ + |class $repl_$init { + | List(1) map (_ + 1) + |} + |""".stripMargin + + val result = try { + new _compiler.Run() compileSources List(new BatchSourceFile("", source)) + if (isReplDebug || settings.debug.value) { + // Can't use printMessage here, it deadlocks + Console.println("Repl compiler initialized.") + } + // addImports(defaultImports: _*) + true + } + catch { + case x: AbstractMethodError => + printMessage(""" + |Failed to initialize compiler: abstract method error. + |This is most often remedied by a full clean and recompile. + |""".stripMargin + ) + x.printStackTrace() + false + case x: MissingRequirementError => printMessage(""" + |Failed to initialize compiler: %s not found. + |** Note that as of 2.8 scala does not assume use of the java classpath. + |** For the old behavior pass -usejavacp to scala, or if using a Settings + |** object programatically, settings.usejavacp.value = true.""".stripMargin.format(x.req) + ) + false + } + + try result + finally _initializeComplete = result + } + + // set up initialization future + private var _isInitialized: () => Boolean = null + def initialize() = synchronized { + if (_isInitialized == null) + _isInitialized = scala.concurrent.ops future _initialize() + } + + /** the public, go through the future compiler */ + lazy val global: Global = { + initialize() + + // blocks until it is ; false means catastrophic failure + if (_isInitialized()) _compiler + else null + } + @deprecated("Use `global` for access to the compiler instance.", "2.9.0") + lazy val compiler: global.type = global + + import global._ + + object naming extends { + val global: imain.global.type = imain.global + } with Naming { + // make sure we don't overwrite their unwisely named res3 etc. + override def freshUserVarName(): String = { + val name = super.freshUserVarName() + if (definedNameMap contains name) freshUserVarName() + else name + } + } + import naming._ + + // object dossiers extends { + // val intp: imain.type = imain + // } with Dossiers { } + // import dossiers._ + + lazy val memberHandlers = new { + val intp: imain.type = imain + } with SparkMemberHandlers + import memberHandlers._ + + def atPickler[T](op: => T): T = atPhase(currentRun.picklerPhase)(op) + def afterTyper[T](op: => T): T = atPhase(currentRun.typerPhase.next)(op) + + /** Temporarily be quiet */ + def beQuietDuring[T](operation: => T): T = { + val wasPrinting = printResults + ultimately(printResults = wasPrinting) { + if (isReplDebug) echo(">> beQuietDuring") + else printResults = false + + operation + } + } + def beSilentDuring[T](operation: => T): T = { + val saved = totalSilence + totalSilence = true + try operation + finally totalSilence = saved + } + + def quietRun[T](code: String) = beQuietDuring(interpret(code)) + + /** whether to bind the lastException variable */ + private var bindLastException = true + + /** A string representing code to be wrapped around all lines. */ + private var _executionWrapper: String = "" + def executionWrapper = _executionWrapper + def setExecutionWrapper(code: String) = _executionWrapper = code + def clearExecutionWrapper() = _executionWrapper = "" + + /** Temporarily stop binding lastException */ + def withoutBindingLastException[T](operation: => T): T = { + val wasBinding = bindLastException + ultimately(bindLastException = wasBinding) { + bindLastException = false + operation + } + } + + protected def createLineManager(): Line.Manager = new Line.Manager + lazy val lineManager = createLineManager() + + /** interpreter settings */ + lazy val isettings = new SparkISettings(this) + + /** Instantiate a compiler. Subclasses can override this to + * change the compiler class used by this interpreter. */ + protected def newCompiler(settings: Settings, reporter: Reporter) = { + settings.outputDirs setSingleOutput virtualDirectory + settings.exposeEmptyPackage.value = true + new Global(settings, reporter) + } + + /** the compiler's classpath, as URL's */ + lazy val compilerClasspath: List[URL] = new PathResolver(settings) asURLs + + /* A single class loader is used for all commands interpreted by this Interpreter. + It would also be possible to create a new class loader for each command + to interpret. The advantages of the current approach are: + + - Expressions are only evaluated one time. This is especially + significant for I/O, e.g. "val x = Console.readLine" + + The main disadvantage is: + + - Objects, classes, and methods cannot be rebound. Instead, definitions + shadow the old ones, and old code objects refer to the old + definitions. + */ + private var _classLoader: AbstractFileClassLoader = null + def resetClassLoader() = _classLoader = makeClassLoader() + def classLoader: AbstractFileClassLoader = { + if (_classLoader == null) + resetClassLoader() + + _classLoader + } + private def makeClassLoader(): AbstractFileClassLoader = { + val parent = + if (parentClassLoader == null) ScalaClassLoader fromURLs compilerClasspath + else new URLClassLoader(compilerClasspath, parentClassLoader) + + new AbstractFileClassLoader(virtualDirectory, parent) { + /** Overridden here to try translating a simple name to the generated + * class name if the original attempt fails. This method is used by + * getResourceAsStream as well as findClass. + */ + override protected def findAbstractFile(name: String): AbstractFile = { + super.findAbstractFile(name) match { + // deadlocks on startup if we try to translate names too early + case null if isInitializeComplete => generatedName(name) map (x => super.findAbstractFile(x)) orNull + case file => file + } + } + } + } + private def loadByName(s: String): JClass = + (classLoader tryToInitializeClass s) getOrElse sys.error("Failed to load expected class: '" + s + "'") + + protected def parentClassLoader: ClassLoader = + SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() ) + + def getInterpreterClassLoader() = classLoader + + // Set the current Java "context" class loader to this interpreter's class loader + def setContextClassLoader() = classLoader.setAsContext() + + /** Given a simple repl-defined name, returns the real name of + * the class representing it, e.g. for "Bippy" it may return + * + * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy + */ + def generatedName(simpleName: String): Option[String] = { + if (simpleName endsWith "$") optFlatName(simpleName.init) map (_ + "$") + else optFlatName(simpleName) + } + def flatName(id: String) = optFlatName(id) getOrElse id + def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id) + + def allDefinedNames = definedNameMap.keys.toList sortBy (_.toString) + def pathToType(id: String): String = pathToName(newTypeName(id)) + def pathToTerm(id: String): String = pathToName(newTermName(id)) + def pathToName(name: Name): String = { + if (definedNameMap contains name) + definedNameMap(name) fullPath name + else name.toString + } + + /** Most recent tree handled which wasn't wholly synthetic. */ + private def mostRecentlyHandledTree: Option[Tree] = { + prevRequests.reverse foreach { req => + req.handlers.reverse foreach { + case x: MemberDefHandler if x.definesValue && !isInternalVarName(x.name.toString) => return Some(x.member) + case _ => () + } + } + None + } + + /** Stubs for work in progress. */ + def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = { + for (t1 <- old.simpleNameOfType(name) ; t2 <- req.simpleNameOfType(name)) { + DBG("Redefining type '%s'\n %s -> %s".format(name, t1, t2)) + } + } + + def handleTermRedefinition(name: TermName, old: Request, req: Request) = { + for (t1 <- old.compilerTypeOf get name ; t2 <- req.compilerTypeOf get name) { + // Printing the types here has a tendency to cause assertion errors, like + // assertion failed: fatal: has owner value x, but a class owner is required + // so DBG is by-name now to keep it in the family. (It also traps the assertion error, + // but we don't want to unnecessarily risk hosing the compiler's internal state.) + DBG("Redefining term '%s'\n %s -> %s".format(name, t1, t2)) + } + } + def recordRequest(req: Request) { + if (req == null || referencedNameMap == null) + return + + prevRequests += req + req.referencedNames foreach (x => referencedNameMap(x) = req) + + // warning about serially defining companions. It'd be easy + // enough to just redefine them together but that may not always + // be what people want so I'm waiting until I can do it better. + if (!settings.nowarnings.value) { + for { + name <- req.definedNames filterNot (x => req.definedNames contains x.companionName) + oldReq <- definedNameMap get name.companionName + newSym <- req.definedSymbols get name + oldSym <- oldReq.definedSymbols get name.companionName + } { + printMessage("warning: previously defined %s is not a companion to %s.".format(oldSym, newSym)) + printMessage("Companions must be defined together; you may wish to use :paste mode for this.") + } + } + + // Updating the defined name map + req.definedNames foreach { name => + if (definedNameMap contains name) { + if (name.isTypeName) handleTypeRedefinition(name.toTypeName, definedNameMap(name), req) + else handleTermRedefinition(name.toTermName, definedNameMap(name), req) + } + definedNameMap(name) = req + } + } + + /** Parse a line into a sequence of trees. Returns None if the input is incomplete. */ + def parse(line: String): Option[List[Tree]] = { + var justNeedsMore = false + reporter.withIncompleteHandler((pos,msg) => {justNeedsMore = true}) { + // simple parse: just parse it, nothing else + def simpleParse(code: String): List[Tree] = { + reporter.reset() + val unit = new CompilationUnit(new BatchSourceFile("", code)) + val scanner = new syntaxAnalyzer.UnitParser(unit) + + scanner.templateStatSeq(false)._2 + } + val trees = simpleParse(line) + + if (reporter.hasErrors) Some(Nil) // the result did not parse, so stop + else if (justNeedsMore) None + else Some(trees) + } + } + + def isParseable(line: String): Boolean = { + beSilentDuring { + parse(line) match { + case Some(xs) => xs.nonEmpty // parses as-is + case None => true // incomplete + } + } + } + + /** Compile an nsc SourceFile. Returns true if there are + * no compilation errors, or false otherwise. + */ + def compileSources(sources: SourceFile*): Boolean = { + reporter.reset() + new Run() compileSources sources.toList + !reporter.hasErrors + } + + /** Compile a string. Returns true if there are no + * compilation errors, or false otherwise. + */ + def compileString(code: String): Boolean = + compileSources(new BatchSourceFile("