aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReza Zadeh <rizlar@gmail.com>2014-01-13 23:52:34 -0800
committerReza Zadeh <rizlar@gmail.com>2014-01-13 23:52:34 -0800
commit845e568fada0550e632e7381748c5a9ebbe53e16 (patch)
tree3a4fa34894df649b5ef429cd794b73cf4b3e99b1
parentf324d5355514b1c7ae85019b476046bb64b5593e (diff)
parentfdaabdc67387524ffb84354f87985f48bd31cf60 (diff)
downloadspark-845e568fada0550e632e7381748c5a9ebbe53e16.tar.gz
spark-845e568fada0550e632e7381748c5a9ebbe53e16.tar.bz2
spark-845e568fada0550e632e7381748c5a9ebbe53e16.zip
Merge remote-tracking branch 'upstream/master' into sparsesvd
-rw-r--r--bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala1
-rwxr-xr-xbin/compute-classpath.sh2
-rw-r--r--core/src/main/scala/org/apache/spark/Accumulators.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/Aggregator.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/CacheManager.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/FutureAction.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/HttpFileServer.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/InterruptibleIterator.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/Logging.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/Partitioner.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala75
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/Client.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/network/BufferMessage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/network/Connection.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/network/Message.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/package.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Pool.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockMessage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/MemoryStore.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StorageLevel.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala53
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/util/CompletionIterator.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/util/SizeEstimator.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala36
-rw-r--r--core/src/main/scala/org/apache/spark/util/Vector.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/BitSet.scala87
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala72
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala23
-rw-r--r--core/src/test/scala/org/apache/spark/LocalSparkContext.scala1
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala1
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala9
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala14
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala77
-rw-r--r--docs/_config.yml2
-rwxr-xr-xdocs/_layouts/global.html8
-rw-r--r--docs/_plugins/copy_api_dirs.rb2
-rw-r--r--docs/api.md1
-rw-r--r--docs/bagel-programming-guide.md10
-rw-r--r--docs/configuration.md13
-rw-r--r--docs/graphx-programming-guide.md1003
-rw-r--r--docs/img/data_parallel_vs_graph_parallel.pngbin0 -> 432725 bytes
-rw-r--r--docs/img/edge-cut.pngbin0 -> 12563 bytes
-rw-r--r--docs/img/edge_cut_vs_vertex_cut.pngbin0 -> 79745 bytes
-rw-r--r--docs/img/graph_analytics_pipeline.pngbin0 -> 427220 bytes
-rw-r--r--docs/img/graph_parallel.pngbin0 -> 92288 bytes
-rw-r--r--docs/img/graphx_figures.pptxbin0 -> 1123363 bytes
-rw-r--r--docs/img/graphx_logo.pngbin0 -> 40324 bytes
-rw-r--r--docs/img/graphx_performance_comparison.pngbin0 -> 166343 bytes
-rw-r--r--docs/img/property_graph.pngbin0 -> 225151 bytes
-rw-r--r--docs/img/tables_and_graphs.pngbin0 -> 166265 bytes
-rw-r--r--docs/img/triplet.pngbin0 -> 31489 bytes
-rw-r--r--docs/img/vertex-cut.pngbin0 -> 12246 bytes
-rw-r--r--docs/img/vertex_routing_edge_tables.pngbin0 -> 570007 bytes
-rw-r--r--docs/index.md4
-rw-r--r--docs/mllib-guide.md19
-rw-r--r--docs/python-programming-guide.md8
-rw-r--r--docs/streaming-programming-guide.md6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/LocalALS.scala8
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkALS.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala12
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala49
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala2
-rw-r--r--external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala4
-rw-r--r--external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala6
-rw-r--r--external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala4
-rw-r--r--external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala2
-rw-r--r--external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala7
-rw-r--r--external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala2
-rw-r--r--external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala3
-rw-r--r--graphx/data/followers.txt8
-rw-r--r--graphx/data/users.txt7
-rw-r--r--graphx/pom.xml67
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/Edge.scala45
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala44
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala102
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala49
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/Graph.scala405
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala31
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala72
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala301
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala103
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala139
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala347
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala220
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala45
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala42
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala379
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala98
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala195
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTable.scala65
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala395
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala261
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/package.scala7
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala136
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala38
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala147
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala138
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala94
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala76
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/package.scala18
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala117
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala218
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala153
-rw-r--r--graphx/src/test/resources/log4j.properties28
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala66
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala273
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala28
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala41
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala183
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala85
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala76
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala113
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala113
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala119
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala31
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala57
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala70
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala93
-rw-r--r--mllib/data/sample_naive_bayes_data.txt6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala46
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala65
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala4
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java72
-rw-r--r--pom.xml5
-rw-r--r--project/SparkBuild.scala21
-rw-r--r--python/pyspark/mllib/_common.py2
-rw-r--r--python/pyspark/mllib/classification.py77
-rw-r--r--python/pyspark/mllib/clustering.py11
-rw-r--r--python/pyspark/mllib/recommendation.py10
-rw-r--r--python/pyspark/mllib/regression.py35
-rw-r--r--python/pyspark/worker.py4
-rwxr-xr-xpython/run-tests5
-rw-r--r--repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala6
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala28
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala12
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala117
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala30
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala1
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala28
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/DStream.scala)166
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala)12
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala86
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala)7
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala10
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala17
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala9
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala81
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala141
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala18
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala40
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala23
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala (renamed from core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala)2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala13
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala13
-rw-r--r--streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java2
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala79
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala25
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala219
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala1
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala10
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala15
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala (renamed from core/src/test/scala/org/apache/spark/util/RateLimitedOutputStreamSuite.scala)2
-rw-r--r--tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala50
-rw-r--r--yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala7
-rw-r--r--yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala8
-rw-r--r--yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala7
-rw-r--r--yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala10
-rw-r--r--yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala2
-rw-r--r--yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala3
252 files changed, 8912 insertions, 903 deletions
diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
index 7b954a4775..9c37fadb78 100644
--- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
+++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
@@ -38,7 +38,6 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeo
}
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
}
test("halting by voting") {
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
index 0c82310421..278969655d 100755
--- a/bin/compute-classpath.sh
+++ b/bin/compute-classpath.sh
@@ -39,6 +39,7 @@ if [ -f "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-dep
CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes"
DEPS_ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar`
@@ -59,6 +60,7 @@ if [[ $SPARK_TESTING == 1 ]]; then
CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/test-classes"
CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/test-classes"
CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/test-classes"
CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/test-classes"
fi
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 5f73d234aa..2ba871a600 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -24,7 +24,7 @@ import scala.collection.generic.Growable
import org.apache.spark.serializer.JavaSerializer
/**
- * A datatype that can be accumulated, i.e. has an commutative and associative "add" operation,
+ * A datatype that can be accumulated, ie has an commutative and associative "add" operation,
* but where the result type, `R`, may be different from the element type being added, `T`.
*
* You must define how to add data, and how to merge two of these together. For some datatypes,
@@ -185,7 +185,7 @@ class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Ser
}
/**
- * A simpler value of [[org.apache.spark.Accumulable]] where the result type being accumulated is the same
+ * A simpler value of [[Accumulable]] where the result type being accumulated is the same
* as the types of elements being merged.
*
* @param initialValue initial value of accumulator
@@ -218,7 +218,7 @@ private object Accumulators {
def newId: Long = synchronized {
lastId += 1
- return lastId
+ lastId
}
def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized {
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index 8b30cd4bfe..6d439fdc68 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -32,9 +32,10 @@ case class Aggregator[K, V, C] (
mergeCombiners: (C, C) => C) {
private val sparkConf = SparkEnv.get.conf
- private val externalSorting = sparkConf.getBoolean("spark.shuffle.externalSorting", true)
+ private val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true)
- def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = {
+ def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]],
+ context: TaskContext) : Iterator[(K, C)] = {
if (!externalSorting) {
val combiners = new AppendOnlyMap[K,C]
var kv: Product2[K, V] = null
@@ -47,17 +48,18 @@ case class Aggregator[K, V, C] (
}
combiners.iterator
} else {
- val combiners =
- new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
+ val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
while (iter.hasNext) {
val (k, v) = iter.next()
combiners.insert(k, v)
}
+ context.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled
+ context.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled
combiners.iterator
}
}
- def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
+ def combineCombinersByKey(iter: Iterator[(K, C)], context: TaskContext) : Iterator[(K, C)] = {
if (!externalSorting) {
val combiners = new AppendOnlyMap[K,C]
var kc: Product2[K, C] = null
@@ -75,6 +77,8 @@ case class Aggregator[K, V, C] (
val (k, c) = iter.next()
combiners.insert(k, c)
}
+ context.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled
+ context.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled
combiners.iterator
}
}
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index 519ecde50a..8e5dd8a850 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -38,7 +38,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
blockManager.get(key) match {
case Some(values) =>
// Partition is already materialized, so just return its values
- return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
+ new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
case None =>
// Mark the split as loading (unless someone else marks it first)
@@ -74,7 +74,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
val elements = new ArrayBuffer[Any]
elements ++= computedValues
blockManager.put(key, elements, storageLevel, tellMaster = true)
- return elements.iterator.asInstanceOf[Iterator[T]]
+ elements.iterator.asInstanceOf[Iterator[T]]
} finally {
loading.synchronized {
loading.remove(key)
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index c6b4ac5192..d7d10285da 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -27,8 +27,8 @@ import org.apache.spark.rdd.RDD
/**
- * A future for the result of an action. This is an extension of the Scala Future interface to
- * support cancellation.
+ * A future for the result of an action to support cancellation. This is an extension of the
+ * Scala Future interface to support cancellation.
*/
trait FutureAction[T] extends Future[T] {
// Note that we redefine methods of the Future trait here explicitly so we can specify a different
@@ -86,7 +86,7 @@ trait FutureAction[T] extends Future[T] {
/**
- * The future holding the result of an action that triggers a single job. Examples include
+ * A [[FutureAction]] holding the result of an action that triggers a single job. Examples include
* count, collect, reduce.
*/
class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
@@ -150,7 +150,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
/**
- * A FutureAction for actions that could trigger multiple Spark jobs. Examples include take,
+ * A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take,
* takeSample. Cancellation works by setting the cancelled flag to true and interrupting the
* action thread if it is being blocked by a job.
*/
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
index ad1ee20045..a885898ad4 100644
--- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -47,17 +47,17 @@ private[spark] class HttpFileServer extends Logging {
def addFile(file: File) : String = {
addFileToDir(file, fileDir)
- return serverUri + "/files/" + file.getName
+ serverUri + "/files/" + file.getName
}
def addJar(file: File) : String = {
addFileToDir(file, jarDir)
- return serverUri + "/jars/" + file.getName
+ serverUri + "/jars/" + file.getName
}
def addFileToDir(file: File, dir: File) : String = {
Files.copy(file, new File(dir, file.getName))
- return dir + "/" + file.getName
+ dir + "/" + file.getName
}
}
diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
index 56e0b8d2c0..9b1601d5b9 100644
--- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
+++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
@@ -19,7 +19,7 @@ package org.apache.spark
/**
* An iterator that wraps around an existing iterator to provide task killing functionality.
- * It works by checking the interrupted flag in TaskContext.
+ * It works by checking the interrupted flag in [[TaskContext]].
*/
class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
extends Iterator[T] {
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala
index 4a34989e50..b749e5414d 100644
--- a/core/src/main/scala/org/apache/spark/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/Logging.scala
@@ -41,7 +41,7 @@ trait Logging {
}
log_ = LoggerFactory.getLogger(className)
}
- return log_
+ log_
}
// Log methods that take only a String
@@ -122,7 +122,7 @@ trait Logging {
}
}
-object Logging {
+private object Logging {
@volatile private var initialized = false
val initLock = new Object()
}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 77b8ca1cce..30d182b008 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -32,15 +32,16 @@ import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.{AkkaUtils, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
private[spark] sealed trait MapOutputTrackerMessage
-private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String)
+private[spark] case class GetMapOutputStatuses(shuffleId: Int)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
extends Actor with Logging {
def receive = {
- case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
- logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
+ case GetMapOutputStatuses(shuffleId: Int) =>
+ val hostPort = sender.path.address.hostPort
+ logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
case StopMapOutputTracker =>
@@ -119,11 +120,10 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
if (fetchedStatuses == null) {
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
- val hostPort = Utils.localHostPort(conf)
// This try-finally prevents hangs due to timeouts:
try {
val fetchedBytes =
- askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
+ askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]]
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
@@ -139,7 +139,7 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
}
- else{
+ else {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing all output locations for shuffle " + shuffleId))
}
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index 9b043f06dd..fc0a749882 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -53,9 +53,9 @@ object Partitioner {
return r.partitioner.get
}
if (rdd.context.conf.contains("spark.default.parallelism")) {
- return new HashPartitioner(rdd.context.defaultParallelism)
+ new HashPartitioner(rdd.context.defaultParallelism)
} else {
- return new HashPartitioner(bySize.head.partitions.size)
+ new HashPartitioner(bySize.head.partitions.size)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index d7e681d921..55ac76bf63 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -345,9 +345,20 @@ class SparkContext(
}
/**
- * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and any
- * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
- * etc).
+ * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and other
+ * necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable),
+ * using the older MapReduce API (`org.apache.hadoop.mapred`).
+ *
+ * @param conf JobConf for setting up the dataset
+ * @param inputFormatClass Class of the [[InputFormat]]
+ * @param keyClass Class of the keys
+ * @param valueClass Class of the values
+ * @param minSplits Minimum number of Hadoop Splits to generate.
+ * @param cloneRecords If true, Spark will clone the records produced by Hadoop RecordReader.
+ * Most RecordReader implementations reuse wrapper objects across multiple
+ * records, and can cause problems in RDD collect or aggregation operations.
+ * By default the records are cloned in Spark. However, application
+ * programmers can explicitly disable the cloning for better performance.
*/
def hadoopRDD[K: ClassTag, V: ClassTag](
conf: JobConf,
@@ -355,11 +366,11 @@ class SparkContext(
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int = defaultMinSplits,
- cloneKeyValues: Boolean = true
+ cloneRecords: Boolean = true
): RDD[(K, V)] = {
// Add necessary security credentials to the JobConf before broadcasting it.
SparkHadoopUtil.get.addCredentials(conf)
- new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits, cloneKeyValues)
+ new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits, cloneRecords)
}
/** Get an RDD for a Hadoop file with an arbitrary InputFormat */
@@ -369,7 +380,7 @@ class SparkContext(
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int = defaultMinSplits,
- cloneKeyValues: Boolean = true
+ cloneRecords: Boolean = true
): RDD[(K, V)] = {
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
@@ -382,7 +393,7 @@ class SparkContext(
keyClass,
valueClass,
minSplits,
- cloneKeyValues)
+ cloneRecords)
}
/**
@@ -394,14 +405,14 @@ class SparkContext(
* }}}
*/
def hadoopFile[K, V, F <: InputFormat[K, V]]
- (path: String, minSplits: Int, cloneKeyValues: Boolean = true)
+ (path: String, minSplits: Int, cloneRecords: Boolean = true)
(implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = {
hadoopFile(path,
fm.runtimeClass.asInstanceOf[Class[F]],
km.runtimeClass.asInstanceOf[Class[K]],
vm.runtimeClass.asInstanceOf[Class[V]],
minSplits,
- cloneKeyValues = cloneKeyValues)
+ cloneRecords)
}
/**
@@ -412,20 +423,20 @@ class SparkContext(
* val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path)
* }}}
*/
- def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, cloneKeyValues: Boolean = true)
+ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, cloneRecords: Boolean = true)
(implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] =
- hadoopFile[K, V, F](path, defaultMinSplits, cloneKeyValues)
+ hadoopFile[K, V, F](path, defaultMinSplits, cloneRecords)
/** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */
def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]]
- (path: String, cloneKeyValues: Boolean = true)
+ (path: String, cloneRecords: Boolean = true)
(implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = {
newAPIHadoopFile(
path,
fm.runtimeClass.asInstanceOf[Class[F]],
km.runtimeClass.asInstanceOf[Class[K]],
vm.runtimeClass.asInstanceOf[Class[V]],
- cloneKeyValues = cloneKeyValues)
+ cloneRecords = cloneRecords)
}
/**
@@ -438,11 +449,11 @@ class SparkContext(
kClass: Class[K],
vClass: Class[V],
conf: Configuration = hadoopConfiguration,
- cloneKeyValues: Boolean = true): RDD[(K, V)] = {
+ cloneRecords: Boolean = true): RDD[(K, V)] = {
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
- new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf, cloneKeyValues)
+ new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf, cloneRecords)
}
/**
@@ -454,8 +465,8 @@ class SparkContext(
fClass: Class[F],
kClass: Class[K],
vClass: Class[V],
- cloneKeyValues: Boolean = true): RDD[(K, V)] = {
- new NewHadoopRDD(this, fClass, kClass, vClass, conf, cloneKeyValues)
+ cloneRecords: Boolean = true): RDD[(K, V)] = {
+ new NewHadoopRDD(this, fClass, kClass, vClass, conf, cloneRecords)
}
/** Get an RDD for a Hadoop SequenceFile with given key and value types. */
@@ -463,16 +474,16 @@ class SparkContext(
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int,
- cloneKeyValues: Boolean = true
+ cloneRecords: Boolean = true
): RDD[(K, V)] = {
val inputFormatClass = classOf[SequenceFileInputFormat[K, V]]
- hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits, cloneKeyValues)
+ hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits, cloneRecords)
}
/** Get an RDD for a Hadoop SequenceFile with given key and value types. */
def sequenceFile[K: ClassTag, V: ClassTag](path: String, keyClass: Class[K], valueClass: Class[V],
- cloneKeyValues: Boolean = true): RDD[(K, V)] =
- sequenceFile(path, keyClass, valueClass, defaultMinSplits, cloneKeyValues)
+ cloneRecords: Boolean = true): RDD[(K, V)] =
+ sequenceFile(path, keyClass, valueClass, defaultMinSplits, cloneRecords)
/**
* Version of sequenceFile() for types implicitly convertible to Writables through a
@@ -490,17 +501,18 @@ class SparkContext(
* for the appropriate type. In addition, we pass the converter a ClassTag of its type to
* allow it to figure out the Writable class to use in the subclass case.
*/
- def sequenceFile[K, V](path: String, minSplits: Int = defaultMinSplits,
- cloneKeyValues: Boolean = true) (implicit km: ClassTag[K], vm: ClassTag[V],
- kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
+ def sequenceFile[K, V]
+ (path: String, minSplits: Int = defaultMinSplits, cloneRecords: Boolean = true)
+ (implicit km: ClassTag[K], vm: ClassTag[V],
+ kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
: RDD[(K, V)] = {
val kc = kcf()
val vc = vcf()
val format = classOf[SequenceFileInputFormat[Writable, Writable]]
val writables = hadoopFile(path, format,
kc.writableClass(km).asInstanceOf[Class[Writable]],
- vc.writableClass(vm).asInstanceOf[Class[Writable]], minSplits, cloneKeyValues)
- writables.map{case (k,v) => (kc.convert(k), vc.convert(v))}
+ vc.writableClass(vm).asInstanceOf[Class[Writable]], minSplits, cloneRecords)
+ writables.map { case (k, v) => (kc.convert(k), vc.convert(v)) }
}
/**
@@ -774,8 +786,11 @@ class SparkContext(
private[spark] def getCallSite(): String = {
val callSite = getLocalProperty("externalCallSite")
- if (callSite == null) return Utils.formatSparkCallSite
- callSite
+ if (callSite == null) {
+ Utils.formatSparkCallSite
+ } else {
+ callSite
+ }
}
/**
@@ -925,7 +940,7 @@ class SparkContext(
*/
private[spark] def clean[F <: AnyRef](f: F): F = {
ClosureCleaner.clean(f)
- return f
+ f
}
/**
@@ -937,7 +952,7 @@ class SparkContext(
val path = new Path(dir, UUID.randomUUID().toString)
val fs = path.getFileSystem(hadoopConfiguration)
fs.mkdirs(path)
- fs.getFileStatus(path).getPath().toString
+ fs.getFileStatus(path).getPath.toString
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 08b592df71..ed788560e7 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -132,16 +132,6 @@ object SparkEnv extends Logging {
conf.set("spark.driver.port", boundPort.toString)
}
- // set only if unset until now.
- if (!conf.contains("spark.hostPort")) {
- if (!isDriver){
- // unexpected
- Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set")
- }
- Utils.checkHost(hostname)
- conf.set("spark.hostPort", hostname + ":" + boundPort)
- }
-
val classLoader = Thread.currentThread.getContextClassLoader
// Create an instance of the class named by the given Java system property, or by
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 618d95015f..4e63117a51 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -134,28 +134,28 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
format = conf.value.getOutputFormat()
.asInstanceOf[OutputFormat[AnyRef,AnyRef]]
}
- return format
+ format
}
private def getOutputCommitter(): OutputCommitter = {
if (committer == null) {
committer = conf.value.getOutputCommitter
}
- return committer
+ committer
}
private def getJobContext(): JobContext = {
if (jobContext == null) {
jobContext = newJobContext(conf.value, jID.value)
}
- return jobContext
+ jobContext
}
private def getTaskContext(): TaskAttemptContext = {
if (taskContext == null) {
taskContext = newTaskAttemptContext(conf.value, taID.value)
}
- return taskContext
+ taskContext
}
private def setIDs(jobid: Int, splitid: Int, attemptid: Int) {
@@ -182,19 +182,18 @@ object SparkHadoopWriter {
def createJobID(time: Date, id: Int): JobID = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
val jobtrackerID = formatter.format(new Date())
- return new JobID(jobtrackerID, id)
+ new JobID(jobtrackerID, id)
}
def createPathFromString(path: String, conf: JobConf): Path = {
if (path == null) {
throw new IllegalArgumentException("Output path is null")
}
- var outputPath = new Path(path)
+ val outputPath = new Path(path)
val fs = outputPath.getFileSystem(conf)
if (outputPath == null || fs == null) {
throw new IllegalArgumentException("Incorrectly formatted output path")
}
- outputPath = outputPath.makeQualified(fs)
- return outputPath
+ outputPath.makeQualified(fs)
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 40c519b5bd..82527fe663 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -95,7 +95,7 @@ private[spark] class PythonRDD[T: ClassTag](
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
- return new Iterator[Array[Byte]] {
+ val stdoutIterator = new Iterator[Array[Byte]] {
def next(): Array[Byte] = {
val obj = _nextObj
if (hasNext) {
@@ -156,6 +156,7 @@ private[spark] class PythonRDD[T: ClassTag](
def hasNext = _nextObj.length != 0
}
+ stdoutIterator
}
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
index 0fc478a419..6bfe2cb4a2 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicLong
import org.apache.spark._
+private[spark]
abstract class Broadcast[T](private[spark] val id: Long) extends Serializable {
def value: T
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
index fb161ce69d..940e5ab805 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
@@ -25,7 +25,7 @@ import org.apache.spark.SparkConf
* BroadcastFactory implementation to instantiate a particular broadcast for the
* entire Spark job.
*/
-private[spark] trait BroadcastFactory {
+trait BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 0eacda3d7d..39ee0dbb92 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -63,7 +63,10 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
}
}
-private[spark] class HttpBroadcastFactory extends BroadcastFactory {
+/**
+ * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium.
+ */
+class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf) { HttpBroadcast.initialize(isDriver, conf) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index fdf92eca4f..d351dfc1f5 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -203,16 +203,16 @@ extends Logging {
}
bais.close()
- var tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
+ val tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
tInfo.hasBlocks = blockNum
- return tInfo
+ tInfo
}
def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock],
totalBytes: Int,
totalBlocks: Int): T = {
- var retByteArray = new Array[Byte](totalBytes)
+ val retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
@@ -236,8 +236,10 @@ private[spark] case class TorrentInfo(
@transient var hasBlocks = 0
}
-private[spark] class TorrentBroadcastFactory
- extends BroadcastFactory {
+/**
+ * A [[BroadcastFactory]] that creates a torrent-based implementation of broadcast.
+ */
+class TorrentBroadcastFactory extends BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) }
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index e133893f6c..9987e2300c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -29,13 +29,12 @@ import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.{DriverState, Master}
import org.apache.spark.util.{AkkaUtils, Utils}
-import akka.actor.Actor.emptyBehavior
import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent}
/**
* Proxy that relays messages to the driver.
*/
-class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends Actor with Logging {
+private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends Actor with Logging {
var masterActor: ActorSelection = _
val timeout = AkkaUtils.askTimeout(conf)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
index 7507bf8ad0..cf6a23339d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
@@ -10,8 +10,9 @@ import org.apache.spark.util.Utils
/**
** Utilities for running commands with the spark classpath.
*/
+private[spark]
object CommandUtils extends Logging {
- private[spark] def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = {
+ def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = {
val runner = getEnv("JAVA_HOME", command).map(_ + "/bin/java").getOrElse("java")
// SPARK-698: do not call the run.cmd script, as process.destroy()
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index f9e43e0e94..45b43b403d 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -103,7 +103,6 @@ private[spark] object CoarseGrainedExecutorBackend {
indestructible = true, conf = new SparkConf)
// set it
val sparkHostPort = hostname + ":" + boundPort
-// conf.set("spark.hostPort", sparkHostPort)
actorSystem.actorOf(
Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores),
name = "Executor")
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index a7b2328a02..c1b57f74d7 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -57,7 +57,7 @@ private[spark] class Executor(
Utils.setCustomHostname(slaveHostname)
// Set spark.* properties from executor arg
- val conf = new SparkConf(false)
+ val conf = new SparkConf(true)
conf.setAll(properties)
// If we are in yarn mode, systems can have different disk layouts so we must set it
@@ -279,7 +279,7 @@ private[spark] class Executor(
//System.exit(1)
}
} finally {
- // TODO: Unregister shuffle memory only for ShuffleMapTask
+ // TODO: Unregister shuffle memory only for ResultTask
val shuffleMemoryMap = env.shuffleMemoryMap
shuffleMemoryMap.synchronized {
shuffleMemoryMap.remove(Thread.currentThread().getId)
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index bb1471d9ee..0c8f4662a5 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -49,6 +49,16 @@ class TaskMetrics extends Serializable {
var resultSerializationTime: Long = _
/**
+ * The number of in-memory bytes spilled by this task
+ */
+ var memoryBytesSpilled: Long = _
+
+ /**
+ * The number of on-disk bytes spilled by this task
+ */
+ var diskBytesSpilled: Long = _
+
+ /**
* If this task reads from shuffle output, metrics on getting shuffle data will be collected here
*/
var shuffleReadMetrics: Option[ShuffleReadMetrics] = None
diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
index f736bb3713..fb4c65909a 100644
--- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
+++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
@@ -46,7 +46,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
throw new Exception("Max chunk size is " + maxChunkSize)
}
- if (size == 0 && gotChunkForSendingOnce == false) {
+ if (size == 0 && !gotChunkForSendingOnce) {
val newChunk = new MessageChunk(
new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
gotChunkForSendingOnce = true
diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala
index 95cb0206ac..cba8477ed5 100644
--- a/core/src/main/scala/org/apache/spark/network/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/Connection.scala
@@ -330,7 +330,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
// Is highly unlikely unless there was an unclean close of socket, etc
registerInterest()
logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
- return true
+ true
} catch {
case e: Exception => {
logWarning("Error finishing connection to " + address, e)
@@ -385,7 +385,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
}
}
// should not happen - to keep scala compiler happy
- return true
+ true
}
// This is a hack to determine if remote socket was closed or not.
@@ -559,7 +559,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
}
}
// should not happen - to keep scala compiler happy
- return true
+ true
}
def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala
index f2ecc6d439..2612884bdb 100644
--- a/core/src/main/scala/org/apache/spark/network/Message.scala
+++ b/core/src/main/scala/org/apache/spark/network/Message.scala
@@ -61,7 +61,7 @@ private[spark] object Message {
if (dataBuffers.exists(_ == null)) {
throw new Exception("Attempting to create buffer message with null buffer")
}
- return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId)
+ new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId)
}
def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
@@ -69,9 +69,9 @@ private[spark] object Message {
def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
if (dataBuffer == null) {
- return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
+ createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
} else {
- return createBufferMessage(Array(dataBuffer), ackId)
+ createBufferMessage(Array(dataBuffer), ackId)
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
index 546d921067..44204a8c46 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
@@ -64,7 +64,7 @@ private[spark] object ShuffleSender {
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
val file = new File(subDir, blockId.name)
- return new FileSegment(file, 0, file.length())
+ new FileSegment(file, 0, file.length())
}
}
val sender = new ShuffleSender(port, pResovler)
diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala
index 70a5a8caff..2625a7f6a5 100644
--- a/core/src/main/scala/org/apache/spark/package.scala
+++ b/core/src/main/scala/org/apache/spark/package.scala
@@ -29,6 +29,9 @@ package org.apache
* be saved as SequenceFiles. These operations are automatically available on any RDD of the right
* type (e.g. RDD[(Int, Int)] through implicit conversions when you
* `import org.apache.spark.SparkContext._`.
+ *
+ * Java programmers should reference the [[spark.api.java]] package
+ * for Spark programming APIs in Java.
*/
package object spark {
// For package docs only
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index a73714abca..9c6b308804 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -106,6 +106,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
override val partitioner = Some(part)
override def compute(s: Partition, context: TaskContext): Iterator[(K, CoGroupCombiner)] = {
+
val externalSorting = sparkConf.getBoolean("spark.shuffle.externalSorting", true)
val split = s.asInstanceOf[CoGroupPartition]
val numRdds = split.deps.size
@@ -150,6 +151,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
map.insert(kv._1, new CoGroupValue(kv._2, depNum))
}
}
+ context.taskMetrics.memoryBytesSpilled = map.memoryBytesSpilled
+ context.taskMetrics.diskBytesSpilled = map.diskBytesSpilled
new InterruptibleIterator(context, map.iterator)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index 98da35763b..cefcc3d2d9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -295,10 +295,10 @@ private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanc
val prefPartActual = prefPart.get
- if (minPowerOfTwo.size + slack <= prefPartActual.size) // more imbalance than the slack allows
- return minPowerOfTwo // prefer balance over locality
- else {
- return prefPartActual // prefer locality over balance
+ if (minPowerOfTwo.size + slack <= prefPartActual.size) { // more imbalance than the slack allows
+ minPowerOfTwo // prefer balance over locality
+ } else {
+ prefPartActual // prefer locality over balance
}
}
@@ -331,7 +331,7 @@ private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanc
*/
def run(): Array[PartitionGroup] = {
setupGroups(math.min(prev.partitions.length, maxPartitions)) // setup the groups (bins)
- throwBalls() // assign partitions (balls) to each group (bins)
+ throwBalls() // assign partitions (balls) to each group (bins)
getPartitions
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 2da4611b9c..dbe76f3431 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -45,14 +45,14 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp
val inputSplit = new SerializableWritable[InputSplit](s)
- override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt
+ override def hashCode(): Int = 41 * (41 + rddId) + idx
override val index: Int = idx
}
/**
* An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS,
- * sources in HBase, or S3).
+ * sources in HBase, or S3), using the older MapReduce API (`org.apache.hadoop.mapred`).
*
* @param sc The SparkContext to associate the RDD with.
* @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed
@@ -64,6 +64,11 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp
* @param keyClass Class of the key associated with the inputFormatClass.
* @param valueClass Class of the value associated with the inputFormatClass.
* @param minSplits Minimum number of Hadoop Splits (HadoopRDD partitions) to generate.
+ * @param cloneRecords If true, Spark will clone the records produced by Hadoop RecordReader.
+ * Most RecordReader implementations reuse wrapper objects across multiple
+ * records, and can cause problems in RDD collect or aggregation operations.
+ * By default the records are cloned in Spark. However, application
+ * programmers can explicitly disable the cloning for better performance.
*/
class HadoopRDD[K: ClassTag, V: ClassTag](
sc: SparkContext,
@@ -73,7 +78,7 @@ class HadoopRDD[K: ClassTag, V: ClassTag](
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int,
- cloneKeyValues: Boolean)
+ cloneRecords: Boolean = true)
extends RDD[(K, V)](sc, Nil) with Logging {
def this(
@@ -83,7 +88,7 @@ class HadoopRDD[K: ClassTag, V: ClassTag](
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int,
- cloneKeyValues: Boolean) = {
+ cloneRecords: Boolean) = {
this(
sc,
sc.broadcast(new SerializableWritable(conf))
@@ -93,7 +98,7 @@ class HadoopRDD[K: ClassTag, V: ClassTag](
keyClass,
valueClass,
minSplits,
- cloneKeyValues)
+ cloneRecords)
}
protected val jobConfCacheKey = "rdd_%d_job_conf".format(id)
@@ -105,11 +110,11 @@ class HadoopRDD[K: ClassTag, V: ClassTag](
val conf: Configuration = broadcastedConf.value.value
if (conf.isInstanceOf[JobConf]) {
// A user-broadcasted JobConf was provided to the HadoopRDD, so always use it.
- return conf.asInstanceOf[JobConf]
+ conf.asInstanceOf[JobConf]
} else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) {
// getJobConf() has been called previously, so there is already a local cache of the JobConf
// needed by this RDD.
- return HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf]
+ HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf]
} else {
// Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the
// local process. The local cache is accessed through HadoopRDD.putCachedMetadata().
@@ -117,7 +122,7 @@ class HadoopRDD[K: ClassTag, V: ClassTag](
val newJobConf = new JobConf(broadcastedConf.value.value)
initLocalJobConfFuncOpt.map(f => f(newJobConf))
HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
- return newJobConf
+ newJobConf
}
}
@@ -133,7 +138,7 @@ class HadoopRDD[K: ClassTag, V: ClassTag](
newInputFormat.asInstanceOf[Configurable].setConf(conf)
}
HadoopRDD.putCachedMetadata(inputFormatCacheKey, newInputFormat)
- return newInputFormat
+ newInputFormat
}
override def getPartitions: Array[Partition] = {
@@ -165,9 +170,9 @@ class HadoopRDD[K: ClassTag, V: ClassTag](
// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback{ () => closeIfNeeded() }
val key: K = reader.createKey()
- val keyCloneFunc = cloneWritables[K](getConf)
+ val keyCloneFunc = cloneWritables[K](jobConf)
val value: V = reader.createValue()
- val valueCloneFunc = cloneWritables[V](getConf)
+ val valueCloneFunc = cloneWritables[V](jobConf)
override def getNext() = {
try {
finished = !reader.next(key, value)
@@ -175,9 +180,8 @@ class HadoopRDD[K: ClassTag, V: ClassTag](
case eof: EOFException =>
finished = true
}
- if (cloneKeyValues) {
- (keyCloneFunc(key.asInstanceOf[Writable]),
- valueCloneFunc(value.asInstanceOf[Writable]))
+ if (cloneRecords) {
+ (keyCloneFunc(key.asInstanceOf[Writable]), valueCloneFunc(value.asInstanceOf[Writable]))
} else {
(key, value)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index a34786495b..992bd4aa0a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -36,16 +36,31 @@ class NewHadoopPartition(rddId: Int, val index: Int, @transient rawSplit: InputS
val serializableHadoopSplit = new SerializableWritable(rawSplit)
- override def hashCode(): Int = (41 * (41 + rddId) + index)
+ override def hashCode(): Int = 41 * (41 + rddId) + index
}
+/**
+ * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS,
+ * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`).
+ *
+ * @param sc The SparkContext to associate the RDD with.
+ * @param inputFormatClass Storage format of the data to be read.
+ * @param keyClass Class of the key associated with the inputFormatClass.
+ * @param valueClass Class of the value associated with the inputFormatClass.
+ * @param conf The Hadoop configuration.
+ * @param cloneRecords If true, Spark will clone the records produced by Hadoop RecordReader.
+ * Most RecordReader implementations reuse wrapper objects across multiple
+ * records, and can cause problems in RDD collect or aggregation operations.
+ * By default the records are cloned in Spark. However, application
+ * programmers can explicitly disable the cloning for better performance.
+ */
class NewHadoopRDD[K: ClassTag, V: ClassTag](
sc : SparkContext,
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
@transient conf: Configuration,
- cloneKeyValues: Boolean)
+ cloneRecords: Boolean)
extends RDD[(K, V)](sc, Nil)
with SparkHadoopMapReduceUtil
with Logging {
@@ -112,9 +127,8 @@ class NewHadoopRDD[K: ClassTag, V: ClassTag](
havePair = false
val key = reader.getCurrentKey
val value = reader.getCurrentValue
- if (cloneKeyValues) {
- (keyCloneFunc(key.asInstanceOf[Writable]),
- valueCloneFunc(value.asInstanceOf[Writable]))
+ if (cloneRecords) {
+ (keyCloneFunc(key.asInstanceOf[Writable]), valueCloneFunc(value.asInstanceOf[Writable]))
} else {
(key, value)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 1248409e35..4148581f52 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -88,20 +88,22 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) {
self.mapPartitionsWithContext((context, iter) => {
- new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
+ new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
}, preservesPartitioning = true)
} else if (mapSideCombine) {
- val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
+ val combined = self.mapPartitionsWithContext((context, iter) => {
+ aggregator.combineValuesByKey(iter, context)
+ }, preservesPartitioning = true)
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
.setSerializer(serializerClass)
partitioned.mapPartitionsWithContext((context, iter) => {
- new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter))
+ new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter, context))
}, preservesPartitioning = true)
} else {
// Don't apply map-side combiner.
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
values.mapPartitionsWithContext((context, iter) => {
- new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
+ new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
}, preservesPartitioning = true)
}
}
@@ -286,7 +288,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
if (getKeyClass().isArray && partitioner.isInstanceOf[HashPartitioner]) {
throw new SparkException("Default partitioner cannot partition array keys.")
}
- new ShuffledRDD[K, V, (K, V)](self, partitioner)
+ if (self.partitioner == partitioner) self else new ShuffledRDD[K, V, (K, V)](self, partitioner)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index 1dbbe39898..8ef919c4b5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -27,7 +27,6 @@ import scala.io.Source
import scala.reflect.ClassTag
import org.apache.spark.{SparkEnv, Partition, TaskContext}
-import org.apache.spark.broadcast.Broadcast
/**
@@ -96,7 +95,7 @@ class PipedRDD[T: ClassTag](
// Return an iterator that read lines from the process's stdout
val lines = Source.fromInputStream(proc.getInputStream).getLines
- return new Iterator[String] {
+ new Iterator[String] {
def next() = lines.next()
def hasNext = {
if (lines.hasNext) {
@@ -113,7 +112,7 @@ class PipedRDD[T: ClassTag](
}
}
-object PipedRDD {
+private object PipedRDD {
// Split a string into words using a standard StringTokenizer
def tokenize(command: String): Seq[String] = {
val buf = new ArrayBuffer[String]
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index f9dc12eee3..cd90a1561a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -549,6 +549,11 @@ abstract class RDD[T: ClassTag](
* of elements in each partition.
*/
def zipPartitions[B: ClassTag, V: ClassTag]
+ (rdd2: RDD[B], preservesPartitioning: Boolean)
+ (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] =
+ new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, preservesPartitioning)
+
+ def zipPartitions[B: ClassTag, V: ClassTag]
(rdd2: RDD[B])
(f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] =
new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, false)
@@ -764,7 +769,7 @@ abstract class RDD[T: ClassTag](
val entry = iter.next()
m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue)
}
- return m1
+ m1
}
val myResult = mapPartitions(countPartition).reduce(mergeMaps)
myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map
@@ -842,7 +847,7 @@ abstract class RDD[T: ClassTag](
partsScanned += numPartsToTry
}
- return buf.toArray
+ buf.toArray
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 38b536023b..7046c06d20 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -133,7 +133,8 @@ class DAGScheduler(
private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
- private[spark] val listenerBus = new SparkListenerBus()
+ // An async scheduler event bus. The bus should be stopped when DAGSCheduler is stopped.
+ private[spark] val listenerBus = new SparkListenerBus
// Contains the locations that each RDD's partitions are cached on
private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]]
@@ -1121,5 +1122,6 @@ class DAGScheduler(
}
metadataCleaner.cancel()
taskSched.stop()
+ listenerBus.stop()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
index 90eb8a747f..cc10cc0849 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
@@ -103,7 +103,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split)
}
- return retval.toSet
+ retval.toSet
}
// This method does not expect failures, since validate has already passed ...
@@ -121,18 +121,18 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
elem => retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, elem)
)
- return retval.toSet
+ retval.toSet
}
private def findPreferredLocations(): Set[SplitInfo] = {
logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat +
", inputFormatClazz : " + inputFormatClazz)
if (mapreduceInputFormat) {
- return prefLocsFromMapreduceInputFormat()
+ prefLocsFromMapreduceInputFormat()
}
else {
assert(mapredInputFormat)
- return prefLocsFromMapredInputFormat()
+ prefLocsFromMapredInputFormat()
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
index 1791242215..4bc13c23d9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
@@ -75,12 +75,12 @@ private[spark] class Pool(
return schedulableNameToSchedulable(schedulableName)
}
for (schedulable <- schedulableQueue) {
- var sched = schedulable.getSchedulableByName(schedulableName)
+ val sched = schedulable.getSchedulableByName(schedulableName)
if (sched != null) {
return sched
}
}
- return null
+ null
}
override def executorLost(executorId: String, host: String) {
@@ -92,7 +92,7 @@ private[spark] class Pool(
for (schedulable <- schedulableQueue) {
shouldRevive |= schedulable.checkSpeculatableTasks()
}
- return shouldRevive
+ shouldRevive
}
override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
@@ -101,7 +101,7 @@ private[spark] class Pool(
for (schedulable <- sortedSchedulableQueue) {
sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue()
}
- return sortedTaskSetQueue
+ sortedTaskSetQueue
}
def increaseRunningTasks(taskNum: Int) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala
index 3418640b8c..5e62c8468f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala
@@ -37,9 +37,9 @@ private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm {
res = math.signum(stageId1 - stageId2)
}
if (res < 0) {
- return true
+ true
} else {
- return false
+ false
}
}
}
@@ -56,7 +56,6 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm {
val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble
val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble
val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble
- var res:Boolean = true
var compare:Int = 0
if (s1Needy && !s2Needy) {
@@ -70,11 +69,11 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm {
}
if (compare < 0) {
- return true
+ true
} else if (compare > 0) {
- return false
+ false
} else {
- return s1.name < s2.name
+ s1.name < s2.name
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 627995c826..d8e97c3b7c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
import java.util.Properties
import org.apache.spark.util.{Utils, Distribution}
-import org.apache.spark.{Logging, SparkContext, TaskEndReason}
+import org.apache.spark.{Logging, TaskEndReason}
import org.apache.spark.executor.TaskMetrics
sealed trait SparkListenerEvents
@@ -27,7 +27,7 @@ sealed trait SparkListenerEvents
case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties)
extends SparkListenerEvents
-case class SparkListenerStageCompleted(val stage: StageInfo) extends SparkListenerEvents
+case class SparkListenerStageCompleted(stage: StageInfo) extends SparkListenerEvents
case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
@@ -43,6 +43,12 @@ case class SparkListenerJobStart(job: ActiveJob, stageIds: Array[Int], propertie
case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult)
extends SparkListenerEvents
+/** An event used in the listener to shutdown the listener daemon thread. */
+private[scheduler] case object SparkListenerShutdown extends SparkListenerEvents
+
+/**
+ * Interface for listening to events from the Spark scheduler.
+ */
trait SparkListener {
/**
* Called when a stage is completed, with information on the completed stage
@@ -112,7 +118,7 @@ class StatsReportListener extends SparkListener with Logging {
}
-object StatsReportListener extends Logging {
+private[spark] object StatsReportListener extends Logging {
//for profiling, the extremes are more interesting
val percentiles = Array[Int](0,5,10,25,50,75,90,95,100)
@@ -199,9 +205,9 @@ object StatsReportListener extends Logging {
}
}
+private case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double)
-case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double)
-object RuntimePercentage {
+private object RuntimePercentage {
def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
val denom = totalTime.toDouble
val fetchTime = metrics.shuffleReadMetrics.map{_.fetchWaitTime}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index e7defd768b..17b1328b86 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -24,15 +24,17 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import org.apache.spark.Logging
/** Asynchronously passes SparkListenerEvents to registered SparkListeners. */
-private[spark] class SparkListenerBus() extends Logging {
- private val sparkListeners = new ArrayBuffer[SparkListener]() with SynchronizedBuffer[SparkListener]
+private[spark] class SparkListenerBus extends Logging {
+ private val sparkListeners = new ArrayBuffer[SparkListener] with SynchronizedBuffer[SparkListener]
/* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than
* an OOM exception) if it's perpetually being added to more quickly than it's being drained. */
- private val EVENT_QUEUE_CAPACITY = 10000
+ private val EVENT_QUEUE_CAPACITY = 10000
private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents](EVENT_QUEUE_CAPACITY)
private var queueFullErrorMessageLogged = false
+ // Create a new daemon thread to listen for events. This thread is stopped when it receives
+ // a SparkListenerShutdown event, using the stop method.
new Thread("SparkListenerBus") {
setDaemon(true)
override def run() {
@@ -53,6 +55,9 @@ private[spark] class SparkListenerBus() extends Logging {
sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
case taskEnd: SparkListenerTaskEnd =>
sparkListeners.foreach(_.onTaskEnd(taskEnd))
+ case SparkListenerShutdown =>
+ // Get out of the while loop and shutdown the daemon thread
+ return
case _ =>
}
}
@@ -80,7 +85,7 @@ private[spark] class SparkListenerBus() extends Logging {
*/
def waitUntilEmpty(timeoutMillis: Int): Boolean = {
val finishTime = System.currentTimeMillis + timeoutMillis
- while (!eventQueue.isEmpty()) {
+ while (!eventQueue.isEmpty) {
if (System.currentTimeMillis > finishTime) {
return false
}
@@ -88,6 +93,8 @@ private[spark] class SparkListenerBus() extends Logging {
* add overhead in the general case. */
Thread.sleep(10)
}
- return true
+ true
}
+
+ def stop(): Unit = post(SparkListenerShutdown)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 7cb3fe46e5..c60e9896de 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -96,7 +96,7 @@ private[spark] class Stage(
def newAttemptId(): Int = {
val id = nextAttemptId
nextAttemptId += 1
- return id
+ id
}
val name = callSite.getOrElse(rdd.origin)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index e80cc6b0f6..9d3e615826 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -74,6 +74,6 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long
def value(): T = {
val resultSer = SparkEnv.get.serializer.newInstance()
- return resultSer.deserialize(valueBytes)
+ resultSer.deserialize(valueBytes)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index c52d6175d2..35e9544718 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -37,7 +37,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
protected val serializer = new ThreadLocal[SerializerInstance] {
override def initialValue(): SerializerInstance = {
- return sparkEnv.closureSerializer.newInstance()
+ sparkEnv.closureSerializer.newInstance()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index d4f74d3e18..6cc608ea5b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -352,9 +352,8 @@ private[spark] class TaskSchedulerImpl(
taskResultGetter.stop()
}
- // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out.
- // TODO: Do something better !
- Thread.sleep(5000L)
+ // sleeping for an arbitrary 1 seconds to ensure that messages are sent out.
+ Thread.sleep(1000L)
}
override def defaultParallelism() = backend.defaultParallelism()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index a10e5397ad..fc0ee07089 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -228,7 +228,7 @@ private[spark] class TaskSetManager(
return Some(index)
}
}
- return None
+ None
}
/** Check whether a task is currently running an attempt on a given host */
@@ -291,7 +291,7 @@ private[spark] class TaskSetManager(
}
}
- return None
+ None
}
/**
@@ -332,7 +332,7 @@ private[spark] class TaskSetManager(
}
// Finally, if all else has failed, find a speculative task
- return findSpeculativeTask(execId, host, locality)
+ findSpeculativeTask(execId, host, locality)
}
/**
@@ -387,7 +387,7 @@ private[spark] class TaskSetManager(
case _ =>
}
}
- return None
+ None
}
/**
@@ -584,7 +584,7 @@ private[spark] class TaskSetManager(
}
override def getSchedulableByName(name: String): Schedulable = {
- return null
+ null
}
override def addSchedulable(schedulable: Schedulable) {}
@@ -594,7 +594,7 @@ private[spark] class TaskSetManager(
override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this)
sortedTaskSetQueue += this
- return sortedTaskSetQueue
+ sortedTaskSetQueue
}
/** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */
@@ -669,7 +669,7 @@ private[spark] class TaskSetManager(
}
}
}
- return foundTasks
+ foundTasks
}
private def getLocalityWait(level: TaskLocality.TaskLocality): Long = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 8d596a76c2..0208388e86 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -165,7 +165,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
override def start() {
val properties = new ArrayBuffer[(String, String)]
for ((key, value) <- scheduler.sc.conf.getAll) {
- if (key.startsWith("spark.") && !key.equals("spark.hostPort")) {
+ if (key.startsWith("spark.")) {
properties += ((key, value))
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index e16d60c54c..c27049bdb5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -140,7 +140,7 @@ private[spark] class CoarseMesosSchedulerBackend(
.format(basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
- return command.build()
+ command.build()
}
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index b428c82a48..49781485d9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -141,13 +141,13 @@ private[spark] class MesosSchedulerBackend(
// Serialize the map as an array of (String, String) pairs
execArgs = Utils.serialize(props.toArray)
}
- return execArgs
+ execArgs
}
private def setClassLoader(): ClassLoader = {
val oldClassLoader = Thread.currentThread.getContextClassLoader
Thread.currentThread.setContextClassLoader(classLoader)
- return oldClassLoader
+ oldClassLoader
}
private def restoreClassLoader(oldClassLoader: ClassLoader) {
@@ -255,7 +255,7 @@ private[spark] class MesosSchedulerBackend(
.setType(Value.Type.SCALAR)
.setScalar(Value.Scalar.newBuilder().setValue(1).build())
.build()
- return MesosTaskInfo.newBuilder()
+ MesosTaskInfo.newBuilder()
.setTaskId(taskId)
.setSlaveId(SlaveID.newBuilder().setValue(slaveId).build())
.setExecutor(createExecutorInfo(slaveId))
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index ff9f241fc1..6461deee32 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -80,11 +80,11 @@ private[spark] class BlockManager(
val compressShuffle = conf.getBoolean("spark.shuffle.compress", true)
// Whether to compress RDD partitions that are stored serialized
val compressRdds = conf.getBoolean("spark.rdd.compress", false)
+ // Whether to compress shuffle output temporarily spilled to disk
+ val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", false)
val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf)
- val hostPort = Utils.localHostPort(conf)
-
val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
@@ -412,7 +412,7 @@ private[spark] class BlockManager(
logDebug("The value of block " + blockId + " is null")
}
logDebug("Block " + blockId + " not found")
- return None
+ None
}
/**
@@ -792,6 +792,7 @@ private[spark] class BlockManager(
case ShuffleBlockId(_, _, _) => compressShuffle
case BroadcastBlockId(_) => compressBroadcast
case RDDBlockId(_, _) => compressRdds
+ case TempBlockId(_) => compressShuffleSpill
case _ => false
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
index 21f003609b..42f52d7b26 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
@@ -42,15 +42,15 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage)
logDebug("Parsed as a block message array")
val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
- return Some(new BlockMessageArray(responseMessages).toBufferMessage)
+ Some(new BlockMessageArray(responseMessages).toBufferMessage)
} catch {
case e: Exception => logError("Exception handling buffer message", e)
- return None
+ None
}
}
case otherMessage: Any => {
logError("Unknown type message received: " + otherMessage)
- return None
+ None
}
}
}
@@ -61,7 +61,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
logDebug("Received [" + pB + "]")
putBlock(pB.id, pB.data, pB.level)
- return None
+ None
}
case BlockMessage.TYPE_GET_BLOCK => {
val gB = new GetBlock(blockMessage.getId)
@@ -70,9 +70,9 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
if (buffer == null) {
return None
}
- return Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer)))
+ Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer)))
}
- case _ => return None
+ case _ => None
}
}
@@ -93,7 +93,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
}
logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " and got buffer " + buffer)
- return buffer
+ buffer
}
}
@@ -111,7 +111,7 @@ private[spark] object BlockManagerWorker extends Logging {
val blockMessageArray = new BlockMessageArray(blockMessage)
val resultMessage = connectionManager.sendMessageReliablySync(
toConnManagerId, blockMessageArray.toBufferMessage)
- return (resultMessage != None)
+ resultMessage != None
}
def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
@@ -130,8 +130,8 @@ private[spark] object BlockManagerWorker extends Logging {
return blockMessage.getData
})
}
- case None => logDebug("No response message received"); return null
+ case None => logDebug("No response message received")
}
- return null
+ null
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
index 80dcb5a207..fbafcf79d2 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
@@ -154,7 +154,7 @@ private[spark] class BlockMessage() {
println()
*/
val finishTime = System.currentTimeMillis
- return Message.createBufferMessage(buffers)
+ Message.createBufferMessage(buffers)
}
override def toString: String = {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
index a06f50a0ac..59329361f3 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
@@ -96,7 +96,7 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockM
println()
println()
*/
- return Message.createBufferMessage(buffers)
+ Message.createBufferMessage(buffers)
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 369a277232..48cec4be41 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -32,7 +32,7 @@ import org.apache.spark.serializer.{SerializationStream, Serializer}
*
* This interface does not support concurrent writes.
*/
-abstract class BlockObjectWriter(val blockId: BlockId) {
+private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
def open(): BlockObjectWriter
@@ -69,7 +69,7 @@ abstract class BlockObjectWriter(val blockId: BlockId) {
}
/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
-class DiskBlockObjectWriter(
+private[spark] class DiskBlockObjectWriter(
blockId: BlockId,
file: File,
serializer: Serializer,
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index 05f676c6e2..27f057b9f2 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -245,7 +245,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
return false
}
}
- return true
+ true
}
override def contains(blockId: BlockId): Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index 6e0ff143b7..e2b24298a5 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -64,7 +64,7 @@ class ShuffleBlockManager(blockManager: BlockManager) {
// Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
// TODO: Remove this once the shuffle file consolidation feature is stable.
val consolidateShuffleFiles =
- conf.getBoolean("spark.shuffle.consolidateFiles", true)
+ conf.getBoolean("spark.shuffle.consolidateFiles", false)
private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
index b5596dffd3..1b7934d59f 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
@@ -74,7 +74,7 @@ class StorageLevel private(
if (deserialized_) {
ret |= 1
}
- return ret
+ ret
}
override def writeExternal(out: ObjectOutput) {
@@ -108,6 +108,10 @@ class StorageLevel private(
}
+/**
+ * Various [[org.apache.spark.storage.StorageLevel]] defined and utility functions for creating
+ * new storage levels.
+ */
object StorageLevel {
val NONE = new StorageLevel(false, false, false)
val DISK_ONLY = new StorageLevel(true, false, false)
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala
index 3c53e88380..64e22a30b4 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala
@@ -24,4 +24,6 @@ private[spark] class ExecutorSummary {
var succeededTasks : Int = 0
var shuffleRead : Long = 0
var shuffleWrite : Long = 0
+ var memoryBytesSpilled : Long = 0
+ var diskBytesSpilled : Long = 0
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
index 0dd876480a..ab03eb5ce1 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
@@ -48,6 +48,8 @@ private[spark] class ExecutorTable(val parent: JobProgressUI, val stageId: Int)
<th>Succeeded Tasks</th>
<th>Shuffle Read</th>
<th>Shuffle Write</th>
+ <th>Shuffle Spill (Memory)</th>
+ <th>Shuffle Spill (Disk)</th>
</thead>
<tbody>
{createExecutorTable()}
@@ -80,6 +82,8 @@ private[spark] class ExecutorTable(val parent: JobProgressUI, val stageId: Int)
<td>{v.succeededTasks}</td>
<td>{Utils.bytesToString(v.shuffleRead)}</td>
<td>{Utils.bytesToString(v.shuffleWrite)}</td>
+ <td>{Utils.bytesToString(v.memoryBytesSpilled)}</td>
+ <td>{Utils.bytesToString(v.diskBytesSpilled)}</td>
</tr>
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index bcd2824450..858a10ce75 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -52,6 +52,8 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val stageIdToTime = HashMap[Int, Long]()
val stageIdToShuffleRead = HashMap[Int, Long]()
val stageIdToShuffleWrite = HashMap[Int, Long]()
+ val stageIdToMemoryBytesSpilled = HashMap[Int, Long]()
+ val stageIdToDiskBytesSpilled = HashMap[Int, Long]()
val stageIdToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
val stageIdToTasksComplete = HashMap[Int, Int]()
val stageIdToTasksFailed = HashMap[Int, Int]()
@@ -78,6 +80,8 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
stageIdToTime.remove(s.stageId)
stageIdToShuffleRead.remove(s.stageId)
stageIdToShuffleWrite.remove(s.stageId)
+ stageIdToMemoryBytesSpilled.remove(s.stageId)
+ stageIdToDiskBytesSpilled.remove(s.stageId)
stageIdToTasksActive.remove(s.stageId)
stageIdToTasksComplete.remove(s.stageId)
stageIdToTasksFailed.remove(s.stageId)
@@ -149,6 +153,8 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
Option(taskEnd.taskMetrics).foreach { taskMetrics =>
taskMetrics.shuffleReadMetrics.foreach { y.shuffleRead += _.remoteBytesRead }
taskMetrics.shuffleWriteMetrics.foreach { y.shuffleWrite += _.shuffleBytesWritten }
+ y.memoryBytesSpilled += taskMetrics.memoryBytesSpilled
+ y.diskBytesSpilled += taskMetrics.diskBytesSpilled
}
}
case _ => {}
@@ -184,6 +190,14 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
stageIdToShuffleWrite(sid) += shuffleWrite
totalShuffleWrite += shuffleWrite
+ stageIdToMemoryBytesSpilled.getOrElseUpdate(sid, 0L)
+ val memoryBytesSpilled = metrics.map(m => m.memoryBytesSpilled).getOrElse(0L)
+ stageIdToMemoryBytesSpilled(sid) += memoryBytesSpilled
+
+ stageIdToDiskBytesSpilled.getOrElseUpdate(sid, 0L)
+ val diskBytesSpilled = metrics.map(m => m.diskBytesSpilled).getOrElse(0L)
+ stageIdToDiskBytesSpilled(sid) += diskBytesSpilled
+
val taskList = stageIdToTaskInfos.getOrElse(
sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
taskList -= ((taskEnd.taskInfo, None, None))
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index d1e58016be..cfaf121895 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -56,6 +56,9 @@ private[spark] class StagePage(parent: JobProgressUI) {
val hasShuffleRead = shuffleReadBytes > 0
val shuffleWriteBytes = listener.stageIdToShuffleWrite.getOrElse(stageId, 0L)
val hasShuffleWrite = shuffleWriteBytes > 0
+ val memoryBytesSpilled = listener.stageIdToMemoryBytesSpilled.getOrElse(stageId, 0L)
+ val diskBytesSpilled = listener.stageIdToDiskBytesSpilled.getOrElse(stageId, 0L)
+ val hasBytesSpilled = (memoryBytesSpilled > 0 && diskBytesSpilled > 0)
var activeTime = 0L
listener.stageIdToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
@@ -81,6 +84,16 @@ private[spark] class StagePage(parent: JobProgressUI) {
{Utils.bytesToString(shuffleWriteBytes)}
</li>
}
+ {if (hasBytesSpilled)
+ <li>
+ <strong>Shuffle spill (memory): </strong>
+ {Utils.bytesToString(memoryBytesSpilled)}
+ </li>
+ <li>
+ <strong>Shuffle spill (disk): </strong>
+ {Utils.bytesToString(diskBytesSpilled)}
+ </li>
+ }
</ul>
</div>
@@ -89,9 +102,10 @@ private[spark] class StagePage(parent: JobProgressUI) {
Seq("Duration", "GC Time", "Result Ser Time") ++
{if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++
{if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++
+ {if (hasBytesSpilled) Seq("Shuffle Spill (Memory)", "Shuffle Spill (Disk)") else Nil} ++
Seq("Errors")
- val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite), tasks)
+ val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite, hasBytesSpilled), tasks)
// Excludes tasks which failed and have incomplete metrics
val validTasks = tasks.filter(t => t._1.status == "SUCCESS" && (t._2.isDefined))
@@ -153,13 +167,29 @@ private[spark] class StagePage(parent: JobProgressUI) {
}
val shuffleWriteQuantiles = "Shuffle Write" +: getQuantileCols(shuffleWriteSizes)
+ val memoryBytesSpilledSizes = validTasks.map {
+ case(info, metrics, exception) =>
+ metrics.get.memoryBytesSpilled.toDouble
+ }
+ val memoryBytesSpilledQuantiles = "Shuffle spill (memory)" +:
+ getQuantileCols(memoryBytesSpilledSizes)
+
+ val diskBytesSpilledSizes = validTasks.map {
+ case(info, metrics, exception) =>
+ metrics.get.diskBytesSpilled.toDouble
+ }
+ val diskBytesSpilledQuantiles = "Shuffle spill (disk)" +:
+ getQuantileCols(diskBytesSpilledSizes)
+
val listings: Seq[Seq[String]] = Seq(
serializationQuantiles,
serviceQuantiles,
gettingResultQuantiles,
schedulerDelayQuantiles,
if (hasShuffleRead) shuffleReadQuantiles else Nil,
- if (hasShuffleWrite) shuffleWriteQuantiles else Nil)
+ if (hasShuffleWrite) shuffleWriteQuantiles else Nil,
+ if (hasBytesSpilled) memoryBytesSpilledQuantiles else Nil,
+ if (hasBytesSpilled) diskBytesSpilledQuantiles else Nil)
val quantileHeaders = Seq("Metric", "Min", "25th percentile",
"Median", "75th percentile", "Max")
@@ -178,8 +208,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
}
}
-
- def taskRow(shuffleRead: Boolean, shuffleWrite: Boolean)
+ def taskRow(shuffleRead: Boolean, shuffleWrite: Boolean, bytesSpilled: Boolean)
(taskData: (TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])): Seq[Node] = {
def fmtStackTrace(trace: Seq[StackTraceElement]): Seq[Node] =
trace.map(e => <span style="display:block;">{e.toString}</span>)
@@ -205,6 +234,14 @@ private[spark] class StagePage(parent: JobProgressUI) {
val writeTimeReadable = maybeWriteTime.map{ t => t / (1000 * 1000)}.map{ ms =>
if (ms == 0) "" else parent.formatDuration(ms)}.getOrElse("")
+ val maybeMemoryBytesSpilled = metrics.map{m => m.memoryBytesSpilled}
+ val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.map(_.toString).getOrElse("")
+ val memoryBytesSpilledReadable = maybeMemoryBytesSpilled.map{Utils.bytesToString(_)}.getOrElse("")
+
+ val maybeDiskBytesSpilled = metrics.map{m => m.diskBytesSpilled}
+ val diskBytesSpilledSortable = maybeDiskBytesSpilled.map(_.toString).getOrElse("")
+ val diskBytesSpilledReadable = maybeDiskBytesSpilled.map{Utils.bytesToString(_)}.getOrElse("")
+
<tr>
<td>{info.index}</td>
<td>{info.taskId}</td>
@@ -234,6 +271,14 @@ private[spark] class StagePage(parent: JobProgressUI) {
{shuffleWriteReadable}
</td>
}}
+ {if (bytesSpilled) {
+ <td sorttable_customkey={memoryBytesSpilledSortable}>
+ {memoryBytesSpilledReadable}
+ </td>
+ <td sorttable_customkey={diskBytesSpilledSortable}>
+ {diskBytesSpilledReadable}
+ </td>
+ }}
<td>{exception.map(e =>
<span>
{e.className} ({e.description})<br/>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 463d85dfd5..9ad6de3c6d 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -48,7 +48,7 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr
{if (isFairScheduler) {<th>Pool Name</th>} else {}}
<th>Description</th>
<th>Submitted</th>
- <th>Task Time</th>
+ <th>Duration</th>
<th>Tasks: Succeeded/Total</th>
<th>Shuffle Read</th>
<th>Shuffle Write</th>
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index 7108595e3e..1df6b87fb0 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -61,7 +61,7 @@ private[spark] object ClosureCleaner extends Logging {
return f.getType :: Nil // Stop at the first $outer that is not a closure
}
}
- return Nil
+ Nil
}
// Get a list of the outer objects for a given closure object.
@@ -74,7 +74,7 @@ private[spark] object ClosureCleaner extends Logging {
return f.get(obj) :: Nil // Stop at the first $outer that is not a closure
}
}
- return Nil
+ Nil
}
private def getInnerClasses(obj: AnyRef): List[Class[_]] = {
@@ -174,7 +174,7 @@ private[spark] object ClosureCleaner extends Logging {
field.setAccessible(true)
field.set(obj, outer)
}
- return obj
+ obj
}
}
}
@@ -182,7 +182,7 @@ private[spark] object ClosureCleaner extends Logging {
private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
- return new MethodVisitor(ASM4) {
+ new MethodVisitor(ASM4) {
override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) {
if (op == GETFIELD) {
for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
@@ -215,7 +215,7 @@ private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisi
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
- return new MethodVisitor(ASM4) {
+ new MethodVisitor(ASM4) {
override def visitMethodInsn(op: Int, owner: String, name: String,
desc: String) {
val argTypes = Type.getArgumentTypes(desc)
diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
index dc15a38b29..fcc1ca9502 100644
--- a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
+++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
@@ -18,14 +18,15 @@
package org.apache.spark.util
/**
- * Wrapper around an iterator which calls a completion method after it successfully iterates through all the elements
+ * Wrapper around an iterator which calls a completion method after it successfully iterates
+ * through all the elements.
*/
-abstract class CompletionIterator[+A, +I <: Iterator[A]](sub: I) extends Iterator[A]{
- def next = sub.next
+private[spark] abstract class CompletionIterator[+A, +I <: Iterator[A]](sub: I) extends Iterator[A]{
+ def next() = sub.next()
def hasNext = {
val r = sub.hasNext
if (!r) {
- completion
+ completion()
}
r
}
@@ -33,7 +34,7 @@ abstract class CompletionIterator[+A, +I <: Iterator[A]](sub: I) extends Iterato
def completion()
}
-object CompletionIterator {
+private[spark] object CompletionIterator {
def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A,I] = {
new CompletionIterator[A,I](sub) {
def completion() = completionFunction
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index ac07a55cb9..b0febe906a 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -18,13 +18,13 @@
package org.apache.spark.util
import java.util.{TimerTask, Timer}
-import org.apache.spark.{SparkConf, SparkContext, Logging}
+import org.apache.spark.{SparkConf, Logging}
/**
* Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries)
*/
-class MetadataCleaner(
+private[spark] class MetadataCleaner(
cleanerType: MetadataCleanerType.MetadataCleanerType,
cleanupFunc: (Long) => Unit,
conf: SparkConf)
@@ -60,7 +60,7 @@ class MetadataCleaner(
}
}
-object MetadataCleanerType extends Enumeration {
+private[spark] object MetadataCleanerType extends Enumeration {
val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
@@ -72,7 +72,7 @@ object MetadataCleanerType extends Enumeration {
// TODO: This mutates a Conf to set properties right now, which is kind of ugly when used in the
// initialization of StreamingContext. It's okay for users trying to configure stuff themselves.
-object MetadataCleaner {
+private[spark] object MetadataCleaner {
def getDelaySeconds(conf: SparkConf) = {
conf.getInt("spark.cleaner.ttl", -1)
}
diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
index bddb3bb735..3cf94892e9 100644
--- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
+++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
@@ -108,7 +108,7 @@ private[spark] object SizeEstimator extends Logging {
val bean = ManagementFactory.newPlatformMXBeanProxy(server,
hotSpotMBeanName, hotSpotMBeanClass)
// TODO: We could use reflection on the VMOption returned ?
- return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true")
+ getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true")
} catch {
case e: Exception => {
// Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB
@@ -141,7 +141,7 @@ private[spark] object SizeEstimator extends Logging {
def dequeue(): AnyRef = {
val elem = stack.last
stack.trimEnd(1)
- return elem
+ elem
}
}
@@ -162,7 +162,7 @@ private[spark] object SizeEstimator extends Logging {
while (!state.isFinished) {
visitSingleObject(state.dequeue(), state)
}
- return state.size
+ state.size
}
private def visitSingleObject(obj: AnyRef, state: SearchState) {
@@ -276,11 +276,11 @@ private[spark] object SizeEstimator extends Logging {
// Create and cache a new ClassInfo
val newInfo = new ClassInfo(shellSize, pointerFields)
classInfos.put(cls, newInfo)
- return newInfo
+ newInfo
}
private def alignSize(size: Long): Long = {
val rem = size % ALIGN_SIZE
- return if (rem == 0) size else (size + ALIGN_SIZE - rem)
+ if (rem == 0) size else (size + ALIGN_SIZE - rem)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 23b72701c2..caa9bf4c92 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -73,14 +73,14 @@ private[spark] object Utils extends Logging {
val oos = new ObjectOutputStream(bos)
oos.writeObject(o)
oos.close()
- return bos.toByteArray
+ bos.toByteArray
}
/** Deserialize an object using Java serialization */
def deserialize[T](bytes: Array[Byte]): T = {
val bis = new ByteArrayInputStream(bytes)
val ois = new ObjectInputStream(bis)
- return ois.readObject.asInstanceOf[T]
+ ois.readObject.asInstanceOf[T]
}
/** Deserialize an object using Java serialization and the given ClassLoader */
@@ -90,7 +90,7 @@ private[spark] object Utils extends Logging {
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, loader)
}
- return ois.readObject.asInstanceOf[T]
+ ois.readObject.asInstanceOf[T]
}
/** Deserialize a Long value (used for {@link org.apache.spark.api.python.PythonPartitioner}) */
@@ -168,7 +168,7 @@ private[spark] object Utils extends Logging {
i += 1
}
}
- return buf
+ buf
}
private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]()
@@ -420,15 +420,6 @@ private[spark] object Utils extends Logging {
InetAddress.getByName(address).getHostName
}
- def localHostPort(conf: SparkConf): String = {
- val retval = conf.get("spark.hostPort", null)
- if (retval == null) {
- logErrorWithStack("spark.hostPort not set but invoking localHostPort")
- return localHostName()
- }
- retval
- }
-
def checkHost(host: String, message: String = "") {
assert(host.indexOf(':') == -1, message)
}
@@ -437,14 +428,6 @@ private[spark] object Utils extends Logging {
assert(hostPort.indexOf(':') != -1, message)
}
- def logErrorWithStack(msg: String) {
- try {
- throw new Exception
- } catch {
- case ex: Exception => logError(msg, ex)
- }
- }
-
// Typically, this will be of order of number of nodes in cluster
// If not, we should change it to LRUCache or something.
private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]()
@@ -452,7 +435,7 @@ private[spark] object Utils extends Logging {
def parseHostPort(hostPort: String): (String, Int) = {
{
// Check cache first.
- var cached = hostPortParseResults.get(hostPort)
+ val cached = hostPortParseResults.get(hostPort)
if (cached != null) return cached
}
@@ -755,7 +738,7 @@ private[spark] object Utils extends Logging {
} catch {
case ise: IllegalStateException => return true
}
- return false
+ false
}
def isSpace(c: Char): Boolean = {
@@ -772,7 +755,7 @@ private[spark] object Utils extends Logging {
var inWord = false
var inSingleQuote = false
var inDoubleQuote = false
- var curWord = new StringBuilder
+ val curWord = new StringBuilder
def endWord() {
buf += curWord.toString
curWord.clear()
@@ -818,7 +801,7 @@ private[spark] object Utils extends Logging {
if (inWord || inDoubleQuote || inSingleQuote) {
endWord()
}
- return buf
+ buf
}
/* Calculates 'x' modulo 'mod', takes to consideration sign of x,
@@ -846,8 +829,7 @@ private[spark] object Utils extends Logging {
/** Returns a copy of the system properties that is thread-safe to iterator over. */
def getSystemProperties(): Map[String, String] = {
- return System.getProperties().clone()
- .asInstanceOf[java.util.Properties].toMap[String, String]
+ System.getProperties.clone().asInstanceOf[java.util.Properties].toMap[String, String]
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala
index 62fd6d8da5..fcdf848637 100644
--- a/core/src/main/scala/org/apache/spark/util/Vector.scala
+++ b/core/src/main/scala/org/apache/spark/util/Vector.scala
@@ -27,7 +27,7 @@ class Vector(val elements: Array[Double]) extends Serializable {
def + (other: Vector): Vector = {
if (length != other.length)
throw new IllegalArgumentException("Vectors of different length")
- return Vector(length, i => this(i) + other(i))
+ Vector(length, i => this(i) + other(i))
}
def add(other: Vector) = this + other
@@ -35,7 +35,7 @@ class Vector(val elements: Array[Double]) extends Serializable {
def - (other: Vector): Vector = {
if (length != other.length)
throw new IllegalArgumentException("Vectors of different length")
- return Vector(length, i => this(i) - other(i))
+ Vector(length, i => this(i) - other(i))
}
def subtract(other: Vector) = this - other
@@ -49,7 +49,7 @@ class Vector(val elements: Array[Double]) extends Serializable {
ans += this(i) * other(i)
i += 1
}
- return ans
+ ans
}
/**
@@ -69,7 +69,7 @@ class Vector(val elements: Array[Double]) extends Serializable {
ans += (this(i) + plus(i)) * other(i)
i += 1
}
- return ans
+ ans
}
def += (other: Vector): Vector = {
@@ -104,7 +104,7 @@ class Vector(val elements: Array[Double]) extends Serializable {
ans += (this(i) - other(i)) * (this(i) - other(i))
i += 1
}
- return ans
+ ans
}
def dist(other: Vector): Double = math.sqrt(squaredDist(other))
@@ -119,7 +119,7 @@ object Vector {
def apply(length: Int, initializer: Int => Double): Vector = {
val elements: Array[Double] = Array.tabulate(length)(initializer)
- return new Vector(elements)
+ new Vector(elements)
}
def zeros(length: Int) = new Vector(new Array[Double](length))
diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
index d98c7aa3d7..b8c852b4ff 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
@@ -75,7 +75,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K,
i += 1
}
}
- return null.asInstanceOf[V]
+ null.asInstanceOf[V]
}
/** Set the value for a key */
diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
index a1a452315d..856eb772a1 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
@@ -22,10 +22,72 @@ package org.apache.spark.util.collection
* A simple, fixed-size bit set implementation. This implementation is fast because it avoids
* safety/bound checking.
*/
-class BitSet(numBits: Int) {
+class BitSet(numBits: Int) extends Serializable {
- private[this] val words = new Array[Long](bit2words(numBits))
- private[this] val numWords = words.length
+ private val words = new Array[Long](bit2words(numBits))
+ private val numWords = words.length
+
+ /**
+ * Compute the capacity (number of bits) that can be represented
+ * by this bitset.
+ */
+ def capacity: Int = numWords * 64
+
+ /**
+ * Set all the bits up to a given index
+ */
+ def setUntil(bitIndex: Int) {
+ val wordIndex = bitIndex >> 6 // divide by 64
+ var i = 0
+ while(i < wordIndex) { words(i) = -1; i += 1 }
+ if(wordIndex < words.size) {
+ // Set the remaining bits (note that the mask could still be zero)
+ val mask = ~(-1L << (bitIndex & 0x3f))
+ words(wordIndex) |= mask
+ }
+ }
+
+ /**
+ * Compute the bit-wise AND of the two sets returning the
+ * result.
+ */
+ def &(other: BitSet): BitSet = {
+ val newBS = new BitSet(math.max(capacity, other.capacity))
+ val smaller = math.min(numWords, other.numWords)
+ assert(newBS.numWords >= numWords)
+ assert(newBS.numWords >= other.numWords)
+ var ind = 0
+ while( ind < smaller ) {
+ newBS.words(ind) = words(ind) & other.words(ind)
+ ind += 1
+ }
+ newBS
+ }
+
+ /**
+ * Compute the bit-wise OR of the two sets returning the
+ * result.
+ */
+ def |(other: BitSet): BitSet = {
+ val newBS = new BitSet(math.max(capacity, other.capacity))
+ assert(newBS.numWords >= numWords)
+ assert(newBS.numWords >= other.numWords)
+ val smaller = math.min(numWords, other.numWords)
+ var ind = 0
+ while( ind < smaller ) {
+ newBS.words(ind) = words(ind) | other.words(ind)
+ ind += 1
+ }
+ while( ind < numWords ) {
+ newBS.words(ind) = words(ind)
+ ind += 1
+ }
+ while( ind < other.numWords ) {
+ newBS.words(ind) = other.words(ind)
+ ind += 1
+ }
+ newBS
+ }
/**
* Sets the bit at the specified index to true.
@@ -36,6 +98,11 @@ class BitSet(numBits: Int) {
words(index >> 6) |= bitmask // div by 64 and mask
}
+ def unset(index: Int) {
+ val bitmask = 1L << (index & 0x3f) // mod 64 and shift
+ words(index >> 6) &= ~bitmask // div by 64 and mask
+ }
+
/**
* Return the value of the bit with the specified index. The value is true if the bit with
* the index is currently set in this BitSet; otherwise, the result is false.
@@ -48,6 +115,20 @@ class BitSet(numBits: Int) {
(words(index >> 6) & bitmask) != 0 // div by 64 and mask
}
+ /**
+ * Get an iterator over the set bits.
+ */
+ def iterator = new Iterator[Int] {
+ var ind = nextSetBit(0)
+ override def hasNext: Boolean = ind >= 0
+ override def next() = {
+ val tmp = ind
+ ind = nextSetBit(ind+1)
+ tmp
+ }
+ }
+
+
/** Return the number of bits set to true in this BitSet. */
def cardinality(): Int = {
var sum = 0
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index e3bcd895aa..64e9b436f0 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -26,8 +26,8 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{Logging, SparkEnv}
-import org.apache.spark.serializer.Serializer
-import org.apache.spark.storage.{DiskBlockManager, DiskBlockObjectWriter}
+import org.apache.spark.serializer.{KryoDeserializationStream, KryoSerializationStream, Serializer}
+import org.apache.spark.storage.{BlockId, BlockManager, DiskBlockManager, DiskBlockObjectWriter}
/**
* An append-only map that spills sorted content to disk when there is insufficient space for it
@@ -60,7 +60,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
serializer: Serializer = SparkEnv.get.serializerManager.default,
- diskBlockManager: DiskBlockManager = SparkEnv.get.blockManager.diskBlockManager)
+ blockManager: BlockManager = SparkEnv.get.blockManager)
extends Iterable[(K, C)] with Serializable with Logging {
import ExternalAppendOnlyMap._
@@ -68,6 +68,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
private var currentMap = new SizeTrackingAppendOnlyMap[K, C]
private val spilledMaps = new ArrayBuffer[DiskMapIterator]
private val sparkConf = SparkEnv.get.conf
+ private val diskBlockManager = blockManager.diskBlockManager
// Collective memory threshold shared across all running tasks
private val maxMemoryThreshold = {
@@ -77,14 +78,26 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
}
// Number of pairs in the in-memory map
- private var numPairsInMemory = 0
+ private var numPairsInMemory = 0L
// Number of in-memory pairs inserted before tracking the map's shuffle memory usage
private val trackMemoryThreshold = 1000
+ // Size of object batches when reading/writing from serializers. Objects are written in
+ // batches, with each batch using its own serialization stream. This cuts down on the size
+ // of reference-tracking maps constructed when deserializing a stream.
+ //
+ // NOTE: Setting this too low can cause excess copying when serializing, since some serializers
+ // grow internal data structures by growing + copying every time the number of objects doubles.
+ private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000)
+
// How many times we have spilled so far
private var spillCount = 0
+ // Number of bytes spilled in total
+ private var _memoryBytesSpilled = 0L
+ private var _diskBytesSpilled = 0L
+
private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
private val syncWrites = sparkConf.getBoolean("spark.shuffle.sync", false)
private val comparator = new KCComparator[K, C]
@@ -139,21 +152,35 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
logWarning("Spilling in-memory map of %d MB to disk (%d time%s so far)"
.format(mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
val (blockId, file) = diskBlockManager.createTempBlock()
- val writer =
- new DiskBlockObjectWriter(blockId, file, serializer, fileBufferSize, identity, syncWrites)
+
+ val compressStream: OutputStream => OutputStream = blockManager.wrapForCompression(blockId, _)
+ def getNewWriter = new DiskBlockObjectWriter(blockId, file, serializer, fileBufferSize,
+ compressStream, syncWrites)
+
+ var writer = getNewWriter
+ var objectsWritten = 0
try {
val it = currentMap.destructiveSortedIterator(comparator)
while (it.hasNext) {
val kv = it.next()
writer.write(kv)
+ objectsWritten += 1
+
+ if (objectsWritten == serializerBatchSize) {
+ writer.commit()
+ writer = getNewWriter
+ objectsWritten = 0
+ }
}
- writer.commit()
+
+ if (objectsWritten > 0) writer.commit()
} finally {
// Partial failures cannot be tolerated; do not revert partial writes
+ _diskBytesSpilled += writer.bytesWritten
writer.close()
}
currentMap = new SizeTrackingAppendOnlyMap[K, C]
- spilledMaps.append(new DiskMapIterator(file))
+ spilledMaps.append(new DiskMapIterator(file, blockId))
// Reset the amount of shuffle memory used by this map in the global pool
val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
@@ -161,8 +188,12 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
shuffleMemoryMap(Thread.currentThread().getId) = 0
}
numPairsInMemory = 0
+ _memoryBytesSpilled += mapSize
}
+ def memoryBytesSpilled: Long = _memoryBytesSpilled
+ def diskBytesSpilled: Long = _diskBytesSpilled
+
/**
* Return an iterator that merges the in-memory map with the spilled maps.
* If no spill has occurred, simply return the in-memory map's iterator.
@@ -297,16 +328,35 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
/**
* An iterator that returns (K, C) pairs in sorted order from an on-disk map
*/
- private class DiskMapIterator(file: File) extends Iterator[(K, C)] {
+ private class DiskMapIterator(file: File, blockId: BlockId) extends Iterator[(K, C)] {
val fileStream = new FileInputStream(file)
- val bufferedStream = new FastBufferedInputStream(fileStream)
- val deserializeStream = ser.deserializeStream(bufferedStream)
+ val bufferedStream = new FastBufferedInputStream(fileStream, fileBufferSize)
+ val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream)
+ var deserializeStream = ser.deserializeStream(compressedStream)
+ var objectsRead = 0
+
var nextItem: (K, C) = null
var eof = false
def readNextItem(): (K, C) = {
if (!eof) {
try {
+ if (objectsRead == serializerBatchSize) {
+ val newInputStream = deserializeStream match {
+ case stream: KryoDeserializationStream =>
+ // Kryo's serializer stores an internal buffer that pre-fetches from the underlying
+ // stream. We need to capture this buffer and feed it to the new serialization
+ // stream so that the bytes are not lost.
+ val kryoInput = stream.input
+ val remainingBytes = kryoInput.limit() - kryoInput.position()
+ val extraBuf = kryoInput.readBytes(remainingBytes)
+ new SequenceInputStream(new ByteArrayInputStream(extraBuf), compressedStream)
+ case _ => compressedStream
+ }
+ deserializeStream = ser.deserializeStream(newInputStream)
+ objectsRead = 0
+ }
+ objectsRead += 1
return deserializeStream.readObject().asInstanceOf[(K, C)]
} catch {
case e: EOFException =>
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
index 87e009a4de..5ded5d0b6d 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -84,6 +84,8 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
protected var _bitset = new BitSet(_capacity)
+ def getBitSet = _bitset
+
// Init of the array in constructor (instead of in declaration) to work around a Scala compiler
// specialization bug that would generate two arrays (one for Object and one for specialized T).
protected var _data: Array[T] = _
@@ -161,7 +163,8 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
def getPos(k: T): Int = {
var pos = hashcode(hasher.hash(k)) & _mask
var i = 1
- while (true) {
+ val maxProbe = _data.size
+ while (i < maxProbe) {
if (!_bitset.get(pos)) {
return INVALID_POS
} else if (k == _data(pos)) {
@@ -179,6 +182,22 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
/** Return the value at the specified position. */
def getValue(pos: Int): T = _data(pos)
+ def iterator = new Iterator[T] {
+ var pos = nextPos(0)
+ override def hasNext: Boolean = pos != INVALID_POS
+ override def next(): T = {
+ val tmp = getValue(pos)
+ pos = nextPos(pos+1)
+ tmp
+ }
+ }
+
+ /** Return the value at the specified position. */
+ def getValueSafe(pos: Int): T = {
+ assert(_bitset.get(pos))
+ _data(pos)
+ }
+
/**
* Return the next position with an element stored, starting from the given position inclusively.
*/
@@ -259,7 +278,7 @@ object OpenHashSet {
* A set of specialized hash function implementation to avoid boxing hash code computation
* in the specialized implementation of OpenHashSet.
*/
- sealed class Hasher[@specialized(Long, Int) T] {
+ sealed class Hasher[@specialized(Long, Int) T] extends Serializable {
def hash(o: T): Int = o.hashCode()
}
diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
index 8dd5786da6..3ac706110e 100644
--- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
+++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
@@ -53,7 +53,6 @@ object LocalSparkContext {
}
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
}
/** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index afc1beff98..930c2523ca 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -99,7 +99,6 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val hostname = "localhost"
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf)
System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
- System.setProperty("spark.hostPort", hostname + ":" + boundPort)
val masterTracker = new MapOutputTrackerMaster(conf)
masterTracker.trackerActor = actorSystem.actorOf(
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
index 7bf2020fe3..235d31709a 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
@@ -64,7 +64,7 @@ class FakeTaskSetManager(
}
override def getSchedulableByName(name: String): Schedulable = {
- return null
+ null
}
override def executorLost(executorId: String, host: String): Unit = {
@@ -79,13 +79,14 @@ class FakeTaskSetManager(
{
if (tasksSuccessful + runningTasks < numTasks) {
increaseRunningTasks(1)
- return Some(new TaskDescription(0, execId, "task 0:0", 0, null))
+ Some(new TaskDescription(0, execId, "task 0:0", 0, null))
+ } else {
+ None
}
- return None
}
override def checkSpeculatableTasks(): Boolean = {
- return true
+ true
}
def taskFinished() {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 2aa259daf3..f0236ef1e9 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -122,7 +122,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
locations: Seq[Seq[String]] = Nil
): MyRDD = {
val maxPartition = numPartitions - 1
- return new MyRDD(sc, dependencies) {
+ val newRDD = new MyRDD(sc, dependencies) {
override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
throw new RuntimeException("should not be reached")
override def getPartitions = (0 to maxPartition).map(i => new Partition {
@@ -135,6 +135,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
Nil
override def toString: String = "DAGSchedulerSuiteRDD " + id
}
+ newRDD
}
/**
diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
index 5cc48ee00a..29102913c7 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
@@ -42,12 +42,9 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage)
}
type MyRDD = RDD[(Int, Int)]
- def makeRdd(
- numPartitions: Int,
- dependencies: List[Dependency[_]]
- ): MyRDD = {
+ def makeRdd(numPartitions: Int, dependencies: List[Dependency[_]]): MyRDD = {
val maxPartition = numPartitions - 1
- return new MyRDD(sc, dependencies) {
+ new MyRDD(sc, dependencies) {
override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
throw new RuntimeException("should not be reached")
override def getPartitions = (0 to maxPartition).map(i => new Partition {
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index f60ce270c7..18aa587662 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -53,7 +53,6 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf)
this.actorSystem = actorSystem
conf.set("spark.driver.port", boundPort.toString)
- conf.set("spark.hostPort", "localhost:" + boundPort)
master = new BlockManagerMaster(
actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf))), conf)
@@ -65,13 +64,10 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
conf.set("spark.storage.disableBlockManagerHeartBeat", "true")
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
- // Set some value ...
- conf.set("spark.hostPort", Utils.localHostName() + ":" + 1111)
}
after {
System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
if (store != null) {
store.stop()
diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
index 0ed366fb70..de4871d043 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
@@ -61,8 +61,8 @@ class NonSerializable {}
object TestObject {
def run(): Int = {
var nonSer = new NonSerializable
- var x = 5
- return withSpark(new SparkContext("local", "test")) { sc =>
+ val x = 5
+ withSpark(new SparkContext("local", "test")) { sc =>
val nums = sc.parallelize(Array(1, 2, 3, 4))
nums.map(_ + x).reduce(_ + _)
}
@@ -76,7 +76,7 @@ class TestClass extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
- return withSpark(new SparkContext("local", "test")) { sc =>
+ withSpark(new SparkContext("local", "test")) { sc =>
val nums = sc.parallelize(Array(1, 2, 3, 4))
nums.map(_ + getX).reduce(_ + _)
}
@@ -88,7 +88,7 @@ class TestClassWithoutDefaultConstructor(x: Int) extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
- return withSpark(new SparkContext("local", "test")) { sc =>
+ withSpark(new SparkContext("local", "test")) { sc =>
val nums = sc.parallelize(Array(1, 2, 3, 4))
nums.map(_ + getX).reduce(_ + _)
}
@@ -103,7 +103,7 @@ class TestClassWithoutFieldAccess {
def run(): Int = {
var nonSer2 = new NonSerializable
var x = 5
- return withSpark(new SparkContext("local", "test")) { sc =>
+ withSpark(new SparkContext("local", "test")) { sc =>
val nums = sc.parallelize(Array(1, 2, 3, 4))
nums.map(_ + x).reduce(_ + _)
}
@@ -115,7 +115,7 @@ object TestObjectWithNesting {
def run(): Int = {
var nonSer = new NonSerializable
var answer = 0
- return withSpark(new SparkContext("local", "test")) { sc =>
+ withSpark(new SparkContext("local", "test")) { sc =>
val nums = sc.parallelize(Array(1, 2, 3, 4))
var y = 1
for (i <- 1 to 4) {
@@ -134,7 +134,7 @@ class TestClassWithNesting(val y: Int) extends Serializable {
def run(): Int = {
var nonSer = new NonSerializable
var answer = 0
- return withSpark(new SparkContext("local", "test")) { sc =>
+ withSpark(new SparkContext("local", "test")) { sc =>
val nums = sc.parallelize(Array(1, 2, 3, 4))
for (i <- 1 to 4) {
var nonSer2 = new NonSerializable
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index ef957bb0e5..c3391f3e53 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -9,22 +9,19 @@ import org.apache.spark.SparkContext._
class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
- override def beforeEach() {
- val conf = new SparkConf(false)
- conf.set("spark.shuffle.externalSorting", "true")
- sc = new SparkContext("local", "test", conf)
- }
-
- val createCombiner: (Int => ArrayBuffer[Int]) = i => ArrayBuffer[Int](i)
- val mergeValue: (ArrayBuffer[Int], Int) => ArrayBuffer[Int] = (buffer, i) => {
+ private val createCombiner: (Int => ArrayBuffer[Int]) = i => ArrayBuffer[Int](i)
+ private val mergeValue: (ArrayBuffer[Int], Int) => ArrayBuffer[Int] = (buffer, i) => {
buffer += i
}
- val mergeCombiners: (ArrayBuffer[Int], ArrayBuffer[Int]) => ArrayBuffer[Int] =
+ private val mergeCombiners: (ArrayBuffer[Int], ArrayBuffer[Int]) => ArrayBuffer[Int] =
(buf1, buf2) => {
buf1 ++= buf2
}
test("simple insert") {
+ val conf = new SparkConf(false)
+ sc = new SparkContext("local", "test", conf)
+
val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
mergeValue, mergeCombiners)
@@ -48,6 +45,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with Local
}
test("insert with collision") {
+ val conf = new SparkConf(false)
+ sc = new SparkContext("local", "test", conf)
+
val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
mergeValue, mergeCombiners)
@@ -67,6 +67,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with Local
}
test("ordering") {
+ val conf = new SparkConf(false)
+ sc = new SparkContext("local", "test", conf)
+
val map1 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
mergeValue, mergeCombiners)
map1.insert(1, 10)
@@ -109,6 +112,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with Local
}
test("null keys and values") {
+ val conf = new SparkConf(false)
+ sc = new SparkContext("local", "test", conf)
+
val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
mergeValue, mergeCombiners)
map.insert(1, 5)
@@ -147,6 +153,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with Local
}
test("simple aggregator") {
+ val conf = new SparkConf(false)
+ sc = new SparkContext("local", "test", conf)
+
// reduceByKey
val rdd = sc.parallelize(1 to 10).map(i => (i%2, 1))
val result1 = rdd.reduceByKey(_+_).collect()
@@ -159,6 +168,8 @@ class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with Local
}
test("simple cogroup") {
+ val conf = new SparkConf(false)
+ sc = new SparkContext("local", "test", conf)
val rdd1 = sc.parallelize(1 to 4).map(i => (i, i))
val rdd2 = sc.parallelize(1 to 4).map(i => (i%2, i))
val result = rdd1.cogroup(rdd2).collect()
@@ -175,56 +186,56 @@ class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with Local
}
test("spilling") {
- // TODO: Figure out correct memory parameters to actually induce spilling
- // System.setProperty("spark.shuffle.buffer.mb", "1")
- // System.setProperty("spark.shuffle.buffer.fraction", "0.05")
+ // TODO: Use SparkConf (which currently throws connection reset exception)
+ System.setProperty("spark.shuffle.memoryFraction", "0.001")
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
- // reduceByKey - should spill exactly 6 times
- val rddA = sc.parallelize(0 until 10000).map(i => (i/2, i))
+ // reduceByKey - should spill ~8 times
+ val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i))
val resultA = rddA.reduceByKey(math.max(_, _)).collect()
- assert(resultA.length == 5000)
+ assert(resultA.length == 50000)
resultA.foreach { case(k, v) =>
k match {
case 0 => assert(v == 1)
- case 2500 => assert(v == 5001)
- case 4999 => assert(v == 9999)
+ case 25000 => assert(v == 50001)
+ case 49999 => assert(v == 99999)
case _ =>
}
}
- // groupByKey - should spill exactly 11 times
- val rddB = sc.parallelize(0 until 10000).map(i => (i/4, i))
+ // groupByKey - should spill ~17 times
+ val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i))
val resultB = rddB.groupByKey().collect()
- assert(resultB.length == 2500)
+ assert(resultB.length == 25000)
resultB.foreach { case(i, seq) =>
i match {
case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3))
- case 1250 => assert(seq.toSet == Set[Int](5000, 5001, 5002, 5003))
- case 2499 => assert(seq.toSet == Set[Int](9996, 9997, 9998, 9999))
+ case 12500 => assert(seq.toSet == Set[Int](50000, 50001, 50002, 50003))
+ case 24999 => assert(seq.toSet == Set[Int](99996, 99997, 99998, 99999))
case _ =>
}
}
- // cogroup - should spill exactly 7 times
- val rddC1 = sc.parallelize(0 until 1000).map(i => (i, i))
- val rddC2 = sc.parallelize(0 until 1000).map(i => (i%100, i))
+ // cogroup - should spill ~7 times
+ val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i))
+ val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i))
val resultC = rddC1.cogroup(rddC2).collect()
- assert(resultC.length == 1000)
+ assert(resultC.length == 10000)
resultC.foreach { case(i, (seq1, seq2)) =>
i match {
case 0 =>
assert(seq1.toSet == Set[Int](0))
- assert(seq2.toSet == Set[Int](0, 100, 200, 300, 400, 500, 600, 700, 800, 900))
- case 500 =>
- assert(seq1.toSet == Set[Int](500))
+ assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
+ case 5000 =>
+ assert(seq1.toSet == Set[Int](5000))
assert(seq2.toSet == Set[Int]())
- case 999 =>
- assert(seq1.toSet == Set[Int](999))
+ case 9999 =>
+ assert(seq1.toSet == Set[Int](9999))
assert(seq2.toSet == Set[Int]())
case _ =>
}
}
- }
- // TODO: Test memory allocation for multiple concurrently running tasks
+ System.clearProperty("spark.shuffle.memoryFraction")
+ }
}
diff --git a/docs/_config.yml b/docs/_config.yml
index 11d18f0ac2..ce0fdf5fb4 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -5,6 +5,6 @@ markdown: kramdown
# of Spark, Scala, and Mesos.
SPARK_VERSION: 0.9.0-incubating-SNAPSHOT
SPARK_VERSION_SHORT: 0.9.0
-SCALA_VERSION: 2.10
+SCALA_VERSION: "2.10"
MESOS_VERSION: 0.13.0
SPARK_ISSUE_TRACKER_URL: https://spark-project.atlassian.net
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index ad7969d012..c529d89ffd 100755
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -21,7 +21,7 @@
<link rel="stylesheet" href="css/main.css">
<script src="js/vendor/modernizr-2.6.1-respond-1.1.0.min.js"></script>
-
+
<link rel="stylesheet" href="css/pygments-default.css">
<!-- Google analytics script -->
@@ -68,9 +68,10 @@
<li><a href="streaming-programming-guide.html">Spark Streaming</a></li>
<li><a href="mllib-guide.html">MLlib (Machine Learning)</a></li>
<li><a href="bagel-programming-guide.html">Bagel (Pregel on Spark)</a></li>
+ <li><a href="graphx-programming-guide.html">GraphX (Graph Processing)</a></li>
</ul>
</li>
-
+
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown">API Docs<b class="caret"></b></a>
<ul class="dropdown-menu">
@@ -80,6 +81,7 @@
<li><a href="api/streaming/index.html#org.apache.spark.streaming.package">Spark Streaming</a></li>
<li><a href="api/mllib/index.html#org.apache.spark.mllib.package">MLlib (Machine Learning)</a></li>
<li><a href="api/bagel/index.html#org.apache.spark.bagel.package">Bagel (Pregel on Spark)</a></li>
+ <li><a href="api/graphx/index.html#org.apache.spark.graphx.package">GraphX (Graph Processing)</a></li>
</ul>
</li>
@@ -161,7 +163,7 @@
<script src="js/vendor/jquery-1.8.0.min.js"></script>
<script src="js/vendor/bootstrap.min.js"></script>
<script src="js/main.js"></script>
-
+
<!-- A script to fix internal hash links because we have an overlapping top bar.
Based on https://github.com/twitter/bootstrap/issues/193#issuecomment-2281510 -->
<script>
diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb
index 431de909cb..acc6bf0816 100644
--- a/docs/_plugins/copy_api_dirs.rb
+++ b/docs/_plugins/copy_api_dirs.rb
@@ -20,7 +20,7 @@ include FileUtils
if not (ENV['SKIP_API'] == '1' or ENV['SKIP_SCALADOC'] == '1')
# Build Scaladoc for Java/Scala
- projects = ["core", "examples", "repl", "bagel", "streaming", "mllib"]
+ projects = ["core", "examples", "repl", "bagel", "graphx", "streaming", "mllib"]
puts "Moving to project root and building scaladoc."
curr_dir = pwd
diff --git a/docs/api.md b/docs/api.md
index e86d07770a..91c8e51d26 100644
--- a/docs/api.md
+++ b/docs/api.md
@@ -9,4 +9,5 @@ Here you can find links to the Scaladoc generated for the Spark sbt subprojects.
- [Spark Examples](api/examples/index.html)
- [Spark Streaming](api/streaming/index.html)
- [Bagel](api/bagel/index.html)
+- [GraphX](api/graphx/index.html)
- [PySpark](api/pyspark/index.html)
diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md
index c4f1f6d6ad..cffa55ee95 100644
--- a/docs/bagel-programming-guide.md
+++ b/docs/bagel-programming-guide.md
@@ -3,6 +3,8 @@ layout: global
title: Bagel Programming Guide
---
+**Bagel will soon be superseded by [GraphX](graphx-programming-guide.html); we recommend that new users try GraphX instead.**
+
Bagel is a Spark implementation of Google's [Pregel](http://portal.acm.org/citation.cfm?id=1807184) graph processing framework. Bagel currently supports basic graph computation, combiners, and aggregators.
In the Pregel programming model, jobs run as a sequence of iterations called _supersteps_. In each superstep, each vertex in the graph runs a user-specified function that can update state associated with the vertex and send messages to other vertices for use in the *next* iteration.
@@ -21,7 +23,7 @@ To use Bagel in your program, add the following SBT or Maven dependency:
Bagel operates on a graph represented as a [distributed dataset](scala-programming-guide.html) of (K, V) pairs, where keys are vertex IDs and values are vertices plus their associated state. In each superstep, Bagel runs a user-specified compute function on each vertex that takes as input the current vertex state and a list of messages sent to that vertex during the previous superstep, and returns the new vertex state and a list of outgoing messages.
-For example, we can use Bagel to implement PageRank. Here, vertices represent pages, edges represent links between pages, and messages represent shares of PageRank sent to the pages that a particular page links to.
+For example, we can use Bagel to implement PageRank. Here, vertices represent pages, edges represent links between pages, and messages represent shares of PageRank sent to the pages that a particular page links to.
We first extend the default `Vertex` class to store a `Double`
representing the current PageRank of the vertex, and similarly extend
@@ -38,7 +40,7 @@ import org.apache.spark.bagel.Bagel._
val active: Boolean) extends Vertex
@serializable class PRMessage(
- val targetId: String, val rankShare: Double) extends Message
+ val targetId: String, val rankShare: Double) extends Message
{% endhighlight %}
Next, we load a sample graph from a text file as a distributed dataset and package it into `PRVertex` objects. We also cache the distributed dataset because Bagel will use it multiple times and we'd like to avoid recomputing it.
@@ -114,7 +116,7 @@ Here are the actions and types in the Bagel API. See [Bagel.scala](https://githu
/*** Full form ***/
Bagel.run(sc, vertices, messages, combiner, aggregator, partitioner, numSplits)(compute)
-// where compute takes (vertex: V, combinedMessages: Option[C], aggregated: Option[A], superstep: Int)
+// where compute takes (vertex: V, combinedMessages: Option[C], aggregated: Option[A], superstep: Int)
// and returns (newVertex: V, outMessages: Array[M])
/*** Abbreviated forms ***/
@@ -124,7 +126,7 @@ Bagel.run(sc, vertices, messages, combiner, partitioner, numSplits)(compute)
// and returns (newVertex: V, outMessages: Array[M])
Bagel.run(sc, vertices, messages, combiner, numSplits)(compute)
-// where compute takes (vertex: V, combinedMessages: Option[C], superstep: Int)
+// where compute takes (vertex: V, combinedMessages: Option[C], superstep: Int)
// and returns (newVertex: V, outMessages: Array[M])
Bagel.run(sc, vertices, messages, numSplits)(compute)
diff --git a/docs/configuration.md b/docs/configuration.md
index ad75e06fc7..be06bd19be 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -116,7 +116,7 @@ Apart from these, the following properties are also available, and may be useful
<td>0.3</td>
<td>
Fraction of Java heap to use for aggregation and cogroups during shuffles, if
- <code>spark.shuffle.externalSorting</code> is enabled. At any given time, the collective size of
+ <code>spark.shuffle.spill</code> is true. At any given time, the collective size of
all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will
begin to spill to disk. If spills are often, consider increasing this value at the expense of
<code>spark.storage.memoryFraction</code>.
@@ -155,6 +155,13 @@ Apart from these, the following properties are also available, and may be useful
</td>
</tr>
<tr>
+ <td>spark.shuffle.spill.compress</td>
+ <td>false</td>
+ <td>
+ Whether to compress data spilled during shuffles.
+ </td>
+</tr>
+<tr>
<td>spark.broadcast.compress</td>
<td>true</td>
<td>
@@ -382,13 +389,13 @@ Apart from these, the following properties are also available, and may be useful
<tr>
<td>spark.shuffle.consolidateFiles</td>
- <td>true</td>
+ <td>false</td>
<td>
If set to "true", consolidates intermediate files created during a shuffle. Creating fewer files can improve filesystem performance for shuffles with large numbers of reduce tasks. It is recommended to set this to "true" when using ext4 or xfs filesystems. On ext3, this option might degrade performance on machines with many (>8) cores due to filesystem limitations.
</td>
</tr>
<tr>
- <td>spark.shuffle.externalSorting</td>
+ <td>spark.shuffle.spill</td>
<td>true</td>
<td>
If set to "true", limits the amount of memory used during reduces by spilling data out to disk. This spilling
diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md
new file mode 100644
index 0000000000..9fbde4eb09
--- /dev/null
+++ b/docs/graphx-programming-guide.md
@@ -0,0 +1,1003 @@
+---
+layout: global
+title: GraphX Programming Guide
+---
+
+* This will become a table of contents (this text will be scraped).
+{:toc}
+
+<p style="text-align: center;">
+ <img src="img/graphx_logo.png"
+ title="GraphX Logo"
+ alt="GraphX"
+ width="65%" />
+ <!-- Images are downsized intentionally to improve quality on retina displays -->
+</p>
+
+# Overview
+
+GraphX is the new (alpha) Spark API for graphs and graph-parallel computation. At a high-level,
+GraphX extends the Spark [RDD](api/core/index.html#org.apache.spark.rdd.RDD) by introducing the
+[Resilient Distributed property Graph (RDG)](#property_graph): a directed multigraph with properties
+attached to each vertex and edge. To support graph computation, GraphX exposes a set of fundamental
+operators (e.g., [subgraph](#structural_operators), [joinVertices](#join_operators), and
+[mapReduceTriplets](#mrTriplets)) as well as an optimized variant of the [Pregel](#pregel) API. In
+addition, GraphX includes a growing collection of graph [algorithms](#graph_algorithms) and
+[builders](#graph_builders) to simplify graph analytics tasks.
+
+## Background on Graph-Parallel Computation
+
+From social networks to language modeling, the growing scale and importance of
+graph data has driven the development of numerous new *graph-parallel* systems
+(e.g., [Giraph](http://http://giraph.apache.org) and
+[GraphLab](http://graphlab.org)). By restricting the types of computation that can be
+expressed and introducing new techniques to partition and distribute graphs,
+these systems can efficiently execute sophisticated graph algorithms orders of
+magnitude faster than more general *data-parallel* systems.
+
+<p style="text-align: center;">
+ <img src="img/data_parallel_vs_graph_parallel.png"
+ title="Data-Parallel vs. Graph-Parallel"
+ alt="Data-Parallel vs. Graph-Parallel"
+ width="50%" />
+ <!-- Images are downsized intentionally to improve quality on retina displays -->
+</p>
+
+However, the same restrictions that enable these substantial performance gains
+also make it difficult to express many of the important stages in a typical graph-analytics pipeline:
+constructing the graph, modifying its structure, or expressing computation that
+spans multiple graphs. As a consequence, existing graph analytics pipelines
+compose graph-parallel and data-parallel systems, leading to extensive data
+movement and duplication and a complicated programming model.
+
+<p style="text-align: center;">
+ <img src="img/graph_analytics_pipeline.png"
+ title="Graph Analytics Pipeline"
+ alt="Graph Analytics Pipeline"
+ width="50%" />
+ <!-- Images are downsized intentionally to improve quality on retina displays -->
+</p>
+
+The goal of the GraphX project is to unify graph-parallel and data-parallel computation in one
+system with a single composable API. The GraphX API enables users to view data both as a graph and
+as collections (i.e., RDDs) without data movement or duplication. By incorporating recent advances
+in graph-parallel systems, GraphX is able to optimize the execution of graph operations.
+
+## GraphX Replaces the Spark Bagel API
+
+Prior to the release of GraphX, graph computation in Spark was expressed using Bagel, an
+implementation of Pregel. GraphX improves upon Bagel by exposing a richer property graph API, a
+more streamlined version of the Pregel abstraction, and system optimizations to improve performance
+and reduce memory overhead. While we plan to eventually deprecate Bagel, we will continue to
+support the [Bagel API](api/bagel/index.html#org.apache.spark.bagel.package) and
+[Bagel programming guide](bagel-programming-guide.html). However, we encourage Bagel users to
+explore the new GraphX API and comment on issues that may complicate the transition from Bagel.
+
+# Getting Started
+
+To get started you first need to import Spark and GraphX into your project, as follows:
+
+{% highlight scala %}
+import org.apache.spark._
+import org.apache.spark.graphx._
+// To make some of the examples work we will also need RDD
+import org.apache.spark.rdd.RDD
+{% endhighlight %}
+
+If you are not using the Spark shell you will also need a `SparkContext`. To learn more about
+getting started with Spark refer to the [Spark Quick Start Guide](quick-start.html).
+
+# The Property Graph
+<a name="property_graph"></a>
+
+The [property graph](api/graphx/index.html#org.apache.spark.graphx.Graph) is a directed multigraph
+with user defined objects attached to each vertex and edge. A directed multigraph is a directed
+graph with potentially multiple parallel edges sharing the same source and destination vertex. The
+ability to support parallel edges simplifies modeling scenarios where there can be multiple
+relationships (e.g., co-worker and friend) between the same vertices. Each vertex is keyed by a
+*unique* 64-bit long identifier (`VertexId`). Similarly, edges have corresponding source and
+destination vertex identifiers. GraphX does not impose any ordering or constraints on the vertex
+identifiers. The property graph is parameterized over the vertex `VD` and edge `ED` types. These
+are the types of the objects associated with each vertex and edge respectively.
+
+> GraphX optimizes the representation of `VD` and `ED` when they are plain old data-types (e.g.,
+> int, double, etc...) reducing the in memory footprint.
+
+In some cases we may wish to have vertices with different property types in the same graph. This can
+be accomplished through inheritance. For example to model users and products as a bipartite graph
+we might do the following:
+
+{% highlight scala %}
+class VertexProperty()
+case class UserProperty(val name: String) extends VertexProperty
+case class ProductProperty(val name: String, val price: Double) extends VertexProperty
+// The graph might then have the type:
+var graph: Graph[VertexProperty, String] = null
+{% endhighlight %}
+
+Like RDDs, property graphs are immutable, distributed, and fault-tolerant. Changes to the values or
+structure of the graph are accomplished by producing a new graph with the desired changes. The graph
+is partitioned across the workers using a range of vertex-partitioning heuristics. As with RDDs,
+each partition of the graph can be recreated on a different machine in the event of a failure.
+
+Logically the property graph corresponds to a pair of typed collections (RDDs) encoding the
+properties for each vertex and edge. As a consequence, the graph class contains members to access
+the vertices and edges of the graph:
+
+{% highlight scala %}
+class Graph[VD, ED] {
+ val vertices: VertexRDD[VD]
+ val edges: EdgeRDD[ED]
+}
+{% endhighlight %}
+
+The classes `VertexRDD[VD]` and `EdgeRDD[ED]` extend and are optimized versions of `RDD[(VertexId,
+VD)]` and `RDD[Edge[ED]]` respectively. Both `VertexRDD[VD]` and `EdgeRDD[ED]` provide additional
+functionality built around graph computation and leverage internal optimizations. We discuss the
+`VertexRDD` and `EdgeRDD` API in greater detail in the section on [vertex and edge
+RDDs](#vertex_and_edge_rdds) but for now they can be thought of as simply RDDs of the form:
+`RDD[(VertexId, VD)]` and `RDD[Edge[ED]]`.
+
+### Example Property Graph
+
+Suppose we want to construct a property graph consisting of the various collaborators on the GraphX
+project. The vertex property might contain the username and occupation. We could annotate edges
+with a string describing the relationships between collaborators:
+
+<p style="text-align: center;">
+ <img src="img/property_graph.png"
+ title="The Property Graph"
+ alt="The Property Graph"
+ width="50%" />
+ <!-- Images are downsized intentionally to improve quality on retina displays -->
+</p>
+
+The resulting graph would have the type signature:
+
+{% highlight scala %}
+val userGraph: Graph[(String, String), String]
+{% endhighlight %}
+
+There are numerous ways to construct a property graph from raw files, RDDs, and even synthetic
+generators and these are discussed in more detail in the section on
+[graph builders](#graph_builders). Probably the most general method is to use the
+[Graph object](api/graphx/index.html#org.apache.spark.graphx.Graph$). For example the following
+code constructs a graph from a collection of RDDs:
+
+{% highlight scala %}
+// Assume the SparkContext has already been constructed
+val sc: SparkContext
+// Create an RDD for the vertices
+val users: RDD[(VertexID, (String, String))] =
+ sc.parallelize(Array((3L, ("rxin", "student")), (7L, ("jgonzal", "postdoc")),
+ (5L, ("franklin", "prof")), (2L, ("istoica", "prof"))))
+// Create an RDD for edges
+val relationships: RDD[Edge[String]] =
+ sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"),
+ Edge(2L, 5L, "colleague"), Edge(5L, 7L, "pi")))
+// Define a default user in case there are relationship with missing user
+val defaultUser = ("John Doe", "Missing")
+// Build the initial Graph
+val graph = Graph(users, relationships, defaultUser)
+{% endhighlight %}
+
+In the above example we make use of the [`Edge`][Edge] case class. Edges have a `srcId` and a
+`dstId` corresponding to the source and destination vertex identifiers. In addition, the `Edge`
+class contains the `attr` member which contains the edge property.
+
+[Edge]: api/graphx/index.html#org.apache.spark.graphx.Edge
+
+We can deconstruct a graph into the respective vertex and edge views by using the `graph.vertices`
+and `graph.edges` members respectively.
+
+{% highlight scala %}
+val graph: Graph[(String, String), String] // Constructed from above
+// Count all users which are postdocs
+graph.vertices.filter { case (id, (name, pos)) => pos == "postdoc" }.count
+// Count all the edges where src > dst
+graph.edges.filter(e => e.srcId > e.dstId).count
+{% endhighlight %}
+
+> Note that `graph.vertices` returns an `VertexRDD[(String, String)]` which extends
+> `RDD[(VertexId, (String, String))]` and so we use the scala `case` expression to deconstruct the
+> tuple. On the other hand, `graph.edges` returns an `EdgeRDD` containing `Edge[String]` objects.
+> We could have also used the case class type constructor as in the following:
+> {% highlight scala %}
+graph.edges.filter { case Edge(src, dst, prop) => src > dst }.count
+{% endhighlight %}
+
+In addition to the vertex and edge views of the property graph, GraphX also exposes a triplet view.
+The triplet view logically joins the vertex and edge properties yielding an
+`RDD[EdgeTriplet[VD, ED]]` containing instances of the [`EdgeTriplet`][EdgeTriplet] class. This
+*join* can be expressed in the following SQL expression:
+
+[EdgeTriplet]: api/graphx/index.html#org.apache.spark.graphx.EdgeTriplet
+
+{% highlight sql %}
+SELECT src.id, dst.id, src.attr, e.attr, dst.attr
+FROM edges AS e LEFT JOIN vertices AS src, vertices AS dst
+ON e.srcId = src.Id AND e.dstId = dst.Id
+{% endhighlight %}
+
+or graphically as:
+
+<p style="text-align: center;">
+ <img src="img/triplet.png"
+ title="Edge Triplet"
+ alt="Edge Triplet"
+ width="50%" />
+ <!-- Images are downsized intentionally to improve quality on retina displays -->
+</p>
+
+The [`EdgeTriplet`][EdgeTriplet] class extends the [`Edge`][Edge] class by adding the `srcAttr` and
+`dstAttr` members which contain the source and destination properties respectively. We can use the
+triplet view of a graph to render a collection of strings describing relationships between users.
+
+{% highlight scala %}
+val graph: Graph[(String, String), String] // Constructed from above
+// Use the triplets view to create an RDD of facts.
+val facts: RDD[String] =
+ graph.triplets.map(triplet =>
+ triplet.srcAttr._1 + " is the " + triplet.attr + " of " + triplet.dstAttr._1)
+facts.collect.foreach(println(_))
+{% endhighlight %}
+
+# Graph Operators
+
+Just as RDDs have basic operations like `map`, `filter`, and `reduceByKey`, property graphs also
+have a collection of basic operators that take user defined functions and produce new graphs with
+transformed properties and structure. The core operators that have optimized implementations are
+defined in [`Graph`][Graph] and convenient operators that are expressed as a compositions of the
+core operators are defined in [`GraphOps`][GraphOps]. However, thanks to Scala implicits the
+operators in `GraphOps` are automatically available as members of `Graph`. For example, we can
+compute the in-degree of each vertex (defined in `GraphOps`) by the following:
+
+[Graph]: api/graphx/index.html#org.apache.spark.graphx.Graph
+[GraphOps]: api/graphx/index.html#org.apache.spark.graphx.GraphOps
+
+{% highlight scala %}
+val graph: Graph[(String, String), String]
+// Use the implicit GraphOps.inDegrees operator
+val inDegrees: VertexRDD[Int] = graph.inDegrees
+{% endhighlight %}
+
+The reason for differentiating between core graph operations and [`GraphOps`][GraphOps] is to be
+able to support different graph representations in the future. Each graph representation must
+provide implementations of the core operations and reuse many of the useful operations defined in
+[`GraphOps`][GraphOps].
+
+## Property Operators
+
+In direct analogy to the RDD `map` operator, the property
+graph contains the following:
+
+{% highlight scala %}
+class Graph[VD, ED] {
+ def mapVertices[VD2](map: (VertexID, VD) => VD2): Graph[VD2, ED]
+ def mapEdges[ED2](map: Edge[ED] => ED2): Graph[VD, ED2]
+ def mapTriplets[ED2](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2]
+}
+{% endhighlight %}
+
+Each of these operators yields a new graph with the vertex or edge properties modified by the user
+defined `map` function.
+
+> Note that in all cases the graph structure is unaffected. This is a key feature of these operators
+> which allows the resulting graph to reuse the structural indices of the original graph. The
+> following snippets are logically equivalent, but the first one does not preserve the structural
+> indices and would not benefit from the GraphX system optimizations:
+> {% highlight scala %}
+val newVertices = graph.vertices.map { case (id, attr) => (id, mapUdf(id, attr)) }
+val newGraph = Graph(newVertices, graph.edges)
+{% endhighlight %}
+> Instead, use [`mapVertices`][Graph.mapVertices] to preserve the indices:
+> {% highlight scala %}
+val newGraph = graph.mapVertices((id, attr) => mapUdf(id, attr))
+{% endhighlight %}
+
+[Graph.mapVertices]: api/graphx/index.html#org.apache.spark.graphx.Graph@mapVertices[VD2]((VertexID,VD)⇒VD2)(ClassTag[VD2]):Graph[VD2,ED]
+
+These operators are often used to initialize the graph for a particular computation or project away
+unnecessary properties. For example, given a graph with the out-degrees as the vertex properties
+(we describe how to construct such a graph later), we initialize it for PageRank:
+
+{% highlight scala %}
+// Given a graph where the vertex property is the out-degree
+val inputGraph: Graph[Int, String] =
+ graph.outerJoinVertices(graph.outDegrees)((vid, _, degOpt) => degOpt.getOrElse(0))
+// Construct a graph where each edge contains the weight
+// and each vertex is the initial PageRank
+val outputGraph: Graph[Double, Double] =
+ inputGraph.mapTriplets(triplet => 1.0 / triplet.srcAttr).mapVertices((id, _) => 1.0)
+{% endhighlight %}
+
+## Structural Operators
+<a name="structural_operators"></a>
+
+Currently GraphX supports only a simple set of commonly used structural operators and we expect to
+add more in the future. The following is a list of the basic structural operators.
+
+{% highlight scala %}
+class Graph[VD, ED] {
+ def reverse: Graph[VD, ED]
+ def subgraph(epred: EdgeTriplet[VD,ED] => Boolean,
+ vpred: (VertexID, VD) => Boolean): Graph[VD, ED]
+ def mask[VD2, ED2](other: Graph[VD2, ED2]): Graph[VD, ED]
+ def groupEdges(merge: (ED, ED) => ED): Graph[VD,ED]
+}
+{% endhighlight %}
+
+The [`reverse`][Graph.reverse] operator returns a new graph with all the edge directions reversed.
+This can be useful when, for example, trying to compute the inverse PageRank. Because the reverse
+operation does not modify vertex or edge properties or change the number of edges, it can be
+implemented efficiently without data-movement or duplication.
+
+[Graph.reverse]: api/graphx/index.html#org.apache.spark.graphx.Graph@reverse:Graph[VD,ED]
+
+The [`subgraph`][Graph.subgraph] operator takes vertex and edge predicates and returns the graph
+containing only the vertices that satisfy the vertex predicate (evaluate to true) and edges that
+satisfy the edge predicate *and connect vertices that satisfy the vertex predicate*. The `subgraph`
+operator can be used in number of situations to restrict the graph to the vertices and edges of
+interest or eliminate broken links. For example in the following code we remove broken links:
+
+[Graph.subgraph]: api/graphx/index.html#org.apache.spark.graphx.Graph@subgraph((EdgeTriplet[VD,ED])⇒Boolean,(VertexID,VD)⇒Boolean):Graph[VD,ED]
+
+{% highlight scala %}
+// Create an RDD for the vertices
+val users: RDD[(VertexID, (String, String))] =
+ sc.parallelize(Array((3L, ("rxin", "student")), (7L, ("jgonzal", "postdoc")),
+ (5L, ("franklin", "prof")), (2L, ("istoica", "prof")),
+ (4L, ("peter", "student"))))
+// Create an RDD for edges
+val relationships: RDD[Edge[String]] =
+ sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"),
+ Edge(2L, 5L, "colleague"), Edge(5L, 7L, "pi"),
+ Edge(4L, 0L, "student"), Edge(5L, 0L, "colleague")))
+// Define a default user in case there are relationship with missing user
+val defaultUser = ("John Doe", "Missing")
+// Build the initial Graph
+val graph = Graph(users, relationships, defaultUser)
+// Notice that there is a user 0 (for which we have no information) connected to users
+// 4 (peter) and 5 (franklin).
+graph.triplets.map(
+ triplet => triplet.srcAttr._1 + " is the " + triplet.attr + " of " + triplet.dstAttr._1
+ ).collect.foreach(println(_))
+// Remove missing vertices as well as the edges to connected to them
+val validGraph = graph.subgraph(vpred = (id, attr) => attr._2 != "Missing")
+// The valid subgraph will disconnect users 4 and 5 by removing user 0
+validGraph.vertices.collect.foreach(println(_))
+validGraph.triplets.map(
+ triplet => triplet.srcAttr._1 + " is the " + triplet.attr + " of " + triplet.dstAttr._1
+ ).collect.foreach(println(_))
+{% endhighlight %}
+
+> Note in the above example only the vertex predicate is provided. The `subgraph` operator defaults
+> to `true` if the vertex or edge predicates are not provided.
+
+The [`mask`][Graph.mask] operator also constructs a subgraph by returning a graph that contains the
+vertices and edges that are also found in the input graph. This can be used in conjunction with the
+`subgraph` operator to restrict a graph based on the properties in another related graph. For
+example, we might run connected components using the graph with missing vertices and then restrict
+the answer to the valid subgraph.
+
+[Graph.mask]: api/graphx/index.html#org.apache.spark.graphx.Graph@mask[VD2,ED2](Graph[VD2,ED2])(ClassTag[VD2],ClassTag[ED2]):Graph[VD,ED]
+
+{% highlight scala %}
+// Run Connected Components
+val ccGraph = graph.connectedComponents() // No longer contains missing field
+// Remove missing vertices as well as the edges to connected to them
+val validGraph = graph.subgraph(vpred = (id, attr) => attr._2 != "Missing")
+// Restrict the answer to the valid subgraph
+val validCCGraph = ccGraph.mask(validGraph)
+{% endhighlight %}
+
+The [`groupEdges`][Graph.groupEdges] operator merges parallel edges (i.e., duplicate edges between
+pairs of vertices) in the multigraph. In many numerical applications, parallel edges can be *added*
+(their weights combined) into a single edge thereby reducing the size of the graph.
+
+[Graph.groupEdges]: api/graphx/index.html#org.apache.spark.graphx.Graph@groupEdges((ED,ED)⇒ED):Graph[VD,ED]
+
+## Join Operators
+<a name="join_operators"></a>
+
+In many cases it is necessary to join data from external collections (RDDs) with graphs. For
+example, we might have extra user properties that we want to merge with an existing graph or we
+might want to pull vertex properties from one graph into another. These tasks can be accomplished
+using the *join* operators. Below we list the key join operators:
+
+{% highlight scala %}
+class Graph[VD, ED] {
+ def joinVertices[U](table: RDD[(VertexID, U)])(map: (VertexID, VD, U) => VD)
+ : Graph[VD, ED]
+ def outerJoinVertices[U, VD2](table: RDD[(VertexID, U)])(map: (VertexID, VD, Option[U]) => VD2)
+ : Graph[VD2, ED]
+}
+{% endhighlight %}
+
+The [`joinVertices`][GraphOps.joinVertices] operator joins the vertices with the input RDD and
+returns a new graph with the vertex properties obtained by applying the user defined `map` function
+to the result of the joined vertices. Vertices without a matching value in the RDD retain their
+original value.
+
+[GraphOps.joinVertices]: api/graphx/index.html#org.apache.spark.graphx.GraphOps@joinVertices[U](RDD[(VertexID,U)])((VertexID,VD,U)⇒VD)(ClassTag[U]):Graph[VD,ED]
+
+> Note that if the RDD contains more than one value for a given vertex only one will be used. It
+> is therefore recommended that the input RDD be first made unique using the following which will
+> also *pre-index* the resulting values to substantially accelerate the subsequent join.
+> {% highlight scala %}
+val nonUniqueCosts: RDD[(VertexId, Double)]
+val uniqueCosts: VertexRDD[Double] =
+ graph.vertices.aggregateUsingIndex(nonUnique, (a,b) => a + b)
+val joinedGraph = graph.joinVertices(uniqueCosts)(
+ (id, oldCost, extraCost) => oldCost + extraCost)
+{% endhighlight %}
+
+The more general [`outerJoinVertices`][Graph.outerJoinVertices] behaves similarly to `joinVertices`
+except that the user defined `map` function is applied to all vertices and can change the vertex
+property type. Because not all vertices may have a matching value in the input RDD the `map`
+function takes an `Option` type. For example, we can setup a graph for PageRank by initializing
+vertex properties with their `outDegree`.
+
+[Graph.outerJoinVertices]: api/graphx/index.html#org.apache.spark.graphx.Graph@outerJoinVertices[U,VD2](RDD[(VertexID,U)])((VertexID,VD,Option[U])⇒VD2)(ClassTag[U],ClassTag[VD2]):Graph[VD2,ED]
+
+
+{% highlight scala %}
+val outDegrees: VertexRDD[Int] = graph.outDegrees
+val degreeGraph = graph.outerJoinVertices(outDegrees) { (id, oldAttr, outDegOpt) =>
+ outDegOpt match {
+ case Some(outDeg) => outDeg
+ case None => 0 // No outDegree means zero outDegree
+ }
+}
+{% endhighlight %}
+
+> You may have noticed the multiple parameter lists (e.g., `f(a)(b)`) curried function pattern used
+> in the above examples. While we could have equally written `f(a)(b)` as `f(a,b)` this would mean
+> that type inference on `b` would not depend on `a`. As a consequence, the user would need to
+> provide type annotation for the user defined function:
+> {% highlight scala %}
+val joinedGraph = graph.joinVertices(uniqueCosts,
+ (id: VertexId, oldCost: Double, extraCost: Double) => oldCost + extraCost)
+{% endhighlight %}
+
+
+## Neighborhood Aggregation
+
+A key part of graph computation is aggregating information about the neighborhood of each vertex.
+For example we might want to know the number of followers each user has or the average age of the
+the followers of each user. Many iterative graph algorithms (e.g., PageRank, Shortest Path, and
+connected components) repeatedly aggregate properties of neighboring vertices (e.g., current
+PageRank Value, shortest path to the source, and smallest reachable vertex id).
+
+### Map Reduce Triplets (mapReduceTriplets)
+<a name="mrTriplets"></a>
+
+[Graph.mapReduceTriplets]: api/graphx/index.html#org.apache.spark.graphx.Graph@mapReduceTriplets[A](mapFunc:org.apache.spark.graphx.EdgeTriplet[VD,ED]=&gt;Iterator[(org.apache.spark.graphx.VertexID,A)],reduceFunc:(A,A)=&gt;A,activeSetOpt:Option[(org.apache.spark.graphx.VertexRDD[_],org.apache.spark.graphx.EdgeDirection)])(implicitevidence$10:scala.reflect.ClassTag[A]):org.apache.spark.graphx.VertexRDD[A]
+
+The core (heavily optimized) aggregation primitive in GraphX is the
+[`mapReduceTriplets`][Graph.mapReduceTriplets] operator:
+
+{% highlight scala %}
+class Graph[VD, ED] {
+ def mapReduceTriplets[A](
+ map: EdgeTriplet[VD, ED] => Iterator[(VertexID, A)],
+ reduce: (A, A) => A)
+ : VertexRDD[A]
+}
+{% endhighlight %}
+
+The [`mapReduceTriplets`][Graph.mapReduceTriplets] operator takes a user defined map function which
+is applied to each triplet and can yield *messages* destined to either (none or both) vertices in
+the triplet. To facilitate optimized pre-aggregation, we currently only support messages destined
+to the source or destination vertex of the triplet. The user defined `reduce` function combines the
+messages destined to each vertex. The `mapReduceTriplets` operator returns a `VertexRDD[A]`
+containing the aggregate message (of type `A`) destined to each vertex. Vertices that do not
+receive a message are not included in the returned `VertexRDD`.
+
+<blockquote>
+<p>
+Note that <code>mapReduceTriplets</code> takes an additional optional <code>activeSet</code>
+(see API docs) which restricts the map phase to edges adjacent to the vertices in the provided
+<code>VertexRDD</code>:
+</p>
+{% highlight scala %}
+ activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None
+{% endhighlight %}
+<p>
+The EdgeDirection specifies which edges adjacent to the vertex set are included in the map phase. If
+the direction is <code>In</code>, <code>mapFunc</code> will only be run only on edges with
+destination in the active set. If the direction is <code>Out</code>, <code>mapFunc</code> will only
+be run only on edges originating from vertices in the active set. If the direction is
+<code>Either</code>, <code>mapFunc</code> will be run only on edges with <i>either</i> vertex in the
+active set. If the direction is <code>Both</code>, <code>mapFunc</code> will be run only on edges
+with both vertices in the active set. The active set must be derived from the set of vertices in
+the graph. Restricting computation to triplets adjacent to a subset of the vertices is often
+necessary in incremental iterative computation and is a key part of the GraphX implementation of
+Pregel.
+</p>
+</blockquote>
+
+In the following example we use the `mapReduceTriplets` operator to compute the average age of the
+more senior followers of each user.
+
+{% highlight scala %}
+// Import random graph generation library
+import org.apache.spark.graphx.util.GraphGenerators
+// Create a graph with "age" as the vertex property. Here we use a random graph for simplicity.
+val graph: Graph[Double, Int] =
+ GraphGenerators.logNormalGraph(sc, numVertices = 100).mapVertices( (id, _) => id.toDouble )
+// Compute the number of older followers and their total age
+val olderFollowers: VertexRDD[(Int, Double)] = graph.mapReduceTriplets[(Int, Double)](
+ triplet => { // Map Function
+ if (triplet.srcAttr > triplet.dstAttr) {
+ // Send message to destination vertex containing counter and age
+ Iterator((triplet.dstId, (1, triplet.srcAttr)))
+ } else {
+ // Don't send a message for this triplet
+ Iterator.empty
+ }
+ },
+ // Add counter and age
+ (a, b) => (a._1 + b._1, a._2 + b._2) // Reduce Function
+)
+// Divide total age by number of older followers to get average age of older followers
+val avgAgeOfOlderFollowers: VertexRDD[Double] =
+ olderFollowers.mapValues( (id, value) => value match { case (count, totalAge) => totalAge / count } )
+// Display the results
+avgAgeOfOlderFollowers.collect.foreach(println(_))
+{% endhighlight %}
+
+> Note that the `mapReduceTriplets` operation performs optimally when the messages (and their sums)
+> are constant sized (e.g., floats and addition instead of lists and concatenation). More
+> precisely, the result of `mapReduceTriplets` should ideally be sub-linear in the degree of each
+> vertex.
+
+### Computing Degree Information
+
+A common aggregation task is computing the degree of each vertex: the number of edges adjacent to
+each vertex. In the context of directed graphs it often necessary to know the in-degree, out-
+degree, and the total degree of each vertex. The [`GraphOps`][GraphOps] class contains a
+collection of operators to compute the degrees of each vertex. For example in the following we
+compute the max in, out, and total degrees:
+
+{% highlight scala %}
+// Define a reduce operation to compute the highest degree vertex
+def max(a: (VertexID, Int), b: (VertexID, Int)): (VertexID, Int) = {
+ if (a._2 > b._2) a else b
+}
+// Compute the max degrees
+val maxInDegree: (VertexID, Int) = graph.inDegrees.reduce(max)
+val maxOutDegree: (VertexID, Int) = graph.outDegrees.reduce(max)
+val maxDegrees: (VertexID, Int) = graph.degrees.reduce(max)
+{% endhighlight %}
+
+### Collecting Neighbors
+
+In some cases it may be easier to express computation by collecting neighboring vertices and their
+attributes at each vertex. This can be easily accomplished using the
+[`collectNeighborIds`][GraphOps.collectNeighborIds] and the
+[`collectNeighbors`][GraphOps.collectNeighbors] operators.
+
+[GraphOps.collectNeighborIds]: api/graphx/index.html#org.apache.spark.graphx.GraphOps@collectNeighborIds(EdgeDirection):VertexRDD[Array[VertexID]]
+[GraphOps.collectNeighbors]: api/graphx/index.html#org.apache.spark.graphx.GraphOps@collectNeighbors(EdgeDirection):VertexRDD[Array[(VertexID,VD)]]
+
+
+{% highlight scala %}
+class GraphOps[VD, ED] {
+ def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexID]]
+ def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[ Array[(VertexID, VD)] ]
+}
+{% endhighlight %}
+
+> Note that these operators can be quite costly as they duplicate information and require
+> substantial communication. If possible try expressing the same computation using the
+> `mapReduceTriplets` operator directly.
+
+# Pregel API
+<a name="pregel"></a>
+
+Graphs are inherently recursive data-structures as properties of vertices depend on properties of
+their neighbors which intern depend on properties of *their* neighbors. As a
+consequence many important graph algorithms iteratively recompute the properties of each vertex
+until a fixed-point condition is reached. A range of graph-parallel abstractions have been proposed
+to express these iterative algorithms. GraphX exposes a Pregel-like operator which is a fusion of
+the widely used Pregel and GraphLab abstractions.
+
+At a high-level the Pregel operator in GraphX is a bulk-synchronous parallel messaging abstraction
+*constrained to the topology of the graph*. The Pregel operator executes in a series of super-steps
+in which vertices receive the *sum* of their inbound messages from the previous super- step, compute
+a new value for the vertex property, and then send messages to neighboring vertices in the next
+super-step. Unlike Pregel and instead more like GraphLab messages are computed in parallel as a
+function of the edge triplet and the message computation has access to both the source and
+destination vertex attributes. Vertices that do not receive a message are skipped within a super-
+step. The Pregel operators terminates iteration and returns the final graph when there are no
+messages remaining.
+
+> Note, unlike more standard Pregel implementations, vertices in GraphX can only send messages to
+> neighboring vertices and the message construction is done in parallel using a user defined
+> messaging function. These constraints allow additional optimization within GraphX.
+
+The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch*
+of its implementation (note calls to graph.cache have been removed):
+
+[GraphOps.pregel]: api/graphx/index.html#org.apache.spark.graphx.GraphOps@pregel[A](A,Int,EdgeDirection)((VertexID,VD,A)⇒VD,(EdgeTriplet[VD,ED])⇒Iterator[(VertexID,A)],(A,A)⇒A)(ClassTag[A]):Graph[VD,ED]
+
+{% highlight scala %}
+class GraphOps[VD, ED] {
+ def pregel[A]
+ (initialMsg: A,
+ maxIter: Int = Int.MaxValue,
+ activeDir: EdgeDirection = EdgeDirection.Out)
+ (vprog: (VertexID, VD, A) => VD,
+ sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexID, A)],
+ mergeMsg: (A, A) => A)
+ : Graph[VD, ED] = {
+ // Receive the initial message at each vertex
+ var g = mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ).cache()
+ // compute the messages
+ var messages = g.mapReduceTriplets(sendMsg, mergeMsg)
+ var activeMessages = messages.count()
+ // Loop until no messages remain or maxIterations is achieved
+ var i = 0
+ while (activeMessages > 0 && i < maxIterations) {
+ // Receive the messages: -----------------------------------------------------------------------
+ // Run the vertex program on all vertices that receive messages
+ val newVerts = g.vertices.innerJoin(messages)(vprog).cache()
+ // Merge the new vertex values back into the graph
+ g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }.cache()
+ // Send Messages: ------------------------------------------------------------------------------
+ // Vertices that didn't receive a message above don't appear in newVerts and therefore don't
+ // get to send messages. More precisely the map phase of mapReduceTriplets is only invoked
+ // on edges in the activeDir of vertices in newVerts
+ messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDir))).cache()
+ activeMessages = messages.count()
+ i += 1
+ }
+ g
+ }
+}
+{% endhighlight %}
+
+Notice that Pregel takes two argument lists (i.e., `graph.pregel(list1)(list2)`). The first
+argument list contains configuration parameters including the initial message, the maximum number of
+iterations, and the edge direction in which to send messages (by default along out edges). The
+second argument list contains the user defined functions for receiving messages (the vertex program
+`vprog`), computing messages (`sendMsg`), and combining messages `mergeMsg`.
+
+We can use the Pregel operator to express computation such as single source
+shortest path in the following example.
+
+{% highlight scala %}
+import org.apache.spark.graphx._
+// Import random graph generation library
+import org.apache.spark.graphx.util.GraphGenerators
+// A graph with edge attributes containing distances
+val graph: Graph[Int, Double] =
+ GraphGenerators.logNormalGraph(sc, numVertices = 100).mapEdges(e => e.attr.toDouble)
+val sourceId: VertexID = 42 // The ultimate source
+// Initialize the graph such that all vertices except the root have distance infinity.
+val initialGraph = graph.mapVertices((id, _) => if (id == sourceId) 0.0 else Double.PositiveInfinity)
+val sssp = initialGraph.pregel(Double.PositiveInfinity)(
+ (id, dist, newDist) => math.min(dist, newDist), // Vertex Program
+ triplet => { // Send Message
+ if (triplet.srcAttr + triplet.attr < triplet.dstAttr) {
+ Iterator((triplet.dstId, triplet.srcAttr + triplet.attr))
+ } else {
+ Iterator.empty
+ }
+ },
+ (a,b) => math.min(a,b) // Merge Message
+ )
+println(sssp.vertices.collect.mkString("\n"))
+{% endhighlight %}
+
+# Graph Builders
+<a name="graph_builders"></a>
+
+GraphX provides several ways of building a graph from a collection of vertices and edges in an RDD or on disk. None of the graph builders repartitions the graph's edges by default; instead, edges are left in their default partitions (such as their original blocks in HDFS). [`Graph.groupEdges`][Graph.groupEdges] requires the graph to be repartitioned because it assumes identical edges will be colocated on the same partition, so you must call [`Graph.partitionBy`][Graph.partitionBy] before calling `groupEdges`.
+
+{% highlight scala %}
+object GraphLoader {
+ def edgeListFile(
+ sc: SparkContext,
+ path: String,
+ canonicalOrientation: Boolean = false,
+ minEdgePartitions: Int = 1)
+ : Graph[Int, Int]
+}
+{% endhighlight %}
+
+[`GraphLoader.edgeListFile`][GraphLoader.edgeListFile] provides a way to load a graph from a list of edges on disk. It parses an adjacency list of (source vertex ID, destination vertex ID) pairs of the following form, skipping comment lines that begin with `#`:
+
+~~~
+# This is a comment
+2 1
+4 1
+1 2
+~~~
+
+It creates a `Graph` from the specified edges, automatically creating any vertices mentioned by edges. All vertex and edge attributes default to 1. The `canonicalOrientation` argument allows reorienting edges in the positive direction (`srcId < dstId`), which is required by the [connected components][ConnectedComponents] algorithm. The `minEdgePartitions` argument specifies the minimum number of edge partitions to generate; there may be more edge partitions than specified if, for example, the HDFS file has more blocks.
+
+{% highlight scala %}
+object Graph {
+ def apply[VD, ED](
+ vertices: RDD[(VertexID, VD)],
+ edges: RDD[Edge[ED]],
+ defaultVertexAttr: VD = null)
+ : Graph[VD, ED]
+
+ def fromEdges[VD, ED](
+ edges: RDD[Edge[ED]],
+ defaultValue: VD): Graph[VD, ED]
+
+ def fromEdgeTuples[VD](
+ rawEdges: RDD[(VertexID, VertexID)],
+ defaultValue: VD,
+ uniqueEdges: Option[PartitionStrategy] = None): Graph[VD, Int]
+
+}
+{% endhighlight %}
+
+[`Graph.apply`][Graph.apply] allows creating a graph from RDDs of vertices and edges. Duplicate vertices are picked arbitrarily and vertices found in the edge RDD but not the vertex RDD are assigned the default attribute.
+
+[`Graph.fromEdges`][Graph.fromEdges] allows creating a graph from only an RDD of edges, automatically creating any vertices mentioned by edges and assigning them the default value.
+
+[`Graph.fromEdgeTuples`][Graph.fromEdgeTuples] allows creating a graph from only an RDD of edge tuples, assigning the edges the value 1, and automatically creating any vertices mentioned by edges and assigning them the default value. It also supports deduplicating the edges; to deduplicate, pass `Some` of a [`PartitionStrategy`][PartitionStrategy] as the `uniqueEdges` parameter (for example, `uniqueEdges = Some(PartitionStrategy.RandomVertexCut)`). A partition strategy is necessary to colocate identical edges on the same partition so they can be deduplicated.
+
+[PartitionStrategy]: api/graphx/index.html#org.apache.spark.graphx.PartitionStrategy$
+
+[GraphLoader.edgeListFile]: api/graphx/index.html#org.apache.spark.graphx.GraphLoader$@edgeListFile(SparkContext,String,Boolean,Int):Graph[Int,Int]
+[Graph.apply]: api/graphx/index.html#org.apache.spark.graphx.Graph$@apply[VD,ED](RDD[(VertexID,VD)],RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED]
+[Graph.fromEdgeTuples]: api/graphx/index.html#org.apache.spark.graphx.Graph$@fromEdgeTuples[VD](RDD[(VertexID,VertexID)],VD,Option[PartitionStrategy])(ClassTag[VD]):Graph[VD,Int]
+[Graph.fromEdges]: api/graphx/index.html#org.apache.spark.graphx.Graph$@fromEdges[VD,ED](RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED]
+
+# Vertex and Edge RDDs
+<a name="vertex_and_edge_rdds"></a>
+
+GraphX exposes `RDD` views of the vertices and edges stored within the graph. However, because
+GraphX maintains the vertices and edges in optimized data-structures and these data-structures
+provide additional functionality, the vertices and edges are returned as `VertexRDD` and `EdgeRDD`
+respectively. In this section we review some of the additional useful functionality in these types.
+
+## VertexRDDs
+
+The `VertexRDD[A]` extends the more traditional `RDD[(VertexId, A)]` but adds the additional
+constraint that each `VertexId` occurs only *once*. Moreover, `VertexRDD[A]` represents a *set* of
+vertices each with an attribute of type `A`. Internally, this is achieved by storing the vertex
+attributes in a reusable hash-map data-structure. As a consequence if two `VertexRDD`s are derived
+from the same base `VertexRDD` (e.g., by `filter` or `mapValues`) they can be joined in constant
+time without hash evaluations. To leverage this indexed data-structure, the `VertexRDD` exposes the
+following additional functionality:
+
+{% highlight scala %}
+class VertexRDD[VD] {
+ // Filter the vertex set but preserves the internal index
+ def filter(pred: Tuple2[VertexID, VD] => Boolean): VertexRDD[VD]
+ // Transform the values without changing the ids (preserves the internal index)
+ def mapValues[VD2](map: VD => VD2): VertexRDD[VD2]
+ def mapValues[VD2](map: (VertexID, VD) => VD2): VertexRDD[VD2]
+ // Remove vertices from this set that appear in the other set
+ def diff(other: VertexRDD[VD]): VertexRDD[VD]
+ // Join operators that take advantage of the internal indexing to accelerate joins (substantially)
+ def leftJoin[VD2, VD3](other: RDD[(VertexID, VD2)])(f: (VertexID, VD, Option[VD2]) => VD3): VertexRDD[VD3]
+ def innerJoin[U, VD2](other: RDD[(VertexID, U)])(f: (VertexID, VD, U) => VD2): VertexRDD[VD2]
+ // Use the index on this RDD to accelerate a `reduceByKey` operation on the input RDD.
+ def aggregateUsingIndex[VD2](other: RDD[(VertexID, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2]
+}
+{% endhighlight %}
+
+Notice, for example, how the `filter` operator returns an `VertexRDD`. Filter is actually
+implemented using a `BitSet` thereby reusing the index and preserving the ability to do fast joins
+with other `VertexRDD`s. Likewise, the `mapValues` operators do not allow the `map` function to
+change the `VertexId` thereby enabling the same `HashMap` data-structures to be reused. Both the
+`leftJoin` and `innerJoin` are able to identify when joining two `VertexRDD`s derived from the same
+`HashMap` and implement the join by linear scan rather than costly point lookups.
+
+The `aggregateUsingIndex` operator can be slightly confusing but is also useful for efficient
+construction of a new `VertexRDD` from an `RDD[(VertexId, A)]`. Conceptually, if I have constructed
+a `VertexRDD[B]` over a set of vertices, *which is a super-set* of the vertices in some
+`RDD[(VertexId, A)]` then I can reuse the index to both aggregate and then subsequently index the
+RDD. For example:
+
+{% highlight scala %}
+val setA: VertexRDD[Int] = VertexRDD(sc.parallelize(0L until 100L).map(id => (id, 1)))
+val rddB: RDD[(VertexID, Double)] = sc.parallelize(0L until 100L).flatMap(id => List((id, 1.0), (id, 2.0)))
+// There should be 200 entries in rddB
+rddB.count
+val setB: VertexRDD[Double] = setA.aggregateUsingIndex(rddB, _ + _)
+// There should be 100 entries in setB
+setB.count
+// Joining A and B should now be fast!
+val setC: VertexRDD[Double] = setA.innerJoin(setB)((id, a, b) => a + b)
+{% endhighlight %}
+
+## EdgeRDDs
+
+The `EdgeRDD[ED]`, which extends `RDD[Edge[ED]]` is considerably simpler than the `VertexRDD`.
+GraphX organizes the edges in blocks partitioned using one of the various partitioning strategies
+defined in [`PartitionStrategy`][PartitionStrategy]. Within each partition, edge attributes and
+adjacency structure, are stored separately enabling maximum reuse when changing attribute values.
+
+[PartitionStrategy]: api/graphx/index.html#org.apache.spark.graphx.PartitionStrategy
+
+The three additional functions exposed by the `EdgeRDD` are:
+{% highlight scala %}
+// Transform the edge attributes while preserving the structure
+def mapValues[ED2](f: Edge[ED] => ED2): EdgeRDD[ED2]
+// Revere the edges reusing both attributes and structure
+def reverse: EdgeRDD[ED]
+// Join two `EdgeRDD`s partitioned using the same partitioning strategy.
+def innerJoin[ED2, ED3](other: EdgeRDD[ED2])(f: (VertexID, VertexID, ED, ED2) => ED3): EdgeRDD[ED3]
+{% endhighlight %}
+
+In most applications we have found that operations on the `EdgeRDD` are accomplished through the
+graph or rely on operations defined in the base `RDD` class.
+
+# Optimized Representation
+
+While a detailed description of the optimizations used in the GraphX representation of distributed
+graphs is beyond the scope of this guide, some high-level understanding may aid in the design of
+scalable algorithms as well as optimal use of the API. GraphX adopts a vertex-cut approach to
+distributed graph partitioning:
+
+<p style="text-align: center;">
+ <img src="img/edge_cut_vs_vertex_cut.png"
+ title="Edge Cut vs. Vertex Cut"
+ alt="Edge Cut vs. Vertex Cut"
+ width="50%" />
+ <!-- Images are downsized intentionally to improve quality on retina displays -->
+</p>
+
+Rather than splitting graphs along edges, GraphX partitions the graph along vertices which can
+reduce both the communication and storage overhead. Logically, this corresponds to assigning edges
+to machines and allowing vertices to span multiple machines. The exact method of assigning edges
+depends on the [`PartitionStrategy`][PartitionStrategy] and there are several tradeoffs to the
+various heuristics. Users can choose between different strategies by repartitioning the graph with
+the [`Graph.partitionBy`][Graph.partitionBy] operator.
+
+[Graph.partitionBy]: api/graphx/index.html#org.apache.spark.graphx.Graph$@partitionBy(partitionStrategy:org.apache.spark.graphx.PartitionStrategy):org.apache.spark.graphx.Graph[VD,ED]
+
+<p style="text-align: center;">
+ <img src="img/vertex_routing_edge_tables.png"
+ title="RDD Graph Representation"
+ alt="RDD Graph Representation"
+ width="50%" />
+ <!-- Images are downsized intentionally to improve quality on retina displays -->
+</p>
+
+Once the edges have be partitioned the key challenge to efficient graph-parallel computation is
+efficiently joining vertex attributes with the edges. Because real-world graphs typically have more
+edges than vertices, we move vertex attributes to the edges.
+
+
+
+
+
+# Graph Algorithms
+<a name="graph_algorithms"></a>
+
+GraphX includes a set of graph algorithms in to simplify analytics. The algorithms are contained in the `org.apache.spark.graphx.lib` package and can be accessed directly as methods on `Graph` via [`GraphOps`][GraphOps]. This section describes the algorithms and how they are used.
+
+## PageRank
+<a name="pagerank"></a>
+
+PageRank measures the importance of each vertex in a graph, assuming an edge from *u* to *v* represents an endorsement of *v*'s importance by *u*. For example, if a Twitter user is followed by many others, the user will be ranked highly.
+
+GraphX comes with static and dynamic implementations of PageRank as methods on the [`PageRank` object][PageRank]. Static PageRank runs for a fixed number of iterations, while dynamic PageRank runs until the ranks converge (i.e., stop changing by more than a specified tolerance). [`GraphOps`][GraphOps] allows calling these algorithms directly as methods on `Graph`.
+
+GraphX also includes an example social network dataset that we can run PageRank on. A set of users is given in `graphx/data/users.txt`, and a set of relationships between users is given in `graphx/data/followers.txt`. We compute the PageRank of each user as follows:
+
+[PageRank]: api/graphx/index.html#org.apache.spark.graphx.lib.PageRank$
+
+{% highlight scala %}
+// Load the edges as a graph
+val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt")
+// Run PageRank
+val ranks = graph.pageRank(0.0001).vertices
+// Join the ranks with the usernames
+val users = sc.textFile("graphx/data/users.txt").map { line =>
+ val fields = line.split(",")
+ (fields(0).toLong, fields(1))
+}
+val ranksByUsername = users.join(ranks).map {
+ case (id, (username, rank)) => (username, rank)
+}
+// Print the result
+println(ranksByUsername.collect().mkString("\n"))
+{% endhighlight %}
+
+## Connected Components
+
+The connected components algorithm labels each connected component of the graph with the ID of its lowest-numbered vertex. For example, in a social network, connected components can approximate clusters. GraphX contains an implementation of the algorithm in the [`ConnectedComponents` object][ConnectedComponents], and we compute the connected components of the example social network dataset from the [PageRank section](#pagerank) as follows:
+
+[ConnectedComponents]: api/graphx/index.html#org.apache.spark.graphx.lib.ConnectedComponents$
+
+{% highlight scala %}
+// Load the graph as in the PageRank example
+val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt")
+// Find the connected components
+val cc = graph.connectedComponents().vertices
+// Join the connected components with the usernames
+val users = sc.textFile("graphx/data/users.txt").map { line =>
+ val fields = line.split(",")
+ (fields(0).toLong, fields(1))
+}
+val ccByUsername = users.join(cc).map {
+ case (id, (username, cc)) => (username, cc)
+}
+// Print the result
+println(ccByUsername.collect().mkString("\n"))
+{% endhighlight %}
+
+## Triangle Counting
+
+A vertex is part of a triangle when it has two adjacent vertices with an edge between them. GraphX implements a triangle counting algorithm in the [`TriangleCount` object][TriangleCount] that determines the number of triangles passing through each vertex, providing a measure of clustering. We compute the triangle count of the social network dataset from the [PageRank section](#pagerank). *Note that `TriangleCount` requires the edges to be in canonical orientation (`srcId < dstId`) and the graph to be partitioned using [`Graph.partitionBy`][Graph.partitionBy].*
+
+[TriangleCount]: api/graphx/index.html#org.apache.spark.graphx.lib.TriangleCount$
+[Graph.partitionBy]: api/graphx/index.html#org.apache.spark.graphx.Graph@partitionBy(PartitionStrategy):Graph[VD,ED]
+
+{% highlight scala %}
+// Load the edges in canonical order and partition the graph for triangle count
+val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt", true).partitionBy(PartitionStrategy.RandomVertexCut)
+// Find the triangle count for each vertex
+val triCounts = graph.triangleCount().vertices
+// Join the triangle counts with the usernames
+val users = sc.textFile("graphx/data/users.txt").map { line =>
+ val fields = line.split(",")
+ (fields(0).toLong, fields(1))
+}
+val triCountByUsername = users.join(triCounts).map { case (id, (username, tc)) =>
+ (username, tc)
+}
+// Print the result
+println(triCountByUsername.collect().mkString("\n"))
+{% endhighlight %}
+
+<p style="text-align: center;">
+ <img src="img/tables_and_graphs.png"
+ title="Tables and Graphs"
+ alt="Tables and Graphs"
+ width="50%" />
+ <!-- Images are downsized intentionally to improve quality on retina displays -->
+</p>
+
+# Examples
+
+Suppose I want to build a graph from some text files, restrict the graph
+to important relationships and users, run page-rank on the sub-graph, and
+then finally return attributes associated with the top users. I can do
+all of this in just a few lines with GraphX:
+
+{% highlight scala %}
+// Connect to the Spark cluster
+val sc = new SparkContext("spark://master.amplab.org", "research")
+
+// Load my user data and parse into tuples of user id and attribute list
+val users = (sc.textFile("graphx/data/users.txt")
+ .map(line => line.split(",")).map( parts => (parts.head.toLong, parts.tail) ))
+
+// Parse the edge data which is already in userId -> userId format
+val followerGraph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt")
+
+// Attach the user attributes
+val graph = followerGraph.outerJoinVertices(users) {
+ case (uid, deg, Some(attrList)) => attrList
+ // Some users may not have attributes so we set them as empty
+ case (uid, deg, None) => Array.empty[String]
+}
+
+// Restrict the graph to users with usernames and names
+val subgraph = graph.subgraph(vpred = (vid, attr) => attr.size == 2)
+
+// Compute the PageRank
+val pagerankGraph = subgraph.pageRank(0.001)
+
+// Get the attributes of the top pagerank users
+val userInfoWithPageRank = subgraph.outerJoinVertices(pagerankGraph.vertices) {
+ case (uid, attrList, Some(pr)) => (pr, attrList.toList)
+ case (uid, attrList, None) => (0.0, attrList.toList)
+}
+
+println(userInfoWithPageRank.vertices.top(5)(Ordering.by(_._2._1)).mkString("\n"))
+
+{% endhighlight %}
diff --git a/docs/img/data_parallel_vs_graph_parallel.png b/docs/img/data_parallel_vs_graph_parallel.png
new file mode 100644
index 0000000000..d3918f01d8
--- /dev/null
+++ b/docs/img/data_parallel_vs_graph_parallel.png
Binary files differ
diff --git a/docs/img/edge-cut.png b/docs/img/edge-cut.png
new file mode 100644
index 0000000000..698f4ff181
--- /dev/null
+++ b/docs/img/edge-cut.png
Binary files differ
diff --git a/docs/img/edge_cut_vs_vertex_cut.png b/docs/img/edge_cut_vs_vertex_cut.png
new file mode 100644
index 0000000000..ae30396d3f
--- /dev/null
+++ b/docs/img/edge_cut_vs_vertex_cut.png
Binary files differ
diff --git a/docs/img/graph_analytics_pipeline.png b/docs/img/graph_analytics_pipeline.png
new file mode 100644
index 0000000000..6d606e0189
--- /dev/null
+++ b/docs/img/graph_analytics_pipeline.png
Binary files differ
diff --git a/docs/img/graph_parallel.png b/docs/img/graph_parallel.png
new file mode 100644
index 0000000000..330be5567c
--- /dev/null
+++ b/docs/img/graph_parallel.png
Binary files differ
diff --git a/docs/img/graphx_figures.pptx b/docs/img/graphx_figures.pptx
new file mode 100644
index 0000000000..e567bf08fe
--- /dev/null
+++ b/docs/img/graphx_figures.pptx
Binary files differ
diff --git a/docs/img/graphx_logo.png b/docs/img/graphx_logo.png
new file mode 100644
index 0000000000..9869ac148c
--- /dev/null
+++ b/docs/img/graphx_logo.png
Binary files differ
diff --git a/docs/img/graphx_performance_comparison.png b/docs/img/graphx_performance_comparison.png
new file mode 100644
index 0000000000..62dcf098c9
--- /dev/null
+++ b/docs/img/graphx_performance_comparison.png
Binary files differ
diff --git a/docs/img/property_graph.png b/docs/img/property_graph.png
new file mode 100644
index 0000000000..6f3f89a010
--- /dev/null
+++ b/docs/img/property_graph.png
Binary files differ
diff --git a/docs/img/tables_and_graphs.png b/docs/img/tables_and_graphs.png
new file mode 100644
index 0000000000..ec37bb45a6
--- /dev/null
+++ b/docs/img/tables_and_graphs.png
Binary files differ
diff --git a/docs/img/triplet.png b/docs/img/triplet.png
new file mode 100644
index 0000000000..8b82a09bed
--- /dev/null
+++ b/docs/img/triplet.png
Binary files differ
diff --git a/docs/img/vertex-cut.png b/docs/img/vertex-cut.png
new file mode 100644
index 0000000000..0a508dcee9
--- /dev/null
+++ b/docs/img/vertex-cut.png
Binary files differ
diff --git a/docs/img/vertex_routing_edge_tables.png b/docs/img/vertex_routing_edge_tables.png
new file mode 100644
index 0000000000..4379becc87
--- /dev/null
+++ b/docs/img/vertex_routing_edge_tables.png
Binary files differ
diff --git a/docs/index.md b/docs/index.md
index 86d574daaa..debdb33108 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -5,7 +5,7 @@ title: Spark Overview
Apache Spark is a fast and general-purpose cluster computing system.
It provides high-level APIs in [Scala](scala-programming-guide.html), [Java](java-programming-guide.html), and [Python](python-programming-guide.html) that make parallel jobs easy to write, and an optimized engine that supports general computation graphs.
-It also supports a rich set of higher-level tools including [Shark](http://shark.cs.berkeley.edu) (Hive on Spark), [MLlib](mllib-guide.html) for machine learning, [Bagel](bagel-programming-guide.html) for graph processing, and [Spark Streaming](streaming-programming-guide.html).
+It also supports a rich set of higher-level tools including [Shark](http://shark.cs.berkeley.edu) (Hive on Spark), [MLlib](mllib-guide.html) for machine learning, [GraphX](graphx-programming-guide.html) for graph processing, and [Spark Streaming](streaming-programming-guide.html).
# Downloading
@@ -78,6 +78,7 @@ For this version of Spark (0.8.1) Hadoop 2.2.x (or newer) users will have to bui
* [Spark Streaming](streaming-programming-guide.html): using the alpha release of Spark Streaming
* [MLlib (Machine Learning)](mllib-guide.html): Spark's built-in machine learning library
* [Bagel (Pregel on Spark)](bagel-programming-guide.html): simple graph processing model
+* [GraphX (Graphs on Spark)](graphx-programming-guide.html): Spark's new API for graphs
**API Docs:**
@@ -86,6 +87,7 @@ For this version of Spark (0.8.1) Hadoop 2.2.x (or newer) users will have to bui
* [Spark Streaming for Java/Scala (Scaladoc)](api/streaming/index.html)
* [MLlib (Machine Learning) for Java/Scala (Scaladoc)](api/mllib/index.html)
* [Bagel (Pregel on Spark) for Scala (Scaladoc)](api/bagel/index.html)
+* [GraphX (Graphs on Spark) for Scala (Scaladoc)](api/graphx/index.html)
**Deployment guides:**
diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md
index 21d0464852..a140ecb618 100644
--- a/docs/mllib-guide.md
+++ b/docs/mllib-guide.md
@@ -21,6 +21,8 @@ depends on native Fortran routines. You may need to install the
if it is not already present on your nodes. MLlib will throw a linking error if it cannot
detect these libraries automatically.
+To use MLlib in Python, you will also need [NumPy](http://www.numpy.org) version 1.7 or newer.
+
# Binary Classification
Binary classification is a supervised learning problem in which we want to
@@ -316,6 +318,13 @@ other signals), you can use the trainImplicit method to get better results.
val model = ALS.trainImplicit(ratings, 1, 20, 0.01)
{% endhighlight %}
+# Using MLLib in Java
+
+All of MLlib's methods use Java-friendly types, so you can import and call them there the same
+way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the
+Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by
+calling `.rdd()` on your `JavaRDD` object.
+
# Using MLLib in Python
Following examples can be tested in the PySpark shell.
@@ -330,7 +339,7 @@ from numpy import array
# Load and parse the data
data = sc.textFile("mllib/data/sample_svm_data.txt")
parsedData = data.map(lambda line: array([float(x) for x in line.split(' ')]))
-model = LogisticRegressionWithSGD.train(sc, parsedData)
+model = LogisticRegressionWithSGD.train(parsedData)
# Build the model
labelsAndPreds = parsedData.map(lambda point: (int(point.item(0)),
@@ -356,7 +365,7 @@ data = sc.textFile("mllib/data/ridge-data/lpsa.data")
parsedData = data.map(lambda line: array([float(x) for x in line.replace(',', ' ').split(' ')]))
# Build the model
-model = LinearRegressionWithSGD.train(sc, parsedData)
+model = LinearRegressionWithSGD.train(parsedData)
# Evaluate the model on training data
valuesAndPreds = parsedData.map(lambda point: (point.item(0),
@@ -382,7 +391,7 @@ data = sc.textFile("kmeans_data.txt")
parsedData = data.map(lambda line: array([float(x) for x in line.split(' ')]))
# Build the model (cluster the data)
-clusters = KMeans.train(sc, parsedData, 2, maxIterations=10,
+clusters = KMeans.train(parsedData, 2, maxIterations=10,
runs=30, initialization_mode="random")
# Evaluate clustering by computing Within Set Sum of Squared Errors
@@ -411,7 +420,7 @@ data = sc.textFile("mllib/data/als/test.data")
ratings = data.map(lambda line: array([float(x) for x in line.split(',')]))
# Build the recommendation model using Alternating Least Squares
-model = ALS.train(sc, ratings, 1, 20)
+model = ALS.train(ratings, 1, 20)
# Evaluate the model on training data
testdata = ratings.map(lambda p: (int(p[0]), int(p[1])))
@@ -426,7 +435,7 @@ signals), you can use the trainImplicit method to get better results.
{% highlight python %}
# Build the recommendation model using Alternating Least Squares based on implicit ratings
-model = ALS.trainImplicit(sc, ratings, 1, 20)
+model = ALS.trainImplicit(ratings, 1, 20)
{% endhighlight %}
diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md
index c4236f8312..b07899c2e1 100644
--- a/docs/python-programming-guide.md
+++ b/docs/python-programming-guide.md
@@ -52,7 +52,7 @@ In addition, PySpark fully supports interactive use---simply run `./bin/pyspark`
# Installing and Configuring PySpark
-PySpark requires Python 2.6 or higher.
+PySpark requires Python 2.7 or higher.
PySpark applications are executed using a standard CPython interpreter in order to support Python modules that use C extensions.
We have not tested PySpark with Python 3 or with alternative Python interpreters, such as [PyPy](http://pypy.org/) or [Jython](http://www.jython.org/).
@@ -149,6 +149,12 @@ sc = SparkContext(conf = conf)
[API documentation](api/pyspark/index.html) for PySpark is available as Epydoc.
Many of the methods also contain [doctests](http://docs.python.org/2/library/doctest.html) that provide additional usage examples.
+# Libraries
+
+[MLlib](mllib-guide.html) is also available in PySpark. To use it, you'll need
+[NumPy](http://www.numpy.org) version 1.7 or newer. The [MLlib guide](mllib-guide.html) contains
+some example applications.
+
# Where to Go from Here
PySpark also includes several sample programs in the [`python/examples` folder](https://github.com/apache/incubator-spark/tree/master/python/examples).
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 1c9ece6270..4e8a680a75 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -167,7 +167,7 @@ Spark Streaming features windowed computations, which allow you to apply transfo
</tr>
</table>
-A complete list of DStream operations is available in the API documentation of [DStream](api/streaming/index.html#org.apache.spark.streaming.DStream) and [PairDStreamFunctions](api/streaming/index.html#org.apache.spark.streaming.PairDStreamFunctions).
+A complete list of DStream operations is available in the API documentation of [DStream](api/streaming/index.html#org.apache.spark.streaming.dstream.DStream) and [PairDStreamFunctions](api/streaming/index.html#org.apache.spark.streaming.dstream.PairDStreamFunctions).
## Output Operations
When an output operator is called, it triggers the computation of a stream. Currently the following output operators are defined:
@@ -175,7 +175,7 @@ When an output operator is called, it triggers the computation of a stream. Curr
<table class="table">
<tr><th style="width:30%">Operator</th><th>Meaning</th></tr>
<tr>
- <td> <b>foreach</b>(<i>func</i>) </td>
+ <td> <b>foreachRDD</b>(<i>func</i>) </td>
<td> The fundamental output operator. Applies a function, <i>func</i>, to each RDD generated from the stream. This function should have side effects, such as printing output, saving the RDD to external files, or writing it over the network to an external system. </td>
</tr>
@@ -375,7 +375,7 @@ There are two failure behaviors based on which input sources are used.
1. _Using HDFS files as input source_ - Since the data is reliably stored on HDFS, all data can re-computed and therefore no data will be lost due to any failure.
1. _Using any input source that receives data through a network_ - For network-based data sources like Kafka and Flume, the received input data is replicated in memory between nodes of the cluster (default replication factor is 2). So if a worker node fails, then the system can recompute the lost from the the left over copy of the input data. However, if the worker node where a network receiver was running fails, then a tiny bit of data may be lost, that is, the data received by the system but not yet replicated to other node(s). The receiver will be started on a different node and it will continue to receive data.
-Since all data is modeled as RDDs with their lineage of deterministic operations, any recomputation always leads to the same result. As a result, all DStream transformations are guaranteed to have _exactly-once_ semantics. That is, the final transformed result will be same even if there were was a worker node failure. However, output operations (like `foreach`) have _at-least once_ semantics, that is, the transformed data may get written to an external entity more than once in the event of a worker failure. While this is acceptable for saving to HDFS using the `saveAs*Files` operations (as the file will simply get over-written by the same data), additional transactions-like mechanisms may be necessary to achieve exactly-once semantics for output operations.
+Since all data is modeled as RDDs with their lineage of deterministic operations, any recomputation always leads to the same result. As a result, all DStream transformations are guaranteed to have _exactly-once_ semantics. That is, the final transformed result will be same even if there were was a worker node failure. However, output operations (like `foreachRDD`) have _at-least once_ semantics, that is, the transformed data may get written to an external entity more than once in the event of a worker failure. While this is acceptable for saving to HDFS using the `saveAs*Files` operations (as the file will simply get over-written by the same data), additional transactions-like mechanisms may be necessary to achieve exactly-once semantics for output operations.
## Failure of the Driver Node
A system that is required to operate 24/7 needs to be able tolerate the failure of the driver node as well. Spark Streaming does this by saving the state of the DStream computation periodically to a HDFS file, that can be used to restart the streaming computation in the event of a failure of the driver node. This checkpointing is enabled by setting a HDFS directory for checkpointing using `ssc.checkpoint(<checkpoint directory>)` as described [earlier](#rdd-checkpointing-within-dstreams). To elaborate, the following state is periodically saved to a file.
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
index 83db8b9e26..c8ecbb8e41 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
@@ -43,7 +43,7 @@ object LocalALS {
def generateR(): DoubleMatrix2D = {
val mh = factory2D.random(M, F)
val uh = factory2D.random(U, F)
- return algebra.mult(mh, algebra.transpose(uh))
+ algebra.mult(mh, algebra.transpose(uh))
}
def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D],
@@ -56,7 +56,7 @@ object LocalALS {
//println("R: " + r)
blas.daxpy(-1, targetR, r)
val sumSqs = r.aggregate(Functions.plus, Functions.square)
- return sqrt(sumSqs / (M * U))
+ sqrt(sumSqs / (M * U))
}
def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
@@ -80,7 +80,7 @@ object LocalALS {
val ch = new CholeskyDecomposition(XtX)
val Xty2D = factory2D.make(Xty.toArray, F)
val solved2D = ch.solve(Xty2D)
- return solved2D.viewColumn(0)
+ solved2D.viewColumn(0)
}
def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D],
@@ -104,7 +104,7 @@ object LocalALS {
val ch = new CholeskyDecomposition(XtX)
val Xty2D = factory2D.make(Xty.toArray, F)
val solved2D = ch.solve(Xty2D)
- return solved2D.viewColumn(0)
+ solved2D.viewColumn(0)
}
def main(args: Array[String]) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala
index fb130ea198..9ab5f5a486 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala
@@ -28,7 +28,7 @@ object LocalFileLR {
def parsePoint(line: String): DataPoint = {
val nums = line.split(' ').map(_.toDouble)
- return DataPoint(new Vector(nums.slice(1, D+1)), nums(0))
+ DataPoint(new Vector(nums.slice(1, D+1)), nums(0))
}
def main(args: Array[String]) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala
index f90ea35cd4..a730464ea1 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala
@@ -55,7 +55,7 @@ object LocalKMeans {
}
}
- return bestIndex
+ bestIndex
}
def main(args: Array[String]) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
index 30c86d83e6..17bafc2218 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
@@ -44,7 +44,7 @@ object SparkALS {
def generateR(): DoubleMatrix2D = {
val mh = factory2D.random(M, F)
val uh = factory2D.random(U, F)
- return algebra.mult(mh, algebra.transpose(uh))
+ algebra.mult(mh, algebra.transpose(uh))
}
def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D],
@@ -57,7 +57,7 @@ object SparkALS {
//println("R: " + r)
blas.daxpy(-1, targetR, r)
val sumSqs = r.aggregate(Functions.plus, Functions.square)
- return sqrt(sumSqs / (M * U))
+ sqrt(sumSqs / (M * U))
}
def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
@@ -83,7 +83,7 @@ object SparkALS {
val ch = new CholeskyDecomposition(XtX)
val Xty2D = factory2D.make(Xty.toArray, F)
val solved2D = ch.solve(Xty2D)
- return solved2D.viewColumn(0)
+ solved2D.viewColumn(0)
}
def main(args: Array[String]) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
index ff72532db1..39819064ed 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
@@ -43,7 +43,7 @@ object SparkHdfsLR {
while (i < D) {
x(i) = tok.nextToken.toDouble; i += 1
}
- return DataPoint(new Vector(x), y)
+ DataPoint(new Vector(x), y)
}
def main(args: Array[String]) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
index 8c99025eaa..9fe2465235 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
@@ -30,7 +30,7 @@ object SparkKMeans {
val rand = new Random(42)
def parseVector(line: String): Vector = {
- return new Vector(line.split(' ').map(_.toDouble))
+ new Vector(line.split(' ').map(_.toDouble))
}
def closestPoint(p: Vector, centers: Array[Vector]): Int = {
@@ -46,7 +46,7 @@ object SparkKMeans {
}
}
- return bestIndex
+ bestIndex
}
def main(args: Array[String]) {
@@ -61,15 +61,15 @@ object SparkKMeans {
val K = args(2).toInt
val convergeDist = args(3).toDouble
- var kPoints = data.takeSample(false, K, 42).toArray
+ val kPoints = data.takeSample(withReplacement = false, K, 42).toArray
var tempDist = 1.0
while(tempDist > convergeDist) {
- var closest = data.map (p => (closestPoint(p, kPoints), (p, 1)))
+ val closest = data.map (p => (closestPoint(p, kPoints), (p, 1)))
- var pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)}
+ val pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)}
- var newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collectAsMap()
+ val newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collectAsMap()
tempDist = 0.0
for (i <- 0 until K) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala
new file mode 100644
index 0000000000..d58fddff2b
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.graphx
+
+import org.apache.spark.SparkContext._
+import org.apache.spark._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.lib.Analytics
+
+/**
+ * Uses GraphX to run PageRank on a LiveJournal social network graph. Download the dataset from
+ * http://snap.stanford.edu/data/soc-LiveJournal1.html.
+ */
+object LiveJournalPageRank {
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ System.err.println(
+ "Usage: LiveJournalPageRank <master> <edge_list_file>\n" +
+ " [--tol=<tolerance>]\n" +
+ " The tolerance allowed at convergence (smaller => more accurate). Default is " +
+ "0.001.\n" +
+ " [--output=<output_file>]\n" +
+ " If specified, the file to write the ranks to.\n" +
+ " [--numEPart=<num_edge_partitions>]\n" +
+ " The number of partitions for the graph's edge RDD. Default is 4.\n" +
+ " [--partStrategy=RandomVertexCut | EdgePartition1D | EdgePartition2D | " +
+ "CanonicalRandomVertexCut]\n" +
+ " The way edges are assigned to edge partitions. Default is RandomVertexCut.")
+ System.exit(-1)
+ }
+
+ Analytics.main(args.patch(1, List("pagerank"), 0))
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala
index 3d08d86567..99b79c3949 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala
@@ -58,7 +58,7 @@ object RawNetworkGrep {
val rawStreams = (1 to numStreams).map(_ =>
ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray
val union = ssc.union(rawStreams)
- union.filter(_.contains("the")).count().foreach(r =>
+ union.filter(_.contains("the")).count().foreachRDD(r =>
println("Grep count: " + r.collect().mkString))
ssc.start()
}
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala
index d51e6e9418..8c5d0bd568 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala
@@ -82,7 +82,7 @@ object RecoverableNetworkWordCount {
val lines = ssc.socketTextStream(ip, port)
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
- wordCounts.foreach((rdd: RDD[(String, Int)], time: Time) => {
+ wordCounts.foreachRDD((rdd: RDD[(String, Int)], time: Time) => {
val counts = "Counts at time " + time + " " + rdd.collect().mkString("[", ", ", "]")
println(counts)
println("Appending to " + outputFile.getAbsolutePath)
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala
index 80b5a98b14..483c4d3118 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala
@@ -81,7 +81,7 @@ object TwitterAlgebirdCMS {
val exactTopUsers = users.map(id => (id, 1))
.reduceByKey((a, b) => a + b)
- approxTopUsers.foreach(rdd => {
+ approxTopUsers.foreachRDD(rdd => {
if (rdd.count() != 0) {
val partial = rdd.first()
val partialTopK = partial.heavyHitters.map(id =>
@@ -96,7 +96,7 @@ object TwitterAlgebirdCMS {
}
})
- exactTopUsers.foreach(rdd => {
+ exactTopUsers.foreachRDD(rdd => {
if (rdd.count() != 0) {
val partialMap = rdd.collect().toMap
val partialTopK = rdd.map(
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala
index cb2f2c51a0..94c2bf29ac 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala
@@ -67,7 +67,7 @@ object TwitterAlgebirdHLL {
val exactUsers = users.map(id => Set(id)).reduce(_ ++ _)
- approxUsers.foreach(rdd => {
+ approxUsers.foreachRDD(rdd => {
if (rdd.count() != 0) {
val partial = rdd.first()
globalHll += partial
@@ -76,7 +76,7 @@ object TwitterAlgebirdHLL {
}
})
- exactUsers.foreach(rdd => {
+ exactUsers.foreachRDD(rdd => {
if (rdd.count() != 0) {
val partial = rdd.first()
userSet ++= partial
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala
index 16c10feaba..8a70d4a978 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala
@@ -56,13 +56,13 @@ object TwitterPopularTags {
// Print popular hashtags
- topCounts60.foreach(rdd => {
+ topCounts60.foreachRDD(rdd => {
val topList = rdd.take(5)
println("\nPopular topics in last 60 seconds (%s total):".format(rdd.count()))
topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))}
})
- topCounts10.foreach(rdd => {
+ topCounts10.foreachRDD(rdd => {
val topList = rdd.take(5)
println("\nPopular topics in last 10 seconds (%s total):".format(rdd.count()))
topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))}
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala
index 4fe57de4a4..a2600989ca 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala
@@ -65,7 +65,7 @@ object PageViewGenerator {
return item
}
}
- return inputMap.take(1).head._1 // Shouldn't get here if probabilities add up to 1.0
+ inputMap.take(1).head._1 // Shouldn't get here if probabilities add up to 1.0
}
def getNextClickEvent() : String = {
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala
index da6b67bcce..bb44bc3d06 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala
@@ -91,7 +91,7 @@ object PageViewStream {
case "popularUsersSeen" =>
// Look for users in our existing dataset and print it out if we have a match
pageViews.map(view => (view.userID, 1))
- .foreach((rdd, time) => rdd.join(userList)
+ .foreachRDD((rdd, time) => rdd.join(userList)
.map(_._2._2)
.take(10)
.foreach(u => println("Saw user %s at time %s".format(u, time))))
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala
index 834b775d4f..d53b66dd46 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala
@@ -18,8 +18,9 @@
package org.apache.spark.streaming.flume
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.{StreamingContext, DStream}
+import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.api.java.{JavaStreamingContext, JavaDStream}
+import org.apache.spark.streaming.dstream.DStream
object FlumeUtils {
/**
@@ -42,6 +43,7 @@ object FlumeUtils {
/**
* Creates a input stream from a Flume source.
+ * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2.
* @param hostname Hostname of the slave machine to which the flume data will be sent
* @param port Port of the slave machine to which the flume data will be sent
*/
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
index c2d851f943..37c03be4e7 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
@@ -26,8 +26,9 @@ import java.util.{Map => JMap}
import kafka.serializer.{Decoder, StringDecoder}
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.{StreamingContext, DStream}
+import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.api.java.{JavaStreamingContext, JavaPairDStream}
+import org.apache.spark.streaming.dstream.DStream
object KafkaUtils {
@@ -77,6 +78,7 @@ object KafkaUtils {
/**
* Create an input stream that pulls messages form a Kafka Broker.
+ * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2.
* @param jssc JavaStreamingContext object
* @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..)
* @param groupId The group id for this consumer
@@ -127,7 +129,7 @@ object KafkaUtils {
* see http://kafka.apache.org/08/configuration.html
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread
- * @param storageLevel RDD storage level. Defaults to MEMORY_AND_DISK_2.
+ * @param storageLevel RDD storage level.
*/
def createStream[K, V, U <: Decoder[_], T <: Decoder[_]](
jssc: JavaStreamingContext,
diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala
index 0e6c25dbee..3636e46bb8 100644
--- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala
+++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala
@@ -18,9 +18,10 @@
package org.apache.spark.streaming.mqtt
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.{StreamingContext, DStream}
+import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.api.java.{JavaStreamingContext, JavaDStream}
import scala.reflect.ClassTag
+import org.apache.spark.streaming.dstream.DStream
object MQTTUtils {
/**
@@ -43,6 +44,7 @@ object MQTTUtils {
/**
* Create an input stream that receives messages pushed by a MQTT publisher.
+ * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2.
* @param jssc JavaStreamingContext object
* @param brokerUrl Url of remote MQTT publisher
* @param topic Topic name to subscribe to
diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
index fcc159e85a..73e7ce6e96 100644
--- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
+++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.storage.StorageLevel
class MQTTStreamSuite extends TestSuiteBase {
- test("MQTT input stream") {
+ test("mqtt input stream") {
val ssc = new StreamingContext(master, framework, batchDuration)
val brokerUrl = "abc"
val topic = "def"
diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala
index 5e506ffabc..b8bae7b6d3 100644
--- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala
+++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala
@@ -20,8 +20,9 @@ package org.apache.spark.streaming.twitter
import twitter4j.Status
import twitter4j.auth.Authorization
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.{StreamingContext, DStream}
+import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.api.java.{JavaDStream, JavaStreamingContext}
+import org.apache.spark.streaming.dstream.DStream
object TwitterUtils {
/**
@@ -50,6 +51,7 @@ object TwitterUtils {
* OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey,
* twitter4j.oauth.consumerSecret, twitter4j.oauth.accessToken and
* twitter4j.oauth.accessTokenSecret.
+ * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2.
* @param jssc JavaStreamingContext object
*/
def createStream(jssc: JavaStreamingContext): JavaDStream[Status] = {
@@ -61,6 +63,7 @@ object TwitterUtils {
* OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey,
* twitter4j.oauth.consumerSecret, twitter4j.oauth.accessToken and
* twitter4j.oauth.accessTokenSecret.
+ * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2.
* @param jssc JavaStreamingContext object
* @param filters Set of filter strings to get only those tweets that match them
*/
@@ -87,6 +90,7 @@ object TwitterUtils {
/**
* Create a input stream that returns tweets received from Twitter.
+ * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2.
* @param jssc JavaStreamingContext object
* @param twitterAuth Twitter4J Authorization
*/
@@ -96,6 +100,7 @@ object TwitterUtils {
/**
* Create a input stream that returns tweets received from Twitter.
+ * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2.
* @param jssc JavaStreamingContext object
* @param twitterAuth Twitter4J Authorization
* @param filters Set of filter strings to get only those tweets that match them
diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala
index a0a8fe617b..ccc38784ef 100644
--- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala
+++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala
@@ -23,7 +23,7 @@ import twitter4j.auth.{NullAuthorization, Authorization}
class TwitterStreamSuite extends TestSuiteBase {
- test("kafka input stream") {
+ test("twitter input stream") {
val ssc = new StreamingContext(master, framework, batchDuration)
val filters = Seq("filter1", "filter2")
val authorization: Authorization = NullAuthorization.getInstance()
diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala
index 546d9df3b5..7a14b3d2bf 100644
--- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala
+++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala
@@ -25,8 +25,9 @@ import akka.zeromq.Subscribe
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.receivers.ReceiverSupervisorStrategy
-import org.apache.spark.streaming.{StreamingContext, DStream}
+import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.api.java.{JavaStreamingContext, JavaDStream}
+import org.apache.spark.streaming.dstream.DStream
object ZeroMQUtils {
/**
diff --git a/graphx/data/followers.txt b/graphx/data/followers.txt
new file mode 100644
index 0000000000..7bb8e900e2
--- /dev/null
+++ b/graphx/data/followers.txt
@@ -0,0 +1,8 @@
+2 1
+4 1
+1 2
+6 3
+7 3
+7 6
+6 7
+3 7
diff --git a/graphx/data/users.txt b/graphx/data/users.txt
new file mode 100644
index 0000000000..982d19d50b
--- /dev/null
+++ b/graphx/data/users.txt
@@ -0,0 +1,7 @@
+1,BarackObama,Barack Obama
+2,ladygaga,Goddess of Love
+3,jeresig,John Resig
+4,justinbieber,Justin Bieber
+6,matei_zaharia,Matei Zaharia
+7,odersky,Martin Odersky
+8,anonsys
diff --git a/graphx/pom.xml b/graphx/pom.xml
new file mode 100644
index 0000000000..3e5faf230d
--- /dev/null
+++ b/graphx/pom.xml
@@ -0,0 +1,67 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-parent</artifactId>
+ <version>0.9.0-incubating-SNAPSHOT</version>
+ <relativePath>../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-graphx_2.10</artifactId>
+ <packaging>jar</packaging>
+ <name>Spark Project GraphX</name>
+ <url>http://spark-project.org/</url>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-server</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.scalacheck</groupId>
+ <artifactId>scalacheck_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
+ <plugins>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+</project>
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Edge.scala b/graphx/src/main/scala/org/apache/spark/graphx/Edge.scala
new file mode 100644
index 0000000000..738a38b27f
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Edge.scala
@@ -0,0 +1,45 @@
+package org.apache.spark.graphx
+
+/**
+ * A single directed edge consisting of a source id, target id,
+ * and the data associated with the edge.
+ *
+ * @tparam ED type of the edge attribute
+ *
+ * @param srcId The vertex id of the source vertex
+ * @param dstId The vertex id of the target vertex
+ * @param attr The attribute associated with the edge
+ */
+case class Edge[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED] (
+ var srcId: VertexID = 0,
+ var dstId: VertexID = 0,
+ var attr: ED = null.asInstanceOf[ED])
+ extends Serializable {
+
+ /**
+ * Given one vertex in the edge return the other vertex.
+ *
+ * @param vid the id one of the two vertices on the edge.
+ * @return the id of the other vertex on the edge.
+ */
+ def otherVertexId(vid: VertexID): VertexID =
+ if (srcId == vid) dstId else { assert(dstId == vid); srcId }
+
+ /**
+ * Return the relative direction of the edge to the corresponding
+ * vertex.
+ *
+ * @param vid the id of one of the two vertices in the edge.
+ * @return the relative direction of the edge to the corresponding
+ * vertex.
+ */
+ def relativeDirection(vid: VertexID): EdgeDirection =
+ if (vid == srcId) EdgeDirection.Out else { assert(vid == dstId); EdgeDirection.In }
+}
+
+object Edge {
+ private[graphx] def lexicographicOrdering[ED] = new Ordering[Edge[ED]] {
+ override def compare(a: Edge[ED], b: Edge[ED]): Int =
+ (if (a.srcId != b.srcId) a.srcId - b.srcId else a.dstId - b.dstId).toInt
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala
new file mode 100644
index 0000000000..f265764006
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala
@@ -0,0 +1,44 @@
+package org.apache.spark.graphx
+
+/**
+ * The direction of a directed edge relative to a vertex.
+ */
+class EdgeDirection private (private val name: String) extends Serializable {
+ /**
+ * Reverse the direction of an edge. An in becomes out,
+ * out becomes in and both and either remain the same.
+ */
+ def reverse: EdgeDirection = this match {
+ case EdgeDirection.In => EdgeDirection.Out
+ case EdgeDirection.Out => EdgeDirection.In
+ case EdgeDirection.Either => EdgeDirection.Either
+ case EdgeDirection.Both => EdgeDirection.Both
+ }
+
+ override def toString: String = "EdgeDirection." + name
+
+ override def equals(o: Any) = o match {
+ case other: EdgeDirection => other.name == name
+ case _ => false
+ }
+
+ override def hashCode = name.hashCode
+}
+
+
+/**
+ * A set of [[EdgeDirection]]s.
+ */
+object EdgeDirection {
+ /** Edges arriving at a vertex. */
+ final val In = new EdgeDirection("In")
+
+ /** Edges originating from a vertex. */
+ final val Out = new EdgeDirection("Out")
+
+ /** Edges originating from *or* arriving at a vertex of interest. */
+ final val Either = new EdgeDirection("Either")
+
+ /** Edges originating from *and* arriving at a vertex of interest. */
+ final val Both = new EdgeDirection("Both")
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
new file mode 100644
index 0000000000..832b7816fe
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
@@ -0,0 +1,102 @@
+package org.apache.spark.graphx
+
+import scala.reflect.{classTag, ClassTag}
+
+import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext}
+import org.apache.spark.graphx.impl.EdgePartition
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * `EdgeRDD[ED]` extends `RDD[Edge[ED]]` by storing the edges in columnar format on each partition
+ * for performance.
+ */
+class EdgeRDD[@specialized ED: ClassTag](
+ val partitionsRDD: RDD[(PartitionID, EdgePartition[ED])])
+ extends RDD[Edge[ED]](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) {
+
+ partitionsRDD.setName("EdgeRDD")
+
+ override protected def getPartitions: Array[Partition] = partitionsRDD.partitions
+
+ /**
+ * If `partitionsRDD` already has a partitioner, use it. Otherwise assume that the
+ * [[PartitionID]]s in `partitionsRDD` correspond to the actual partitions and create a new
+ * partitioner that allows co-partitioning with `partitionsRDD`.
+ */
+ override val partitioner =
+ partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD)))
+
+ override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = {
+ firstParent[(PartitionID, EdgePartition[ED])].iterator(part, context).next._2.iterator
+ }
+
+ override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect()
+
+ override def persist(newLevel: StorageLevel): EdgeRDD[ED] = {
+ partitionsRDD.persist(newLevel)
+ this
+ }
+
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
+ override def persist(): EdgeRDD[ED] = persist(StorageLevel.MEMORY_ONLY)
+
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
+ override def cache(): EdgeRDD[ED] = persist()
+
+ override def unpersist(blocking: Boolean = true): EdgeRDD[ED] = {
+ partitionsRDD.unpersist(blocking)
+ this
+ }
+
+ private[graphx] def mapEdgePartitions[ED2: ClassTag](f: (PartitionID, EdgePartition[ED]) => EdgePartition[ED2])
+ : EdgeRDD[ED2] = {
+ new EdgeRDD[ED2](partitionsRDD.mapPartitions({ iter =>
+ val (pid, ep) = iter.next()
+ Iterator(Tuple2(pid, f(pid, ep)))
+ }, preservesPartitioning = true))
+ }
+
+ /**
+ * Map the values in an edge partitioning preserving the structure but changing the values.
+ *
+ * @tparam ED2 the new edge value type
+ * @param f the function from an edge to a new edge value
+ * @return a new EdgeRDD containing the new edge values
+ */
+ def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2] =
+ mapEdgePartitions((pid, part) => part.map(f))
+
+ /**
+ * Reverse all the edges in this RDD.
+ *
+ * @return a new EdgeRDD containing all the edges reversed
+ */
+ def reverse: EdgeRDD[ED] = mapEdgePartitions((pid, part) => part.reverse)
+
+ /**
+ * Inner joins this EdgeRDD with another EdgeRDD, assuming both are partitioned using the same
+ * [[PartitionStrategy]].
+ *
+ * @param other the EdgeRDD to join with
+ * @param f the join function applied to corresponding values of `this` and `other`
+ * @return a new EdgeRDD containing only edges that appear in both `this` and `other`, with values
+ * supplied by `f`
+ */
+ def innerJoin[ED2: ClassTag, ED3: ClassTag]
+ (other: EdgeRDD[ED2])
+ (f: (VertexID, VertexID, ED, ED2) => ED3): EdgeRDD[ED3] = {
+ val ed2Tag = classTag[ED2]
+ val ed3Tag = classTag[ED3]
+ new EdgeRDD[ED3](partitionsRDD.zipPartitions(other.partitionsRDD, true) {
+ (thisIter, otherIter) =>
+ val (pid, thisEPart) = thisIter.next()
+ val (_, otherEPart) = otherIter.next()
+ Iterator(Tuple2(pid, thisEPart.innerJoin(otherEPart)(f)(ed2Tag, ed3Tag)))
+ })
+ }
+
+ private[graphx] def collectVertexIDs(): RDD[VertexID] = {
+ partitionsRDD.flatMap { case (_, p) => Array.concat(p.srcIds, p.dstIds) }
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala
new file mode 100644
index 0000000000..4253b24b5a
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala
@@ -0,0 +1,49 @@
+package org.apache.spark.graphx
+
+/**
+ * An edge triplet represents an edge along with the vertex attributes of its neighboring vertices.
+ *
+ * @tparam VD the type of the vertex attribute.
+ * @tparam ED the type of the edge attribute
+ */
+class EdgeTriplet[VD, ED] extends Edge[ED] {
+ /**
+ * The source vertex attribute
+ */
+ var srcAttr: VD = _ //nullValue[VD]
+
+ /**
+ * The destination vertex attribute
+ */
+ var dstAttr: VD = _ //nullValue[VD]
+
+ /**
+ * Set the edge properties of this triplet.
+ */
+ protected[spark] def set(other: Edge[ED]): EdgeTriplet[VD,ED] = {
+ srcId = other.srcId
+ dstId = other.dstId
+ attr = other.attr
+ this
+ }
+
+ /**
+ * Given one vertex in the edge return the other vertex.
+ *
+ * @param vid the id one of the two vertices on the edge
+ * @return the attribute for the other vertex on the edge
+ */
+ def otherVertexAttr(vid: VertexID): VD =
+ if (srcId == vid) dstAttr else { assert(dstId == vid); srcAttr }
+
+ /**
+ * Get the vertex object for the given vertex in the edge.
+ *
+ * @param vid the id of one of the two vertices on the edge
+ * @return the attr for the vertex with that id
+ */
+ def vertexAttr(vid: VertexID): VD =
+ if (srcId == vid) srcAttr else { assert(dstId == vid); dstAttr }
+
+ override def toString = ((srcId, srcAttr), (dstId, dstAttr), attr).toString()
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
new file mode 100644
index 0000000000..9dd05ade0a
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
@@ -0,0 +1,405 @@
+package org.apache.spark.graphx
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.graphx.impl._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * The Graph abstractly represents a graph with arbitrary objects
+ * associated with vertices and edges. The graph provides basic
+ * operations to access and manipulate the data associated with
+ * vertices and edges as well as the underlying structure. Like Spark
+ * RDDs, the graph is a functional data-structure in which mutating
+ * operations return new graphs.
+ *
+ * @note [[GraphOps]] contains additional convenience operations and graph algorithms.
+ *
+ * @tparam VD the vertex attribute type
+ * @tparam ED the edge attribute type
+ */
+abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializable {
+
+ /**
+ * An RDD containing the vertices and their associated attributes.
+ *
+ * @note vertex ids are unique.
+ * @return an RDD containing the vertices in this graph
+ */
+ val vertices: VertexRDD[VD]
+
+ /**
+ * An RDD containing the edges and their associated attributes. The entries in the RDD contain
+ * just the source id and target id along with the edge data.
+ *
+ * @return an RDD containing the edges in this graph
+ *
+ * @see [[Edge]] for the edge type.
+ * @see [[triplets]] to get an RDD which contains all the edges
+ * along with their vertex data.
+ *
+ */
+ val edges: EdgeRDD[ED]
+
+ /**
+ * An RDD containing the edge triplets, which are edges along with the vertex data associated with
+ * the adjacent vertices. The caller should use [[edges]] if the vertex data are not needed, i.e.
+ * if only the edge data and adjacent vertex ids are needed.
+ *
+ * @return an RDD containing edge triplets
+ *
+ * @example This operation might be used to evaluate a graph
+ * coloring where we would like to check that both vertices are a
+ * different color.
+ * {{{
+ * type Color = Int
+ * val graph: Graph[Color, Int] = GraphLoader.edgeListFile("hdfs://file.tsv")
+ * val numInvalid = graph.triplets.map(e => if (e.src.data == e.dst.data) 1 else 0).sum
+ * }}}
+ */
+ val triplets: RDD[EdgeTriplet[VD, ED]]
+
+ /**
+ * Caches the vertices and edges associated with this graph at the specified storage level.
+ *
+ * @param newLevel the level at which to cache the graph.
+ *
+ * @return A reference to this graph for convenience.
+ */
+ def persist(newLevel: StorageLevel = StorageLevel.MEMORY_ONLY): Graph[VD, ED]
+
+ /**
+ * Caches the vertices and edges associated with this graph. This is used to
+ * pin a graph in memory enabling multiple queries to reuse the same
+ * construction process.
+ */
+ def cache(): Graph[VD, ED]
+
+ /**
+ * Uncaches only the vertices of this graph, leaving the edges alone. This is useful in iterative
+ * algorithms that modify the vertex attributes but reuse the edges. This method can be used to
+ * uncache the vertex attributes of previous iterations once they are no longer needed, improving
+ * GC performance.
+ */
+ def unpersistVertices(blocking: Boolean = true): Graph[VD, ED]
+
+ /**
+ * Repartitions the edges in the graph according to `partitionStrategy`.
+ */
+ def partitionBy(partitionStrategy: PartitionStrategy): Graph[VD, ED]
+
+ /**
+ * Transforms each vertex attribute in the graph using the map function.
+ *
+ * @note The new graph has the same structure. As a consequence the underlying index structures
+ * can be reused.
+ *
+ * @param map the function from a vertex object to a new vertex value
+ *
+ * @tparam VD2 the new vertex data type
+ *
+ * @example We might use this operation to change the vertex values
+ * from one type to another to initialize an algorithm.
+ * {{{
+ * val rawGraph: Graph[(), ()] = Graph.textFile("hdfs://file")
+ * val root = 42
+ * var bfsGraph = rawGraph.mapVertices[Int]((vid, data) => if (vid == root) 0 else Math.MaxValue)
+ * }}}
+ *
+ */
+ def mapVertices[VD2: ClassTag](map: (VertexID, VD) => VD2): Graph[VD2, ED]
+
+ /**
+ * Transforms each edge attribute in the graph using the map function. The map function is not
+ * passed the vertex value for the vertices adjacent to the edge. If vertex values are desired,
+ * use `mapTriplets`.
+ *
+ * @note This graph is not changed and that the new graph has the
+ * same structure. As a consequence the underlying index structures
+ * can be reused.
+ *
+ * @param map the function from an edge object to a new edge value.
+ *
+ * @tparam ED2 the new edge data type
+ *
+ * @example This function might be used to initialize edge
+ * attributes.
+ *
+ */
+ def mapEdges[ED2: ClassTag](map: Edge[ED] => ED2): Graph[VD, ED2] = {
+ mapEdges((pid, iter) => iter.map(map))
+ }
+
+ /**
+ * Transforms each edge attribute using the map function, passing it a whole partition at a
+ * time. The map function is given an iterator over edges within a logical partition as well as
+ * the partition's ID, and it should return a new iterator over the new values of each edge. The
+ * new iterator's elements must correspond one-to-one with the old iterator's elements. If
+ * adjacent vertex values are desired, use `mapTriplets`.
+ *
+ * @note This does not change the structure of the
+ * graph or modify the values of this graph. As a consequence
+ * the underlying index structures can be reused.
+ *
+ * @param map a function that takes a partition id and an iterator
+ * over all the edges in the partition, and must return an iterator over
+ * the new values for each edge in the order of the input iterator
+ *
+ * @tparam ED2 the new edge data type
+ *
+ */
+ def mapEdges[ED2: ClassTag](map: (PartitionID, Iterator[Edge[ED]]) => Iterator[ED2])
+ : Graph[VD, ED2]
+
+ /**
+ * Transforms each edge attribute using the map function, passing it the adjacent vertex attributes
+ * as well. If adjacent vertex values are not required, consider using `mapEdges` instead.
+ *
+ * @note This does not change the structure of the
+ * graph or modify the values of this graph. As a consequence
+ * the underlying index structures can be reused.
+ *
+ * @param map the function from an edge object to a new edge value.
+ *
+ * @tparam ED2 the new edge data type
+ *
+ * @example This function might be used to initialize edge
+ * attributes based on the attributes associated with each vertex.
+ * {{{
+ * val rawGraph: Graph[Int, Int] = someLoadFunction()
+ * val graph = rawGraph.mapTriplets[Int]( edge =>
+ * edge.src.data - edge.dst.data)
+ * }}}
+ *
+ */
+ def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
+ mapTriplets((pid, iter) => iter.map(map))
+ }
+
+ /**
+ * Transforms each edge attribute a partition at a time using the map function, passing it the
+ * adjacent vertex attributes as well. The map function is given an iterator over edge triplets
+ * within a logical partition and should yield a new iterator over the new values of each edge in
+ * the order in which they are provided. If adjacent vertex values are not required, consider
+ * using `mapEdges` instead.
+ *
+ * @note This does not change the structure of the
+ * graph or modify the values of this graph. As a consequence
+ * the underlying index structures can be reused.
+ *
+ * @param map the iterator transform
+ *
+ * @tparam ED2 the new edge data type
+ *
+ */
+ def mapTriplets[ED2: ClassTag](map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2])
+ : Graph[VD, ED2]
+
+ /**
+ * Reverses all edges in the graph. If this graph contains an edge from a to b then the returned
+ * graph contains an edge from b to a.
+ */
+ def reverse: Graph[VD, ED]
+
+ /**
+ * Restricts the graph to only the vertices and edges satisfying the predicates. The resulting
+ * subgraph satisifies
+ *
+ * {{{
+ * V' = {v : for all v in V where vpred(v)}
+ * E' = {(u,v): for all (u,v) in E where epred((u,v)) && vpred(u) && vpred(v)}
+ * }}}
+ *
+ * @param epred the edge predicate, which takes a triplet and
+ * evaluates to true if the edge is to remain in the subgraph. Note
+ * that only edges where both vertices satisfy the vertex
+ * predicate are considered.
+ *
+ * @param vpred the vertex predicate, which takes a vertex object and
+ * evaluates to true if the vertex is to be included in the subgraph
+ *
+ * @return the subgraph containing only the vertices and edges that
+ * satisfy the predicates
+ */
+ def subgraph(
+ epred: EdgeTriplet[VD,ED] => Boolean = (x => true),
+ vpred: (VertexID, VD) => Boolean = ((v, d) => true))
+ : Graph[VD, ED]
+
+ /**
+ * Restricts the graph to only the vertices and edges that are also in `other`, but keeps the
+ * attributes from this graph.
+ * @param other the graph to project this graph onto
+ * @return a graph with vertices and edges that exist in both the current graph and `other`,
+ * with vertex and edge data from the current graph
+ */
+ def mask[VD2: ClassTag, ED2: ClassTag](other: Graph[VD2, ED2]): Graph[VD, ED]
+
+ /**
+ * Merges multiple edges between two vertices into a single edge. For correct results, the graph
+ * must have been partitioned using [[partitionBy]].
+ *
+ * @param merge the user-supplied commutative associative function to merge edge attributes
+ * for duplicate edges.
+ *
+ * @return The resulting graph with a single edge for each (source, dest) vertex pair.
+ */
+ def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED]
+
+ /**
+ * Aggregates values from the neighboring edges and vertices of each vertex. The user supplied
+ * `mapFunc` function is invoked on each edge of the graph, generating 0 or more "messages" to be
+ * "sent" to either vertex in the edge. The `reduceFunc` is then used to combine the output of
+ * the map phase destined to each vertex.
+ *
+ * @tparam A the type of "message" to be sent to each vertex
+ *
+ * @param mapFunc the user defined map function which returns 0 or
+ * more messages to neighboring vertices
+ *
+ * @param reduceFunc the user defined reduce function which should
+ * be commutative and associative and is used to combine the output
+ * of the map phase
+ *
+ * @param activeSetOpt optionally, a set of "active" vertices and a direction of edges to consider
+ * when running `mapFunc`. If the direction is `In`, `mapFunc` will only be run on edges with
+ * destination in the active set. If the direction is `Out`, `mapFunc` will only be run on edges
+ * originating from vertices in the active set. If the direction is `Either`, `mapFunc` will be
+ * run on edges with *either* vertex in the active set. If the direction is `Both`, `mapFunc` will
+ * be run on edges with *both* vertices in the active set. The active set must have the same index
+ * as the graph's vertices.
+ *
+ * @example We can use this function to compute the in-degree of each
+ * vertex
+ * {{{
+ * val rawGraph: Graph[(),()] = Graph.textFile("twittergraph")
+ * val inDeg: RDD[(VertexID, Int)] =
+ * mapReduceTriplets[Int](et => Iterator((et.dst.id, 1)), _ + _)
+ * }}}
+ *
+ * @note By expressing computation at the edge level we achieve
+ * maximum parallelism. This is one of the core functions in the
+ * Graph API in that enables neighborhood level computation. For
+ * example this function can be used to count neighbors satisfying a
+ * predicate or implement PageRank.
+ *
+ */
+ def mapReduceTriplets[A: ClassTag](
+ mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexID, A)],
+ reduceFunc: (A, A) => A,
+ activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None)
+ : VertexRDD[A]
+
+ /**
+ * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. The
+ * input table should contain at most one entry for each vertex. If no entry in `other` is
+ * provided for a particular vertex in the graph, the map function receives `None`.
+ *
+ * @tparam U the type of entry in the table of updates
+ * @tparam VD2 the new vertex value type
+ *
+ * @param other the table to join with the vertices in the graph.
+ * The table should contain at most one entry for each vertex.
+ * @param mapFunc the function used to compute the new vertex values.
+ * The map function is invoked for all vertices, even those
+ * that do not have a corresponding entry in the table.
+ *
+ * @example This function is used to update the vertices with new values based on external data.
+ * For example we could add the out-degree to each vertex record:
+ *
+ * {{{
+ * val rawGraph: Graph[_, _] = Graph.textFile("webgraph")
+ * val outDeg: RDD[(VertexID, Int)] = rawGraph.outDegrees()
+ * val graph = rawGraph.outerJoinVertices(outDeg) {
+ * (vid, data, optDeg) => optDeg.getOrElse(0)
+ * }
+ * }}}
+ */
+ def outerJoinVertices[U: ClassTag, VD2: ClassTag](other: RDD[(VertexID, U)])
+ (mapFunc: (VertexID, VD, Option[U]) => VD2)
+ : Graph[VD2, ED]
+
+ /**
+ * The associated [[GraphOps]] object.
+ */
+ // Save a copy of the GraphOps object so there is always one unique GraphOps object
+ // for a given Graph object, and thus the lazy vals in GraphOps would work as intended.
+ val ops = new GraphOps(this)
+} // end of Graph
+
+
+/**
+ * The Graph object contains a collection of routines used to construct graphs from RDDs.
+ */
+object Graph {
+
+ /**
+ * Construct a graph from a collection of edges encoded as vertex id pairs.
+ *
+ * @param rawEdges a collection of edges in (src, dst) form
+ * @param uniqueEdges if multiple identical edges are found they are combined and the edge
+ * attribute is set to the sum. Otherwise duplicate edges are treated as separate. To enable
+ * `uniqueEdges`, a [[PartitionStrategy]] must be provided.
+ *
+ * @return a graph with edge attributes containing either the count of duplicate edges or 1
+ * (if `uniqueEdges` is `None`) and vertex attributes containing the total degree of each vertex.
+ */
+ def fromEdgeTuples[VD: ClassTag](
+ rawEdges: RDD[(VertexID, VertexID)],
+ defaultValue: VD,
+ uniqueEdges: Option[PartitionStrategy] = None): Graph[VD, Int] =
+ {
+ val edges = rawEdges.map(p => Edge(p._1, p._2, 1))
+ val graph = GraphImpl(edges, defaultValue)
+ uniqueEdges match {
+ case Some(p) => graph.partitionBy(p).groupEdges((a, b) => a + b)
+ case None => graph
+ }
+ }
+
+ /**
+ * Construct a graph from a collection of edges.
+ *
+ * @param edges the RDD containing the set of edges in the graph
+ * @param defaultValue the default vertex attribute to use for each vertex
+ *
+ * @return a graph with edge attributes described by `edges` and vertices
+ * given by all vertices in `edges` with value `defaultValue`
+ */
+ def fromEdges[VD: ClassTag, ED: ClassTag](
+ edges: RDD[Edge[ED]],
+ defaultValue: VD): Graph[VD, ED] = {
+ GraphImpl(edges, defaultValue)
+ }
+
+ /**
+ * Construct a graph from a collection of vertices and
+ * edges with attributes. Duplicate vertices are picked arbitrarily and
+ * vertices found in the edge collection but not in the input
+ * vertices are assigned the default attribute.
+ *
+ * @tparam VD the vertex attribute type
+ * @tparam ED the edge attribute type
+ * @param vertices the "set" of vertices and their attributes
+ * @param edges the collection of edges in the graph
+ * @param defaultVertexAttr the default vertex attribute to use for vertices that are
+ * mentioned in edges but not in vertices
+ */
+ def apply[VD: ClassTag, ED: ClassTag](
+ vertices: RDD[(VertexID, VD)],
+ edges: RDD[Edge[ED]],
+ defaultVertexAttr: VD = null.asInstanceOf[VD]): Graph[VD, ED] = {
+ GraphImpl(vertices, edges, defaultVertexAttr)
+ }
+
+ /**
+ * Implicitly extracts the [[GraphOps]] member from a graph.
+ *
+ * To improve modularity the Graph type only contains a small set of basic operations.
+ * All the convenience operations are defined in the [[GraphOps]] class which may be
+ * shared across multiple graph implementations.
+ */
+ implicit def graphToGraphOps[VD: ClassTag, ED: ClassTag](g: Graph[VD, ED]) = g.ops
+} // end of Graph object
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
new file mode 100644
index 0000000000..d79bdf9618
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
@@ -0,0 +1,31 @@
+package org.apache.spark.graphx
+
+import com.esotericsoftware.kryo.Kryo
+
+import org.apache.spark.graphx.impl._
+import org.apache.spark.serializer.KryoRegistrator
+import org.apache.spark.util.collection.BitSet
+import org.apache.spark.util.BoundedPriorityQueue
+
+/**
+ * Registers GraphX classes with Kryo for improved performance.
+ */
+class GraphKryoRegistrator extends KryoRegistrator {
+
+ def registerClasses(kryo: Kryo) {
+ kryo.register(classOf[Edge[Object]])
+ kryo.register(classOf[MessageToPartition[Object]])
+ kryo.register(classOf[VertexBroadcastMsg[Object]])
+ kryo.register(classOf[(VertexID, Object)])
+ kryo.register(classOf[EdgePartition[Object]])
+ kryo.register(classOf[BitSet])
+ kryo.register(classOf[VertexIdToIndexMap])
+ kryo.register(classOf[VertexAttributeBlock[Object]])
+ kryo.register(classOf[PartitionStrategy])
+ kryo.register(classOf[BoundedPriorityQueue[Object]])
+ kryo.register(classOf[EdgeDirection])
+
+ // This avoids a large number of hash table lookups.
+ kryo.setReferences(false)
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala
new file mode 100644
index 0000000000..5904aa3a28
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala
@@ -0,0 +1,72 @@
+package org.apache.spark.graphx
+
+import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.graphx.impl.{EdgePartitionBuilder, GraphImpl}
+
+/**
+ * Provides utilities for loading [[Graph]]s from files.
+ */
+object GraphLoader extends Logging {
+
+ /**
+ * Loads a graph from an edge list formatted file where each line contains two integers: a source
+ * id and a target id. Skips lines that begin with `#`.
+ *
+ * If desired the edges can be automatically oriented in the positive
+ * direction (source Id < target Id) by setting `canonicalOrientation` to
+ * true.
+ *
+ * @example Loads a file in the following format:
+ * {{{
+ * # Comment Line
+ * # Source Id <\t> Target Id
+ * 1 -5
+ * 1 2
+ * 2 7
+ * 1 8
+ * }}}
+ *
+ * @param sc SparkContext
+ * @param path the path to the file (e.g., /home/data/file or hdfs://file)
+ * @param canonicalOrientation whether to orient edges in the positive
+ * direction
+ * @param minEdgePartitions the number of partitions for the
+ * the edge RDD
+ */
+ def edgeListFile(
+ sc: SparkContext,
+ path: String,
+ canonicalOrientation: Boolean = false,
+ minEdgePartitions: Int = 1)
+ : Graph[Int, Int] =
+ {
+ val startTime = System.currentTimeMillis
+
+ // Parse the edge data table directly into edge partitions
+ val edges = sc.textFile(path, minEdgePartitions).mapPartitionsWithIndex { (pid, iter) =>
+ val builder = new EdgePartitionBuilder[Int]
+ iter.foreach { line =>
+ if (!line.isEmpty && line(0) != '#') {
+ val lineArray = line.split("\\s+")
+ if (lineArray.length < 2) {
+ logWarning("Invalid line: " + line)
+ }
+ val srcId = lineArray(0).toLong
+ val dstId = lineArray(1).toLong
+ if (canonicalOrientation && srcId > dstId) {
+ builder.add(dstId, srcId, 1)
+ } else {
+ builder.add(srcId, dstId, 1)
+ }
+ }
+ }
+ Iterator((pid, builder.toEdgePartition))
+ }.cache()
+ edges.count()
+
+ logInfo("It took %d ms to load the edges".format(System.currentTimeMillis - startTime))
+
+ GraphImpl.fromEdgePartitions(edges, defaultVertexAttr = 1)
+ } // end of edgeListFile
+
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
new file mode 100644
index 0000000000..f10e63f059
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -0,0 +1,301 @@
+package org.apache.spark.graphx
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.SparkException
+import org.apache.spark.graphx.lib._
+import org.apache.spark.rdd.RDD
+
+/**
+ * Contains additional functionality for [[Graph]]. All operations are expressed in terms of the
+ * efficient GraphX API. This class is implicitly constructed for each Graph object.
+ *
+ * @tparam VD the vertex attribute type
+ * @tparam ED the edge attribute type
+ */
+class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Serializable {
+
+ /** The number of edges in the graph. */
+ lazy val numEdges: Long = graph.edges.count()
+
+ /** The number of vertices in the graph. */
+ lazy val numVertices: Long = graph.vertices.count()
+
+ /**
+ * The in-degree of each vertex in the graph.
+ * @note Vertices with no in-edges are not returned in the resulting RDD.
+ */
+ lazy val inDegrees: VertexRDD[Int] = degreesRDD(EdgeDirection.In)
+
+ /**
+ * The out-degree of each vertex in the graph.
+ * @note Vertices with no out-edges are not returned in the resulting RDD.
+ */
+ lazy val outDegrees: VertexRDD[Int] = degreesRDD(EdgeDirection.Out)
+
+ /**
+ * The degree of each vertex in the graph.
+ * @note Vertices with no edges are not returned in the resulting RDD.
+ */
+ lazy val degrees: VertexRDD[Int] = degreesRDD(EdgeDirection.Either)
+
+ /**
+ * Computes the neighboring vertex degrees.
+ *
+ * @param edgeDirection the direction along which to collect neighboring vertex attributes
+ */
+ private def degreesRDD(edgeDirection: EdgeDirection): VertexRDD[Int] = {
+ if (edgeDirection == EdgeDirection.In) {
+ graph.mapReduceTriplets(et => Iterator((et.dstId,1)), _ + _)
+ } else if (edgeDirection == EdgeDirection.Out) {
+ graph.mapReduceTriplets(et => Iterator((et.srcId,1)), _ + _)
+ } else { // EdgeDirection.Either
+ graph.mapReduceTriplets(et => Iterator((et.srcId,1), (et.dstId,1)), _ + _)
+ }
+ }
+
+ /**
+ * Collect the neighbor vertex ids for each vertex.
+ *
+ * @param edgeDirection the direction along which to collect
+ * neighboring vertices
+ *
+ * @return the set of neighboring ids for each vertex
+ */
+ def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexID]] = {
+ val nbrs =
+ if (edgeDirection == EdgeDirection.Either) {
+ graph.mapReduceTriplets[Array[VertexID]](
+ mapFunc = et => Iterator((et.srcId, Array(et.dstId)), (et.dstId, Array(et.srcId))),
+ reduceFunc = _ ++ _
+ )
+ } else if (edgeDirection == EdgeDirection.Out) {
+ graph.mapReduceTriplets[Array[VertexID]](
+ mapFunc = et => Iterator((et.srcId, Array(et.dstId))),
+ reduceFunc = _ ++ _)
+ } else if (edgeDirection == EdgeDirection.In) {
+ graph.mapReduceTriplets[Array[VertexID]](
+ mapFunc = et => Iterator((et.dstId, Array(et.srcId))),
+ reduceFunc = _ ++ _)
+ } else {
+ throw new SparkException("It doesn't make sense to collect neighbor ids without a " +
+ "direction. (EdgeDirection.Both is not supported; use EdgeDirection.Either instead.)")
+ }
+ graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) =>
+ nbrsOpt.getOrElse(Array.empty[VertexID])
+ }
+ } // end of collectNeighborIds
+
+ /**
+ * Collect the neighbor vertex attributes for each vertex.
+ *
+ * @note This function could be highly inefficient on power-law
+ * graphs where high degree vertices may force a large ammount of
+ * information to be collected to a single location.
+ *
+ * @param edgeDirection the direction along which to collect
+ * neighboring vertices
+ *
+ * @return the vertex set of neighboring vertex attributes for each vertex
+ */
+ def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexID, VD)]] = {
+ val nbrs = graph.mapReduceTriplets[Array[(VertexID,VD)]](
+ edge => {
+ val msgToSrc = (edge.srcId, Array((edge.dstId, edge.dstAttr)))
+ val msgToDst = (edge.dstId, Array((edge.srcId, edge.srcAttr)))
+ edgeDirection match {
+ case EdgeDirection.Either => Iterator(msgToSrc, msgToDst)
+ case EdgeDirection.In => Iterator(msgToDst)
+ case EdgeDirection.Out => Iterator(msgToSrc)
+ case EdgeDirection.Both =>
+ throw new SparkException("collectNeighbors does not support EdgeDirection.Both. Use" +
+ "EdgeDirection.Either instead.")
+ }
+ },
+ (a, b) => a ++ b)
+
+ graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) =>
+ nbrsOpt.getOrElse(Array.empty[(VertexID, VD)])
+ }
+ } // end of collectNeighbor
+
+ /**
+ * Join the vertices with an RDD and then apply a function from the
+ * the vertex and RDD entry to a new vertex value. The input table
+ * should contain at most one entry for each vertex. If no entry is
+ * provided the map function is skipped and the old value is used.
+ *
+ * @tparam U the type of entry in the table of updates
+ * @param table the table to join with the vertices in the graph.
+ * The table should contain at most one entry for each vertex.
+ * @param mapFunc the function used to compute the new vertex
+ * values. The map function is invoked only for vertices with a
+ * corresponding entry in the table otherwise the old vertex value
+ * is used.
+ *
+ * @example This function is used to update the vertices with new
+ * values based on external data. For example we could add the out
+ * degree to each vertex record
+ *
+ * {{{
+ * val rawGraph: Graph[Int, Int] = GraphLoader.edgeListFile(sc, "webgraph")
+ * .mapVertices(v => 0)
+ * val outDeg: RDD[(Int, Int)] = rawGraph.outDegrees
+ * val graph = rawGraph.leftJoinVertices[Int,Int](outDeg,
+ * (v, deg) => deg )
+ * }}}
+ *
+ */
+ def joinVertices[U: ClassTag](table: RDD[(VertexID, U)])(mapFunc: (VertexID, VD, U) => VD)
+ : Graph[VD, ED] = {
+ val uf = (id: VertexID, data: VD, o: Option[U]) => {
+ o match {
+ case Some(u) => mapFunc(id, data, u)
+ case None => data
+ }
+ }
+ graph.outerJoinVertices(table)(uf)
+ }
+
+ /**
+ * Filter the graph by computing some values to filter on, and applying the predicates.
+ *
+ * @param preprocess a function to compute new vertex and edge data before filtering
+ * @param epred edge pred to filter on after preprocess, see more details under
+ * [[org.apache.spark.graphx.Graph#subgraph]]
+ * @param vpred vertex pred to filter on after prerocess, see more details under
+ * [[org.apache.spark.graphx.Graph#subgraph]]
+ * @tparam VD2 vertex type the vpred operates on
+ * @tparam ED2 edge type the epred operates on
+ * @return a subgraph of the orginal graph, with its data unchanged
+ *
+ * @example This function can be used to filter the graph based on some property, without
+ * changing the vertex and edge values in your program. For example, we could remove the vertices
+ * in a graph with 0 outdegree
+ *
+ * {{{
+ * graph.filter(
+ * graph => {
+ * val degrees: VertexRDD[Int] = graph.outDegrees
+ * graph.outerJoinVertices(degrees) {(vid, data, deg) => deg.getOrElse(0)}
+ * },
+ * vpred = (vid: VertexID, deg:Int) => deg > 0
+ * )
+ * }}}
+ *
+ */
+ def filter[VD2: ClassTag, ED2: ClassTag](
+ preprocess: Graph[VD, ED] => Graph[VD2, ED2],
+ epred: (EdgeTriplet[VD2, ED2]) => Boolean = (x: EdgeTriplet[VD2, ED2]) => true,
+ vpred: (VertexID, VD2) => Boolean = (v:VertexID, d:VD2) => true): Graph[VD, ED] = {
+ graph.mask(preprocess(graph).subgraph(epred, vpred))
+ }
+
+ /**
+ * Execute a Pregel-like iterative vertex-parallel abstraction. The
+ * user-defined vertex-program `vprog` is executed in parallel on
+ * each vertex receiving any inbound messages and computing a new
+ * value for the vertex. The `sendMsg` function is then invoked on
+ * all out-edges and is used to compute an optional message to the
+ * destination vertex. The `mergeMsg` function is a commutative
+ * associative function used to combine messages destined to the
+ * same vertex.
+ *
+ * On the first iteration all vertices receive the `initialMsg` and
+ * on subsequent iterations if a vertex does not receive a message
+ * then the vertex-program is not invoked.
+ *
+ * This function iterates until there are no remaining messages, or
+ * for `maxIterations` iterations.
+ *
+ * @tparam A the Pregel message type
+ *
+ * @param initialMsg the message each vertex will receive at the on
+ * the first iteration
+ *
+ * @param maxIterations the maximum number of iterations to run for
+ *
+ * @param activeDirection the direction of edges incident to a vertex that received a message in
+ * the previous round on which to run `sendMsg`. For example, if this is `EdgeDirection.Out`, only
+ * out-edges of vertices that received a message in the previous round will run.
+ *
+ * @param vprog the user-defined vertex program which runs on each
+ * vertex and receives the inbound message and computes a new vertex
+ * value. On the first iteration the vertex program is invoked on
+ * all vertices and is passed the default message. On subsequent
+ * iterations the vertex program is only invoked on those vertices
+ * that receive messages.
+ *
+ * @param sendMsg a user supplied function that is applied to out
+ * edges of vertices that received messages in the current
+ * iteration
+ *
+ * @param mergeMsg a user supplied function that takes two incoming
+ * messages of type A and merges them into a single message of type
+ * A. ''This function must be commutative and associative and
+ * ideally the size of A should not increase.''
+ *
+ * @return the resulting graph at the end of the computation
+ *
+ */
+ def pregel[A: ClassTag](
+ initialMsg: A,
+ maxIterations: Int = Int.MaxValue,
+ activeDirection: EdgeDirection = EdgeDirection.Either)(
+ vprog: (VertexID, VD, A) => VD,
+ sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexID,A)],
+ mergeMsg: (A, A) => A)
+ : Graph[VD, ED] = {
+ Pregel(graph, initialMsg, maxIterations, activeDirection)(vprog, sendMsg, mergeMsg)
+ }
+
+ /**
+ * Run a dynamic version of PageRank returning a graph with vertex attributes containing the
+ * PageRank and edge attributes containing the normalized edge weight.
+ *
+ * @see [[org.apache.spark.graphx.lib.PageRank$#runUntilConvergence]]
+ */
+ def pageRank(tol: Double, resetProb: Double = 0.15): Graph[Double, Double] = {
+ PageRank.runUntilConvergence(graph, tol, resetProb)
+ }
+
+ /**
+ * Run PageRank for a fixed number of iterations returning a graph with vertex attributes
+ * containing the PageRank and edge attributes the normalized edge weight.
+ *
+ * @see [[org.apache.spark.graphx.lib.PageRank$#run]]
+ */
+ def staticPageRank(numIter: Int, resetProb: Double = 0.15): Graph[Double, Double] = {
+ PageRank.run(graph, numIter, resetProb)
+ }
+
+ /**
+ * Compute the connected component membership of each vertex and return a graph with the vertex
+ * value containing the lowest vertex id in the connected component containing that vertex.
+ *
+ * @see [[org.apache.spark.graphx.lib.ConnectedComponents$#run]]
+ */
+ def connectedComponents(): Graph[VertexID, ED] = {
+ ConnectedComponents.run(graph)
+ }
+
+ /**
+ * Compute the number of triangles passing through each vertex.
+ *
+ * @see [[org.apache.spark.graphx.lib.TriangleCount$#run]]
+ */
+ def triangleCount(): Graph[Int, ED] = {
+ TriangleCount.run(graph)
+ }
+
+ /**
+ * Compute the strongly connected component (SCC) of each vertex and return a graph with the
+ * vertex value containing the lowest vertex id in the SCC containing that vertex.
+ *
+ * @see [[org.apache.spark.graphx.lib.StronglyConnectedComponents$#run]]
+ */
+ def stronglyConnectedComponents(numIter: Int): Graph[VertexID, ED] = {
+ StronglyConnectedComponents.run(graph, numIter)
+ }
+} // end of GraphOps
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala
new file mode 100644
index 0000000000..6d2990a3f6
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala
@@ -0,0 +1,103 @@
+package org.apache.spark.graphx
+
+/**
+ * Represents the way edges are assigned to edge partitions based on their source and destination
+ * vertex IDs.
+ */
+trait PartitionStrategy extends Serializable {
+ /** Returns the partition number for a given edge. */
+ def getPartition(src: VertexID, dst: VertexID, numParts: PartitionID): PartitionID
+}
+
+/**
+ * Collection of built-in [[PartitionStrategy]] implementations.
+ */
+object PartitionStrategy {
+ /**
+ * Assigns edges to partitions using a 2D partitioning of the sparse edge adjacency matrix,
+ * guaranteeing a `2 * sqrt(numParts)` bound on vertex replication.
+ *
+ * Suppose we have a graph with 11 vertices that we want to partition
+ * over 9 machines. We can use the following sparse matrix representation:
+ *
+ * <pre>
+ * __________________________________
+ * v0 | P0 * | P1 | P2 * |
+ * v1 | **** | * | |
+ * v2 | ******* | ** | **** |
+ * v3 | ***** | * * | * |
+ * ----------------------------------
+ * v4 | P3 * | P4 *** | P5 ** * |
+ * v5 | * * | * | |
+ * v6 | * | ** | **** |
+ * v7 | * * * | * * | * |
+ * ----------------------------------
+ * v8 | P6 * | P7 * | P8 * *|
+ * v9 | * | * * | |
+ * v10 | * | ** | * * |
+ * v11 | * <-E | *** | ** |
+ * ----------------------------------
+ * </pre>
+ *
+ * The edge denoted by `E` connects `v11` with `v1` and is assigned to processor `P6`. To get the
+ * processor number we divide the matrix into `sqrt(numParts)` by `sqrt(numParts)` blocks. Notice
+ * that edges adjacent to `v11` can only be in the first column of blocks `(P0, P3, P6)` or the last
+ * row of blocks `(P6, P7, P8)`. As a consequence we can guarantee that `v11` will need to be
+ * replicated to at most `2 * sqrt(numParts)` machines.
+ *
+ * Notice that `P0` has many edges and as a consequence this partitioning would lead to poor work
+ * balance. To improve balance we first multiply each vertex id by a large prime to shuffle the
+ * vertex locations.
+ *
+ * One of the limitations of this approach is that the number of machines must either be a perfect
+ * square. We partially address this limitation by computing the machine assignment to the next
+ * largest perfect square and then mapping back down to the actual number of machines.
+ * Unfortunately, this can also lead to work imbalance and so it is suggested that a perfect square
+ * is used.
+ */
+ case object EdgePartition2D extends PartitionStrategy {
+ override def getPartition(src: VertexID, dst: VertexID, numParts: PartitionID): PartitionID = {
+ val ceilSqrtNumParts: PartitionID = math.ceil(math.sqrt(numParts)).toInt
+ val mixingPrime: VertexID = 1125899906842597L
+ val col: PartitionID = ((math.abs(src) * mixingPrime) % ceilSqrtNumParts).toInt
+ val row: PartitionID = ((math.abs(dst) * mixingPrime) % ceilSqrtNumParts).toInt
+ (col * ceilSqrtNumParts + row) % numParts
+ }
+ }
+
+ /**
+ * Assigns edges to partitions using only the source vertex ID, colocating edges with the same
+ * source.
+ */
+ case object EdgePartition1D extends PartitionStrategy {
+ override def getPartition(src: VertexID, dst: VertexID, numParts: PartitionID): PartitionID = {
+ val mixingPrime: VertexID = 1125899906842597L
+ (math.abs(src) * mixingPrime).toInt % numParts
+ }
+ }
+
+
+ /**
+ * Assigns edges to partitions by hashing the source and destination vertex IDs, resulting in a
+ * random vertex cut that colocates all same-direction edges between two vertices.
+ */
+ case object RandomVertexCut extends PartitionStrategy {
+ override def getPartition(src: VertexID, dst: VertexID, numParts: PartitionID): PartitionID = {
+ math.abs((src, dst).hashCode()) % numParts
+ }
+ }
+
+
+ /**
+ * Assigns edges to partitions by hashing the source and destination vertex IDs in a canonical
+ * direction, resulting in a random vertex cut that colocates all edges between two vertices,
+ * regardless of direction.
+ */
+ case object CanonicalRandomVertexCut extends PartitionStrategy {
+ override def getPartition(src: VertexID, dst: VertexID, numParts: PartitionID): PartitionID = {
+ val lower = math.min(src, dst)
+ val higher = math.max(src, dst)
+ math.abs((lower, higher).hashCode()) % numParts
+ }
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
new file mode 100644
index 0000000000..fc18f7e785
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
@@ -0,0 +1,139 @@
+package org.apache.spark.graphx
+
+import scala.reflect.ClassTag
+
+
+/**
+ * Implements a Pregel-like bulk-synchronous message-passing API.
+ *
+ * Unlike the original Pregel API, the GraphX Pregel API factors the sendMessage computation over
+ * edges, enables the message sending computation to read both vertex attributes, and constrains
+ * messages to the graph structure. These changes allow for substantially more efficient
+ * distributed execution while also exposing greater flexibility for graph-based computation.
+ *
+ * @example We can use the Pregel abstraction to implement PageRank:
+ * {{{
+ * val pagerankGraph: Graph[Double, Double] = graph
+ * // Associate the degree with each vertex
+ * .outerJoinVertices(graph.outDegrees) {
+ * (vid, vdata, deg) => deg.getOrElse(0)
+ * }
+ * // Set the weight on the edges based on the degree
+ * .mapTriplets(e => 1.0 / e.srcAttr)
+ * // Set the vertex attributes to the initial pagerank values
+ * .mapVertices((id, attr) => 1.0)
+ *
+ * def vertexProgram(id: VertexID, attr: Double, msgSum: Double): Double =
+ * resetProb + (1.0 - resetProb) * msgSum
+ * def sendMessage(id: VertexID, edge: EdgeTriplet[Double, Double]): Iterator[(VertexId, Double)] =
+ * Iterator((edge.dstId, edge.srcAttr * edge.attr))
+ * def messageCombiner(a: Double, b: Double): Double = a + b
+ * val initialMessage = 0.0
+ * // Execute Pregel for a fixed number of iterations.
+ * Pregel(pagerankGraph, initialMessage, numIter)(
+ * vertexProgram, sendMessage, messageCombiner)
+ * }}}
+ *
+ */
+object Pregel {
+
+ /**
+ * Execute a Pregel-like iterative vertex-parallel abstraction. The
+ * user-defined vertex-program `vprog` is executed in parallel on
+ * each vertex receiving any inbound messages and computing a new
+ * value for the vertex. The `sendMsg` function is then invoked on
+ * all out-edges and is used to compute an optional message to the
+ * destination vertex. The `mergeMsg` function is a commutative
+ * associative function used to combine messages destined to the
+ * same vertex.
+ *
+ * On the first iteration all vertices receive the `initialMsg` and
+ * on subsequent iterations if a vertex does not receive a message
+ * then the vertex-program is not invoked.
+ *
+ * This function iterates until there are no remaining messages, or
+ * for `maxIterations` iterations.
+ *
+ * @tparam VD the vertex data type
+ * @tparam ED the edge data type
+ * @tparam A the Pregel message type
+ *
+ * @param graph the input graph.
+ *
+ * @param initialMsg the message each vertex will receive at the on
+ * the first iteration
+ *
+ * @param maxIterations the maximum number of iterations to run for
+ *
+ * @param activeDirection the direction of edges incident to a vertex that received a message in
+ * the previous round on which to run `sendMsg`. For example, if this is `EdgeDirection.Out`, only
+ * out-edges of vertices that received a message in the previous round will run. The default is
+ * `EdgeDirection.Either`, which will run `sendMsg` on edges where either side received a message
+ * in the previous round. If this is `EdgeDirection.Both`, `sendMsg` will only run on edges where
+ * *both* vertices received a message.
+ *
+ * @param vprog the user-defined vertex program which runs on each
+ * vertex and receives the inbound message and computes a new vertex
+ * value. On the first iteration the vertex program is invoked on
+ * all vertices and is passed the default message. On subsequent
+ * iterations the vertex program is only invoked on those vertices
+ * that receive messages.
+ *
+ * @param sendMsg a user supplied function that is applied to out
+ * edges of vertices that received messages in the current
+ * iteration
+ *
+ * @param mergeMsg a user supplied function that takes two incoming
+ * messages of type A and merges them into a single message of type
+ * A. ''This function must be commutative and associative and
+ * ideally the size of A should not increase.''
+ *
+ * @return the resulting graph at the end of the computation
+ *
+ */
+ def apply[VD: ClassTag, ED: ClassTag, A: ClassTag]
+ (graph: Graph[VD, ED],
+ initialMsg: A,
+ maxIterations: Int = Int.MaxValue,
+ activeDirection: EdgeDirection = EdgeDirection.Either)
+ (vprog: (VertexID, VD, A) => VD,
+ sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexID, A)],
+ mergeMsg: (A, A) => A)
+ : Graph[VD, ED] =
+ {
+ var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
+ // compute the messages
+ var messages = g.mapReduceTriplets(sendMsg, mergeMsg)
+ var activeMessages = messages.count()
+ // Loop
+ var prevG: Graph[VD, ED] = null
+ var i = 0
+ while (activeMessages > 0 && i < maxIterations) {
+ // Receive the messages. Vertices that didn't get any messages do not appear in newVerts.
+ val newVerts = g.vertices.innerJoin(messages)(vprog).cache()
+ // Update the graph with the new vertices.
+ prevG = g
+ g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }
+ g.cache()
+
+ val oldMessages = messages
+ // Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't
+ // get to send messages. We must cache messages so it can be materialized on the next line,
+ // allowing us to uncache the previous iteration.
+ messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDirection))).cache()
+ // The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This
+ // hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the
+ // vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g).
+ activeMessages = messages.count()
+ // Unpersist the RDDs hidden by newly-materialized RDDs
+ oldMessages.unpersist(blocking=false)
+ newVerts.unpersist(blocking=false)
+ prevG.unpersistVertices(blocking=false)
+ // count the iteration
+ i += 1
+ }
+
+ g
+ } // end of apply
+
+} // end of class Pregel
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
new file mode 100644
index 0000000000..9a95364cb1
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
@@ -0,0 +1,347 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.graphx
+
+import scala.reflect.ClassTag
+
+import org.apache.spark._
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd._
+import org.apache.spark.storage.StorageLevel
+
+import org.apache.spark.graphx.impl.MsgRDDFunctions
+import org.apache.spark.graphx.impl.VertexPartition
+
+/**
+ * Extends `RDD[(VertexID, VD)]` by ensuring that there is only one entry for each vertex and by
+ * pre-indexing the entries for fast, efficient joins. Two VertexRDDs with the same index can be
+ * joined efficiently. All operations except [[reindex]] preserve the index. To construct a
+ * `VertexRDD`, use the [[org.apache.spark.graphx.VertexRDD$ VertexRDD object]].
+ *
+ * @example Construct a `VertexRDD` from a plain RDD:
+ * {{{
+ * // Construct an initial vertex set
+ * val someData: RDD[(VertexID, SomeType)] = loadData(someFile)
+ * val vset = VertexRDD(someData)
+ * // If there were redundant values in someData we would use a reduceFunc
+ * val vset2 = VertexRDD(someData, reduceFunc)
+ * // Finally we can use the VertexRDD to index another dataset
+ * val otherData: RDD[(VertexID, OtherType)] = loadData(otherFile)
+ * val vset3 = vset2.innerJoin(otherData) { (vid, a, b) => b }
+ * // Now we can construct very fast joins between the two sets
+ * val vset4: VertexRDD[(SomeType, OtherType)] = vset.leftJoin(vset3)
+ * }}}
+ *
+ * @tparam VD the vertex attribute associated with each vertex in the set.
+ */
+class VertexRDD[@specialized VD: ClassTag](
+ val partitionsRDD: RDD[VertexPartition[VD]])
+ extends RDD[(VertexID, VD)](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) {
+
+ require(partitionsRDD.partitioner.isDefined)
+
+ partitionsRDD.setName("VertexRDD")
+
+ /**
+ * Construct a new VertexRDD that is indexed by only the visible vertices. The resulting
+ * VertexRDD will be based on a different index and can no longer be quickly joined with this RDD.
+ */
+ def reindex(): VertexRDD[VD] = new VertexRDD(partitionsRDD.map(_.reindex()))
+
+ override val partitioner = partitionsRDD.partitioner
+
+ override protected def getPartitions: Array[Partition] = partitionsRDD.partitions
+
+ override protected def getPreferredLocations(s: Partition): Seq[String] =
+ partitionsRDD.preferredLocations(s)
+
+ override def persist(newLevel: StorageLevel): VertexRDD[VD] = {
+ partitionsRDD.persist(newLevel)
+ this
+ }
+
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
+ override def persist(): VertexRDD[VD] = persist(StorageLevel.MEMORY_ONLY)
+
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
+ override def cache(): VertexRDD[VD] = persist()
+
+ override def unpersist(blocking: Boolean = true): VertexRDD[VD] = {
+ partitionsRDD.unpersist(blocking)
+ this
+ }
+
+ /** The number of vertices in the RDD. */
+ override def count(): Long = {
+ partitionsRDD.map(_.size).reduce(_ + _)
+ }
+
+ /**
+ * Provides the `RDD[(VertexID, VD)]` equivalent output.
+ */
+ override def compute(part: Partition, context: TaskContext): Iterator[(VertexID, VD)] = {
+ firstParent[VertexPartition[VD]].iterator(part, context).next.iterator
+ }
+
+ /**
+ * Applies a function to each `VertexPartition` of this RDD and returns a new VertexRDD.
+ */
+ private[graphx] def mapVertexPartitions[VD2: ClassTag](f: VertexPartition[VD] => VertexPartition[VD2])
+ : VertexRDD[VD2] = {
+ val newPartitionsRDD = partitionsRDD.mapPartitions(_.map(f), preservesPartitioning = true)
+ new VertexRDD(newPartitionsRDD)
+ }
+
+
+ /**
+ * Restricts the vertex set to the set of vertices satisfying the given predicate. This operation
+ * preserves the index for efficient joins with the original RDD, and it sets bits in the bitmask
+ * rather than allocating new memory.
+ *
+ * @param pred the user defined predicate, which takes a tuple to conform to the
+ * `RDD[(VertexID, VD)]` interface
+ */
+ override def filter(pred: Tuple2[VertexID, VD] => Boolean): VertexRDD[VD] =
+ this.mapVertexPartitions(_.filter(Function.untupled(pred)))
+
+ /**
+ * Maps each vertex attribute, preserving the index.
+ *
+ * @tparam VD2 the type returned by the map function
+ *
+ * @param f the function applied to each value in the RDD
+ * @return a new VertexRDD with values obtained by applying `f` to each of the entries in the
+ * original VertexRDD
+ */
+ def mapValues[VD2: ClassTag](f: VD => VD2): VertexRDD[VD2] =
+ this.mapVertexPartitions(_.map((vid, attr) => f(attr)))
+
+ /**
+ * Maps each vertex attribute, additionally supplying the vertex ID.
+ *
+ * @tparam VD2 the type returned by the map function
+ *
+ * @param f the function applied to each ID-value pair in the RDD
+ * @return a new VertexRDD with values obtained by applying `f` to each of the entries in the
+ * original VertexRDD. The resulting VertexRDD retains the same index.
+ */
+ def mapValues[VD2: ClassTag](f: (VertexID, VD) => VD2): VertexRDD[VD2] =
+ this.mapVertexPartitions(_.map(f))
+
+ /**
+ * Hides vertices that are the same between `this` and `other`; for vertices that are different,
+ * keeps the values from `other`.
+ */
+ def diff(other: VertexRDD[VD]): VertexRDD[VD] = {
+ val newPartitionsRDD = partitionsRDD.zipPartitions(
+ other.partitionsRDD, preservesPartitioning = true
+ ) { (thisIter, otherIter) =>
+ val thisPart = thisIter.next()
+ val otherPart = otherIter.next()
+ Iterator(thisPart.diff(otherPart))
+ }
+ new VertexRDD(newPartitionsRDD)
+ }
+
+ /**
+ * Left joins this RDD with another VertexRDD with the same index. This function will fail if both
+ * VertexRDDs do not share the same index. The resulting vertex set contains an entry for each
+ * vertex in `this`. If `other` is missing any vertex in this VertexRDD, `f` is passed `None`.
+ *
+ * @tparam VD2 the attribute type of the other VertexRDD
+ * @tparam VD3 the attribute type of the resulting VertexRDD
+ *
+ * @param other the other VertexRDD with which to join.
+ * @param f the function mapping a vertex id and its attributes in this and the other vertex set
+ * to a new vertex attribute.
+ * @return a VertexRDD containing the results of `f`
+ */
+ def leftZipJoin[VD2: ClassTag, VD3: ClassTag]
+ (other: VertexRDD[VD2])(f: (VertexID, VD, Option[VD2]) => VD3): VertexRDD[VD3] = {
+ val newPartitionsRDD = partitionsRDD.zipPartitions(
+ other.partitionsRDD, preservesPartitioning = true
+ ) { (thisIter, otherIter) =>
+ val thisPart = thisIter.next()
+ val otherPart = otherIter.next()
+ Iterator(thisPart.leftJoin(otherPart)(f))
+ }
+ new VertexRDD(newPartitionsRDD)
+ }
+
+ /**
+ * Left joins this VertexRDD with an RDD containing vertex attribute pairs. If the other RDD is
+ * backed by a VertexRDD with the same index then the efficient [[leftZipJoin]] implementation is
+ * used. The resulting VertexRDD contains an entry for each vertex in `this`. If `other` is
+ * missing any vertex in this VertexRDD, `f` is passed `None`. If there are duplicates, the vertex
+ * is picked arbitrarily.
+ *
+ * @tparam VD2 the attribute type of the other VertexRDD
+ * @tparam VD3 the attribute type of the resulting VertexRDD
+ *
+ * @param other the other VertexRDD with which to join
+ * @param f the function mapping a vertex id and its attributes in this and the other vertex set
+ * to a new vertex attribute.
+ * @return a VertexRDD containing all the vertices in this VertexRDD with the attributes emitted
+ * by `f`.
+ */
+ def leftJoin[VD2: ClassTag, VD3: ClassTag]
+ (other: RDD[(VertexID, VD2)])
+ (f: (VertexID, VD, Option[VD2]) => VD3)
+ : VertexRDD[VD3] = {
+ // Test if the other vertex is a VertexRDD to choose the optimal join strategy.
+ // If the other set is a VertexRDD then we use the much more efficient leftZipJoin
+ other match {
+ case other: VertexRDD[_] =>
+ leftZipJoin(other)(f)
+ case _ =>
+ new VertexRDD[VD3](
+ partitionsRDD.zipPartitions(
+ other.partitionBy(this.partitioner.get), preservesPartitioning = true)
+ { (part, msgs) =>
+ val vertexPartition: VertexPartition[VD] = part.next()
+ Iterator(vertexPartition.leftJoin(msgs)(f))
+ }
+ )
+ }
+ }
+
+ /**
+ * Efficiently inner joins this VertexRDD with another VertexRDD sharing the same index. See
+ * [[innerJoin]] for the behavior of the join.
+ */
+ def innerZipJoin[U: ClassTag, VD2: ClassTag](other: VertexRDD[U])
+ (f: (VertexID, VD, U) => VD2): VertexRDD[VD2] = {
+ val newPartitionsRDD = partitionsRDD.zipPartitions(
+ other.partitionsRDD, preservesPartitioning = true
+ ) { (thisIter, otherIter) =>
+ val thisPart = thisIter.next()
+ val otherPart = otherIter.next()
+ Iterator(thisPart.innerJoin(otherPart)(f))
+ }
+ new VertexRDD(newPartitionsRDD)
+ }
+
+ /**
+ * Inner joins this VertexRDD with an RDD containing vertex attribute pairs. If the other RDD is
+ * backed by a VertexRDD with the same index then the efficient [[innerZipJoin]] implementation is
+ * used.
+ *
+ * @param other an RDD containing vertices to join. If there are multiple entries for the same
+ * vertex, one is picked arbitrarily. Use [[aggregateUsingIndex]] to merge multiple entries.
+ * @param f the join function applied to corresponding values of `this` and `other`
+ * @return a VertexRDD co-indexed with `this`, containing only vertices that appear in both `this`
+ * and `other`, with values supplied by `f`
+ */
+ def innerJoin[U: ClassTag, VD2: ClassTag](other: RDD[(VertexID, U)])
+ (f: (VertexID, VD, U) => VD2): VertexRDD[VD2] = {
+ // Test if the other vertex is a VertexRDD to choose the optimal join strategy.
+ // If the other set is a VertexRDD then we use the much more efficient innerZipJoin
+ other match {
+ case other: VertexRDD[_] =>
+ innerZipJoin(other)(f)
+ case _ =>
+ new VertexRDD(
+ partitionsRDD.zipPartitions(
+ other.partitionBy(this.partitioner.get), preservesPartitioning = true)
+ { (part, msgs) =>
+ val vertexPartition: VertexPartition[VD] = part.next()
+ Iterator(vertexPartition.innerJoin(msgs)(f))
+ }
+ )
+ }
+ }
+
+ /**
+ * Aggregates vertices in `messages` that have the same ids using `reduceFunc`, returning a
+ * VertexRDD co-indexed with `this`.
+ *
+ * @param messages an RDD containing messages to aggregate, where each message is a pair of its
+ * target vertex ID and the message data
+ * @param reduceFunc the associative aggregation function for merging messages to the same vertex
+ * @return a VertexRDD co-indexed with `this`, containing only vertices that received messages.
+ * For those vertices, their values are the result of applying `reduceFunc` to all received
+ * messages.
+ */
+ def aggregateUsingIndex[VD2: ClassTag](
+ messages: RDD[(VertexID, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = {
+ val shuffled = MsgRDDFunctions.partitionForAggregation(messages, this.partitioner.get)
+ val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) =>
+ val vertexPartition: VertexPartition[VD] = thisIter.next()
+ Iterator(vertexPartition.aggregateUsingIndex(msgIter, reduceFunc))
+ }
+ new VertexRDD[VD2](parts)
+ }
+
+} // end of VertexRDD
+
+
+/**
+ * The VertexRDD singleton is used to construct VertexRDDs.
+ */
+object VertexRDD {
+
+ /**
+ * Construct a `VertexRDD` from an RDD of vertex-attribute pairs.
+ * Duplicate entries are removed arbitrarily.
+ *
+ * @tparam VD the vertex attribute type
+ *
+ * @param rdd the collection of vertex-attribute pairs
+ */
+ def apply[VD: ClassTag](rdd: RDD[(VertexID, VD)]): VertexRDD[VD] = {
+ val partitioned: RDD[(VertexID, VD)] = rdd.partitioner match {
+ case Some(p) => rdd
+ case None => rdd.partitionBy(new HashPartitioner(rdd.partitions.size))
+ }
+ val vertexPartitions = partitioned.mapPartitions(
+ iter => Iterator(VertexPartition(iter)),
+ preservesPartitioning = true)
+ new VertexRDD(vertexPartitions)
+ }
+
+ /**
+ * Constructs a `VertexRDD` from an RDD of vertex-attribute pairs, merging duplicates using
+ * `mergeFunc`.
+ *
+ * @tparam VD the vertex attribute type
+ *
+ * @param rdd the collection of vertex-attribute pairs
+ * @param mergeFunc the associative, commutative merge function.
+ */
+ def apply[VD: ClassTag](rdd: RDD[(VertexID, VD)], mergeFunc: (VD, VD) => VD): VertexRDD[VD] = {
+ val partitioned: RDD[(VertexID, VD)] = rdd.partitioner match {
+ case Some(p) => rdd
+ case None => rdd.partitionBy(new HashPartitioner(rdd.partitions.size))
+ }
+ val vertexPartitions = partitioned.mapPartitions(
+ iter => Iterator(VertexPartition(iter)),
+ preservesPartitioning = true)
+ new VertexRDD(vertexPartitions)
+ }
+
+ /**
+ * Constructs a VertexRDD from the vertex IDs in `vids`, taking attributes from `rdd` and using
+ * `defaultVal` otherwise.
+ */
+ def apply[VD: ClassTag](vids: RDD[VertexID], rdd: RDD[(VertexID, VD)], defaultVal: VD)
+ : VertexRDD[VD] = {
+ VertexRDD(vids.map(vid => (vid, defaultVal))).leftJoin(rdd) { (vid, default, value) =>
+ value.getOrElse(default)
+ }
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
new file mode 100644
index 0000000000..ee95ead3ad
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
@@ -0,0 +1,220 @@
+package org.apache.spark.graphx.impl
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+
+/**
+ * A collection of edges stored in 3 large columnar arrays (src, dst, attribute). The arrays are
+ * clustered by src.
+ *
+ * @param srcIds the source vertex id of each edge
+ * @param dstIds the destination vertex id of each edge
+ * @param data the attribute associated with each edge
+ * @param index a clustered index on source vertex id
+ * @tparam ED the edge attribute type.
+ */
+private[graphx]
+class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED: ClassTag](
+ val srcIds: Array[VertexID],
+ val dstIds: Array[VertexID],
+ val data: Array[ED],
+ val index: PrimitiveKeyOpenHashMap[VertexID, Int]) extends Serializable {
+
+ /**
+ * Reverse all the edges in this partition.
+ *
+ * @return a new edge partition with all edges reversed.
+ */
+ def reverse: EdgePartition[ED] = {
+ val builder = new EdgePartitionBuilder(size)
+ for (e <- iterator) {
+ builder.add(e.dstId, e.srcId, e.attr)
+ }
+ builder.toEdgePartition
+ }
+
+ /**
+ * Construct a new edge partition by applying the function f to all
+ * edges in this partition.
+ *
+ * @param f a function from an edge to a new attribute
+ * @tparam ED2 the type of the new attribute
+ * @return a new edge partition with the result of the function `f`
+ * applied to each edge
+ */
+ def map[ED2: ClassTag](f: Edge[ED] => ED2): EdgePartition[ED2] = {
+ val newData = new Array[ED2](data.size)
+ val edge = new Edge[ED]()
+ val size = data.size
+ var i = 0
+ while (i < size) {
+ edge.srcId = srcIds(i)
+ edge.dstId = dstIds(i)
+ edge.attr = data(i)
+ newData(i) = f(edge)
+ i += 1
+ }
+ new EdgePartition(srcIds, dstIds, newData, index)
+ }
+
+ /**
+ * Construct a new edge partition by using the edge attributes
+ * contained in the iterator.
+ *
+ * @note The input iterator should return edge attributes in the
+ * order of the edges returned by `EdgePartition.iterator` and
+ * should return attributes equal to the number of edges.
+ *
+ * @param f a function from an edge to a new attribute
+ * @tparam ED2 the type of the new attribute
+ * @return a new edge partition with the result of the function `f`
+ * applied to each edge
+ */
+ def map[ED2: ClassTag](iter: Iterator[ED2]): EdgePartition[ED2] = {
+ val newData = new Array[ED2](data.size)
+ var i = 0
+ while (iter.hasNext) {
+ newData(i) = iter.next()
+ i += 1
+ }
+ assert(newData.size == i)
+ new EdgePartition(srcIds, dstIds, newData, index)
+ }
+
+ /**
+ * Apply the function f to all edges in this partition.
+ *
+ * @param f an external state mutating user defined function.
+ */
+ def foreach(f: Edge[ED] => Unit) {
+ iterator.foreach(f)
+ }
+
+ /**
+ * Merge all the edges with the same src and dest id into a single
+ * edge using the `merge` function
+ *
+ * @param merge a commutative associative merge operation
+ * @return a new edge partition without duplicate edges
+ */
+ def groupEdges(merge: (ED, ED) => ED): EdgePartition[ED] = {
+ val builder = new EdgePartitionBuilder[ED]
+ var currSrcId: VertexID = null.asInstanceOf[VertexID]
+ var currDstId: VertexID = null.asInstanceOf[VertexID]
+ var currAttr: ED = null.asInstanceOf[ED]
+ var i = 0
+ while (i < size) {
+ if (i > 0 && currSrcId == srcIds(i) && currDstId == dstIds(i)) {
+ currAttr = merge(currAttr, data(i))
+ } else {
+ if (i > 0) {
+ builder.add(currSrcId, currDstId, currAttr)
+ }
+ currSrcId = srcIds(i)
+ currDstId = dstIds(i)
+ currAttr = data(i)
+ }
+ i += 1
+ }
+ if (size > 0) {
+ builder.add(currSrcId, currDstId, currAttr)
+ }
+ builder.toEdgePartition
+ }
+
+ /**
+ * Apply `f` to all edges present in both `this` and `other` and return a new EdgePartition
+ * containing the resulting edges.
+ *
+ * If there are multiple edges with the same src and dst in `this`, `f` will be invoked once for
+ * each edge, but each time it may be invoked on any corresponding edge in `other`.
+ *
+ * If there are multiple edges with the same src and dst in `other`, `f` will only be invoked
+ * once.
+ */
+ def innerJoin[ED2: ClassTag, ED3: ClassTag]
+ (other: EdgePartition[ED2])
+ (f: (VertexID, VertexID, ED, ED2) => ED3): EdgePartition[ED3] = {
+ val builder = new EdgePartitionBuilder[ED3]
+ var i = 0
+ var j = 0
+ // For i = index of each edge in `this`...
+ while (i < size && j < other.size) {
+ val srcId = this.srcIds(i)
+ val dstId = this.dstIds(i)
+ // ... forward j to the index of the corresponding edge in `other`, and...
+ while (j < other.size && other.srcIds(j) < srcId) { j += 1 }
+ if (j < other.size && other.srcIds(j) == srcId) {
+ while (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) < dstId) { j += 1 }
+ if (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) == dstId) {
+ // ... run `f` on the matching edge
+ builder.add(srcId, dstId, f(srcId, dstId, this.data(i), other.data(j)))
+ }
+ }
+ i += 1
+ }
+ builder.toEdgePartition
+ }
+
+ /**
+ * The number of edges in this partition
+ *
+ * @return size of the partition
+ */
+ def size: Int = srcIds.size
+
+ /** The number of unique source vertices in the partition. */
+ def indexSize: Int = index.size
+
+ /**
+ * Get an iterator over the edges in this partition.
+ *
+ * @return an iterator over edges in the partition
+ */
+ def iterator = new Iterator[Edge[ED]] {
+ private[this] val edge = new Edge[ED]
+ private[this] var pos = 0
+
+ override def hasNext: Boolean = pos < EdgePartition.this.size
+
+ override def next(): Edge[ED] = {
+ edge.srcId = srcIds(pos)
+ edge.dstId = dstIds(pos)
+ edge.attr = data(pos)
+ pos += 1
+ edge
+ }
+ }
+
+ /**
+ * Get an iterator over the edges in this partition whose source vertex ids match srcIdPred. The
+ * iterator is generated using an index scan, so it is efficient at skipping edges that don't
+ * match srcIdPred.
+ */
+ def indexIterator(srcIdPred: VertexID => Boolean): Iterator[Edge[ED]] =
+ index.iterator.filter(kv => srcIdPred(kv._1)).flatMap(Function.tupled(clusterIterator))
+
+ /**
+ * Get an iterator over the cluster of edges in this partition with source vertex id `srcId`. The
+ * cluster must start at position `index`.
+ */
+ private def clusterIterator(srcId: VertexID, index: Int) = new Iterator[Edge[ED]] {
+ private[this] val edge = new Edge[ED]
+ private[this] var pos = index
+
+ override def hasNext: Boolean = {
+ pos >= 0 && pos < EdgePartition.this.size && srcIds(pos) == srcId
+ }
+
+ override def next(): Edge[ED] = {
+ assert(srcIds(pos) == srcId)
+ edge.srcId = srcIds(pos)
+ edge.dstId = dstIds(pos)
+ edge.attr = data(pos)
+ pos += 1
+ edge
+ }
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
new file mode 100644
index 0000000000..9d072f9335
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
@@ -0,0 +1,45 @@
+package org.apache.spark.graphx.impl
+
+import scala.reflect.ClassTag
+import scala.util.Sorting
+
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+import org.apache.spark.util.collection.PrimitiveVector
+
+private[graphx]
+class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag](size: Int = 64) {
+ var edges = new PrimitiveVector[Edge[ED]](size)
+
+ /** Add a new edge to the partition. */
+ def add(src: VertexID, dst: VertexID, d: ED) {
+ edges += Edge(src, dst, d)
+ }
+
+ def toEdgePartition: EdgePartition[ED] = {
+ val edgeArray = edges.trim().array
+ Sorting.quickSort(edgeArray)(Edge.lexicographicOrdering)
+ val srcIds = new Array[VertexID](edgeArray.size)
+ val dstIds = new Array[VertexID](edgeArray.size)
+ val data = new Array[ED](edgeArray.size)
+ val index = new PrimitiveKeyOpenHashMap[VertexID, Int]
+ // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and
+ // adding them to the index
+ if (edgeArray.length > 0) {
+ index.update(srcIds(0), 0)
+ var currSrcId: VertexID = srcIds(0)
+ var i = 0
+ while (i < edgeArray.size) {
+ srcIds(i) = edgeArray(i).srcId
+ dstIds(i) = edgeArray(i).dstId
+ data(i) = edgeArray(i).attr
+ if (edgeArray(i).srcId != currSrcId) {
+ currSrcId = edgeArray(i).srcId
+ index.update(currSrcId, i)
+ }
+ i += 1
+ }
+ }
+ new EdgePartition(srcIds, dstIds, data, index)
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala
new file mode 100644
index 0000000000..bad840f1cd
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala
@@ -0,0 +1,42 @@
+package org.apache.spark.graphx.impl
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+
+/**
+ * The Iterator type returned when constructing edge triplets. This class technically could be
+ * an anonymous class in GraphImpl.triplets, but we name it here explicitly so it is easier to
+ * debug / profile.
+ */
+private[impl]
+class EdgeTripletIterator[VD: ClassTag, ED: ClassTag](
+ val vidToIndex: VertexIdToIndexMap,
+ val vertexArray: Array[VD],
+ val edgePartition: EdgePartition[ED])
+ extends Iterator[EdgeTriplet[VD, ED]] {
+
+ // Current position in the array.
+ private var pos = 0
+
+ // A triplet object that this iterator.next() call returns. We reuse this object to avoid
+ // allocating too many temporary Java objects.
+ private val triplet = new EdgeTriplet[VD, ED]
+
+ private val vmap = new PrimitiveKeyOpenHashMap[VertexID, VD](vidToIndex, vertexArray)
+
+ override def hasNext: Boolean = pos < edgePartition.size
+
+ override def next() = {
+ triplet.srcId = edgePartition.srcIds(pos)
+ // assert(vmap.containsKey(e.src.id))
+ triplet.srcAttr = vmap(triplet.srcId)
+ triplet.dstId = edgePartition.dstIds(pos)
+ // assert(vmap.containsKey(e.dst.id))
+ triplet.dstAttr = vmap(triplet.dstId)
+ triplet.attr = edgePartition.data(pos)
+ pos += 1
+ triplet
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
new file mode 100644
index 0000000000..56d1d9efea
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
@@ -0,0 +1,379 @@
+package org.apache.spark.graphx.impl
+
+import scala.reflect.{classTag, ClassTag}
+
+import org.apache.spark.util.collection.PrimitiveVector
+import org.apache.spark.{HashPartitioner, Partitioner}
+import org.apache.spark.SparkContext._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.impl.GraphImpl._
+import org.apache.spark.graphx.impl.MsgRDDFunctions._
+import org.apache.spark.graphx.util.BytecodeUtils
+import org.apache.spark.rdd.{ShuffledRDD, RDD}
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.ClosureCleaner
+
+
+/**
+ * A graph that supports computation on graphs.
+ *
+ * Graphs are represented using two classes of data: vertex-partitioned and
+ * edge-partitioned. `vertices` contains vertex attributes, which are vertex-partitioned. `edges`
+ * contains edge attributes, which are edge-partitioned. For operations on vertex neighborhoods,
+ * vertex attributes are replicated to the edge partitions where they appear as sources or
+ * destinations. `routingTable` stores the routing information for shipping vertex attributes to
+ * edge partitions. `replicatedVertexView` stores a view of the replicated vertex attributes created
+ * using the routing table.
+ */
+class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
+ @transient val vertices: VertexRDD[VD],
+ @transient val edges: EdgeRDD[ED],
+ @transient val routingTable: RoutingTable,
+ @transient val replicatedVertexView: ReplicatedVertexView[VD])
+ extends Graph[VD, ED] with Serializable {
+
+ /** Default constructor is provided to support serialization */
+ protected def this() = this(null, null, null, null)
+
+ /** Return a RDD that brings edges together with their source and destination vertices. */
+ @transient override val triplets: RDD[EdgeTriplet[VD, ED]] = {
+ val vdTag = classTag[VD]
+ val edTag = classTag[ED]
+ edges.partitionsRDD.zipPartitions(
+ replicatedVertexView.get(true, true), true) { (ePartIter, vPartIter) =>
+ val (pid, ePart) = ePartIter.next()
+ val (_, vPart) = vPartIter.next()
+ new EdgeTripletIterator(vPart.index, vPart.values, ePart)(vdTag, edTag)
+ }
+ }
+
+ override def persist(newLevel: StorageLevel): Graph[VD, ED] = {
+ vertices.persist(newLevel)
+ edges.persist(newLevel)
+ this
+ }
+
+ override def cache(): Graph[VD, ED] = persist(StorageLevel.MEMORY_ONLY)
+
+ override def unpersistVertices(blocking: Boolean = true): Graph[VD, ED] = {
+ vertices.unpersist(blocking)
+ replicatedVertexView.unpersist(blocking)
+ this
+ }
+
+ override def partitionBy(partitionStrategy: PartitionStrategy): Graph[VD, ED] = {
+ val numPartitions = edges.partitions.size
+ val edTag = classTag[ED]
+ val newEdges = new EdgeRDD(edges.map { e =>
+ val part: PartitionID = partitionStrategy.getPartition(e.srcId, e.dstId, numPartitions)
+
+ // Should we be using 3-tuple or an optimized class
+ new MessageToPartition(part, (e.srcId, e.dstId, e.attr))
+ }
+ .partitionBy(new HashPartitioner(numPartitions))
+ .mapPartitionsWithIndex( { (pid, iter) =>
+ val builder = new EdgePartitionBuilder[ED]()(edTag)
+ iter.foreach { message =>
+ val data = message.data
+ builder.add(data._1, data._2, data._3)
+ }
+ val edgePartition = builder.toEdgePartition
+ Iterator((pid, edgePartition))
+ }, preservesPartitioning = true).cache())
+ GraphImpl(vertices, newEdges)
+ }
+
+ override def reverse: Graph[VD, ED] = {
+ val newETable = edges.mapEdgePartitions((pid, part) => part.reverse)
+ new GraphImpl(vertices, newETable, routingTable, replicatedVertexView)
+ }
+
+ override def mapVertices[VD2: ClassTag](f: (VertexID, VD) => VD2): Graph[VD2, ED] = {
+ if (classTag[VD] equals classTag[VD2]) {
+ // The map preserves type, so we can use incremental replication
+ val newVerts = vertices.mapVertexPartitions(_.map(f)).cache()
+ val changedVerts = vertices.asInstanceOf[VertexRDD[VD2]].diff(newVerts)
+ val newReplicatedVertexView = new ReplicatedVertexView[VD2](
+ changedVerts, edges, routingTable,
+ Some(replicatedVertexView.asInstanceOf[ReplicatedVertexView[VD2]]))
+ new GraphImpl(newVerts, edges, routingTable, newReplicatedVertexView)
+ } else {
+ // The map does not preserve type, so we must re-replicate all vertices
+ GraphImpl(vertices.mapVertexPartitions(_.map(f)), edges, routingTable)
+ }
+ }
+
+ override def mapEdges[ED2: ClassTag](
+ f: (PartitionID, Iterator[Edge[ED]]) => Iterator[ED2]): Graph[VD, ED2] = {
+ val newETable = edges.mapEdgePartitions((pid, part) => part.map(f(pid, part.iterator)))
+ new GraphImpl(vertices, newETable , routingTable, replicatedVertexView)
+ }
+
+ override def mapTriplets[ED2: ClassTag](
+ f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]): Graph[VD, ED2] = {
+ val newEdgePartitions =
+ edges.partitionsRDD.zipPartitions(replicatedVertexView.get(true, true), true) {
+ (ePartIter, vTableReplicatedIter) =>
+ val (ePid, edgePartition) = ePartIter.next()
+ val (vPid, vPart) = vTableReplicatedIter.next()
+ assert(!vTableReplicatedIter.hasNext)
+ assert(ePid == vPid)
+ val et = new EdgeTriplet[VD, ED]
+ val inputIterator = edgePartition.iterator.map { e =>
+ et.set(e)
+ et.srcAttr = vPart(e.srcId)
+ et.dstAttr = vPart(e.dstId)
+ et
+ }
+ // Apply the user function to the vertex partition
+ val outputIter = f(ePid, inputIterator)
+ // Consume the iterator to update the edge attributes
+ val newEdgePartition = edgePartition.map(outputIter)
+ Iterator((ePid, newEdgePartition))
+ }
+ new GraphImpl(vertices, new EdgeRDD(newEdgePartitions), routingTable, replicatedVertexView)
+ }
+
+ override def subgraph(
+ epred: EdgeTriplet[VD, ED] => Boolean = x => true,
+ vpred: (VertexID, VD) => Boolean = (a, b) => true): Graph[VD, ED] = {
+ // Filter the vertices, reusing the partitioner and the index from this graph
+ val newVerts = vertices.mapVertexPartitions(_.filter(vpred))
+
+ // Filter the edges
+ val edTag = classTag[ED]
+ val newEdges = new EdgeRDD[ED](triplets.filter { et =>
+ vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et)
+ }.mapPartitionsWithIndex( { (pid, iter) =>
+ val builder = new EdgePartitionBuilder[ED]()(edTag)
+ iter.foreach { et => builder.add(et.srcId, et.dstId, et.attr) }
+ val edgePartition = builder.toEdgePartition
+ Iterator((pid, edgePartition))
+ }, preservesPartitioning = true)).cache()
+
+ // Reuse the previous ReplicatedVertexView unmodified. The replicated vertices that have been
+ // removed will be ignored, since we only refer to replicated vertices when they are adjacent to
+ // an edge.
+ new GraphImpl(newVerts, newEdges, new RoutingTable(newEdges, newVerts), replicatedVertexView)
+ } // end of subgraph
+
+ override def mask[VD2: ClassTag, ED2: ClassTag] (
+ other: Graph[VD2, ED2]): Graph[VD, ED] = {
+ val newVerts = vertices.innerJoin(other.vertices) { (vid, v, w) => v }
+ val newEdges = edges.innerJoin(other.edges) { (src, dst, v, w) => v }
+ // Reuse the previous ReplicatedVertexView unmodified. The replicated vertices that have been
+ // removed will be ignored, since we only refer to replicated vertices when they are adjacent to
+ // an edge.
+ new GraphImpl(newVerts, newEdges, routingTable, replicatedVertexView)
+ }
+
+ override def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED] = {
+ ClosureCleaner.clean(merge)
+ val newETable = edges.mapEdgePartitions((pid, part) => part.groupEdges(merge))
+ new GraphImpl(vertices, newETable, routingTable, replicatedVertexView)
+ }
+
+ //////////////////////////////////////////////////////////////////////////////////////////////////
+ // Lower level transformation methods
+ //////////////////////////////////////////////////////////////////////////////////////////////////
+
+ override def mapReduceTriplets[A: ClassTag](
+ mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexID, A)],
+ reduceFunc: (A, A) => A,
+ activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None) = {
+
+ ClosureCleaner.clean(mapFunc)
+ ClosureCleaner.clean(reduceFunc)
+
+ // For each vertex, replicate its attribute only to partitions where it is
+ // in the relevant position in an edge.
+ val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr")
+ val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr")
+ val vs = activeSetOpt match {
+ case Some((activeSet, _)) =>
+ replicatedVertexView.get(mapUsesSrcAttr, mapUsesDstAttr, activeSet)
+ case None =>
+ replicatedVertexView.get(mapUsesSrcAttr, mapUsesDstAttr)
+ }
+ val activeDirectionOpt = activeSetOpt.map(_._2)
+
+ // Map and combine.
+ val preAgg = edges.partitionsRDD.zipPartitions(vs, true) { (ePartIter, vPartIter) =>
+ val (ePid, edgePartition) = ePartIter.next()
+ val (vPid, vPart) = vPartIter.next()
+ assert(!vPartIter.hasNext)
+ assert(ePid == vPid)
+ // Choose scan method
+ val activeFraction = vPart.numActives.getOrElse(0) / edgePartition.indexSize.toFloat
+ val edgeIter = activeDirectionOpt match {
+ case Some(EdgeDirection.Both) =>
+ if (activeFraction < 0.8) {
+ edgePartition.indexIterator(srcVertexID => vPart.isActive(srcVertexID))
+ .filter(e => vPart.isActive(e.dstId))
+ } else {
+ edgePartition.iterator.filter(e => vPart.isActive(e.srcId) && vPart.isActive(e.dstId))
+ }
+ case Some(EdgeDirection.Either) =>
+ // TODO: Because we only have a clustered index on the source vertex ID, we can't filter
+ // the index here. Instead we have to scan all edges and then do the filter.
+ edgePartition.iterator.filter(e => vPart.isActive(e.srcId) || vPart.isActive(e.dstId))
+ case Some(EdgeDirection.Out) =>
+ if (activeFraction < 0.8) {
+ edgePartition.indexIterator(srcVertexID => vPart.isActive(srcVertexID))
+ } else {
+ edgePartition.iterator.filter(e => vPart.isActive(e.srcId))
+ }
+ case Some(EdgeDirection.In) =>
+ edgePartition.iterator.filter(e => vPart.isActive(e.dstId))
+ case _ => // None
+ edgePartition.iterator
+ }
+
+ // Scan edges and run the map function
+ val et = new EdgeTriplet[VD, ED]
+ val mapOutputs = edgeIter.flatMap { e =>
+ et.set(e)
+ if (mapUsesSrcAttr) {
+ et.srcAttr = vPart(e.srcId)
+ }
+ if (mapUsesDstAttr) {
+ et.dstAttr = vPart(e.dstId)
+ }
+ mapFunc(et)
+ }
+ // Note: This doesn't allow users to send messages to arbitrary vertices.
+ vPart.aggregateUsingIndex(mapOutputs, reduceFunc).iterator
+ }
+
+ // do the final reduction reusing the index map
+ vertices.aggregateUsingIndex(preAgg, reduceFunc)
+ } // end of mapReduceTriplets
+
+ override def outerJoinVertices[U: ClassTag, VD2: ClassTag]
+ (other: RDD[(VertexID, U)])
+ (updateF: (VertexID, VD, Option[U]) => VD2): Graph[VD2, ED] =
+ {
+ if (classTag[VD] equals classTag[VD2]) {
+ // updateF preserves type, so we can use incremental replication
+ val newVerts = vertices.leftJoin(other)(updateF)
+ val changedVerts = vertices.asInstanceOf[VertexRDD[VD2]].diff(newVerts)
+ val newReplicatedVertexView = new ReplicatedVertexView[VD2](
+ changedVerts, edges, routingTable,
+ Some(replicatedVertexView.asInstanceOf[ReplicatedVertexView[VD2]]))
+ new GraphImpl(newVerts, edges, routingTable, newReplicatedVertexView)
+ } else {
+ // updateF does not preserve type, so we must re-replicate all vertices
+ val newVerts = vertices.leftJoin(other)(updateF)
+ GraphImpl(newVerts, edges, routingTable)
+ }
+ }
+
+ /** Test whether the closure accesses the the attribute with name `attrName`. */
+ private def accessesVertexAttr(closure: AnyRef, attrName: String): Boolean = {
+ try {
+ BytecodeUtils.invokedMethod(closure, classOf[EdgeTriplet[VD, ED]], attrName)
+ } catch {
+ case _: ClassNotFoundException => true // if we don't know, be conservative
+ }
+ }
+} // end of class GraphImpl
+
+
+object GraphImpl {
+
+ def apply[VD: ClassTag, ED: ClassTag](
+ edges: RDD[Edge[ED]],
+ defaultVertexAttr: VD): GraphImpl[VD, ED] =
+ {
+ fromEdgeRDD(createEdgeRDD(edges), defaultVertexAttr)
+ }
+
+ def fromEdgePartitions[VD: ClassTag, ED: ClassTag](
+ edgePartitions: RDD[(PartitionID, EdgePartition[ED])],
+ defaultVertexAttr: VD): GraphImpl[VD, ED] = {
+ fromEdgeRDD(new EdgeRDD(edgePartitions), defaultVertexAttr)
+ }
+
+ def apply[VD: ClassTag, ED: ClassTag](
+ vertices: RDD[(VertexID, VD)],
+ edges: RDD[Edge[ED]],
+ defaultVertexAttr: VD): GraphImpl[VD, ED] =
+ {
+ val edgeRDD = createEdgeRDD(edges).cache()
+
+ // Get the set of all vids
+ val partitioner = Partitioner.defaultPartitioner(vertices)
+ val vPartitioned = vertices.partitionBy(partitioner)
+ val vidsFromEdges = collectVertexIDsFromEdges(edgeRDD, partitioner)
+ val vids = vPartitioned.zipPartitions(vidsFromEdges) { (vertexIter, vidsFromEdgesIter) =>
+ vertexIter.map(_._1) ++ vidsFromEdgesIter.map(_._1)
+ }
+
+ val vertexRDD = VertexRDD(vids, vPartitioned, defaultVertexAttr)
+
+ GraphImpl(vertexRDD, edgeRDD)
+ }
+
+ def apply[VD: ClassTag, ED: ClassTag](
+ vertices: VertexRDD[VD],
+ edges: EdgeRDD[ED]): GraphImpl[VD, ED] = {
+ // Cache RDDs that are referenced multiple times
+ edges.cache()
+
+ GraphImpl(vertices, edges, new RoutingTable(edges, vertices))
+ }
+
+ def apply[VD: ClassTag, ED: ClassTag](
+ vertices: VertexRDD[VD],
+ edges: EdgeRDD[ED],
+ routingTable: RoutingTable): GraphImpl[VD, ED] = {
+ // Cache RDDs that are referenced multiple times. `routingTable` is cached by default, so we
+ // don't cache it explicitly.
+ vertices.cache()
+ edges.cache()
+
+ new GraphImpl(
+ vertices, edges, routingTable, new ReplicatedVertexView(vertices, edges, routingTable))
+ }
+
+ /**
+ * Create the edge RDD, which is much more efficient for Java heap storage than the normal edges
+ * data structure (RDD[(VertexID, VertexID, ED)]).
+ *
+ * The edge RDD contains multiple partitions, and each partition contains only one RDD key-value
+ * pair: the key is the partition id, and the value is an EdgePartition object containing all the
+ * edges in a partition.
+ */
+ private def createEdgeRDD[ED: ClassTag](
+ edges: RDD[Edge[ED]]): EdgeRDD[ED] = {
+ val edgePartitions = edges.mapPartitionsWithIndex { (pid, iter) =>
+ val builder = new EdgePartitionBuilder[ED]
+ iter.foreach { e =>
+ builder.add(e.srcId, e.dstId, e.attr)
+ }
+ Iterator((pid, builder.toEdgePartition))
+ }
+ new EdgeRDD(edgePartitions)
+ }
+
+ private def fromEdgeRDD[VD: ClassTag, ED: ClassTag](
+ edges: EdgeRDD[ED],
+ defaultVertexAttr: VD): GraphImpl[VD, ED] = {
+ edges.cache()
+ // Get the set of all vids
+ val vids = collectVertexIDsFromEdges(edges, new HashPartitioner(edges.partitions.size))
+ // Create the VertexRDD.
+ val vertices = VertexRDD(vids.mapValues(x => defaultVertexAttr))
+ GraphImpl(vertices, edges)
+ }
+
+ /** Collects all vids mentioned in edges and partitions them by partitioner. */
+ private def collectVertexIDsFromEdges(
+ edges: EdgeRDD[_],
+ partitioner: Partitioner): RDD[(VertexID, Int)] = {
+ // TODO: Consider doing map side distinct before shuffle.
+ new ShuffledRDD[VertexID, Int, (VertexID, Int)](
+ edges.collectVertexIDs.map(vid => (vid, 0)), partitioner)
+ .setSerializer(classOf[VertexIDMsgSerializer].getName)
+ }
+} // end of object GraphImpl
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
new file mode 100644
index 0000000000..05508ff716
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
@@ -0,0 +1,98 @@
+package org.apache.spark.graphx.impl
+
+import scala.reflect.{classTag, ClassTag}
+
+import org.apache.spark.Partitioner
+import org.apache.spark.graphx.{PartitionID, VertexID}
+import org.apache.spark.rdd.{ShuffledRDD, RDD}
+
+
+private[graphx]
+class VertexBroadcastMsg[@specialized(Int, Long, Double, Boolean) T](
+ @transient var partition: PartitionID,
+ var vid: VertexID,
+ var data: T)
+ extends Product2[PartitionID, (VertexID, T)] with Serializable {
+
+ override def _1 = partition
+
+ override def _2 = (vid, data)
+
+ override def canEqual(that: Any): Boolean = that.isInstanceOf[VertexBroadcastMsg[_]]
+}
+
+
+/**
+ * A message used to send a specific value to a partition.
+ * @param partition index of the target partition.
+ * @param data value to send
+ */
+private[graphx]
+class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T](
+ @transient var partition: PartitionID,
+ var data: T)
+ extends Product2[PartitionID, T] with Serializable {
+
+ override def _1 = partition
+
+ override def _2 = data
+
+ override def canEqual(that: Any): Boolean = that.isInstanceOf[MessageToPartition[_]]
+}
+
+
+private[graphx]
+class VertexBroadcastMsgRDDFunctions[T: ClassTag](self: RDD[VertexBroadcastMsg[T]]) {
+ def partitionBy(partitioner: Partitioner): RDD[VertexBroadcastMsg[T]] = {
+ val rdd = new ShuffledRDD[PartitionID, (VertexID, T), VertexBroadcastMsg[T]](self, partitioner)
+
+ // Set a custom serializer if the data is of int or double type.
+ if (classTag[T] == ClassTag.Int) {
+ rdd.setSerializer(classOf[IntVertexBroadcastMsgSerializer].getName)
+ } else if (classTag[T] == ClassTag.Long) {
+ rdd.setSerializer(classOf[LongVertexBroadcastMsgSerializer].getName)
+ } else if (classTag[T] == ClassTag.Double) {
+ rdd.setSerializer(classOf[DoubleVertexBroadcastMsgSerializer].getName)
+ }
+ rdd
+ }
+}
+
+
+private[graphx]
+class MsgRDDFunctions[T: ClassTag](self: RDD[MessageToPartition[T]]) {
+
+ /**
+ * Return a copy of the RDD partitioned using the specified partitioner.
+ */
+ def partitionBy(partitioner: Partitioner): RDD[MessageToPartition[T]] = {
+ new ShuffledRDD[PartitionID, T, MessageToPartition[T]](self, partitioner)
+ }
+
+}
+
+
+private[graphx]
+object MsgRDDFunctions {
+ implicit def rdd2PartitionRDDFunctions[T: ClassTag](rdd: RDD[MessageToPartition[T]]) = {
+ new MsgRDDFunctions(rdd)
+ }
+
+ implicit def rdd2vertexMessageRDDFunctions[T: ClassTag](rdd: RDD[VertexBroadcastMsg[T]]) = {
+ new VertexBroadcastMsgRDDFunctions(rdd)
+ }
+
+ def partitionForAggregation[T: ClassTag](msgs: RDD[(VertexID, T)], partitioner: Partitioner) = {
+ val rdd = new ShuffledRDD[VertexID, T, (VertexID, T)](msgs, partitioner)
+
+ // Set a custom serializer if the data is of int or double type.
+ if (classTag[T] == ClassTag.Int) {
+ rdd.setSerializer(classOf[IntAggMsgSerializer].getName)
+ } else if (classTag[T] == ClassTag.Long) {
+ rdd.setSerializer(classOf[LongAggMsgSerializer].getName)
+ } else if (classTag[T] == ClassTag.Double) {
+ rdd.setSerializer(classOf[DoubleAggMsgSerializer].getName)
+ }
+ rdd
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala
new file mode 100644
index 0000000000..4ebe0b0267
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala
@@ -0,0 +1,195 @@
+package org.apache.spark.graphx.impl
+
+import scala.reflect.{classTag, ClassTag}
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.collection.{PrimitiveVector, OpenHashSet}
+
+import org.apache.spark.graphx._
+
+/**
+ * A view of the vertices after they are shipped to the join sites specified in
+ * `vertexPlacement`. The resulting view is co-partitioned with `edges`. If `prevViewOpt` is
+ * specified, `updatedVerts` are treated as incremental updates to the previous view. Otherwise, a
+ * fresh view is created.
+ *
+ * The view is always cached (i.e., once it is evaluated, it remains materialized). This avoids
+ * constructing it twice if the user calls graph.triplets followed by graph.mapReduceTriplets, for
+ * example. However, it means iterative algorithms must manually call `Graph.unpersist` on previous
+ * iterations' graphs for best GC performance. See the implementation of
+ * [[org.apache.spark.graphx.Pregel]] for an example.
+ */
+private[impl]
+class ReplicatedVertexView[VD: ClassTag](
+ updatedVerts: VertexRDD[VD],
+ edges: EdgeRDD[_],
+ routingTable: RoutingTable,
+ prevViewOpt: Option[ReplicatedVertexView[VD]] = None) {
+
+ /**
+ * Within each edge partition, create a local map from vid to an index into the attribute
+ * array. Each map contains a superset of the vertices that it will receive, because it stores
+ * vids from both the source and destination of edges. It must always include both source and
+ * destination vids because some operations, such as GraphImpl.mapReduceTriplets, rely on this.
+ */
+ private val localVertexIDMap: RDD[(Int, VertexIdToIndexMap)] = prevViewOpt match {
+ case Some(prevView) =>
+ prevView.localVertexIDMap
+ case None =>
+ edges.partitionsRDD.mapPartitions(_.map {
+ case (pid, epart) =>
+ val vidToIndex = new VertexIdToIndexMap
+ epart.foreach { e =>
+ vidToIndex.add(e.srcId)
+ vidToIndex.add(e.dstId)
+ }
+ (pid, vidToIndex)
+ }, preservesPartitioning = true).cache().setName("ReplicatedVertexView localVertexIDMap")
+ }
+
+ private lazy val bothAttrs: RDD[(PartitionID, VertexPartition[VD])] = create(true, true)
+ private lazy val srcAttrOnly: RDD[(PartitionID, VertexPartition[VD])] = create(true, false)
+ private lazy val dstAttrOnly: RDD[(PartitionID, VertexPartition[VD])] = create(false, true)
+ private lazy val noAttrs: RDD[(PartitionID, VertexPartition[VD])] = create(false, false)
+
+ def unpersist(blocking: Boolean = true): ReplicatedVertexView[VD] = {
+ bothAttrs.unpersist(blocking)
+ srcAttrOnly.unpersist(blocking)
+ dstAttrOnly.unpersist(blocking)
+ noAttrs.unpersist(blocking)
+ // Don't unpersist localVertexIDMap because a future ReplicatedVertexView may be using it
+ // without modification
+ this
+ }
+
+ def get(includeSrc: Boolean, includeDst: Boolean): RDD[(PartitionID, VertexPartition[VD])] = {
+ (includeSrc, includeDst) match {
+ case (true, true) => bothAttrs
+ case (true, false) => srcAttrOnly
+ case (false, true) => dstAttrOnly
+ case (false, false) => noAttrs
+ }
+ }
+
+ def get(
+ includeSrc: Boolean,
+ includeDst: Boolean,
+ actives: VertexRDD[_]): RDD[(PartitionID, VertexPartition[VD])] = {
+ // Ship active sets to edge partitions using vertexPlacement, but ignoring includeSrc and
+ // includeDst. These flags govern attribute shipping, but the activeness of a vertex must be
+ // shipped to all edges mentioning that vertex, regardless of whether the vertex attribute is
+ // also shipped there.
+ val shippedActives = routingTable.get(true, true)
+ .zipPartitions(actives.partitionsRDD)(ReplicatedVertexView.buildActiveBuffer(_, _))
+ .partitionBy(edges.partitioner.get)
+ // Update the view with shippedActives, setting activeness flags in the resulting
+ // VertexPartitions
+ get(includeSrc, includeDst).zipPartitions(shippedActives) { (viewIter, shippedActivesIter) =>
+ val (pid, vPart) = viewIter.next()
+ val newPart = vPart.replaceActives(shippedActivesIter.flatMap(_._2.iterator))
+ Iterator((pid, newPart))
+ }
+ }
+
+ private def create(includeSrc: Boolean, includeDst: Boolean)
+ : RDD[(PartitionID, VertexPartition[VD])] = {
+ val vdTag = classTag[VD]
+
+ // Ship vertex attributes to edge partitions according to vertexPlacement
+ val verts = updatedVerts.partitionsRDD
+ val shippedVerts = routingTable.get(includeSrc, includeDst)
+ .zipPartitions(verts)(ReplicatedVertexView.buildBuffer(_, _)(vdTag))
+ .partitionBy(edges.partitioner.get)
+ // TODO: Consider using a specialized shuffler.
+
+ prevViewOpt match {
+ case Some(prevView) =>
+ // Update prevView with shippedVerts, setting staleness flags in the resulting
+ // VertexPartitions
+ prevView.get(includeSrc, includeDst).zipPartitions(shippedVerts) {
+ (prevViewIter, shippedVertsIter) =>
+ val (pid, prevVPart) = prevViewIter.next()
+ val newVPart = prevVPart.innerJoinKeepLeft(shippedVertsIter.flatMap(_._2.iterator))
+ Iterator((pid, newVPart))
+ }.cache().setName("ReplicatedVertexView delta %s %s".format(includeSrc, includeDst))
+
+ case None =>
+ // Within each edge partition, place the shipped vertex attributes into the correct
+ // locations specified in localVertexIDMap
+ localVertexIDMap.zipPartitions(shippedVerts) { (mapIter, shippedVertsIter) =>
+ val (pid, vidToIndex) = mapIter.next()
+ assert(!mapIter.hasNext)
+ // Populate the vertex array using the vidToIndex map
+ val vertexArray = vdTag.newArray(vidToIndex.capacity)
+ for ((_, block) <- shippedVertsIter) {
+ for (i <- 0 until block.vids.size) {
+ val vid = block.vids(i)
+ val attr = block.attrs(i)
+ val ind = vidToIndex.getPos(vid)
+ vertexArray(ind) = attr
+ }
+ }
+ val newVPart = new VertexPartition(
+ vidToIndex, vertexArray, vidToIndex.getBitSet)(vdTag)
+ Iterator((pid, newVPart))
+ }.cache().setName("ReplicatedVertexView %s %s".format(includeSrc, includeDst))
+ }
+ }
+}
+
+private object ReplicatedVertexView {
+ protected def buildBuffer[VD: ClassTag](
+ pid2vidIter: Iterator[Array[Array[VertexID]]],
+ vertexPartIter: Iterator[VertexPartition[VD]]) = {
+ val pid2vid: Array[Array[VertexID]] = pid2vidIter.next()
+ val vertexPart: VertexPartition[VD] = vertexPartIter.next()
+
+ Iterator.tabulate(pid2vid.size) { pid =>
+ val vidsCandidate = pid2vid(pid)
+ val size = vidsCandidate.length
+ val vids = new PrimitiveVector[VertexID](pid2vid(pid).size)
+ val attrs = new PrimitiveVector[VD](pid2vid(pid).size)
+ var i = 0
+ while (i < size) {
+ val vid = vidsCandidate(i)
+ if (vertexPart.isDefined(vid)) {
+ vids += vid
+ attrs += vertexPart(vid)
+ }
+ i += 1
+ }
+ (pid, new VertexAttributeBlock(vids.trim().array, attrs.trim().array))
+ }
+ }
+
+ protected def buildActiveBuffer(
+ pid2vidIter: Iterator[Array[Array[VertexID]]],
+ activePartIter: Iterator[VertexPartition[_]])
+ : Iterator[(Int, Array[VertexID])] = {
+ val pid2vid: Array[Array[VertexID]] = pid2vidIter.next()
+ val activePart: VertexPartition[_] = activePartIter.next()
+
+ Iterator.tabulate(pid2vid.size) { pid =>
+ val vidsCandidate = pid2vid(pid)
+ val size = vidsCandidate.length
+ val actives = new PrimitiveVector[VertexID](vidsCandidate.size)
+ var i = 0
+ while (i < size) {
+ val vid = vidsCandidate(i)
+ if (activePart.isDefined(vid)) {
+ actives += vid
+ }
+ i += 1
+ }
+ (pid, actives.trim().array)
+ }
+ }
+}
+
+private[graphx]
+class VertexAttributeBlock[VD: ClassTag](val vids: Array[VertexID], val attrs: Array[VD])
+ extends Serializable {
+ def iterator: Iterator[(VertexID, VD)] =
+ (0 until vids.size).iterator.map { i => (vids(i), attrs(i)) }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTable.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTable.scala
new file mode 100644
index 0000000000..f342fd7437
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTable.scala
@@ -0,0 +1,65 @@
+package org.apache.spark.graphx.impl
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.graphx._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.collection.PrimitiveVector
+
+/**
+ * Stores the locations of edge-partition join sites for each vertex attribute; that is, the routing
+ * information for shipping vertex attributes to edge partitions. This is always cached because it
+ * may be used multiple times in ReplicatedVertexView -- once to ship the vertex attributes and
+ * (possibly) once to ship the active-set information.
+ */
+private[impl]
+class RoutingTable(edges: EdgeRDD[_], vertices: VertexRDD[_]) {
+
+ val bothAttrs: RDD[Array[Array[VertexID]]] = createPid2Vid(true, true)
+ val srcAttrOnly: RDD[Array[Array[VertexID]]] = createPid2Vid(true, false)
+ val dstAttrOnly: RDD[Array[Array[VertexID]]] = createPid2Vid(false, true)
+ val noAttrs: RDD[Array[Array[VertexID]]] = createPid2Vid(false, false)
+
+ def get(includeSrcAttr: Boolean, includeDstAttr: Boolean): RDD[Array[Array[VertexID]]] =
+ (includeSrcAttr, includeDstAttr) match {
+ case (true, true) => bothAttrs
+ case (true, false) => srcAttrOnly
+ case (false, true) => dstAttrOnly
+ case (false, false) => noAttrs
+ }
+
+ private def createPid2Vid(
+ includeSrcAttr: Boolean, includeDstAttr: Boolean): RDD[Array[Array[VertexID]]] = {
+ // Determine which vertices each edge partition needs by creating a mapping from vid to pid.
+ val vid2pid: RDD[(VertexID, PartitionID)] = edges.partitionsRDD.mapPartitions { iter =>
+ val (pid: PartitionID, edgePartition: EdgePartition[_]) = iter.next()
+ val numEdges = edgePartition.size
+ val vSet = new VertexSet
+ if (includeSrcAttr) { // Add src vertices to the set.
+ var i = 0
+ while (i < numEdges) {
+ vSet.add(edgePartition.srcIds(i))
+ i += 1
+ }
+ }
+ if (includeDstAttr) { // Add dst vertices to the set.
+ var i = 0
+ while (i < numEdges) {
+ vSet.add(edgePartition.dstIds(i))
+ i += 1
+ }
+ }
+ vSet.iterator.map { vid => (vid, pid) }
+ }
+
+ val numPartitions = vertices.partitions.size
+ vid2pid.partitionBy(vertices.partitioner.get).mapPartitions { iter =>
+ val pid2vid = Array.fill(numPartitions)(new PrimitiveVector[VertexID])
+ for ((vid, pid) <- iter) {
+ pid2vid(pid) += vid
+ }
+
+ Iterator(pid2vid.map(_.trim().array))
+ }.cache().setName("RoutingTable %s %s".format(includeSrcAttr, includeDstAttr))
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
new file mode 100644
index 0000000000..cbd6318f33
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
@@ -0,0 +1,395 @@
+package org.apache.spark.graphx.impl
+
+import java.io.{EOFException, InputStream, OutputStream}
+import java.nio.ByteBuffer
+
+import org.apache.spark.SparkConf
+import org.apache.spark.graphx._
+import org.apache.spark.serializer._
+
+private[graphx]
+class VertexIDMsgSerializer(conf: SparkConf) extends Serializer {
+ override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+ override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+ def writeObject[T](t: T) = {
+ val msg = t.asInstanceOf[(VertexID, _)]
+ writeVarLong(msg._1, optimizePositive = false)
+ this
+ }
+ }
+
+ override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+ override def readObject[T](): T = {
+ (readVarLong(optimizePositive = false), null).asInstanceOf[T]
+ }
+ }
+ }
+}
+
+/** A special shuffle serializer for VertexBroadcastMessage[Int]. */
+private[graphx]
+class IntVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer {
+ override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+ override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+ def writeObject[T](t: T) = {
+ val msg = t.asInstanceOf[VertexBroadcastMsg[Int]]
+ writeVarLong(msg.vid, optimizePositive = false)
+ writeInt(msg.data)
+ this
+ }
+ }
+
+ override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+ override def readObject[T](): T = {
+ val a = readVarLong(optimizePositive = false)
+ val b = readInt()
+ new VertexBroadcastMsg[Int](0, a, b).asInstanceOf[T]
+ }
+ }
+ }
+}
+
+/** A special shuffle serializer for VertexBroadcastMessage[Long]. */
+private[graphx]
+class LongVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer {
+ override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+ override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+ def writeObject[T](t: T) = {
+ val msg = t.asInstanceOf[VertexBroadcastMsg[Long]]
+ writeVarLong(msg.vid, optimizePositive = false)
+ writeLong(msg.data)
+ this
+ }
+ }
+
+ override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+ override def readObject[T](): T = {
+ val a = readVarLong(optimizePositive = false)
+ val b = readLong()
+ new VertexBroadcastMsg[Long](0, a, b).asInstanceOf[T]
+ }
+ }
+ }
+}
+
+/** A special shuffle serializer for VertexBroadcastMessage[Double]. */
+private[graphx]
+class DoubleVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer {
+ override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+ override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+ def writeObject[T](t: T) = {
+ val msg = t.asInstanceOf[VertexBroadcastMsg[Double]]
+ writeVarLong(msg.vid, optimizePositive = false)
+ writeDouble(msg.data)
+ this
+ }
+ }
+
+ override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+ def readObject[T](): T = {
+ val a = readVarLong(optimizePositive = false)
+ val b = readDouble()
+ new VertexBroadcastMsg[Double](0, a, b).asInstanceOf[T]
+ }
+ }
+ }
+}
+
+/** A special shuffle serializer for AggregationMessage[Int]. */
+private[graphx]
+class IntAggMsgSerializer(conf: SparkConf) extends Serializer {
+ override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+ override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+ def writeObject[T](t: T) = {
+ val msg = t.asInstanceOf[(VertexID, Int)]
+ writeVarLong(msg._1, optimizePositive = false)
+ writeUnsignedVarInt(msg._2)
+ this
+ }
+ }
+
+ override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+ override def readObject[T](): T = {
+ val a = readVarLong(optimizePositive = false)
+ val b = readUnsignedVarInt()
+ (a, b).asInstanceOf[T]
+ }
+ }
+ }
+}
+
+/** A special shuffle serializer for AggregationMessage[Long]. */
+private[graphx]
+class LongAggMsgSerializer(conf: SparkConf) extends Serializer {
+ override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+ override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+ def writeObject[T](t: T) = {
+ val msg = t.asInstanceOf[(VertexID, Long)]
+ writeVarLong(msg._1, optimizePositive = false)
+ writeVarLong(msg._2, optimizePositive = true)
+ this
+ }
+ }
+
+ override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+ override def readObject[T](): T = {
+ val a = readVarLong(optimizePositive = false)
+ val b = readVarLong(optimizePositive = true)
+ (a, b).asInstanceOf[T]
+ }
+ }
+ }
+}
+
+/** A special shuffle serializer for AggregationMessage[Double]. */
+private[graphx]
+class DoubleAggMsgSerializer(conf: SparkConf) extends Serializer {
+ override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+ override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+ def writeObject[T](t: T) = {
+ val msg = t.asInstanceOf[(VertexID, Double)]
+ writeVarLong(msg._1, optimizePositive = false)
+ writeDouble(msg._2)
+ this
+ }
+ }
+
+ override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+ def readObject[T](): T = {
+ val a = readVarLong(optimizePositive = false)
+ val b = readDouble()
+ (a, b).asInstanceOf[T]
+ }
+ }
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Helper classes to shorten the implementation of those special serializers.
+////////////////////////////////////////////////////////////////////////////////
+
+private[graphx]
+abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream {
+ // The implementation should override this one.
+ def writeObject[T](t: T): SerializationStream
+
+ def writeInt(v: Int) {
+ s.write(v >> 24)
+ s.write(v >> 16)
+ s.write(v >> 8)
+ s.write(v)
+ }
+
+ def writeUnsignedVarInt(value: Int) {
+ if ((value >>> 7) == 0) {
+ s.write(value.toInt)
+ } else if ((value >>> 14) == 0) {
+ s.write((value & 0x7F) | 0x80)
+ s.write(value >>> 7)
+ } else if ((value >>> 21) == 0) {
+ s.write((value & 0x7F) | 0x80)
+ s.write(value >>> 7 | 0x80)
+ s.write(value >>> 14)
+ } else if ((value >>> 28) == 0) {
+ s.write((value & 0x7F) | 0x80)
+ s.write(value >>> 7 | 0x80)
+ s.write(value >>> 14 | 0x80)
+ s.write(value >>> 21)
+ } else {
+ s.write((value & 0x7F) | 0x80)
+ s.write(value >>> 7 | 0x80)
+ s.write(value >>> 14 | 0x80)
+ s.write(value >>> 21 | 0x80)
+ s.write(value >>> 28)
+ }
+ }
+
+ def writeVarLong(value: Long, optimizePositive: Boolean) {
+ val v = if (!optimizePositive) (value << 1) ^ (value >> 63) else value
+ if ((v >>> 7) == 0) {
+ s.write(v.toInt)
+ } else if ((v >>> 14) == 0) {
+ s.write(((v & 0x7F) | 0x80).toInt)
+ s.write((v >>> 7).toInt)
+ } else if ((v >>> 21) == 0) {
+ s.write(((v & 0x7F) | 0x80).toInt)
+ s.write((v >>> 7 | 0x80).toInt)
+ s.write((v >>> 14).toInt)
+ } else if ((v >>> 28) == 0) {
+ s.write(((v & 0x7F) | 0x80).toInt)
+ s.write((v >>> 7 | 0x80).toInt)
+ s.write((v >>> 14 | 0x80).toInt)
+ s.write((v >>> 21).toInt)
+ } else if ((v >>> 35) == 0) {
+ s.write(((v & 0x7F) | 0x80).toInt)
+ s.write((v >>> 7 | 0x80).toInt)
+ s.write((v >>> 14 | 0x80).toInt)
+ s.write((v >>> 21 | 0x80).toInt)
+ s.write((v >>> 28).toInt)
+ } else if ((v >>> 42) == 0) {
+ s.write(((v & 0x7F) | 0x80).toInt)
+ s.write((v >>> 7 | 0x80).toInt)
+ s.write((v >>> 14 | 0x80).toInt)
+ s.write((v >>> 21 | 0x80).toInt)
+ s.write((v >>> 28 | 0x80).toInt)
+ s.write((v >>> 35).toInt)
+ } else if ((v >>> 49) == 0) {
+ s.write(((v & 0x7F) | 0x80).toInt)
+ s.write((v >>> 7 | 0x80).toInt)
+ s.write((v >>> 14 | 0x80).toInt)
+ s.write((v >>> 21 | 0x80).toInt)
+ s.write((v >>> 28 | 0x80).toInt)
+ s.write((v >>> 35 | 0x80).toInt)
+ s.write((v >>> 42).toInt)
+ } else if ((v >>> 56) == 0) {
+ s.write(((v & 0x7F) | 0x80).toInt)
+ s.write((v >>> 7 | 0x80).toInt)
+ s.write((v >>> 14 | 0x80).toInt)
+ s.write((v >>> 21 | 0x80).toInt)
+ s.write((v >>> 28 | 0x80).toInt)
+ s.write((v >>> 35 | 0x80).toInt)
+ s.write((v >>> 42 | 0x80).toInt)
+ s.write((v >>> 49).toInt)
+ } else {
+ s.write(((v & 0x7F) | 0x80).toInt)
+ s.write((v >>> 7 | 0x80).toInt)
+ s.write((v >>> 14 | 0x80).toInt)
+ s.write((v >>> 21 | 0x80).toInt)
+ s.write((v >>> 28 | 0x80).toInt)
+ s.write((v >>> 35 | 0x80).toInt)
+ s.write((v >>> 42 | 0x80).toInt)
+ s.write((v >>> 49 | 0x80).toInt)
+ s.write((v >>> 56).toInt)
+ }
+ }
+
+ def writeLong(v: Long) {
+ s.write((v >>> 56).toInt)
+ s.write((v >>> 48).toInt)
+ s.write((v >>> 40).toInt)
+ s.write((v >>> 32).toInt)
+ s.write((v >>> 24).toInt)
+ s.write((v >>> 16).toInt)
+ s.write((v >>> 8).toInt)
+ s.write(v.toInt)
+ }
+
+ //def writeDouble(v: Double): Unit = writeUnsignedVarLong(java.lang.Double.doubleToLongBits(v))
+ def writeDouble(v: Double): Unit = writeLong(java.lang.Double.doubleToLongBits(v))
+
+ override def flush(): Unit = s.flush()
+
+ override def close(): Unit = s.close()
+}
+
+private[graphx]
+abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream {
+ // The implementation should override this one.
+ def readObject[T](): T
+
+ def readInt(): Int = {
+ val first = s.read()
+ if (first < 0) throw new EOFException
+ (first & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF)
+ }
+
+ def readUnsignedVarInt(): Int = {
+ var value: Int = 0
+ var i: Int = 0
+ def readOrThrow(): Int = {
+ val in = s.read()
+ if (in < 0) throw new EOFException
+ in & 0xFF
+ }
+ var b: Int = readOrThrow()
+ while ((b & 0x80) != 0) {
+ value |= (b & 0x7F) << i
+ i += 7
+ if (i > 35) throw new IllegalArgumentException("Variable length quantity is too long")
+ b = readOrThrow()
+ }
+ value | (b << i)
+ }
+
+ def readVarLong(optimizePositive: Boolean): Long = {
+ def readOrThrow(): Int = {
+ val in = s.read()
+ if (in < 0) throw new EOFException
+ in & 0xFF
+ }
+ var b = readOrThrow()
+ var ret: Long = b & 0x7F
+ if ((b & 0x80) != 0) {
+ b = readOrThrow()
+ ret |= (b & 0x7F) << 7
+ if ((b & 0x80) != 0) {
+ b = readOrThrow()
+ ret |= (b & 0x7F) << 14
+ if ((b & 0x80) != 0) {
+ b = readOrThrow()
+ ret |= (b & 0x7F) << 21
+ if ((b & 0x80) != 0) {
+ b = readOrThrow()
+ ret |= (b & 0x7F).toLong << 28
+ if ((b & 0x80) != 0) {
+ b = readOrThrow()
+ ret |= (b & 0x7F).toLong << 35
+ if ((b & 0x80) != 0) {
+ b = readOrThrow()
+ ret |= (b & 0x7F).toLong << 42
+ if ((b & 0x80) != 0) {
+ b = readOrThrow()
+ ret |= (b & 0x7F).toLong << 49
+ if ((b & 0x80) != 0) {
+ b = readOrThrow()
+ ret |= b.toLong << 56
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ if (!optimizePositive) (ret >>> 1) ^ -(ret & 1) else ret
+ }
+
+ def readLong(): Long = {
+ val first = s.read()
+ if (first < 0) throw new EOFException()
+ (first.toLong << 56) |
+ (s.read() & 0xFF).toLong << 48 |
+ (s.read() & 0xFF).toLong << 40 |
+ (s.read() & 0xFF).toLong << 32 |
+ (s.read() & 0xFF).toLong << 24 |
+ (s.read() & 0xFF) << 16 |
+ (s.read() & 0xFF) << 8 |
+ (s.read() & 0xFF)
+ }
+
+ //def readDouble(): Double = java.lang.Double.longBitsToDouble(readUnsignedVarLong())
+ def readDouble(): Double = java.lang.Double.longBitsToDouble(readLong())
+
+ override def close(): Unit = s.close()
+}
+
+private[graphx] sealed trait ShuffleSerializerInstance extends SerializerInstance {
+
+ override def serialize[T](t: T): ByteBuffer = throw new UnsupportedOperationException
+
+ override def deserialize[T](bytes: ByteBuffer): T = throw new UnsupportedOperationException
+
+ override def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T =
+ throw new UnsupportedOperationException
+
+ // The implementation should override the following two.
+ override def serializeStream(s: OutputStream): SerializationStream
+ override def deserializeStream(s: InputStream): DeserializationStream
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala
new file mode 100644
index 0000000000..f97ff75fb2
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala
@@ -0,0 +1,261 @@
+package org.apache.spark.graphx.impl
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.Logging
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+import org.apache.spark.util.collection.BitSet
+
+private[graphx] object VertexPartition {
+
+ def apply[VD: ClassTag](iter: Iterator[(VertexID, VD)]): VertexPartition[VD] = {
+ val map = new PrimitiveKeyOpenHashMap[VertexID, VD]
+ iter.foreach { case (k, v) =>
+ map(k) = v
+ }
+ new VertexPartition(map.keySet, map._values, map.keySet.getBitSet)
+ }
+
+ def apply[VD: ClassTag](iter: Iterator[(VertexID, VD)], mergeFunc: (VD, VD) => VD)
+ : VertexPartition[VD] =
+ {
+ val map = new PrimitiveKeyOpenHashMap[VertexID, VD]
+ iter.foreach { case (k, v) =>
+ map.setMerge(k, v, mergeFunc)
+ }
+ new VertexPartition(map.keySet, map._values, map.keySet.getBitSet)
+ }
+}
+
+
+private[graphx]
+class VertexPartition[@specialized(Long, Int, Double) VD: ClassTag](
+ val index: VertexIdToIndexMap,
+ val values: Array[VD],
+ val mask: BitSet,
+ /** A set of vids of active vertices. May contain vids not in index due to join rewrite. */
+ private val activeSet: Option[VertexSet] = None)
+ extends Logging {
+
+ val capacity: Int = index.capacity
+
+ def size: Int = mask.cardinality()
+
+ /** Return the vertex attribute for the given vertex ID. */
+ def apply(vid: VertexID): VD = values(index.getPos(vid))
+
+ def isDefined(vid: VertexID): Boolean = {
+ val pos = index.getPos(vid)
+ pos >= 0 && mask.get(pos)
+ }
+
+ /** Look up vid in activeSet, throwing an exception if it is None. */
+ def isActive(vid: VertexID): Boolean = {
+ activeSet.get.contains(vid)
+ }
+
+ /** The number of active vertices, if any exist. */
+ def numActives: Option[Int] = activeSet.map(_.size)
+
+ /**
+ * Pass each vertex attribute along with the vertex id through a map
+ * function and retain the original RDD's partitioning and index.
+ *
+ * @tparam VD2 the type returned by the map function
+ *
+ * @param f the function applied to each vertex id and vertex
+ * attribute in the RDD
+ *
+ * @return a new VertexPartition with values obtained by applying `f` to
+ * each of the entries in the original VertexRDD. The resulting
+ * VertexPartition retains the same index.
+ */
+ def map[VD2: ClassTag](f: (VertexID, VD) => VD2): VertexPartition[VD2] = {
+ // Construct a view of the map transformation
+ val newValues = new Array[VD2](capacity)
+ var i = mask.nextSetBit(0)
+ while (i >= 0) {
+ newValues(i) = f(index.getValue(i), values(i))
+ i = mask.nextSetBit(i + 1)
+ }
+ new VertexPartition[VD2](index, newValues, mask)
+ }
+
+ /**
+ * Restrict the vertex set to the set of vertices satisfying the given predicate.
+ *
+ * @param pred the user defined predicate
+ *
+ * @note The vertex set preserves the original index structure which means that the returned
+ * RDD can be easily joined with the original vertex-set. Furthermore, the filter only
+ * modifies the bitmap index and so no new values are allocated.
+ */
+ def filter(pred: (VertexID, VD) => Boolean): VertexPartition[VD] = {
+ // Allocate the array to store the results into
+ val newMask = new BitSet(capacity)
+ // Iterate over the active bits in the old mask and evaluate the predicate
+ var i = mask.nextSetBit(0)
+ while (i >= 0) {
+ if (pred(index.getValue(i), values(i))) {
+ newMask.set(i)
+ }
+ i = mask.nextSetBit(i + 1)
+ }
+ new VertexPartition(index, values, newMask)
+ }
+
+ /**
+ * Hides vertices that are the same between this and other. For vertices that are different, keeps
+ * the values from `other`. The indices of `this` and `other` must be the same.
+ */
+ def diff(other: VertexPartition[VD]): VertexPartition[VD] = {
+ if (index != other.index) {
+ logWarning("Diffing two VertexPartitions with different indexes is slow.")
+ diff(createUsingIndex(other.iterator))
+ } else {
+ val newMask = mask & other.mask
+ var i = newMask.nextSetBit(0)
+ while (i >= 0) {
+ if (values(i) == other.values(i)) {
+ newMask.unset(i)
+ }
+ i = newMask.nextSetBit(i + 1)
+ }
+ new VertexPartition(index, other.values, newMask)
+ }
+ }
+
+ /** Left outer join another VertexPartition. */
+ def leftJoin[VD2: ClassTag, VD3: ClassTag]
+ (other: VertexPartition[VD2])
+ (f: (VertexID, VD, Option[VD2]) => VD3): VertexPartition[VD3] = {
+ if (index != other.index) {
+ logWarning("Joining two VertexPartitions with different indexes is slow.")
+ leftJoin(createUsingIndex(other.iterator))(f)
+ } else {
+ val newValues = new Array[VD3](capacity)
+
+ var i = mask.nextSetBit(0)
+ while (i >= 0) {
+ val otherV: Option[VD2] = if (other.mask.get(i)) Some(other.values(i)) else None
+ newValues(i) = f(index.getValue(i), values(i), otherV)
+ i = mask.nextSetBit(i + 1)
+ }
+ new VertexPartition(index, newValues, mask)
+ }
+ }
+
+ /** Left outer join another iterator of messages. */
+ def leftJoin[VD2: ClassTag, VD3: ClassTag]
+ (other: Iterator[(VertexID, VD2)])
+ (f: (VertexID, VD, Option[VD2]) => VD3): VertexPartition[VD3] = {
+ leftJoin(createUsingIndex(other))(f)
+ }
+
+ /** Inner join another VertexPartition. */
+ def innerJoin[U: ClassTag, VD2: ClassTag](other: VertexPartition[U])
+ (f: (VertexID, VD, U) => VD2): VertexPartition[VD2] = {
+ if (index != other.index) {
+ logWarning("Joining two VertexPartitions with different indexes is slow.")
+ innerJoin(createUsingIndex(other.iterator))(f)
+ } else {
+ val newMask = mask & other.mask
+ val newValues = new Array[VD2](capacity)
+ var i = newMask.nextSetBit(0)
+ while (i >= 0) {
+ newValues(i) = f(index.getValue(i), values(i), other.values(i))
+ i = newMask.nextSetBit(i + 1)
+ }
+ new VertexPartition(index, newValues, newMask)
+ }
+ }
+
+ /**
+ * Inner join an iterator of messages.
+ */
+ def innerJoin[U: ClassTag, VD2: ClassTag]
+ (iter: Iterator[Product2[VertexID, U]])
+ (f: (VertexID, VD, U) => VD2): VertexPartition[VD2] = {
+ innerJoin(createUsingIndex(iter))(f)
+ }
+
+ /**
+ * Similar effect as aggregateUsingIndex((a, b) => a)
+ */
+ def createUsingIndex[VD2: ClassTag](iter: Iterator[Product2[VertexID, VD2]])
+ : VertexPartition[VD2] = {
+ val newMask = new BitSet(capacity)
+ val newValues = new Array[VD2](capacity)
+ iter.foreach { case (vid, vdata) =>
+ val pos = index.getPos(vid)
+ if (pos >= 0) {
+ newMask.set(pos)
+ newValues(pos) = vdata
+ }
+ }
+ new VertexPartition[VD2](index, newValues, newMask)
+ }
+
+ /**
+ * Similar to innerJoin, but vertices from the left side that don't appear in iter will remain in
+ * the partition, hidden by the bitmask.
+ */
+ def innerJoinKeepLeft(iter: Iterator[Product2[VertexID, VD]]): VertexPartition[VD] = {
+ val newMask = new BitSet(capacity)
+ val newValues = new Array[VD](capacity)
+ System.arraycopy(values, 0, newValues, 0, newValues.length)
+ iter.foreach { case (vid, vdata) =>
+ val pos = index.getPos(vid)
+ if (pos >= 0) {
+ newMask.set(pos)
+ newValues(pos) = vdata
+ }
+ }
+ new VertexPartition(index, newValues, newMask)
+ }
+
+ def aggregateUsingIndex[VD2: ClassTag](
+ iter: Iterator[Product2[VertexID, VD2]],
+ reduceFunc: (VD2, VD2) => VD2): VertexPartition[VD2] = {
+ val newMask = new BitSet(capacity)
+ val newValues = new Array[VD2](capacity)
+ iter.foreach { product =>
+ val vid = product._1
+ val vdata = product._2
+ val pos = index.getPos(vid)
+ if (pos >= 0) {
+ if (newMask.get(pos)) {
+ newValues(pos) = reduceFunc(newValues(pos), vdata)
+ } else { // otherwise just store the new value
+ newMask.set(pos)
+ newValues(pos) = vdata
+ }
+ }
+ }
+ new VertexPartition[VD2](index, newValues, newMask)
+ }
+
+ def replaceActives(iter: Iterator[VertexID]): VertexPartition[VD] = {
+ val newActiveSet = new VertexSet
+ iter.foreach(newActiveSet.add(_))
+ new VertexPartition(index, values, mask, Some(newActiveSet))
+ }
+
+ /**
+ * Construct a new VertexPartition whose index contains only the vertices in the mask.
+ */
+ def reindex(): VertexPartition[VD] = {
+ val hashMap = new PrimitiveKeyOpenHashMap[VertexID, VD]
+ val arbitraryMerge = (a: VD, b: VD) => a
+ for ((k, v) <- this.iterator) {
+ hashMap.setMerge(k, v, arbitraryMerge)
+ }
+ new VertexPartition(hashMap.keySet, hashMap._values, hashMap.keySet.getBitSet)
+ }
+
+ def iterator: Iterator[(VertexID, VD)] =
+ mask.iterator.map(ind => (index.getValue(ind), values(ind)))
+
+ def vidIterator: Iterator[VertexID] = mask.iterator.map(ind => index.getValue(ind))
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/package.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/package.scala
new file mode 100644
index 0000000000..cfc3281b64
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/package.scala
@@ -0,0 +1,7 @@
+package org.apache.spark.graphx
+
+import org.apache.spark.util.collection.OpenHashSet
+
+package object impl {
+ private[graphx] type VertexIdToIndexMap = OpenHashSet[VertexID]
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala
new file mode 100644
index 0000000000..e0aff5644e
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala
@@ -0,0 +1,136 @@
+package org.apache.spark.graphx.lib
+
+import org.apache.spark._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.PartitionStrategy._
+
+/**
+ * Driver program for running graph algorithms.
+ */
+object Analytics extends Logging {
+
+ def main(args: Array[String]) = {
+ val host = args(0)
+ val taskType = args(1)
+ val fname = args(2)
+ val options = args.drop(3).map { arg =>
+ arg.dropWhile(_ == '-').split('=') match {
+ case Array(opt, v) => (opt -> v)
+ case _ => throw new IllegalArgumentException("Invalid argument: " + arg)
+ }
+ }
+
+ def pickPartitioner(v: String): PartitionStrategy = {
+ // TODO: Use reflection rather than listing all the partitioning strategies here.
+ v match {
+ case "RandomVertexCut" => RandomVertexCut
+ case "EdgePartition1D" => EdgePartition1D
+ case "EdgePartition2D" => EdgePartition2D
+ case "CanonicalRandomVertexCut" => CanonicalRandomVertexCut
+ case _ => throw new IllegalArgumentException("Invalid PartitionStrategy: " + v)
+ }
+ }
+
+ val conf = new SparkConf()
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
+
+ taskType match {
+ case "pagerank" =>
+ var tol: Float = 0.001F
+ var outFname = ""
+ var numEPart = 4
+ var partitionStrategy: Option[PartitionStrategy] = None
+
+ options.foreach{
+ case ("tol", v) => tol = v.toFloat
+ case ("output", v) => outFname = v
+ case ("numEPart", v) => numEPart = v.toInt
+ case ("partStrategy", v) => partitionStrategy = Some(pickPartitioner(v))
+ case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt)
+ }
+
+ println("======================================")
+ println("| PageRank |")
+ println("======================================")
+
+ val sc = new SparkContext(host, "PageRank(" + fname + ")", conf)
+
+ val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname,
+ minEdgePartitions = numEPart).cache()
+ val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_))
+
+ println("GRAPHX: Number of vertices " + graph.vertices.count)
+ println("GRAPHX: Number of edges " + graph.edges.count)
+
+ val pr = graph.pageRank(tol).vertices.cache()
+
+ println("GRAPHX: Total rank: " + pr.map(_._2).reduce(_+_))
+
+ if (!outFname.isEmpty) {
+ logWarning("Saving pageranks of pages to " + outFname)
+ pr.map{case (id, r) => id + "\t" + r}.saveAsTextFile(outFname)
+ }
+
+ sc.stop()
+
+ case "cc" =>
+ var numIter = Int.MaxValue
+ var numVPart = 4
+ var numEPart = 4
+ var isDynamic = false
+ var partitionStrategy: Option[PartitionStrategy] = None
+
+ options.foreach{
+ case ("numIter", v) => numIter = v.toInt
+ case ("dynamic", v) => isDynamic = v.toBoolean
+ case ("numEPart", v) => numEPart = v.toInt
+ case ("numVPart", v) => numVPart = v.toInt
+ case ("partStrategy", v) => partitionStrategy = Some(pickPartitioner(v))
+ case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt)
+ }
+
+ if (!isDynamic && numIter == Int.MaxValue) {
+ println("Set number of iterations!")
+ sys.exit(1)
+ }
+ println("======================================")
+ println("| Connected Components |")
+ println("======================================")
+
+ val sc = new SparkContext(host, "ConnectedComponents(" + fname + ")", conf)
+ val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname,
+ minEdgePartitions = numEPart).cache()
+ val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_))
+
+ val cc = ConnectedComponents.run(graph)
+ println("Components: " + cc.vertices.map{ case (vid,data) => data}.distinct())
+ sc.stop()
+
+ case "triangles" =>
+ var numEPart = 4
+ // TriangleCount requires the graph to be partitioned
+ var partitionStrategy: PartitionStrategy = RandomVertexCut
+
+ options.foreach{
+ case ("numEPart", v) => numEPart = v.toInt
+ case ("partStrategy", v) => partitionStrategy = pickPartitioner(v)
+ case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt)
+ }
+ println("======================================")
+ println("| Triangle Count |")
+ println("======================================")
+ val sc = new SparkContext(host, "TriangleCount(" + fname + ")", conf)
+ val graph = GraphLoader.edgeListFile(sc, fname, canonicalOrientation = true,
+ minEdgePartitions = numEPart).partitionBy(partitionStrategy).cache()
+ val triangles = TriangleCount.run(graph)
+ println("Triangles: " + triangles.vertices.map {
+ case (vid,data) => data.toLong
+ }.reduce(_ + _) / 3)
+ sc.stop()
+
+ case _ =>
+ println("Invalid task type.")
+ }
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
new file mode 100644
index 0000000000..4d1f5e74df
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
@@ -0,0 +1,38 @@
+package org.apache.spark.graphx.lib
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.graphx._
+
+/** Connected components algorithm. */
+object ConnectedComponents {
+ /**
+ * Compute the connected component membership of each vertex and return a graph with the vertex
+ * value containing the lowest vertex id in the connected component containing that vertex.
+ *
+ * @tparam VD the vertex attribute type (discarded in the computation)
+ * @tparam ED the edge attribute type (preserved in the computation)
+ *
+ * @param graph the graph for which to compute the connected components
+ *
+ * @return a graph with vertex attributes containing the smallest vertex in each
+ * connected component
+ */
+ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexID, ED] = {
+ val ccGraph = graph.mapVertices { case (vid, _) => vid }
+ def sendMessage(edge: EdgeTriplet[VertexID, ED]) = {
+ if (edge.srcAttr < edge.dstAttr) {
+ Iterator((edge.dstId, edge.srcAttr))
+ } else if (edge.srcAttr > edge.dstAttr) {
+ Iterator((edge.srcId, edge.dstAttr))
+ } else {
+ Iterator.empty
+ }
+ }
+ val initialMessage = Long.MaxValue
+ Pregel(ccGraph, initialMessage, activeDirection = EdgeDirection.Either)(
+ vprog = (id, attr, msg) => math.min(attr, msg),
+ sendMsg = sendMessage,
+ mergeMsg = (a, b) => math.min(a, b))
+ } // end of connectedComponents
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
new file mode 100644
index 0000000000..2f4d6d6864
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
@@ -0,0 +1,147 @@
+package org.apache.spark.graphx.lib
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.Logging
+import org.apache.spark.graphx._
+
+/**
+ * PageRank algorithm implementation. There are two implementations of PageRank implemented.
+ *
+ * The first implementation uses the [[Pregel]] interface and runs PageRank for a fixed number
+ * of iterations:
+ * {{{
+ * var PR = Array.fill(n)( 1.0 )
+ * val oldPR = Array.fill(n)( 1.0 )
+ * for( iter <- 0 until numIter ) {
+ * swap(oldPR, PR)
+ * for( i <- 0 until n ) {
+ * PR[i] = alpha + (1 - alpha) * inNbrs[i].map(j => oldPR[j] / outDeg[j]).sum
+ * }
+ * }
+ * }}}
+ *
+ * The second implementation uses the standalone [[Graph]] interface and runs PageRank until
+ * convergence:
+ *
+ * {{{
+ * var PR = Array.fill(n)( 1.0 )
+ * val oldPR = Array.fill(n)( 0.0 )
+ * while( max(abs(PR - oldPr)) > tol ) {
+ * swap(oldPR, PR)
+ * for( i <- 0 until n if abs(PR[i] - oldPR[i]) > tol ) {
+ * PR[i] = alpha + (1 - \alpha) * inNbrs[i].map(j => oldPR[j] / outDeg[j]).sum
+ * }
+ * }
+ * }}}
+ *
+ * `alpha` is the random reset probability (typically 0.15), `inNbrs[i]` is the set of
+ * neighbors whick link to `i` and `outDeg[j]` is the out degree of vertex `j`.
+ *
+ * Note that this is not the "normalized" PageRank and as a consequence pages that have no
+ * inlinks will have a PageRank of alpha.
+ */
+object PageRank extends Logging {
+
+ /**
+ * Run PageRank for a fixed number of iterations returning a graph
+ * with vertex attributes containing the PageRank and edge
+ * attributes the normalized edge weight.
+ *
+ * @tparam VD the original vertex attribute (not used)
+ * @tparam ED the original edge attribute (not used)
+ *
+ * @param graph the graph on which to compute PageRank
+ * @param numIter the number of iterations of PageRank to run
+ * @param resetProb the random reset probability (alpha)
+ *
+ * @return the graph containing with each vertex containing the PageRank and each edge
+ * containing the normalized weight.
+ *
+ */
+ def run[VD: ClassTag, ED: ClassTag](
+ graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15): Graph[Double, Double] =
+ {
+ // Initialize the pagerankGraph with each edge attribute having
+ // weight 1/outDegree and each vertex with attribute 1.0.
+ val pagerankGraph: Graph[Double, Double] = graph
+ // Associate the degree with each vertex
+ .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) }
+ // Set the weight on the edges based on the degree
+ .mapTriplets( e => 1.0 / e.srcAttr )
+ // Set the vertex attributes to the initial pagerank values
+ .mapVertices( (id, attr) => 1.0 )
+ .cache()
+
+ // Define the three functions needed to implement PageRank in the GraphX
+ // version of Pregel
+ def vertexProgram(id: VertexID, attr: Double, msgSum: Double): Double =
+ resetProb + (1.0 - resetProb) * msgSum
+ def sendMessage(edge: EdgeTriplet[Double, Double]) =
+ Iterator((edge.dstId, edge.srcAttr * edge.attr))
+ def messageCombiner(a: Double, b: Double): Double = a + b
+ // The initial message received by all vertices in PageRank
+ val initialMessage = 0.0
+
+ // Execute pregel for a fixed number of iterations.
+ Pregel(pagerankGraph, initialMessage, numIter, activeDirection = EdgeDirection.Out)(
+ vertexProgram, sendMessage, messageCombiner)
+ }
+
+ /**
+ * Run a dynamic version of PageRank returning a graph with vertex attributes containing the
+ * PageRank and edge attributes containing the normalized edge weight.
+ *
+ * @tparam VD the original vertex attribute (not used)
+ * @tparam ED the original edge attribute (not used)
+ *
+ * @param graph the graph on which to compute PageRank
+ * @param tol the tolerance allowed at convergence (smaller => more accurate).
+ * @param resetProb the random reset probability (alpha)
+ *
+ * @return the graph containing with each vertex containing the PageRank and each edge
+ * containing the normalized weight.
+ */
+ def runUntilConvergence[VD: ClassTag, ED: ClassTag](
+ graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): Graph[Double, Double] =
+ {
+ // Initialize the pagerankGraph with each edge attribute
+ // having weight 1/outDegree and each vertex with attribute 1.0.
+ val pagerankGraph: Graph[(Double, Double), Double] = graph
+ // Associate the degree with each vertex
+ .outerJoinVertices(graph.outDegrees) {
+ (vid, vdata, deg) => deg.getOrElse(0)
+ }
+ // Set the weight on the edges based on the degree
+ .mapTriplets( e => 1.0 / e.srcAttr )
+ // Set the vertex attributes to (initalPR, delta = 0)
+ .mapVertices( (id, attr) => (0.0, 0.0) )
+ .cache()
+
+ // Define the three functions needed to implement PageRank in the GraphX
+ // version of Pregel
+ def vertexProgram(id: VertexID, attr: (Double, Double), msgSum: Double): (Double, Double) = {
+ val (oldPR, lastDelta) = attr
+ val newPR = oldPR + (1.0 - resetProb) * msgSum
+ (newPR, newPR - oldPR)
+ }
+
+ def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = {
+ if (edge.srcAttr._2 > tol) {
+ Iterator((edge.dstId, edge.srcAttr._2 * edge.attr))
+ } else {
+ Iterator.empty
+ }
+ }
+
+ def messageCombiner(a: Double, b: Double): Double = a + b
+
+ // The initial message received by all vertices in PageRank
+ val initialMessage = resetProb / (1.0 - resetProb)
+
+ // Execute a dynamic version of Pregel.
+ Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out)(
+ vertexProgram, sendMessage, messageCombiner)
+ .mapVertices((vid, attr) => attr._1)
+ } // end of deltaPageRank
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
new file mode 100644
index 0000000000..ba6517e012
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
@@ -0,0 +1,138 @@
+package org.apache.spark.graphx.lib
+
+import scala.util.Random
+import org.apache.commons.math.linear._
+import org.apache.spark.rdd._
+import org.apache.spark.graphx._
+
+/** Implementation of SVD++ algorithm. */
+object SVDPlusPlus {
+
+ /** Configuration parameters for SVDPlusPlus. */
+ class Conf(
+ var rank: Int,
+ var maxIters: Int,
+ var minVal: Double,
+ var maxVal: Double,
+ var gamma1: Double,
+ var gamma2: Double,
+ var gamma6: Double,
+ var gamma7: Double)
+ extends Serializable
+
+ /**
+ * Implement SVD++ based on "Factorization Meets the Neighborhood:
+ * a Multifaceted Collaborative Filtering Model",
+ * available at [[http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf]].
+ *
+ * The prediction rule is rui = u + bu + bi + qi*(pu + |N(u)|^(-0.5)*sum(y)),
+ * see the details on page 6.
+ *
+ * @param edges edges for constructing the graph
+ *
+ * @param conf SVDPlusPlus parameters
+ *
+ * @return a graph with vertex attributes containing the trained model
+ */
+ def run(edges: RDD[Edge[Double]], conf: Conf)
+ : (Graph[(RealVector, RealVector, Double, Double), Double], Double) =
+ {
+ // Generate default vertex attribute
+ def defaultF(rank: Int): (RealVector, RealVector, Double, Double) = {
+ val v1 = new ArrayRealVector(rank)
+ val v2 = new ArrayRealVector(rank)
+ for (i <- 0 until rank) {
+ v1.setEntry(i, Random.nextDouble())
+ v2.setEntry(i, Random.nextDouble())
+ }
+ (v1, v2, 0.0, 0.0)
+ }
+
+ // calculate global rating mean
+ edges.cache()
+ val (rs, rc) = edges.map(e => (e.attr, 1L)).reduce((a, b) => (a._1 + b._1, a._2 + b._2))
+ val u = rs / rc
+
+ // construct graph
+ var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache()
+
+ // Calculate initial bias and norm
+ val t0 = g.mapReduceTriplets(
+ et => Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))),
+ (g1: (Long, Double), g2: (Long, Double)) => (g1._1 + g2._1, g1._2 + g2._2))
+
+ g = g.outerJoinVertices(t0) {
+ (vid: VertexID, vd: (RealVector, RealVector, Double, Double), msg: Option[(Long, Double)]) =>
+ (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1))
+ }
+
+ def mapTrainF(conf: Conf, u: Double)
+ (et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double])
+ : Iterator[(VertexID, (RealVector, RealVector, Double))] = {
+ val (usr, itm) = (et.srcAttr, et.dstAttr)
+ val (p, q) = (usr._1, itm._1)
+ var pred = u + usr._3 + itm._3 + q.dotProduct(usr._2)
+ pred = math.max(pred, conf.minVal)
+ pred = math.min(pred, conf.maxVal)
+ val err = et.attr - pred
+ val updateP = q.mapMultiply(err)
+ .subtract(p.mapMultiply(conf.gamma7))
+ .mapMultiply(conf.gamma2)
+ val updateQ = usr._2.mapMultiply(err)
+ .subtract(q.mapMultiply(conf.gamma7))
+ .mapMultiply(conf.gamma2)
+ val updateY = q.mapMultiply(err * usr._4)
+ .subtract(itm._2.mapMultiply(conf.gamma7))
+ .mapMultiply(conf.gamma2)
+ Iterator((et.srcId, (updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)),
+ (et.dstId, (updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1)))
+ }
+
+ for (i <- 0 until conf.maxIters) {
+ // Phase 1, calculate pu + |N(u)|^(-0.5)*sum(y) for user nodes
+ g.cache()
+ val t1 = g.mapReduceTriplets(
+ et => Iterator((et.srcId, et.dstAttr._2)),
+ (g1: RealVector, g2: RealVector) => g1.add(g2))
+ g = g.outerJoinVertices(t1) {
+ (vid: VertexID, vd: (RealVector, RealVector, Double, Double), msg: Option[RealVector]) =>
+ if (msg.isDefined) (vd._1, vd._1.add(msg.get.mapMultiply(vd._4)), vd._3, vd._4) else vd
+ }
+
+ // Phase 2, update p for user nodes and q, y for item nodes
+ g.cache()
+ val t2 = g.mapReduceTriplets(
+ mapTrainF(conf, u),
+ (g1: (RealVector, RealVector, Double), g2: (RealVector, RealVector, Double)) =>
+ (g1._1.add(g2._1), g1._2.add(g2._2), g1._3 + g2._3))
+ g = g.outerJoinVertices(t2) {
+ (vid: VertexID,
+ vd: (RealVector, RealVector, Double, Double),
+ msg: Option[(RealVector, RealVector, Double)]) =>
+ (vd._1.add(msg.get._1), vd._2.add(msg.get._2), vd._3 + msg.get._3, vd._4)
+ }
+ }
+
+ // calculate error on training set
+ def mapTestF(conf: Conf, u: Double)
+ (et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double])
+ : Iterator[(VertexID, Double)] =
+ {
+ val (usr, itm) = (et.srcAttr, et.dstAttr)
+ val (p, q) = (usr._1, itm._1)
+ var pred = u + usr._3 + itm._3 + q.dotProduct(usr._2)
+ pred = math.max(pred, conf.minVal)
+ pred = math.min(pred, conf.maxVal)
+ val err = (et.attr - pred) * (et.attr - pred)
+ Iterator((et.dstId, err))
+ }
+ g.cache()
+ val t3 = g.mapReduceTriplets(mapTestF(conf, u), (g1: Double, g2: Double) => g1 + g2)
+ g = g.outerJoinVertices(t3) {
+ (vid: VertexID, vd: (RealVector, RealVector, Double, Double), msg: Option[Double]) =>
+ if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd
+ }
+
+ (g, u)
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala
new file mode 100644
index 0000000000..d3d496e335
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala
@@ -0,0 +1,94 @@
+package org.apache.spark.graphx.lib
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.graphx._
+
+/** Strongly connected components algorithm implementation. */
+object StronglyConnectedComponents {
+
+ /**
+ * Compute the strongly connected component (SCC) of each vertex and return a graph with the
+ * vertex value containing the lowest vertex id in the SCC containing that vertex.
+ *
+ * @tparam VD the vertex attribute type (discarded in the computation)
+ * @tparam ED the edge attribute type (preserved in the computation)
+ *
+ * @param graph the graph for which to compute the SCC
+ *
+ * @return a graph with vertex attributes containing the smallest vertex id in each SCC
+ */
+ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int): Graph[VertexID, ED] = {
+
+ // the graph we update with final SCC ids, and the graph we return at the end
+ var sccGraph = graph.mapVertices { case (vid, _) => vid }
+ // graph we are going to work with in our iterations
+ var sccWorkGraph = graph.mapVertices { case (vid, _) => (vid, false) }.cache()
+
+ var numVertices = sccWorkGraph.numVertices
+ var iter = 0
+ while (sccWorkGraph.numVertices > 0 && iter < numIter) {
+ iter += 1
+ do {
+ numVertices = sccWorkGraph.numVertices
+ sccWorkGraph = sccWorkGraph.outerJoinVertices(sccWorkGraph.outDegrees) {
+ (vid, data, degreeOpt) => if (degreeOpt.isDefined) data else (vid, true)
+ }.outerJoinVertices(sccWorkGraph.inDegrees) {
+ (vid, data, degreeOpt) => if (degreeOpt.isDefined) data else (vid, true)
+ }.cache()
+
+ // get all vertices to be removed
+ val finalVertices = sccWorkGraph.vertices
+ .filter { case (vid, (scc, isFinal)) => isFinal}
+ .mapValues { (vid, data) => data._1}
+
+ // write values to sccGraph
+ sccGraph = sccGraph.outerJoinVertices(finalVertices) {
+ (vid, scc, opt) => opt.getOrElse(scc)
+ }
+ // only keep vertices that are not final
+ sccWorkGraph = sccWorkGraph.subgraph(vpred = (vid, data) => !data._2).cache()
+ } while (sccWorkGraph.numVertices < numVertices)
+
+ sccWorkGraph = sccWorkGraph.mapVertices{ case (vid, (color, isFinal)) => (vid, isFinal) }
+
+ // collect min of all my neighbor's scc values, update if it's smaller than mine
+ // then notify any neighbors with scc values larger than mine
+ sccWorkGraph = Pregel[(VertexID, Boolean), ED, VertexID](
+ sccWorkGraph, Long.MaxValue, activeDirection = EdgeDirection.Out)(
+ (vid, myScc, neighborScc) => (math.min(myScc._1, neighborScc), myScc._2),
+ e => {
+ if (e.srcId < e.dstId) {
+ Iterator((e.dstId, e.srcAttr._1))
+ } else {
+ Iterator()
+ }
+ },
+ (vid1, vid2) => math.min(vid1, vid2))
+
+ // start at root of SCCs. Traverse values in reverse, notify all my neighbors
+ // do not propagate if colors do not match!
+ sccWorkGraph = Pregel[(VertexID, Boolean), ED, Boolean](
+ sccWorkGraph, false, activeDirection = EdgeDirection.In)(
+ // vertex is final if it is the root of a color
+ // or it has the same color as a neighbor that is final
+ (vid, myScc, existsSameColorFinalNeighbor) => {
+ val isColorRoot = vid == myScc._1
+ (myScc._1, myScc._2 || isColorRoot || existsSameColorFinalNeighbor)
+ },
+ // activate neighbor if they are not final, you are, and you have the same color
+ e => {
+ val sameColor = e.dstAttr._1 == e.srcAttr._1
+ val onlyDstIsFinal = e.dstAttr._2 && !e.srcAttr._2
+ if (sameColor && onlyDstIsFinal) {
+ Iterator((e.srcId, e.dstAttr._2))
+ } else {
+ Iterator()
+ }
+ },
+ (final1, final2) => final1 || final2)
+ }
+ sccGraph
+ }
+
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
new file mode 100644
index 0000000000..23c9c40594
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
@@ -0,0 +1,76 @@
+package org.apache.spark.graphx.lib
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.graphx._
+
+/**
+ * Compute the number of triangles passing through each vertex.
+ *
+ * The algorithm is relatively straightforward and can be computed in three steps:
+ *
+ * <ul>
+ * <li>Compute the set of neighbors for each vertex
+ * <li>For each edge compute the intersection of the sets and send the count to both vertices.
+ * <li> Compute the sum at each vertex and divide by two since each triangle is counted twice.
+ * </ul>
+ *
+ * Note that the input graph should have its edges in canonical direction
+ * (i.e. the `sourceId` less than `destId`). Also the graph must have been partitioned
+ * using [[org.apache.spark.graphx.Graph#partitionBy]].
+ */
+object TriangleCount {
+
+ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD,ED]): Graph[Int, ED] = {
+ // Remove redundant edges
+ val g = graph.groupEdges((a, b) => a).cache()
+
+ // Construct set representations of the neighborhoods
+ val nbrSets: VertexRDD[VertexSet] =
+ g.collectNeighborIds(EdgeDirection.Either).mapValues { (vid, nbrs) =>
+ val set = new VertexSet(4)
+ var i = 0
+ while (i < nbrs.size) {
+ // prevent self cycle
+ if(nbrs(i) != vid) {
+ set.add(nbrs(i))
+ }
+ i += 1
+ }
+ set
+ }
+ // join the sets with the graph
+ val setGraph: Graph[VertexSet, ED] = g.outerJoinVertices(nbrSets) {
+ (vid, _, optSet) => optSet.getOrElse(null)
+ }
+ // Edge function computes intersection of smaller vertex with larger vertex
+ def edgeFunc(et: EdgeTriplet[VertexSet, ED]): Iterator[(VertexID, Int)] = {
+ assert(et.srcAttr != null)
+ assert(et.dstAttr != null)
+ val (smallSet, largeSet) = if (et.srcAttr.size < et.dstAttr.size) {
+ (et.srcAttr, et.dstAttr)
+ } else {
+ (et.dstAttr, et.srcAttr)
+ }
+ val iter = smallSet.iterator
+ var counter: Int = 0
+ while (iter.hasNext) {
+ val vid = iter.next()
+ if (vid != et.srcId && vid != et.dstId && largeSet.contains(vid)) {
+ counter += 1
+ }
+ }
+ Iterator((et.srcId, counter), (et.dstId, counter))
+ }
+ // compute the intersection along edges
+ val counters: VertexRDD[Int] = setGraph.mapReduceTriplets(edgeFunc, _ + _)
+ // Merge counters with the graph and divide by two since each triangle is counted twice
+ g.outerJoinVertices(counters) {
+ (vid, _, optCounter: Option[Int]) =>
+ val dblCount = optCounter.getOrElse(0)
+ // double count should be even (divisible by two)
+ assert((dblCount & 1) == 0)
+ dblCount / 2
+ }
+ } // end of TriangleCount
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/package.scala b/graphx/src/main/scala/org/apache/spark/graphx/package.scala
new file mode 100644
index 0000000000..60dfc1dc37
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/package.scala
@@ -0,0 +1,18 @@
+package org.apache.spark
+
+import org.apache.spark.util.collection.OpenHashSet
+
+/** GraphX is a graph processing framework built on top of Spark. */
+package object graphx {
+ /**
+ * A 64-bit vertex identifier that uniquely identifies a vertex within a graph. It does not need
+ * to follow any ordering or any constraints other than uniqueness.
+ */
+ type VertexID = Long
+
+ /** Integer identifer of a graph partition. */
+ // TODO: Consider using Char.
+ type PartitionID = Int
+
+ private[graphx] type VertexSet = OpenHashSet[VertexID]
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala
new file mode 100644
index 0000000000..1c5b234d74
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala
@@ -0,0 +1,117 @@
+package org.apache.spark.graphx.util
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+
+import scala.collection.mutable.HashSet
+
+import org.apache.spark.util.Utils
+
+import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor}
+import org.objectweb.asm.Opcodes._
+
+
+/**
+ * Includes an utility function to test whether a function accesses a specific attribute
+ * of an object.
+ */
+private[graphx] object BytecodeUtils {
+
+ /**
+ * Test whether the given closure invokes the specified method in the specified class.
+ */
+ def invokedMethod(closure: AnyRef, targetClass: Class[_], targetMethod: String): Boolean = {
+ if (_invokedMethod(closure.getClass, "apply", targetClass, targetMethod)) {
+ true
+ } else {
+ // look at closures enclosed in this closure
+ for (f <- closure.getClass.getDeclaredFields
+ if f.getType.getName.startsWith("scala.Function")) {
+ f.setAccessible(true)
+ if (invokedMethod(f.get(closure), targetClass, targetMethod)) {
+ return true
+ }
+ }
+ return false
+ }
+ }
+
+ private def _invokedMethod(cls: Class[_], method: String,
+ targetClass: Class[_], targetMethod: String): Boolean = {
+
+ val seen = new HashSet[(Class[_], String)]
+ var stack = List[(Class[_], String)]((cls, method))
+
+ while (stack.nonEmpty) {
+ val (c, m) = stack.head
+ stack = stack.tail
+ seen.add((c, m))
+ val finder = new MethodInvocationFinder(c.getName, m)
+ getClassReader(c).accept(finder, 0)
+ for (classMethod <- finder.methodsInvoked) {
+ //println(classMethod)
+ if (classMethod._1 == targetClass && classMethod._2 == targetMethod) {
+ return true
+ } else if (!seen.contains(classMethod)) {
+ stack = classMethod :: stack
+ }
+ }
+ }
+ return false
+ }
+
+ /**
+ * Get an ASM class reader for a given class from the JAR that loaded it.
+ */
+ private def getClassReader(cls: Class[_]): ClassReader = {
+ // Copy data over, before delegating to ClassReader - else we can run out of open file handles.
+ val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
+ val resourceStream = cls.getResourceAsStream(className)
+ // todo: Fixme - continuing with earlier behavior ...
+ if (resourceStream == null) return new ClassReader(resourceStream)
+
+ val baos = new ByteArrayOutputStream(128)
+ Utils.copyStream(resourceStream, baos, true)
+ new ClassReader(new ByteArrayInputStream(baos.toByteArray))
+ }
+
+ /**
+ * Given the class name, return whether we should look into the class or not. This is used to
+ * skip examing a large quantity of Java or Scala classes that we know for sure wouldn't access
+ * the closures. Note that the class name is expected in ASM style (i.e. use "/" instead of ".").
+ */
+ private def skipClass(className: String): Boolean = {
+ val c = className
+ c.startsWith("java/") || c.startsWith("scala/") || c.startsWith("javax/")
+ }
+
+ /**
+ * Find the set of methods invoked by the specified method in the specified class.
+ * For example, after running the visitor,
+ * MethodInvocationFinder("spark/graph/Foo", "test")
+ * its methodsInvoked variable will contain the set of methods invoked directly by
+ * Foo.test(). Interface invocations are not returned as part of the result set because we cannot
+ * determine the actual metod invoked by inspecting the bytecode.
+ */
+ private class MethodInvocationFinder(className: String, methodName: String)
+ extends ClassVisitor(ASM4) {
+
+ val methodsInvoked = new HashSet[(Class[_], String)]
+
+ override def visitMethod(access: Int, name: String, desc: String,
+ sig: String, exceptions: Array[String]): MethodVisitor = {
+ if (name == methodName) {
+ new MethodVisitor(ASM4) {
+ override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) {
+ if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) {
+ if (!skipClass(owner)) {
+ methodsInvoked.add((Class.forName(owner.replace("/", ".")), name))
+ }
+ }
+ }
+ }
+ } else {
+ null
+ }
+ }
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
new file mode 100644
index 0000000000..57422ce3f1
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
@@ -0,0 +1,218 @@
+package org.apache.spark.graphx.util
+
+import scala.annotation.tailrec
+import scala.math._
+import scala.reflect.ClassTag
+import scala.util._
+
+import org.apache.spark._
+import org.apache.spark.serializer._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.Graph
+import org.apache.spark.graphx.Edge
+import org.apache.spark.graphx.impl.GraphImpl
+
+/** A collection of graph generating functions. */
+object GraphGenerators {
+
+ val RMATa = 0.45
+ val RMATb = 0.15
+ val RMATc = 0.15
+ val RMATd = 0.25
+
+ // Right now it just generates a bunch of edges where
+ // the edge data is the weight (default 1)
+ /**
+ * Generate a graph whose vertex out degree is log normal.
+ */
+ def logNormalGraph(sc: SparkContext, numVertices: Int): Graph[Int, Int] = {
+ // based on Pregel settings
+ val mu = 4
+ val sigma = 1.3
+
+ val vertices: RDD[(VertexID, Int)] = sc.parallelize(0 until numVertices).map{
+ src => (src, sampleLogNormal(mu, sigma, numVertices))
+ }
+ val edges = vertices.flatMap { v =>
+ generateRandomEdges(v._1.toInt, v._2, numVertices)
+ }
+ Graph(vertices, edges, 0)
+ }
+
+ def generateRandomEdges(src: Int, numEdges: Int, maxVertexID: Int): Array[Edge[Int]] = {
+ val rand = new Random()
+ Array.fill(maxVertexID) { Edge[Int](src, rand.nextInt(maxVertexID), 1) }
+ }
+
+ /**
+ * Randomly samples from a log normal distribution whose corresponding normal distribution has the
+ * the given mean and standard deviation. It uses the formula `X = exp(m+s*Z)` where `m`, `s` are
+ * the mean, standard deviation of the lognormal distribution and `Z ~ N(0, 1)`. In this function,
+ * `m = e^(mu+sigma^2/2)` and `s = sqrt[(e^(sigma^2) - 1)(e^(2*mu+sigma^2))]`.
+ *
+ * @param mu the mean of the normal distribution
+ * @param sigma the standard deviation of the normal distribution
+ * @param maxVal exclusive upper bound on the value of the sample
+ */
+ private def sampleLogNormal(mu: Double, sigma: Double, maxVal: Int): Int = {
+ val rand = new Random()
+ val m = math.exp(mu+(sigma*sigma)/2.0)
+ val s = math.sqrt((math.exp(sigma*sigma) - 1) * math.exp(2*mu + sigma*sigma))
+ // Z ~ N(0, 1)
+ var X: Double = maxVal
+
+ while (X >= maxVal) {
+ val Z = rand.nextGaussian()
+ X = math.exp(mu + sigma*Z)
+ }
+ math.round(X.toFloat)
+ }
+
+ /**
+ * A random graph generator using the R-MAT model, proposed in
+ * "R-MAT: A Recursive Model for Graph Mining" by Chakrabarti et al.
+ *
+ * See [[http://www.cs.cmu.edu/~christos/PUBLICATIONS/siam04.pdf]].
+ */
+ def rmatGraph(sc: SparkContext, requestedNumVertices: Int, numEdges: Int): Graph[Int, Int] = {
+ // let N = requestedNumVertices
+ // the number of vertices is 2^n where n=ceil(log2[N])
+ // This ensures that the 4 quadrants are the same size at all recursion levels
+ val numVertices = math.round(
+ math.pow(2.0, math.ceil(math.log(requestedNumVertices) / math.log(2.0)))).toInt
+ var edges: Set[Edge[Int]] = Set()
+ while (edges.size < numEdges) {
+ if (edges.size % 100 == 0) {
+ println(edges.size + " edges")
+ }
+ edges += addEdge(numVertices)
+ }
+ outDegreeFromEdges(sc.parallelize(edges.toList))
+ }
+
+ private def outDegreeFromEdges[ED: ClassTag](edges: RDD[Edge[ED]]): Graph[Int, ED] = {
+ val vertices = edges.flatMap { edge => List((edge.srcId, 1)) }
+ .reduceByKey(_ + _)
+ .map{ case (vid, degree) => (vid, degree) }
+ Graph(vertices, edges, 0)
+ }
+
+ /**
+ * @param numVertices Specifies the total number of vertices in the graph (used to get
+ * the dimensions of the adjacency matrix
+ */
+ private def addEdge(numVertices: Int): Edge[Int] = {
+ //val (src, dst) = chooseCell(numVertices/2.0, numVertices/2.0, numVertices/2.0)
+ val v = math.round(numVertices.toFloat/2.0).toInt
+
+ val (src, dst) = chooseCell(v, v, v)
+ Edge[Int](src, dst, 1)
+ }
+
+ /**
+ * This method recursively subdivides the the adjacency matrix into quadrants
+ * until it picks a single cell. The naming conventions in this paper match
+ * those of the R-MAT paper. There are a power of 2 number of nodes in the graph.
+ * The adjacency matrix looks like:
+ * <pre>
+ *
+ * dst ->
+ * (x,y) *************** _
+ * | | | |
+ * | a | b | |
+ * src | | | |
+ * | *************** | T
+ * \|/ | | | |
+ * | c | d | |
+ * | | | |
+ * *************** -
+ * </pre>
+ *
+ * where this represents the subquadrant of the adj matrix currently being
+ * subdivided. (x,y) represent the upper left hand corner of the subquadrant,
+ * and T represents the side length (guaranteed to be a power of 2).
+ *
+ * After choosing the next level subquadrant, we get the resulting sets
+ * of parameters:
+ * {{{
+ * quad = a, x'=x, y'=y, T'=T/2
+ * quad = b, x'=x+T/2, y'=y, T'=T/2
+ * quad = c, x'=x, y'=y+T/2, T'=T/2
+ * quad = d, x'=x+T/2, y'=y+T/2, T'=T/2
+ * }}}
+ */
+ @tailrec
+ private def chooseCell(x: Int, y: Int, t: Int): (Int, Int) = {
+ if (t <= 1) {
+ (x, y)
+ } else {
+ val newT = math.round(t.toFloat/2.0).toInt
+ pickQuadrant(RMATa, RMATb, RMATc, RMATd) match {
+ case 0 => chooseCell(x, y, newT)
+ case 1 => chooseCell(x+newT, y, newT)
+ case 2 => chooseCell(x, y+newT, newT)
+ case 3 => chooseCell(x+newT, y+newT, newT)
+ }
+ }
+ }
+
+ // TODO(crankshaw) turn result into an enum (or case class for pattern matching}
+ private def pickQuadrant(a: Double, b: Double, c: Double, d: Double): Int = {
+ if (a + b + c + d != 1.0) {
+ throw new IllegalArgumentException(
+ "R-MAT probability parameters sum to " + (a+b+c+d) + ", should sum to 1.0")
+ }
+ val rand = new Random()
+ val result = rand.nextDouble()
+ result match {
+ case x if x < a => 0 // 0 corresponds to quadrant a
+ case x if (x >= a && x < a + b) => 1 // 1 corresponds to b
+ case x if (x >= a + b && x < a + b + c) => 2 // 2 corresponds to c
+ case _ => 3 // 3 corresponds to d
+ }
+ }
+
+ /**
+ * Create `rows` by `cols` grid graph with each vertex connected to its
+ * row+1 and col+1 neighbors. Vertex ids are assigned in row major
+ * order.
+ *
+ * @param sc the spark context in which to construct the graph
+ * @param rows the number of rows
+ * @param cols the number of columns
+ *
+ * @return A graph containing vertices with the row and column ids
+ * as their attributes and edge values as 1.0.
+ */
+ def gridGraph(sc: SparkContext, rows: Int, cols: Int): Graph[(Int,Int), Double] = {
+ // Convert row column address into vertex ids (row major order)
+ def sub2ind(r: Int, c: Int): VertexID = r * cols + c
+
+ val vertices: RDD[(VertexID, (Int,Int))] =
+ sc.parallelize(0 until rows).flatMap( r => (0 until cols).map( c => (sub2ind(r,c), (r,c)) ) )
+ val edges: RDD[Edge[Double]] =
+ vertices.flatMap{ case (vid, (r,c)) =>
+ (if (r+1 < rows) { Seq( (sub2ind(r, c), sub2ind(r+1, c))) } else { Seq.empty }) ++
+ (if (c+1 < cols) { Seq( (sub2ind(r, c), sub2ind(r, c+1))) } else { Seq.empty })
+ }.map{ case (src, dst) => Edge(src, dst, 1.0) }
+ Graph(vertices, edges)
+ } // end of gridGraph
+
+ /**
+ * Create a star graph with vertex 0 being the center.
+ *
+ * @param sc the spark context in which to construct the graph
+ * @param nverts the number of vertices in the star
+ *
+ * @return A star graph containing `nverts` vertices with vertex 0
+ * being the center vertex.
+ */
+ def starGraph(sc: SparkContext, nverts: Int): Graph[Int, Int] = {
+ val edges: RDD[(VertexID, VertexID)] = sc.parallelize(1 until nverts).map(vid => (vid, 0))
+ Graph.fromEdgeTuples(edges, 1)
+ } // end of starGraph
+
+} // end of Graph Generators
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala
new file mode 100644
index 0000000000..7b02e2ed1a
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.graphx.util.collection
+
+import org.apache.spark.util.collection.OpenHashSet
+
+import scala.reflect._
+
+/**
+ * A fast hash map implementation for primitive, non-null keys. This hash map supports
+ * insertions and updates, but not deletions. This map is about an order of magnitude
+ * faster than java.util.HashMap, while using much less space overhead.
+ *
+ * Under the hood, it uses our OpenHashSet implementation.
+ */
+private[graphx]
+class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
+ @specialized(Long, Int, Double) V: ClassTag](
+ val keySet: OpenHashSet[K], var _values: Array[V])
+ extends Iterable[(K, V)]
+ with Serializable {
+
+ /**
+ * Allocate an OpenHashMap with a fixed initial capacity
+ */
+ def this(initialCapacity: Int) =
+ this(new OpenHashSet[K](initialCapacity), new Array[V](initialCapacity))
+
+ /**
+ * Allocate an OpenHashMap with a default initial capacity, providing a true
+ * no-argument constructor.
+ */
+ def this() = this(64)
+
+ /**
+ * Allocate an OpenHashMap with a fixed initial capacity
+ */
+ def this(keySet: OpenHashSet[K]) = this(keySet, new Array[V](keySet.capacity))
+
+ require(classTag[K] == classTag[Long] || classTag[K] == classTag[Int])
+
+ private var _oldValues: Array[V] = null
+
+ override def size = keySet.size
+
+ /** Get the value for a given key */
+ def apply(k: K): V = {
+ val pos = keySet.getPos(k)
+ _values(pos)
+ }
+
+ /** Get the value for a given key, or returns elseValue if it doesn't exist. */
+ def getOrElse(k: K, elseValue: V): V = {
+ val pos = keySet.getPos(k)
+ if (pos >= 0) _values(pos) else elseValue
+ }
+
+ /** Set the value for a key */
+ def update(k: K, v: V) {
+ val pos = keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
+ _values(pos) = v
+ keySet.rehashIfNeeded(k, grow, move)
+ _oldValues = null
+ }
+
+
+ /** Set the value for a key */
+ def setMerge(k: K, v: V, mergeF: (V, V) => V) {
+ val pos = keySet.addWithoutResize(k)
+ val ind = pos & OpenHashSet.POSITION_MASK
+ if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { // if first add
+ _values(ind) = v
+ } else {
+ _values(ind) = mergeF(_values(ind), v)
+ }
+ keySet.rehashIfNeeded(k, grow, move)
+ _oldValues = null
+ }
+
+
+ /**
+ * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
+ * set its value to mergeValue(oldValue).
+ *
+ * @return the newly updated value.
+ */
+ def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
+ val pos = keySet.addWithoutResize(k)
+ if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
+ val newValue = defaultValue
+ _values(pos & OpenHashSet.POSITION_MASK) = newValue
+ keySet.rehashIfNeeded(k, grow, move)
+ newValue
+ } else {
+ _values(pos) = mergeValue(_values(pos))
+ _values(pos)
+ }
+ }
+
+ override def iterator = new Iterator[(K, V)] {
+ var pos = 0
+ var nextPair: (K, V) = computeNextPair()
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def computeNextPair(): (K, V) = {
+ pos = keySet.nextPos(pos)
+ if (pos >= 0) {
+ val ret = (keySet.getValue(pos), _values(pos))
+ pos += 1
+ ret
+ } else {
+ null
+ }
+ }
+
+ def hasNext = nextPair != null
+
+ def next() = {
+ val pair = nextPair
+ nextPair = computeNextPair()
+ pair
+ }
+ }
+
+ // The following member variables are declared as protected instead of private for the
+ // specialization to work (specialized class extends the unspecialized one and needs access
+ // to the "private" variables).
+ // They also should have been val's. We use var's because there is a Scala compiler bug that
+ // would throw illegal access error at runtime if they are declared as val's.
+ protected var grow = (newCapacity: Int) => {
+ _oldValues = _values
+ _values = new Array[V](newCapacity)
+ }
+
+ protected var move = (oldPos: Int, newPos: Int) => {
+ _values(newPos) = _oldValues(oldPos)
+ }
+}
diff --git a/graphx/src/test/resources/log4j.properties b/graphx/src/test/resources/log4j.properties
new file mode 100644
index 0000000000..85e57f0c4b
--- /dev/null
+++ b/graphx/src/test/resources/log4j.properties
@@ -0,0 +1,28 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file core/target/unit-tests.log
+log4j.rootCategory=INFO, file
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=false
+log4j.appender.file.file=graphx/target/unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+
+# Ignore messages below warning level from Jetty, because it's a bit verbose
+log4j.logger.org.eclipse.jetty=WARN
+org.eclipse.jetty.LEVEL=WARN
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
new file mode 100644
index 0000000000..280f50e39a
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
@@ -0,0 +1,66 @@
+package org.apache.spark.graphx
+
+import org.apache.spark.SparkContext
+import org.apache.spark.graphx.Graph._
+import org.apache.spark.graphx.impl.EdgePartition
+import org.apache.spark.rdd._
+import org.scalatest.FunSuite
+
+class GraphOpsSuite extends FunSuite with LocalSparkContext {
+
+ test("joinVertices") {
+ withSpark { sc =>
+ val vertices =
+ sc.parallelize(Seq[(VertexID, String)]((1, "one"), (2, "two"), (3, "three")), 2)
+ val edges = sc.parallelize((Seq(Edge(1, 2, "onetwo"))))
+ val g: Graph[String, String] = Graph(vertices, edges)
+
+ val tbl = sc.parallelize(Seq[(VertexID, Int)]((1, 10), (2, 20)))
+ val g1 = g.joinVertices(tbl) { (vid: VertexID, attr: String, u: Int) => attr + u }
+
+ val v = g1.vertices.collect().toSet
+ assert(v === Set((1, "one10"), (2, "two20"), (3, "three")))
+ }
+ }
+
+ test("collectNeighborIds") {
+ withSpark { sc =>
+ val chain = (0 until 100).map(x => (x, (x+1)%100) )
+ val rawEdges = sc.parallelize(chain, 3).map { case (s,d) => (s.toLong, d.toLong) }
+ val graph = Graph.fromEdgeTuples(rawEdges, 1.0).cache()
+ val nbrs = graph.collectNeighborIds(EdgeDirection.Either).cache()
+ assert(nbrs.count === chain.size)
+ assert(graph.numVertices === nbrs.count)
+ nbrs.collect.foreach { case (vid, nbrs) => assert(nbrs.size === 2) }
+ nbrs.collect.foreach { case (vid, nbrs) =>
+ val s = nbrs.toSet
+ assert(s.contains((vid + 1) % 100))
+ assert(s.contains(if (vid > 0) vid - 1 else 99 ))
+ }
+ }
+ }
+
+ test ("filter") {
+ withSpark { sc =>
+ val n = 5
+ val vertices = sc.parallelize((0 to n).map(x => (x:VertexID, x)))
+ val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x)))
+ val graph: Graph[Int, Int] = Graph(vertices, edges).cache()
+ val filteredGraph = graph.filter(
+ graph => {
+ val degrees: VertexRDD[Int] = graph.outDegrees
+ graph.outerJoinVertices(degrees) {(vid, data, deg) => deg.getOrElse(0)}
+ },
+ vpred = (vid: VertexID, deg:Int) => deg > 0
+ ).cache()
+
+ val v = filteredGraph.vertices.collect().toSet
+ assert(v === Set((0,0)))
+
+ // the map is necessary because of object-reuse in the edge iterator
+ val e = filteredGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet
+ assert(e.isEmpty)
+ }
+ }
+
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
new file mode 100644
index 0000000000..9587f04c3e
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -0,0 +1,273 @@
+package org.apache.spark.graphx
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.graphx.Graph._
+import org.apache.spark.graphx.PartitionStrategy._
+import org.apache.spark.rdd._
+
+class GraphSuite extends FunSuite with LocalSparkContext {
+
+ def starGraph(sc: SparkContext, n: Int): Graph[String, Int] = {
+ Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: VertexID, x: VertexID)), 3), "v")
+ }
+
+ test("Graph.fromEdgeTuples") {
+ withSpark { sc =>
+ val ring = (0L to 100L).zip((1L to 99L) :+ 0L)
+ val doubleRing = ring ++ ring
+ val graph = Graph.fromEdgeTuples(sc.parallelize(doubleRing), 1)
+ assert(graph.edges.count() === doubleRing.size)
+ assert(graph.edges.collect.forall(e => e.attr == 1))
+
+ // uniqueEdges option should uniquify edges and store duplicate count in edge attributes
+ val uniqueGraph = Graph.fromEdgeTuples(sc.parallelize(doubleRing), 1, Some(RandomVertexCut))
+ assert(uniqueGraph.edges.count() === ring.size)
+ assert(uniqueGraph.edges.collect.forall(e => e.attr == 2))
+ }
+ }
+
+ test("Graph.fromEdges") {
+ withSpark { sc =>
+ val ring = (0L to 100L).zip((1L to 99L) :+ 0L).map { case (a, b) => Edge(a, b, 1) }
+ val graph = Graph.fromEdges(sc.parallelize(ring), 1.0F)
+ assert(graph.edges.count() === ring.size)
+ }
+ }
+
+ test("Graph.apply") {
+ withSpark { sc =>
+ val rawEdges = (0L to 98L).zip((1L to 99L) :+ 0L)
+ val edges: RDD[Edge[Int]] = sc.parallelize(rawEdges).map { case (s, t) => Edge(s, t, 1) }
+ val vertices: RDD[(VertexID, Boolean)] = sc.parallelize((0L until 10L).map(id => (id, true)))
+ val graph = Graph(vertices, edges, false)
+ assert( graph.edges.count() === rawEdges.size )
+ // Vertices not explicitly provided but referenced by edges should be created automatically
+ assert( graph.vertices.count() === 100)
+ graph.triplets.map { et =>
+ assert((et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr))
+ assert((et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr))
+ }
+ }
+ }
+
+ test("triplets") {
+ withSpark { sc =>
+ val n = 5
+ val star = starGraph(sc, n)
+ assert(star.triplets.map(et => (et.srcId, et.dstId, et.srcAttr, et.dstAttr)).collect.toSet ===
+ (1 to n).map(x => (0: VertexID, x: VertexID, "v", "v")).toSet)
+ }
+ }
+
+ test("partitionBy") {
+ withSpark { sc =>
+ def mkGraph(edges: List[(Long, Long)]) = Graph.fromEdgeTuples(sc.parallelize(edges, 2), 0)
+ def nonemptyParts(graph: Graph[Int, Int]) = {
+ graph.edges.partitionsRDD.mapPartitions { iter =>
+ Iterator(iter.next()._2.iterator.toList)
+ }.filter(_.nonEmpty)
+ }
+ val identicalEdges = List((0L, 1L), (0L, 1L))
+ val canonicalEdges = List((0L, 1L), (1L, 0L))
+ val sameSrcEdges = List((0L, 1L), (0L, 2L))
+
+ // The two edges start out in different partitions
+ for (edges <- List(identicalEdges, canonicalEdges, sameSrcEdges)) {
+ assert(nonemptyParts(mkGraph(edges)).count === 2)
+ }
+ // partitionBy(RandomVertexCut) puts identical edges in the same partition
+ assert(nonemptyParts(mkGraph(identicalEdges).partitionBy(RandomVertexCut)).count === 1)
+ // partitionBy(EdgePartition1D) puts same-source edges in the same partition
+ assert(nonemptyParts(mkGraph(sameSrcEdges).partitionBy(EdgePartition1D)).count === 1)
+ // partitionBy(CanonicalRandomVertexCut) puts edges that are identical modulo direction into
+ // the same partition
+ assert(nonemptyParts(mkGraph(canonicalEdges).partitionBy(CanonicalRandomVertexCut)).count === 1)
+ // partitionBy(EdgePartition2D) puts identical edges in the same partition
+ assert(nonemptyParts(mkGraph(identicalEdges).partitionBy(EdgePartition2D)).count === 1)
+
+ // partitionBy(EdgePartition2D) ensures that vertices need only be replicated to 2 * sqrt(p)
+ // partitions
+ val n = 100
+ val p = 100
+ val verts = 1 to n
+ val graph = Graph.fromEdgeTuples(sc.parallelize(verts.flatMap(x =>
+ verts.filter(y => y % x == 0).map(y => (x: VertexID, y: VertexID))), p), 0)
+ assert(graph.edges.partitions.length === p)
+ val partitionedGraph = graph.partitionBy(EdgePartition2D)
+ assert(graph.edges.partitions.length === p)
+ val bound = 2 * math.sqrt(p)
+ // Each vertex should be replicated to at most 2 * sqrt(p) partitions
+ val partitionSets = partitionedGraph.edges.partitionsRDD.mapPartitions { iter =>
+ val part = iter.next()._2
+ Iterator((part.srcIds ++ part.dstIds).toSet)
+ }.collect
+ assert(verts.forall(id => partitionSets.count(_.contains(id)) <= bound))
+ // This should not be true for the default hash partitioning
+ val partitionSetsUnpartitioned = graph.edges.partitionsRDD.mapPartitions { iter =>
+ val part = iter.next()._2
+ Iterator((part.srcIds ++ part.dstIds).toSet)
+ }.collect
+ assert(verts.exists(id => partitionSetsUnpartitioned.count(_.contains(id)) > bound))
+ }
+ }
+
+ test("mapVertices") {
+ withSpark { sc =>
+ val n = 5
+ val star = starGraph(sc, n)
+ // mapVertices preserving type
+ val mappedVAttrs = star.mapVertices((vid, attr) => attr + "2")
+ assert(mappedVAttrs.vertices.collect.toSet === (0 to n).map(x => (x: VertexID, "v2")).toSet)
+ // mapVertices changing type
+ val mappedVAttrs2 = star.mapVertices((vid, attr) => attr.length)
+ assert(mappedVAttrs2.vertices.collect.toSet === (0 to n).map(x => (x: VertexID, 1)).toSet)
+ }
+ }
+
+ test("mapEdges") {
+ withSpark { sc =>
+ val n = 3
+ val star = starGraph(sc, n)
+ val starWithEdgeAttrs = star.mapEdges(e => e.dstId)
+
+ val edges = starWithEdgeAttrs.edges.collect()
+ assert(edges.size === n)
+ assert(edges.toSet === (1 to n).map(x => Edge(0, x, x)).toSet)
+ }
+ }
+
+ test("mapTriplets") {
+ withSpark { sc =>
+ val n = 5
+ val star = starGraph(sc, n)
+ assert(star.mapTriplets(et => et.srcAttr + et.dstAttr).edges.collect.toSet ===
+ (1L to n).map(x => Edge(0, x, "vv")).toSet)
+ }
+ }
+
+ test("reverse") {
+ withSpark { sc =>
+ val n = 5
+ val star = starGraph(sc, n)
+ assert(star.reverse.outDegrees.collect.toSet === (1 to n).map(x => (x: VertexID, 1)).toSet)
+ }
+ }
+
+ test("subgraph") {
+ withSpark { sc =>
+ // Create a star graph of 10 veritces.
+ val n = 10
+ val star = starGraph(sc, n)
+ // Take only vertices whose vids are even
+ val subgraph = star.subgraph(vpred = (vid, attr) => vid % 2 == 0)
+
+ // We should have 5 vertices.
+ assert(subgraph.vertices.collect().toSet === (0 to n by 2).map(x => (x, "v")).toSet)
+
+ // And 4 edges.
+ assert(subgraph.edges.map(_.copy()).collect().toSet === (2 to n by 2).map(x => Edge(0, x, 1)).toSet)
+ }
+ }
+
+ test("mask") {
+ withSpark { sc =>
+ val n = 5
+ val vertices = sc.parallelize((0 to n).map(x => (x:VertexID, x)))
+ val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x)))
+ val graph: Graph[Int, Int] = Graph(vertices, edges).cache()
+
+ val subgraph = graph.subgraph(
+ e => e.dstId != 4L,
+ (vid, vdata) => vid != 3L
+ ).mapVertices((vid, vdata) => -1).mapEdges(e => -1)
+
+ val projectedGraph = graph.mask(subgraph)
+
+ val v = projectedGraph.vertices.collect().toSet
+ assert(v === Set((0,0), (1,1), (2,2), (4,4), (5,5)))
+
+ // the map is necessary because of object-reuse in the edge iterator
+ val e = projectedGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet
+ assert(e === Set(Edge(0,1,1), Edge(0,2,2), Edge(0,5,5)))
+
+ }
+ }
+
+ test("groupEdges") {
+ withSpark { sc =>
+ val n = 5
+ val star = starGraph(sc, n)
+ val doubleStar = Graph.fromEdgeTuples(
+ sc.parallelize((1 to n).flatMap(x =>
+ List((0: VertexID, x: VertexID), (0: VertexID, x: VertexID))), 1), "v")
+ val star2 = doubleStar.groupEdges { (a, b) => a}
+ assert(star2.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int]) ===
+ star.edges.collect.toArray.sorted(Edge.lexicographicOrdering[Int]))
+ assert(star2.vertices.collect.toSet === star.vertices.collect.toSet)
+ }
+ }
+
+ test("mapReduceTriplets") {
+ withSpark { sc =>
+ val n = 5
+ val star = starGraph(sc, n).mapVertices { (_, _) => 0 }.cache()
+ val starDeg = star.joinVertices(star.degrees){ (vid, oldV, deg) => deg }
+ val neighborDegreeSums = starDeg.mapReduceTriplets(
+ edge => Iterator((edge.srcId, edge.dstAttr), (edge.dstId, edge.srcAttr)),
+ (a: Int, b: Int) => a + b)
+ assert(neighborDegreeSums.collect().toSet === (0 to n).map(x => (x, n)).toSet)
+
+ // activeSetOpt
+ val allPairs = for (x <- 1 to n; y <- 1 to n) yield (x: VertexID, y: VertexID)
+ val complete = Graph.fromEdgeTuples(sc.parallelize(allPairs, 3), 0)
+ val vids = complete.mapVertices((vid, attr) => vid).cache()
+ val active = vids.vertices.filter { case (vid, attr) => attr % 2 == 0 }
+ val numEvenNeighbors = vids.mapReduceTriplets(et => {
+ // Map function should only run on edges with destination in the active set
+ if (et.dstId % 2 != 0) {
+ throw new Exception("map ran on edge with dst vid %d, which is odd".format(et.dstId))
+ }
+ Iterator((et.srcId, 1))
+ }, (a: Int, b: Int) => a + b, Some((active, EdgeDirection.In))).collect.toSet
+ assert(numEvenNeighbors === (1 to n).map(x => (x: VertexID, n / 2)).toSet)
+
+ // outerJoinVertices followed by mapReduceTriplets(activeSetOpt)
+ val ringEdges = sc.parallelize((0 until n).map(x => (x: VertexID, (x+1) % n: VertexID)), 3)
+ val ring = Graph.fromEdgeTuples(ringEdges, 0) .mapVertices((vid, attr) => vid).cache()
+ val changed = ring.vertices.filter { case (vid, attr) => attr % 2 == 1 }.mapValues(-_).cache()
+ val changedGraph = ring.outerJoinVertices(changed) { (vid, old, newOpt) => newOpt.getOrElse(old) }
+ val numOddNeighbors = changedGraph.mapReduceTriplets(et => {
+ // Map function should only run on edges with source in the active set
+ if (et.srcId % 2 != 1) {
+ throw new Exception("map ran on edge with src vid %d, which is even".format(et.dstId))
+ }
+ Iterator((et.dstId, 1))
+ }, (a: Int, b: Int) => a + b, Some(changed, EdgeDirection.Out)).collect.toSet
+ assert(numOddNeighbors === (2 to n by 2).map(x => (x: VertexID, 1)).toSet)
+
+ }
+ }
+
+ test("outerJoinVertices") {
+ withSpark { sc =>
+ val n = 5
+ val reverseStar = starGraph(sc, n).reverse.cache()
+ // outerJoinVertices changing type
+ val reverseStarDegrees =
+ reverseStar.outerJoinVertices(reverseStar.outDegrees) { (vid, a, bOpt) => bOpt.getOrElse(0) }
+ val neighborDegreeSums = reverseStarDegrees.mapReduceTriplets(
+ et => Iterator((et.srcId, et.dstAttr), (et.dstId, et.srcAttr)),
+ (a: Int, b: Int) => a + b).collect.toSet
+ assert(neighborDegreeSums === Set((0: VertexID, n)) ++ (1 to n).map(x => (x: VertexID, 0)))
+ // outerJoinVertices preserving type
+ val messages = reverseStar.vertices.mapValues { (vid, attr) => vid.toString }
+ val newReverseStar =
+ reverseStar.outerJoinVertices(messages) { (vid, a, bOpt) => a + bOpt.getOrElse("") }
+ assert(newReverseStar.vertices.map(_._2).collect.toSet ===
+ (0 to n).map(x => "v%d".format(x)).toSet)
+ }
+ }
+
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
new file mode 100644
index 0000000000..aa9ba84084
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
@@ -0,0 +1,28 @@
+package org.apache.spark.graphx
+
+import org.scalatest.Suite
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+
+/**
+ * Provides a method to run tests against a {@link SparkContext} variable that is correctly stopped
+ * after each test.
+*/
+trait LocalSparkContext {
+ /** Runs `f` on a new SparkContext and ensures that it is stopped afterwards. */
+ def withSpark[T](f: SparkContext => T) = {
+ val conf = new SparkConf()
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
+ val sc = new SparkContext("local", "test", conf)
+ try {
+ f(sc)
+ } finally {
+ sc.stop()
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.driver.port")
+ }
+ }
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala
new file mode 100644
index 0000000000..bceff11b8e
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala
@@ -0,0 +1,41 @@
+package org.apache.spark.graphx
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.rdd._
+
+class PregelSuite extends FunSuite with LocalSparkContext {
+
+ test("1 iteration") {
+ withSpark { sc =>
+ val n = 5
+ val starEdges = (1 to n).map(x => (0: VertexID, x: VertexID))
+ val star = Graph.fromEdgeTuples(sc.parallelize(starEdges, 3), "v").cache()
+ val result = Pregel(star, 0)(
+ (vid, attr, msg) => attr,
+ et => Iterator.empty,
+ (a: Int, b: Int) => throw new Exception("mergeMsg run unexpectedly"))
+ assert(result.vertices.collect.toSet === star.vertices.collect.toSet)
+ }
+ }
+
+ test("chain propagation") {
+ withSpark { sc =>
+ val n = 5
+ val chain = Graph.fromEdgeTuples(
+ sc.parallelize((1 until n).map(x => (x: VertexID, x + 1: VertexID)), 3),
+ 0).cache()
+ assert(chain.vertices.collect.toSet === (1 to n).map(x => (x: VertexID, 0)).toSet)
+ val chainWithSeed = chain.mapVertices { (vid, attr) => if (vid == 1) 1 else 0 }.cache()
+ assert(chainWithSeed.vertices.collect.toSet ===
+ Set((1: VertexID, 1)) ++ (2 to n).map(x => (x: VertexID, 0)).toSet)
+ val result = Pregel(chainWithSeed, 0)(
+ (vid, attr, msg) => math.max(msg, attr),
+ et => if (et.dstAttr != et.srcAttr) Iterator((et.dstId, et.srcAttr)) else Iterator.empty,
+ (a: Int, b: Int) => math.max(a, b))
+ assert(result.vertices.collect.toSet ===
+ chain.vertices.mapValues { (vid, attr) => attr + 1 }.collect.toSet)
+ }
+ }
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala
new file mode 100644
index 0000000000..3ba412c1f8
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala
@@ -0,0 +1,183 @@
+package org.apache.spark.graphx
+
+import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream}
+
+import scala.util.Random
+
+import org.scalatest.FunSuite
+
+import org.apache.spark._
+import org.apache.spark.graphx.impl._
+import org.apache.spark.graphx.impl.MsgRDDFunctions._
+import org.apache.spark.serializer.SerializationStream
+
+
+class SerializerSuite extends FunSuite with LocalSparkContext {
+
+ test("IntVertexBroadcastMsgSerializer") {
+ val conf = new SparkConf(false)
+ val outMsg = new VertexBroadcastMsg[Int](3, 4, 5)
+ val bout = new ByteArrayOutputStream
+ val outStrm = new IntVertexBroadcastMsgSerializer(conf).newInstance().serializeStream(bout)
+ outStrm.writeObject(outMsg)
+ outStrm.writeObject(outMsg)
+ bout.flush()
+ val bin = new ByteArrayInputStream(bout.toByteArray)
+ val inStrm = new IntVertexBroadcastMsgSerializer(conf).newInstance().deserializeStream(bin)
+ val inMsg1: VertexBroadcastMsg[Int] = inStrm.readObject()
+ val inMsg2: VertexBroadcastMsg[Int] = inStrm.readObject()
+ assert(outMsg.vid === inMsg1.vid)
+ assert(outMsg.vid === inMsg2.vid)
+ assert(outMsg.data === inMsg1.data)
+ assert(outMsg.data === inMsg2.data)
+
+ intercept[EOFException] {
+ inStrm.readObject()
+ }
+ }
+
+ test("LongVertexBroadcastMsgSerializer") {
+ val conf = new SparkConf(false)
+ val outMsg = new VertexBroadcastMsg[Long](3, 4, 5)
+ val bout = new ByteArrayOutputStream
+ val outStrm = new LongVertexBroadcastMsgSerializer(conf).newInstance().serializeStream(bout)
+ outStrm.writeObject(outMsg)
+ outStrm.writeObject(outMsg)
+ bout.flush()
+ val bin = new ByteArrayInputStream(bout.toByteArray)
+ val inStrm = new LongVertexBroadcastMsgSerializer(conf).newInstance().deserializeStream(bin)
+ val inMsg1: VertexBroadcastMsg[Long] = inStrm.readObject()
+ val inMsg2: VertexBroadcastMsg[Long] = inStrm.readObject()
+ assert(outMsg.vid === inMsg1.vid)
+ assert(outMsg.vid === inMsg2.vid)
+ assert(outMsg.data === inMsg1.data)
+ assert(outMsg.data === inMsg2.data)
+
+ intercept[EOFException] {
+ inStrm.readObject()
+ }
+ }
+
+ test("DoubleVertexBroadcastMsgSerializer") {
+ val conf = new SparkConf(false)
+ val outMsg = new VertexBroadcastMsg[Double](3, 4, 5.0)
+ val bout = new ByteArrayOutputStream
+ val outStrm = new DoubleVertexBroadcastMsgSerializer(conf).newInstance().serializeStream(bout)
+ outStrm.writeObject(outMsg)
+ outStrm.writeObject(outMsg)
+ bout.flush()
+ val bin = new ByteArrayInputStream(bout.toByteArray)
+ val inStrm = new DoubleVertexBroadcastMsgSerializer(conf).newInstance().deserializeStream(bin)
+ val inMsg1: VertexBroadcastMsg[Double] = inStrm.readObject()
+ val inMsg2: VertexBroadcastMsg[Double] = inStrm.readObject()
+ assert(outMsg.vid === inMsg1.vid)
+ assert(outMsg.vid === inMsg2.vid)
+ assert(outMsg.data === inMsg1.data)
+ assert(outMsg.data === inMsg2.data)
+
+ intercept[EOFException] {
+ inStrm.readObject()
+ }
+ }
+
+ test("IntAggMsgSerializer") {
+ val conf = new SparkConf(false)
+ val outMsg = (4: VertexID, 5)
+ val bout = new ByteArrayOutputStream
+ val outStrm = new IntAggMsgSerializer(conf).newInstance().serializeStream(bout)
+ outStrm.writeObject(outMsg)
+ outStrm.writeObject(outMsg)
+ bout.flush()
+ val bin = new ByteArrayInputStream(bout.toByteArray)
+ val inStrm = new IntAggMsgSerializer(conf).newInstance().deserializeStream(bin)
+ val inMsg1: (VertexID, Int) = inStrm.readObject()
+ val inMsg2: (VertexID, Int) = inStrm.readObject()
+ assert(outMsg === inMsg1)
+ assert(outMsg === inMsg2)
+
+ intercept[EOFException] {
+ inStrm.readObject()
+ }
+ }
+
+ test("LongAggMsgSerializer") {
+ val conf = new SparkConf(false)
+ val outMsg = (4: VertexID, 1L << 32)
+ val bout = new ByteArrayOutputStream
+ val outStrm = new LongAggMsgSerializer(conf).newInstance().serializeStream(bout)
+ outStrm.writeObject(outMsg)
+ outStrm.writeObject(outMsg)
+ bout.flush()
+ val bin = new ByteArrayInputStream(bout.toByteArray)
+ val inStrm = new LongAggMsgSerializer(conf).newInstance().deserializeStream(bin)
+ val inMsg1: (VertexID, Long) = inStrm.readObject()
+ val inMsg2: (VertexID, Long) = inStrm.readObject()
+ assert(outMsg === inMsg1)
+ assert(outMsg === inMsg2)
+
+ intercept[EOFException] {
+ inStrm.readObject()
+ }
+ }
+
+ test("DoubleAggMsgSerializer") {
+ val conf = new SparkConf(false)
+ val outMsg = (4: VertexID, 5.0)
+ val bout = new ByteArrayOutputStream
+ val outStrm = new DoubleAggMsgSerializer(conf).newInstance().serializeStream(bout)
+ outStrm.writeObject(outMsg)
+ outStrm.writeObject(outMsg)
+ bout.flush()
+ val bin = new ByteArrayInputStream(bout.toByteArray)
+ val inStrm = new DoubleAggMsgSerializer(conf).newInstance().deserializeStream(bin)
+ val inMsg1: (VertexID, Double) = inStrm.readObject()
+ val inMsg2: (VertexID, Double) = inStrm.readObject()
+ assert(outMsg === inMsg1)
+ assert(outMsg === inMsg2)
+
+ intercept[EOFException] {
+ inStrm.readObject()
+ }
+ }
+
+ test("TestShuffleVertexBroadcastMsg") {
+ withSpark { sc =>
+ val bmsgs = sc.parallelize(0 until 100, 10).map { pid =>
+ new VertexBroadcastMsg[Int](pid, pid, pid)
+ }
+ bmsgs.partitionBy(new HashPartitioner(3)).collect()
+ }
+ }
+
+ test("variable long encoding") {
+ def testVarLongEncoding(v: Long, optimizePositive: Boolean) {
+ val bout = new ByteArrayOutputStream
+ val stream = new ShuffleSerializationStream(bout) {
+ def writeObject[T](t: T): SerializationStream = {
+ writeVarLong(t.asInstanceOf[Long], optimizePositive = optimizePositive)
+ this
+ }
+ }
+ stream.writeObject(v)
+
+ val bin = new ByteArrayInputStream(bout.toByteArray)
+ val dstream = new ShuffleDeserializationStream(bin) {
+ def readObject[T](): T = {
+ readVarLong(optimizePositive).asInstanceOf[T]
+ }
+ }
+ val read = dstream.readObject[Long]()
+ assert(read === v)
+ }
+
+ // Test all variable encoding code path (each branch uses 7 bits, i.e. 1L << 7 difference)
+ val d = Random.nextLong() % 128
+ Seq[Long](0, 1L << 0 + d, 1L << 7 + d, 1L << 14 + d, 1L << 21 + d, 1L << 28 + d, 1L << 35 + d,
+ 1L << 42 + d, 1L << 49 + d, 1L << 56 + d, 1L << 63 + d).foreach { number =>
+ testVarLongEncoding(number, optimizePositive = false)
+ testVarLongEncoding(number, optimizePositive = true)
+ testVarLongEncoding(-number, optimizePositive = false)
+ testVarLongEncoding(-number, optimizePositive = true)
+ }
+ }
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
new file mode 100644
index 0000000000..d94a3aa67c
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala
@@ -0,0 +1,85 @@
+package org.apache.spark.graphx
+
+import org.apache.spark.SparkContext
+import org.apache.spark.graphx.Graph._
+import org.apache.spark.graphx.impl.EdgePartition
+import org.apache.spark.rdd._
+import org.scalatest.FunSuite
+
+class VertexRDDSuite extends FunSuite with LocalSparkContext {
+
+ def vertices(sc: SparkContext, n: Int) = {
+ VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5))
+ }
+
+ test("filter") {
+ withSpark { sc =>
+ val n = 100
+ val verts = vertices(sc, n)
+ val evens = verts.filter(q => ((q._2 % 2) == 0))
+ assert(evens.count === (0 to n).filter(_ % 2 == 0).size)
+ }
+ }
+
+ test("mapValues") {
+ withSpark { sc =>
+ val n = 100
+ val verts = vertices(sc, n)
+ val negatives = verts.mapValues(x => -x).cache() // Allow joining b with a derived RDD of b
+ assert(negatives.count === n + 1)
+ }
+ }
+
+ test("diff") {
+ withSpark { sc =>
+ val n = 100
+ val verts = vertices(sc, n).cache()
+ val flipEvens = verts.mapValues(x => if (x % 2 == 0) -x else x).cache()
+ // diff should keep only the changed vertices
+ assert(verts.diff(flipEvens).map(_._2).collect().toSet === (2 to n by 2).map(-_).toSet)
+ // diff should keep the vertex values from `other`
+ assert(flipEvens.diff(verts).map(_._2).collect().toSet === (2 to n by 2).toSet)
+ }
+ }
+
+ test("leftJoin") {
+ withSpark { sc =>
+ val n = 100
+ val verts = vertices(sc, n).cache()
+ val evens = verts.filter(q => ((q._2 % 2) == 0)).cache()
+ // leftJoin with another VertexRDD
+ assert(verts.leftJoin(evens) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet ===
+ (0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet)
+ // leftJoin with an RDD
+ val evensRDD = evens.map(identity)
+ assert(verts.leftJoin(evensRDD) { (id, a, bOpt) => a - bOpt.getOrElse(0) }.collect.toSet ===
+ (0 to n by 2).map(x => (x.toLong, 0)).toSet ++ (1 to n by 2).map(x => (x.toLong, x)).toSet)
+ }
+ }
+
+ test("innerJoin") {
+ withSpark { sc =>
+ val n = 100
+ val verts = vertices(sc, n).cache()
+ val evens = verts.filter(q => ((q._2 % 2) == 0)).cache()
+ // innerJoin with another VertexRDD
+ assert(verts.innerJoin(evens) { (id, a, b) => a - b }.collect.toSet ===
+ (0 to n by 2).map(x => (x.toLong, 0)).toSet)
+ // innerJoin with an RDD
+ val evensRDD = evens.map(identity)
+ assert(verts.innerJoin(evensRDD) { (id, a, b) => a - b }.collect.toSet ===
+ (0 to n by 2).map(x => (x.toLong, 0)).toSet) }
+ }
+
+ test("aggregateUsingIndex") {
+ withSpark { sc =>
+ val n = 100
+ val verts = vertices(sc, n)
+ val messageTargets = (0 to n) ++ (0 to n by 2)
+ val messages = sc.parallelize(messageTargets.map(x => (x.toLong, 1)))
+ assert(verts.aggregateUsingIndex[Int](messages, _ + _).collect.toSet ===
+ (0 to n).map(x => (x.toLong, if (x % 2 == 0) 2 else 1)).toSet)
+ }
+ }
+
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
new file mode 100644
index 0000000000..eb82436f09
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
@@ -0,0 +1,76 @@
+package org.apache.spark.graphx.impl
+
+import scala.reflect.ClassTag
+import scala.util.Random
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.graphx._
+
+class EdgePartitionSuite extends FunSuite {
+
+ test("reverse") {
+ val edges = List(Edge(0, 1, 0), Edge(1, 2, 0), Edge(2, 0, 0))
+ val reversedEdges = List(Edge(0, 2, 0), Edge(1, 0, 0), Edge(2, 1, 0))
+ val builder = new EdgePartitionBuilder[Int]
+ for (e <- edges) {
+ builder.add(e.srcId, e.dstId, e.attr)
+ }
+ val edgePartition = builder.toEdgePartition
+ assert(edgePartition.reverse.iterator.map(_.copy()).toList === reversedEdges)
+ assert(edgePartition.reverse.reverse.iterator.map(_.copy()).toList === edges)
+ }
+
+ test("map") {
+ val edges = List(Edge(0, 1, 0), Edge(1, 2, 0), Edge(2, 0, 0))
+ val builder = new EdgePartitionBuilder[Int]
+ for (e <- edges) {
+ builder.add(e.srcId, e.dstId, e.attr)
+ }
+ val edgePartition = builder.toEdgePartition
+ assert(edgePartition.map(e => e.srcId + e.dstId).iterator.map(_.copy()).toList ===
+ edges.map(e => e.copy(attr = e.srcId + e.dstId)))
+ }
+
+ test("groupEdges") {
+ val edges = List(
+ Edge(0, 1, 1), Edge(1, 2, 2), Edge(2, 0, 4), Edge(0, 1, 8), Edge(1, 2, 16), Edge(2, 0, 32))
+ val groupedEdges = List(Edge(0, 1, 9), Edge(1, 2, 18), Edge(2, 0, 36))
+ val builder = new EdgePartitionBuilder[Int]
+ for (e <- edges) {
+ builder.add(e.srcId, e.dstId, e.attr)
+ }
+ val edgePartition = builder.toEdgePartition
+ assert(edgePartition.groupEdges(_ + _).iterator.map(_.copy()).toList === groupedEdges)
+ }
+
+ test("indexIterator") {
+ val edgesFrom0 = List(Edge(0, 1, 0))
+ val edgesFrom1 = List(Edge(1, 0, 0), Edge(1, 2, 0))
+ val sortedEdges = edgesFrom0 ++ edgesFrom1
+ val builder = new EdgePartitionBuilder[Int]
+ for (e <- Random.shuffle(sortedEdges)) {
+ builder.add(e.srcId, e.dstId, e.attr)
+ }
+
+ val edgePartition = builder.toEdgePartition
+ assert(edgePartition.iterator.map(_.copy()).toList === sortedEdges)
+ assert(edgePartition.indexIterator(_ == 0).map(_.copy()).toList === edgesFrom0)
+ assert(edgePartition.indexIterator(_ == 1).map(_.copy()).toList === edgesFrom1)
+ }
+
+ test("innerJoin") {
+ def makeEdgePartition[A: ClassTag](xs: Iterable[(Int, Int, A)]): EdgePartition[A] = {
+ val builder = new EdgePartitionBuilder[A]
+ for ((src, dst, attr) <- xs) { builder.add(src: VertexID, dst: VertexID, attr) }
+ builder.toEdgePartition
+ }
+ val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0))
+ val bList = List((0, 1, 0), (1, 0, 0), (1, 1, 0), (3, 4, 0), (5, 5, 0))
+ val a = makeEdgePartition(aList)
+ val b = makeEdgePartition(bList)
+
+ assert(a.innerJoin(b) { (src, dst, a, b) => a }.iterator.map(_.copy()).toList ===
+ List(Edge(0, 1, 0), Edge(1, 0, 0), Edge(5, 5, 0)))
+ }
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
new file mode 100644
index 0000000000..d37d64e8c8
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
@@ -0,0 +1,113 @@
+package org.apache.spark.graphx.impl
+
+import org.apache.spark.graphx._
+import org.scalatest.FunSuite
+
+class VertexPartitionSuite extends FunSuite {
+
+ test("isDefined, filter") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).filter { (vid, attr) => vid == 0 }
+ assert(vp.isDefined(0))
+ assert(!vp.isDefined(1))
+ assert(!vp.isDefined(2))
+ assert(!vp.isDefined(-1))
+ }
+
+ test("isActive, numActives, replaceActives") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1)))
+ .filter { (vid, attr) => vid == 0 }
+ .replaceActives(Iterator(0, 2, 0))
+ assert(vp.isActive(0))
+ assert(!vp.isActive(1))
+ assert(vp.isActive(2))
+ assert(!vp.isActive(-1))
+ assert(vp.numActives == Some(2))
+ }
+
+ test("map") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).map { (vid, attr) => 2 }
+ assert(vp(0) === 2)
+ }
+
+ test("diff") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val vp2 = vp.filter { (vid, attr) => vid <= 1 }
+ val vp3a = vp.map { (vid, attr) => 2 }
+ val vp3b = VertexPartition(vp3a.iterator)
+ // diff with same index
+ val diff1 = vp2.diff(vp3a)
+ assert(diff1(0) === 2)
+ assert(diff1(1) === 2)
+ assert(diff1(2) === 2)
+ assert(!diff1.isDefined(2))
+ // diff with different indexes
+ val diff2 = vp2.diff(vp3b)
+ assert(diff2(0) === 2)
+ assert(diff2(1) === 2)
+ assert(diff2(2) === 2)
+ assert(!diff2.isDefined(2))
+ }
+
+ test("leftJoin") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val vp2a = vp.filter { (vid, attr) => vid <= 1 }.map { (vid, attr) => 2 }
+ val vp2b = VertexPartition(vp2a.iterator)
+ // leftJoin with same index
+ val join1 = vp.leftJoin(vp2a) { (vid, a, bOpt) => bOpt.getOrElse(a) }
+ assert(join1.iterator.toSet === Set((0L, 2), (1L, 2), (2L, 1)))
+ // leftJoin with different indexes
+ val join2 = vp.leftJoin(vp2b) { (vid, a, bOpt) => bOpt.getOrElse(a) }
+ assert(join2.iterator.toSet === Set((0L, 2), (1L, 2), (2L, 1)))
+ // leftJoin an iterator
+ val join3 = vp.leftJoin(vp2a.iterator) { (vid, a, bOpt) => bOpt.getOrElse(a) }
+ assert(join3.iterator.toSet === Set((0L, 2), (1L, 2), (2L, 1)))
+ }
+
+ test("innerJoin") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val vp2a = vp.filter { (vid, attr) => vid <= 1 }.map { (vid, attr) => 2 }
+ val vp2b = VertexPartition(vp2a.iterator)
+ // innerJoin with same index
+ val join1 = vp.innerJoin(vp2a) { (vid, a, b) => b }
+ assert(join1.iterator.toSet === Set((0L, 2), (1L, 2)))
+ // innerJoin with different indexes
+ val join2 = vp.innerJoin(vp2b) { (vid, a, b) => b }
+ assert(join2.iterator.toSet === Set((0L, 2), (1L, 2)))
+ // innerJoin an iterator
+ val join3 = vp.innerJoin(vp2a.iterator) { (vid, a, b) => b }
+ assert(join3.iterator.toSet === Set((0L, 2), (1L, 2)))
+ }
+
+ test("createUsingIndex") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val elems = List((0L, 2), (2L, 2), (3L, 2))
+ val vp2 = vp.createUsingIndex(elems.iterator)
+ assert(vp2.iterator.toSet === Set((0L, 2), (2L, 2)))
+ assert(vp.index === vp2.index)
+ }
+
+ test("innerJoinKeepLeft") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val elems = List((0L, 2), (2L, 2), (3L, 2))
+ val vp2 = vp.innerJoinKeepLeft(elems.iterator)
+ assert(vp2.iterator.toSet === Set((0L, 2), (2L, 2)))
+ assert(vp2(1) === 1)
+ }
+
+ test("aggregateUsingIndex") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val messages = List((0L, "a"), (2L, "b"), (0L, "c"), (3L, "d"))
+ val vp2 = vp.aggregateUsingIndex[String](messages.iterator, _ + _)
+ assert(vp2.iterator.toSet === Set((0L, "ac"), (2L, "b")))
+ }
+
+ test("reindex") {
+ val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1)))
+ val vp2 = vp.filter { (vid, attr) => vid <= 1 }
+ val vp3 = vp2.reindex()
+ assert(vp2.iterator.toSet === vp3.iterator.toSet)
+ assert(vp2(2) === 1)
+ assert(vp3.index.getPos(2) === -1)
+ }
+
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
new file mode 100644
index 0000000000..27c8705bca
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala
@@ -0,0 +1,113 @@
+package org.apache.spark.graphx.lib
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.util.GraphGenerators
+import org.apache.spark.rdd._
+
+
+class ConnectedComponentsSuite extends FunSuite with LocalSparkContext {
+
+ test("Grid Connected Components") {
+ withSpark { sc =>
+ val gridGraph = GraphGenerators.gridGraph(sc, 10, 10)
+ val ccGraph = gridGraph.connectedComponents()
+ val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
+ assert(maxCCid === 0)
+ }
+ } // end of Grid connected components
+
+
+ test("Reverse Grid Connected Components") {
+ withSpark { sc =>
+ val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse
+ val ccGraph = gridGraph.connectedComponents()
+ val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
+ assert(maxCCid === 0)
+ }
+ } // end of Grid connected components
+
+
+ test("Chain Connected Components") {
+ withSpark { sc =>
+ val chain1 = (0 until 9).map(x => (x, x+1) )
+ val chain2 = (10 until 20).map(x => (x, x+1) )
+ val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
+ val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0)
+ val ccGraph = twoChains.connectedComponents()
+ val vertices = ccGraph.vertices.collect()
+ for ( (id, cc) <- vertices ) {
+ if(id < 10) { assert(cc === 0) }
+ else { assert(cc === 10) }
+ }
+ val ccMap = vertices.toMap
+ for (id <- 0 until 20) {
+ if (id < 10) {
+ assert(ccMap(id) === 0)
+ } else {
+ assert(ccMap(id) === 10)
+ }
+ }
+ }
+ } // end of chain connected components
+
+ test("Reverse Chain Connected Components") {
+ withSpark { sc =>
+ val chain1 = (0 until 9).map(x => (x, x+1) )
+ val chain2 = (10 until 20).map(x => (x, x+1) )
+ val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
+ val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse
+ val ccGraph = twoChains.connectedComponents()
+ val vertices = ccGraph.vertices.collect
+ for ( (id, cc) <- vertices ) {
+ if (id < 10) {
+ assert(cc === 0)
+ } else {
+ assert(cc === 10)
+ }
+ }
+ val ccMap = vertices.toMap
+ for ( id <- 0 until 20 ) {
+ if (id < 10) {
+ assert(ccMap(id) === 0)
+ } else {
+ assert(ccMap(id) === 10)
+ }
+ }
+ }
+ } // end of reverse chain connected components
+
+ test("Connected Components on a Toy Connected Graph") {
+ withSpark { sc =>
+ // Create an RDD for the vertices
+ val users: RDD[(VertexID, (String, String))] =
+ sc.parallelize(Array((3L, ("rxin", "student")), (7L, ("jgonzal", "postdoc")),
+ (5L, ("franklin", "prof")), (2L, ("istoica", "prof")),
+ (4L, ("peter", "student"))))
+ // Create an RDD for edges
+ val relationships: RDD[Edge[String]] =
+ sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"),
+ Edge(2L, 5L, "colleague"), Edge(5L, 7L, "pi"),
+ Edge(4L, 0L, "student"), Edge(5L, 0L, "colleague")))
+ // Edges are:
+ // 2 ---> 5 ---> 3
+ // | \
+ // V \|
+ // 4 ---> 0 7
+ //
+ // Define a default user in case there are relationship with missing user
+ val defaultUser = ("John Doe", "Missing")
+ // Build the initial Graph
+ val graph = Graph(users, relationships, defaultUser)
+ val ccGraph = graph.connectedComponents()
+ val vertices = ccGraph.vertices.collect
+ for ( (id, cc) <- vertices ) {
+ assert(cc == 0)
+ }
+ }
+ } // end of toy connected components
+
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
new file mode 100644
index 0000000000..fe7e4261f8
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala
@@ -0,0 +1,119 @@
+package org.apache.spark.graphx.lib
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.lib._
+import org.apache.spark.graphx.util.GraphGenerators
+import org.apache.spark.rdd._
+
+object GridPageRank {
+ def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double) = {
+ val inNbrs = Array.fill(nRows * nCols)(collection.mutable.MutableList.empty[Int])
+ val outDegree = Array.fill(nRows * nCols)(0)
+ // Convert row column address into vertex ids (row major order)
+ def sub2ind(r: Int, c: Int): Int = r * nCols + c
+ // Make the grid graph
+ for (r <- 0 until nRows; c <- 0 until nCols) {
+ val ind = sub2ind(r,c)
+ if (r+1 < nRows) {
+ outDegree(ind) += 1
+ inNbrs(sub2ind(r+1,c)) += ind
+ }
+ if (c+1 < nCols) {
+ outDegree(ind) += 1
+ inNbrs(sub2ind(r,c+1)) += ind
+ }
+ }
+ // compute the pagerank
+ var pr = Array.fill(nRows * nCols)(resetProb)
+ for (iter <- 0 until nIter) {
+ val oldPr = pr
+ pr = new Array[Double](nRows * nCols)
+ for (ind <- 0 until (nRows * nCols)) {
+ pr(ind) = resetProb + (1.0 - resetProb) *
+ inNbrs(ind).map( nbr => oldPr(nbr) / outDegree(nbr)).sum
+ }
+ }
+ (0L until (nRows * nCols)).zip(pr)
+ }
+
+}
+
+
+class PageRankSuite extends FunSuite with LocalSparkContext {
+
+ def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = {
+ a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) }
+ .map { case (id, error) => error }.sum
+ }
+
+ test("Star PageRank") {
+ withSpark { sc =>
+ val nVertices = 100
+ val starGraph = GraphGenerators.starGraph(sc, nVertices).cache()
+ val resetProb = 0.15
+ val errorTol = 1.0e-5
+
+ val staticRanks1 = starGraph.staticPageRank(numIter = 1, resetProb).vertices
+ val staticRanks2 = starGraph.staticPageRank(numIter = 2, resetProb).vertices.cache()
+
+ // Static PageRank should only take 2 iterations to converge
+ val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) =>
+ if (pr1 != pr2) 1 else 0
+ }.map { case (vid, test) => test }.sum
+ assert(notMatching === 0)
+
+ val staticErrors = staticRanks2.map { case (vid, pr) =>
+ val correct = (vid > 0 && pr == resetProb) ||
+ (vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) < 1.0E-5)
+ if (!correct) 1 else 0
+ }
+ assert(staticErrors.sum === 0)
+
+ val dynamicRanks = starGraph.pageRank(0, resetProb).vertices.cache()
+ assert(compareRanks(staticRanks2, dynamicRanks) < errorTol)
+ }
+ } // end of test Star PageRank
+
+
+
+ test("Grid PageRank") {
+ withSpark { sc =>
+ val rows = 10
+ val cols = 10
+ val resetProb = 0.15
+ val tol = 0.0001
+ val numIter = 50
+ val errorTol = 1.0e-5
+ val gridGraph = GraphGenerators.gridGraph(sc, rows, cols).cache()
+
+ val staticRanks = gridGraph.staticPageRank(numIter, resetProb).vertices.cache()
+ val dynamicRanks = gridGraph.pageRank(tol, resetProb).vertices.cache()
+ val referenceRanks = VertexRDD(sc.parallelize(GridPageRank(rows, cols, numIter, resetProb))).cache()
+
+ assert(compareRanks(staticRanks, referenceRanks) < errorTol)
+ assert(compareRanks(dynamicRanks, referenceRanks) < errorTol)
+ }
+ } // end of Grid PageRank
+
+
+ test("Chain PageRank") {
+ withSpark { sc =>
+ val chain1 = (0 until 9).map(x => (x, x+1) )
+ val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) }
+ val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache()
+ val resetProb = 0.15
+ val tol = 0.0001
+ val numIter = 10
+ val errorTol = 1.0e-5
+
+ val staticRanks = chain.staticPageRank(numIter, resetProb).vertices
+ val dynamicRanks = chain.pageRank(tol, resetProb).vertices
+
+ assert(compareRanks(staticRanks, dynamicRanks) < errorTol)
+ }
+ }
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
new file mode 100644
index 0000000000..e173c652a5
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
@@ -0,0 +1,31 @@
+package org.apache.spark.graphx.lib
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.util.GraphGenerators
+import org.apache.spark.rdd._
+
+
+class SVDPlusPlusSuite extends FunSuite with LocalSparkContext {
+
+ test("Test SVD++ with mean square error on training set") {
+ withSpark { sc =>
+ val svdppErr = 8.0
+ val edges = sc.textFile("mllib/data/als/test.data").map { line =>
+ val fields = line.split(",")
+ Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble)
+ }
+ val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations
+ var (graph, u) = SVDPlusPlus.run(edges, conf)
+ graph.cache()
+ val err = graph.vertices.collect.map{ case (vid, vd) =>
+ if (vid % 2 == 1) vd._4 else 0.0
+ }.reduce(_ + _) / graph.triplets.collect.size
+ assert(err <= svdppErr)
+ }
+ }
+
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala
new file mode 100644
index 0000000000..0458311661
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala
@@ -0,0 +1,57 @@
+package org.apache.spark.graphx.lib
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext._
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.util.GraphGenerators
+import org.apache.spark.rdd._
+
+
+class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext {
+
+ test("Island Strongly Connected Components") {
+ withSpark { sc =>
+ val vertices = sc.parallelize((1L to 5L).map(x => (x, -1)))
+ val edges = sc.parallelize(Seq.empty[Edge[Int]])
+ val graph = Graph(vertices, edges)
+ val sccGraph = graph.stronglyConnectedComponents(5)
+ for ((id, scc) <- sccGraph.vertices.collect) {
+ assert(id == scc)
+ }
+ }
+ }
+
+ test("Cycle Strongly Connected Components") {
+ withSpark { sc =>
+ val rawEdges = sc.parallelize((0L to 6L).map(x => (x, (x + 1) % 7)))
+ val graph = Graph.fromEdgeTuples(rawEdges, -1)
+ val sccGraph = graph.stronglyConnectedComponents(20)
+ for ((id, scc) <- sccGraph.vertices.collect) {
+ assert(0L == scc)
+ }
+ }
+ }
+
+ test("2 Cycle Strongly Connected Components") {
+ withSpark { sc =>
+ val edges =
+ Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
+ Array(3L -> 4L, 4L -> 5L, 5L -> 3L) ++
+ Array(6L -> 0L, 5L -> 7L)
+ val rawEdges = sc.parallelize(edges)
+ val graph = Graph.fromEdgeTuples(rawEdges, -1)
+ val sccGraph = graph.stronglyConnectedComponents(20)
+ for ((id, scc) <- sccGraph.vertices.collect) {
+ if (id < 3)
+ assert(0L == scc)
+ else if (id < 6)
+ assert(3L == scc)
+ else
+ assert(id == scc)
+ }
+ }
+ }
+
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala
new file mode 100644
index 0000000000..3452ce9764
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala
@@ -0,0 +1,70 @@
+package org.apache.spark.graphx.lib
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.PartitionStrategy.RandomVertexCut
+
+
+class TriangleCountSuite extends FunSuite with LocalSparkContext {
+
+ test("Count a single triangle") {
+ withSpark { sc =>
+ val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ), 2)
+ val graph = Graph.fromEdgeTuples(rawEdges, true).cache()
+ val triangleCount = graph.triangleCount()
+ val verts = triangleCount.vertices
+ verts.collect.foreach { case (vid, count) => assert(count === 1) }
+ }
+ }
+
+ test("Count two triangles") {
+ withSpark { sc =>
+ val triangles = Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
+ Array(0L -> -1L, -1L -> -2L, -2L -> 0L)
+ val rawEdges = sc.parallelize(triangles, 2)
+ val graph = Graph.fromEdgeTuples(rawEdges, true).cache()
+ val triangleCount = graph.triangleCount()
+ val verts = triangleCount.vertices
+ verts.collect().foreach { case (vid, count) =>
+ if (vid == 0) {
+ assert(count === 2)
+ } else {
+ assert(count === 1)
+ }
+ }
+ }
+ }
+
+ test("Count two triangles with bi-directed edges") {
+ withSpark { sc =>
+ val triangles =
+ Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
+ Array(0L -> -1L, -1L -> -2L, -2L -> 0L)
+ val revTriangles = triangles.map { case (a,b) => (b,a) }
+ val rawEdges = sc.parallelize(triangles ++ revTriangles, 2)
+ val graph = Graph.fromEdgeTuples(rawEdges, true).cache()
+ val triangleCount = graph.triangleCount()
+ val verts = triangleCount.vertices
+ verts.collect().foreach { case (vid, count) =>
+ if (vid == 0) {
+ assert(count === 4)
+ } else {
+ assert(count === 2)
+ }
+ }
+ }
+ }
+
+ test("Count a single triangle with duplicate edges") {
+ withSpark { sc =>
+ val rawEdges = sc.parallelize(Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
+ Array(0L -> 1L, 1L -> 2L, 2L -> 0L), 2)
+ val graph = Graph.fromEdgeTuples(rawEdges, true, uniqueEdges = Some(RandomVertexCut)).cache()
+ val triangleCount = graph.triangleCount()
+ val verts = triangleCount.vertices
+ verts.collect.foreach { case (vid, count) => assert(count === 1) }
+ }
+ }
+
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala
new file mode 100644
index 0000000000..11db339750
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala
@@ -0,0 +1,93 @@
+package org.apache.spark.graphx.util
+
+import org.scalatest.FunSuite
+
+
+class BytecodeUtilsSuite extends FunSuite {
+
+ import BytecodeUtilsSuite.TestClass
+
+ test("closure invokes a method") {
+ val c1 = {e: TestClass => println(e.foo); println(e.bar); println(e.baz); }
+ assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "foo"))
+ assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "bar"))
+ assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "baz"))
+
+ val c2 = {e: TestClass => println(e.foo); println(e.bar); }
+ assert(BytecodeUtils.invokedMethod(c2, classOf[TestClass], "foo"))
+ assert(BytecodeUtils.invokedMethod(c2, classOf[TestClass], "bar"))
+ assert(!BytecodeUtils.invokedMethod(c2, classOf[TestClass], "baz"))
+
+ val c3 = {e: TestClass => println(e.foo); }
+ assert(BytecodeUtils.invokedMethod(c3, classOf[TestClass], "foo"))
+ assert(!BytecodeUtils.invokedMethod(c3, classOf[TestClass], "bar"))
+ assert(!BytecodeUtils.invokedMethod(c3, classOf[TestClass], "baz"))
+ }
+
+ test("closure inside a closure invokes a method") {
+ val c1 = {e: TestClass => println(e.foo); println(e.bar); println(e.baz); }
+ val c2 = {e: TestClass => c1(e); println(e.foo); }
+ assert(BytecodeUtils.invokedMethod(c2, classOf[TestClass], "foo"))
+ assert(BytecodeUtils.invokedMethod(c2, classOf[TestClass], "bar"))
+ assert(BytecodeUtils.invokedMethod(c2, classOf[TestClass], "baz"))
+ }
+
+ test("closure inside a closure inside a closure invokes a method") {
+ val c1 = {e: TestClass => println(e.baz); }
+ val c2 = {e: TestClass => c1(e); println(e.foo); }
+ val c3 = {e: TestClass => c2(e) }
+ assert(BytecodeUtils.invokedMethod(c3, classOf[TestClass], "foo"))
+ assert(!BytecodeUtils.invokedMethod(c3, classOf[TestClass], "bar"))
+ assert(BytecodeUtils.invokedMethod(c3, classOf[TestClass], "baz"))
+ }
+
+ test("closure calling a function that invokes a method") {
+ def zoo(e: TestClass) {
+ println(e.baz)
+ }
+ val c1 = {e: TestClass => zoo(e)}
+ assert(!BytecodeUtils.invokedMethod(c1, classOf[TestClass], "foo"))
+ assert(!BytecodeUtils.invokedMethod(c1, classOf[TestClass], "bar"))
+ assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "baz"))
+ }
+
+ test("closure calling a function that invokes a method which uses another closure") {
+ val c2 = {e: TestClass => println(e.baz)}
+ def zoo(e: TestClass) {
+ c2(e)
+ }
+ val c1 = {e: TestClass => zoo(e)}
+ assert(!BytecodeUtils.invokedMethod(c1, classOf[TestClass], "foo"))
+ assert(!BytecodeUtils.invokedMethod(c1, classOf[TestClass], "bar"))
+ assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "baz"))
+ }
+
+ test("nested closure") {
+ val c2 = {e: TestClass => println(e.baz)}
+ def zoo(e: TestClass, c: TestClass => Unit) {
+ c(e)
+ }
+ val c1 = {e: TestClass => zoo(e, c2)}
+ assert(!BytecodeUtils.invokedMethod(c1, classOf[TestClass], "foo"))
+ assert(!BytecodeUtils.invokedMethod(c1, classOf[TestClass], "bar"))
+ assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "baz"))
+ }
+
+ // The following doesn't work yet, because the byte code doesn't contain any information
+ // about what exactly "c" is.
+// test("invoke interface") {
+// val c1 = {e: TestClass => c(e)}
+// assert(!BytecodeUtils.invokedMethod(c1, classOf[TestClass], "foo"))
+// assert(!BytecodeUtils.invokedMethod(c1, classOf[TestClass], "bar"))
+// assert(BytecodeUtils.invokedMethod(c1, classOf[TestClass], "baz"))
+// }
+
+ private val c = {e: TestClass => println(e.baz)}
+}
+
+
+object BytecodeUtilsSuite {
+ class TestClass(val foo: Int, val bar: Long) {
+ def baz: Boolean = false
+ }
+}
diff --git a/mllib/data/sample_naive_bayes_data.txt b/mllib/data/sample_naive_bayes_data.txt
new file mode 100644
index 0000000000..f874adbaf4
--- /dev/null
+++ b/mllib/data/sample_naive_bayes_data.txt
@@ -0,0 +1,6 @@
+0, 1 0 0
+0, 2 0 0
+1, 0 1 0
+1, 0 2 0
+2, 0 0 1
+2, 0 0 2
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 2d8623392e..3fec1a909d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -48,7 +48,7 @@ class PythonMLLibAPI extends Serializable {
val db = bb.asDoubleBuffer()
val ans = new Array[Double](length.toInt)
db.get(ans)
- return ans
+ ans
}
private def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = {
@@ -60,7 +60,7 @@ class PythonMLLibAPI extends Serializable {
bb.putLong(len)
val db = bb.asDoubleBuffer()
db.put(doubles)
- return bytes
+ bytes
}
private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = {
@@ -86,7 +86,7 @@ class PythonMLLibAPI extends Serializable {
ans(i) = new Array[Double](cols.toInt)
db.get(ans(i))
}
- return ans
+ ans
}
private def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = {
@@ -102,11 +102,10 @@ class PythonMLLibAPI extends Serializable {
bb.putLong(rows)
bb.putLong(cols)
val db = bb.asDoubleBuffer()
- var i = 0
for (i <- 0 until rows) {
db.put(doubles(i))
}
- return bytes
+ bytes
}
private def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel,
@@ -121,7 +120,7 @@ class PythonMLLibAPI extends Serializable {
val ret = new java.util.LinkedList[java.lang.Object]()
ret.add(serializeDoubleVector(model.weights))
ret.add(model.intercept: java.lang.Double)
- return ret
+ ret
}
/**
@@ -130,7 +129,7 @@ class PythonMLLibAPI extends Serializable {
def trainLinearRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]],
numIterations: Int, stepSize: Double, miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
- return trainRegressionModel((data, initialWeights) =>
+ trainRegressionModel((data, initialWeights) =>
LinearRegressionWithSGD.train(data, numIterations, stepSize,
miniBatchFraction, initialWeights),
dataBytesJRDD, initialWeightsBA)
@@ -142,7 +141,7 @@ class PythonMLLibAPI extends Serializable {
def trainLassoModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
stepSize: Double, regParam: Double, miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
- return trainRegressionModel((data, initialWeights) =>
+ trainRegressionModel((data, initialWeights) =>
LassoWithSGD.train(data, numIterations, stepSize, regParam,
miniBatchFraction, initialWeights),
dataBytesJRDD, initialWeightsBA)
@@ -154,7 +153,7 @@ class PythonMLLibAPI extends Serializable {
def trainRidgeModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
stepSize: Double, regParam: Double, miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
- return trainRegressionModel((data, initialWeights) =>
+ trainRegressionModel((data, initialWeights) =>
RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam,
miniBatchFraction, initialWeights),
dataBytesJRDD, initialWeightsBA)
@@ -166,7 +165,7 @@ class PythonMLLibAPI extends Serializable {
def trainSVMModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
stepSize: Double, regParam: Double, miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
- return trainRegressionModel((data, initialWeights) =>
+ trainRegressionModel((data, initialWeights) =>
SVMWithSGD.train(data, numIterations, stepSize, regParam,
miniBatchFraction, initialWeights),
dataBytesJRDD, initialWeightsBA)
@@ -178,13 +177,30 @@ class PythonMLLibAPI extends Serializable {
def trainLogisticRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]],
numIterations: Int, stepSize: Double, miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
- return trainRegressionModel((data, initialWeights) =>
+ trainRegressionModel((data, initialWeights) =>
LogisticRegressionWithSGD.train(data, numIterations, stepSize,
miniBatchFraction, initialWeights),
dataBytesJRDD, initialWeightsBA)
}
/**
+ * Java stub for NaiveBayes.train()
+ */
+ def trainNaiveBayes(dataBytesJRDD: JavaRDD[Array[Byte]], lambda: Double)
+ : java.util.List[java.lang.Object] =
+ {
+ val data = dataBytesJRDD.rdd.map(xBytes => {
+ val x = deserializeDoubleVector(xBytes)
+ LabeledPoint(x(0), x.slice(1, x.length))
+ })
+ val model = NaiveBayes.train(data, lambda)
+ val ret = new java.util.LinkedList[java.lang.Object]()
+ ret.add(serializeDoubleVector(model.pi))
+ ret.add(serializeDoubleMatrix(model.theta))
+ ret
+ }
+
+ /**
* Java stub for Python mllib KMeans.train()
*/
def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int,
@@ -194,7 +210,7 @@ class PythonMLLibAPI extends Serializable {
val model = KMeans.train(data, k, maxIterations, runs, initializationMode)
val ret = new java.util.LinkedList[java.lang.Object]()
ret.add(serializeDoubleMatrix(model.clusterCenters))
- return ret
+ ret
}
/** Unpack a Rating object from an array of bytes */
@@ -204,7 +220,7 @@ class PythonMLLibAPI extends Serializable {
val user = bb.getInt()
val product = bb.getInt()
val rating = bb.getDouble()
- return new Rating(user, product, rating)
+ new Rating(user, product, rating)
}
/** Unpack a tuple of Ints from an array of bytes */
@@ -245,7 +261,7 @@ class PythonMLLibAPI extends Serializable {
def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = {
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
- return ALS.train(ratings, rank, iterations, lambda, blocks)
+ ALS.train(ratings, rank, iterations, lambda, blocks)
}
/**
@@ -257,6 +273,6 @@ class PythonMLLibAPI extends Serializable {
def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = {
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
- return ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
+ ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 50aede9c07..a481f52276 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -97,7 +97,7 @@ object LogisticRegressionWithSGD {
* @param numIterations Number of iterations of gradient descent to run.
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param miniBatchFraction Fraction of data to be used per iteration.
- * @param initialWeights Initial set of weights to be used. Array should be equal in size to
+ * @param initialWeights Initial set of weights to be used. Array should be equal in size to
* the number of features in the data.
*/
def train(
@@ -183,6 +183,8 @@ object LogisticRegressionWithSGD {
val sc = new SparkContext(args(0), "LogisticRegression")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble)
+ println("Weights: " + model.weights.mkString("[", ", ", "]"))
+ println("Intercept: " + model.intercept)
sc.stop()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 524300d6ae..6539b2f339 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -21,17 +21,18 @@ import scala.collection.mutable
import org.jblas.DoubleMatrix
-import org.apache.spark.Logging
+import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.util.MLUtils
/**
* Model for Naive Bayes Classifiers.
*
* @param pi Log of class priors, whose dimension is C.
- * @param theta Log of class conditional probabilities, whose dimension is CXD.
+ * @param theta Log of class conditional probabilities, whose dimension is CxD.
*/
-class NaiveBayesModel(pi: Array[Double], theta: Array[Array[Double]])
+class NaiveBayesModel(val pi: Array[Double], val theta: Array[Array[Double]])
extends ClassificationModel with Serializable {
// Create a column vector that can be used for predictions
@@ -50,10 +51,21 @@ class NaiveBayesModel(pi: Array[Double], theta: Array[Array[Double]])
/**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
*
- * @param lambda The smooth parameter
+ * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
+ * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
+ * document classification. By making every vector a 0-1 vector, it can also be used as
+ * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
*/
-class NaiveBayes private (val lambda: Double = 1.0)
- extends Serializable with Logging {
+class NaiveBayes private (var lambda: Double)
+ extends Serializable with Logging
+{
+ def this() = this(1.0)
+
+ /** Set the smoothing parameter. Default: 1.0. */
+ def setLambda(lambda: Double): NaiveBayes = {
+ this.lambda = lambda
+ this
+ }
/**
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
@@ -106,14 +118,49 @@ object NaiveBayes {
*
* This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
* discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
- * document classification. By making every vector a 0-1 vector. it can also be used as
+ * document classification. By making every vector a 0-1 vector, it can also be used as
+ * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
+ *
+ * This version of the method uses a default smoothing parameter of 1.0.
+ *
+ * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
+ * vector or a count vector.
+ */
+ def train(input: RDD[LabeledPoint]): NaiveBayesModel = {
+ new NaiveBayes().run(input)
+ }
+
+ /**
+ * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
+ *
+ * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
+ * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
+ * document classification. By making every vector a 0-1 vector, it can also be used as
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
*
* @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
* vector or a count vector.
- * @param lambda The smooth parameter
+ * @param lambda The smoothing parameter
*/
- def train(input: RDD[LabeledPoint], lambda: Double = 1.0): NaiveBayesModel = {
+ def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
new NaiveBayes(lambda).run(input)
}
+
+ def main(args: Array[String]) {
+ if (args.length != 2 && args.length != 3) {
+ println("Usage: NaiveBayes <master> <input_dir> [<lambda>]")
+ System.exit(1)
+ }
+ val sc = new SparkContext(args(0), "NaiveBayes")
+ val data = MLUtils.loadLabeledData(sc, args(1))
+ val model = if (args.length == 2) {
+ NaiveBayes.train(data)
+ } else {
+ NaiveBayes.train(data, args(2).toDouble)
+ }
+ println("Pi: " + model.pi.mkString("[", ", ", "]"))
+ println("Theta:\n" + model.theta.map(_.mkString("[", ", ", "]")).mkString("[", "\n ", "]"))
+
+ sc.stop()
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index 3b8f8550d0..f2964ea446 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -183,6 +183,8 @@ object SVMWithSGD {
val sc = new SparkContext(args(0), "SVM")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
+ println("Weights: " + model.weights.mkString("[", ", ", "]"))
+ println("Intercept: " + model.intercept)
sc.stop()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 8b27ecf82c..89ee07063d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -22,7 +22,7 @@ import scala.util.Random
import scala.util.Sorting
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.{Logging, HashPartitioner, Partitioner, SparkContext}
+import org.apache.spark.{Logging, HashPartitioner, Partitioner, SparkContext, SparkConf}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.KryoRegistrator
@@ -578,12 +578,13 @@ object ALS {
val implicitPrefs = if (args.length >= 7) args(6).toBoolean else false
val alpha = if (args.length >= 8) args(7).toDouble else 1
val blocks = if (args.length == 9) args(8).toInt else -1
- val sc = new SparkContext(master, "ALS")
- sc.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- sc.conf.set("spark.kryo.registrator", classOf[ALSRegistrator].getName)
- sc.conf.set("spark.kryo.referenceTracking", "false")
- sc.conf.set("spark.kryoserializer.buffer.mb", "8")
- sc.conf.set("spark.locality.wait", "10000")
+ val conf = new SparkConf()
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .set("spark.kryo.registrator", classOf[ALSRegistrator].getName)
+ .set("spark.kryo.referenceTracking", "false")
+ .set("spark.kryoserializer.buffer.mb", "8")
+ .set("spark.locality.wait", "10000")
+ val sc = new SparkContext(master, "ALS", conf)
val ratings = sc.textFile(ratingsFile).map { line =>
val fields = line.split(',')
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
index 63240e24dc..1a18292fe3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
@@ -23,4 +23,8 @@ package org.apache.spark.mllib.regression
* @param label Label for this data point.
* @param features List of features for this data point.
*/
-case class LabeledPoint(val label: Double, val features: Array[Double])
+case class LabeledPoint(label: Double, features: Array[Double]) {
+ override def toString: String = {
+ "LabeledPoint(%s, %s)".format(label, features.mkString("[", ", ", "]"))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index d959695325..7c41793722 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -121,7 +121,7 @@ object LassoWithSGD {
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param regParam Regularization parameter.
* @param miniBatchFraction Fraction of data to be used per iteration.
- * @param initialWeights Initial set of weights to be used. Array should be equal in size to
+ * @param initialWeights Initial set of weights to be used. Array should be equal in size to
* the number of features in the data.
*/
def train(
@@ -205,6 +205,8 @@ object LassoWithSGD {
val sc = new SparkContext(args(0), "Lasso")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = LassoWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
+ println("Weights: " + model.weights.mkString("[", ", ", "]"))
+ println("Intercept: " + model.intercept)
sc.stop()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index 597d55e0bb..fe5cce064b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -162,6 +162,8 @@ object LinearRegressionWithSGD {
val sc = new SparkContext(args(0), "LinearRegression")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble)
+ println("Weights: " + model.weights.mkString("[", ", ", "]"))
+ println("Intercept: " + model.intercept)
sc.stop()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index b29508d2b9..c125c6797a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -122,7 +122,7 @@ object RidgeRegressionWithSGD {
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param regParam Regularization parameter.
* @param miniBatchFraction Fraction of data to be used per iteration.
- * @param initialWeights Initial set of weights to be used. Array should be equal in size to
+ * @param initialWeights Initial set of weights to be used. Array should be equal in size to
* the number of features in the data.
*/
def train(
@@ -208,6 +208,8 @@ object RidgeRegressionWithSGD {
val data = MLUtils.loadLabeledData(sc, args(1))
val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble,
args(3).toDouble)
+ println("Weights: " + model.weights.mkString("[", ", ", "]"))
+ println("Intercept: " + model.intercept)
sc.stop()
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
new file mode 100644
index 0000000000..23ea3548b9
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -0,0 +1,72 @@
+package org.apache.spark.mllib.classification;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+public class JavaNaiveBayesSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaNaiveBayesSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ System.clearProperty("spark.driver.port");
+ }
+
+ private static final List<LabeledPoint> POINTS = Arrays.asList(
+ new LabeledPoint(0, new double[] {1.0, 0.0, 0.0}),
+ new LabeledPoint(0, new double[] {2.0, 0.0, 0.0}),
+ new LabeledPoint(1, new double[] {0.0, 1.0, 0.0}),
+ new LabeledPoint(1, new double[] {0.0, 2.0, 0.0}),
+ new LabeledPoint(2, new double[] {0.0, 0.0, 1.0}),
+ new LabeledPoint(2, new double[] {0.0, 0.0, 2.0})
+ );
+
+ private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) {
+ int correct = 0;
+ for (LabeledPoint p: points) {
+ if (model.predict(p.features()) == p.label()) {
+ correct += 1;
+ }
+ }
+ return correct;
+ }
+
+ @Test
+ public void runUsingConstructor() {
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
+
+ NaiveBayes nb = new NaiveBayes().setLambda(1.0);
+ NaiveBayesModel model = nb.run(testRDD.rdd());
+
+ int numAccurate = validatePrediction(POINTS, model);
+ Assert.assertEquals(POINTS.size(), numAccurate);
+ }
+
+ @Test
+ public void runUsingStaticMethods() {
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
+
+ NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd());
+ int numAccurate1 = validatePrediction(POINTS, model1);
+ Assert.assertEquals(POINTS.size(), numAccurate1);
+
+ NaiveBayesModel model2 = NaiveBayes.train(testRDD.rdd(), 0.5);
+ int numAccurate2 = validatePrediction(POINTS, model2);
+ Assert.assertEquals(POINTS.size(), numAccurate2);
+ }
+}
diff --git a/pom.xml b/pom.xml
index 6e2dd33d49..b25d9d7ef8 100644
--- a/pom.xml
+++ b/pom.xml
@@ -87,6 +87,7 @@
<modules>
<module>core</module>
<module>bagel</module>
+ <module>graphx</module>
<module>mllib</module>
<module>tools</module>
<module>streaming</module>
@@ -123,7 +124,7 @@
</properties>
<repositories>
- <repository>
+ <repository>
<id>maven-repo</id> <!-- This should be at top, it makes maven try the central repo first and then others and hence faster dep resolution -->
<name>Maven Repository</name>
<url>http://repo.maven.apache.org/maven2</url>
@@ -206,7 +207,7 @@
</dependency>
<!-- In theory we need not directly depend on protobuf since Spark does not directly
use it. However, when building with Hadoop/YARN 2.2 Maven doesn't correctly bump
- the protobuf version up from the one Mesos gives. For now we include this variable
+ the protobuf version up from the one Mesos gives. For now we include this variable
to explicitly bump the version when building with YARN. It would be nice to figure
out why Maven can't resolve this correctly (like SBT does). -->
<dependency>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index c8b5f09ab5..a9f9937cb1 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -48,18 +48,20 @@ object SparkBuild extends Build {
lazy val core = Project("core", file("core"), settings = coreSettings)
lazy val repl = Project("repl", file("repl"), settings = replSettings)
- .dependsOn(core, bagel, mllib)
+ .dependsOn(core, graphx, bagel, mllib)
lazy val tools = Project("tools", file("tools"), settings = toolsSettings) dependsOn(core) dependsOn(streaming)
lazy val bagel = Project("bagel", file("bagel"), settings = bagelSettings) dependsOn(core)
+ lazy val graphx = Project("graphx", file("graphx"), settings = graphxSettings) dependsOn(core)
+
lazy val streaming = Project("streaming", file("streaming"), settings = streamingSettings) dependsOn(core)
lazy val mllib = Project("mllib", file("mllib"), settings = mllibSettings) dependsOn(core)
lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings)
- .dependsOn(core, bagel, mllib, repl, streaming) dependsOn(maybeYarn: _*)
+ .dependsOn(core, graphx, bagel, mllib, repl, streaming) dependsOn(maybeYarn: _*)
lazy val assembleDeps = TaskKey[Unit]("assemble-deps", "Build assembly of dependencies and packages Spark projects")
@@ -109,10 +111,10 @@ object SparkBuild extends Build {
lazy val allExternalRefs = Seq[ProjectReference](externalTwitter, externalKafka, externalFlume, externalZeromq, externalMqtt)
lazy val examples = Project("examples", file("examples"), settings = examplesSettings)
- .dependsOn(core, mllib, bagel, streaming, externalTwitter) dependsOn(allExternal: _*)
+ .dependsOn(core, mllib, graphx, bagel, streaming, externalTwitter) dependsOn(allExternal: _*)
// Everything except assembly, tools and examples belong to packageProjects
- lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib) ++ maybeYarnRef
+ lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib, graphx) ++ maybeYarnRef
lazy val allProjects = packageProjects ++ allExternalRefs ++ Seq[ProjectReference](examples, tools, assemblyProj)
@@ -136,6 +138,13 @@ object SparkBuild extends Build {
javaOptions += "-Xmx3g",
// Show full stack trace and duration in test cases.
testOptions in Test += Tests.Argument("-oDF"),
+ // Remove certain packages from Scaladoc
+ scalacOptions in (Compile,doc) := Seq("-skip-packages", Seq(
+ "akka",
+ "org.apache.spark.network",
+ "org.apache.spark.deploy",
+ "org.apache.spark.util.collection"
+ ).mkString(":")),
// Only allow one test at a time, even across projects, since they run in the same JVM
concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
@@ -307,6 +316,10 @@ object SparkBuild extends Build {
name := "spark-tools"
) ++ assemblySettings ++ extraAssemblySettings
+ def graphxSettings = sharedSettings ++ Seq(
+ name := "spark-graphx"
+ )
+
def bagelSettings = sharedSettings ++ Seq(
name := "spark-bagel"
)
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index 769d88dfb9..20a0e309d1 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -16,7 +16,7 @@
#
from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
-from pyspark import SparkContext
+from pyspark import SparkContext, RDD
from pyspark.serializers import Serializer
import struct
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 70de332d34..19b90dfd6e 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -15,6 +15,8 @@
# limitations under the License.
#
+import numpy
+
from numpy import array, dot, shape
from pyspark import SparkContext
from pyspark.mllib._common import \
@@ -29,8 +31,8 @@ class LogisticRegressionModel(LinearModel):
"""A linear binary classification model derived from logistic regression.
>>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2)
- >>> lrm = LogisticRegressionWithSGD.train(sc, sc.parallelize(data))
- >>> lrm.predict(array([1.0])) != None
+ >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data))
+ >>> lrm.predict(array([1.0])) > 0
True
"""
def predict(self, x):
@@ -41,20 +43,21 @@ class LogisticRegressionModel(LinearModel):
class LogisticRegressionWithSGD(object):
@classmethod
- def train(cls, sc, data, iterations=100, step=1.0,
- mini_batch_fraction=1.0, initial_weights=None):
+ def train(cls, data, iterations=100, step=1.0,
+ miniBatchFraction=1.0, initialWeights=None):
"""Train a logistic regression model on the given data."""
+ sc = data.context
return _regression_train_wrapper(sc, lambda d, i:
sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD(d._jrdd,
- iterations, step, mini_batch_fraction, i),
- LogisticRegressionModel, data, initial_weights)
+ iterations, step, miniBatchFraction, i),
+ LogisticRegressionModel, data, initialWeights)
class SVMModel(LinearModel):
"""A support vector machine.
>>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2)
- >>> svm = SVMWithSGD.train(sc, sc.parallelize(data))
- >>> svm.predict(array([1.0])) != None
+ >>> svm = SVMWithSGD.train(sc.parallelize(data))
+ >>> svm.predict(array([1.0])) > 0
True
"""
def predict(self, x):
@@ -64,13 +67,63 @@ class SVMModel(LinearModel):
class SVMWithSGD(object):
@classmethod
- def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0,
- mini_batch_fraction=1.0, initial_weights=None):
+ def train(cls, data, iterations=100, step=1.0, regParam=1.0,
+ miniBatchFraction=1.0, initialWeights=None):
"""Train a support vector machine on the given data."""
+ sc = data.context
return _regression_train_wrapper(sc, lambda d, i:
sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD(d._jrdd,
- iterations, step, reg_param, mini_batch_fraction, i),
- SVMModel, data, initial_weights)
+ iterations, step, regParam, miniBatchFraction, i),
+ SVMModel, data, initialWeights)
+
+class NaiveBayesModel(object):
+ """
+ Model for Naive Bayes classifiers.
+
+ Contains two parameters:
+ - pi: vector of logs of class priors (dimension C)
+ - theta: matrix of logs of class conditional probabilities (CxD)
+
+ >>> data = array([0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0]).reshape(3,3)
+ >>> model = NaiveBayes.train(sc.parallelize(data))
+ >>> model.predict(array([0.0, 1.0]))
+ 0
+ >>> model.predict(array([1.0, 0.0]))
+ 1
+ """
+
+ def __init__(self, pi, theta):
+ self.pi = pi
+ self.theta = theta
+
+ def predict(self, x):
+ """Return the most likely class for a data vector x"""
+ return numpy.argmax(self.pi + dot(x, self.theta))
+
+class NaiveBayes(object):
+ @classmethod
+ def train(cls, data, lambda_=1.0):
+ """
+ Train a Naive Bayes model given an RDD of (label, features) vectors.
+
+ This is the Multinomial NB (U{http://tinyurl.com/lsdw6p}) which can
+ handle all kinds of discrete data. For example, by converting
+ documents into TF-IDF vectors, it can be used for document
+ classification. By making every vector a 0-1 vector, it can also be
+ used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}).
+
+ @param data: RDD of NumPy vectors, one per element, where the first
+ coordinate is the label and the rest is the feature vector
+ (e.g. a count vector).
+ @param lambda_: The smoothing parameter
+ """
+ sc = data.context
+ dataBytes = _get_unmangled_double_vector_rdd(data)
+ ans = sc._jvm.PythonMLLibAPI().trainNaiveBayes(dataBytes._jrdd, lambda_)
+ return NaiveBayesModel(
+ _deserialize_double_vector(ans[0]),
+ _deserialize_double_matrix(ans[1]))
+
def _test():
import doctest
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index 8cf20e591a..30862918c3 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -28,12 +28,12 @@ class KMeansModel(object):
"""A clustering model derived from the k-means method.
>>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2)
- >>> clusters = KMeans.train(sc, sc.parallelize(data), 2, maxIterations=10, runs=30, initialization_mode="random")
+ >>> clusters = KMeans.train(sc.parallelize(data), 2, maxIterations=10, runs=30, initializationMode="random")
>>> clusters.predict(array([0.0, 0.0])) == clusters.predict(array([1.0, 1.0]))
True
>>> clusters.predict(array([8.0, 9.0])) == clusters.predict(array([9.0, 8.0]))
True
- >>> clusters = KMeans.train(sc, sc.parallelize(data), 2)
+ >>> clusters = KMeans.train(sc.parallelize(data), 2)
"""
def __init__(self, centers_):
self.centers = centers_
@@ -52,12 +52,13 @@ class KMeansModel(object):
class KMeans(object):
@classmethod
- def train(cls, sc, data, k, maxIterations=100, runs=1,
- initialization_mode="k-means||"):
+ def train(cls, data, k, maxIterations=100, runs=1,
+ initializationMode="k-means||"):
"""Train a k-means clustering model."""
+ sc = data.context
dataBytes = _get_unmangled_double_vector_rdd(data)
ans = sc._jvm.PythonMLLibAPI().trainKMeansModel(dataBytes._jrdd,
- k, maxIterations, runs, initialization_mode)
+ k, maxIterations, runs, initializationMode)
if len(ans) != 1:
raise RuntimeError("JVM call result had unexpected length")
elif type(ans[0]) != bytearray:
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 0eeb5bb66b..f4a83f0209 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -32,11 +32,11 @@ class MatrixFactorizationModel(object):
>>> r2 = (1, 2, 2.0)
>>> r3 = (2, 1, 2.0)
>>> ratings = sc.parallelize([r1, r2, r3])
- >>> model = ALS.trainImplicit(sc, ratings, 1)
+ >>> model = ALS.trainImplicit(ratings, 1)
>>> model.predict(2,2) is not None
True
>>> testset = sc.parallelize([(1, 2), (1, 1)])
- >>> model.predictAll(testset).count == 2
+ >>> model.predictAll(testset).count() == 2
True
"""
@@ -57,14 +57,16 @@ class MatrixFactorizationModel(object):
class ALS(object):
@classmethod
- def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
+ def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
+ sc = ratings.context
ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
mod = sc._jvm.PythonMLLibAPI().trainALSModel(ratingBytes._jrdd,
rank, iterations, lambda_, blocks)
return MatrixFactorizationModel(sc, mod)
@classmethod
- def trainImplicit(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01):
+ def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01):
+ sc = ratings.context
ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(ratingBytes._jrdd,
rank, iterations, lambda_, blocks, alpha)
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index a3a68b29e0..7656db07f6 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -47,54 +47,57 @@ class LinearRegressionModel(LinearRegressionModelBase):
"""A linear regression model derived from a least-squares fit.
>>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2)
- >>> lrm = LinearRegressionWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0]))
+ >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0]))
"""
class LinearRegressionWithSGD(object):
@classmethod
- def train(cls, sc, data, iterations=100, step=1.0,
- mini_batch_fraction=1.0, initial_weights=None):
+ def train(cls, data, iterations=100, step=1.0,
+ miniBatchFraction=1.0, initialWeights=None):
"""Train a linear regression model on the given data."""
+ sc = data.context
return _regression_train_wrapper(sc, lambda d, i:
sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
- d._jrdd, iterations, step, mini_batch_fraction, i),
- LinearRegressionModel, data, initial_weights)
+ d._jrdd, iterations, step, miniBatchFraction, i),
+ LinearRegressionModel, data, initialWeights)
class LassoModel(LinearRegressionModelBase):
"""A linear regression model derived from a least-squares fit with an
l_1 penalty term.
>>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2)
- >>> lrm = LassoWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0]))
+ >>> lrm = LassoWithSGD.train(sc.parallelize(data), initialWeights=array([1.0]))
"""
-
+
class LassoWithSGD(object):
@classmethod
- def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0,
- mini_batch_fraction=1.0, initial_weights=None):
+ def train(cls, data, iterations=100, step=1.0, regParam=1.0,
+ miniBatchFraction=1.0, initialWeights=None):
"""Train a Lasso regression model on the given data."""
+ sc = data.context
return _regression_train_wrapper(sc, lambda d, i:
sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD(d._jrdd,
- iterations, step, reg_param, mini_batch_fraction, i),
- LassoModel, data, initial_weights)
+ iterations, step, regParam, miniBatchFraction, i),
+ LassoModel, data, initialWeights)
class RidgeRegressionModel(LinearRegressionModelBase):
"""A linear regression model derived from a least-squares fit with an
l_2 penalty term.
>>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2)
- >>> lrm = RidgeRegressionWithSGD.train(sc, sc.parallelize(data), initial_weights=array([1.0]))
+ >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0]))
"""
class RidgeRegressionWithSGD(object):
@classmethod
- def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0,
- mini_batch_fraction=1.0, initial_weights=None):
+ def train(cls, data, iterations=100, step=1.0, regParam=1.0,
+ miniBatchFraction=1.0, initialWeights=None):
"""Train a ridge regression model on the given data."""
+ sc = data.context
return _regression_train_wrapper(sc, lambda d, i:
sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD(d._jrdd,
- iterations, step, reg_param, mini_batch_fraction, i),
- RidgeRegressionModel, data, initial_weights)
+ iterations, step, regParam, miniBatchFraction, i),
+ RidgeRegressionModel, data, initialWeights)
def _test():
import doctest
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index f2b3f3c142..d77981f61f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -76,6 +76,10 @@ def main(infile, outfile):
iterator = deserializer.load_stream(infile)
serializer.dump_stream(func(split_index, iterator), outfile)
except Exception as e:
+ # Write the error to stderr in addition to trying to passi t back to
+ # Java, in case it happened while serializing a record
+ print >> sys.stderr, "PySpark worker failed with exception:"
+ print >> sys.stderr, traceback.format_exc()
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
write_with_length(traceback.format_exc(), outfile)
sys.exit(-1)
diff --git a/python/run-tests b/python/run-tests
index feba97cee0..2005f610b4 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -40,6 +40,11 @@ run_test "-m doctest pyspark/broadcast.py"
run_test "-m doctest pyspark/accumulators.py"
run_test "-m doctest pyspark/serializers.py"
run_test "pyspark/tests.py"
+#run_test "pyspark/mllib/_common.py"
+#run_test "pyspark/mllib/classification.py"
+#run_test "pyspark/mllib/clustering.py"
+#run_test "pyspark/mllib/recommendation.py"
+#run_test "pyspark/mllib/regression.py"
if [[ $FAILED != 0 ]]; then
echo -en "\033[31m" # Red
diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index daaa2a0305..8aad273665 100644
--- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -35,7 +35,6 @@ class ReplSuite extends FunSuite {
}
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
return out.toString
}
@@ -75,7 +74,6 @@ class ReplSuite extends FunSuite {
interp.sparkContext.stop()
System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
}
test("simple foreach with accumulator") {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index 1249ef4c3d..5046a1d53f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -40,13 +40,13 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
val graph = ssc.graph
val checkpointDir = ssc.checkpointDir
val checkpointDuration = ssc.checkpointDuration
- val pendingTimes = ssc.scheduler.getPendingTimes()
+ val pendingTimes = ssc.scheduler.getPendingTimes().toArray
val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf)
val sparkConf = ssc.conf
// These should be unset when a checkpoint is deserialized,
// otherwise the SparkContext won't initialize correctly.
- sparkConf.remove("spark.hostPort").remove("spark.driver.host").remove("spark.driver.port")
+ sparkConf.remove("spark.driver.host").remove("spark.driver.port")
def validate() {
assert(master != null, "Checkpoint.master is null")
@@ -271,6 +271,6 @@ class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoade
} catch {
case e: Exception =>
}
- return super.resolveClass(desc)
+ super.resolveClass(desc)
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala
new file mode 100644
index 0000000000..1f5dacb543
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala
@@ -0,0 +1,28 @@
+package org.apache.spark.streaming
+
+private[streaming] class ContextWaiter {
+ private var error: Throwable = null
+ private var stopped: Boolean = false
+
+ def notifyError(e: Throwable) = synchronized {
+ error = e
+ notifyAll()
+ }
+
+ def notifyStop() = synchronized {
+ notifyAll()
+ }
+
+ def waitForStopOrError(timeout: Long = -1) = synchronized {
+ // If already had error, then throw it
+ if (error != null) {
+ throw error
+ }
+
+ // If not already stopped, then wait
+ if (!stopped) {
+ if (timeout < 0) wait() else wait(timeout)
+ if (error != null) throw error
+ }
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index eee9591ffc..8faa79f8c7 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -17,11 +17,11 @@
package org.apache.spark.streaming
-import dstream.InputDStream
+import scala.collection.mutable.ArrayBuffer
import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
-import collection.mutable.ArrayBuffer
import org.apache.spark.Logging
import org.apache.spark.streaming.scheduler.Job
+import org.apache.spark.streaming.dstream.{DStream, NetworkInputDStream, InputDStream}
final private[streaming] class DStreamGraph extends Serializable with Logging {
@@ -78,7 +78,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
def remember(duration: Duration) {
this.synchronized {
if (rememberDuration != null) {
- throw new Exception("Batch duration already set as " + batchDuration +
+ throw new Exception("Remember duration already set as " + batchDuration +
". cannot set it again.")
}
rememberDuration = duration
@@ -103,6 +103,12 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
def getOutputStreams() = this.synchronized { outputStreams.toArray }
+ def getNetworkInputStreams() = this.synchronized {
+ inputStreams.filter(_.isInstanceOf[NetworkInputDStream[_]])
+ .map(_.asInstanceOf[NetworkInputDStream[_]])
+ .toArray
+ }
+
def generateJobs(time: Time): Seq[Job] = {
logDebug("Generating jobs for time " + time)
val jobs = this.synchronized {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index dd34f6f4f2..7b27933403 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -46,7 +46,7 @@ import org.apache.hadoop.conf.Configuration
* information (such as, cluster URL and job name) to internally create a SparkContext, it provides
* methods used to create DStream from various input sources.
*/
-class StreamingContext private (
+class StreamingContext private[streaming] (
sc_ : SparkContext,
cp_ : Checkpoint,
batchDur_ : Duration
@@ -101,20 +101,9 @@ class StreamingContext private (
"both SparkContext and checkpoint as null")
}
- private val conf_ = Option(sc_).map(_.conf).getOrElse(cp_.sparkConf)
+ private[streaming] val isCheckpointPresent = (cp_ != null)
- if(cp_ != null && cp_.delaySeconds >= 0 && MetadataCleaner.getDelaySeconds(conf_) < 0) {
- MetadataCleaner.setDelaySeconds(conf_, cp_.delaySeconds)
- }
-
- if (MetadataCleaner.getDelaySeconds(conf_) < 0) {
- throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; "
- + "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)")
- }
-
- protected[streaming] val isCheckpointPresent = (cp_ != null)
-
- protected[streaming] val sc: SparkContext = {
+ private[streaming] val sc: SparkContext = {
if (isCheckpointPresent) {
new SparkContext(cp_.sparkConf)
} else {
@@ -122,11 +111,16 @@ class StreamingContext private (
}
}
- protected[streaming] val conf = sc.conf
+ if (MetadataCleaner.getDelaySeconds(sc.conf) < 0) {
+ throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; "
+ + "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)")
+ }
+
+ private[streaming] val conf = sc.conf
- protected[streaming] val env = SparkEnv.get
+ private[streaming] val env = SparkEnv.get
- protected[streaming] val graph: DStreamGraph = {
+ private[streaming] val graph: DStreamGraph = {
if (isCheckpointPresent) {
cp_.graph.setContext(this)
cp_.graph.restoreCheckpointData()
@@ -139,10 +133,9 @@ class StreamingContext private (
}
}
- protected[streaming] val nextNetworkInputStreamId = new AtomicInteger(0)
- protected[streaming] var networkInputTracker: NetworkInputTracker = null
+ private val nextNetworkInputStreamId = new AtomicInteger(0)
- protected[streaming] var checkpointDir: String = {
+ private[streaming] var checkpointDir: String = {
if (isCheckpointPresent) {
sc.setCheckpointDir(cp_.checkpointDir)
cp_.checkpointDir
@@ -151,11 +144,13 @@ class StreamingContext private (
}
}
- protected[streaming] val checkpointDuration: Duration = {
+ private[streaming] val checkpointDuration: Duration = {
if (isCheckpointPresent) cp_.checkpointDuration else graph.batchDuration
}
- protected[streaming] val scheduler = new JobScheduler(this)
+ private[streaming] val scheduler = new JobScheduler(this)
+
+ private[streaming] val waiter = new ContextWaiter
/**
* Return the associated Spark context
*/
@@ -173,7 +168,7 @@ class StreamingContext private (
}
/**
- * Set the context to periodically checkpoint the DStream operations for master
+ * Set the context to periodically checkpoint the DStream operations for driver
* fault-tolerance.
* @param directory HDFS-compatible directory where the checkpoint data will be reliably stored.
* Note that this must be a fault-tolerant file system like HDFS for
@@ -191,11 +186,11 @@ class StreamingContext private (
}
}
- protected[streaming] def initialCheckpoint: Checkpoint = {
+ private[streaming] def initialCheckpoint: Checkpoint = {
if (isCheckpointPresent) cp_ else null
}
- protected[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement()
+ private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement()
/**
* Create an input stream with any arbitrary user implemented network receiver.
@@ -225,7 +220,7 @@ class StreamingContext private (
def actorStream[T: ClassTag](
props: Props,
name: String,
- storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2,
+ storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2,
supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy
): DStream[T] = {
networkStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy))
@@ -277,6 +272,7 @@ class StreamingContext private (
* @param hostname Hostname to connect to for receiving data
* @param port Port to connect to for receiving data
* @param storageLevel Storage level to use for storing the received objects
+ * (default: StorageLevel.MEMORY_AND_DISK_SER_2)
* @tparam T Type of the objects in the received blocks
*/
def rawSocketStream[T: ClassTag](
@@ -416,7 +412,7 @@ class StreamingContext private (
scheduler.listenerBus.addListener(streamingListener)
}
- protected def validate() {
+ private def validate() {
assert(graph != null, "Graph is null")
graph.validate()
@@ -430,38 +426,37 @@ class StreamingContext private (
/**
* Start the execution of the streams.
*/
- def start() {
+ def start() = synchronized {
validate()
+ scheduler.start()
+ }
- // Get the network input streams
- val networkInputStreams = graph.getInputStreams().filter(s => s match {
- case n: NetworkInputDStream[_] => true
- case _ => false
- }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray
-
- // Start the network input tracker (must start before receivers)
- if (networkInputStreams.length > 0) {
- networkInputTracker = new NetworkInputTracker(this, networkInputStreams)
- networkInputTracker.start()
- }
- Thread.sleep(1000)
+ /**
+ * Wait for the execution to stop. Any exceptions that occurs during the execution
+ * will be thrown in this thread.
+ */
+ def awaitTermination() {
+ waiter.waitForStopOrError()
+ }
- // Start the scheduler
- scheduler.start()
+ /**
+ * Wait for the execution to stop. Any exceptions that occurs during the execution
+ * will be thrown in this thread.
+ * @param timeout time to wait in milliseconds
+ */
+ def awaitTermination(timeout: Long) {
+ waiter.waitForStopOrError(timeout)
}
/**
* Stop the execution of the streams.
+ * @param stopSparkContext Stop the associated SparkContext or not
*/
- def stop() {
- try {
- if (scheduler != null) scheduler.stop()
- if (networkInputTracker != null) networkInputTracker.stop()
- sc.stop()
- logInfo("StreamingContext stopped successfully")
- } catch {
- case e: Exception => logWarning("Error while stopping", e)
- }
+ def stop(stopSparkContext: Boolean = true) = synchronized {
+ scheduler.stop()
+ logInfo("StreamingContext stopped successfully")
+ waiter.notifyStop()
+ if (stopSparkContext) sc.stop()
}
}
@@ -472,6 +467,8 @@ class StreamingContext private (
object StreamingContext extends Logging {
+ private[streaming] val DEFAULT_CLEANER_TTL = 3600
+
implicit def toPairDStreamFunctions[K: ClassTag, V: ClassTag](stream: DStream[(K,V)]) = {
new PairDStreamFunctions[K, V](stream)
}
@@ -515,37 +512,29 @@ object StreamingContext extends Logging {
*/
def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls)
-
- protected[streaming] def createNewSparkContext(conf: SparkConf): SparkContext = {
+ private[streaming] def createNewSparkContext(conf: SparkConf): SparkContext = {
// Set the default cleaner delay to an hour if not already set.
// This should be sufficient for even 1 second batch intervals.
if (MetadataCleaner.getDelaySeconds(conf) < 0) {
- MetadataCleaner.setDelaySeconds(conf, 3600)
+ MetadataCleaner.setDelaySeconds(conf, DEFAULT_CLEANER_TTL)
}
val sc = new SparkContext(conf)
sc
}
- protected[streaming] def createNewSparkContext(
+ private[streaming] def createNewSparkContext(
master: String,
appName: String,
sparkHome: String,
jars: Seq[String],
environment: Map[String, String]
): SparkContext = {
-
val conf = SparkContext.updatedConf(
new SparkConf(), master, appName, sparkHome, jars, environment)
- // Set the default cleaner delay to an hour if not already set.
- // This should be sufficient for even 1 second batch intervals.
- if (MetadataCleaner.getDelaySeconds(conf) < 0) {
- MetadataCleaner.setDelaySeconds(conf, 3600)
- }
- val sc = new SparkContext(master, appName, sparkHome, jars, environment)
- sc
+ createNewSparkContext(conf)
}
- protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = {
+ private[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = {
if (prefix == null) {
time.milliseconds.toString
} else if (suffix == null || suffix.length ==0) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala
index d29033df32..c92854ccd9 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala
@@ -17,13 +17,14 @@
package org.apache.spark.streaming.api.java
-import org.apache.spark.streaming.{Duration, Time, DStream}
+import org.apache.spark.streaming.{Duration, Time}
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.rdd.RDD
import scala.reflect.ClassTag
+import org.apache.spark.streaming.dstream.DStream
/**
* A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
index 64f38ce1c0..a493a8279f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
@@ -30,6 +30,7 @@ import org.apache.spark.api.java.function.{Function3 => JFunction3, _}
import java.util
import org.apache.spark.rdd.RDD
import JavaDStream._
+import org.apache.spark.streaming.dstream.DStream
trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]]
extends Serializable {
@@ -207,7 +208,6 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration)
}
-
/**
* Return a new DStream in which each RDD has a single element generated by reducing all
* elements in a sliding window over this DStream. However, the reduction is done incrementally
@@ -243,17 +243,39 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
* 'this' DStream will be registered as an output stream and therefore materialized.
+ *
+ * @deprecated As of release 0.9.0, replaced by foreachRDD
*/
+ @Deprecated
def foreach(foreachFunc: JFunction[R, Void]) {
- dstream.foreach(rdd => foreachFunc.call(wrapRDD(rdd)))
+ foreachRDD(foreachFunc)
}
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
* 'this' DStream will be registered as an output stream and therefore materialized.
+ *
+ * @deprecated As of release 0.9.0, replaced by foreachRDD
*/
+ @Deprecated
def foreach(foreachFunc: JFunction2[R, Time, Void]) {
- dstream.foreach((rdd, time) => foreachFunc.call(wrapRDD(rdd), time))
+ foreachRDD(foreachFunc)
+ }
+
+ /**
+ * Apply a function to each RDD in this DStream. This is an output operator, so
+ * 'this' DStream will be registered as an output stream and therefore materialized.
+ */
+ def foreachRDD(foreachFunc: JFunction[R, Void]) {
+ dstream.foreachRDD(rdd => foreachFunc.call(wrapRDD(rdd)))
+ }
+
+ /**
+ * Apply a function to each RDD in this DStream. This is an output operator, so
+ * 'this' DStream will be registered as an output stream and therefore materialized.
+ */
+ def foreachRDD(foreachFunc: JFunction2[R, Time, Void]) {
+ dstream.foreachRDD((rdd, time) => foreachFunc.call(wrapRDD(rdd), time))
}
/**
@@ -387,7 +409,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
}
/**
- * Enable periodic checkpointing of RDDs of this DStream
+ * Enable periodic checkpointing of RDDs of this DStream.
* @param interval Time interval after which generated RDD will be checkpointed
*/
def checkpoint(interval: Duration) = {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index 6c3467d405..6bb985ca54 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -35,6 +35,7 @@ import org.apache.spark.storage.StorageLevel
import com.google.common.base.Optional
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.PairRDDFunctions
+import org.apache.spark.streaming.dstream.DStream
class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
implicit val kManifest: ClassTag[K],
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index 523173d45a..108950466a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -36,6 +36,7 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming._
import org.apache.spark.streaming.scheduler.StreamingListener
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.streaming.dstream.DStream
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -131,7 +132,7 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* Re-creates a StreamingContext from a checkpoint file.
* @param path Path to the directory that was specified as the checkpoint directory
*/
- def this(path: String) = this(new StreamingContext(path))
+ def this(path: String) = this(new StreamingContext(path, new Configuration))
/**
* Re-creates a StreamingContext from a checkpoint file.
@@ -150,7 +151,6 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @param hostname Hostname to connect to for receiving data
* @param port Port to connect to for receiving data
* @param storageLevel Storage level to use for storing the received objects
- * (default: StorageLevel.MEMORY_AND_DISK_SER_2)
*/
def socketTextStream(hostname: String, port: Int, storageLevel: StorageLevel)
: JavaDStream[String] = {
@@ -160,7 +160,7 @@ class JavaStreamingContext(val ssc: StreamingContext) {
/**
* Create a input stream from network source hostname:port. Data is received using
* a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited
- * lines.
+ * lines. Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2.
* @param hostname Hostname to connect to for receiving data
* @param port Port to connect to for receiving data
*/
@@ -301,6 +301,7 @@ class JavaStreamingContext(val ssc: StreamingContext) {
/**
* Create an input stream with any arbitrary user implemented actor receiver.
+ * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2.
* @param props Props object defining creation of the actor
* @param name Name of the actor
*
@@ -483,9 +484,28 @@ class JavaStreamingContext(val ssc: StreamingContext) {
def start() = ssc.start()
/**
- * Stop the execution of the streams.
+ * Wait for the execution to stop. Any exceptions that occurs during the execution
+ * will be thrown in this thread.
+ */
+ def awaitTermination() = ssc.awaitTermination()
+
+ /**
+ * Wait for the execution to stop. Any exceptions that occurs during the execution
+ * will be thrown in this thread.
+ * @param timeout time to wait in milliseconds
+ */
+ def awaitTermination(timeout: Long) = ssc.awaitTermination(timeout)
+
+ /**
+ * Stop the execution of the streams. Will stop the associated JavaSparkContext as well.
*/
def stop() = ssc.stop()
+
+ /**
+ * Stop the execution of the streams.
+ * @param stopSparkContext Stop the associated SparkContext or not
+ */
+ def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext)
}
/**
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index b98f4a5101..426f61e24a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -15,36 +15,39 @@
* limitations under the License.
*/
-package org.apache.spark.streaming
+package org.apache.spark.streaming.dstream
-import StreamingContext._
-import org.apache.spark.streaming.dstream._
-import org.apache.spark.streaming.scheduler.Job
-import org.apache.spark.Logging
-import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.MetadataCleaner
+import scala.deprecated
import scala.collection.mutable.HashMap
import scala.reflect.ClassTag
-import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
+import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
+import org.apache.spark.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.MetadataCleaner
+import org.apache.spark.streaming._
+import org.apache.spark.streaming.StreamingContext._
+import org.apache.spark.streaming.scheduler.Job
+import org.apache.spark.streaming.Duration
/**
* A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous
- * sequence of RDDs (of the same type) representing a continuous stream of data (see [[org.apache.spark.rdd.RDD]]
- * for more details on RDDs). DStreams can either be created from live data (such as, data from
- * HDFS, Kafka or Flume) or it can be generated by transformation existing DStreams using operations
- * such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each
- * DStream periodically generates a RDD, either from live data or by transforming the RDD generated
- * by a parent DStream.
+ * sequence of RDDs (of the same type) representing a continuous stream of data (see
+ * [[org.apache.spark.rdd.RDD]] for more details on RDDs). DStreams can either be created from
+ * live data (such as, data from * HDFS, Kafka or Flume) or it can be generated by transformation
+ * existing DStreams using operations such as `map`, `window` and `reduceByKeyAndWindow`.
+ * While a Spark Streaming program is running, each DStream periodically generates a RDD,
+ * either from live data or by transforming the RDD generated by a parent DStream.
*
* This class contains the basic operations available on all DStreams, such as `map`, `filter` and
- * `window`. In addition, [[org.apache.spark.streaming.PairDStreamFunctions]] contains operations available
- * only on DStreams of key-value pairs, such as `groupByKeyAndWindow` and `join`. These operations
- * are automatically available on any DStream of the right type (e.g., DStream[(Int, Int)] through
- * implicit conversions when `spark.streaming.StreamingContext._` is imported.
+ * `window`. In addition, [[org.apache.spark.streaming.dstream.PairDStreamFunctions]] contains
+ * operations available only on DStreams of key-value pairs, such as `groupByKeyAndWindow` and
+ * `join`. These operations are automatically available on any DStream of pairs
+ * (e.g., DStream[(Int, Int)] through implicit conversions when
+ * `org.apache.spark.streaming.StreamingContext._` is imported.
*
* DStreams internally is characterized by a few basic properties:
* - A list of other DStreams that the DStream depends on
@@ -53,7 +56,7 @@ import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
*/
abstract class DStream[T: ClassTag] (
- @transient protected[streaming] var ssc: StreamingContext
+ @transient private[streaming] var ssc: StreamingContext
) extends Serializable with Logging {
// =======================================================================
@@ -73,31 +76,31 @@ abstract class DStream[T: ClassTag] (
// Methods and fields available on all DStreams
// =======================================================================
- // RDDs generated, marked as protected[streaming] so that testsuites can access it
+ // RDDs generated, marked as private[streaming] so that testsuites can access it
@transient
- protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] ()
+ private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] ()
// Time zero for the DStream
- protected[streaming] var zeroTime: Time = null
+ private[streaming] var zeroTime: Time = null
// Duration for which the DStream will remember each RDD created
- protected[streaming] var rememberDuration: Duration = null
+ private[streaming] var rememberDuration: Duration = null
// Storage level of the RDDs in the stream
- protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE
+ private[streaming] var storageLevel: StorageLevel = StorageLevel.NONE
// Checkpoint details
- protected[streaming] val mustCheckpoint = false
- protected[streaming] var checkpointDuration: Duration = null
- protected[streaming] val checkpointData = new DStreamCheckpointData(this)
+ private[streaming] val mustCheckpoint = false
+ private[streaming] var checkpointDuration: Duration = null
+ private[streaming] val checkpointData = new DStreamCheckpointData(this)
// Reference to whole DStream graph
- protected[streaming] var graph: DStreamGraph = null
+ private[streaming] var graph: DStreamGraph = null
- protected[streaming] def isInitialized = (zeroTime != null)
+ private[streaming] def isInitialized = (zeroTime != null)
// Duration for which the DStream requires its parent DStream to remember each RDD created
- protected[streaming] def parentRememberDuration = rememberDuration
+ private[streaming] def parentRememberDuration = rememberDuration
/** Return the StreamingContext associated with this DStream */
def context = ssc
@@ -137,7 +140,7 @@ abstract class DStream[T: ClassTag] (
* the validity of future times is calculated. This method also recursively initializes
* its parent DStreams.
*/
- protected[streaming] def initialize(time: Time) {
+ private[streaming] def initialize(time: Time) {
if (zeroTime != null && zeroTime != time) {
throw new Exception("ZeroTime is already initialized to " + zeroTime
+ ", cannot initialize it again to " + time)
@@ -153,7 +156,8 @@ abstract class DStream[T: ClassTag] (
// Set the minimum value of the rememberDuration if not already set
var minRememberDuration = slideDuration
if (checkpointDuration != null && minRememberDuration <= checkpointDuration) {
- minRememberDuration = checkpointDuration * 2 // times 2 just to be sure that the latest checkpoint is not forgetten
+ // times 2 just to be sure that the latest checkpoint is not forgotten (#paranoia)
+ minRememberDuration = checkpointDuration * 2
}
if (rememberDuration == null || rememberDuration < minRememberDuration) {
rememberDuration = minRememberDuration
@@ -163,7 +167,7 @@ abstract class DStream[T: ClassTag] (
dependencies.foreach(_.initialize(zeroTime))
}
- protected[streaming] def validate() {
+ private[streaming] def validate() {
assert(rememberDuration != null, "Remember duration is set to null")
assert(
@@ -227,7 +231,7 @@ abstract class DStream[T: ClassTag] (
logInfo("Initialized and validated " + this)
}
- protected[streaming] def setContext(s: StreamingContext) {
+ private[streaming] def setContext(s: StreamingContext) {
if (ssc != null && ssc != s) {
throw new Exception("Context is already set in " + this + ", cannot set it again")
}
@@ -236,7 +240,7 @@ abstract class DStream[T: ClassTag] (
dependencies.foreach(_.setContext(ssc))
}
- protected[streaming] def setGraph(g: DStreamGraph) {
+ private[streaming] def setGraph(g: DStreamGraph) {
if (graph != null && graph != g) {
throw new Exception("Graph is already set in " + this + ", cannot set it again")
}
@@ -244,7 +248,7 @@ abstract class DStream[T: ClassTag] (
dependencies.foreach(_.setGraph(graph))
}
- protected[streaming] def remember(duration: Duration) {
+ private[streaming] def remember(duration: Duration) {
if (duration != null && duration > rememberDuration) {
rememberDuration = duration
logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this)
@@ -253,14 +257,15 @@ abstract class DStream[T: ClassTag] (
}
/** Checks whether the 'time' is valid wrt slideDuration for generating RDD */
- protected def isTimeValid(time: Time): Boolean = {
+ private[streaming] def isTimeValid(time: Time): Boolean = {
if (!isInitialized) {
throw new Exception (this + " has not been initialized")
} else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) {
- logInfo("Time " + time + " is invalid as zeroTime is " + zeroTime + " and slideDuration is " + slideDuration + " and difference is " + (time - zeroTime))
+ logInfo("Time " + time + " is invalid as zeroTime is " + zeroTime +
+ " and slideDuration is " + slideDuration + " and difference is " + (time - zeroTime))
false
} else {
- logInfo("Time " + time + " is valid")
+ logDebug("Time " + time + " is valid")
true
}
}
@@ -269,7 +274,7 @@ abstract class DStream[T: ClassTag] (
* Retrieve a precomputed RDD of this DStream, or computes the RDD. This is an internal
* method that should not be called directly.
*/
- protected[streaming] def getOrCompute(time: Time): Option[RDD[T]] = {
+ private[streaming] def getOrCompute(time: Time): Option[RDD[T]] = {
// If this DStream was not initialized (i.e., zeroTime not set), then do it
// If RDD was already generated, then retrieve it from HashMap
generatedRDDs.get(time) match {
@@ -286,11 +291,14 @@ abstract class DStream[T: ClassTag] (
case Some(newRDD) =>
if (storageLevel != StorageLevel.NONE) {
newRDD.persist(storageLevel)
- logInfo("Persisting RDD " + newRDD.id + " for time " + time + " to " + storageLevel + " at time " + time)
+ logInfo("Persisting RDD " + newRDD.id + " for time " +
+ time + " to " + storageLevel + " at time " + time)
}
- if (checkpointDuration != null && (time - zeroTime).isMultipleOf(checkpointDuration)) {
+ if (checkpointDuration != null &&
+ (time - zeroTime).isMultipleOf(checkpointDuration)) {
newRDD.checkpoint()
- logInfo("Marking RDD " + newRDD.id + " for time " + time + " for checkpointing at time " + time)
+ logInfo("Marking RDD " + newRDD.id + " for time " + time +
+ " for checkpointing at time " + time)
}
generatedRDDs.put(time, newRDD)
Some(newRDD)
@@ -310,7 +318,7 @@ abstract class DStream[T: ClassTag] (
* that materializes the corresponding RDD. Subclasses of DStream may override this
* to generate their own jobs.
*/
- protected[streaming] def generateJob(time: Time): Option[Job] = {
+ private[streaming] def generateJob(time: Time): Option[Job] = {
getOrCompute(time) match {
case Some(rdd) => {
val jobFunc = () => {
@@ -329,18 +337,22 @@ abstract class DStream[T: ClassTag] (
* implementation clears the old generated RDDs. Subclasses of DStream may override
* this to clear their own metadata along with the generated RDDs.
*/
- protected[streaming] def clearMetadata(time: Time) {
+ private[streaming] def clearMetadata(time: Time) {
val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration))
generatedRDDs --= oldRDDs.keys
+ if (ssc.conf.getBoolean("spark.streaming.unpersist", false)) {
+ logDebug("Unpersisting old RDDs: " + oldRDDs.keys.mkString(", "))
+ oldRDDs.values.foreach(_.unpersist(false))
+ }
logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " +
(time - rememberDuration) + ": " + oldRDDs.keys.mkString(", "))
dependencies.foreach(_.clearMetadata(time))
}
/* Adds metadata to the Stream while it is running.
- * This methd should be overwritten by sublcasses of InputDStream.
+ * This method should be overwritten by sublcasses of InputDStream.
*/
- protected[streaming] def addMetadata(metadata: Any) {
+ private[streaming] def addMetadata(metadata: Any) {
if (metadata != null) {
logInfo("Dropping Metadata: " + metadata.toString)
}
@@ -353,18 +365,18 @@ abstract class DStream[T: ClassTag] (
* checkpointData. Subclasses of DStream (especially those of InputDStream) may override
* this method to save custom checkpoint data.
*/
- protected[streaming] def updateCheckpointData(currentTime: Time) {
- logInfo("Updating checkpoint data for time " + currentTime)
+ private[streaming] def updateCheckpointData(currentTime: Time) {
+ logDebug("Updating checkpoint data for time " + currentTime)
checkpointData.update(currentTime)
dependencies.foreach(_.updateCheckpointData(currentTime))
logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData)
}
- protected[streaming] def clearCheckpointData(time: Time) {
- logInfo("Clearing checkpoint data")
+ private[streaming] def clearCheckpointData(time: Time) {
+ logDebug("Clearing checkpoint data")
checkpointData.cleanup(time)
dependencies.foreach(_.clearCheckpointData(time))
- logInfo("Cleared checkpoint data")
+ logDebug("Cleared checkpoint data")
}
/**
@@ -373,7 +385,7 @@ abstract class DStream[T: ClassTag] (
* from the checkpoint file names stored in checkpointData. Subclasses of DStream that
* override the updateCheckpointData() method would also need to override this method.
*/
- protected[streaming] def restoreCheckpointData() {
+ private[streaming] def restoreCheckpointData() {
// Create RDDs from the checkpoint data
logInfo("Restoring checkpoint data")
checkpointData.restore()
@@ -399,7 +411,8 @@ abstract class DStream[T: ClassTag] (
}
}
} else {
- throw new java.io.NotSerializableException("Graph is unexpectedly null when DStream is being serialized.")
+ throw new java.io.NotSerializableException(
+ "Graph is unexpectedly null when DStream is being serialized.")
}
}
@@ -487,15 +500,29 @@ abstract class DStream[T: ClassTag] (
* Apply a function to each RDD in this DStream. This is an output operator, so
* 'this' DStream will be registered as an output stream and therefore materialized.
*/
- def foreach(foreachFunc: RDD[T] => Unit) {
- this.foreach((r: RDD[T], t: Time) => foreachFunc(r))
+ @deprecated("use foreachRDD", "0.9.0")
+ def foreach(foreachFunc: RDD[T] => Unit) = this.foreachRDD(foreachFunc)
+
+ /**
+ * Apply a function to each RDD in this DStream. This is an output operator, so
+ * 'this' DStream will be registered as an output stream and therefore materialized.
+ */
+ @deprecated("use foreachRDD", "0.9.0")
+ def foreach(foreachFunc: (RDD[T], Time) => Unit) = this.foreachRDD(foreachFunc)
+
+ /**
+ * Apply a function to each RDD in this DStream. This is an output operator, so
+ * 'this' DStream will be registered as an output stream and therefore materialized.
+ */
+ def foreachRDD(foreachFunc: RDD[T] => Unit) {
+ this.foreachRDD((r: RDD[T], t: Time) => foreachFunc(r))
}
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
* 'this' DStream will be registered as an output stream and therefore materialized.
*/
- def foreach(foreachFunc: (RDD[T], Time) => Unit) {
+ def foreachRDD(foreachFunc: (RDD[T], Time) => Unit) {
ssc.registerOutputStream(new ForEachDStream(this, context.sparkContext.clean(foreachFunc)))
}
@@ -635,8 +662,8 @@ abstract class DStream[T: ClassTag] (
/**
* Return a new DStream in which each RDD has a single element generated by counting the number
- * of elements in a sliding window over this DStream. Hash partitioning is used to generate the RDDs with
- * Spark's default number of partitions.
+ * of elements in a sliding window over this DStream. Hash partitioning is used to generate
+ * the RDDs with Spark's default number of partitions.
* @param windowDuration width of the window; must be a multiple of this DStream's
* batching interval
* @param slideDuration sliding interval of the window (i.e., the interval after which
@@ -684,7 +711,7 @@ abstract class DStream[T: ClassTag] (
/**
* Return all the RDDs defined by the Interval object (both end times included)
*/
- protected[streaming] def slice(interval: Interval): Seq[RDD[T]] = {
+ def slice(interval: Interval): Seq[RDD[T]] = {
slice(interval.beginTime, interval.endTime)
}
@@ -693,10 +720,12 @@ abstract class DStream[T: ClassTag] (
*/
def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = {
if (!(fromTime - zeroTime).isMultipleOf(slideDuration)) {
- logWarning("fromTime (" + fromTime + ") is not a multiple of slideDuration (" + slideDuration + ")")
+ logWarning("fromTime (" + fromTime + ") is not a multiple of slideDuration ("
+ + slideDuration + ")")
}
if (!(toTime - zeroTime).isMultipleOf(slideDuration)) {
- logWarning("toTime (" + fromTime + ") is not a multiple of slideDuration (" + slideDuration + ")")
+ logWarning("toTime (" + fromTime + ") is not a multiple of slideDuration ("
+ + slideDuration + ")")
}
val alignedToTime = toTime.floor(slideDuration)
val alignedFromTime = fromTime.floor(slideDuration)
@@ -719,7 +748,7 @@ abstract class DStream[T: ClassTag] (
val file = rddToFileName(prefix, suffix, time)
rdd.saveAsObjectFile(file)
}
- this.foreach(saveFunc)
+ this.foreachRDD(saveFunc)
}
/**
@@ -732,10 +761,15 @@ abstract class DStream[T: ClassTag] (
val file = rddToFileName(prefix, suffix, time)
rdd.saveAsTextFile(file)
}
- this.foreach(saveFunc)
+ this.foreachRDD(saveFunc)
}
- def register() {
+ /**
+ * Register this streaming as an output stream. This would ensure that RDDs of this
+ * DStream will be generated.
+ */
+ def register(): DStream[T] = {
ssc.registerOutputStream(this)
+ this
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala
index 671f7bbce7..38bad5ac80 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala
@@ -15,17 +15,15 @@
* limitations under the License.
*/
-package org.apache.spark.streaming
+package org.apache.spark.streaming.dstream
-import scala.collection.mutable.{HashMap, HashSet}
+import scala.collection.mutable.HashMap
import scala.reflect.ClassTag
-
+import java.io.{ObjectInputStream, IOException}
import org.apache.hadoop.fs.Path
import org.apache.hadoop.fs.FileSystem
-
import org.apache.spark.Logging
-
-import java.io.{ObjectInputStream, IOException}
+import org.apache.spark.streaming.Time
private[streaming]
class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
@@ -96,7 +94,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
}
}
case None =>
- logInfo("Nothing to delete")
+ logDebug("Nothing to delete")
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
index 1f0f31c4b1..8a6051622e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
@@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.UnionRDD
-import org.apache.spark.streaming.{DStreamCheckpointData, StreamingContext, Time}
+import org.apache.spark.streaming.{StreamingContext, Time}
import org.apache.spark.util.TimeStampedHashMap
@@ -39,24 +39,22 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData
- // Latest file mod time seen till any point of time
- private val prevModTimeFiles = new HashSet[String]()
- private var prevModTime = 0L
+ // files found in the last interval
+ private val lastFoundFiles = new HashSet[String]
+
+ // Files with mod time earlier than this is ignored. This is updated every interval
+ // such that in the current interval, files older than any file found in the
+ // previous interval will be ignored. Obviously this time keeps moving forward.
+ private var ignoreTime = if (newFilesOnly) 0L else System.currentTimeMillis()
+ // Latest file mod time seen till any point of time
@transient private var path_ : Path = null
@transient private var fs_ : FileSystem = null
@transient private[streaming] var files = new HashMap[Time, Array[String]]
@transient private var fileModTimes = new TimeStampedHashMap[String, Long](true)
@transient private var lastNewFileFindingTime = 0L
- override def start() {
- if (newFilesOnly) {
- prevModTime = graph.zeroTime.milliseconds
- } else {
- prevModTime = 0
- }
- logDebug("LastModTime initialized to " + prevModTime + ", new files only = " + newFilesOnly)
- }
+ override def start() { }
override def stop() { }
@@ -70,20 +68,16 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
* the previous call.
*/
override def compute(validTime: Time): Option[RDD[(K, V)]] = {
- assert(validTime.milliseconds >= prevModTime,
- "Trying to get new files for really old time [" + validTime + " < " + prevModTime + "]")
+ assert(validTime.milliseconds >= ignoreTime,
+ "Trying to get new files for a really old time [" + validTime + " < " + ignoreTime + "]")
// Find new files
- val (newFiles, latestModTime, latestModTimeFiles) = findNewFiles(validTime.milliseconds)
+ val (newFiles, minNewFileModTime) = findNewFiles(validTime.milliseconds)
logInfo("New files at time " + validTime + ":\n" + newFiles.mkString("\n"))
- if (newFiles.length > 0) {
- // Update the modification time and the files processed for that modification time
- if (prevModTime < latestModTime) {
- prevModTime = latestModTime
- prevModTimeFiles.clear()
- }
- prevModTimeFiles ++= latestModTimeFiles
- logDebug("Last mod time updated to " + prevModTime)
+ if (!newFiles.isEmpty) {
+ lastFoundFiles.clear()
+ lastFoundFiles ++= newFiles
+ ignoreTime = minNewFileModTime
}
files += ((validTime, newFiles.toArray))
Some(filesToRDD(newFiles))
@@ -92,7 +86,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
/** Clear the old time-to-files mappings along with old RDDs */
protected[streaming] override def clearMetadata(time: Time) {
super.clearMetadata(time)
- val oldFiles = files.filter(_._1 <= (time - rememberDuration))
+ val oldFiles = files.filter(_._1 < (time - rememberDuration))
files --= oldFiles.keys
logInfo("Cleared " + oldFiles.size + " old files that were older than " +
(time - rememberDuration) + ": " + oldFiles.keys.mkString(", "))
@@ -106,7 +100,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
* Find files which have modification timestamp <= current time and return a 3-tuple of
* (new files found, latest modification time among them, files with latest modification time)
*/
- private def findNewFiles(currentTime: Long): (Seq[String], Long, Seq[String]) = {
+ private def findNewFiles(currentTime: Long): (Seq[String], Long) = {
logDebug("Trying to get new files for time " + currentTime)
lastNewFileFindingTime = System.currentTimeMillis
val filter = new CustomPathFilter(currentTime)
@@ -121,7 +115,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
"files in the monitored directory."
)
}
- (newFiles, filter.latestModTime, filter.latestModTimeFiles.toSeq)
+ (newFiles, filter.minNewFileModTime)
}
/** Generate one RDD from an array of files */
@@ -200,38 +194,42 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
}
/**
- * Custom PathFilter class to find new files that have modification timestamps <= current time,
- * but have not been seen before (i.e. the file should not be in lastModTimeFiles)
+ * Custom PathFilter class to find new files that
+ * ... have modification time more than ignore time
+ * ... have not been seen in the last interval
+ * ... have modification time less than maxModTime
*/
private[streaming]
class CustomPathFilter(maxModTime: Long) extends PathFilter {
- // Latest file mod time seen in this round of fetching files and its corresponding files
- var latestModTime = 0L
- val latestModTimeFiles = new HashSet[String]()
+
+ // Minimum of the mod times of new files found in the current interval
+ var minNewFileModTime = -1L
+
def accept(path: Path): Boolean = {
try {
if (!filter(path)) { // Reject file if it does not satisfy filter
logDebug("Rejected by filter " + path)
return false
}
+ // Reject file if it was found in the last interval
+ if (lastFoundFiles.contains(path.toString)) {
+ logDebug("Mod time equal to last mod time, but file considered already")
+ return false
+ }
val modTime = getFileModTime(path)
logDebug("Mod time for " + path + " is " + modTime)
- if (modTime < prevModTime) {
- logDebug("Mod time less than last mod time")
- return false // If the file was created before the last time it was called
- } else if (modTime == prevModTime && prevModTimeFiles.contains(path.toString)) {
- logDebug("Mod time equal to last mod time, but file considered already")
- return false // If the file was created exactly as lastModTime but not reported yet
+ if (modTime < ignoreTime) {
+ // Reject file if it was created before the ignore time (or, before last interval)
+ logDebug("Mod time " + modTime + " less than ignore time " + ignoreTime)
+ return false
} else if (modTime > maxModTime) {
+ // Reject file if it is too new that considering it may give errors
logDebug("Mod time more than ")
- return false // If the file is too new that considering it may give errors
+ return false
}
- if (modTime > latestModTime) {
- latestModTime = modTime
- latestModTimeFiles.clear()
- logDebug("Latest mod time updated to " + latestModTime)
+ if (minNewFileModTime < 0 || modTime < minNewFileModTime) {
+ minNewFileModTime = modTime
}
- latestModTimeFiles += path.toString
logDebug("Accepted " + path)
} catch {
case fnfe: java.io.FileNotFoundException =>
@@ -239,7 +237,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
reset()
return false
}
- return true
+ true
}
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala
index db2e0a4cee..c81534ae58 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming.dstream
-import org.apache.spark.streaming.{Duration, DStream, Time}
+import org.apache.spark.streaming.{Duration, Time}
import org.apache.spark.rdd.RDD
import scala.reflect.ClassTag
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala
index 244dc3ee4f..6586234554 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming.dstream
-import org.apache.spark.streaming.{Duration, DStream, Time}
+import org.apache.spark.streaming.{Duration, Time}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import scala.reflect.ClassTag
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala
index 336c4b7a92..c7bb2833ea 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming.dstream
-import org.apache.spark.streaming.{Duration, DStream, Time}
+import org.apache.spark.streaming.{Duration, Time}
import org.apache.spark.rdd.RDD
import scala.reflect.ClassTag
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
index 364abcde68..905bc723f6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
@@ -18,7 +18,7 @@
package org.apache.spark.streaming.dstream
import org.apache.spark.rdd.RDD
-import org.apache.spark.streaming.{Duration, DStream, Time}
+import org.apache.spark.streaming.{Duration, Time}
import org.apache.spark.streaming.scheduler.Job
import scala.reflect.ClassTag
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala
index 23136f44fa..a9bb51f054 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming.dstream
-import org.apache.spark.streaming.{Duration, DStream, Time}
+import org.apache.spark.streaming.{Duration, Time}
import org.apache.spark.rdd.RDD
import scala.reflect.ClassTag
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
index f01e67fe13..a1075ad304 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming.dstream
-import org.apache.spark.streaming.{Time, Duration, StreamingContext, DStream}
+import org.apache.spark.streaming.{Time, Duration, StreamingContext}
import scala.reflect.ClassTag
@@ -43,7 +43,7 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext)
* This ensures that InputDStream.compute() is called strictly on increasing
* times.
*/
- override protected def isTimeValid(time: Time): Boolean = {
+ override private[streaming] def isTimeValid(time: Time): Boolean = {
if (!super.isTimeValid(time)) {
false // Time not valid
} else {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala
index 8a04060e5b..3d8ee29df1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming.dstream
-import org.apache.spark.streaming.{Duration, DStream, Time}
+import org.apache.spark.streaming.{Duration, Time}
import org.apache.spark.rdd.RDD
import scala.reflect.ClassTag
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala
index 0ce364fd46..7aea1f945d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming.dstream
-import org.apache.spark.streaming.{Duration, DStream, Time}
+import org.apache.spark.streaming.{Duration, Time}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import scala.reflect.ClassTag
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala
index c0b7491d09..02704a8d1c 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming.dstream
-import org.apache.spark.streaming.{Duration, DStream, Time}
+import org.apache.spark.streaming.{Duration, Time}
import org.apache.spark.rdd.RDD
import scala.reflect.ClassTag
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
index d41f726f83..0f1f6fc2ce 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
@@ -68,7 +68,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte
// then this returns an empty RDD. This may happen when recovering from a
// master failure
if (validTime >= graph.startTime) {
- val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime)
+ val blockIds = ssc.scheduler.networkInputTracker.getBlockIds(id, validTime)
Some(new BlockRDD[T](ssc.sc, blockIds))
} else {
Some(new BlockRDD[T](ssc.sc, Array[BlockId]()))
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
index 56dbcbda23..6b3e48382e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.streaming
+package org.apache.spark.streaming.dstream
import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.streaming.dstream._
@@ -33,6 +33,7 @@ import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.streaming.{Time, Duration}
class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)])
extends Serializable {
@@ -582,7 +583,7 @@ extends Serializable {
val file = rddToFileName(prefix, suffix, time)
rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, conf)
}
- self.foreach(saveFunc)
+ self.foreachRDD(saveFunc)
}
/**
@@ -612,7 +613,7 @@ extends Serializable {
val file = rddToFileName(prefix, suffix, time)
rdd.saveAsNewAPIHadoopFile(file, keyClass, valueClass, outputFormatClass, conf)
}
- self.foreach(saveFunc)
+ self.foreachRDD(saveFunc)
}
private def getKeyClass() = implicitly[ClassTag[K]].runtimeClass
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
index db56345ca8..7a6b1ea35e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
@@ -26,7 +26,7 @@ import org.apache.spark.SparkContext._
import org.apache.spark.storage.StorageLevel
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.streaming.{Duration, Interval, Time, DStream}
+import org.apache.spark.streaming.{Duration, Interval, Time}
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
index 84e69f277b..880a89bc36 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
@@ -20,7 +20,7 @@ package org.apache.spark.streaming.dstream
import org.apache.spark.Partitioner
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
-import org.apache.spark.streaming.{Duration, DStream, Time}
+import org.apache.spark.streaming.{Duration, Time}
import scala.reflect.ClassTag
private[streaming]
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
index e0ff3ccba4..9d8889b655 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
@@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.Partitioner
import org.apache.spark.SparkContext._
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.{Duration, Time, DStream}
+import org.apache.spark.streaming.{Duration, Time}
import scala.reflect.ClassTag
@@ -65,7 +65,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
//logDebug("Generating state RDD for time " + validTime)
- return Some(stateRDD)
+ Some(stateRDD)
}
case None => { // If parent RDD does not exist
@@ -76,7 +76,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
updateFuncLocal(i)
}
val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning)
- return Some(stateRDD)
+ Some(stateRDD)
}
}
}
@@ -98,11 +98,11 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
val groupedRDD = parentRDD.groupByKey(partitioner)
val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning)
//logDebug("Generating state RDD for time " + validTime + " (first)")
- return Some(sessionRDD)
+ Some(sessionRDD)
}
case None => { // If parent RDD does not exist, then nothing to do!
//logDebug("Not generating state RDD (no previous state, no parent)")
- return None
+ None
}
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
index aeea060df7..7cd4554282 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
@@ -18,7 +18,7 @@
package org.apache.spark.streaming.dstream
import org.apache.spark.rdd.RDD
-import org.apache.spark.streaming.{Duration, DStream, Time}
+import org.apache.spark.streaming.{Duration, Time}
import scala.reflect.ClassTag
private[streaming]
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
index 0d84ec84f2..4ecba03ab5 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
@@ -17,9 +17,8 @@
package org.apache.spark.streaming.dstream
-import org.apache.spark.streaming.{Duration, DStream, Time}
+import org.apache.spark.streaming.{Duration, Time}
import org.apache.spark.rdd.RDD
-import collection.mutable.ArrayBuffer
import org.apache.spark.rdd.UnionRDD
import scala.collection.mutable.ArrayBuffer
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
index 89c43ff935..6301772468 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
@@ -32,13 +32,14 @@ class WindowedDStream[T: ClassTag](
extends DStream[T](parent.ssc) {
if (!_windowDuration.isMultipleOf(parent.slideDuration))
- throw new Exception("The window duration of WindowedDStream (" + _slideDuration + ") " +
- "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")")
+ throw new Exception("The window duration of windowed DStream (" + _slideDuration + ") " +
+ "must be a multiple of the slide duration of parent DStream (" + parent.slideDuration + ")")
if (!_slideDuration.isMultipleOf(parent.slideDuration))
- throw new Exception("The slide duration of WindowedDStream (" + _slideDuration + ") " +
- "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")")
+ throw new Exception("The slide duration of windowed DStream (" + _slideDuration + ") " +
+ "must be a multiple of the slide duration of parent DStream (" + parent.slideDuration + ")")
+ // Persist parent level by default, as those RDDs are going to be obviously reused.
parent.persist(StorageLevel.MEMORY_ONLY_SER)
def windowDuration: Duration = _windowDuration
@@ -49,6 +50,14 @@ class WindowedDStream[T: ClassTag](
override def parentRememberDuration: Duration = rememberDuration + windowDuration
+ override def persist(level: StorageLevel): DStream[T] = {
+ // Do not let this windowed DStream be persisted as windowed (union-ed) RDDs share underlying
+ // RDDs and persisting the windowed RDDs would store numerous copies of the underlying data.
+ // Instead control the persistence of the parent DStream.
+ parent.persist(level)
+ this
+ }
+
override def compute(validTime: Time): Option[RDD[T]] = {
val currentWindow = new Interval(validTime - windowDuration + parent.slideDuration, validTime)
val rddsInWindow = parent.slice(currentWindow)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
index c8ee93bf5b..7e0f6b2cdf 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
@@ -18,6 +18,7 @@
package org.apache.spark.streaming.scheduler
import org.apache.spark.streaming.Time
+import scala.util.Try
/**
* Class representing a Spark computation. It may contain multiple Spark jobs.
@@ -25,12 +26,10 @@ import org.apache.spark.streaming.Time
private[streaming]
class Job(val time: Time, func: () => _) {
var id: String = _
+ var result: Try[_] = null
- def run(): Long = {
- val startTime = System.currentTimeMillis
- func()
- val stopTime = System.currentTimeMillis
- (stopTime - startTime)
+ def run() {
+ result = Try(func())
}
def setId(number: Int) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 2fa6853ae0..b5f11d3440 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -17,11 +17,11 @@
package org.apache.spark.streaming.scheduler
-import akka.actor.{Props, Actor}
-import org.apache.spark.SparkEnv
-import org.apache.spark.Logging
+import akka.actor.{ActorRef, ActorSystem, Props, Actor}
+import org.apache.spark.{SparkException, SparkEnv, Logging}
import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter}
import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock}
+import scala.util.{Failure, Success, Try}
/** Event classes for JobGenerator */
private[scheduler] sealed trait JobGeneratorEvent
@@ -37,29 +37,38 @@ private[scheduler] case class ClearCheckpointData(time: Time) extends JobGenerat
private[streaming]
class JobGenerator(jobScheduler: JobScheduler) extends Logging {
- val ssc = jobScheduler.ssc
- val graph = ssc.graph
- val eventProcessorActor = ssc.env.actorSystem.actorOf(Props(new Actor {
- def receive = {
- case event: JobGeneratorEvent =>
- logDebug("Got event of type " + event.getClass.getName)
- processEvent(event)
- }
- }))
+ private val ssc = jobScheduler.ssc
+ private val graph = ssc.graph
val clock = {
val clockClass = ssc.sc.conf.get(
"spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock")
Class.forName(clockClass).newInstance().asInstanceOf[Clock]
}
- val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
- longTime => eventProcessorActor ! GenerateJobs(new Time(longTime)))
- lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
+ private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
+ longTime => eventActor ! GenerateJobs(new Time(longTime)))
+ private lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration)
} else {
null
}
+ // eventActor is created when generator starts.
+ // This not being null means the scheduler has been started and not stopped
+ private var eventActor: ActorRef = null
+
+ /** Start generation of jobs */
def start() = synchronized {
+ if (eventActor != null) {
+ throw new SparkException("JobGenerator already started")
+ }
+
+ eventActor = ssc.env.actorSystem.actorOf(Props(new Actor {
+ def receive = {
+ case event: JobGeneratorEvent =>
+ logDebug("Got event of type " + event.getClass.getName)
+ processEvent(event)
+ }
+ }), "JobGenerator")
if (ssc.isCheckpointPresent) {
restart()
} else {
@@ -67,22 +76,26 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
}
}
- def stop() {
- timer.stop()
- if (checkpointWriter != null) checkpointWriter.stop()
- ssc.graph.stop()
- logInfo("JobGenerator stopped")
+ /** Stop generation of jobs */
+ def stop() = synchronized {
+ if (eventActor != null) {
+ timer.stop()
+ ssc.env.actorSystem.stop(eventActor)
+ if (checkpointWriter != null) checkpointWriter.stop()
+ ssc.graph.stop()
+ logInfo("JobGenerator stopped")
+ }
}
/**
* On batch completion, clear old metadata and checkpoint computation.
*/
- private[scheduler] def onBatchCompletion(time: Time) {
- eventProcessorActor ! ClearMetadata(time)
+ def onBatchCompletion(time: Time) {
+ eventActor ! ClearMetadata(time)
}
- private[streaming] def onCheckpointCompletion(time: Time) {
- eventProcessorActor ! ClearCheckpointData(time)
+ def onCheckpointCompletion(time: Time) {
+ eventActor ! ClearCheckpointData(time)
}
/** Processes all events */
@@ -121,14 +134,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
val checkpointTime = ssc.initialCheckpoint.checkpointTime
val restartTime = new Time(timer.getRestartTime(graph.zeroTime.milliseconds))
val downTimes = checkpointTime.until(restartTime, batchDuration)
- logInfo("Batches during down time (" + downTimes.size + " batches): " + downTimes.mkString(", "))
+ logInfo("Batches during down time (" + downTimes.size + " batches): "
+ + downTimes.mkString(", "))
// Batches that were unprocessed before failure
val pendingTimes = ssc.initialCheckpoint.pendingTimes.sorted(Time.ordering)
- logInfo("Batches pending processing (" + pendingTimes.size + " batches): " + pendingTimes.mkString(", "))
+ logInfo("Batches pending processing (" + pendingTimes.size + " batches): " +
+ pendingTimes.mkString(", "))
// Reschedule jobs for these times
val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering)
- logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", "))
+ logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " +
+ timesToReschedule.mkString(", "))
timesToReschedule.foreach(time =>
jobScheduler.runJobs(time, graph.generateJobs(time))
)
@@ -141,15 +157,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
/** Generate jobs and perform checkpoint for the given `time`. */
private def generateJobs(time: Time) {
SparkEnv.set(ssc.env)
- logInfo("\n-----------------------------------------------------\n")
- jobScheduler.runJobs(time, graph.generateJobs(time))
- eventProcessorActor ! DoCheckpoint(time)
+ Try(graph.generateJobs(time)) match {
+ case Success(jobs) => jobScheduler.runJobs(time, jobs)
+ case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e)
+ }
+ eventActor ! DoCheckpoint(time)
}
/** Clear DStream metadata for the given `time`. */
private def clearMetadata(time: Time) {
ssc.graph.clearMetadata(time)
- eventProcessorActor ! DoCheckpoint(time)
+ eventActor ! DoCheckpoint(time)
}
/** Clear DStream checkpoint data for the given `time`. */
@@ -166,4 +184,3 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
}
}
}
-
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index 30c070c274..de675d3c7f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -17,36 +17,68 @@
package org.apache.spark.streaming.scheduler
-import org.apache.spark.Logging
-import org.apache.spark.SparkEnv
+import scala.util.{Failure, Success, Try}
+import scala.collection.JavaConversions._
import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors}
-import scala.collection.mutable.HashSet
+import akka.actor.{ActorRef, Actor, Props}
+import org.apache.spark.{SparkException, Logging, SparkEnv}
import org.apache.spark.streaming._
+
+private[scheduler] sealed trait JobSchedulerEvent
+private[scheduler] case class JobStarted(job: Job) extends JobSchedulerEvent
+private[scheduler] case class JobCompleted(job: Job) extends JobSchedulerEvent
+private[scheduler] case class ErrorReported(msg: String, e: Throwable) extends JobSchedulerEvent
+
/**
* This class schedules jobs to be run on Spark. It uses the JobGenerator to generate
- * the jobs and runs them using a thread pool. Number of threads
+ * the jobs and runs them using a thread pool.
*/
private[streaming]
class JobScheduler(val ssc: StreamingContext) extends Logging {
- val jobSets = new ConcurrentHashMap[Time, JobSet]
- val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1)
- val executor = Executors.newFixedThreadPool(numConcurrentJobs)
- val generator = new JobGenerator(this)
+ private val jobSets = new ConcurrentHashMap[Time, JobSet]
+ private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1)
+ private val executor = Executors.newFixedThreadPool(numConcurrentJobs)
+ private val jobGenerator = new JobGenerator(this)
+ val clock = jobGenerator.clock
val listenerBus = new StreamingListenerBus()
- def clock = generator.clock
+ // These two are created only when scheduler starts.
+ // eventActor not being null means the scheduler has been started and not stopped
+ var networkInputTracker: NetworkInputTracker = null
+ private var eventActor: ActorRef = null
+
+
+ def start() = synchronized {
+ if (eventActor != null) {
+ throw new SparkException("JobScheduler already started")
+ }
- def start() {
- generator.start()
+ eventActor = ssc.env.actorSystem.actorOf(Props(new Actor {
+ def receive = {
+ case event: JobSchedulerEvent => processEvent(event)
+ }
+ }), "JobScheduler")
+ listenerBus.start()
+ networkInputTracker = new NetworkInputTracker(ssc)
+ networkInputTracker.start()
+ Thread.sleep(1000)
+ jobGenerator.start()
+ logInfo("JobScheduler started")
}
- def stop() {
- generator.stop()
- executor.shutdown()
- if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
- executor.shutdownNow()
+ def stop() = synchronized {
+ if (eventActor != null) {
+ jobGenerator.stop()
+ networkInputTracker.stop()
+ executor.shutdown()
+ if (!executor.awaitTermination(2, TimeUnit.SECONDS)) {
+ executor.shutdownNow()
+ }
+ listenerBus.stop()
+ ssc.env.actorSystem.stop(eventActor)
+ logInfo("JobScheduler stopped")
}
}
@@ -61,46 +93,67 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
}
}
- def getPendingTimes(): Array[Time] = {
- jobSets.keySet.toArray(new Array[Time](0))
+ def getPendingTimes(): Seq[Time] = {
+ jobSets.keySet.toSeq
+ }
+
+ def reportError(msg: String, e: Throwable) {
+ eventActor ! ErrorReported(msg, e)
}
- private def beforeJobStart(job: Job) {
+ private def processEvent(event: JobSchedulerEvent) {
+ try {
+ event match {
+ case JobStarted(job) => handleJobStart(job)
+ case JobCompleted(job) => handleJobCompletion(job)
+ case ErrorReported(m, e) => handleError(m, e)
+ }
+ } catch {
+ case e: Throwable =>
+ reportError("Error in job scheduler", e)
+ }
+ }
+
+ private def handleJobStart(job: Job) {
val jobSet = jobSets.get(job.time)
if (!jobSet.hasStarted) {
- listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo()))
+ listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo))
}
- jobSet.beforeJobStart(job)
+ jobSet.handleJobStart(job)
logInfo("Starting job " + job.id + " from job set of time " + jobSet.time)
- SparkEnv.set(generator.ssc.env)
+ SparkEnv.set(ssc.env)
}
- private def afterJobEnd(job: Job) {
- val jobSet = jobSets.get(job.time)
- jobSet.afterJobStop(job)
- logInfo("Finished job " + job.id + " from job set of time " + jobSet.time)
- if (jobSet.hasCompleted) {
- jobSets.remove(jobSet.time)
- generator.onBatchCompletion(jobSet.time)
- logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format(
- jobSet.totalDelay / 1000.0, jobSet.time.toString,
- jobSet.processingDelay / 1000.0
- ))
- listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo()))
+ private def handleJobCompletion(job: Job) {
+ job.result match {
+ case Success(_) =>
+ val jobSet = jobSets.get(job.time)
+ jobSet.handleJobCompletion(job)
+ logInfo("Finished job " + job.id + " from job set of time " + jobSet.time)
+ if (jobSet.hasCompleted) {
+ jobSets.remove(jobSet.time)
+ jobGenerator.onBatchCompletion(jobSet.time)
+ logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format(
+ jobSet.totalDelay / 1000.0, jobSet.time.toString,
+ jobSet.processingDelay / 1000.0
+ ))
+ listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo))
+ }
+ case Failure(e) =>
+ reportError("Error running job " + job, e)
}
}
- private[streaming]
- class JobHandler(job: Job) extends Runnable {
+ private def handleError(msg: String, e: Throwable) {
+ logError(msg, e)
+ ssc.waiter.notifyError(e)
+ }
+
+ private class JobHandler(job: Job) extends Runnable {
def run() {
- beforeJobStart(job)
- try {
- job.run()
- } catch {
- case e: Exception =>
- logError("Running " + job + " failed", e)
- }
- afterJobEnd(job)
+ eventActor ! JobStarted(job)
+ job.run()
+ eventActor ! JobCompleted(job)
}
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
index 57268674ea..fcf303aee6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming.scheduler
-import scala.collection.mutable.HashSet
+import scala.collection.mutable.{ArrayBuffer, HashSet}
import org.apache.spark.streaming.Time
/** Class representing a set of Jobs
@@ -27,25 +27,25 @@ private[streaming]
case class JobSet(time: Time, jobs: Seq[Job]) {
private val incompleteJobs = new HashSet[Job]()
- var submissionTime = System.currentTimeMillis() // when this jobset was submitted
- var processingStartTime = -1L // when the first job of this jobset started processing
- var processingEndTime = -1L // when the last job of this jobset finished processing
+ private val submissionTime = System.currentTimeMillis() // when this jobset was submitted
+ private var processingStartTime = -1L // when the first job of this jobset started processing
+ private var processingEndTime = -1L // when the last job of this jobset finished processing
jobs.zipWithIndex.foreach { case (job, i) => job.setId(i) }
incompleteJobs ++= jobs
- def beforeJobStart(job: Job) {
+ def handleJobStart(job: Job) {
if (processingStartTime < 0) processingStartTime = System.currentTimeMillis()
}
- def afterJobStop(job: Job) {
+ def handleJobCompletion(job: Job) {
incompleteJobs -= job
if (hasCompleted) processingEndTime = System.currentTimeMillis()
}
- def hasStarted() = (processingStartTime > 0)
+ def hasStarted = processingStartTime > 0
- def hasCompleted() = incompleteJobs.isEmpty
+ def hasCompleted = incompleteJobs.isEmpty
// Time taken to process all the jobs from the time they started processing
// (i.e. not including the time they wait in the streaming scheduler queue)
@@ -57,7 +57,7 @@ case class JobSet(time: Time, jobs: Seq[Job]) {
processingEndTime - time.milliseconds
}
- def toBatchInfo(): BatchInfo = {
+ def toBatchInfo: BatchInfo = {
new BatchInfo(
time,
submissionTime,
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
index 75f7244643..0d9733fa69 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
@@ -19,8 +19,7 @@ package org.apache.spark.streaming.scheduler
import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver}
import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError}
-import org.apache.spark.Logging
-import org.apache.spark.SparkEnv
+import org.apache.spark.{SparkException, Logging, SparkEnv}
import org.apache.spark.SparkContext._
import scala.collection.mutable.HashMap
@@ -32,6 +31,7 @@ import akka.pattern.ask
import akka.dispatch._
import org.apache.spark.storage.BlockId
import org.apache.spark.streaming.{Time, StreamingContext}
+import org.apache.spark.util.AkkaUtils
private[streaming] sealed trait NetworkInputTrackerMessage
private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage
@@ -39,33 +39,47 @@ private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[BlockId], m
private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage
/**
- * This class manages the execution of the receivers of NetworkInputDStreams.
+ * This class manages the execution of the receivers of NetworkInputDStreams. Instance of
+ * this class must be created after all input streams have been added and StreamingContext.start()
+ * has been called because it needs the final set of input streams at the time of instantiation.
*/
private[streaming]
-class NetworkInputTracker(
- @transient ssc: StreamingContext,
- @transient networkInputStreams: Array[NetworkInputDStream[_]])
- extends Logging {
+class NetworkInputTracker(ssc: StreamingContext) extends Logging {
+ val networkInputStreams = ssc.graph.getNetworkInputStreams()
val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*)
val receiverExecutor = new ReceiverExecutor()
val receiverInfo = new HashMap[Int, ActorRef]
val receivedBlockIds = new HashMap[Int, Queue[BlockId]]
- val timeout = 5000.milliseconds
+ val timeout = AkkaUtils.askTimeout(ssc.conf)
+
+ // actor is created when generator starts.
+ // This not being null means the tracker has been started and not stopped
+ var actor: ActorRef = null
var currentTime: Time = null
/** Start the actor and receiver execution thread. */
def start() {
- ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker")
- receiverExecutor.start()
+ if (actor != null) {
+ throw new SparkException("NetworkInputTracker already started")
+ }
+
+ if (!networkInputStreams.isEmpty) {
+ actor = ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker")
+ receiverExecutor.start()
+ logInfo("NetworkInputTracker started")
+ }
}
/** Stop the receiver execution thread. */
def stop() {
- // TODO: stop the actor as well
- receiverExecutor.interrupt()
- receiverExecutor.stopReceivers()
+ if (!networkInputStreams.isEmpty && actor != null) {
+ receiverExecutor.interrupt()
+ receiverExecutor.stopReceivers()
+ ssc.env.actorSystem.stop(actor)
+ logInfo("NetworkInputTracker stopped")
+ }
}
/** Return all the blocks received from a receiver. */
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala
index 36225e190c..461ea35064 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala
@@ -24,9 +24,10 @@ import org.apache.spark.util.Distribution
sealed trait StreamingListenerEvent
case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends StreamingListenerEvent
-
case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent
+/** An event used in the listener to shutdown the listener daemon thread. */
+private[scheduler] case object StreamingListenerShutdown extends StreamingListenerEvent
/**
* A listener interface for receiving information about an ongoing streaming
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala
index 110a20f282..3063cf10a3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala
@@ -31,7 +31,7 @@ private[spark] class StreamingListenerBus() extends Logging {
private val eventQueue = new LinkedBlockingQueue[StreamingListenerEvent](EVENT_QUEUE_CAPACITY)
private var queueFullErrorMessageLogged = false
- new Thread("StreamingListenerBus") {
+ val listenerThread = new Thread("StreamingListenerBus") {
setDaemon(true)
override def run() {
while (true) {
@@ -41,11 +41,18 @@ private[spark] class StreamingListenerBus() extends Logging {
listeners.foreach(_.onBatchStarted(batchStarted))
case batchCompleted: StreamingListenerBatchCompleted =>
listeners.foreach(_.onBatchCompleted(batchCompleted))
+ case StreamingListenerShutdown =>
+ // Get out of the while loop and shutdown the daemon thread
+ return
case _ =>
}
}
}
- }.start()
+ }
+
+ def start() {
+ listenerThread.start()
+ }
def addListener(listener: StreamingListener) {
listeners += listener
@@ -54,9 +61,9 @@ private[spark] class StreamingListenerBus() extends Logging {
def post(event: StreamingListenerEvent) {
val eventAdded = eventQueue.offer(event)
if (!eventAdded && !queueFullErrorMessageLogged) {
- logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
- "This likely means one of the SparkListeners is too slow and cannot keep up with the " +
- "rate at which tasks are being started by the scheduler.")
+ logError("Dropping StreamingListenerEvent because no remaining room in event queue. " +
+ "This likely means one of the StreamingListeners is too slow and cannot keep up with the " +
+ "rate at which events are being started by the scheduler.")
queueFullErrorMessageLogged = true
}
}
@@ -68,7 +75,7 @@ private[spark] class StreamingListenerBus() extends Logging {
*/
def waitUntilEmpty(timeoutMillis: Int): Boolean = {
val finishTime = System.currentTimeMillis + timeoutMillis
- while (!eventQueue.isEmpty()) {
+ while (!eventQueue.isEmpty) {
if (System.currentTimeMillis > finishTime) {
return false
}
@@ -76,6 +83,8 @@ private[spark] class StreamingListenerBus() extends Logging {
* add overhead in the general case. */
Thread.sleep(10)
}
- return true
+ true
}
+
+ def stop(): Unit = post(StreamingListenerShutdown)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala
index f67bb2f6ac..c3a849d276 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala
@@ -66,7 +66,7 @@ class SystemClock() extends Clock {
}
Thread.sleep(sleepTime)
}
- return -1
+ -1
}
}
@@ -96,6 +96,6 @@ class ManualClock() extends Clock {
this.wait(100)
}
}
- return currentTime()
+ currentTime()
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala
index 162b19d7f0..be67af3a64 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala
@@ -20,7 +20,7 @@ package org.apache.spark.streaming.util
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming._
-import org.apache.spark.streaming.dstream.ForEachDStream
+import org.apache.spark.streaming.dstream.{DStream, ForEachDStream}
import StreamingContext._
import scala.util.Random
@@ -186,7 +186,6 @@ object MasterFailureTest extends Logging {
// Setup the streaming computation with the given operation
System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
val ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map())
ssc.checkpoint(checkpointDir.toString)
val inputStream = ssc.textFileStream(testDir.toString)
@@ -233,7 +232,6 @@ object MasterFailureTest extends Logging {
// (iii) Its not timed out yet
System.clearProperty("spark.streaming.clock")
System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
ssc.start()
val startTime = System.currentTimeMillis()
while (!killed && !isLastOutputGenerated && !isTimedOut) {
diff --git a/core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala
index 47e1b45004..b9c0596378 100644
--- a/core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.util
+package org.apache.spark.streaming.util
import scala.annotation.tailrec
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala
index 4e6ce6eabd..5b6c048a39 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala
@@ -90,7 +90,7 @@ object RawTextHelper {
}
}
}
- return taken.toIterator
+ taken.toIterator
}
/**
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala
index 6585d494a6..463617a713 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala
@@ -17,14 +17,17 @@
package org.apache.spark.streaming.util
-import java.nio.ByteBuffer
-import org.apache.spark.util.{RateLimitedOutputStream, IntParam}
+import java.io.IOException
import java.net.ServerSocket
-import org.apache.spark.{SparkConf, Logging}
-import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+import java.nio.ByteBuffer
+
import scala.io.Source
-import java.io.IOException
+
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+
+import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.util.IntParam
/**
* A helper program that sends blocks of Kryo-serialized text strings out on a socket at a
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
index d644240405..559c247385 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
@@ -20,17 +20,7 @@ package org.apache.spark.streaming.util
private[streaming]
class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) {
- private val minPollTime = 25L
-
- private val pollTime = {
- if (period / 10.0 > minPollTime) {
- (period / 10.0).toLong
- } else {
- minPollTime
- }
- }
-
- private val thread = new Thread() {
+ private val thread = new Thread("RecurringTimer") {
override def run() { loop }
}
@@ -66,7 +56,6 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) =>
callback(nextTime)
nextTime += period
}
-
} catch {
case e: InterruptedException =>
}
diff --git a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
index 34bee56885..849bbf1299 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
@@ -28,7 +28,6 @@ public abstract class LocalJavaStreamingContext {
@Before
public void setUp() {
System.clearProperty("spark.driver.port");
- System.clearProperty("spark.hostPort");
System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock");
ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000));
ssc.checkpoint("checkpoint");
@@ -41,6 +40,5 @@ public abstract class LocalJavaStreamingContext {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port");
- System.clearProperty("spark.hostPort");
}
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index 5ccef7f461..b73edf81d4 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -24,6 +24,9 @@ import org.apache.spark.SparkContext._
import util.ManualClock
import org.apache.spark.{SparkContext, SparkConf}
+import org.apache.spark.streaming.dstream.{WindowedDStream, DStream}
+import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import scala.reflect.ClassTag
class BasicOperationsSuite extends TestSuiteBase {
test("map") {
@@ -375,15 +378,11 @@ class BasicOperationsSuite extends TestSuiteBase {
}
test("slice") {
- val conf2 = conf.clone()
- .setMaster("local[2]")
- .setAppName("BasicOperationsSuite")
- .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
- val ssc = new StreamingContext(new SparkContext(conf2), Seconds(1))
+ val ssc = new StreamingContext(conf, Seconds(1))
val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4))
val stream = new TestInputStream[Int](ssc, input, 2)
ssc.registerInputStream(stream)
- stream.foreach(_ => {}) // Dummy output stream
+ stream.foreachRDD(_ => {}) // Dummy output stream
ssc.start()
Thread.sleep(2000)
def getInputFromSlice(fromMillis: Long, toMillis: Long) = {
@@ -398,40 +397,31 @@ class BasicOperationsSuite extends TestSuiteBase {
Thread.sleep(1000)
}
- test("forgetting of RDDs - map and window operations") {
- assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second")
+ val cleanupTestInput = (0 until 10).map(x => Seq(x, x + 1)).toSeq
- val input = (0 until 10).map(x => Seq(x, x + 1)).toSeq
+ test("rdd cleanup - map and window") {
val rememberDuration = Seconds(3)
-
- assert(input.size === 10, "Number of inputs have changed")
-
def operation(s: DStream[Int]): DStream[(Int, Int)] = {
s.map(x => (x % 10, 1))
.window(Seconds(2), Seconds(1))
.window(Seconds(4), Seconds(2))
}
- val ssc = setupStreams(input, operation _)
- ssc.remember(rememberDuration)
- runStreams[(Int, Int)](ssc, input.size, input.size / 2)
-
- val windowedStream2 = ssc.graph.getOutputStreams().head.dependencies.head
- val windowedStream1 = windowedStream2.dependencies.head
+ val operatedStream = runCleanupTest(conf, operation _,
+ numExpectedOutput = cleanupTestInput.size / 2, rememberDuration = Seconds(3))
+ val windowedStream2 = operatedStream.asInstanceOf[WindowedDStream[_]]
+ val windowedStream1 = windowedStream2.dependencies.head.asInstanceOf[WindowedDStream[_]]
val mappedStream = windowedStream1.dependencies.head
- val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
- assert(clock.time === Seconds(10).milliseconds)
-
- // IDEALLY
- // WindowedStream2 should remember till 7 seconds: 10, 8,
- // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5
- // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3,
+ // Checkpoint remember durations
+ assert(windowedStream2.rememberDuration === rememberDuration)
+ assert(windowedStream1.rememberDuration === rememberDuration + windowedStream2.windowDuration)
+ assert(mappedStream.rememberDuration ===
+ rememberDuration + windowedStream2.windowDuration + windowedStream1.windowDuration)
- // IN THIS TEST
- // WindowedStream2 should remember till 7 seconds: 10, 8,
+ // WindowedStream2 should remember till 7 seconds: 10, 9, 8, 7
// WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5, 4
- // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3, 2
+ // MappedStream should remember till 2 seconds: 10, 9, 8, 7, 6, 5, 4, 3, 2
// WindowedStream2
assert(windowedStream2.generatedRDDs.contains(Time(10000)))
@@ -448,4 +438,37 @@ class BasicOperationsSuite extends TestSuiteBase {
assert(mappedStream.generatedRDDs.contains(Time(2000)))
assert(!mappedStream.generatedRDDs.contains(Time(1000)))
}
+
+ test("rdd cleanup - updateStateByKey") {
+ val updateFunc = (values: Seq[Int], state: Option[Int]) => {
+ Some(values.foldLeft(0)(_ + _) + state.getOrElse(0))
+ }
+ val stateStream = runCleanupTest(
+ conf, _.map(_ -> 1).updateStateByKey(updateFunc).checkpoint(Seconds(3)))
+
+ assert(stateStream.rememberDuration === stateStream.checkpointDuration * 2)
+ assert(stateStream.generatedRDDs.contains(Time(10000)))
+ assert(!stateStream.generatedRDDs.contains(Time(4000)))
+ }
+
+ /** Test cleanup of RDDs in DStream metadata */
+ def runCleanupTest[T: ClassTag](
+ conf2: SparkConf,
+ operation: DStream[Int] => DStream[T],
+ numExpectedOutput: Int = cleanupTestInput.size,
+ rememberDuration: Duration = null
+ ): DStream[T] = {
+
+ // Setup the stream computation
+ assert(batchDuration === Seconds(1),
+ "Batch duration has changed from 1 second, check cleanup tests")
+ val ssc = setupStreams(cleanupTestInput, operation)
+ val operatedStream = ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]]
+ if (rememberDuration != null) ssc.remember(rememberDuration)
+ val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput)
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ assert(clock.time === Seconds(10).milliseconds)
+ assert(output.size === numExpectedOutput)
+ operatedStream
+ }
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 6499de98c9..0c68c44ddb 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -26,8 +26,10 @@ import com.google.common.io.Files
import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.hadoop.conf.Configuration
import org.apache.spark.streaming.StreamingContext._
-import org.apache.spark.streaming.dstream.FileInputDStream
+import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
import org.apache.spark.streaming.util.ManualClock
+import org.apache.spark.util.Utils
+import org.apache.spark.SparkConf
/**
* This test suites tests the checkpointing functionality of DStreams -
@@ -142,6 +144,26 @@ class CheckpointSuite extends TestSuiteBase {
ssc = null
}
+ // This tests whether spark conf persists through checkpoints, and certain
+ // configs gets scrubbed
+ test("persistence of conf through checkpoints") {
+ val key = "spark.mykey"
+ val value = "myvalue"
+ System.setProperty(key, value)
+ ssc = new StreamingContext(master, framework, batchDuration)
+ val cp = new Checkpoint(ssc, Time(1000))
+ assert(!cp.sparkConf.contains("spark.driver.host"))
+ assert(!cp.sparkConf.contains("spark.driver.port"))
+ assert(!cp.sparkConf.contains("spark.hostPort"))
+ assert(cp.sparkConf.get(key) === value)
+ ssc.stop()
+ val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp))
+ assert(!newCp.sparkConf.contains("spark.driver.host"))
+ assert(!newCp.sparkConf.contains("spark.driver.port"))
+ assert(!newCp.sparkConf.contains("spark.hostPort"))
+ assert(newCp.sparkConf.get(key) === value)
+ }
+
// This tests whether the systm can recover from a master failure with simple
// non-stateful operations. This assumes as reliable, replayable input
@@ -336,7 +358,6 @@ class CheckpointSuite extends TestSuiteBase {
)
ssc = new StreamingContext(checkpointDir)
System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
ssc.start()
val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches)
// the first element will be re-processed data of the last batch before restart
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
new file mode 100644
index 0000000000..f7f3346f81
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import org.scalatest.{FunSuite, BeforeAndAfter}
+import org.scalatest.exceptions.TestFailedDueToTimeoutException
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.time.SpanSugar._
+import org.apache.spark.{SparkException, SparkConf, SparkContext}
+import org.apache.spark.util.{Utils, MetadataCleaner}
+import org.apache.spark.streaming.dstream.DStream
+
+class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
+
+ val master = "local[2]"
+ val appName = this.getClass.getSimpleName
+ val batchDuration = Seconds(1)
+ val sparkHome = "someDir"
+ val envPair = "key" -> "value"
+ val ttl = StreamingContext.DEFAULT_CLEANER_TTL + 100
+
+ var sc: SparkContext = null
+ var ssc: StreamingContext = null
+
+ before {
+ System.clearProperty("spark.cleaner.ttl")
+ }
+
+ after {
+ if (ssc != null) {
+ ssc.stop()
+ ssc = null
+ }
+ if (sc != null) {
+ sc.stop()
+ sc = null
+ }
+ }
+
+ test("from no conf constructor") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ assert(ssc.sparkContext.conf.get("spark.master") === master)
+ assert(ssc.sparkContext.conf.get("spark.app.name") === appName)
+ assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) ===
+ StreamingContext.DEFAULT_CLEANER_TTL)
+ }
+
+ test("from no conf + spark home") {
+ ssc = new StreamingContext(master, appName, batchDuration, sparkHome, Nil)
+ assert(ssc.conf.get("spark.home") === sparkHome)
+ assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) ===
+ StreamingContext.DEFAULT_CLEANER_TTL)
+ }
+
+ test("from no conf + spark home + env") {
+ ssc = new StreamingContext(master, appName, batchDuration,
+ sparkHome, Nil, Map(envPair))
+ assert(ssc.conf.getExecutorEnv.exists(_ == envPair))
+ assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) ===
+ StreamingContext.DEFAULT_CLEANER_TTL)
+ }
+
+ test("from conf without ttl set") {
+ val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
+ ssc = new StreamingContext(myConf, batchDuration)
+ assert(MetadataCleaner.getDelaySeconds(ssc.conf) ===
+ StreamingContext.DEFAULT_CLEANER_TTL)
+ }
+
+ test("from conf with ttl set") {
+ val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
+ myConf.set("spark.cleaner.ttl", ttl.toString)
+ ssc = new StreamingContext(myConf, batchDuration)
+ assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === ttl)
+ }
+
+ test("from existing SparkContext without ttl set") {
+ sc = new SparkContext(master, appName)
+ val exception = intercept[SparkException] {
+ ssc = new StreamingContext(sc, batchDuration)
+ }
+ assert(exception.getMessage.contains("ttl"))
+ }
+
+ test("from existing SparkContext with ttl set") {
+ val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
+ myConf.set("spark.cleaner.ttl", ttl.toString)
+ ssc = new StreamingContext(myConf, batchDuration)
+ assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === ttl)
+ }
+
+ test("from checkpoint") {
+ val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
+ myConf.set("spark.cleaner.ttl", ttl.toString)
+ val ssc1 = new StreamingContext(myConf, batchDuration)
+ val cp = new Checkpoint(ssc1, Time(1000))
+ assert(MetadataCleaner.getDelaySeconds(cp.sparkConf) === ttl)
+ ssc1.stop()
+ val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp))
+ assert(MetadataCleaner.getDelaySeconds(newCp.sparkConf) === ttl)
+ ssc = new StreamingContext(null, cp, null)
+ assert(MetadataCleaner.getDelaySeconds(ssc.conf) === ttl)
+ }
+
+ test("start multiple times") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ addInputStream(ssc).register
+
+ ssc.start()
+ intercept[SparkException] {
+ ssc.start()
+ }
+ }
+
+ test("stop multiple times") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ addInputStream(ssc).register
+ ssc.start()
+ ssc.stop()
+ ssc.stop()
+ ssc = null
+ }
+
+ test("stop only streaming context") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ sc = ssc.sparkContext
+ addInputStream(ssc).register
+ ssc.start()
+ ssc.stop(false)
+ ssc = null
+ assert(sc.makeRDD(1 to 100).collect().size === 100)
+ ssc = new StreamingContext(sc, batchDuration)
+ }
+
+ test("awaitTermination") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ val inputStream = addInputStream(ssc)
+ inputStream.map(x => x).register
+
+ // test whether start() blocks indefinitely or not
+ failAfter(2000 millis) {
+ ssc.start()
+ }
+
+ // test whether waitForStop() exits after give amount of time
+ failAfter(1000 millis) {
+ ssc.awaitTermination(500)
+ }
+
+ // test whether waitForStop() does not exit if not time is given
+ val exception = intercept[Exception] {
+ failAfter(1000 millis) {
+ ssc.awaitTermination()
+ throw new Exception("Did not wait for stop")
+ }
+ }
+ assert(exception.isInstanceOf[TestFailedDueToTimeoutException], "Did not wait for stop")
+
+ // test whether wait exits if context is stopped
+ failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown
+ new Thread() {
+ override def run {
+ Thread.sleep(500)
+ ssc.stop()
+ }
+ }.start()
+ ssc.awaitTermination()
+ }
+ }
+
+ test("awaitTermination with error in task") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ val inputStream = addInputStream(ssc)
+ inputStream.map(x => { throw new TestException("error in map task"); x})
+ .foreachRDD(_.count)
+
+ val exception = intercept[Exception] {
+ ssc.start()
+ ssc.awaitTermination(5000)
+ }
+ assert(exception.getMessage.contains("map task"), "Expected exception not thrown")
+ }
+
+ test("awaitTermination with error in job generation") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ val inputStream = addInputStream(ssc)
+
+ inputStream.transform(rdd => { throw new TestException("error in transform"); rdd }).register
+ val exception = intercept[TestException] {
+ ssc.start()
+ ssc.awaitTermination(5000)
+ }
+ assert(exception.getMessage.contains("transform"), "Expected exception not thrown")
+ }
+
+ def addInputStream(s: StreamingContext): DStream[Int] = {
+ val input = (1 to 100).map(i => (1 to i))
+ val inputStream = new TestInputStream(s, input, 1)
+ s.registerInputStream(inputStream)
+ inputStream
+ }
+}
+
+class TestException(msg: String) extends Exception(msg) \ No newline at end of file
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
index fa64142096..9e0f2c900e 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.streaming
import org.apache.spark.streaming.scheduler._
import scala.collection.mutable.ArrayBuffer
import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.streaming.dstream.DStream
class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index b20d02f996..535e5bd1f1 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming
-import org.apache.spark.streaming.dstream.{InputDStream, ForEachDStream}
+import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream}
import org.apache.spark.streaming.util.ManualClock
import scala.collection.mutable.ArrayBuffer
@@ -137,7 +137,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
val conf = new SparkConf()
.setMaster(master)
.setAppName(framework)
- .set("spark.cleaner.ttl", "3600")
+ .set("spark.cleaner.ttl", StreamingContext.DEFAULT_CLEANER_TTL.toString)
// Default before function for any streaming test suite. Override this
// if you want to add your stuff to "before" (i.e., don't call before { } )
@@ -156,7 +156,6 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
def afterFunction() {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
}
before(beforeFunction)
@@ -273,10 +272,11 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
val startTime = System.currentTimeMillis()
while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput)
- Thread.sleep(10)
+ ssc.awaitTermination(50)
}
val timeTaken = System.currentTimeMillis() - startTime
-
+ logInfo("Output generated in " + timeTaken + " milliseconds")
+ output.foreach(x => logInfo("[" + x.mkString(",") + "]"))
assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms")
assert(output.size === numExpectedOutput, "Unexpected number of outputs generated")
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
index c39abfc21b..471c99fab4 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
@@ -18,6 +18,8 @@
package org.apache.spark.streaming
import org.apache.spark.streaming.StreamingContext._
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.storage.StorageLevel
class WindowOperationsSuite extends TestSuiteBase {
@@ -143,6 +145,19 @@ class WindowOperationsSuite extends TestSuiteBase {
Seconds(3)
)
+ test("window - persistence level") {
+ val input = Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5))
+ val ssc = new StreamingContext(conf, batchDuration)
+ val inputStream = new TestInputStream[Int](ssc, input, 1)
+ val windowStream1 = inputStream.window(batchDuration * 2)
+ assert(windowStream1.storageLevel === StorageLevel.NONE)
+ assert(inputStream.storageLevel === StorageLevel.MEMORY_ONLY_SER)
+ windowStream1.persist(StorageLevel.MEMORY_ONLY)
+ assert(windowStream1.storageLevel === StorageLevel.NONE)
+ assert(inputStream.storageLevel === StorageLevel.MEMORY_ONLY)
+ ssc.stop()
+ }
+
// Testing naive reduceByKeyAndWindow (without invertible function)
testReduceByKeyAndWindow(
diff --git a/core/src/test/scala/org/apache/spark/util/RateLimitedOutputStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala
index a9dd0b1a5b..15f13d5b19 100644
--- a/core/src/test/scala/org/apache/spark/util/RateLimitedOutputStreamSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.util
+package org.apache.spark.streaming.util
import org.scalatest.FunSuite
import java.io.ByteArrayOutputStream
diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala
index f670f65bf5..4886cd6ea8 100644
--- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala
+++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala
@@ -24,8 +24,9 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark._
import org.apache.spark.api.java._
import org.apache.spark.rdd.{RDD, DoubleRDDFunctions, PairRDDFunctions, OrderedRDDFunctions}
-import org.apache.spark.streaming.{PairDStreamFunctions, DStream, StreamingContext}
+import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream, JavaStreamingContext}
+import org.apache.spark.streaming.dstream.{DStream, PairDStreamFunctions}
private[spark] abstract class SparkType(val name: String)
@@ -147,7 +148,7 @@ object JavaAPICompletenessChecker {
} else {
ParameterizedType(classOf[JavaRDD[_]].getName, parameters.map(applySubs))
}
- case "org.apache.spark.streaming.DStream" =>
+ case "org.apache.spark.streaming.dstream.DStream" =>
if (parameters(0).name == classOf[Tuple2[_, _]].getName) {
val tupleParams =
parameters(0).asInstanceOf[ParameterizedType].parameters.map(applySubs)
@@ -248,30 +249,29 @@ object JavaAPICompletenessChecker {
"org.apache.spark.SparkContext.getSparkHome",
"org.apache.spark.SparkContext.executorMemoryRequested",
"org.apache.spark.SparkContext.getExecutorStorageStatus",
- "org.apache.spark.streaming.DStream.generatedRDDs",
- "org.apache.spark.streaming.DStream.zeroTime",
- "org.apache.spark.streaming.DStream.rememberDuration",
- "org.apache.spark.streaming.DStream.storageLevel",
- "org.apache.spark.streaming.DStream.mustCheckpoint",
- "org.apache.spark.streaming.DStream.checkpointDuration",
- "org.apache.spark.streaming.DStream.checkpointData",
- "org.apache.spark.streaming.DStream.graph",
- "org.apache.spark.streaming.DStream.isInitialized",
- "org.apache.spark.streaming.DStream.parentRememberDuration",
- "org.apache.spark.streaming.DStream.initialize",
- "org.apache.spark.streaming.DStream.validate",
- "org.apache.spark.streaming.DStream.setContext",
- "org.apache.spark.streaming.DStream.setGraph",
- "org.apache.spark.streaming.DStream.remember",
- "org.apache.spark.streaming.DStream.getOrCompute",
- "org.apache.spark.streaming.DStream.generateJob",
- "org.apache.spark.streaming.DStream.clearOldMetadata",
- "org.apache.spark.streaming.DStream.addMetadata",
- "org.apache.spark.streaming.DStream.updateCheckpointData",
- "org.apache.spark.streaming.DStream.restoreCheckpointData",
- "org.apache.spark.streaming.DStream.isTimeValid",
+ "org.apache.spark.streaming.dstream.DStream.generatedRDDs",
+ "org.apache.spark.streaming.dstream.DStream.zeroTime",
+ "org.apache.spark.streaming.dstream.DStream.rememberDuration",
+ "org.apache.spark.streaming.dstream.DStream.storageLevel",
+ "org.apache.spark.streaming.dstream.DStream.mustCheckpoint",
+ "org.apache.spark.streaming.dstream.DStream.checkpointDuration",
+ "org.apache.spark.streaming.dstream.DStream.checkpointData",
+ "org.apache.spark.streaming.dstream.DStream.graph",
+ "org.apache.spark.streaming.dstream.DStream.isInitialized",
+ "org.apache.spark.streaming.dstream.DStream.parentRememberDuration",
+ "org.apache.spark.streaming.dstream.DStream.initialize",
+ "org.apache.spark.streaming.dstream.DStream.validate",
+ "org.apache.spark.streaming.dstream.DStream.setContext",
+ "org.apache.spark.streaming.dstream.DStream.setGraph",
+ "org.apache.spark.streaming.dstream.DStream.remember",
+ "org.apache.spark.streaming.dstream.DStream.getOrCompute",
+ "org.apache.spark.streaming.dstream.DStream.generateJob",
+ "org.apache.spark.streaming.dstream.DStream.clearOldMetadata",
+ "org.apache.spark.streaming.dstream.DStream.addMetadata",
+ "org.apache.spark.streaming.dstream.DStream.updateCheckpointData",
+ "org.apache.spark.streaming.dstream.DStream.restoreCheckpointData",
+ "org.apache.spark.streaming.dstream.DStream.isTimeValid",
"org.apache.spark.streaming.StreamingContext.nextNetworkInputStreamId",
- "org.apache.spark.streaming.StreamingContext.networkInputTracker",
"org.apache.spark.streaming.StreamingContext.checkpointDir",
"org.apache.spark.streaming.StreamingContext.checkpointDuration",
"org.apache.spark.streaming.StreamingContext.receiverJobThread",
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 23781ea35c..e56bc02897 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -158,7 +158,7 @@ class Client(args: ClientArguments, conf: Configuration, sparkConf: SparkConf)
val appContext = Records.newRecord(classOf[ApplicationSubmissionContext])
appContext.setApplicationId(appId)
appContext.setApplicationName(args.appName)
- return appContext
+ appContext
}
/** See if two file systems are the same or not. */
@@ -193,7 +193,8 @@ class Client(args: ClientArguments, conf: Configuration, sparkConf: SparkConf)
if (srcUri.getPort() != dstUri.getPort()) {
return false
}
- return true
+
+ true
}
/** Copy the file into HDFS if needed. */
@@ -299,7 +300,7 @@ class Client(args: ClientArguments, conf: Configuration, sparkConf: SparkConf)
}
UserGroupInformation.getCurrentUser().addCredentials(credentials)
- return localResources
+ localResources
}
def setupLaunchEnv(
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
index 62b20b8fba..9fe4d64a0f 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
@@ -145,7 +145,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
val containerId = ConverterUtils.toContainerId(containerIdString)
val appAttemptId = containerId.getApplicationAttemptId()
logInfo("ApplicationAttemptId: " + appAttemptId)
- return appAttemptId
+ appAttemptId
}
private def registerWithResourceManager(): AMRMProtocol = {
@@ -153,7 +153,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
YarnConfiguration.RM_SCHEDULER_ADDRESS,
YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS))
logInfo("Connecting to ResourceManager at " + rmAddress)
- return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
+ rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
}
private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
@@ -167,7 +167,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
appMasterRequest.setRpcPort(0)
// What do we provide here ? Might make sense to expose something sensible later ?
appMasterRequest.setTrackingUrl("")
- return resourceManager.registerApplicationMaster(appMasterRequest)
+ resourceManager.registerApplicationMaster(appMasterRequest)
}
private def waitForSparkMaster() {
@@ -240,7 +240,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
t.setDaemon(true)
t.start()
logInfo("Started progress reporter thread - sleep time : " + sleepTime)
- return t
+ t
}
private def sendProgress() {
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
index 132630e5ef..d32cdcc879 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
@@ -195,7 +195,7 @@ class WorkerRunnable(
}
logInfo("Prepared Local resources " + localResources)
- return localResources
+ localResources
}
def prepareEnvironment: HashMap[String, String] = {
@@ -207,7 +207,7 @@ class WorkerRunnable(
Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"))
System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
- return env
+ env
}
def connectToCM: ContainerManager = {
@@ -226,8 +226,7 @@ class WorkerRunnable(
val proxy = user
.doAs(new PrivilegedExceptionAction[ContainerManager] {
def run: ContainerManager = {
- return rpc.getProxy(classOf[ContainerManager],
- cmAddress, conf).asInstanceOf[ContainerManager]
+ rpc.getProxy(classOf[ContainerManager], cmAddress, conf).asInstanceOf[ContainerManager]
}
})
proxy
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
index 5f159b073f..535abbfb7f 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
@@ -143,7 +143,7 @@ class ClientDistributedCacheManager() extends Logging {
if (isPublic(conf, uri, statCache)) {
return LocalResourceVisibility.PUBLIC
}
- return LocalResourceVisibility.PRIVATE
+ LocalResourceVisibility.PRIVATE
}
/**
@@ -161,7 +161,7 @@ class ClientDistributedCacheManager() extends Logging {
if (!checkPermissionOfOther(fs, current, FsAction.READ, statCache)) {
return false
}
- return ancestorsHaveExecutePermissions(fs, current.getParent(), statCache)
+ ancestorsHaveExecutePermissions(fs, current.getParent(), statCache)
}
/**
@@ -183,7 +183,7 @@ class ClientDistributedCacheManager() extends Logging {
}
current = current.getParent()
}
- return true
+ true
}
/**
@@ -203,7 +203,7 @@ class ClientDistributedCacheManager() extends Logging {
if (otherAction.implies(action)) {
return true
}
- return false
+ false
}
/**
@@ -223,6 +223,6 @@ class ClientDistributedCacheManager() extends Logging {
statCache.put(uri, newStat)
newStat
}
- return stat
+ stat
}
}
diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
index 2941356bc5..458df4fa3c 100644
--- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
+++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
@@ -42,7 +42,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
class MockClientDistributedCacheManager extends ClientDistributedCacheManager {
override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]):
LocalResourceVisibility = {
- return LocalResourceVisibility.PRIVATE
+ LocalResourceVisibility.PRIVATE
}
}
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 952e963389..51d9adb9d4 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -208,7 +208,8 @@ class Client(args: ClientArguments, conf: Configuration, sparkConf: SparkConf)
if (srcUri.getPort() != dstUri.getPort()) {
return false
}
- return true
+
+ true
}
/** Copy the file into HDFS if needed. */