aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--bagel/pom.xml37
-rw-r--r--bagel/src/test/scala/bagel/BagelSuite.scala1
-rw-r--r--bin/compute-classpath.cmd52
-rwxr-xr-xbin/compute-classpath.sh89
-rwxr-xr-xbin/start-slave.sh3
-rwxr-xr-xconf/spark-env.sh.template18
-rw-r--r--core/pom.xml82
-rw-r--r--core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala3
-rw-r--r--core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala3
-rw-r--r--core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala23
-rw-r--r--core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala13
-rw-r--r--core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala13
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala63
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala329
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala77
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala272
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala105
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala171
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala547
-rw-r--r--core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala42
-rw-r--r--core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala3
-rw-r--r--core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala3
-rw-r--r--core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala23
-rw-r--r--core/src/main/java/spark/network/netty/FileClient.java72
-rw-r--r--core/src/main/java/spark/network/netty/FileClientChannelInitializer.java24
-rw-r--r--core/src/main/java/spark/network/netty/FileClientHandler.java43
-rw-r--r--core/src/main/java/spark/network/netty/FileServer.java86
-rw-r--r--core/src/main/java/spark/network/netty/FileServerChannelInitializer.java25
-rw-r--r--core/src/main/java/spark/network/netty/FileServerHandler.java65
-rwxr-xr-xcore/src/main/java/spark/network/netty/PathResolver.java12
-rw-r--r--core/src/main/scala/spark/BlockStoreShuffleFetcher.scala33
-rw-r--r--core/src/main/scala/spark/ClosureCleaner.scala23
-rw-r--r--core/src/main/scala/spark/Dependency.scala4
-rw-r--r--core/src/main/scala/spark/FetchFailedException.scala25
-rw-r--r--core/src/main/scala/spark/HadoopWriter.scala12
-rw-r--r--core/src/main/scala/spark/Logging.scala4
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala99
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala73
-rw-r--r--core/src/main/scala/spark/RDD.scala129
-rw-r--r--core/src/main/scala/spark/RDDCheckpointData.scala15
-rw-r--r--core/src/main/scala/spark/SequenceFileRDDFunctions.scala15
-rw-r--r--core/src/main/scala/spark/ShuffleFetcher.scala7
-rw-r--r--core/src/main/scala/spark/SizeEstimator.scala2
-rw-r--r--core/src/main/scala/spark/SparkContext.scala167
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala68
-rw-r--r--core/src/main/scala/spark/TaskEndReason.scala16
-rw-r--r--core/src/main/scala/spark/Utils.scala296
-rw-r--r--core/src/main/scala/spark/api/java/JavaPairRDD.scala11
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDD.scala9
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala50
-rw-r--r--core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala11
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala109
-rw-r--r--core/src/main/scala/spark/api/python/PythonWorkerFactory.scala113
-rw-r--r--core/src/main/scala/spark/deploy/ApplicationDescription.scala5
-rw-r--r--core/src/main/scala/spark/deploy/DeployMessage.scala19
-rw-r--r--core/src/main/scala/spark/deploy/JsonProtocol.scala5
-rw-r--r--core/src/main/scala/spark/deploy/LocalSparkCluster.scala8
-rw-r--r--core/src/main/scala/spark/deploy/client/Client.scala9
-rw-r--r--core/src/main/scala/spark/deploy/client/ClientListener.scala2
-rw-r--r--core/src/main/scala/spark/deploy/client/TestClient.scala4
-rw-r--r--core/src/main/scala/spark/deploy/master/ApplicationInfo.scala6
-rw-r--r--core/src/main/scala/spark/deploy/master/Master.scala19
-rw-r--r--core/src/main/scala/spark/deploy/master/MasterArguments.scala17
-rw-r--r--core/src/main/scala/spark/deploy/master/MasterWebUI.scala4
-rw-r--r--core/src/main/scala/spark/deploy/master/WorkerInfo.scala9
-rw-r--r--core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala36
-rw-r--r--core/src/main/scala/spark/deploy/worker/Worker.scala25
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerArguments.scala13
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala3
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala53
-rw-r--r--core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala32
-rw-r--r--core/src/main/scala/spark/executor/TaskMetrics.scala17
-rw-r--r--core/src/main/scala/spark/network/BufferMessage.scala94
-rw-r--r--core/src/main/scala/spark/network/Connection.scala291
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala549
-rw-r--r--core/src/main/scala/spark/network/ConnectionManagerId.scala21
-rw-r--r--core/src/main/scala/spark/network/Message.scala178
-rw-r--r--core/src/main/scala/spark/network/MessageChunk.scala25
-rw-r--r--core/src/main/scala/spark/network/MessageChunkHeader.scala58
-rw-r--r--core/src/main/scala/spark/network/netty/FileHeader.scala57
-rw-r--r--core/src/main/scala/spark/network/netty/ShuffleCopier.scala101
-rw-r--r--core/src/main/scala/spark/network/netty/ShuffleSender.scala53
-rw-r--r--core/src/main/scala/spark/rdd/BlockRDD.scala10
-rw-r--r--core/src/main/scala/spark/rdd/CheckpointRDD.scala31
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala20
-rw-r--r--core/src/main/scala/spark/rdd/EmptyRDD.scala16
-rw-r--r--core/src/main/scala/spark/rdd/JdbcRDD.scala103
-rw-r--r--core/src/main/scala/spark/rdd/NewHadoopRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/PipedRDD.scala27
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala10
-rw-r--r--core/src/main/scala/spark/rdd/SubtractedRDD.scala18
-rw-r--r--core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala138
-rw-r--r--core/src/main/scala/spark/rdd/ZippedRDD.scala23
-rw-r--r--core/src/main/scala/spark/scheduler/ActiveJob.scala5
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala59
-rw-r--r--core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/InputFormatInfo.scala156
-rw-r--r--core/src/main/scala/spark/scheduler/JobLogger.scala306
-rw-r--r--core/src/main/scala/spark/scheduler/ResultTask.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala62
-rw-r--r--core/src/main/scala/spark/scheduler/SparkListener.scala50
-rw-r--r--core/src/main/scala/spark/scheduler/SplitInfo.scala61
-rw-r--r--core/src/main/scala/spark/scheduler/Stage.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/TaskScheduler.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala3
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSet.scala11
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala364
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala747
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/Pool.scala104
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/Schedulable.scala27
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala115
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala11
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala64
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala7
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala7
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala34
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala431
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala225
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala172
-rw-r--r--core/src/main/scala/spark/serializer/Serializer.scala8
-rw-r--r--core/src/main/scala/spark/serializer/SerializerManager.scala45
-rw-r--r--core/src/main/scala/spark/storage/BlockException.scala5
-rw-r--r--core/src/main/scala/spark/storage/BlockFetchTracker.scala12
-rw-r--r--core/src/main/scala/spark/storage/BlockFetcherIterator.scala330
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala465
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerId.scala62
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMaster.scala24
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMasterActor.scala211
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMessages.scala3
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala8
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerUI.scala25
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerWorker.scala28
-rw-r--r--core/src/main/scala/spark/storage/BlockMessageArray.scala1
-rw-r--r--core/src/main/scala/spark/storage/BlockObjectWriter.scala50
-rw-r--r--core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala12
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala190
-rw-r--r--core/src/main/scala/spark/storage/MemoryStore.scala4
-rw-r--r--core/src/main/scala/spark/storage/ShuffleBlockManager.scala50
-rw-r--r--core/src/main/scala/spark/storage/StorageLevel.scala8
-rw-r--r--core/src/main/scala/spark/storage/StorageUtils.scala47
-rw-r--r--core/src/main/scala/spark/util/AkkaUtils.scala23
-rw-r--r--core/src/main/scala/spark/util/BoundedPriorityQueue.scala45
-rw-r--r--core/src/main/scala/spark/util/StatCounter.scala26
-rw-r--r--core/src/main/scala/spark/util/TimeStampedHashMap.scala8
-rw-r--r--core/src/main/scala/spark/util/TimedIterator.scala32
-rw-r--r--core/src/main/twirl/spark/deploy/master/app_details.scala.html12
-rw-r--r--core/src/main/twirl/spark/deploy/master/executor_row.scala.html2
-rw-r--r--core/src/main/twirl/spark/deploy/master/index.scala.html2
-rw-r--r--core/src/main/twirl/spark/deploy/master/worker_row.scala.html2
-rw-r--r--core/src/main/twirl/spark/deploy/worker/index.scala.html2
-rw-r--r--core/src/main/twirl/spark/storage/worker_table.scala.html2
-rw-r--r--core/src/test/resources/fairscheduler.xml14
-rw-r--r--core/src/test/scala/spark/CheckpointSuite.scala10
-rw-r--r--core/src/test/scala/spark/DistributedSuite.scala67
-rw-r--r--core/src/test/scala/spark/FileSuite.scala46
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java71
-rw-r--r--core/src/test/scala/spark/LocalSparkContext.scala3
-rw-r--r--core/src/test/scala/spark/MapOutputTrackerSuite.scala36
-rw-r--r--core/src/test/scala/spark/PairRDDFunctionsSuite.scala287
-rw-r--r--core/src/test/scala/spark/PartitioningSuite.scala30
-rw-r--r--core/src/test/scala/spark/PipedRDDSuite.scala45
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala98
-rw-r--r--core/src/test/scala/spark/SharedSparkContext.scala25
-rw-r--r--core/src/test/scala/spark/ShuffleNettySuite.scala17
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala322
-rw-r--r--core/src/test/scala/spark/SizeEstimatorSuite.scala2
-rw-r--r--core/src/test/scala/spark/SortingSuite.scala23
-rw-r--r--core/src/test/scala/spark/UnpersistSuite.scala30
-rw-r--r--core/src/test/scala/spark/UtilsSuite.scala53
-rw-r--r--core/src/test/scala/spark/ZippedPartitionsSuite.scala33
-rw-r--r--core/src/test/scala/spark/rdd/JdbcRDDSuite.scala56
-rw-r--r--core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala250
-rw-r--r--core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala20
-rw-r--r--core/src/test/scala/spark/scheduler/JobLoggerSuite.scala104
-rw-r--r--core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala206
-rw-r--r--core/src/test/scala/spark/scheduler/SparkListenerSuite.scala3
-rw-r--r--core/src/test/scala/spark/storage/BlockManagerSuite.scala76
-rw-r--r--docs/_plugins/copy_api_dirs.rb2
-rw-r--r--docs/configuration.md42
-rw-r--r--docs/ec2-scripts.md5
-rw-r--r--docs/python-programming-guide.md12
-rw-r--r--docs/quick-start.md21
-rw-r--r--docs/running-on-yarn.md60
-rw-r--r--docs/scala-programming-guide.md12
-rw-r--r--docs/tuning.md6
-rwxr-xr-xec2/spark_ec2.py2
-rw-r--r--examples/pom.xml93
-rw-r--r--examples/src/main/scala/spark/examples/CassandraTest.scala196
-rw-r--r--examples/src/main/scala/spark/examples/HBaseTest.scala35
-rw-r--r--examples/src/main/scala/spark/examples/SparkHdfsLR.scala10
-rw-r--r--examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala2
-rw-r--r--examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala50
-rw-r--r--examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala9
-rw-r--r--examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala9
-rw-r--r--examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala9
-rw-r--r--pom.xml108
-rw-r--r--project/SparkBuild.scala98
-rw-r--r--project/plugins.sbt2
-rw-r--r--python/pyspark/daemon.py164
-rw-r--r--python/pyspark/serializers.py4
-rw-r--r--python/pyspark/tests.py43
-rw-r--r--python/pyspark/worker.py55
-rw-r--r--repl-bin/pom.xml55
-rw-r--r--repl/pom.xml79
-rw-r--r--repl/src/main/scala/spark/repl/ExecutorClassLoader.scala3
-rw-r--r--repl/src/test/scala/spark/repl/ReplSuite.scala23
-rwxr-xr-xrun98
-rw-r--r--run2.cmd33
-rw-r--r--streaming/pom.xml43
-rw-r--r--streaming/src/main/scala/spark/streaming/Checkpoint.scala33
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala9
-rw-r--r--streaming/src/main/scala/spark/streaming/DStreamGraph.scala2
-rw-r--r--streaming/src/main/scala/spark/streaming/StreamingContext.scala60
-rw-r--r--streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala124
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala94
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala2
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala30
-rw-r--r--streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala7
-rw-r--r--streaming/src/main/scala/spark/streaming/util/RawTextSender.scala2
-rw-r--r--streaming/src/test/java/spark/streaming/JavaAPISuite.java14
-rw-r--r--streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala5
-rw-r--r--streaming/src/test/scala/spark/streaming/CheckpointSuite.scala2
-rw-r--r--streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala12
-rw-r--r--streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala1
228 files changed, 11461 insertions, 2760 deletions
diff --git a/.gitignore b/.gitignore
index 155e785b01..ae39c52b11 100644
--- a/.gitignore
+++ b/.gitignore
@@ -36,3 +36,5 @@ streaming-tests.log
dependency-reduced-pom.xml
.ensime
.ensime_lucene
+checkpoint
+derby.log
diff --git a/bagel/pom.xml b/bagel/pom.xml
index be2e358091..b83a0ef6c0 100644
--- a/bagel/pom.xml
+++ b/bagel/pom.xml
@@ -102,5 +102,42 @@
</plugins>
</build>
</profile>
+ <profile>
+ <id>hadoop2-yarn</id>
+ <dependencies>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-core</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-common</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <configuration>
+ <classifier>hadoop2-yarn</classifier>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
</profiles>
</project>
diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala
index 25db395c22..a09c978068 100644
--- a/bagel/src/test/scala/bagel/BagelSuite.scala
+++ b/bagel/src/test/scala/bagel/BagelSuite.scala
@@ -23,6 +23,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeo
}
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
}
test("halting by voting") {
diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd
new file mode 100644
index 0000000000..6e7efbd334
--- /dev/null
+++ b/bin/compute-classpath.cmd
@@ -0,0 +1,52 @@
+@echo off
+
+rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
+rem script and the ExecutorRunner in standalone cluster mode.
+
+set SCALA_VERSION=2.9.3
+
+rem Figure out where the Spark framework is installed
+set FWDIR=%~dp0..\
+
+rem Load environment variables from conf\spark-env.cmd, if it exists
+if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
+
+set CORE_DIR=%FWDIR%core
+set REPL_DIR=%FWDIR%repl
+set EXAMPLES_DIR=%FWDIR%examples
+set BAGEL_DIR=%FWDIR%bagel
+set STREAMING_DIR=%FWDIR%streaming
+set PYSPARK_DIR=%FWDIR%python
+
+rem Build up classpath
+set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes
+set CLASSPATH=%CLASSPATH%;%CORE_DIR%\target\scala-%SCALA_VERSION%\test-classes;%CORE_DIR%\src\main\resources
+set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\classes;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\test-classes
+set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\lib\org\apache\kafka\kafka\0.7.2-spark\*
+set CLASSPATH=%CLASSPATH%;%REPL_DIR%\target\scala-%SCALA_VERSION%\classes;%EXAMPLES_DIR%\target\scala-%SCALA_VERSION%\classes
+set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\jars\*
+set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\bundles\*
+set CLASSPATH=%CLASSPATH%;%FWDIR%repl\lib\*
+set CLASSPATH=%CLASSPATH%;%FWDIR%python\lib\*
+set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes
+
+rem Add hadoop conf dir - else FileSystem.*, etc fail
+rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
+rem the configurtion files.
+if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir
+ set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR%
+:no_hadoop_conf_dir
+
+if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
+ set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
+:no_yarn_conf_dir
+
+rem Add Scala standard library
+set CLASSPATH=%CLASSPATH%;%SCALA_HOME%\lib\scala-library.jar;%SCALA_HOME%\lib\scala-compiler.jar;%SCALA_HOME%\lib\jline.jar
+
+rem A bit of a hack to allow calling this script within run2.cmd without seeing output
+if "%DONT_PRINT_CLASSPATH%"=="1" goto exit
+
+echo %CLASSPATH%
+
+:exit
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
new file mode 100755
index 0000000000..3a78880290
--- /dev/null
+++ b/bin/compute-classpath.sh
@@ -0,0 +1,89 @@
+#!/bin/bash
+
+# This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
+# script and the ExecutorRunner in standalone cluster mode.
+
+SCALA_VERSION=2.9.3
+
+# Figure out where Spark is installed
+FWDIR="$(cd `dirname $0`/..; pwd)"
+
+# Load environment variables from conf/spark-env.sh, if it exists
+if [ -e $FWDIR/conf/spark-env.sh ] ; then
+ . $FWDIR/conf/spark-env.sh
+fi
+
+CORE_DIR="$FWDIR/core"
+REPL_DIR="$FWDIR/repl"
+REPL_BIN_DIR="$FWDIR/repl-bin"
+EXAMPLES_DIR="$FWDIR/examples"
+BAGEL_DIR="$FWDIR/bagel"
+STREAMING_DIR="$FWDIR/streaming"
+PYSPARK_DIR="$FWDIR/python"
+
+# Build up classpath
+CLASSPATH="$SPARK_CLASSPATH"
+CLASSPATH="$CLASSPATH:$FWDIR/conf"
+CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/classes"
+if [ -n "$SPARK_TESTING" ] ; then
+ CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/test-classes"
+fi
+CLASSPATH="$CLASSPATH:$CORE_DIR/src/main/resources"
+CLASSPATH="$CLASSPATH:$REPL_DIR/target/scala-$SCALA_VERSION/classes"
+CLASSPATH="$CLASSPATH:$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes"
+CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/classes"
+CLASSPATH="$CLASSPATH:$STREAMING_DIR/lib/org/apache/kafka/kafka/0.7.2-spark/*" # <-- our in-project Kafka Jar
+if [ -e "$FWDIR/lib_managed" ]; then
+ CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/jars/*"
+ CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/bundles/*"
+fi
+CLASSPATH="$CLASSPATH:$REPL_DIR/lib/*"
+# Add the shaded JAR for Maven builds
+if [ -e $REPL_BIN_DIR/target ]; then
+ for jar in `find "$REPL_BIN_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do
+ CLASSPATH="$CLASSPATH:$jar"
+ done
+ # The shaded JAR doesn't contain examples, so include those separately
+ EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar`
+ CLASSPATH+=":$EXAMPLES_JAR"
+fi
+CLASSPATH="$CLASSPATH:$BAGEL_DIR/target/scala-$SCALA_VERSION/classes"
+for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do
+ CLASSPATH="$CLASSPATH:$jar"
+done
+
+# Figure out the JAR file that our examples were packaged into. This includes a bit of a hack
+# to avoid the -sources and -doc packages that are built by publish-local.
+if [ -e "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar ]; then
+ # Use the JAR from the SBT build
+ export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar`
+fi
+if [ -e "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar ]; then
+ # Use the JAR from the Maven build
+ export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar`
+fi
+
+# Add hadoop conf dir - else FileSystem.*, etc fail !
+# Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
+# the configurtion files.
+if [ "x" != "x$HADOOP_CONF_DIR" ]; then
+ CLASSPATH="$CLASSPATH:$HADOOP_CONF_DIR"
+fi
+if [ "x" != "x$YARN_CONF_DIR" ]; then
+ CLASSPATH="$CLASSPATH:$YARN_CONF_DIR"
+fi
+
+# Add Scala standard library
+if [ -z "$SCALA_LIBRARY_PATH" ]; then
+ if [ -z "$SCALA_HOME" ]; then
+ echo "SCALA_HOME is not set" >&2
+ exit 1
+ fi
+ SCALA_LIBRARY_PATH="$SCALA_HOME/lib"
+fi
+CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-library.jar"
+CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-compiler.jar"
+CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/jline.jar"
+
+echo "$CLASSPATH"
diff --git a/bin/start-slave.sh b/bin/start-slave.sh
index 616c76e4ee..26b5b9d462 100755
--- a/bin/start-slave.sh
+++ b/bin/start-slave.sh
@@ -6,7 +6,8 @@ bin=`cd "$bin"; pwd`
# Set SPARK_PUBLIC_DNS so slaves can be linked in master web UI
if [ "$SPARK_PUBLIC_DNS" = "" ]; then
# If we appear to be running on EC2, use the public address by default:
- if [[ `hostname` == *ec2.internal ]]; then
+ # NOTE: ec2-metadata is installed on Amazon Linux AMI. Check based on that and hostname
+ if command -v ec2-metadata > /dev/null || [[ `hostname` == *ec2.internal ]]; then
export SPARK_PUBLIC_DNS=`wget -q -O - http://instance-data.ec2.internal/latest/meta-data/public-hostname`
fi
fi
diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template
index 37565ca827..b8936314ec 100755
--- a/conf/spark-env.sh.template
+++ b/conf/spark-env.sh.template
@@ -3,8 +3,10 @@
# This file contains environment variables required to run Spark. Copy it as
# spark-env.sh and edit that to configure Spark for your site. At a minimum,
# the following two variables should be set:
-# - MESOS_NATIVE_LIBRARY, to point to your Mesos native library (libmesos.so)
-# - SCALA_HOME, to point to your Scala installation
+# - SCALA_HOME, to point to your Scala installation, or SCALA_LIBRARY_PATH to
+# point to the directory for Scala library JARs (if you install Scala as a
+# Debian or RPM package, these are in a separate path, often /usr/share/java)
+# - MESOS_NATIVE_LIBRARY, to point to your libmesos.so if you use Mesos
#
# If using the standalone deploy mode, you can also set variables for it:
# - SPARK_MASTER_IP, to bind the master to a different IP address
@@ -12,14 +14,6 @@
# - SPARK_WORKER_CORES, to set the number of cores to use on this machine
# - SPARK_WORKER_MEMORY, to set how much memory to use (e.g. 1000m, 2g)
# - SPARK_WORKER_PORT / SPARK_WORKER_WEBUI_PORT
-# - SPARK_WORKER_INSTANCES, to set the number of worker instances/processes to be spawned on every slave machine
-#
-# Finally, Spark also relies on the following variables, but these can be set
-# on just the *master* (i.e. in your driver program), and will automatically
-# be propagated to workers:
-# - SPARK_MEM, to change the amount of memory used per node (this should
-# be in the same format as the JVM's -Xmx option, e.g. 300m or 1g)
-# - SPARK_CLASSPATH, to add elements to Spark's classpath
-# - SPARK_JAVA_OPTS, to add JVM options
-# - SPARK_LIBRARY_PATH, to add extra search paths for native libraries.
+# - SPARK_WORKER_INSTANCES, to set the number of worker instances/processes
+# to be spawned on every slave machine
diff --git a/core/pom.xml b/core/pom.xml
index 7f5cffc818..385663a638 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -32,8 +32,8 @@
<artifactId>compress-lzf</artifactId>
</dependency>
<dependency>
- <groupId>asm</groupId>
- <artifactId>asm-all</artifactId>
+ <groupId>org.ow2.asm</groupId>
+ <artifactId>asm</artifactId>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
@@ -85,18 +85,27 @@
</dependency>
<dependency>
<groupId>com.github.scala-incubator.io</groupId>
- <artifactId>scala-io-file_${scala.version}</artifactId>
+ <artifactId>scala-io-file_2.10</artifactId>
</dependency>
<dependency>
<groupId>org.apache.mesos</groupId>
<artifactId>mesos</artifactId>
</dependency>
<dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-all</artifactId>
+ </dependency>
+ <dependency>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
</dependency>
<dependency>
+ <groupId>org.apache.derby</groupId>
+ <artifactId>derby</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.version}</artifactId>
<scope>test</scope>
@@ -283,5 +292,72 @@
</plugins>
</build>
</profile>
+ <profile>
+ <id>hadoop2-yarn</id>
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-common</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-client</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>build-helper-maven-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>add-source</id>
+ <phase>generate-sources</phase>
+ <goals>
+ <goal>add-source</goal>
+ </goals>
+ <configuration>
+ <sources>
+ <source>src/main/scala</source>
+ <source>src/hadoop2-yarn/scala</source>
+ </sources>
+ </configuration>
+ </execution>
+ <execution>
+ <id>add-scala-test-sources</id>
+ <phase>generate-test-sources</phase>
+ <goals>
+ <goal>add-test-source</goal>
+ </goals>
+ <configuration>
+ <sources>
+ <source>src/test/scala</source>
+ </sources>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <configuration>
+ <classifier>hadoop2-yarn</classifier>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
</profiles>
</project>
diff --git a/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
index ca9f7219de..f286f2cf9c 100644
--- a/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
+++ b/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
@@ -4,4 +4,7 @@ trait HadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContext(conf, jobId)
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
+ jobId, isMap, taskId, attemptId)
}
diff --git a/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
index de7b0f81e3..264d421d14 100644
--- a/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
+++ b/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
@@ -6,4 +6,7 @@ trait HadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContext(conf, jobId)
def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
+ jobId, isMap, taskId, attemptId)
}
diff --git a/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala
new file mode 100644
index 0000000000..a0fb4fe25d
--- /dev/null
+++ b/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala
@@ -0,0 +1,23 @@
+package spark.deploy
+import org.apache.hadoop.conf.Configuration
+
+
+/**
+ * Contains util methods to interact with Hadoop from spark.
+ */
+object SparkHadoopUtil {
+
+ def getUserNameFromEnvironment(): String = {
+ // defaulting to -D ...
+ System.getProperty("user.name")
+ }
+
+ def runAsUser(func: (Product) => Unit, args: Product) {
+
+ // Add support, if exists - for now, simply run func !
+ func(args)
+ }
+
+ // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
+ def newConfiguration(): Configuration = new Configuration()
+}
diff --git a/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
new file mode 100644
index 0000000000..875c0a220b
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
@@ -0,0 +1,13 @@
+
+package org.apache.hadoop.mapred
+
+import org.apache.hadoop.mapreduce.TaskType
+
+trait HadoopMapRedUtil {
+ def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
+
+ def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) =
+ new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId)
+}
diff --git a/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
new file mode 100644
index 0000000000..8bc6fb6dea
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
@@ -0,0 +1,13 @@
+package org.apache.hadoop.mapreduce
+
+import org.apache.hadoop.conf.Configuration
+import task.{TaskAttemptContextImpl, JobContextImpl}
+
+trait HadoopMapReduceUtil {
+ def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
+
+ def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) =
+ new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId)
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala
new file mode 100644
index 0000000000..ab1ab9d8a7
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala
@@ -0,0 +1,63 @@
+package spark.deploy
+
+import collection.mutable.HashMap
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import java.security.PrivilegedExceptionAction
+
+/**
+ * Contains util methods to interact with Hadoop from spark.
+ */
+object SparkHadoopUtil {
+
+ val yarnConf = newConfiguration()
+
+ def getUserNameFromEnvironment(): String = {
+ // defaulting to env if -D is not present ...
+ val retval = System.getProperty(Environment.USER.name, System.getenv(Environment.USER.name))
+
+ // If nothing found, default to user we are running as
+ if (retval == null) System.getProperty("user.name") else retval
+ }
+
+ def runAsUser(func: (Product) => Unit, args: Product) {
+ runAsUser(func, args, getUserNameFromEnvironment())
+ }
+
+ def runAsUser(func: (Product) => Unit, args: Product, user: String) {
+
+ // println("running as user " + jobUserName)
+
+ UserGroupInformation.setConfiguration(yarnConf)
+ val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(user)
+ appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] {
+ def run: AnyRef = {
+ func(args)
+ // no return value ...
+ null
+ }
+ })
+ }
+
+ // Note that all params which start with SPARK are propagated all the way through, so if in yarn mode, this MUST be set to true.
+ def isYarnMode(): Boolean = {
+ val yarnMode = System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))
+ java.lang.Boolean.valueOf(yarnMode)
+ }
+
+ // Set an env variable indicating we are running in YARN mode.
+ // Note that anything with SPARK prefix gets propagated to all (remote) processes
+ def setYarnMode() {
+ System.setProperty("SPARK_YARN_MODE", "true")
+ }
+
+ def setYarnMode(env: HashMap[String, String]) {
+ env("SPARK_YARN_MODE") = "true"
+ }
+
+ // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
+ // Always create a new config, dont reuse yarnConf.
+ def newConfiguration(): Configuration = new YarnConfiguration(new Configuration())
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala
new file mode 100644
index 0000000000..aa72c1e5fe
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala
@@ -0,0 +1,329 @@
+package spark.deploy.yarn
+
+import java.net.Socket
+import java.util.concurrent.CopyOnWriteArrayList
+import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+import scala.collection.JavaConversions._
+import spark.{SparkContext, Logging, Utils}
+import org.apache.hadoop.security.UserGroupInformation
+import java.security.PrivilegedExceptionAction
+
+class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
+
+ def this(args: ApplicationMasterArguments) = this(args, new Configuration())
+
+ private var rpc: YarnRPC = YarnRPC.create(conf)
+ private var resourceManager: AMRMProtocol = null
+ private var appAttemptId: ApplicationAttemptId = null
+ private var userThread: Thread = null
+ private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+ private var yarnAllocator: YarnAllocationHandler = null
+
+ def run() {
+
+ // Initialization
+ val jobUserName = Utils.getUserNameFromEnvironment()
+ logInfo("running as user " + jobUserName)
+
+ // run as user ...
+ UserGroupInformation.setConfiguration(yarnConf)
+ val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(jobUserName)
+ appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] {
+ def run: AnyRef = {
+ runImpl()
+ return null
+ }
+ })
+ }
+
+ private def runImpl() {
+
+ appAttemptId = getApplicationAttemptId()
+ resourceManager = registerWithResourceManager()
+ val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
+
+ // Compute number of threads for akka
+ val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()
+
+ if (minimumMemory > 0) {
+ val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD
+ val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0)
+
+ if (numCore > 0) {
+ // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406
+ // TODO: Uncomment when hadoop is on a version which has this fixed.
+ // args.workerCores = numCore
+ }
+ }
+
+ // Workaround until hadoop moves to something which has
+ // https://issues.apache.org/jira/browse/HADOOP-8406
+ // ignore result
+ // This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times
+ // Hence args.workerCores = numCore disabled above. Any better option ?
+ // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf)
+
+ ApplicationMaster.register(this)
+ // Start the user's JAR
+ userThread = startUserClass()
+
+ // This a bit hacky, but we need to wait until the spark.driver.port property has
+ // been set by the Thread executing the user class.
+ waitForSparkMaster()
+
+ // Allocate all containers
+ allocateWorkers()
+
+ // Wait for the user class to Finish
+ userThread.join()
+
+ // Finish the ApplicationMaster
+ finishApplicationMaster()
+ // TODO: Exit based on success/failure
+ System.exit(0)
+ }
+
+ private def getApplicationAttemptId(): ApplicationAttemptId = {
+ val envs = System.getenv()
+ val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV)
+ val containerId = ConverterUtils.toContainerId(containerIdString)
+ val appAttemptId = containerId.getApplicationAttemptId()
+ logInfo("ApplicationAttemptId: " + appAttemptId)
+ return appAttemptId
+ }
+
+ private def registerWithResourceManager(): AMRMProtocol = {
+ val rmAddress = NetUtils.createSocketAddr(yarnConf.get(
+ YarnConfiguration.RM_SCHEDULER_ADDRESS,
+ YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS))
+ logInfo("Connecting to ResourceManager at " + rmAddress)
+ return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
+ }
+
+ private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
+ logInfo("Registering the ApplicationMaster")
+ val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest])
+ .asInstanceOf[RegisterApplicationMasterRequest]
+ appMasterRequest.setApplicationAttemptId(appAttemptId)
+ // Setting this to master host,port - so that the ApplicationReport at client has some sensible info.
+ // Users can then monitor stderr/stdout on that node if required.
+ appMasterRequest.setHost(Utils.localHostName())
+ appMasterRequest.setRpcPort(0)
+ // What do we provide here ? Might make sense to expose something sensible later ?
+ appMasterRequest.setTrackingUrl("")
+ return resourceManager.registerApplicationMaster(appMasterRequest)
+ }
+
+ private def waitForSparkMaster() {
+ logInfo("Waiting for spark driver to be reachable.")
+ var driverUp = false
+ while(!driverUp) {
+ val driverHost = System.getProperty("spark.driver.host")
+ val driverPort = System.getProperty("spark.driver.port")
+ try {
+ val socket = new Socket(driverHost, driverPort.toInt)
+ socket.close()
+ logInfo("Master now available: " + driverHost + ":" + driverPort)
+ driverUp = true
+ } catch {
+ case e: Exception =>
+ logError("Failed to connect to driver at " + driverHost + ":" + driverPort)
+ Thread.sleep(100)
+ }
+ }
+ }
+
+ private def startUserClass(): Thread = {
+ logInfo("Starting the user JAR in a separate Thread")
+ val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader)
+ .getMethod("main", classOf[Array[String]])
+ val t = new Thread {
+ override def run() {
+ // Copy
+ var mainArgs: Array[String] = new Array[String](args.userArgs.size())
+ args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size())
+ mainMethod.invoke(null, mainArgs)
+ }
+ }
+ t.start()
+ return t
+ }
+
+ private def allocateWorkers() {
+ logInfo("Waiting for spark context initialization")
+
+ try {
+ var sparkContext: SparkContext = null
+ ApplicationMaster.sparkContextRef.synchronized {
+ var count = 0
+ while (ApplicationMaster.sparkContextRef.get() == null) {
+ logInfo("Waiting for spark context initialization ... " + count)
+ count = count + 1
+ ApplicationMaster.sparkContextRef.wait(10000L)
+ }
+ sparkContext = ApplicationMaster.sparkContextRef.get()
+ assert(sparkContext != null)
+ this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, sparkContext.preferredNodeLocationData)
+ }
+
+
+ logInfo("Allocating " + args.numWorkers + " workers.")
+ // Wait until all containers have finished
+ // TODO: This is a bit ugly. Can we make it nicer?
+ // TODO: Handle container failure
+ while(yarnAllocator.getNumWorkersRunning < args.numWorkers &&
+ // If user thread exists, then quit !
+ userThread.isAlive) {
+
+ this.yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0))
+ ApplicationMaster.incrementAllocatorLoop(1)
+ Thread.sleep(100)
+ }
+ } finally {
+ // in case of exceptions, etc - ensure that count is atleast ALLOCATOR_LOOP_WAIT_COUNT :
+ // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks
+ ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT)
+ }
+ logInfo("All workers have launched.")
+
+ // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout
+ if (userThread.isAlive){
+ // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
+
+ val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
+ // must be <= timeoutInterval/ 2.
+ // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
+ // so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
+ val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
+ launchReporterThread(interval)
+ }
+ }
+
+ // TODO: We might want to extend this to allocate more containers in case they die !
+ private def launchReporterThread(_sleepTime: Long): Thread = {
+ val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime
+
+ val t = new Thread {
+ override def run() {
+ while (userThread.isAlive){
+ val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning
+ if (missingWorkerCount > 0) {
+ logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers")
+ yarnAllocator.allocateContainers(missingWorkerCount)
+ }
+ else sendProgress()
+ Thread.sleep(sleepTime)
+ }
+ }
+ }
+ // setting to daemon status, though this is usually not a good idea.
+ t.setDaemon(true)
+ t.start()
+ logInfo("Started progress reporter thread - sleep time : " + sleepTime)
+ return t
+ }
+
+ private def sendProgress() {
+ logDebug("Sending progress")
+ // simulated with an allocate request with no nodes requested ...
+ yarnAllocator.allocateContainers(0)
+ }
+
+ /*
+ def printContainers(containers: List[Container]) = {
+ for (container <- containers) {
+ logInfo("Launching shell command on a new container."
+ + ", containerId=" + container.getId()
+ + ", containerNode=" + container.getNodeId().getHost()
+ + ":" + container.getNodeId().getPort()
+ + ", containerNodeURI=" + container.getNodeHttpAddress()
+ + ", containerState" + container.getState()
+ + ", containerResourceMemory"
+ + container.getResource().getMemory())
+ }
+ }
+ */
+
+ def finishApplicationMaster() {
+ val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest])
+ .asInstanceOf[FinishApplicationMasterRequest]
+ finishReq.setAppAttemptId(appAttemptId)
+ // TODO: Check if the application has failed or succeeded
+ finishReq.setFinishApplicationStatus(FinalApplicationStatus.SUCCEEDED)
+ resourceManager.finishApplicationMaster(finishReq)
+ }
+
+}
+
+object ApplicationMaster {
+ // number of times to wait for the allocator loop to complete.
+ // each loop iteration waits for 100ms, so maximum of 3 seconds.
+ // This is to ensure that we have reasonable number of containers before we start
+ // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be optimal as more
+ // containers are available. Might need to handle this better.
+ private val ALLOCATOR_LOOP_WAIT_COUNT = 30
+ def incrementAllocatorLoop(by: Int) {
+ val count = yarnAllocatorLoop.getAndAdd(by)
+ if (count >= ALLOCATOR_LOOP_WAIT_COUNT){
+ yarnAllocatorLoop.synchronized {
+ // to wake threads off wait ...
+ yarnAllocatorLoop.notifyAll()
+ }
+ }
+ }
+
+ private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]()
+
+ def register(master: ApplicationMaster) {
+ applicationMasters.add(master)
+ }
+
+ val sparkContextRef: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null)
+ val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0)
+
+ def sparkContextInitialized(sc: SparkContext): Boolean = {
+ var modified = false
+ sparkContextRef.synchronized {
+ modified = sparkContextRef.compareAndSet(null, sc)
+ sparkContextRef.notifyAll()
+ }
+
+ // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do System.exit
+ // Should not really have to do this, but it helps yarn to evict resources earlier.
+ // not to mention, prevent Client declaring failure even though we exit'ed properly.
+ if (modified) {
+ Runtime.getRuntime().addShutdownHook(new Thread with Logging {
+ // This is not just to log, but also to ensure that log system is initialized for this instance when we actually are 'run'
+ logInfo("Adding shutdown hook for context " + sc)
+ override def run() {
+ logInfo("Invoking sc stop from shutdown hook")
+ sc.stop()
+ // best case ...
+ for (master <- applicationMasters) master.finishApplicationMaster
+ }
+ } )
+ }
+
+ // Wait for initialization to complete and atleast 'some' nodes can get allocated
+ yarnAllocatorLoop.synchronized {
+ while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT){
+ yarnAllocatorLoop.wait(1000L)
+ }
+ }
+ modified
+ }
+
+ def main(argStrings: Array[String]) {
+ val args = new ApplicationMasterArguments(argStrings)
+ new ApplicationMaster(args).run()
+ }
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala
new file mode 100644
index 0000000000..1b00208511
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala
@@ -0,0 +1,77 @@
+package spark.deploy.yarn
+
+import spark.util.IntParam
+import collection.mutable.ArrayBuffer
+
+class ApplicationMasterArguments(val args: Array[String]) {
+ var userJar: String = null
+ var userClass: String = null
+ var userArgs: Seq[String] = Seq[String]()
+ var workerMemory = 1024
+ var workerCores = 1
+ var numWorkers = 2
+
+ parseArgs(args.toList)
+
+ private def parseArgs(inputArgs: List[String]): Unit = {
+ val userArgsBuffer = new ArrayBuffer[String]()
+
+ var args = inputArgs
+
+ while (! args.isEmpty) {
+
+ args match {
+ case ("--jar") :: value :: tail =>
+ userJar = value
+ args = tail
+
+ case ("--class") :: value :: tail =>
+ userClass = value
+ args = tail
+
+ case ("--args") :: value :: tail =>
+ userArgsBuffer += value
+ args = tail
+
+ case ("--num-workers") :: IntParam(value) :: tail =>
+ numWorkers = value
+ args = tail
+
+ case ("--worker-memory") :: IntParam(value) :: tail =>
+ workerMemory = value
+ args = tail
+
+ case ("--worker-cores") :: IntParam(value) :: tail =>
+ workerCores = value
+ args = tail
+
+ case Nil =>
+ if (userJar == null || userClass == null) {
+ printUsageAndExit(1)
+ }
+
+ case _ =>
+ printUsageAndExit(1, args)
+ }
+ }
+
+ userArgs = userArgsBuffer.readOnly
+ }
+
+ def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ if (unknownParam != null) {
+ System.err.println("Unknown/unsupported param " + unknownParam)
+ }
+ System.err.println(
+ "Usage: spark.deploy.yarn.ApplicationMaster [options] \n" +
+ "Options:\n" +
+ " --jar JAR_PATH Path to your application's JAR file (required)\n" +
+ " --class CLASS_NAME Name of your application's main class (required)\n" +
+ " --args ARGS Arguments to be passed to your application's main class.\n" +
+ " Mutliple invocations are possible, each will be passed in order.\n" +
+ " --num-workers NUM Number of workers to start (Default: 2)\n" +
+ " --worker-cores NUM Number of cores for the workers (Default: 1)\n" +
+ " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n")
+ System.exit(exitCode)
+ }
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala
new file mode 100644
index 0000000000..7a881e26df
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala
@@ -0,0 +1,272 @@
+package spark.deploy.yarn
+
+import java.net.{InetSocketAddress, URI}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.client.YarnClientImpl
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions._
+import spark.{Logging, Utils}
+import org.apache.hadoop.yarn.util.{Apps, Records, ConverterUtils}
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import spark.deploy.SparkHadoopUtil
+
+class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging {
+
+ def this(args: ClientArguments) = this(new Configuration(), args)
+
+ var rpc: YarnRPC = YarnRPC.create(conf)
+ val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+ def run() {
+ init(yarnConf)
+ start()
+ logClusterResourceDetails()
+
+ val newApp = super.getNewApplication()
+ val appId = newApp.getApplicationId()
+
+ verifyClusterResources(newApp)
+ val appContext = createApplicationSubmissionContext(appId)
+ val localResources = prepareLocalResources(appId, "spark")
+ val env = setupLaunchEnv(localResources)
+ val amContainer = createContainerLaunchContext(newApp, localResources, env)
+
+ appContext.setQueue(args.amQueue)
+ appContext.setAMContainerSpec(amContainer)
+ appContext.setUser(args.amUser)
+
+ submitApp(appContext)
+
+ monitorApplication(appId)
+ System.exit(0)
+ }
+
+
+ def logClusterResourceDetails() {
+ val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics
+ logInfo("Got Cluster metric info from ASM, numNodeManagers=" + clusterMetrics.getNumNodeManagers)
+
+ val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue)
+ logInfo("Queue info .. queueName=" + queueInfo.getQueueName + ", queueCurrentCapacity=" + queueInfo.getCurrentCapacity +
+ ", queueMaxCapacity=" + queueInfo.getMaximumCapacity + ", queueApplicationCount=" + queueInfo.getApplications.size +
+ ", queueChildQueueCount=" + queueInfo.getChildQueues.size)
+ }
+
+
+ def verifyClusterResources(app: GetNewApplicationResponse) = {
+ val maxMem = app.getMaximumResourceCapability().getMemory()
+ logInfo("Max mem capabililty of resources in this cluster " + maxMem)
+
+ // If the cluster does not have enough memory resources, exit.
+ val requestedMem = (args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + args.numWorkers * args.workerMemory
+ if (requestedMem > maxMem) {
+ logError("Cluster cannot satisfy memory resource request of " + requestedMem)
+ System.exit(1)
+ }
+ }
+
+ def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = {
+ logInfo("Setting up application submission context for ASM")
+ val appContext = Records.newRecord(classOf[ApplicationSubmissionContext])
+ appContext.setApplicationId(appId)
+ appContext.setApplicationName("Spark")
+ return appContext
+ }
+
+ def prepareLocalResources(appId: ApplicationId, appName: String): HashMap[String, LocalResource] = {
+ logInfo("Preparing Local resources")
+ val locaResources = HashMap[String, LocalResource]()
+ // Upload Spark and the application JAR to the remote file system
+ // Add them as local resources to the AM
+ val fs = FileSystem.get(conf)
+ Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF"))
+ .foreach { case(destName, _localPath) =>
+ val localPath: String = if (_localPath != null) _localPath.trim() else ""
+ if (! localPath.isEmpty()) {
+ val src = new Path(localPath)
+ val pathSuffix = appName + "/" + appId.getId() + destName
+ val dst = new Path(fs.getHomeDirectory(), pathSuffix)
+ logInfo("Uploading " + src + " to " + dst)
+ fs.copyFromLocalFile(false, true, src, dst)
+ val destStatus = fs.getFileStatus(dst)
+
+ val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ amJarRsrc.setType(LocalResourceType.FILE)
+ amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(dst))
+ amJarRsrc.setTimestamp(destStatus.getModificationTime())
+ amJarRsrc.setSize(destStatus.getLen())
+ locaResources(destName) = amJarRsrc
+ }
+ }
+ return locaResources
+ }
+
+ def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = {
+ logInfo("Setting up the launch environment")
+ val log4jConfLocalRes = localResources.getOrElse("log4j.properties", null)
+
+ val env = new HashMap[String, String]()
+ Apps.addToEnvironment(env, Environment.USER.name, args.amUser)
+
+ // If log4j present, ensure ours overrides all others
+ if (log4jConfLocalRes != null) Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
+
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
+ Client.populateHadoopClasspath(yarnConf, env)
+ SparkHadoopUtil.setYarnMode(env)
+ env("SPARK_YARN_JAR_PATH") =
+ localResources("spark.jar").getResource().getScheme.toString() + "://" +
+ localResources("spark.jar").getResource().getFile().toString()
+ env("SPARK_YARN_JAR_TIMESTAMP") = localResources("spark.jar").getTimestamp().toString()
+ env("SPARK_YARN_JAR_SIZE") = localResources("spark.jar").getSize().toString()
+
+ env("SPARK_YARN_USERJAR_PATH") =
+ localResources("app.jar").getResource().getScheme.toString() + "://" +
+ localResources("app.jar").getResource().getFile().toString()
+ env("SPARK_YARN_USERJAR_TIMESTAMP") = localResources("app.jar").getTimestamp().toString()
+ env("SPARK_YARN_USERJAR_SIZE") = localResources("app.jar").getSize().toString()
+
+ if (log4jConfLocalRes != null) {
+ env("SPARK_YARN_LOG4J_PATH") =
+ log4jConfLocalRes.getResource().getScheme.toString() + "://" + log4jConfLocalRes.getResource().getFile().toString()
+ env("SPARK_YARN_LOG4J_TIMESTAMP") = log4jConfLocalRes.getTimestamp().toString()
+ env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString()
+ }
+
+ // Add each SPARK-* key to the environment
+ System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
+ return env
+ }
+
+ def userArgsToString(clientArgs: ClientArguments): String = {
+ val prefix = " --args "
+ val args = clientArgs.userArgs
+ val retval = new StringBuilder()
+ for (arg <- args){
+ retval.append(prefix).append(" '").append(arg).append("' ")
+ }
+
+ retval.toString
+ }
+
+ def createContainerLaunchContext(newApp: GetNewApplicationResponse,
+ localResources: HashMap[String, LocalResource],
+ env: HashMap[String, String]): ContainerLaunchContext = {
+ logInfo("Setting up container launch context")
+ val amContainer = Records.newRecord(classOf[ContainerLaunchContext])
+ amContainer.setLocalResources(localResources)
+ amContainer.setEnvironment(env)
+
+ val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory()
+
+ var amMemory = ((args.amMemory / minResMemory) * minResMemory) +
+ (if (0 != (args.amMemory % minResMemory)) minResMemory else 0) - YarnAllocationHandler.MEMORY_OVERHEAD
+
+ // Extra options for the JVM
+ var JAVA_OPTS = ""
+
+ // Add Xmx for am memory
+ JAVA_OPTS += "-Xmx" + amMemory + "m "
+
+ // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
+ // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
+ // node, spark gc effects all other containers performance (which can also be other spark containers)
+ // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
+ // limited to subset of cores on a node.
+ if (env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))) {
+ // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tenant machines
+ JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
+ JAVA_OPTS += " -XX:+CMSIncrementalMode "
+ JAVA_OPTS += " -XX:+CMSIncrementalPacing "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
+ }
+ if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
+ JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
+ }
+
+ // Command for the ApplicationMaster
+ val commands = List[String]("java " +
+ " -server " +
+ JAVA_OPTS +
+ " spark.deploy.yarn.ApplicationMaster" +
+ " --class " + args.userClass +
+ " --jar " + args.userJar +
+ userArgsToString(args) +
+ " --worker-memory " + args.workerMemory +
+ " --worker-cores " + args.workerCores +
+ " --num-workers " + args.numWorkers +
+ " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
+ " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
+ logInfo("Command for the ApplicationMaster: " + commands(0))
+ amContainer.setCommands(commands)
+
+ val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource]
+ // Memory for the ApplicationMaster
+ capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+ amContainer.setResource(capability)
+
+ return amContainer
+ }
+
+ def submitApp(appContext: ApplicationSubmissionContext) = {
+ // Submit the application to the applications manager
+ logInfo("Submitting application to ASM")
+ super.submitApplication(appContext)
+ }
+
+ def monitorApplication(appId: ApplicationId): Boolean = {
+ while(true) {
+ Thread.sleep(1000)
+ val report = super.getApplicationReport(appId)
+
+ logInfo("Application report from ASM: \n" +
+ "\t application identifier: " + appId.toString() + "\n" +
+ "\t appId: " + appId.getId() + "\n" +
+ "\t clientToken: " + report.getClientToken() + "\n" +
+ "\t appDiagnostics: " + report.getDiagnostics() + "\n" +
+ "\t appMasterHost: " + report.getHost() + "\n" +
+ "\t appQueue: " + report.getQueue() + "\n" +
+ "\t appMasterRpcPort: " + report.getRpcPort() + "\n" +
+ "\t appStartTime: " + report.getStartTime() + "\n" +
+ "\t yarnAppState: " + report.getYarnApplicationState() + "\n" +
+ "\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" +
+ "\t appTrackingUrl: " + report.getTrackingUrl() + "\n" +
+ "\t appUser: " + report.getUser()
+ )
+
+ val state = report.getYarnApplicationState()
+ val dsStatus = report.getFinalApplicationStatus()
+ if (state == YarnApplicationState.FINISHED ||
+ state == YarnApplicationState.FAILED ||
+ state == YarnApplicationState.KILLED) {
+ return true
+ }
+ }
+ return true
+ }
+}
+
+object Client {
+ def main(argStrings: Array[String]) {
+ val args = new ClientArguments(argStrings)
+ SparkHadoopUtil.setYarnMode()
+ new Client(args).run
+ }
+
+ // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps
+ def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) {
+ for (c <- conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim)
+ }
+ }
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala
new file mode 100644
index 0000000000..24110558e7
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala
@@ -0,0 +1,105 @@
+package spark.deploy.yarn
+
+import spark.util.MemoryParam
+import spark.util.IntParam
+import collection.mutable.{ArrayBuffer, HashMap}
+import spark.scheduler.{InputFormatInfo, SplitInfo}
+
+// TODO: Add code and support for ensuring that yarn resource 'asks' are location aware !
+class ClientArguments(val args: Array[String]) {
+ var userJar: String = null
+ var userClass: String = null
+ var userArgs: Seq[String] = Seq[String]()
+ var workerMemory = 1024
+ var workerCores = 1
+ var numWorkers = 2
+ var amUser = System.getProperty("user.name")
+ var amQueue = System.getProperty("QUEUE", "default")
+ var amMemory: Int = 512
+ // TODO
+ var inputFormatInfo: List[InputFormatInfo] = null
+
+ parseArgs(args.toList)
+
+ private def parseArgs(inputArgs: List[String]): Unit = {
+ val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]()
+ val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]()
+
+ var args = inputArgs
+
+ while (! args.isEmpty) {
+
+ args match {
+ case ("--jar") :: value :: tail =>
+ userJar = value
+ args = tail
+
+ case ("--class") :: value :: tail =>
+ userClass = value
+ args = tail
+
+ case ("--args") :: value :: tail =>
+ userArgsBuffer += value
+ args = tail
+
+ case ("--master-memory") :: MemoryParam(value) :: tail =>
+ amMemory = value
+ args = tail
+
+ case ("--num-workers") :: IntParam(value) :: tail =>
+ numWorkers = value
+ args = tail
+
+ case ("--worker-memory") :: MemoryParam(value) :: tail =>
+ workerMemory = value
+ args = tail
+
+ case ("--worker-cores") :: IntParam(value) :: tail =>
+ workerCores = value
+ args = tail
+
+ case ("--user") :: value :: tail =>
+ amUser = value
+ args = tail
+
+ case ("--queue") :: value :: tail =>
+ amQueue = value
+ args = tail
+
+ case Nil =>
+ if (userJar == null || userClass == null) {
+ printUsageAndExit(1)
+ }
+
+ case _ =>
+ printUsageAndExit(1, args)
+ }
+ }
+
+ userArgs = userArgsBuffer.readOnly
+ inputFormatInfo = inputFormatMap.values.toList
+ }
+
+
+ def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ if (unknownParam != null) {
+ System.err.println("Unknown/unsupported param " + unknownParam)
+ }
+ System.err.println(
+ "Usage: spark.deploy.yarn.Client [options] \n" +
+ "Options:\n" +
+ " --jar JAR_PATH Path to your application's JAR file (required)\n" +
+ " --class CLASS_NAME Name of your application's main class (required)\n" +
+ " --args ARGS Arguments to be passed to your application's main class.\n" +
+ " Mutliple invocations are possible, each will be passed in order.\n" +
+ " --num-workers NUM Number of workers to start (Default: 2)\n" +
+ " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
+ " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
+ " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
+ " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" +
+ " --user USERNAME Run the ApplicationMaster (and slaves) as a different user\n"
+ )
+ System.exit(exitCode)
+ }
+
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala
new file mode 100644
index 0000000000..a2bf0af762
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala
@@ -0,0 +1,171 @@
+package spark.deploy.yarn
+
+import java.net.URI
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records}
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.HashMap
+
+import spark.{Logging, Utils}
+
+class WorkerRunnable(container: Container, conf: Configuration, masterAddress: String,
+ slaveId: String, hostname: String, workerMemory: Int, workerCores: Int)
+ extends Runnable with Logging {
+
+ var rpc: YarnRPC = YarnRPC.create(conf)
+ var cm: ContainerManager = null
+ val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+ def run = {
+ logInfo("Starting Worker Container")
+ cm = connectToCM
+ startContainer
+ }
+
+ def startContainer = {
+ logInfo("Setting up ContainerLaunchContext")
+
+ val ctx = Records.newRecord(classOf[ContainerLaunchContext])
+ .asInstanceOf[ContainerLaunchContext]
+
+ ctx.setContainerId(container.getId())
+ ctx.setResource(container.getResource())
+ val localResources = prepareLocalResources
+ ctx.setLocalResources(localResources)
+
+ val env = prepareEnvironment
+ ctx.setEnvironment(env)
+
+ // Extra options for the JVM
+ var JAVA_OPTS = ""
+ // Set the JVM memory
+ val workerMemoryString = workerMemory + "m"
+ JAVA_OPTS += "-Xms" + workerMemoryString + " -Xmx" + workerMemoryString + " "
+ if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
+ JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
+ }
+ // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
+ // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
+ // node, spark gc effects all other containers performance (which can also be other spark containers)
+ // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
+ // limited to subset of cores on a node.
+/*
+ else {
+ // If no java_opts specified, default to using -XX:+CMSIncrementalMode
+ // It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont want to mess with it.
+ // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tennent machines
+ // The options are based on
+ // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline
+ JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
+ JAVA_OPTS += " -XX:+CMSIncrementalMode "
+ JAVA_OPTS += " -XX:+CMSIncrementalPacing "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
+ }
+*/
+
+ ctx.setUser(UserGroupInformation.getCurrentUser().getShortUserName())
+ val commands = List[String]("java " +
+ " -server " +
+ // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling.
+ // Not killing the task leaves various aspects of the worker and (to some extent) the jvm in an inconsistent state.
+ // TODO: If the OOM is not recoverable by rescheduling it on different node, then do 'something' to fail job ... akin to blacklisting trackers in mapred ?
+ " -XX:OnOutOfMemoryError='kill %p' " +
+ JAVA_OPTS +
+ " spark.executor.StandaloneExecutorBackend " +
+ masterAddress + " " +
+ slaveId + " " +
+ hostname + " " +
+ workerCores +
+ " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
+ " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
+ logInfo("Setting up worker with commands: " + commands)
+ ctx.setCommands(commands)
+
+ // Send the start request to the ContainerManager
+ val startReq = Records.newRecord(classOf[StartContainerRequest])
+ .asInstanceOf[StartContainerRequest]
+ startReq.setContainerLaunchContext(ctx)
+ cm.startContainer(startReq)
+ }
+
+
+ def prepareLocalResources: HashMap[String, LocalResource] = {
+ logInfo("Preparing Local resources")
+ val locaResources = HashMap[String, LocalResource]()
+
+ // Spark JAR
+ val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ sparkJarResource.setType(LocalResourceType.FILE)
+ sparkJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
+ sparkJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
+ new URI(System.getenv("SPARK_YARN_JAR_PATH"))))
+ sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong)
+ sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong)
+ locaResources("spark.jar") = sparkJarResource
+ // User JAR
+ val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ userJarResource.setType(LocalResourceType.FILE)
+ userJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
+ userJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
+ new URI(System.getenv("SPARK_YARN_USERJAR_PATH"))))
+ userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong)
+ userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong)
+ locaResources("app.jar") = userJarResource
+
+ // Log4j conf - if available
+ if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
+ val log4jConfResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ log4jConfResource.setType(LocalResourceType.FILE)
+ log4jConfResource.setVisibility(LocalResourceVisibility.APPLICATION)
+ log4jConfResource.setResource(ConverterUtils.getYarnUrlFromURI(
+ new URI(System.getenv("SPARK_YARN_LOG4J_PATH"))))
+ log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong)
+ log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong)
+ locaResources("log4j.properties") = log4jConfResource
+ }
+
+
+ logInfo("Prepared Local resources " + locaResources)
+ return locaResources
+ }
+
+ def prepareEnvironment: HashMap[String, String] = {
+ val env = new HashMap[String, String]()
+ // should we add this ?
+ Apps.addToEnvironment(env, Environment.USER.name, Utils.getUserNameFromEnvironment())
+
+ // If log4j present, ensure ours overrides all others
+ if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
+ // Which is correct ?
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./log4j.properties")
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
+ }
+
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
+ Client.populateHadoopClasspath(yarnConf, env)
+
+ System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
+ return env
+ }
+
+ def connectToCM: ContainerManager = {
+ val cmHostPortStr = container.getNodeId().getHost() + ":" + container.getNodeId().getPort()
+ val cmAddress = NetUtils.createSocketAddr(cmHostPortStr)
+ logInfo("Connecting to ContainerManager at " + cmHostPortStr)
+ return rpc.getProxy(classOf[ContainerManager], cmAddress, conf).asInstanceOf[ContainerManager]
+ }
+
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala
new file mode 100644
index 0000000000..61dd72a651
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala
@@ -0,0 +1,547 @@
+package spark.deploy.yarn
+
+import spark.{Logging, Utils}
+import spark.scheduler.SplitInfo
+import scala.collection
+import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId, ContainerId, Priority, Resource, ResourceRequest, ContainerStatus, Container}
+import spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend}
+import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse}
+import org.apache.hadoop.yarn.util.{RackResolver, Records}
+import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap}
+import java.util.concurrent.atomic.AtomicInteger
+import org.apache.hadoop.yarn.api.AMRMProtocol
+import collection.JavaConversions._
+import collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import org.apache.hadoop.conf.Configuration
+import java.util.{Collections, Set => JSet}
+import java.lang.{Boolean => JBoolean}
+
+object AllocationType extends Enumeration ("HOST", "RACK", "ANY") {
+ type AllocationType = Value
+ val HOST, RACK, ANY = Value
+}
+
+// too many params ? refactor it 'somehow' ?
+// needs to be mt-safe
+// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive : should make it
+// more proactive and decoupled.
+// Note that right now, we assume all node asks as uniform in terms of capabilities and priority
+// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for more info
+// on how we are requesting for containers.
+private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceManager: AMRMProtocol,
+ val appAttemptId: ApplicationAttemptId,
+ val maxWorkers: Int, val workerMemory: Int, val workerCores: Int,
+ val preferredHostToCount: Map[String, Int],
+ val preferredRackToCount: Map[String, Int])
+ extends Logging {
+
+
+ // These three are locked on allocatedHostToContainersMap. Complementary data structures
+ // allocatedHostToContainersMap : containers which are running : host, Set<containerid>
+ // allocatedContainerToHostMap: container to host mapping
+ private val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]]()
+ private val allocatedContainerToHostMap = new HashMap[ContainerId, String]()
+ // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an allocated node)
+ // As with the two data structures above, tightly coupled with them, and to be locked on allocatedHostToContainersMap
+ private val allocatedRackCount = new HashMap[String, Int]()
+
+ // containers which have been released.
+ private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]()
+ // containers to be released in next request to RM
+ private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean]
+
+ private val numWorkersRunning = new AtomicInteger()
+ // Used to generate a unique id per worker
+ private val workerIdCounter = new AtomicInteger()
+ private val lastResponseId = new AtomicInteger()
+
+ def getNumWorkersRunning: Int = numWorkersRunning.intValue
+
+
+ def isResourceConstraintSatisfied(container: Container): Boolean = {
+ container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+ }
+
+ def allocateContainers(workersToRequest: Int) {
+ // We need to send the request only once from what I understand ... but for now, not modifying this much.
+
+ // Keep polling the Resource Manager for containers
+ val amResp = allocateWorkerResources(workersToRequest).getAMResponse
+
+ val _allocatedContainers = amResp.getAllocatedContainers()
+ if (_allocatedContainers.size > 0) {
+
+
+ logDebug("Allocated " + _allocatedContainers.size + " containers, current count " +
+ numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
+ ", pendingReleaseContainers : " + pendingReleaseContainers)
+ logDebug("Cluster Resources: " + amResp.getAvailableResources)
+
+ val hostToContainers = new HashMap[String, ArrayBuffer[Container]]()
+
+ // ignore if not satisfying constraints {
+ for (container <- _allocatedContainers) {
+ if (isResourceConstraintSatisfied(container)) {
+ // allocatedContainers += container
+
+ val host = container.getNodeId.getHost
+ val containers = hostToContainers.getOrElseUpdate(host, new ArrayBuffer[Container]())
+
+ containers += container
+ }
+ // Add all ignored containers to released list
+ else releasedContainerList.add(container.getId())
+ }
+
+ // Find the appropriate containers to use
+ // Slightly non trivial groupBy I guess ...
+ val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
+ val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
+ val offRackContainers = new HashMap[String, ArrayBuffer[Container]]()
+
+ for (candidateHost <- hostToContainers.keySet)
+ {
+ val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0)
+ val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost)
+
+ var remainingContainers = hostToContainers.get(candidateHost).getOrElse(null)
+ assert(remainingContainers != null)
+
+ if (requiredHostCount >= remainingContainers.size){
+ // Since we got <= required containers, add all to dataLocalContainers
+ dataLocalContainers.put(candidateHost, remainingContainers)
+ // all consumed
+ remainingContainers = null
+ }
+ else if (requiredHostCount > 0) {
+ // container list has more containers than we need for data locality.
+ // Split into two : data local container count of (remainingContainers.size - requiredHostCount)
+ // and rest as remainingContainer
+ val (dataLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredHostCount)
+ dataLocalContainers.put(candidateHost, dataLocal)
+ // remainingContainers = remaining
+
+ // yarn has nasty habit of allocating a tonne of containers on a host - discourage this :
+ // add remaining to release list. If we have insufficient containers, next allocation cycle
+ // will reallocate (but wont treat it as data local)
+ for (container <- remaining) releasedContainerList.add(container.getId())
+ remainingContainers = null
+ }
+
+ // now rack local
+ if (remainingContainers != null){
+ val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
+
+ if (rack != null){
+ val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0)
+ val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) -
+ rackLocalContainers.get(rack).getOrElse(List()).size
+
+
+ if (requiredRackCount >= remainingContainers.size){
+ // Add all to dataLocalContainers
+ dataLocalContainers.put(rack, remainingContainers)
+ // all consumed
+ remainingContainers = null
+ }
+ else if (requiredRackCount > 0) {
+ // container list has more containers than we need for data locality.
+ // Split into two : data local container count of (remainingContainers.size - requiredRackCount)
+ // and rest as remainingContainer
+ val (rackLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredRackCount)
+ val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, new ArrayBuffer[Container]())
+
+ existingRackLocal ++= rackLocal
+ remainingContainers = remaining
+ }
+ }
+ }
+
+ // If still not consumed, then it is off rack host - add to that list.
+ if (remainingContainers != null){
+ offRackContainers.put(candidateHost, remainingContainers)
+ }
+ }
+
+ // Now that we have split the containers into various groups, go through them in order :
+ // first host local, then rack local and then off rack (everything else).
+ // Note that the list we create below tries to ensure that not all containers end up within a host
+ // if there are sufficiently large number of hosts/containers.
+
+ val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size)
+ allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers)
+ allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers)
+ allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers)
+
+ // Run each of the allocated containers
+ for (container <- allocatedContainers) {
+ val numWorkersRunningNow = numWorkersRunning.incrementAndGet()
+ val workerHostname = container.getNodeId.getHost
+ val containerId = container.getId
+
+ assert (container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD))
+
+ if (numWorkersRunningNow > maxWorkers) {
+ logInfo("Ignoring container " + containerId + " at host " + workerHostname +
+ " .. we already have required number of containers")
+ releasedContainerList.add(containerId)
+ // reset counter back to old value.
+ numWorkersRunning.decrementAndGet()
+ }
+ else {
+ // deallocate + allocate can result in reusing id's wrongly - so use a different counter (workerIdCounter)
+ val workerId = workerIdCounter.incrementAndGet().toString
+ val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
+ StandaloneSchedulerBackend.ACTOR_NAME)
+
+ logInfo("launching container on " + containerId + " host " + workerHostname)
+ // just to be safe, simply remove it from pendingReleaseContainers. Should not be there, but ..
+ pendingReleaseContainers.remove(containerId)
+
+ val rack = YarnAllocationHandler.lookupRack(conf, workerHostname)
+ allocatedHostToContainersMap.synchronized {
+ val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname, new HashSet[ContainerId]())
+
+ containerSet += containerId
+ allocatedContainerToHostMap.put(containerId, workerHostname)
+ if (rack != null) allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1)
+ }
+
+ new Thread(
+ new WorkerRunnable(container, conf, driverUrl, workerId,
+ workerHostname, workerMemory, workerCores)
+ ).start()
+ }
+ }
+ logDebug("After allocated " + allocatedContainers.size + " containers (orig : " +
+ _allocatedContainers.size + "), current count " + numWorkersRunning.get() +
+ ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
+ }
+
+
+ val completedContainers = amResp.getCompletedContainersStatuses()
+ if (completedContainers.size > 0){
+ logDebug("Completed " + completedContainers.size + " containers, current count " + numWorkersRunning.get() +
+ ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
+
+ for (completedContainer <- completedContainers){
+ val containerId = completedContainer.getContainerId
+
+ // Was this released by us ? If yes, then simply remove from containerSet and move on.
+ if (pendingReleaseContainers.containsKey(containerId)) {
+ pendingReleaseContainers.remove(containerId)
+ }
+ else {
+ // simply decrement count - next iteration of ReporterThread will take care of allocating !
+ numWorkersRunning.decrementAndGet()
+ logInfo("Container completed ? nodeId: " + containerId + ", state " + completedContainer.getState +
+ " httpaddress: " + completedContainer.getDiagnostics)
+ }
+
+ allocatedHostToContainersMap.synchronized {
+ if (allocatedContainerToHostMap.containsKey(containerId)) {
+ val host = allocatedContainerToHostMap.get(containerId).getOrElse(null)
+ assert (host != null)
+
+ val containerSet = allocatedHostToContainersMap.get(host).getOrElse(null)
+ assert (containerSet != null)
+
+ containerSet -= containerId
+ if (containerSet.isEmpty) allocatedHostToContainersMap.remove(host)
+ else allocatedHostToContainersMap.update(host, containerSet)
+
+ allocatedContainerToHostMap -= containerId
+
+ // doing this within locked context, sigh ... move to outside ?
+ val rack = YarnAllocationHandler.lookupRack(conf, host)
+ if (rack != null) {
+ val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1
+ if (rackCount > 0) allocatedRackCount.put(rack, rackCount)
+ else allocatedRackCount.remove(rack)
+ }
+ }
+ }
+ }
+ logDebug("After completed " + completedContainers.size + " containers, current count " +
+ numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
+ ", pendingReleaseContainers : " + pendingReleaseContainers)
+ }
+ }
+
+ def createRackResourceRequests(hostContainers: List[ResourceRequest]): List[ResourceRequest] = {
+ // First generate modified racks and new set of hosts under it : then issue requests
+ val rackToCounts = new HashMap[String, Int]()
+
+ // Within this lock - used to read/write to the rack related maps too.
+ for (container <- hostContainers) {
+ val candidateHost = container.getHostName
+ val candidateNumContainers = container.getNumContainers
+ assert(YarnAllocationHandler.ANY_HOST != candidateHost)
+
+ val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
+ if (rack != null) {
+ var count = rackToCounts.getOrElse(rack, 0)
+ count += candidateNumContainers
+ rackToCounts.put(rack, count)
+ }
+ }
+
+ val requestedContainers: ArrayBuffer[ResourceRequest] =
+ new ArrayBuffer[ResourceRequest](rackToCounts.size)
+ for ((rack, count) <- rackToCounts){
+ requestedContainers +=
+ createResourceRequest(AllocationType.RACK, rack, count, YarnAllocationHandler.PRIORITY)
+ }
+
+ requestedContainers.toList
+ }
+
+ def allocatedContainersOnHost(host: String): Int = {
+ var retval = 0
+ allocatedHostToContainersMap.synchronized {
+ retval = allocatedHostToContainersMap.getOrElse(host, Set()).size
+ }
+ retval
+ }
+
+ def allocatedContainersOnRack(rack: String): Int = {
+ var retval = 0
+ allocatedHostToContainersMap.synchronized {
+ retval = allocatedRackCount.getOrElse(rack, 0)
+ }
+ retval
+ }
+
+ private def allocateWorkerResources(numWorkers: Int): AllocateResponse = {
+
+ var resourceRequests: List[ResourceRequest] = null
+
+ // default.
+ if (numWorkers <= 0 || preferredHostToCount.isEmpty) {
+ logDebug("numWorkers: " + numWorkers + ", host preferences ? " + preferredHostToCount.isEmpty)
+ resourceRequests = List(
+ createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY))
+ }
+ else {
+ // request for all hosts in preferred nodes and for numWorkers -
+ // candidates.size, request by default allocation policy.
+ val hostContainerRequests: ArrayBuffer[ResourceRequest] =
+ new ArrayBuffer[ResourceRequest](preferredHostToCount.size)
+ for ((candidateHost, candidateCount) <- preferredHostToCount) {
+ val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost)
+
+ if (requiredCount > 0) {
+ hostContainerRequests +=
+ createResourceRequest(AllocationType.HOST, candidateHost, requiredCount, YarnAllocationHandler.PRIORITY)
+ }
+ }
+ val rackContainerRequests: List[ResourceRequest] = createRackResourceRequests(hostContainerRequests.toList)
+
+ val anyContainerRequests: ResourceRequest =
+ createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY)
+
+ val containerRequests: ArrayBuffer[ResourceRequest] =
+ new ArrayBuffer[ResourceRequest](hostContainerRequests.size() + rackContainerRequests.size() + 1)
+
+ containerRequests ++= hostContainerRequests
+ containerRequests ++= rackContainerRequests
+ containerRequests += anyContainerRequests
+
+ resourceRequests = containerRequests.toList
+ }
+
+ val req = Records.newRecord(classOf[AllocateRequest])
+ req.setResponseId(lastResponseId.incrementAndGet)
+ req.setApplicationAttemptId(appAttemptId)
+
+ req.addAllAsks(resourceRequests)
+
+ val releasedContainerList = createReleasedContainerList()
+ req.addAllReleases(releasedContainerList)
+
+
+
+ if (numWorkers > 0) {
+ logInfo("Allocating " + numWorkers + " worker containers with " + (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + " of memory each.")
+ }
+ else {
+ logDebug("Empty allocation req .. release : " + releasedContainerList)
+ }
+
+ for (req <- resourceRequests) {
+ logInfo("rsrcRequest ... host : " + req.getHostName + ", numContainers : " + req.getNumContainers +
+ ", p = " + req.getPriority().getPriority + ", capability: " + req.getCapability)
+ }
+ resourceManager.allocate(req)
+ }
+
+
+ private def createResourceRequest(requestType: AllocationType.AllocationType,
+ resource:String, numWorkers: Int, priority: Int): ResourceRequest = {
+
+ // If hostname specified, we need atleast two requests - node local and rack local.
+ // There must be a third request - which is ANY : that will be specially handled.
+ requestType match {
+ case AllocationType.HOST => {
+ assert (YarnAllocationHandler.ANY_HOST != resource)
+
+ val hostname = resource
+ val nodeLocal = createResourceRequestImpl(hostname, numWorkers, priority)
+
+ // add to host->rack mapping
+ YarnAllocationHandler.populateRackInfo(conf, hostname)
+
+ nodeLocal
+ }
+
+ case AllocationType.RACK => {
+ val rack = resource
+ createResourceRequestImpl(rack, numWorkers, priority)
+ }
+
+ case AllocationType.ANY => {
+ createResourceRequestImpl(YarnAllocationHandler.ANY_HOST, numWorkers, priority)
+ }
+
+ case _ => throw new IllegalArgumentException("Unexpected/unsupported request type .. " + requestType)
+ }
+ }
+
+ private def createResourceRequestImpl(hostname:String, numWorkers: Int, priority: Int): ResourceRequest = {
+
+ val rsrcRequest = Records.newRecord(classOf[ResourceRequest])
+ val memCapability = Records.newRecord(classOf[Resource])
+ // There probably is some overhead here, let's reserve a bit more memory.
+ memCapability.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+ rsrcRequest.setCapability(memCapability)
+
+ val pri = Records.newRecord(classOf[Priority])
+ pri.setPriority(priority)
+ rsrcRequest.setPriority(pri)
+
+ rsrcRequest.setHostName(hostname)
+
+ rsrcRequest.setNumContainers(java.lang.Math.max(numWorkers, 0))
+ rsrcRequest
+ }
+
+ def createReleasedContainerList(): ArrayBuffer[ContainerId] = {
+
+ val retval = new ArrayBuffer[ContainerId](1)
+ // iterator on COW list ...
+ for (container <- releasedContainerList.iterator()){
+ retval += container
+ }
+ // remove from the original list.
+ if (! retval.isEmpty) {
+ releasedContainerList.removeAll(retval)
+ for (v <- retval) pendingReleaseContainers.put(v, true)
+ logInfo("Releasing " + retval.size + " containers. pendingReleaseContainers : " +
+ pendingReleaseContainers)
+ }
+
+ retval
+ }
+}
+
+object YarnAllocationHandler {
+
+ val ANY_HOST = "*"
+ // all requests are issued with same priority : we do not (yet) have any distinction between
+ // request types (like map/reduce in hadoop for example)
+ val PRIORITY = 1
+
+ // Additional memory overhead - in mb
+ val MEMORY_OVERHEAD = 384
+
+ // host to rack map - saved from allocation requests
+ // We are expecting this not to change.
+ // Note that it is possible for this to change : and RM will indicate that to us via update
+ // response to allocate. But we are punting on handling that for now.
+ private val hostToRack = new ConcurrentHashMap[String, String]()
+ private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]()
+
+ def newAllocator(conf: Configuration,
+ resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
+ args: ApplicationMasterArguments,
+ map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
+
+ val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
+
+
+ new YarnAllocationHandler(conf, resourceManager, appAttemptId, args.numWorkers,
+ args.workerMemory, args.workerCores, hostToCount, rackToCount)
+ }
+
+ def newAllocator(conf: Configuration,
+ resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
+ maxWorkers: Int, workerMemory: Int, workerCores: Int,
+ map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
+
+ val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
+
+ new YarnAllocationHandler(conf, resourceManager, appAttemptId, maxWorkers,
+ workerMemory, workerCores, hostToCount, rackToCount)
+ }
+
+ // A simple method to copy the split info map.
+ private def generateNodeToWeight(conf: Configuration, input: collection.Map[String, collection.Set[SplitInfo]]) :
+ // host to count, rack to count
+ (Map[String, Int], Map[String, Int]) = {
+
+ if (input == null) return (Map[String, Int](), Map[String, Int]())
+
+ val hostToCount = new HashMap[String, Int]
+ val rackToCount = new HashMap[String, Int]
+
+ for ((host, splits) <- input) {
+ val hostCount = hostToCount.getOrElse(host, 0)
+ hostToCount.put(host, hostCount + splits.size)
+
+ val rack = lookupRack(conf, host)
+ if (rack != null){
+ val rackCount = rackToCount.getOrElse(host, 0)
+ rackToCount.put(host, rackCount + splits.size)
+ }
+ }
+
+ (hostToCount.toMap, rackToCount.toMap)
+ }
+
+ def lookupRack(conf: Configuration, host: String): String = {
+ if (! hostToRack.contains(host)) populateRackInfo(conf, host)
+ hostToRack.get(host)
+ }
+
+ def fetchCachedHostsForRack(rack: String): Option[Set[String]] = {
+ val set = rackToHostSet.get(rack)
+ if (set == null) return None
+
+ // No better way to get a Set[String] from JSet ?
+ val convertedSet: collection.mutable.Set[String] = set
+ Some(convertedSet.toSet)
+ }
+
+ def populateRackInfo(conf: Configuration, hostname: String) {
+ Utils.checkHost(hostname)
+
+ if (!hostToRack.containsKey(hostname)) {
+ // If there are repeated failures to resolve, all to an ignore list ?
+ val rackInfo = RackResolver.resolve(conf, hostname)
+ if (rackInfo != null && rackInfo.getNetworkLocation != null) {
+ val rack = rackInfo.getNetworkLocation
+ hostToRack.put(hostname, rack)
+ if (! rackToHostSet.containsKey(rack)) {
+ rackToHostSet.putIfAbsent(rack, Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]()))
+ }
+ rackToHostSet.get(rack).add(hostname)
+
+ // Since RackResolver caches, we are disabling this for now ...
+ } /* else {
+ // right ? Else we will keep calling rack resolver in case we cant resolve rack info ...
+ hostToRack.put(hostname, null)
+ } */
+ }
+ }
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala b/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala
new file mode 100644
index 0000000000..ed732d36bf
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -0,0 +1,42 @@
+package spark.scheduler.cluster
+
+import spark._
+import spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler}
+import org.apache.hadoop.conf.Configuration
+
+/**
+ *
+ * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done
+ */
+private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+
+ def this(sc: SparkContext) = this(sc, new Configuration())
+
+ // Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate
+ // Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?)
+ // Subsequent creations are ignored - since nodes are already allocated by then.
+
+
+ // By default, rack is unknown
+ override def getRackForHost(hostPort: String): Option[String] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ val retval = YarnAllocationHandler.lookupRack(conf, host)
+ if (retval != null) Some(retval) else None
+ }
+
+ // By default, if rack is unknown, return nothing
+ override def getCachedHostsForRack(rack: String): Option[Set[String]] = {
+ if (rack == None || rack == null) return None
+
+ YarnAllocationHandler.fetchCachedHostsForRack(rack)
+ }
+
+ override def postStartHook() {
+ val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc)
+ if (sparkContextInitialized){
+ // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
+ Thread.sleep(3000L)
+ }
+ logInfo("YarnClusterScheduler.postStartHook done")
+ }
+}
diff --git a/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
index 35300cea58..a0652d7fc7 100644
--- a/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
+++ b/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
@@ -4,4 +4,7 @@ trait HadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
+ jobId, isMap, taskId, attemptId)
}
diff --git a/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
index 7afdbff320..7fdbe322fd 100644
--- a/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
+++ b/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
@@ -7,4 +7,7 @@ trait HadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
+ jobId, isMap, taskId, attemptId)
}
diff --git a/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala
new file mode 100644
index 0000000000..a0fb4fe25d
--- /dev/null
+++ b/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala
@@ -0,0 +1,23 @@
+package spark.deploy
+import org.apache.hadoop.conf.Configuration
+
+
+/**
+ * Contains util methods to interact with Hadoop from spark.
+ */
+object SparkHadoopUtil {
+
+ def getUserNameFromEnvironment(): String = {
+ // defaulting to -D ...
+ System.getProperty("user.name")
+ }
+
+ def runAsUser(func: (Product) => Unit, args: Product) {
+
+ // Add support, if exists - for now, simply run func !
+ func(args)
+ }
+
+ // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
+ def newConfiguration(): Configuration = new Configuration()
+}
diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java
new file mode 100644
index 0000000000..a4bb4bc701
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileClient.java
@@ -0,0 +1,72 @@
+package spark.network.netty;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.oio.OioEventLoopGroup;
+import io.netty.channel.socket.oio.OioSocketChannel;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+class FileClient {
+
+ private Logger LOG = LoggerFactory.getLogger(this.getClass().getName());
+ private FileClientHandler handler = null;
+ private Channel channel = null;
+ private Bootstrap bootstrap = null;
+ private int connectTimeout = 60*1000; // 1 min
+
+ public FileClient(FileClientHandler handler, int connectTimeout) {
+ this.handler = handler;
+ this.connectTimeout = connectTimeout;
+ }
+
+ public void init() {
+ bootstrap = new Bootstrap();
+ bootstrap.group(new OioEventLoopGroup())
+ .channel(OioSocketChannel.class)
+ .option(ChannelOption.SO_KEEPALIVE, true)
+ .option(ChannelOption.TCP_NODELAY, true)
+ .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout)
+ .handler(new FileClientChannelInitializer(handler));
+ }
+
+ public void connect(String host, int port) {
+ try {
+ // Start the connection attempt.
+ channel = bootstrap.connect(host, port).sync().channel();
+ // ChannelFuture cf = channel.closeFuture();
+ //cf.addListener(new ChannelCloseListener(this));
+ } catch (InterruptedException e) {
+ close();
+ }
+ }
+
+ public void waitForClose() {
+ try {
+ channel.closeFuture().sync();
+ } catch (InterruptedException e) {
+ LOG.warn("FileClient interrupted", e);
+ }
+ }
+
+ public void sendRequest(String file) {
+ //assert(file == null);
+ //assert(channel == null);
+ channel.write(file + "\r\n");
+ }
+
+ public void close() {
+ if(channel != null) {
+ channel.close();
+ channel = null;
+ }
+ if ( bootstrap!=null) {
+ bootstrap.shutdown();
+ bootstrap = null;
+ }
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java
new file mode 100644
index 0000000000..af25baf641
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java
@@ -0,0 +1,24 @@
+package spark.network.netty;
+
+import io.netty.buffer.BufType;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.codec.string.StringEncoder;
+
+
+class FileClientChannelInitializer extends ChannelInitializer<SocketChannel> {
+
+ private FileClientHandler fhandler;
+
+ public FileClientChannelInitializer(FileClientHandler handler) {
+ fhandler = handler;
+ }
+
+ @Override
+ public void initChannel(SocketChannel channel) {
+ // file no more than 2G
+ channel.pipeline()
+ .addLast("encoder", new StringEncoder(BufType.BYTE))
+ .addLast("handler", fhandler);
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java
new file mode 100644
index 0000000000..9fc9449827
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileClientHandler.java
@@ -0,0 +1,43 @@
+package spark.network.netty;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundByteHandlerAdapter;
+
+
+abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
+
+ private FileHeader currentHeader = null;
+
+ private volatile boolean handlerCalled = false;
+
+ public boolean isComplete() {
+ return handlerCalled;
+ }
+
+ public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
+ public abstract void handleError(String blockId);
+
+ @Override
+ public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {
+ // Use direct buffer if possible.
+ return ctx.alloc().ioBuffer();
+ }
+
+ @Override
+ public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) {
+ // get header
+ if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) {
+ currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE()));
+ }
+ // get file
+ if(in.readableBytes() >= currentHeader.fileLen()) {
+ handle(ctx, in, currentHeader);
+ handlerCalled = true;
+ currentHeader = null;
+ ctx.close();
+ }
+ }
+
+}
+
diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java
new file mode 100644
index 0000000000..dd3a557ae5
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileServer.java
@@ -0,0 +1,86 @@
+package spark.network.netty;
+
+import java.net.InetSocketAddress;
+
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.oio.OioEventLoopGroup;
+import io.netty.channel.socket.oio.OioServerSocketChannel;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Server that accept the path of a file an echo back its content.
+ */
+class FileServer {
+
+ private Logger LOG = LoggerFactory.getLogger(this.getClass().getName());
+
+ private ServerBootstrap bootstrap = null;
+ private ChannelFuture channelFuture = null;
+ private int port = 0;
+ private Thread blockingThread = null;
+
+ public FileServer(PathResolver pResolver, int port) {
+ InetSocketAddress addr = new InetSocketAddress(port);
+
+ // Configure the server.
+ bootstrap = new ServerBootstrap();
+ bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup())
+ .channel(OioServerSocketChannel.class)
+ .option(ChannelOption.SO_BACKLOG, 100)
+ .option(ChannelOption.SO_RCVBUF, 1500)
+ .childHandler(new FileServerChannelInitializer(pResolver));
+ // Start the server.
+ channelFuture = bootstrap.bind(addr);
+ try {
+ // Get the address we bound to.
+ InetSocketAddress boundAddress =
+ ((InetSocketAddress) channelFuture.sync().channel().localAddress());
+ this.port = boundAddress.getPort();
+ } catch (InterruptedException ie) {
+ this.port = 0;
+ }
+ }
+
+ /**
+ * Start the file server asynchronously in a new thread.
+ */
+ public void start() {
+ blockingThread = new Thread() {
+ public void run() {
+ try {
+ channelFuture.channel().closeFuture().sync();
+ LOG.info("FileServer exiting");
+ } catch (InterruptedException e) {
+ LOG.error("File server start got interrupted", e);
+ }
+ // NOTE: bootstrap is shutdown in stop()
+ }
+ };
+ blockingThread.setDaemon(true);
+ blockingThread.start();
+ }
+
+ public int getPort() {
+ return port;
+ }
+
+ public void stop() {
+ // Close the bound channel.
+ if (channelFuture != null) {
+ channelFuture.channel().close();
+ channelFuture = null;
+ }
+ // Shutdown bootstrap.
+ if (bootstrap != null) {
+ bootstrap.shutdown();
+ bootstrap = null;
+ }
+ // TODO: Shutdown all accepted channels as well ?
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java
new file mode 100644
index 0000000000..8f1f5c65cd
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java
@@ -0,0 +1,25 @@
+package spark.network.netty;
+
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.codec.DelimiterBasedFrameDecoder;
+import io.netty.handler.codec.Delimiters;
+import io.netty.handler.codec.string.StringDecoder;
+
+
+class FileServerChannelInitializer extends ChannelInitializer<SocketChannel> {
+
+ PathResolver pResolver;
+
+ public FileServerChannelInitializer(PathResolver pResolver) {
+ this.pResolver = pResolver;
+ }
+
+ @Override
+ public void initChannel(SocketChannel channel) {
+ channel.pipeline()
+ .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter()))
+ .addLast("strDecoder", new StringDecoder())
+ .addLast("handler", new FileServerHandler(pResolver));
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/FileServerHandler.java b/core/src/main/java/spark/network/netty/FileServerHandler.java
new file mode 100644
index 0000000000..a78eddb1b5
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileServerHandler.java
@@ -0,0 +1,65 @@
+package spark.network.netty;
+
+import java.io.File;
+import java.io.FileInputStream;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundMessageHandlerAdapter;
+import io.netty.channel.DefaultFileRegion;
+
+
+class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
+
+ PathResolver pResolver;
+
+ public FileServerHandler(PathResolver pResolver){
+ this.pResolver = pResolver;
+ }
+
+ @Override
+ public void messageReceived(ChannelHandlerContext ctx, String blockId) {
+ String path = pResolver.getAbsolutePath(blockId);
+ // if getFilePath returns null, close the channel
+ if (path == null) {
+ //ctx.close();
+ return;
+ }
+ File file = new File(path);
+ if (file.exists()) {
+ if (!file.isFile()) {
+ //logger.info("Not a file : " + file.getAbsolutePath());
+ ctx.write(new FileHeader(0, blockId).buffer());
+ ctx.flush();
+ return;
+ }
+ long length = file.length();
+ if (length > Integer.MAX_VALUE || length <= 0) {
+ //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
+ ctx.write(new FileHeader(0, blockId).buffer());
+ ctx.flush();
+ return;
+ }
+ int len = new Long(length).intValue();
+ //logger.info("Sending block "+blockId+" filelen = "+len);
+ //logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
+ ctx.write((new FileHeader(len, blockId)).buffer());
+ try {
+ ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
+ .getChannel(), 0, file.length()));
+ } catch (Exception e) {
+ //logger.warning("Exception when sending file : " + file.getAbsolutePath());
+ e.printStackTrace();
+ }
+ } else {
+ //logger.warning("File not found: " + file.getAbsolutePath());
+ ctx.write(new FileHeader(0, blockId).buffer());
+ }
+ ctx.flush();
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+ cause.printStackTrace();
+ ctx.close();
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/PathResolver.java b/core/src/main/java/spark/network/netty/PathResolver.java
new file mode 100755
index 0000000000..302411672c
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/PathResolver.java
@@ -0,0 +1,12 @@
+package spark.network.netty;
+
+
+public interface PathResolver {
+ /**
+ * Get the absolute path of the file
+ *
+ * @param fileId
+ * @return the absolute path of file
+ */
+ public String getAbsolutePath(String fileId);
+}
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
index c27ed36406..3239f4c385 100644
--- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -1,14 +1,19 @@
package spark
-import executor.{ShuffleReadMetrics, TaskMetrics}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
-import spark.storage.{DelegateBlockFetchTracker, BlockManagerId}
-import util.{CompletionIterator, TimedIterator}
+import spark.executor.{ShuffleReadMetrics, TaskMetrics}
+import spark.serializer.Serializer
+import spark.storage.BlockManagerId
+import spark.util.CompletionIterator
+
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
- override def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) = {
+
+ override def fetch[K, V](
+ shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) = {
+
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
@@ -48,18 +53,18 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
}
}
- val blockFetcherItr = blockManager.getMultiple(blocksByAddress)
- val itr = new TimedIterator(blockFetcherItr.flatMap(unpackBlock)) with DelegateBlockFetchTracker
- itr.setDelegate(blockFetcherItr)
+ val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
+ val itr = blockFetcherItr.flatMap(unpackBlock)
+
CompletionIterator[(K,V), Iterator[(K,V)]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
- shuffleMetrics.shuffleReadMillis = itr.getNetMillis
- shuffleMetrics.remoteFetchTime = itr.remoteFetchTime
- shuffleMetrics.fetchWaitTime = itr.fetchWaitTime
- shuffleMetrics.remoteBytesRead = itr.remoteBytesRead
- shuffleMetrics.totalBlocksFetched = itr.totalBlocks
- shuffleMetrics.localBlocksFetched = itr.numLocalBlocks
- shuffleMetrics.remoteBlocksFetched = itr.numRemoteBlocks
+ shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
+ shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
+ shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
+ shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
+ shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
+ shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
+ shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
metrics.shuffleReadMetrics = Some(shuffleMetrics)
})
}
diff --git a/core/src/main/scala/spark/ClosureCleaner.scala b/core/src/main/scala/spark/ClosureCleaner.scala
index 98525b99c8..d5e7132ff9 100644
--- a/core/src/main/scala/spark/ClosureCleaner.scala
+++ b/core/src/main/scala/spark/ClosureCleaner.scala
@@ -5,15 +5,22 @@ import java.lang.reflect.Field
import scala.collection.mutable.Map
import scala.collection.mutable.Set
-import org.objectweb.asm.{ClassReader, MethodVisitor, Type}
-import org.objectweb.asm.commons.EmptyVisitor
+import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import org.objectweb.asm.Opcodes._
+import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream}
private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
private def getClassReader(cls: Class[_]): ClassReader = {
- new ClassReader(cls.getResourceAsStream(
- cls.getName.replaceFirst("^.*\\.", "") + ".class"))
+ // Copy data over, before delegating to ClassReader - else we can run out of open file handles.
+ val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
+ val resourceStream = cls.getResourceAsStream(className)
+ // todo: Fixme - continuing with earlier behavior ...
+ if (resourceStream == null) return new ClassReader(resourceStream)
+
+ val baos = new ByteArrayOutputStream(128)
+ Utils.copyStream(resourceStream, baos, true)
+ new ClassReader(new ByteArrayInputStream(baos.toByteArray))
}
// Check whether a class represents a Scala closure
@@ -154,10 +161,10 @@ private[spark] object ClosureCleaner extends Logging {
}
}
-private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor {
+private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
- return new EmptyVisitor {
+ return new MethodVisitor(ASM4) {
override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) {
if (op == GETFIELD) {
for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
@@ -180,7 +187,7 @@ private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) exten
}
}
-private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor {
+private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) {
var myName: String = null
override def visit(version: Int, access: Int, name: String, sig: String,
@@ -190,7 +197,7 @@ private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisi
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
- return new EmptyVisitor {
+ return new MethodVisitor(ASM4) {
override def visitMethodInsn(op: Int, owner: String, name: String,
desc: String) {
val argTypes = Type.getArgumentTypes(desc)
diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala
index 5eea907322..2af44aa383 100644
--- a/core/src/main/scala/spark/Dependency.scala
+++ b/core/src/main/scala/spark/Dependency.scala
@@ -25,10 +25,12 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
* @param shuffleId the shuffle id
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
+ * @param serializerClass class name of the serializer to use
*/
class ShuffleDependency[K, V](
@transient rdd: RDD[(K, V)],
- val partitioner: Partitioner)
+ val partitioner: Partitioner,
+ val serializerClass: String = null)
extends Dependency(rdd) {
val shuffleId: Int = rdd.context.newShuffleId()
diff --git a/core/src/main/scala/spark/FetchFailedException.scala b/core/src/main/scala/spark/FetchFailedException.scala
index a953081d24..40b0193f19 100644
--- a/core/src/main/scala/spark/FetchFailedException.scala
+++ b/core/src/main/scala/spark/FetchFailedException.scala
@@ -3,18 +3,25 @@ package spark
import spark.storage.BlockManagerId
private[spark] class FetchFailedException(
- val bmAddress: BlockManagerId,
- val shuffleId: Int,
- val mapId: Int,
- val reduceId: Int,
+ taskEndReason: TaskEndReason,
+ message: String,
cause: Throwable)
extends Exception {
-
- override def getMessage(): String =
- "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
+
+ def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) =
+ this(FetchFailed(bmAddress, shuffleId, mapId, reduceId),
+ "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId),
+ cause)
+
+ def this (shuffleId: Int, reduceId: Int, cause: Throwable) =
+ this(FetchFailed(null, shuffleId, -1, reduceId),
+ "Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause)
+
+ override def getMessage(): String = message
+
override def getCause(): Throwable = cause
- def toTaskEndReason: TaskEndReason =
- FetchFailed(bmAddress, shuffleId, mapId, reduceId)
+ def toTaskEndReason: TaskEndReason = taskEndReason
+
}
diff --git a/core/src/main/scala/spark/HadoopWriter.scala b/core/src/main/scala/spark/HadoopWriter.scala
index afcf9f6db4..5e8396edb9 100644
--- a/core/src/main/scala/spark/HadoopWriter.scala
+++ b/core/src/main/scala/spark/HadoopWriter.scala
@@ -2,14 +2,10 @@ package org.apache.hadoop.mapred
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
-import org.apache.hadoop.util.ReflectionUtils
-import org.apache.hadoop.io.NullWritable
-import org.apache.hadoop.io.Text
import java.text.SimpleDateFormat
import java.text.NumberFormat
import java.io.IOException
-import java.net.URI
import java.util.Date
import spark.Logging
@@ -24,7 +20,7 @@ import spark.SerializableWritable
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRedUtil with Serializable {
-
+
private val now = new Date()
private val conf = new SerializableWritable(jobConf)
@@ -106,6 +102,12 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRe
}
}
+ def commitJob() {
+ // always ? Or if cmtr.needsTaskCommit ?
+ val cmtr = getOutputCommitter()
+ cmtr.commitJob(getJobContext())
+ }
+
def cleanup() {
getOutputCommitter().cleanupJob(getJobContext())
}
diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala
index 7c1c1bb144..0fc8c31463 100644
--- a/core/src/main/scala/spark/Logging.scala
+++ b/core/src/main/scala/spark/Logging.scala
@@ -68,6 +68,10 @@ trait Logging {
if (log.isErrorEnabled) log.error(msg, throwable)
}
+ protected def isTraceEnabled(): Boolean = {
+ log.isTraceEnabled
+ }
+
// Method for ensuring that logging is initialized, to avoid having multiple
// threads do it concurrently (as SLF4J initialization is not thread safe).
protected def initLogging() { log }
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 50708d9cb1..0fc6427307 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -1,7 +1,6 @@
package spark
import java.io._
-import java.util.concurrent.ConcurrentHashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.HashMap
@@ -11,6 +10,7 @@ import akka.actor._
import scala.concurrent.Await
import akka.pattern.ask
import akka.remote._
+
import scala.concurrent.duration.Duration
import akka.util.Timeout
import scala.concurrent.duration._
@@ -40,10 +40,12 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac
private[spark] class MapOutputTracker extends Logging {
+ private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _
- var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
@@ -52,7 +54,7 @@ private[spark] class MapOutputTracker extends Logging {
// Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheGeneration = generation
- val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
+ private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
@@ -60,7 +62,6 @@ private[spark] class MapOutputTracker extends Logging {
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
try {
- val timeout = 10.seconds
val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout)
} catch {
@@ -77,10 +78,9 @@ private[spark] class MapOutputTracker extends Logging {
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (mapStatuses.get(shuffleId) != None) {
+ if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
- mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
@@ -101,8 +101,9 @@ private[spark] class MapOutputTracker extends Logging {
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var array = mapStatuses(shuffleId)
- if (array != null) {
+ var arrayOpt = mapStatuses.get(shuffleId)
+ if (arrayOpt.isDefined && arrayOpt.get != null) {
+ var array = arrayOpt.get
array.synchronized {
if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null
@@ -115,13 +116,14 @@ private[spark] class MapOutputTracker extends Logging {
}
// Remembers which map output locations are currently being fetched on a worker
- val fetching = new HashSet[Int]
+ private val fetching = new HashSet[Int]
// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
+ var fetchedStatuses: Array[MapStatus] = null
fetching.synchronized {
if (fetching.contains(shuffleId)) {
// Someone else is fetching it; wait for them to be done
@@ -132,31 +134,48 @@ private[spark] class MapOutputTracker extends Logging {
case e: InterruptedException =>
}
}
- return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId))
- } else {
+ }
+
+ // Either while we waited the fetch happened successfully, or
+ // someone fetched it in between the get and the fetching.synchronized.
+ fetchedStatuses = mapStatuses.get(shuffleId).orNull
+ if (fetchedStatuses == null) {
+ // We have to do the fetch, get others to wait for us.
fetching += shuffleId
}
}
- // We won the race to fetch the output locs; do so
- logInfo("Doing the fetch; tracker actor = " + trackerActor)
- val host = System.getProperty("spark.hostname", Utils.localHostName)
- // This try-finally prevents hangs due to timeouts:
- var fetchedStatuses: Array[MapStatus] = null
- try {
- val fetchedBytes =
- askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]]
- fetchedStatuses = deserializeStatuses(fetchedBytes)
- logInfo("Got the output locations")
- mapStatuses.put(shuffleId, fetchedStatuses)
- } finally {
- fetching.synchronized {
- fetching -= shuffleId
- fetching.notifyAll()
+
+ if (fetchedStatuses == null) {
+ // We won the race to fetch the output locs; do so
+ logInfo("Doing the fetch; tracker actor = " + trackerActor)
+ val hostPort = Utils.localHostPort()
+ // This try-finally prevents hangs due to timeouts:
+ try {
+ val fetchedBytes =
+ askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
+ fetchedStatuses = deserializeStatuses(fetchedBytes)
+ logInfo("Got the output locations")
+ mapStatuses.put(shuffleId, fetchedStatuses)
+ } finally {
+ fetching.synchronized {
+ fetching -= shuffleId
+ fetching.notifyAll()
+ }
+ }
+ }
+ if (fetchedStatuses != null) {
+ fetchedStatuses.synchronized {
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
}
- return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
+ else{
+ throw new FetchFailedException(null, shuffleId, -1, reduceId,
+ new Exception("Missing all output locations for shuffle " + shuffleId))
+ }
} else {
- return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
+ statuses.synchronized {
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
+ }
}
}
@@ -194,7 +213,8 @@ private[spark] class MapOutputTracker extends Logging {
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
- mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ mapStatuses.clear()
generation = newGen
}
}
@@ -232,10 +252,13 @@ private[spark] class MapOutputTracker extends Logging {
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
- def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
+ private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
- objOut.writeObject(statuses)
+ // Since statuses can be modified in parallel, sync on it
+ statuses.synchronized {
+ objOut.writeObject(statuses)
+ }
objOut.close()
out.toByteArray
}
@@ -243,7 +266,10 @@ private[spark] class MapOutputTracker extends Logging {
// Opposite of serializeStatuses.
def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
- objIn.readObject().asInstanceOf[Array[MapStatus]]
+ objIn.readObject().
+ // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
+ // comment this out - nulls could be due to missing location ?
+ asInstanceOf[Array[MapStatus]] // .filter( _ != null )
}
}
@@ -253,16 +279,13 @@ private[spark] object MapOutputTracker {
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),
// throw a FetchFailedException.
- def convertMapStatuses(
+ private def convertMapStatuses(
shuffleId: Int,
reduceId: Int,
statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
- if (statuses == null) {
- throw new FetchFailedException(null, shuffleId, -1, reduceId,
- new Exception("Missing all output locations for shuffle " + shuffleId))
- }
+ assert (statuses != null)
statuses.map {
- status =>
+ status =>
if (status == null) {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing an output location for shuffle " + shuffleId))
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 2052d05788..fe812fe530 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -1,5 +1,6 @@
package spark
+import java.nio.ByteBuffer
import java.util.{Date, HashMap => JHashMap}
import java.text.SimpleDateFormat
@@ -11,6 +12,8 @@ import scala.reflect.{ ClassTag, classTag}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.compress.CompressionCodec
+import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.mapred.FileOutputCommitter
import org.apache.hadoop.mapred.FileOutputFormat
import org.apache.hadoop.mapred.HadoopWriter
@@ -18,7 +21,7 @@ import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
-import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil, TaskAttemptID, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil}
import spark.partial.BoundedDouble
import spark.partial.PartialResult
@@ -53,7 +56,8 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
partitioner: Partitioner,
- mapSideCombine: Boolean = true): RDD[(K, C)] = {
+ mapSideCombine: Boolean = true,
+ serializerClass: String = null): RDD[(K, C)] = {
if (getKeyClass().isArray) {
if (mapSideCombine) {
throw new SparkException("Cannot use map-side combining with array keys.")
@@ -62,19 +66,18 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
throw new SparkException("Default partitioner cannot partition array keys.")
}
}
- val aggregator =
- new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
+ val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) {
self.mapPartitions(aggregator.combineValuesByKey(_), true)
} else if (mapSideCombine) {
val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
- val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
+ val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner, serializerClass)
partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
- val values = new ShuffledRDD[K, V](self, partitioner)
+ val values = new ShuffledRDD[K, V](self, partitioner, serializerClass)
values.mapPartitions(aggregator.combineValuesByKey(_), true)
}
}
@@ -95,7 +98,16 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
* list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = {
- combineByKey[V]({v: V => func(zeroValue, v)}, func, func, partitioner)
+ // Serialize the zero value to a byte array so that we can get a new clone of it on each key
+ val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
+ val zeroArray = new Array[Byte](zeroBuffer.limit)
+ zeroBuffer.get(zeroArray)
+
+ // When deserializing, use a lazy val to create just one instance of the serializer per task
+ lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
+ def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray))
+
+ combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner)
}
/**
@@ -185,11 +197,13 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
* partitioning of the resulting key-value pair RDD by passing a Partitioner.
*/
def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = {
+ // groupByKey shouldn't use map side combine because map side combine does not
+ // reduce the amount of data shuffled and requires all map side data be inserted
+ // into a hash table, leading to more objects in the old gen.
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
- def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
val bufs = combineByKey[ArrayBuffer[V]](
- createCombiner _, mergeValue _, mergeCombiners _, partitioner)
+ createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false)
bufs.asInstanceOf[RDD[(K, Seq[V])]]
}
@@ -516,6 +530,16 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
}
/**
+ * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
+ * supporting the key and value types K and V in this RDD. Compress the result with the
+ * supplied codec.
+ */
+ def saveAsHadoopFile[F <: OutputFormat[K, V]](
+ path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) {
+ saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec)
+ }
+
+ /**
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
*/
@@ -546,8 +570,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
- val attemptId = new TaskAttemptID(jobtrackerID,
- stageId, false, context.splitId, attemptNumber)
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = outputFormatClass.newInstance
val committer = format.getOutputCommitter(hadoopContext)
@@ -566,16 +589,31 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
* however we're only going to use this local OutputCommitter for
* setupJob/commitJob, so we just use a dummy "map" task.
*/
- val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, true, 0, 0)
+ val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, true, 0, 0)
val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
jobCommitter.setupJob(jobTaskContext)
val count = self.context.runJob(self, writeShard _).sum
+ jobCommitter.commitJob(jobTaskContext)
jobCommitter.cleanupJob(jobTaskContext)
}
/**
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
+ * supporting the key and value types K and V in this RDD. Compress with the supplied codec.
+ */
+ def saveAsHadoopFile(
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[_ <: OutputFormat[_, _]],
+ codec: Class[_ <: CompressionCodec]) {
+ saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass,
+ new JobConf(self.context.hadoopConfiguration), Some(codec))
+ }
+
+ /**
+ * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
* supporting the key and value types K and V in this RDD.
*/
def saveAsHadoopFile(
@@ -583,11 +621,19 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[_ <: OutputFormat[_, _]],
- conf: JobConf = new JobConf(self.context.hadoopConfiguration)) {
+ conf: JobConf = new JobConf(self.context.hadoopConfiguration),
+ codec: Option[Class[_ <: CompressionCodec]] = None) {
conf.setOutputKeyClass(keyClass)
conf.setOutputValueClass(valueClass)
// conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug
conf.set("mapred.output.format.class", outputFormatClass.getName)
+ for (c <- codec) {
+ conf.setCompressMapOutput(true)
+ conf.set("mapred.output.compress", "true")
+ conf.setMapOutputCompressorClass(c)
+ conf.set("mapred.output.compression.codec", c.getCanonicalName)
+ conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString)
+ }
conf.setOutputCommitter(classOf[FileOutputCommitter])
FileOutputFormat.setOutputPath(conf, HadoopWriter.createPathFromString(path, conf))
saveAsHadoopDataset(conf)
@@ -638,6 +684,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
}
self.context.runJob(self, writeToFile _)
+ writer.commitJob()
writer.cleanup()
}
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 6ee075315a..e88290fdb2 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -1,22 +1,23 @@
package spark
-import java.net.URL
-import java.util.{Date, Random}
-import java.util.{HashMap => JHashMap}
+import java.util.Random
import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
+
import scala.collection.mutable.HashMap
import scala.reflect.{classTag, ClassTag}
import org.apache.hadoop.io.BytesWritable
+import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.mapred.TextOutputFormat
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+import spark.broadcast.Broadcast
import spark.Partitioner._
import spark.partial.BoundedDouble
import spark.partial.CountEvaluator
@@ -33,10 +34,13 @@ import spark.rdd.MapPartitionsWithIndexRDD
import spark.rdd.PipedRDD
import spark.rdd.SampledRDD
import spark.rdd.ShuffledRDD
-import spark.rdd.SubtractedRDD
import spark.rdd.UnionRDD
import spark.rdd.ZippedRDD
+import spark.rdd.ZippedPartitionsRDD2
+import spark.rdd.ZippedPartitionsRDD3
+import spark.rdd.ZippedPartitionsRDD4
import spark.storage.StorageLevel
+import spark.util.BoundedPriorityQueue
import SparkContext._
@@ -105,7 +109,7 @@ abstract class RDD[T: ClassTag](
// =======================================================================
/** A unique ID for this RDD (within its SparkContext). */
- val id = sc.newRddId()
+ val id: Int = sc.newRddId()
/** A friendly name for this RDD */
var name: String = null
@@ -116,9 +120,18 @@ abstract class RDD[T: ClassTag](
this
}
+ /** User-defined generator of this RDD*/
+ var generator = Utils.getCallSiteInfo.firstUserClass
+
+ /** Reset generator*/
+ def setGenerator(_generator: String) = {
+ generator = _generator
+ }
+
/**
* Set this RDD's storage level to persist its values across operations after the first time
- * it is computed. Can only be called once on each RDD.
+ * it is computed. This can only be used to assign a new storage level if the RDD does not
+ * have a storage level set yet..
*/
def persist(newLevel: StorageLevel): RDD[T] = {
// TODO: Handle changes of StorageLevel
@@ -138,6 +151,20 @@ abstract class RDD[T: ClassTag](
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): RDD[T] = persist()
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ *
+ * @param blocking Whether to block until all blocks are deleted.
+ * @return This RDD.
+ */
+ def unpersist(blocking: Boolean = true): RDD[T] = {
+ logInfo("Removing RDD " + id + " from persistence list")
+ sc.env.blockManager.master.removeRdd(id, blocking)
+ sc.persistentRdds.remove(id)
+ storageLevel = StorageLevel.NONE
+ this
+ }
+
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
@@ -257,8 +284,8 @@ abstract class RDD[T: ClassTag](
def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
var fraction = 0.0
var total = 0
- var multiplier = 3.0
- var initialCount = count()
+ val multiplier = 3.0
+ val initialCount = count()
var maxSelected = 0
if (initialCount > Integer.MAX_VALUE - 1) {
@@ -339,13 +366,36 @@ abstract class RDD[T: ClassTag](
/**
* Return an RDD created by piping elements to a forked external process.
*/
- def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command)
+ def pipe(command: String, env: Map[String, String]): RDD[String] =
+ new PipedRDD(this, command, env)
+
/**
* Return an RDD created by piping elements to a forked external process.
- */
- def pipe(command: Seq[String], env: Map[String, String]): RDD[String] =
- new PipedRDD(this, command, env)
+ * The print behavior can be customized by providing two functions.
+ *
+ * @param command command to run in forked process.
+ * @param env environment variables to set.
+ * @param printPipeContext Before piping elements, this function is called as an oppotunity
+ * to pipe context data. Print line function (like out.println) will be
+ * passed as printPipeContext's parameter.
+ * @param printRDDElement Use this function to customize how to pipe elements. This function
+ * will be called with each RDD element as the 1st parameter, and the
+ * print line function (like out.println()) as the 2nd parameter.
+ * An example of pipe the RDD data of groupBy() in a streaming way,
+ * instead of constructing a huge String to concat all the elements:
+ * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
+ * for (e <- record._2){f(e)}
+ * @return the result RDD
+ */
+ def pipe(
+ command: Seq[String],
+ env: Map[String, String] = Map(),
+ printPipeContext: (String => Unit) => Unit = null,
+ printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
+ new PipedRDD(this, command, env,
+ if (printPipeContext ne null) sc.clean(printPipeContext) else null,
+ if (printRDDElement ne null) sc.clean(printRDDElement) else null)
/**
* Return a new RDD by applying a function to each partition of this RDD.
@@ -437,6 +487,31 @@ abstract class RDD[T: ClassTag](
*/
def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
+ /**
+ * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
+ * applying a function to the zipped partitions. Assumes that all the RDDs have the
+ * *same number of partitions*, but does *not* require them to have the same number
+ * of elements in each partition.
+ */
+ def zipPartitions[B: ClassManifest, V: ClassManifest](
+ f: (Iterator[T], Iterator[B]) => Iterator[V],
+ rdd2: RDD[B]): RDD[V] =
+ new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2)
+
+ def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest](
+ f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V],
+ rdd2: RDD[B],
+ rdd3: RDD[C]): RDD[V] =
+ new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3)
+
+ def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest](
+ f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
+ rdd2: RDD[B],
+ rdd3: RDD[C],
+ rdd4: RDD[D]): RDD[V] =
+ new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4)
+
+
// Actions (launch a job to return a value to the user program)
/**
@@ -452,7 +527,7 @@ abstract class RDD[T: ClassTag](
*/
def foreachPartition(f: Iterator[T] => Unit) {
val cleanF = sc.clean(f)
- sc.runJob(this, (iter: Iterator[T]) => f(iter))
+ sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
}
/**
@@ -685,6 +760,24 @@ abstract class RDD[T: ClassTag](
}
/**
+ * Returns the top K elements from this RDD as defined by
+ * the specified implicit Ordering[T].
+ * @param num the number of top elements to return
+ * @param ord the implicit ordering for T
+ * @return an array of top elements
+ */
+ def top(num: Int)(implicit ord: Ordering[T]): Array[T] = {
+ mapPartitions { items =>
+ val queue = new BoundedPriorityQueue[T](num)
+ queue ++= items
+ Iterator.single(queue)
+ }.reduce { (queue1, queue2) =>
+ queue1 ++= queue2
+ queue1
+ }.toArray
+ }
+
+ /**
* Save this RDD as a text file, using string representations of elements.
*/
def saveAsTextFile(path: String) {
@@ -693,6 +786,14 @@ abstract class RDD[T: ClassTag](
}
/**
+ * Save this RDD as a compressed text file, using string representations of elements.
+ */
+ def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) {
+ this.map(x => (NullWritable.get(), new Text(x.toString)))
+ .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec)
+ }
+
+ /**
* Save this RDD as a SequenceFile of serialized objects.
*/
def saveAsObjectFile(path: String) {
@@ -750,7 +851,7 @@ abstract class RDD[T: ClassTag](
private var storageLevel: StorageLevel = StorageLevel.NONE
/** Record user function generating this RDD. */
- private[spark] val origin = Utils.getSparkCallSite
+ private[spark] val origin = Utils.formatSparkCallSite
private[spark] def elementClassTag: ClassTag[T] = classTag[T]
diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala
index 083ba9b8fa..5e7bb029eb 100644
--- a/core/src/main/scala/spark/RDDCheckpointData.scala
+++ b/core/src/main/scala/spark/RDDCheckpointData.scala
@@ -3,6 +3,7 @@ package spark
import scala.reflect.ClassTag
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.conf.Configuration
import rdd.{CheckpointRDD, CoalescedRDD}
@@ -66,14 +67,20 @@ private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T])
}
}
+ // Create the output path for the checkpoint
+ val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id)
+ val fs = path.getFileSystem(new Configuration())
+ if (!fs.mkdirs(path)) {
+ throw new SparkException("Failed to create checkpoint path " + path)
+ }
+
// Save to file, and reload it as an RDD
- val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id).toString
- rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _)
- val newRDD = new CheckpointRDD[T](rdd.context, path)
+ rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _)
+ val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
// Change the dependencies and partitions of the RDD
RDDCheckpointData.synchronized {
- cpFile = Some(path)
+ cpFile = Some(path.toString)
cpRDD = Some(newRDD)
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
index 883a0152bb..edfde37da3 100644
--- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
@@ -19,6 +19,7 @@ import org.apache.hadoop.mapred.TextOutputFormat
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.mapred.OutputCommitter
import org.apache.hadoop.mapred.FileOutputCommitter
+import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.Writable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.BytesWritable
@@ -63,7 +64,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag
* byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported
* file system.
*/
- def saveAsSequenceFile(path: String) {
+ def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) {
def anyToWritable[U <% Writable](u: U): Writable = u
val keyClass = getWritableClass[K]
@@ -73,14 +74,18 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag
logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" )
val format = classOf[SequenceFileOutputFormat[Writable, Writable]]
+ val jobConf = new JobConf(self.context.hadoopConfiguration)
if (!convertKey && !convertValue) {
- self.saveAsHadoopFile(path, keyClass, valueClass, format)
+ self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec)
} else if (!convertKey && convertValue) {
- self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format)
+ self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(
+ path, keyClass, valueClass, format, jobConf, codec)
} else if (convertKey && !convertValue) {
- self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(path, keyClass, valueClass, format)
+ self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(
+ path, keyClass, valueClass, format, jobConf, codec)
} else if (convertKey && convertValue) {
- self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format)
+ self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(
+ path, keyClass, valueClass, format, jobConf, codec)
}
}
}
diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala
index 442e9f0269..9513a00126 100644
--- a/core/src/main/scala/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/spark/ShuffleFetcher.scala
@@ -1,13 +1,16 @@
package spark
-import executor.TaskMetrics
+import spark.executor.TaskMetrics
+import spark.serializer.Serializer
+
private[spark] abstract class ShuffleFetcher {
/**
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
- def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) : Iterator[(K,V)]
+ def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
+ serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[(K,V)]
/** Stop the fetcher */
def stop() {}
diff --git a/core/src/main/scala/spark/SizeEstimator.scala b/core/src/main/scala/spark/SizeEstimator.scala
index d4e1157250..f8a4c4e489 100644
--- a/core/src/main/scala/spark/SizeEstimator.scala
+++ b/core/src/main/scala/spark/SizeEstimator.scala
@@ -198,7 +198,7 @@ private[spark] object SizeEstimator extends Logging {
val elem = JArray.get(array, index)
size += SizeEstimator.estimate(elem, state.visited)
}
- state.size += ((length / 100.0) * size).toLong
+ state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong
}
}
}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 7272a592a5..ef6de87193 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -1,48 +1,56 @@
package spark
import java.io._
-import java.util.concurrent.atomic.AtomicInteger
import java.net.URI
+import java.util.Properties
+import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.JavaConversions._
import scala.collection.Map
import scala.collection.generic.Growable
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
+
import scala.reflect.{ ClassTag, classTag}
-import org.apache.hadoop.fs.Path
+import scala.util.DynamicVariable
+import scala.collection.mutable.{ConcurrentMap, HashMap}
+
+import akka.actor.Actor._
+
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.mapred.InputFormat
-import org.apache.hadoop.mapred.SequenceFileInputFormat
-import org.apache.hadoop.io.Writable
-import org.apache.hadoop.io.IntWritable
-import org.apache.hadoop.io.LongWritable
-import org.apache.hadoop.io.FloatWritable
-import org.apache.hadoop.io.DoubleWritable
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.ArrayWritable
import org.apache.hadoop.io.BooleanWritable
import org.apache.hadoop.io.BytesWritable
-import org.apache.hadoop.io.ArrayWritable
+import org.apache.hadoop.io.DoubleWritable
+import org.apache.hadoop.io.FloatWritable
+import org.apache.hadoop.io.IntWritable
+import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
+import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.FileInputFormat
+import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapred.SequenceFileInputFormat
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
-import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
+import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
+
import org.apache.mesos.MesosNativeLibrary
-import spark.deploy.LocalSparkCluster
-import spark.partial.ApproximateEvaluator
-import spark.partial.PartialResult
+import spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
+import spark.partial.{ApproximateEvaluator, PartialResult}
import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
-import spark.scheduler._
+import spark.scheduler.{DAGScheduler, ResultTask, ShuffleMapTask, SparkListener, SplitInfo, Stage, StageInfo, TaskScheduler}
+import spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, ClusterScheduler}
import spark.scheduler.local.LocalScheduler
-import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import spark.storage.BlockManagerUI
+import spark.storage.{BlockManagerUI, StorageStatus, StorageUtils, RDDInfo}
import spark.util.{MetadataCleaner, TimeStampedHashMap}
-import spark.storage.{StorageStatus, StorageUtils, RDDInfo}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -60,7 +68,10 @@ class SparkContext(
val appName: String,
val sparkHome: String = null,
val jars: Seq[String] = Nil,
- val environment: Map[String, String] = Map())
+ val environment: Map[String, String] = Map(),
+ // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) too.
+ // This is typically generated from InputFormatInfo.computePreferredLocations .. host, set of data-local splits on host
+ val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map())
extends Logging {
// Ensure logging is initialized before we spawn any threads
@@ -68,7 +79,7 @@ class SparkContext(
// Set Spark driver host and port system properties
if (System.getProperty("spark.driver.host") == null) {
- System.setProperty("spark.driver.host", Utils.localIpAddress)
+ System.setProperty("spark.driver.host", Utils.localHostName())
}
if (System.getProperty("spark.driver.port") == null) {
System.setProperty("spark.driver.port", "0")
@@ -95,24 +106,29 @@ class SparkContext(
private[spark] val addedJars = HashMap[String, Long]()
// Keeps track of all persisted RDDs
- private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]()
+ private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup)
// Add each JAR given through the constructor
- jars.foreach { addJar(_) }
+ if (jars != null) {
+ jars.foreach { addJar(_) }
+ }
// Environment variables to pass to our executors
private[spark] val executorEnvs = HashMap[String, String]()
// Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner
- for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS",
- "SPARK_TESTING")) {
+ for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", "SPARK_TESTING")) {
val value = System.getenv(key)
if (value != null) {
executorEnvs(key) = value
}
}
- executorEnvs ++= environment
+ // Since memory can be set with a system property too, use that
+ executorEnvs("SPARK_MEM") = SparkContext.executorMemoryRequested + "m"
+ if (environment != null) {
+ executorEnvs ++= environment
+ }
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
@@ -144,14 +160,12 @@ class SparkContext(
scheduler
case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
- // Check to make sure SPARK_MEM <= memoryPerSlave. Otherwise Spark will just hang.
+ // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
val memoryPerSlaveInt = memoryPerSlave.toInt
- val sparkMemEnv = System.getenv("SPARK_MEM")
- val sparkMemEnvInt = if (sparkMemEnv != null) Utils.memoryStringToMb(sparkMemEnv) else 512
- if (sparkMemEnvInt > memoryPerSlaveInt) {
+ if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) {
throw new SparkException(
- "Slave memory (%d MB) cannot be smaller than SPARK_MEM (%d MB)".format(
- memoryPerSlaveInt, sparkMemEnvInt))
+ "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format(
+ memoryPerSlaveInt, SparkContext.executorMemoryRequested))
}
val scheduler = new ClusterScheduler(this)
@@ -165,6 +179,22 @@ class SparkContext(
}
scheduler
+ case "yarn-standalone" =>
+ val scheduler = try {
+ val clazz = Class.forName("spark.scheduler.cluster.YarnClusterScheduler")
+ val cons = clazz.getConstructor(classOf[SparkContext])
+ cons.newInstance(this).asInstanceOf[ClusterScheduler]
+ } catch {
+ // TODO: Enumerate the exact reasons why it can fail
+ // But irrespective of it, it means we cannot proceed !
+ case th: Throwable => {
+ throw new SparkException("YARN mode not available ?", th)
+ }
+ }
+ val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem)
+ scheduler.initialize(backend)
+ scheduler
+
case _ =>
if (MESOS_REGEX.findFirstIn(master).isEmpty) {
logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
@@ -184,12 +214,12 @@ class SparkContext(
}
taskScheduler.start()
- private var dagScheduler = new DAGScheduler(taskScheduler)
+ @volatile private var dagScheduler = new DAGScheduler(taskScheduler)
dagScheduler.start()
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
- val conf = new Configuration()
+ val conf = SparkHadoopUtil.newConfiguration()
// Explicitly check for S3 environment variables
if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
@@ -208,6 +238,22 @@ class SparkContext(
private[spark] var checkpointDir: Option[String] = None
+ // Thread Local variable that can be used by users to pass information down the stack
+ private val localProperties = new DynamicVariable[Properties](null)
+
+ def initLocalProperties() {
+ localProperties.value = new Properties()
+ }
+
+ def addLocalProperties(key: String, value: String) {
+ if(localProperties.value == null) {
+ localProperties.value = new Properties()
+ }
+ localProperties.value.setProperty(key,value)
+ }
+ // Post init
+ taskScheduler.postStartHook()
+
// Methods for creating RDDs
/** Distribute a local Scala collection to form an RDD. */
@@ -472,7 +518,7 @@ class SparkContext(
*/
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
- (blockManagerId.ip + ":" + blockManagerId.port, mem)
+ (blockManagerId.host + ":" + blockManagerId.port, mem)
}
}
@@ -480,7 +526,7 @@ class SparkContext(
* Return information about what RDDs are cached, if they are in mem or on disk, how much space
* they take, etc.
*/
- def getRDDStorageInfo : Array[RDDInfo] = {
+ def getRDDStorageInfo: Array[RDDInfo] = {
StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
}
@@ -491,7 +537,7 @@ class SparkContext(
/**
* Return information about blocks stored in all of the slaves
*/
- def getExecutorStorageStatus : Array[StorageStatus] = {
+ def getExecutorStorageStatus: Array[StorageStatus] = {
env.blockManager.master.getStorageStatus
}
@@ -509,13 +555,18 @@ class SparkContext(
* filesystems), or an HTTP, HTTPS or FTP URI.
*/
def addJar(path: String) {
- val uri = new URI(path)
- val key = uri.getScheme match {
- case null | "file" => env.httpFileServer.addJar(new File(uri.getPath))
- case _ => path
+ if (null == path) {
+ logWarning("null specified as parameter to addJar",
+ new SparkException("null specified as parameter to addJar"))
+ } else {
+ val uri = new URI(path)
+ val key = uri.getScheme match {
+ case null | "file" => env.httpFileServer.addJar(new File(uri.getPath))
+ case _ => path
+ }
+ addedJars(key) = System.currentTimeMillis
+ logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
}
- addedJars(key) = System.currentTimeMillis
- logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
}
/**
@@ -528,10 +579,13 @@ class SparkContext(
/** Shut down the SparkContext. */
def stop() {
- if (dagScheduler != null) {
+ // Do this only if not stopped already - best case effort.
+ // prevent NPE if stopped more than once.
+ val dagSchedulerCopy = dagScheduler
+ dagScheduler = null
+ if (dagSchedulerCopy != null) {
metadataCleaner.cancel()
- dagScheduler.stop()
- dagScheduler = null
+ dagSchedulerCopy.stop()
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
@@ -547,6 +601,7 @@ class SparkContext(
}
}
+
/**
* Get Spark's home location from either a value set through the constructor,
* or the spark.home Java property, or the SPARK_HOME environment variable
@@ -576,10 +631,10 @@ class SparkContext(
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
- val callSite = Utils.getSparkCallSite
+ val callSite = Utils.formatSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler)
+ val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
@@ -658,12 +713,11 @@ class SparkContext(
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
- timeout: Long
- ): PartialResult[R] = {
- val callSite = Utils.getSparkCallSite
+ timeout: Long): PartialResult[R] = {
+ val callSite = Utils.formatSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout)
+ val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
result
}
@@ -686,7 +740,7 @@ class SparkContext(
*/
def setCheckpointDir(dir: String, useExisting: Boolean = false) {
val path = new Path(dir)
- val fs = path.getFileSystem(new Configuration())
+ val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration())
if (!useExisting) {
if (fs.exists(path)) {
throw new Exception("Checkpoint directory '" + path + "' already exists.")
@@ -829,6 +883,15 @@ object SparkContext {
/** Find the JAR that contains the class of a particular object */
def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass)
+
+ /** Get the amount of memory per executor requested through system properties or SPARK_MEM */
+ private[spark] val executorMemoryRequested = {
+ // TODO: Might need to add some extra memory for the non-heap parts of the JVM
+ Option(System.getProperty("spark.executor.memory"))
+ .orElse(Option(System.getenv("SPARK_MEM")))
+ .map(Utils.memoryStringToMb)
+ .getOrElse(512)
+ }
}
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 144ddea35f..89d52064e1 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -1,14 +1,19 @@
package spark
+import collection.mutable
+import serializer.Serializer
+
import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
import akka.remote.RemoteActorRefProvider
-import serializer.Serializer
import spark.broadcast.BroadcastManager
import spark.storage.BlockManager
import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager
+import spark.serializer.{Serializer, SerializerManager}
import spark.util.AkkaUtils
+import spark.api.python.PythonWorkerFactory
+
/**
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
@@ -20,6 +25,7 @@ import spark.util.AkkaUtils
class SparkEnv (
val executorId: String,
val actorSystem: ActorSystem,
+ val serializerManager: SerializerManager,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
@@ -29,10 +35,16 @@ class SparkEnv (
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
val httpFileServer: HttpFileServer,
- val sparkFilesDir: String
- ) {
+ val sparkFilesDir: String,
+ // To be set only as part of initialization of SparkContext.
+ // (executorId, defaultHostPort) => executorHostPort
+ // If executorId is NOT found, return defaultHostPort
+ var executorIdToHostPort: Option[(String, String) => String]) {
+
+ private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
def stop() {
+ pythonWorkers.foreach { case(key, worker) => worker.stop() }
httpFileServer.stop()
mapOutputTracker.stop()
shuffleFetcher.stop()
@@ -45,6 +57,23 @@ class SparkEnv (
// UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it.
//actorSystem.awaitTermination()
}
+
+ def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
+ synchronized {
+ val key = (pythonExec, envVars)
+ pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
+ }
+ }
+
+ def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = {
+ val env = SparkEnv.get
+ if (env.executorIdToHostPort.isEmpty) {
+ // default to using host, not host port. Relevant to non cluster modes.
+ return defaultHostPort
+ }
+
+ env.executorIdToHostPort.get(executorId, defaultHostPort)
+ }
}
object SparkEnv extends Logging {
@@ -73,6 +102,16 @@ object SparkEnv extends Logging {
System.setProperty("spark.driver.port", boundPort.toString)
}
+ // set only if unset until now.
+ if (System.getProperty("spark.hostPort", null) == null) {
+ if (!isDriver){
+ // unexpected
+ Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set")
+ }
+ Utils.checkHost(hostname)
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+ }
+
val classLoader = Thread.currentThread.getContextClassLoader
// Create an instance of the class named by the given Java system property, or by
@@ -82,16 +121,23 @@ object SparkEnv extends Logging {
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
}
- val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
-
+ val serializerManager = new SerializerManager
+
+ val serializer = serializerManager.setDefault(
+ System.getProperty("spark.serializer", "spark.JavaSerializer"))
+
+ val closureSerializer = serializerManager.get(
+ System.getProperty("spark.closure.serializer", "spark.JavaSerializer"))
+
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
if (isDriver) {
logInfo("Registering " + name)
actorSystem.actorOf(Props(newActor), name = name)
} else {
- val driverIp: String = System.getProperty("spark.driver.host", "localhost")
+ val driverHost: String = System.getProperty("spark.driver.host", "localhost")
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
- val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, name)
+ Utils.checkHost(driverHost, "Expected hostname")
+ val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name)
logInfo("Connecting to " + name + ": " + url)
actorSystem.actorFor(url)
}
@@ -106,9 +152,6 @@ object SparkEnv extends Logging {
val broadcastManager = new BroadcastManager(isDriver)
- val closureSerializer = instantiateClass[Serializer](
- "spark.closure.serializer", "spark.JavaSerializer")
-
val cacheManager = new CacheManager(blockManager)
// Have to assign trackerActor after initialization as MapOutputTrackerActor
@@ -143,6 +186,7 @@ object SparkEnv extends Logging {
new SparkEnv(
executorId,
actorSystem,
+ serializerManager,
serializer,
closureSerializer,
cacheManager,
@@ -152,7 +196,7 @@ object SparkEnv extends Logging {
blockManager,
connectionManager,
httpFileServer,
- sparkFilesDir)
+ sparkFilesDir,
+ None)
}
-
}
diff --git a/core/src/main/scala/spark/TaskEndReason.scala b/core/src/main/scala/spark/TaskEndReason.scala
index 420c54bc9a..8140cba084 100644
--- a/core/src/main/scala/spark/TaskEndReason.scala
+++ b/core/src/main/scala/spark/TaskEndReason.scala
@@ -14,9 +14,19 @@ private[spark] case object Success extends TaskEndReason
private[spark]
case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it
-private[spark]
-case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
+private[spark] case class FetchFailed(
+ bmAddress: BlockManagerId,
+ shuffleId: Int,
+ mapId: Int,
+ reduceId: Int)
+ extends TaskEndReason
-private[spark] case class ExceptionFailure(exception: Throwable) extends TaskEndReason
+private[spark] case class ExceptionFailure(
+ className: String,
+ description: String,
+ stackTrace: Array[StackTraceElement])
+ extends TaskEndReason
private[spark] case class OtherFailure(message: String) extends TaskEndReason
+
+private[spark] case class TaskResultTooBigFailure() extends TaskEndReason
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index cdccb8b336..e02507f83e 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -1,14 +1,16 @@
package spark
import java.io._
-import java.net._
+import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket}
import java.util.{Locale, Random, UUID}
-import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
+
+import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor}
+import java.util.regex.Pattern
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.JavaConversions._
import scala.io.Source
import scala.reflect.ClassTag
@@ -18,11 +20,14 @@ import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
import spark.serializer.SerializerInstance
+import spark.deploy.SparkHadoopUtil
+
/**
* Various utility methods used by Spark.
*/
private object Utils extends Logging {
+
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
@@ -73,6 +78,40 @@ private object Utils extends Logging {
return buf
}
+ private val shutdownDeletePaths = new collection.mutable.HashSet[String]()
+
+ // Register the path to be deleted via shutdown hook
+ def registerShutdownDeleteDir(file: File) {
+ val absolutePath = file.getAbsolutePath()
+ shutdownDeletePaths.synchronized {
+ shutdownDeletePaths += absolutePath
+ }
+ }
+
+ // Is the path already registered to be deleted via a shutdown hook ?
+ def hasShutdownDeleteDir(file: File): Boolean = {
+ val absolutePath = file.getAbsolutePath()
+ shutdownDeletePaths.synchronized {
+ shutdownDeletePaths.contains(absolutePath)
+ }
+ }
+
+ // Note: if file is child of some registered path, while not equal to it, then return true;
+ // else false. This is to ensure that two shutdown hooks do not try to delete each others
+ // paths - resulting in IOException and incomplete cleanup.
+ def hasRootAsShutdownDeleteDir(file: File): Boolean = {
+ val absolutePath = file.getAbsolutePath()
+ val retval = shutdownDeletePaths.synchronized {
+ shutdownDeletePaths.find { path =>
+ !absolutePath.equals(path) && absolutePath.startsWith(path)
+ }.isDefined
+ }
+ if (retval) {
+ logInfo("path = " + file + ", already present as root for deletion.")
+ }
+ retval
+ }
+
/** Create a temporary directory inside the given parent directory */
def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = {
var attempts = 0
@@ -81,8 +120,8 @@ private object Utils extends Logging {
while (dir == null) {
attempts += 1
if (attempts > maxAttempts) {
- throw new IOException("Failed to create a temp directory after " + maxAttempts +
- " attempts!")
+ throw new IOException("Failed to create a temp directory (under " + root + ") after " +
+ maxAttempts + " attempts!")
}
try {
dir = new File(root, "spark-" + UUID.randomUUID.toString)
@@ -91,13 +130,17 @@ private object Utils extends Logging {
}
} catch { case e: IOException => ; }
}
+
+ registerShutdownDeleteDir(dir)
+
// Add a shutdown hook to delete the temp dir when the JVM exits
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) {
override def run() {
- Utils.deleteRecursively(dir)
+ // Attempt to delete if some patch which is parent of this is not already registered.
+ if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir)
}
})
- return dir
+ dir
}
/** Copy all data from an InputStream to an OutputStream */
@@ -140,40 +183,35 @@ private object Utils extends Logging {
Utils.copyStream(in, out, true)
if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
tempFile.delete()
- throw new SparkException("File " + targetFile + " exists and does not match contents of" +
- " " + url)
+ throw new SparkException(
+ "File " + targetFile + " exists and does not match contents of" + " " + url)
} else {
Files.move(tempFile, targetFile)
}
case "file" | null =>
- val sourceFile = if (uri.isAbsolute) {
- new File(uri)
- } else {
- new File(url)
- }
- if (targetFile.exists && !Files.equal(sourceFile, targetFile)) {
- throw new SparkException("File " + targetFile + " exists and does not match contents of" +
- " " + url)
- } else {
- // Remove the file if it already exists
- targetFile.delete()
- // Symlink the file locally.
- if (uri.isAbsolute) {
- // url is absolute, i.e. it starts with "file:///". Extract the source
- // file's absolute path from the url.
- val sourceFile = new File(uri)
- logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
- FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath)
+ // In the case of a local file, copy the local file to the target directory.
+ // Note the difference between uri vs url.
+ val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url)
+ if (targetFile.exists) {
+ // If the target file already exists, warn the user if
+ if (!Files.equal(sourceFile, targetFile)) {
+ throw new SparkException(
+ "File " + targetFile + " exists and does not match contents of" + " " + url)
} else {
- // url is not absolute, i.e. itself is the path to the source file.
- logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath)
- FileUtil.symLink(url, targetFile.getAbsolutePath)
+ // Do nothing if the file contents are the same, i.e. this file has been copied
+ // previously.
+ logInfo(sourceFile.getAbsolutePath + " has been previously copied to "
+ + targetFile.getAbsolutePath)
}
+ } else {
+ // The file does not exist in the target directory. Copy it there.
+ logInfo("Copying " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
+ Files.copy(sourceFile, targetFile)
}
case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
val uri = new URI(url)
- val conf = new Configuration()
+ val conf = SparkHadoopUtil.newConfiguration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(tempFile)
@@ -232,8 +270,10 @@ private object Utils extends Logging {
/**
* Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4).
+ * Note, this is typically not used from within core spark.
*/
lazy val localIpAddress: String = findLocalIpAddress()
+ lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress)
private def findLocalIpAddress(): String = {
val defaultIpOverride = System.getenv("SPARK_LOCAL_IP")
@@ -271,6 +311,8 @@ private object Utils extends Logging {
* hostname it reports to the master.
*/
def setCustomHostname(hostname: String) {
+ // DEBUG code
+ Utils.checkHost(hostname)
customHostname = Some(hostname)
}
@@ -278,7 +320,91 @@ private object Utils extends Logging {
* Get the local machine's hostname.
*/
def localHostName(): String = {
- customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
+ customHostname.getOrElse(localIpAddressHostname)
+ }
+
+ def getAddressHostName(address: String): String = {
+ InetAddress.getByName(address).getHostName
+ }
+
+ def localHostPort(): String = {
+ val retval = System.getProperty("spark.hostPort", null)
+ if (retval == null) {
+ logErrorWithStack("spark.hostPort not set but invoking localHostPort")
+ return localHostName()
+ }
+
+ retval
+ }
+
+/*
+ // Used by DEBUG code : remove when all testing done
+ private val ipPattern = Pattern.compile("^[0-9]+(\\.[0-9]+)*$")
+ def checkHost(host: String, message: String = "") {
+ // Currently catches only ipv4 pattern, this is just a debugging tool - not rigourous !
+ // if (host.matches("^[0-9]+(\\.[0-9]+)*$")) {
+ if (ipPattern.matcher(host).matches()) {
+ Utils.logErrorWithStack("Unexpected to have host " + host + " which matches IP pattern. Message " + message)
+ }
+ if (Utils.parseHostPort(host)._2 != 0){
+ Utils.logErrorWithStack("Unexpected to have host " + host + " which has port in it. Message " + message)
+ }
+ }
+
+ // Used by DEBUG code : remove when all testing done
+ def checkHostPort(hostPort: String, message: String = "") {
+ val (host, port) = Utils.parseHostPort(hostPort)
+ checkHost(host)
+ if (port <= 0){
+ Utils.logErrorWithStack("Unexpected to have port " + port + " which is not valid in " + hostPort + ". Message " + message)
+ }
+ }
+
+ // Used by DEBUG code : remove when all testing done
+ def logErrorWithStack(msg: String) {
+ try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } }
+ // temp code for debug
+ System.exit(-1)
+ }
+*/
+
+ // Once testing is complete in various modes, replace with this ?
+ def checkHost(host: String, message: String = "") {}
+ def checkHostPort(hostPort: String, message: String = "") {}
+
+ // Used by DEBUG code : remove when all testing done
+ def logErrorWithStack(msg: String) {
+ try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } }
+ }
+
+ def getUserNameFromEnvironment(): String = {
+ SparkHadoopUtil.getUserNameFromEnvironment
+ }
+
+ // Typically, this will be of order of number of nodes in cluster
+ // If not, we should change it to LRUCache or something.
+ private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]()
+
+ def parseHostPort(hostPort: String): (String, Int) = {
+ {
+ // Check cache first.
+ var cached = hostPortParseResults.get(hostPort)
+ if (cached != null) return cached
+ }
+
+ val indx: Int = hostPort.lastIndexOf(':')
+ // This is potentially broken - when dealing with ipv6 addresses for example, sigh ...
+ // but then hadoop does not support ipv6 right now.
+ // For now, we assume that if port exists, then it is valid - not check if it is an int > 0
+ if (-1 == indx) {
+ val retval = (hostPort, 0)
+ hostPortParseResults.put(hostPort, retval)
+ return retval
+ }
+
+ val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt)
+ hostPortParseResults.putIfAbsent(hostPort, retval)
+ hostPortParseResults.get(hostPort)
}
private[spark] val daemonThreadFactory: ThreadFactory =
@@ -400,13 +526,45 @@ private object Utils extends Logging {
execute(command, new File("."))
}
+ /**
+ * Execute a command and get its output, throwing an exception if it yields a code other than 0.
+ */
+ def executeAndGetOutput(command: Seq[String], workingDir: File = new File(".")): String = {
+ val process = new ProcessBuilder(command: _*)
+ .directory(workingDir)
+ .start()
+ new Thread("read stderr for " + command(0)) {
+ override def run() {
+ for (line <- Source.fromInputStream(process.getErrorStream).getLines) {
+ System.err.println(line)
+ }
+ }
+ }.start()
+ val output = new StringBuffer
+ val stdoutThread = new Thread("read stdout for " + command(0)) {
+ override def run() {
+ for (line <- Source.fromInputStream(process.getInputStream).getLines) {
+ output.append(line)
+ }
+ }
+ }
+ stdoutThread.start()
+ val exitCode = process.waitFor()
+ stdoutThread.join() // Wait for it to finish reading output
+ if (exitCode != 0) {
+ throw new SparkException("Process " + command + " exited with code " + exitCode)
+ }
+ output.toString
+ }
+ private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String,
+ val firstUserLine: Int, val firstUserClass: String)
/**
* When called inside a class in the spark package, returns the name of the user code class
* (outside the spark package) that called into Spark, as well as which Spark method they called.
* This is used, for example, to tell users where in their code each RDD got created.
*/
- def getSparkCallSite: String = {
+ def getCallSiteInfo: CallSiteInfo = {
val trace = Thread.currentThread.getStackTrace().filter( el =>
(!el.getMethodName.contains("getStackTrace")))
@@ -418,6 +576,7 @@ private object Utils extends Logging {
var firstUserFile = "<unknown>"
var firstUserLine = 0
var finished = false
+ var firstUserClass = "<unknown>"
for (el <- trace) {
if (!finished) {
@@ -432,13 +591,19 @@ private object Utils extends Logging {
else {
firstUserLine = el.getLineNumber
firstUserFile = el.getFileName
+ firstUserClass = el.getClassName
finished = true
}
}
}
- "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
+ new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass)
}
+ def formatSparkCallSite = {
+ val callSiteInfo = getCallSiteInfo
+ "%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile,
+ callSiteInfo.firstUserLine)
+ }
/**
* Try to find a free port to bind to on the local host. This should ideally never be needed,
* except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray)
@@ -480,4 +645,67 @@ private object Utils extends Logging {
}
return false
}
+
+ def isSpace(c: Char): Boolean = {
+ " \t\r\n".indexOf(c) != -1
+ }
+
+ /**
+ * Split a string of potentially quoted arguments from the command line the way that a shell
+ * would do it to determine arguments to a command. For example, if the string is 'a "b c" d',
+ * then it would be parsed as three arguments: 'a', 'b c' and 'd'.
+ */
+ def splitCommandString(s: String): Seq[String] = {
+ val buf = new ArrayBuffer[String]
+ var inWord = false
+ var inSingleQuote = false
+ var inDoubleQuote = false
+ var curWord = new StringBuilder
+ def endWord() {
+ buf += curWord.toString
+ curWord.clear()
+ }
+ var i = 0
+ while (i < s.length) {
+ var nextChar = s.charAt(i)
+ if (inDoubleQuote) {
+ if (nextChar == '"') {
+ inDoubleQuote = false
+ } else if (nextChar == '\\') {
+ if (i < s.length - 1) {
+ // Append the next character directly, because only " and \ may be escaped in
+ // double quotes after the shell's own expansion
+ curWord.append(s.charAt(i + 1))
+ i += 1
+ }
+ } else {
+ curWord.append(nextChar)
+ }
+ } else if (inSingleQuote) {
+ if (nextChar == '\'') {
+ inSingleQuote = false
+ } else {
+ curWord.append(nextChar)
+ }
+ // Backslashes are not treated specially in single quotes
+ } else if (nextChar == '"') {
+ inWord = true
+ inDoubleQuote = true
+ } else if (nextChar == '\'') {
+ inWord = true
+ inSingleQuote = true
+ } else if (!isSpace(nextChar)) {
+ curWord.append(nextChar)
+ inWord = true
+ } else if (inWord && isSpace(nextChar)) {
+ endWord()
+ inWord = false
+ }
+ i += 1
+ }
+ if (inWord || inDoubleQuote || inSingleQuote) {
+ endWord()
+ }
+ return buf
+ }
}
diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
index 89c6d05383..0fa8162f3c 100644
--- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
@@ -7,6 +7,7 @@ import scala.Tuple2
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
+import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
@@ -460,6 +461,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass)
}
+ /** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */
+ def saveAsHadoopFile[F <: OutputFormat[_, _]](
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F],
+ codec: Class[_ <: CompressionCodec]) {
+ rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec)
+ }
+
/** Output the RDD to any Hadoop-supported file system. */
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
path: String,
diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala
index 032506383c..6f44e018e9 100644
--- a/core/src/main/scala/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDD.scala
@@ -17,10 +17,16 @@ JavaRDDLike[T, JavaRDD[T]] {
/**
* Set this RDD's storage level to persist its values across operations after the first time
- * it is computed. Can only be called once on each RDD.
+ * it is computed. This can only be used to assign a new storage level if the RDD does not
+ * have a storage level set yet..
*/
def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel))
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ */
+ def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist())
+
// Transformations (return a new RDD)
/**
@@ -81,7 +87,6 @@ JavaRDDLike[T, JavaRDD[T]] {
*/
def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] =
wrapRDD(rdd.subtract(other, p))
-
}
object JavaRDD {
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index a6555081b3..3fe2011f4c 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -1,10 +1,11 @@
package spark.api.java
-import java.util.{List => JList}
+import java.util.{List => JList, Comparator}
import scala.Tuple2
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
+import org.apache.hadoop.io.compress.CompressionCodec
import spark.{SparkContext, Partition, RDD, TaskContext}
import spark.api.java.JavaPairRDD._
import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
@@ -183,6 +184,21 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classTag))(classTag, other.classTag)
}
+ /**
+ * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
+ * applying a function to the zipped partitions. Assumes that all the RDDs have the
+ * *same number of partitions*, but does *not* require them to have the same number
+ * of elements in each partition.
+ */
+ def zipPartitions[U, V](
+ f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V],
+ other: JavaRDDLike[U, _]): JavaRDD[V] = {
+ def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator(
+ f.apply(asJavaIterator(x), asJavaIterator(y)).iterator())
+ JavaRDD.fromRDD(
+ rdd.zipPartitions(fn, other.rdd)(other.classTag, f.elementType()))(f.elementType())
+ }
+
// Actions (launch a job to return a value to the user program)
/**
@@ -296,6 +312,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
*/
def saveAsTextFile(path: String) = rdd.saveAsTextFile(path)
+
+ /**
+ * Save this RDD as a compressed text file, using string representations of elements.
+ */
+ def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) =
+ rdd.saveAsTextFile(path, codec)
+
/**
* Save this RDD as a SequenceFile of serialized objects.
*/
@@ -337,4 +360,29 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def toDebugString(): String = {
rdd.toDebugString
}
+
+ /**
+ * Returns the top K elements from this RDD as defined by
+ * the specified Comparator[T].
+ * @param num the number of top elements to return
+ * @param comp the comparator that defines the order
+ * @return an array of top elements
+ */
+ def top(num: Int, comp: Comparator[T]): JList[T] = {
+ import scala.collection.JavaConversions._
+ val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp))
+ val arr: java.util.Collection[T] = topElems.toSeq
+ new java.util.ArrayList(arr)
+ }
+
+ /**
+ * Returns the top K elements from this RDD using the
+ * natural ordering for T.
+ * @param num the number of top elements to return
+ * @return an array of top elements
+ */
+ def top(num: Int): JList[T] = {
+ val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]]
+ top(num, comp)
+ }
}
diff --git a/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala
new file mode 100644
index 0000000000..6044043add
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala
@@ -0,0 +1,11 @@
+package spark.api.java.function
+
+/**
+ * A function that takes two inputs and returns zero or more output records.
+ */
+abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
+ @throws(classOf[Exception])
+ def call(a: A, b:B) : java.lang.Iterable[C]
+
+ def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]]
+}
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 220047c360..3d1e45cb2c 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -2,9 +2,10 @@ package spark.api.python
import java.io._
import java.net._
-import java.util.{List => JList, ArrayList => JArrayList, Collections}
+import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._
+
import scala.io.Source
import scala.reflect.ClassTag
@@ -17,16 +18,18 @@ import spark.rdd.PipedRDD
private[spark] class PythonRDD[T: ClassTag](
parent: RDD[T],
command: Seq[String],
- envVars: java.util.Map[String, String],
+ envVars: JMap[String, String],
preservePartitoning: Boolean,
pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
- def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
+ def this(parent: RDD[T], command: String, envVars: JMap[String, String],
preservePartitoning: Boolean, pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]]) =
@@ -37,68 +40,57 @@ private[spark] class PythonRDD[T: ClassTag](
override val partitioner = if (preservePartitoning) parent.partitioner else None
- override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
- val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
-
- val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py"))
- // Add the environmental variables to the process.
- val currentEnvVars = pb.environment()
- for ((variable, value) <- envVars) {
- currentEnvVars.put(variable, value)
- }
-
- val proc = pb.start()
+ override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
+ val startTime = System.currentTimeMillis
val env = SparkEnv.get
-
- // Start a thread to print the process's stderr to ours
- new Thread("stderr reader for " + pythonExec) {
- override def run() {
- for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
- System.err.println(line)
- }
- }
- }.start()
+ val worker = env.createPythonWorker(pythonExec, envVars.toMap)
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + pythonExec) {
override def run() {
SparkEnv.set(env)
- val out = new PrintWriter(proc.getOutputStream)
- val dOut = new DataOutputStream(proc.getOutputStream)
+ val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
+ val dataOut = new DataOutputStream(stream)
+ val printOut = new PrintWriter(stream)
// Partition index
- dOut.writeInt(split.index)
+ dataOut.writeInt(split.index)
// sparkFilesDir
- PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut)
+ PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
// Broadcast variables
- dOut.writeInt(broadcastVars.length)
+ dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
- dOut.writeLong(broadcast.id)
- dOut.writeInt(broadcast.value.length)
- dOut.write(broadcast.value)
- dOut.flush()
+ dataOut.writeLong(broadcast.id)
+ dataOut.writeInt(broadcast.value.length)
+ dataOut.write(broadcast.value)
}
+ dataOut.flush()
// Serialized user code
for (elem <- command) {
- out.println(elem)
+ printOut.println(elem)
}
- out.flush()
+ printOut.flush()
// Data values
for (elem <- parent.iterator(split, context)) {
- PythonRDD.writeAsPickle(elem, dOut)
+ PythonRDD.writeAsPickle(elem, dataOut)
}
- dOut.flush()
- out.flush()
- proc.getOutputStream.close()
+ dataOut.flush()
+ printOut.flush()
+ worker.shutdownOutput()
}
}.start()
// Return an iterator that read lines from the process's stdout
- val stream = new DataInputStream(proc.getInputStream)
+ val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
return new Iterator[Array[Byte]] {
def next(): Array[Byte] = {
val obj = _nextObj
- _nextObj = read()
+ if (hasNext) {
+ // FIXME: can deadlock if worker is waiting for us to
+ // respond to current message (currently irrelevant because
+ // output is shutdown before we read any input)
+ _nextObj = read()
+ }
obj
}
@@ -109,6 +101,17 @@ private[spark] class PythonRDD[T: ClassTag](
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
+ case -3 =>
+ // Timing data from worker
+ val bootTime = stream.readLong()
+ val initTime = stream.readLong()
+ val finishTime = stream.readLong()
+ val boot = bootTime - startTime
+ val init = initTime - bootTime
+ val finish = finishTime - initTime
+ val total = finishTime - startTime
+ logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
+ read
case -2 =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
@@ -116,23 +119,21 @@ private[spark] class PythonRDD[T: ClassTag](
stream.readFully(obj)
throw new PythonException(new String(obj))
case -1 =>
- // We've finished the data section of the output, but we can still read some
- // accumulator updates; let's do that, breaking when we get EOFException
- while (true) {
- val len2 = stream.readInt()
+ // We've finished the data section of the output, but we can still
+ // read some accumulator updates; let's do that, breaking when we
+ // get a negative length record.
+ var len2 = stream.readInt()
+ while (len2 >= 0) {
val update = new Array[Byte](len2)
stream.readFully(update)
accumulator += Collections.singletonList(update)
+ len2 = stream.readInt()
}
new Array[Byte](0)
}
} catch {
case eof: EOFException => {
- val exitStatus = proc.waitFor()
- if (exitStatus != 0) {
- throw new Exception("Subprocess exited with status " + exitStatus)
- }
- new Array[Byte](0)
+ throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
}
case e : Throwable => throw e
}
@@ -160,7 +161,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
override def compute(split: Partition, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (a, b)
- case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
+ case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
}
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
}
@@ -216,7 +217,7 @@ private[spark] object PythonRDD {
dOut.write(s)
dOut.writeByte(Pickle.STOP)
} else {
- throw new Exception("Unexpected RDD type")
+ throw new SparkException("Unexpected RDD type")
}
}
@@ -279,6 +280,10 @@ private class BytesToString extends spark.api.java.function.Function[Array[Byte]
class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
+ Utils.checkHost(serverHost, "Expected hostname")
+
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
@@ -291,7 +296,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
// This happens on the master, where we pass the updates to Python through a socket
val socket = new Socket(serverHost, serverPort)
val in = socket.getInputStream
- val out = new DataOutputStream(socket.getOutputStream)
+ val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
out.writeInt(val2.size)
for (array <- val2) {
out.writeInt(array.length)
diff --git a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala
new file mode 100644
index 0000000000..85d1dfeac8
--- /dev/null
+++ b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala
@@ -0,0 +1,113 @@
+package spark.api.python
+
+import java.io.{DataInputStream, IOException}
+import java.net.{Socket, SocketException, InetAddress}
+
+import scala.collection.JavaConversions._
+
+import spark._
+
+private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
+ extends Logging {
+ var daemon: Process = null
+ val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
+ var daemonPort: Int = 0
+
+ def create(): Socket = {
+ synchronized {
+ // Start the daemon if it hasn't been started
+ startDaemon()
+
+ // Attempt to connect, restart and retry once if it fails
+ try {
+ new Socket(daemonHost, daemonPort)
+ } catch {
+ case exc: SocketException => {
+ logWarning("Python daemon unexpectedly quit, attempting to restart")
+ stopDaemon()
+ startDaemon()
+ new Socket(daemonHost, daemonPort)
+ }
+ case e => throw e
+ }
+ }
+ }
+
+ def stop() {
+ stopDaemon()
+ }
+
+ private def startDaemon() {
+ synchronized {
+ // Is it already running?
+ if (daemon != null) {
+ return
+ }
+
+ try {
+ // Create and start the daemon
+ val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
+ val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py"))
+ val workerEnv = pb.environment()
+ workerEnv.putAll(envVars)
+ daemon = pb.start()
+
+ // Redirect the stderr to ours
+ new Thread("stderr reader for " + pythonExec) {
+ override def run() {
+ scala.util.control.Exception.ignoring(classOf[IOException]) {
+ // FIXME HACK: We copy the stream on the level of bytes to
+ // attempt to dodge encoding problems.
+ val in = daemon.getErrorStream
+ var buf = new Array[Byte](1024)
+ var len = in.read(buf)
+ while (len != -1) {
+ System.err.write(buf, 0, len)
+ len = in.read(buf)
+ }
+ }
+ }
+ }.start()
+
+ val in = new DataInputStream(daemon.getInputStream)
+ daemonPort = in.readInt()
+
+ // Redirect further stdout output to our stderr
+ new Thread("stdout reader for " + pythonExec) {
+ override def run() {
+ scala.util.control.Exception.ignoring(classOf[IOException]) {
+ // FIXME HACK: We copy the stream on the level of bytes to
+ // attempt to dodge encoding problems.
+ var buf = new Array[Byte](1024)
+ var len = in.read(buf)
+ while (len != -1) {
+ System.err.write(buf, 0, len)
+ len = in.read(buf)
+ }
+ }
+ }
+ }.start()
+ } catch {
+ case e => {
+ stopDaemon()
+ throw e
+ }
+ }
+
+ // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly
+ // detect our disappearance.
+ }
+ }
+
+ private def stopDaemon() {
+ synchronized {
+ // Request shutdown of existing daemon by sending SIGTERM
+ if (daemon != null) {
+ daemon.destroy()
+ }
+
+ daemon = null
+ daemonPort = 0
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/spark/deploy/ApplicationDescription.scala
index 6659e53b25..02193c7008 100644
--- a/core/src/main/scala/spark/deploy/ApplicationDescription.scala
+++ b/core/src/main/scala/spark/deploy/ApplicationDescription.scala
@@ -2,10 +2,11 @@ package spark.deploy
private[spark] class ApplicationDescription(
val name: String,
- val cores: Int,
+ val maxCores: Int, /* Integer.MAX_VALUE denotes an unlimited number of cores */
val memoryPerSlave: Int,
val command: Command,
- val sparkHome: String)
+ val sparkHome: String,
+ val appUiUrl: String)
extends Serializable {
val user = System.getProperty("user.name", "<unknown>")
diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala
index 8a3e64e4c2..51274acb1e 100644
--- a/core/src/main/scala/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/spark/deploy/DeployMessage.scala
@@ -4,6 +4,7 @@ import spark.deploy.ExecutorState.ExecutorState
import spark.deploy.master.{WorkerInfo, ApplicationInfo}
import spark.deploy.worker.ExecutorRunner
import scala.collection.immutable.List
+import spark.Utils
private[spark] sealed trait DeployMessage extends Serializable
@@ -19,7 +20,10 @@ case class RegisterWorker(
memory: Int,
webUiPort: Int,
publicAddress: String)
- extends DeployMessage
+ extends DeployMessage {
+ Utils.checkHost(host, "Required hostname")
+ assert (port > 0)
+}
private[spark]
case class ExecutorStateChanged(
@@ -58,7 +62,9 @@ private[spark]
case class RegisteredApplication(appId: String) extends DeployMessage
private[spark]
-case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int)
+case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) {
+ Utils.checkHostPort(hostPort, "Required hostport")
+}
private[spark]
case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String],
@@ -81,6 +87,9 @@ private[spark]
case class MasterState(host: String, port: Int, workers: Array[WorkerInfo],
activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) {
+ Utils.checkHost(host, "Required hostname")
+ assert (port > 0)
+
def uri = "spark://" + host + ":" + port
}
@@ -92,4 +101,8 @@ private[spark] case object RequestWorkerState
private[spark]
case class WorkerState(host: String, port: Int, workerId: String, executors: List[ExecutorRunner],
finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int,
- coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String)
+ coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) {
+
+ Utils.checkHost(host, "Required hostname")
+ assert (port > 0)
+}
diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala
index 702defb628..88b03a007c 100644
--- a/core/src/main/scala/spark/deploy/JsonProtocol.scala
+++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala
@@ -12,6 +12,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
def write(obj: WorkerInfo) = JsObject(
"id" -> JsString(obj.id),
"host" -> JsString(obj.host),
+ "port" -> JsNumber(obj.port),
"webuiaddress" -> JsString(obj.webUiAddress),
"cores" -> JsNumber(obj.cores),
"coresused" -> JsNumber(obj.coresUsed),
@@ -25,7 +26,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
"starttime" -> JsNumber(obj.startTime),
"id" -> JsString(obj.id),
"name" -> JsString(obj.desc.name),
- "cores" -> JsNumber(obj.desc.cores),
+ "cores" -> JsNumber(obj.desc.maxCores),
"user" -> JsString(obj.desc.user),
"memoryperslave" -> JsNumber(obj.desc.memoryPerSlave),
"submitdate" -> JsString(obj.submitDate.toString))
@@ -34,7 +35,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
implicit object AppDescriptionJsonFormat extends RootJsonWriter[ApplicationDescription] {
def write(obj: ApplicationDescription) = JsObject(
"name" -> JsString(obj.name),
- "cores" -> JsNumber(obj.cores),
+ "cores" -> JsNumber(obj.maxCores),
"memoryperslave" -> JsNumber(obj.memoryPerSlave),
"user" -> JsString(obj.user)
)
diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
index 6abaaeaa3f..2b0b3b10e7 100644
--- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
@@ -18,7 +18,7 @@ import scala.collection.mutable.ArrayBuffer
private[spark]
class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
- private val localIpAddress = Utils.localIpAddress
+ private val localHostname = Utils.localHostName()
private val masterActorSystems = ArrayBuffer[ActorSystem]()
private val workerActorSystems = ArrayBuffer[ActorSystem]()
@@ -26,13 +26,13 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
- val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0)
+ val (masterSystem, masterPort) = Master.startSystemAndActor(localHostname, 0, 0)
masterActorSystems += masterSystem
- val masterUrl = "spark://" + localIpAddress + ":" + masterPort
+ val masterUrl = "spark://" + localHostname + ":" + masterPort
/* Start the Workers */
for (workerNum <- 1 to numWorkers) {
- val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker,
+ val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker,
memoryPerWorker, masterUrl, null, Some(workerNum))
workerActorSystems += workerSystem
}
diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala
index a38218a391..690bb20e50 100644
--- a/core/src/main/scala/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/spark/deploy/client/Client.scala
@@ -4,6 +4,7 @@ import spark.deploy._
import akka.actor._
import akka.pattern.ask
import scala.concurrent.duration._
+
import akka.pattern.AskTimeoutException
import spark.{SparkException, Logging}
import akka.remote.RemoteClientLifeCycleEvent
@@ -59,10 +60,10 @@ private[spark] class Client(
markDisconnected()
context.stop(self)
- case ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) =>
+ case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) =>
val fullId = appId + "/" + id
- logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, host, cores))
- listener.executorAdded(fullId, workerId, host, cores, memory)
+ logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores))
+ listener.executorAdded(fullId, workerId, hostPort, cores, memory)
case ExecutorUpdated(id, state, message, exitStatus) =>
val fullId = appId + "/" + id
@@ -112,7 +113,7 @@ private[spark] class Client(
def stop() {
if (actor != null) {
try {
- val timeout = 5.seconds
+ val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
val future = actor.ask(StopClient)(timeout)
Await.result(future, timeout)
} catch {
diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala
index b7008321df..e8c4083f9d 100644
--- a/core/src/main/scala/spark/deploy/client/ClientListener.scala
+++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala
@@ -12,7 +12,7 @@ private[spark] trait ClientListener {
def disconnected(): Unit
- def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit
+ def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit
def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit
}
diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala
index dc004b59ca..f195082808 100644
--- a/core/src/main/scala/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/spark/deploy/client/TestClient.scala
@@ -16,7 +16,7 @@ private[spark] object TestClient {
System.exit(0)
}
- def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) {}
+ def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {}
def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {}
}
@@ -25,7 +25,7 @@ private[spark] object TestClient {
val url = args(0)
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0)
val desc = new ApplicationDescription(
- "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home")
+ "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored")
val listener = new TestListener
val client = new Client(actorSystem, url, desc, listener)
client.start()
diff --git a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala
index 3591a94072..785c16e2be 100644
--- a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala
@@ -10,7 +10,8 @@ private[spark] class ApplicationInfo(
val id: String,
val desc: ApplicationDescription,
val submitDate: Date,
- val driver: ActorRef)
+ val driver: ActorRef,
+ val appUiUrl: String)
{
var state = ApplicationState.WAITING
var executors = new mutable.HashMap[Int, ExecutorInfo]
@@ -37,7 +38,7 @@ private[spark] class ApplicationInfo(
coresGranted -= exec.cores
}
- def coresLeft: Int = desc.cores - coresGranted
+ def coresLeft: Int = desc.maxCores - coresGranted
private var _retryCount = 0
@@ -60,4 +61,5 @@ private[spark] class ApplicationInfo(
System.currentTimeMillis() - startTime
}
}
+
}
diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala
index d1428bcfc6..770cfe9d05 100644
--- a/core/src/main/scala/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/spark/deploy/master/Master.scala
@@ -15,7 +15,7 @@ import spark.{Logging, SparkException, Utils}
import spark.util.AkkaUtils
-private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
+private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
@@ -35,9 +35,11 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
var firstApp: Option[ApplicationInfo] = None
+ Utils.checkHost(host, "Expected hostname")
+
val masterPublicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
- if (envVar != null) envVar else ip
+ if (envVar != null) envVar else host
}
// As a temporary workaround before better ways of configuring memory, we allow users to set
@@ -46,7 +48,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean
override def preStart() {
- logInfo("Starting Spark master at spark://" + ip + ":" + port)
+ logInfo("Starting Spark master at spark://" + host + ":" + port)
// Listen for remote client disconnection events, since they don't go through Akka's watch()
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
startWebUi()
@@ -146,7 +148,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
case RequestMasterState => {
- sender ! MasterState(ip, port, workers.toArray, apps.toArray, completedApps.toArray)
+ sender ! MasterState(host, port, workers.toArray, apps.toArray, completedApps.toArray)
}
}
@@ -212,13 +214,13 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
worker.actor ! LaunchExecutor(exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome)
- exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
+ exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)
}
def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
publicAddress: String): WorkerInfo = {
// There may be one or more refs to dead workers on this same node (w/ different ID's), remove them.
- workers.filter(w => (w.host == host) && (w.state == WorkerState.DEAD)).foreach(workers -= _)
+ workers.filter(w => (w.host == host && w.port == port) && (w.state == WorkerState.DEAD)).foreach(workers -= _)
val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
workers += worker
idToWorker(worker.id) = worker
@@ -243,7 +245,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
val now = System.currentTimeMillis()
val date = new Date(now)
- val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver)
+ val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl)
apps += app
idToApp(app.id) = app
actorToApp(driver) = app
@@ -274,6 +276,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
for (exec <- app.executors.values) {
exec.worker.removeExecutor(exec)
exec.worker.actor ! KillExecutor(exec.application.id, exec.id)
+ exec.state = ExecutorState.KILLED
}
app.markFinished(state)
app.driver ! ApplicationRemoved(state.toString)
@@ -308,7 +311,7 @@ private[spark] object Master {
def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings)
- val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort)
+ val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort)
actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/spark/deploy/master/MasterArguments.scala
index 4ceab3fc03..3d28ecabb4 100644
--- a/core/src/main/scala/spark/deploy/master/MasterArguments.scala
+++ b/core/src/main/scala/spark/deploy/master/MasterArguments.scala
@@ -7,13 +7,13 @@ import spark.Utils
* Command-line parser for the master.
*/
private[spark] class MasterArguments(args: Array[String]) {
- var ip = Utils.localHostName()
+ var host = Utils.localHostName()
var port = 7077
var webUiPort = 8080
// Check for settings in environment variables
- if (System.getenv("SPARK_MASTER_IP") != null) {
- ip = System.getenv("SPARK_MASTER_IP")
+ if (System.getenv("SPARK_MASTER_HOST") != null) {
+ host = System.getenv("SPARK_MASTER_HOST")
}
if (System.getenv("SPARK_MASTER_PORT") != null) {
port = System.getenv("SPARK_MASTER_PORT").toInt
@@ -26,7 +26,13 @@ private[spark] class MasterArguments(args: Array[String]) {
def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
- ip = value
+ Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
+ host = value
+ parse(tail)
+
+ case ("--host" | "-h") :: value :: tail =>
+ Utils.checkHost(value, "Please use hostname " + value)
+ host = value
parse(tail)
case ("--port" | "-p") :: IntParam(value) :: tail =>
@@ -54,7 +60,8 @@ private[spark] class MasterArguments(args: Array[String]) {
"Usage: Master [options]\n" +
"\n" +
"Options:\n" +
- " -i IP, --ip IP IP address or DNS name to listen on\n" +
+ " -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" +
+ " -h HOST, --host HOST Hostname to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: 7077)\n" +
" --webui-port PORT Port for web UI (default: 8080)")
System.exit(exitCode)
diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
index fe859d48c3..34cee87853 100644
--- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
+++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
@@ -3,6 +3,7 @@ package spark.deploy.master
import akka.actor.{ActorRef, ActorContext, ActorRefFactory}
import scala.concurrent.Await
import akka.pattern.ask
+
import akka.util.Timeout
import scala.concurrent.duration._
import spray.routing.Directives
@@ -25,8 +26,7 @@ class MasterWebUI(master: ActorRef)(implicit val context: ActorContext) extends
val RESOURCE_DIR = "spark/deploy/master/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
- implicit val timeout = Timeout(10 seconds)
-
+ implicit val timeout = Timeout(Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds"))
val handler = {
get {
diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
index 23df1bb463..0c08c5f417 100644
--- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
@@ -2,6 +2,7 @@ package spark.deploy.master
import akka.actor.ActorRef
import scala.collection.mutable
+import spark.Utils
private[spark] class WorkerInfo(
val id: String,
@@ -13,6 +14,9 @@ private[spark] class WorkerInfo(
val webUiPort: Int,
val publicAddress: String) {
+ Utils.checkHost(host, "Expected hostname")
+ assert (port > 0)
+
var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info
var state: WorkerState.Value = WorkerState.ALIVE
var coresUsed = 0
@@ -23,6 +27,11 @@ private[spark] class WorkerInfo(
def coresFree: Int = cores - coresUsed
def memoryFree: Int = memory - memoryUsed
+ def hostPort: String = {
+ assert (port > 0)
+ host + ":" + port
+ }
+
def addExecutor(exec: ExecutorInfo) {
executors(exec.fullId) = exec
coresUsed += exec.cores
diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
index de11771c8e..d7f58b2cb1 100644
--- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
@@ -1,6 +1,7 @@
package spark.deploy.worker
import java.io._
+import java.lang.System.getenv
import spark.deploy.{ExecutorState, ExecutorStateChanged, ApplicationDescription}
import akka.actor.ActorRef
import spark.{Utils, Logging}
@@ -21,11 +22,13 @@ private[spark] class ExecutorRunner(
val memory: Int,
val worker: ActorRef,
val workerId: String,
- val hostname: String,
+ val hostPort: String,
val sparkHome: File,
val workDir: File)
extends Logging {
+ Utils.checkHostPort(hostPort, "Expected hostport")
+
val fullId = appId + "/" + execId
var workerThread: Thread = null
var process: Process = null
@@ -38,7 +41,7 @@ private[spark] class ExecutorRunner(
workerThread.start()
// Shutdown hook that kills actors on shutdown.
- shutdownHook = new Thread() {
+ shutdownHook = new Thread() {
override def run() {
if (process != null) {
logInfo("Shutdown hook killing child process.")
@@ -68,16 +71,36 @@ private[spark] class ExecutorRunner(
/** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */
def substituteVariables(argument: String): String = argument match {
case "{{EXECUTOR_ID}}" => execId.toString
- case "{{HOSTNAME}}" => hostname
+ case "{{HOSTNAME}}" => Utils.parseHostPort(hostPort)._1
case "{{CORES}}" => cores.toString
case other => other
}
def buildCommandSeq(): Seq[String] = {
val command = appDesc.command
- val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run"
- val runScript = new File(sparkHome, script).getCanonicalPath
- Seq(runScript, command.mainClass) ++ (command.arguments ++ Seq(appId)).map(substituteVariables)
+ val runner = Option(getenv("JAVA_HOME")).map(_ + "/bin/java").getOrElse("java")
+ // SPARK-698: do not call the run.cmd script, as process.destroy()
+ // fails to kill a process tree on Windows
+ Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++
+ command.arguments.map(substituteVariables)
+ }
+
+ /**
+ * Attention: this must always be aligned with the environment variables in the run scripts and
+ * the way the JAVA_OPTS are assembled there.
+ */
+ def buildJavaOpts(): Seq[String] = {
+ val libraryOpts = Option(getenv("SPARK_LIBRARY_PATH"))
+ .map(p => List("-Djava.library.path=" + p))
+ .getOrElse(Nil)
+ val userOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil)
+ val memoryOpts = Seq("-Xms" + memory + "M", "-Xmx" + memory + "M")
+
+ // Figure out our classpath with the external compute-classpath script
+ val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh"
+ val classPath = Utils.executeAndGetOutput(Seq(sparkHome + "/bin/compute-classpath" + ext))
+
+ Seq("-cp", classPath) ++ libraryOpts ++ userOpts ++ memoryOpts
}
/** Spawn a thread that will redirect a given stream to a file */
@@ -113,7 +136,6 @@ private[spark] class ExecutorRunner(
for ((key, value) <- appDesc.command.environment) {
env.put(key, value)
}
- env.put("SPARK_MEM", memory.toString + "m")
// In case we are running this from within the Spark Shell, avoid creating a "scala"
// parent process for the executor command
env.put("SPARK_LAUNCH_WITH_SCALA", "0")
diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala
index 5bcf00443c..b5dfd16e67 100644
--- a/core/src/main/scala/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/spark/deploy/worker/Worker.scala
@@ -17,7 +17,7 @@ import java.io.File
private[spark] class Worker(
- ip: String,
+ host: String,
port: Int,
webUiPort: Int,
cores: Int,
@@ -26,6 +26,9 @@ private[spark] class Worker(
workDirPath: String = null)
extends Actor with Logging {
+ Utils.checkHost(host, "Expected hostname")
+ assert (port > 0)
+
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
@@ -40,7 +43,7 @@ private[spark] class Worker(
val finishedExecutors = new HashMap[String, ExecutorRunner]
val publicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
- if (envVar != null) envVar else ip
+ if (envVar != null) envVar else host
}
var coresUsed = 0
@@ -52,10 +55,14 @@ private[spark] class Worker(
def createWorkDir() {
workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work"))
try {
- if (!workDir.exists() && !workDir.mkdirs()) {
+ // This sporadically fails - not sure why ... !workDir.exists() && !workDir.mkdirs()
+ // So attempting to create and then check if directory was created or not.
+ workDir.mkdirs()
+ if ( !workDir.exists() || !workDir.isDirectory) {
logError("Failed to create work directory " + workDir)
System.exit(1)
}
+ assert (workDir.isDirectory)
} catch {
case e: Exception =>
logError("Failed to create work directory " + workDir, e)
@@ -65,7 +72,7 @@ private[spark] class Worker(
override def preStart() {
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
- ip, port, cores, Utils.memoryMegabytesToString(memory)))
+ host, port, cores, Utils.memoryMegabytesToString(memory)))
sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
logInfo("Spark home: " + sparkHome)
createWorkDir()
@@ -76,7 +83,7 @@ private[spark] class Worker(
def connectToMaster() {
logInfo("Connecting to master " + masterUrl)
master = context.actorFor(Master.toAkkaUrl(masterUrl))
- master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
+ master ! RegisterWorker(workerId, host, port, cores, memory, webUiPort, publicAddress)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing
}
@@ -108,7 +115,7 @@ private[spark] class Worker(
case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
val manager = new ExecutorRunner(
- appId, execId, appDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir)
+ appId, execId, appDesc, cores_, memory_, self, workerId, host + ":" + port, new File(execSparkHome_), workDir)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
@@ -143,7 +150,7 @@ private[spark] class Worker(
masterDisconnected()
case RequestWorkerState => {
- sender ! WorkerState(ip, port, workerId, executors.values.toList,
+ sender ! WorkerState(host, port, workerId, executors.values.toList,
finishedExecutors.values.toList, masterUrl, cores, memory,
coresUsed, memoryUsed, masterWebUiUrl)
}
@@ -158,7 +165,7 @@ private[spark] class Worker(
}
def generateWorkerId(): String = {
- "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port)
+ "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), host, port)
}
override def postStop() {
@@ -169,7 +176,7 @@ private[spark] class Worker(
private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
- val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores,
+ val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
args.memory, args.master, args.workDir)
actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
index 08f02bad80..2b96611ee3 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
@@ -9,7 +9,7 @@ import java.lang.management.ManagementFactory
* Command-line parser for the master.
*/
private[spark] class WorkerArguments(args: Array[String]) {
- var ip = Utils.localHostName()
+ var host = Utils.localHostName()
var port = 0
var webUiPort = 8081
var cores = inferDefaultCores()
@@ -38,7 +38,13 @@ private[spark] class WorkerArguments(args: Array[String]) {
def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
- ip = value
+ Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
+ host = value
+ parse(tail)
+
+ case ("--host" | "-h") :: value :: tail =>
+ Utils.checkHost(value, "Please use hostname " + value)
+ host = value
parse(tail)
case ("--port" | "-p") :: IntParam(value) :: tail =>
@@ -93,7 +99,8 @@ private[spark] class WorkerArguments(args: Array[String]) {
" -c CORES, --cores CORES Number of cores to use\n" +
" -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" +
" -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" +
- " -i IP, --ip IP IP address or DNS name to listen on\n" +
+ " -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" +
+ " -h HOST, --host HOST Hostname to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: random)\n" +
" --webui-port PORT Port for web UI (default: 8081)")
System.exit(exitCode)
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
index 33a2a9516e..cc2ab6187a 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
@@ -3,6 +3,7 @@ package spark.deploy.worker
import akka.actor.{ActorRef, ActorContext}
import scala.concurrent.Await
import akka.pattern.ask
+
import akka.util.Timeout
import scala.concurrent.duration._
import spray.routing.Directives
@@ -25,7 +26,7 @@ class WorkerWebUI(worker: ActorRef, workDir: File)(implicit val context: ActorCo
val RESOURCE_DIR = "spark/deploy/worker/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
- implicit val timeout = Timeout(10 seconds)
+ implicit val timeout = Timeout(Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds"))
val handler = {
get {
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index 3e7407b58d..2bf55ea9a9 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -17,7 +17,7 @@ import java.nio.ByteBuffer
* The Mesos executor for Spark.
*/
private[spark] class Executor(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) extends Logging {
-
+
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
@@ -27,6 +27,11 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
initLogging()
+ // No ip or host:port - just hostname
+ Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname")
+ // must not have port specified.
+ assert (0 == Utils.parseHostPort(slaveHostname)._2)
+
// Make sure the local hostname we report matches the cluster scheduler's name for this host
Utils.setCustomHostname(slaveHostname)
@@ -37,7 +42,8 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
// Create our ClassLoader and set it on this thread
private val urlClassLoader = createClassLoader()
- Thread.currentThread.setContextClassLoader(urlClassLoader)
+ private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
+ Thread.currentThread.setContextClassLoader(replClassLoader)
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
@@ -67,6 +73,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
// Initialize Spark environment (using system properties read above)
val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
SparkEnv.set(env)
+ private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size")
// Start worker thread pool
val threadPool = new ThreadPoolExecutor(
@@ -82,7 +89,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
override def run() {
val startTime = System.currentTimeMillis()
SparkEnv.set(env)
- Thread.currentThread.setContextClassLoader(urlClassLoader)
+ Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + taskId)
context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
@@ -98,6 +105,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
val value = task.run(taskId.toInt)
val taskFinish = System.currentTimeMillis()
task.metrics.foreach{ m =>
+ m.hostname = Utils.localHostName
m.executorDeserializeTime = (taskStart - startTime).toInt
m.executorRunTime = (taskFinish - taskStart).toInt
}
@@ -108,6 +116,10 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null))
val serializedResult = ser.serialize(result)
logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit)
+ if (serializedResult.limit >= (akkaFrameSize - 1024)) {
+ context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(TaskResultTooBigFailure()))
+ return
+ }
context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
logInfo("Finished task ID " + taskId)
} catch {
@@ -117,7 +129,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
}
case t: Throwable => {
- val reason = ExceptionFailure(t)
+ val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace)
context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
// TODO: Should we exit the whole executor here? On the one hand, the failed task may
@@ -142,26 +154,31 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
val urls = currentJars.keySet.map { uri =>
new File(uri.split("/").last).toURI.toURL
}.toArray
- loader = new URLClassLoader(urls, loader)
+ new ExecutorURLClassLoader(urls, loader)
+ }
- // If the REPL is in use, add another ClassLoader that will read
- // new classes defined by the REPL as the user types code
+ /**
+ * If the REPL is in use, add another ClassLoader that will read
+ * new classes defined by the REPL as the user types code
+ */
+ private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = {
val classUri = System.getProperty("spark.repl.class.uri")
if (classUri != null) {
logInfo("Using REPL class URI: " + classUri)
- loader = {
- try {
- val klass = Class.forName("spark.repl.ExecutorClassLoader")
- .asInstanceOf[Class[_ <: ClassLoader]]
- val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
- constructor.newInstance(classUri, loader)
- } catch {
- case _: ClassNotFoundException => loader
- }
+ try {
+ val klass = Class.forName("spark.repl.ExecutorClassLoader")
+ .asInstanceOf[Class[_ <: ClassLoader]]
+ val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
+ return constructor.newInstance(classUri, parent)
+ } catch {
+ case _: ClassNotFoundException =>
+ logError("Could not find spark.repl.ExecutorClassLoader on classpath!")
+ System.exit(1)
+ null
}
+ } else {
+ return parent
}
-
- return new ExecutorURLClassLoader(Array(), loader)
}
/**
diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
index 1047f71c6a..ebe2ac68d8 100644
--- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
+++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
@@ -12,23 +12,27 @@ import spark.scheduler.cluster.RegisteredExecutor
import spark.scheduler.cluster.LaunchTask
import spark.scheduler.cluster.RegisterExecutorFailed
import spark.scheduler.cluster.RegisterExecutor
+import spark.Utils
+import spark.deploy.SparkHadoopUtil
private[spark] class StandaloneExecutorBackend(
driverUrl: String,
executorId: String,
- hostname: String,
+ hostPort: String,
cores: Int)
extends Actor
with ExecutorBackend
with Logging {
+ Utils.checkHostPort(hostPort, "Expected hostport")
+
var executor: Executor = null
var driver: ActorRef = null
override def preStart() {
logInfo("Connecting to driver: " + driverUrl)
driver = context.actorFor(driverUrl)
- driver ! RegisterExecutor(executorId, hostname, cores)
+ driver ! RegisterExecutor(executorId, hostPort, cores)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(driver) // Doesn't work with remote actors, but useful for testing
}
@@ -36,7 +40,8 @@ private[spark] class StandaloneExecutorBackend(
override def receive = {
case RegisteredExecutor(sparkProperties) =>
logInfo("Successfully registered with driver")
- executor = new Executor(executorId, hostname, sparkProperties)
+ // Make this host instead of hostPort ?
+ executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties)
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
@@ -63,11 +68,30 @@ private[spark] class StandaloneExecutorBackend(
private[spark] object StandaloneExecutorBackend {
def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
+ SparkHadoopUtil.runAsUser(run0, Tuple4[Any, Any, Any, Any] (driverUrl, executorId, hostname, cores))
+ }
+
+ // This will be run 'as' the user
+ def run0(args: Product) {
+ assert(4 == args.productArity)
+ runImpl(args.productElement(0).asInstanceOf[String],
+ args.productElement(1).asInstanceOf[String],
+ args.productElement(2).asInstanceOf[String],
+ args.productElement(3).asInstanceOf[Int])
+ }
+
+ private def runImpl(driverUrl: String, executorId: String, hostname: String, cores: Int) {
+ // Debug code
+ Utils.checkHost(hostname)
+
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0)
+ // set it
+ val sparkHostPort = hostname + ":" + boundPort
+ System.setProperty("spark.hostPort", sparkHostPort)
val actor = actorSystem.actorOf(
- Props(new StandaloneExecutorBackend(driverUrl, executorId, hostname, cores)),
+ Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)),
name = "Executor")
actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala
index 93bbb6b458..1dc13754f9 100644
--- a/core/src/main/scala/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/spark/executor/TaskMetrics.scala
@@ -2,6 +2,11 @@ package spark.executor
class TaskMetrics extends Serializable {
/**
+ * Host's name the task runs on
+ */
+ var hostname: String = _
+
+ /**
* Time taken on the executor to deserialize this task
*/
var executorDeserializeTime: Int = _
@@ -34,9 +39,14 @@ object TaskMetrics {
class ShuffleReadMetrics extends Serializable {
/**
+ * Time when shuffle finishs
+ */
+ var shuffleFinishTime: Long = _
+
+ /**
* Total number of blocks fetched in a shuffle (remote or local)
*/
- var totalBlocksFetched : Int = _
+ var totalBlocksFetched: Int = _
/**
* Number of remote blocks fetched in a shuffle
@@ -49,11 +59,6 @@ class ShuffleReadMetrics extends Serializable {
var localBlocksFetched: Int = _
/**
- * Total time to read shuffle data
- */
- var shuffleReadMillis: Long = _
-
- /**
* Total time that is spent blocked waiting for shuffle to fetch data
*/
var fetchWaitTime: Long = _
diff --git a/core/src/main/scala/spark/network/BufferMessage.scala b/core/src/main/scala/spark/network/BufferMessage.scala
new file mode 100644
index 0000000000..7b0e489a6c
--- /dev/null
+++ b/core/src/main/scala/spark/network/BufferMessage.scala
@@ -0,0 +1,94 @@
+package spark.network
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+import spark.storage.BlockManager
+
+
+private[spark]
+class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
+ extends Message(Message.BUFFER_MESSAGE, id_) {
+
+ val initialSize = currentSize()
+ var gotChunkForSendingOnce = false
+
+ def size = initialSize
+
+ def currentSize() = {
+ if (buffers == null || buffers.isEmpty) {
+ 0
+ } else {
+ buffers.map(_.remaining).reduceLeft(_ + _)
+ }
+ }
+
+ def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
+ if (maxChunkSize <= 0) {
+ throw new Exception("Max chunk size is " + maxChunkSize)
+ }
+
+ if (size == 0 && gotChunkForSendingOnce == false) {
+ val newChunk = new MessageChunk(
+ new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+
+ while(!buffers.isEmpty) {
+ val buffer = buffers(0)
+ if (buffer.remaining == 0) {
+ BlockManager.dispose(buffer)
+ buffers -= buffer
+ } else {
+ val newBuffer = if (buffer.remaining <= maxChunkSize) {
+ buffer.duplicate()
+ } else {
+ buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
+ }
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+ }
+ None
+ }
+
+ def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
+ // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
+ if (buffers.size > 1) {
+ throw new Exception("Attempting to get chunk from message with multiple data buffers")
+ }
+ val buffer = buffers(0)
+ if (buffer.remaining > 0) {
+ if (buffer.remaining < chunkSize) {
+ throw new Exception("Not enough space in data buffer for receiving chunk")
+ }
+ val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ return Some(newChunk)
+ }
+ None
+ }
+
+ def flip() {
+ buffers.foreach(_.flip)
+ }
+
+ def hasAckId() = (ackId != 0)
+
+ def isCompletelyReceived() = !buffers(0).hasRemaining
+
+ override def toString = {
+ if (hasAckId) {
+ "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
+ } else {
+ "BufferMessage(id = " + id + ", size = " + size + ")"
+ }
+ }
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
index d1451bc212..6e28f677a3 100644
--- a/core/src/main/scala/spark/network/Connection.scala
+++ b/core/src/main/scala/spark/network/Connection.scala
@@ -13,12 +13,13 @@ import java.net._
private[spark]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
- val remoteConnectionManagerId: ConnectionManagerId) extends Logging {
+ val socketRemoteConnectionManagerId: ConnectionManagerId)
+ extends Logging {
+
def this(channel_ : SocketChannel, selector_ : Selector) = {
this(channel_, selector_,
- ConnectionManagerId.fromSocketAddress(
- channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
- ))
+ ConnectionManagerId.fromSocketAddress(
+ channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]))
}
channel.configureBlocking(false)
@@ -33,16 +34,47 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
val remoteAddress = getRemoteAddress()
+ // Read channels typically do not register for write and write does not for read
+ // Now, we do have write registering for read too (temporarily), but this is to detect
+ // channel close NOT to actually read/consume data on it !
+ // How does this work if/when we move to SSL ?
+
+ // What is the interest to register with selector for when we want this connection to be selected
+ def registerInterest()
+
+ // What is the interest to register with selector for when we want this connection to
+ // be de-selected
+ // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack,
+ // it will be SelectionKey.OP_READ (until we fix it properly)
+ def unregisterInterest()
+
+ // On receiving a read event, should we change the interest for this channel or not ?
+ // Will be true for ReceivingConnection, false for SendingConnection.
+ def changeInterestForRead(): Boolean
+
+ // On receiving a write event, should we change the interest for this channel or not ?
+ // Will be false for ReceivingConnection, true for SendingConnection.
+ // Actually, for now, should not get triggered for ReceivingConnection
+ def changeInterestForWrite(): Boolean
+
+ def getRemoteConnectionManagerId(): ConnectionManagerId = {
+ socketRemoteConnectionManagerId
+ }
+
def key() = channel.keyFor(selector)
def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
- def read() {
- throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString)
+ // Returns whether we have to register for further reads or not.
+ def read(): Boolean = {
+ throw new UnsupportedOperationException(
+ "Cannot read on connection of type " + this.getClass.toString)
}
-
- def write() {
- throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString)
+
+ // Returns whether we have to register for further writes or not.
+ def write(): Boolean = {
+ throw new UnsupportedOperationException(
+ "Cannot write on connection of type " + this.getClass.toString)
}
def close() {
@@ -54,26 +86,32 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
callOnCloseCallback()
}
- def onClose(callback: Connection => Unit) {onCloseCallback = callback}
+ def onClose(callback: Connection => Unit) {
+ onCloseCallback = callback
+ }
- def onException(callback: (Connection, Exception) => Unit) {onExceptionCallback = callback}
+ def onException(callback: (Connection, Exception) => Unit) {
+ onExceptionCallback = callback
+ }
- def onKeyInterestChange(callback: (Connection, Int) => Unit) {onKeyInterestChangeCallback = callback}
+ def onKeyInterestChange(callback: (Connection, Int) => Unit) {
+ onKeyInterestChangeCallback = callback
+ }
def callOnExceptionCallback(e: Exception) {
if (onExceptionCallback != null) {
onExceptionCallback(this, e)
} else {
- logError("Error in connection to " + remoteConnectionManagerId +
+ logError("Error in connection to " + getRemoteConnectionManagerId() +
" and OnExceptionCallback not registered", e)
}
}
-
+
def callOnCloseCallback() {
if (onCloseCallback != null) {
onCloseCallback(this)
} else {
- logWarning("Connection to " + remoteConnectionManagerId +
+ logWarning("Connection to " + getRemoteConnectionManagerId() +
" closed and OnExceptionCallback not registered")
}
@@ -81,7 +119,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
def changeConnectionKeyInterest(ops: Int) {
if (onKeyInterestChangeCallback != null) {
- onKeyInterestChangeCallback(this, ops)
+ onKeyInterestChangeCallback(this, ops)
} else {
throw new Exception("OnKeyInterestChangeCallback not registered")
}
@@ -105,24 +143,25 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
print(" (" + position + ", " + length + ")")
buffer.position(curPosition)
}
-
}
-private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
- remoteId_ : ConnectionManagerId)
-extends Connection(SocketChannel.open, selector_, remoteId_) {
+private[spark]
+class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
+ remoteId_ : ConnectionManagerId)
+ extends Connection(SocketChannel.open, selector_, remoteId_) {
class Outbox(fair: Int = 0) {
val messages = new Queue[Message]()
- val defaultChunkSize = 65536 //32768 //16384
+ val defaultChunkSize = 65536 //32768 //16384
var nextMessageToBeUsed = 0
def addMessage(message: Message) {
- messages.synchronized{
+ messages.synchronized{
/*messages += message*/
messages.enqueue(message)
- logDebug("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]")
+ logDebug("Added [" + message + "] to outbox for sending to " +
+ "[" + getRemoteConnectionManagerId() + "]")
}
}
@@ -147,18 +186,18 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
message.started = true
message.startTime = System.currentTimeMillis
}
- return chunk
+ return chunk
} else {
- /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
+ /*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/
message.finishTime = System.currentTimeMillis
- logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
+ logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
"] in " + message.timeTaken )
}
}
}
None
}
-
+
private def getChunkRR(): Option[MessageChunk] = {
messages.synchronized {
while (!messages.isEmpty) {
@@ -170,15 +209,17 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
messages.enqueue(message)
nextMessageToBeUsed = nextMessageToBeUsed + 1
if (!message.started) {
- logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]")
+ logDebug(
+ "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
message.started = true
message.startTime = System.currentTimeMillis
}
- logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]")
- return chunk
+ logTrace(
+ "Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]")
+ return chunk
} else {
message.finishTime = System.currentTimeMillis
- logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
+ logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
"] in " + message.timeTaken )
}
}
@@ -186,27 +227,40 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
None
}
}
-
- val outbox = new Outbox(1)
+
+ private val outbox = new Outbox(1)
val currentBuffers = new ArrayBuffer[ByteBuffer]()
/*channel.socket.setSendBufferSize(256 * 1024)*/
- override def getRemoteAddress() = address
+ override def getRemoteAddress() = address
+
+ val DEFAULT_INTEREST = SelectionKey.OP_READ
+
+ override def registerInterest() {
+ // Registering read too - does not really help in most cases, but for some
+ // it does - so let us keep it for now.
+ changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST)
+ }
+
+ override def unregisterInterest() {
+ changeConnectionKeyInterest(DEFAULT_INTEREST)
+ }
def send(message: Message) {
outbox.synchronized {
outbox.addMessage(message)
if (channel.isConnected) {
- changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ)
+ registerInterest()
}
}
}
+ // MUST be called within the selector loop
def connect() {
try{
- channel.connect(address)
channel.register(selector, SelectionKey.OP_CONNECT)
+ channel.connect(address)
logInfo("Initiating connection to [" + address + "]")
} catch {
case e: Exception => {
@@ -216,36 +270,52 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
}
- def finishConnect() {
+ def finishConnect(force: Boolean): Boolean = {
try {
- channel.finishConnect
- changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ)
+ // Typically, this should finish immediately since it was triggered by a connect
+ // selection - though need not necessarily always complete successfully.
+ val connected = channel.finishConnect
+ if (!force && !connected) {
+ logInfo(
+ "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending")
+ return false
+ }
+
+ // Fallback to previous behavior - assume finishConnect completed
+ // This will happen only when finishConnect failed for some repeated number of times
+ // (10 or so)
+ // Is highly unlikely unless there was an unclean close of socket, etc
+ registerInterest()
logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
+ return true
} catch {
case e: Exception => {
logWarning("Error finishing connection to " + address, e)
callOnExceptionCallback(e)
+ // ignore
+ return true
}
}
}
- override def write() {
- try{
- while(true) {
+ override def write(): Boolean = {
+ try {
+ while (true) {
if (currentBuffers.size == 0) {
outbox.synchronized {
outbox.getChunk() match {
case Some(chunk) => {
- currentBuffers ++= chunk.buffers
+ currentBuffers ++= chunk.buffers
}
case None => {
- changeConnectionKeyInterest(SelectionKey.OP_READ)
- return
+ // changeConnectionKeyInterest(0)
+ /*key.interestOps(0)*/
+ return false
}
}
}
}
-
+
if (currentBuffers.size > 0) {
val buffer = currentBuffers(0)
val remainingBytes = buffer.remaining
@@ -254,69 +324,109 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
currentBuffers -= buffer
}
if (writtenBytes < remainingBytes) {
- return
+ // re-register for write.
+ return true
}
}
}
} catch {
- case e: Exception => {
- logWarning("Error writing in connection to " + remoteConnectionManagerId, e)
+ case e: Exception => {
+ logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
+ return false
}
}
+ // should not happen - to keep scala compiler happy
+ return true
}
- override def read() {
+ // This is a hack to determine if remote socket was closed or not.
+ // SendingConnection DOES NOT expect to receive any data - if it does, it is an error
+ // For a bunch of cases, read will return -1 in case remote socket is closed : hence we
+ // register for reads to determine that.
+ override def read(): Boolean = {
// We don't expect the other side to send anything; so, we just read to detect an error or EOF.
try {
val length = channel.read(ByteBuffer.allocate(1))
if (length == -1) { // EOF
close()
} else if (length > 0) {
- logWarning("Unexpected data read from SendingConnection to " + remoteConnectionManagerId)
+ logWarning(
+ "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId())
}
} catch {
case e: Exception =>
- logError("Exception while reading SendingConnection to " + remoteConnectionManagerId, e)
+ logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
}
+
+ false
}
+
+ override def changeInterestForRead(): Boolean = false
+
+ override def changeInterestForWrite(): Boolean = true
}
-private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
-extends Connection(channel_, selector_) {
-
+// Must be created within selector loop - else deadlock
+private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
+ extends Connection(channel_, selector_) {
+
class Inbox() {
val messages = new HashMap[Int, BufferMessage]()
-
+
def getChunk(header: MessageChunkHeader): Option[MessageChunk] = {
-
+
def createNewMessage: BufferMessage = {
val newMessage = Message.create(header).asInstanceOf[BufferMessage]
newMessage.started = true
newMessage.startTime = System.currentTimeMillis
- logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]")
+ logDebug(
+ "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
messages += ((newMessage.id, newMessage))
newMessage
}
-
+
val message = messages.getOrElseUpdate(header.id, createNewMessage)
- logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]")
+ logTrace(
+ "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]")
message.getChunkForReceiving(header.chunkSize)
}
-
+
def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = {
- messages.get(chunk.header.id)
+ messages.get(chunk.header.id)
}
def removeMessage(message: Message) {
messages -= message.id
}
}
-
+
+ @volatile private var inferredRemoteManagerId: ConnectionManagerId = null
+
+ override def getRemoteConnectionManagerId(): ConnectionManagerId = {
+ val currId = inferredRemoteManagerId
+ if (currId != null) currId else super.getRemoteConnectionManagerId()
+ }
+
+ // The reciever's remote address is the local socket on remote side : which is NOT
+ // the connection manager id of the receiver.
+ // We infer that from the messages we receive on the receiver socket.
+ private def processConnectionManagerId(header: MessageChunkHeader) {
+ val currId = inferredRemoteManagerId
+ if (header.address == null || currId != null) return
+
+ val managerId = ConnectionManagerId.fromSocketAddress(header.address)
+
+ if (managerId != null) {
+ inferredRemoteManagerId = managerId
+ }
+ }
+
+
val inbox = new Inbox()
val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
var onReceiveCallback: (Connection , Message) => Unit = null
@@ -324,24 +434,29 @@ extends Connection(channel_, selector_) {
channel.register(selector, SelectionKey.OP_READ)
- override def read() {
+ override def read(): Boolean = {
try {
while (true) {
if (currentChunk == null) {
val headerBytesRead = channel.read(headerBuffer)
if (headerBytesRead == -1) {
close()
- return
+ return false
}
if (headerBuffer.remaining > 0) {
- return
+ // re-register for read event ...
+ return true
}
headerBuffer.flip
if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
- throw new Exception("Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
+ throw new Exception(
+ "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
}
val header = MessageChunkHeader.create(headerBuffer)
headerBuffer.clear()
+
+ processConnectionManagerId(header)
+
header.typ match {
case Message.BUFFER_MESSAGE => {
if (header.totalSize == 0) {
@@ -349,7 +464,8 @@ extends Connection(channel_, selector_) {
onReceiveCallback(this, Message.create(header))
}
currentChunk = null
- return
+ // re-register for read event ...
+ return true
} else {
currentChunk = inbox.getChunk(header).orNull
}
@@ -357,26 +473,28 @@ extends Connection(channel_, selector_) {
case _ => throw new Exception("Message of unknown type received")
}
}
-
+
if (currentChunk == null) throw new Exception("No message chunk to receive data")
-
+
val bytesRead = channel.read(currentChunk.buffer)
if (bytesRead == 0) {
- return
+ // re-register for read event ...
+ return true
} else if (bytesRead == -1) {
close()
- return
+ return false
}
/*logDebug("Read " + bytesRead + " bytes for the buffer")*/
-
+
if (currentChunk.buffer.remaining == 0) {
/*println("Filled buffer at " + System.currentTimeMillis)*/
val bufferMessage = inbox.getMessageForChunk(currentChunk).get
if (bufferMessage.isCompletelyReceived) {
bufferMessage.flip
bufferMessage.finishTime = System.currentTimeMillis
- logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken)
+ logDebug("Finished receiving [" + bufferMessage + "] from " +
+ "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken)
if (onReceiveCallback != null) {
onReceiveCallback(this, bufferMessage)
}
@@ -386,13 +504,32 @@ extends Connection(channel_, selector_) {
}
}
} catch {
- case e: Exception => {
- logWarning("Error reading from connection to " + remoteConnectionManagerId, e)
+ case e: Exception => {
+ logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
+ return false
}
}
+ // should not happen - to keep scala compiler happy
+ return true
}
-
+
def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
+
+ override def changeInterestForRead(): Boolean = true
+
+ override def changeInterestForWrite(): Boolean = {
+ throw new IllegalStateException("Unexpected invocation right now")
+ }
+
+ override def registerInterest() {
+ // Registering read too - does not really help in most cases, but for some
+ // it does - so let us keep it for now.
+ changeConnectionKeyInterest(SelectionKey.OP_READ)
+ }
+
+ override def unregisterInterest() {
+ changeConnectionKeyInterest(0)
+ }
}
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
index 8f8892b8c7..cc5c62a542 100644
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -6,28 +6,19 @@ import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
import java.net._
-import java.util.concurrent.Executors
+import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor}
+import scala.collection.mutable.HashSet
import scala.collection.mutable.HashMap
import scala.collection.mutable.SynchronizedMap
import scala.collection.mutable.SynchronizedQueue
-import scala.collection.mutable.Queue
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{Await, Promise, ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.concurrent.duration._
-private[spark] case class ConnectionManagerId(host: String, port: Int) {
- def toSocketAddress() = new InetSocketAddress(host, port)
-}
-private[spark] object ConnectionManagerId {
- def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
- new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
- }
-}
-
private[spark] class ConnectionManager(port: Int) extends Logging {
class MessageStatus(
@@ -41,73 +32,263 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
def markDone() { completionHandler(this) }
}
-
- val selector = SelectorProvider.provider.openSelector()
- val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt)
- val serverChannel = ServerSocketChannel.open()
- val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
- val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
- val messageStatuses = new HashMap[Int, MessageStatus]
- val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
- val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
- val sendMessageRequests = new Queue[(Message, SendingConnection)]
+
+ private val selector = SelectorProvider.provider.openSelector()
+
+ private val handleMessageExecutor = new ThreadPoolExecutor(
+ System.getProperty("spark.core.connection.handler.threads.min","20").toInt,
+ System.getProperty("spark.core.connection.handler.threads.max","60").toInt,
+ System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable]())
+
+ private val handleReadWriteExecutor = new ThreadPoolExecutor(
+ System.getProperty("spark.core.connection.io.threads.min","4").toInt,
+ System.getProperty("spark.core.connection.io.threads.max","32").toInt,
+ System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable]())
+
+ // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap
+ private val handleConnectExecutor = new ThreadPoolExecutor(
+ System.getProperty("spark.core.connection.connect.threads.min","1").toInt,
+ System.getProperty("spark.core.connection.connect.threads.max","8").toInt,
+ System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable]())
+
+ private val serverChannel = ServerSocketChannel.open()
+ private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
+ private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
+ private val messageStatuses = new HashMap[Int, MessageStatus]
+ private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
+ private val registerRequests = new SynchronizedQueue[SendingConnection]
implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
- var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
+ private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
serverChannel.configureBlocking(false)
serverChannel.socket.setReuseAddress(true)
- serverChannel.socket.setReceiveBufferSize(256 * 1024)
+ serverChannel.socket.setReceiveBufferSize(256 * 1024)
serverChannel.socket.bind(new InetSocketAddress(port))
serverChannel.register(selector, SelectionKey.OP_ACCEPT)
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
-
- val selectorThread = new Thread("connection-manager-thread") {
+
+ private val selectorThread = new Thread("connection-manager-thread") {
override def run() = ConnectionManager.this.run()
}
selectorThread.setDaemon(true)
selectorThread.start()
- private def run() {
- try {
- while(!selectorThread.isInterrupted) {
- for ((connectionManagerId, sendingConnection) <- connectionRequests) {
- sendingConnection.connect()
- addConnection(sendingConnection)
- connectionRequests -= connectionManagerId
+ private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
+
+ private def triggerWrite(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ writeRunnableStarted.synchronized {
+ // So that we do not trigger more write events while processing this one.
+ // The write method will re-register when done.
+ if (conn.changeInterestForWrite()) conn.unregisterInterest()
+ if (writeRunnableStarted.contains(key)) {
+ // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE)
+ return
+ }
+
+ writeRunnableStarted += key
+ }
+ handleReadWriteExecutor.execute(new Runnable {
+ override def run() {
+ var register: Boolean = false
+ try {
+ register = conn.write()
+ } finally {
+ writeRunnableStarted.synchronized {
+ writeRunnableStarted -= key
+ if (register && conn.changeInterestForWrite()) {
+ conn.registerInterest()
+ }
+ }
}
- sendMessageRequests.synchronized {
- while (!sendMessageRequests.isEmpty) {
- val (message, connection) = sendMessageRequests.dequeue
- connection.send(message)
+ }
+ } )
+ }
+
+ private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
+
+ private def triggerRead(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ readRunnableStarted.synchronized {
+ // So that we do not trigger more read events while processing this one.
+ // The read method will re-register when done.
+ if (conn.changeInterestForRead())conn.unregisterInterest()
+ if (readRunnableStarted.contains(key)) {
+ return
+ }
+
+ readRunnableStarted += key
+ }
+ handleReadWriteExecutor.execute(new Runnable {
+ override def run() {
+ var register: Boolean = false
+ try {
+ register = conn.read()
+ } finally {
+ readRunnableStarted.synchronized {
+ readRunnableStarted -= key
+ if (register && conn.changeInterestForRead()) {
+ conn.registerInterest()
+ }
}
}
+ }
+ } )
+ }
+
+ private def triggerConnect(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection]
+ if (conn == null) return
+
+ // prevent other events from being triggered
+ // Since we are still trying to connect, we do not need to do the additional steps in triggerWrite
+ conn.changeConnectionKeyInterest(0)
+
+ handleConnectExecutor.execute(new Runnable {
+ override def run() {
+
+ var tries: Int = 10
+ while (tries >= 0) {
+ if (conn.finishConnect(false)) return
+ // Sleep ?
+ Thread.sleep(1)
+ tries -= 1
+ }
+
+ // fallback to previous behavior : we should not really come here since this method was
+ // triggered since channel became connectable : but at times, the first finishConnect need not
+ // succeed : hence the loop to retry a few 'times'.
+ conn.finishConnect(true)
+ }
+ } )
+ }
+
+ // MUST be called within selector loop - else deadlock.
+ private def triggerForceCloseByException(key: SelectionKey, e: Exception) {
+ try {
+ key.interestOps(0)
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ // Pushing to connect threadpool
+ handleConnectExecutor.execute(new Runnable {
+ override def run() {
+ try {
+ conn.callOnExceptionCallback(e)
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+ try {
+ conn.close()
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+ }
+ })
+ }
+
- while (!keyInterestChangeRequests.isEmpty) {
+ def run() {
+ try {
+ while(!selectorThread.isInterrupted) {
+ while (! registerRequests.isEmpty) {
+ val conn: SendingConnection = registerRequests.dequeue
+ addListeners(conn)
+ conn.connect()
+ addConnection(conn)
+ }
+
+ while(!keyInterestChangeRequests.isEmpty) {
val (key, ops) = keyInterestChangeRequests.dequeue
- val connection = connectionsByKey(key)
- val lastOps = key.interestOps()
- key.interestOps(ops)
-
- def intToOpStr(op: Int): String = {
- val opStrs = ArrayBuffer[String]()
- if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
- if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
- if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
- if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
- if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
+
+ try {
+ if (key.isValid) {
+ val connection = connectionsByKey.getOrElse(key, null)
+ if (connection != null) {
+ val lastOps = key.interestOps()
+ key.interestOps(ops)
+
+ // hot loop - prevent materialization of string if trace not enabled.
+ if (isTraceEnabled()) {
+ def intToOpStr(op: Int): String = {
+ val opStrs = ArrayBuffer[String]()
+ if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
+ if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
+ if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
+ if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
+ if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
+ }
+
+ logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() +
+ "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
+ }
+ }
+ } else {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
}
-
- logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId +
- "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
-
}
- val selectedKeysCount = selector.select()
+ val selectedKeysCount =
+ try {
+ selector.select()
+ } catch {
+ // Explicitly only dealing with CancelledKeyException here since other exceptions should be dealt with differently.
+ case e: CancelledKeyException => {
+ // Some keys within the selectors list are invalid/closed. clear them.
+ val allKeys = selector.keys().iterator()
+
+ while (allKeys.hasNext()) {
+ val key = allKeys.next()
+ try {
+ if (! key.isValid) {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ }
+ }
+ }
+ 0
+ }
+
if (selectedKeysCount == 0) {
logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
}
@@ -115,20 +296,40 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logInfo("Selector thread was interrupted!")
return
}
-
- val selectedKeys = selector.selectedKeys().iterator()
- while (selectedKeys.hasNext()) {
- val key = selectedKeys.next
- selectedKeys.remove()
- if (key.isValid) {
- if (key.isAcceptable) {
- acceptConnection(key)
- } else if (key.isConnectable) {
- connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect()
- } else if (key.isReadable) {
- connectionsByKey(key).read()
- } else if (key.isWritable) {
- connectionsByKey(key).write()
+
+ if (0 != selectedKeysCount) {
+ val selectedKeys = selector.selectedKeys().iterator()
+ while (selectedKeys.hasNext()) {
+ val key = selectedKeys.next
+ selectedKeys.remove()
+ try {
+ if (key.isValid) {
+ if (key.isAcceptable) {
+ acceptConnection(key)
+ } else
+ if (key.isConnectable) {
+ triggerConnect(key)
+ } else
+ if (key.isReadable) {
+ triggerRead(key)
+ } else
+ if (key.isWritable) {
+ triggerWrite(key)
+ }
+ } else {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ // weird, but we saw this happening - even though key.isValid was true, key.isAcceptable would throw CancelledKeyException.
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
}
}
}
@@ -137,97 +338,119 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
case e: Exception => logError("Error in select loop", e)
}
}
-
- private def acceptConnection(key: SelectionKey) {
+
+ def acceptConnection(key: SelectionKey) {
val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
- val newChannel = serverChannel.accept()
- val newConnection = new ReceivingConnection(newChannel, selector)
- newConnection.onReceive(receiveMessage)
- newConnection.onClose(removeConnection)
- addConnection(newConnection)
- logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
- }
- private def addConnection(connection: Connection) {
- connectionsByKey += ((connection.key, connection))
- if (connection.isInstanceOf[SendingConnection]) {
- val sendingConnection = connection.asInstanceOf[SendingConnection]
- connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection))
+ var newChannel = serverChannel.accept()
+
+ // accept them all in a tight loop. non blocking accept with no processing, should be fine
+ while (newChannel != null) {
+ try {
+ val newConnection = new ReceivingConnection(newChannel, selector)
+ newConnection.onReceive(receiveMessage)
+ addListeners(newConnection)
+ addConnection(newConnection)
+ logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
+ } catch {
+ // might happen in case of issues with registering with selector
+ case e: Exception => logError("Error in accept loop", e)
+ }
+
+ newChannel = serverChannel.accept()
}
+ }
+
+ private def addListeners(connection: Connection) {
connection.onKeyInterestChange(changeConnectionKeyInterest)
connection.onException(handleConnectionError)
connection.onClose(removeConnection)
}
- private def removeConnection(connection: Connection) {
+ def addConnection(connection: Connection) {
+ connectionsByKey += ((connection.key, connection))
+ }
+
+ def removeConnection(connection: Connection) {
connectionsByKey -= connection.key
- if (connection.isInstanceOf[SendingConnection]) {
- val sendingConnection = connection.asInstanceOf[SendingConnection]
- val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId
- logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
-
- connectionsById -= sendingConnectionManagerId
-
- messageStatuses.synchronized {
- messageStatuses
- .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
- logInfo("Notifying " + status)
- status.synchronized {
- status.attempted = true
- status.acked = false
- status.markDone()
- }
+
+ try {
+ if (connection.isInstanceOf[SendingConnection]) {
+ val sendingConnection = connection.asInstanceOf[SendingConnection]
+ val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
+ logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
+
+ connectionsById -= sendingConnectionManagerId
+
+ messageStatuses.synchronized {
+ messageStatuses
+ .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
+ logInfo("Notifying " + status)
+ status.synchronized {
+ status.attempted = true
+ status.acked = false
+ status.markDone()
+ }
+ })
+
+ messageStatuses.retain((i, status) => {
+ status.connectionManagerId != sendingConnectionManagerId
})
+ }
+ } else if (connection.isInstanceOf[ReceivingConnection]) {
+ val receivingConnection = connection.asInstanceOf[ReceivingConnection]
+ val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId()
+ logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
+
+ val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId)
+ if (! sendingConnectionOpt.isDefined) {
+ logError("Corresponding SendingConnectionManagerId not found")
+ return
+ }
- messageStatuses.retain((i, status) => {
- status.connectionManagerId != sendingConnectionManagerId
- })
- }
- } else if (connection.isInstanceOf[ReceivingConnection]) {
- val receivingConnection = connection.asInstanceOf[ReceivingConnection]
- val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId
- logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
-
- val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull
- if (sendingConnectionManagerId == null) {
- logError("Corresponding SendingConnectionManagerId not found")
- return
- }
- logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId)
-
- val sendingConnection = connectionsById(sendingConnectionManagerId)
- sendingConnection.close()
- connectionsById -= sendingConnectionManagerId
-
- messageStatuses.synchronized {
- for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
- logInfo("Notifying " + s)
- s.synchronized {
- s.attempted = true
- s.acked = false
- s.markDone()
+ val sendingConnection = sendingConnectionOpt.get
+ connectionsById -= remoteConnectionManagerId
+ sendingConnection.close()
+
+ val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
+
+ assert (sendingConnectionManagerId == remoteConnectionManagerId)
+
+ messageStatuses.synchronized {
+ for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
+ logInfo("Notifying " + s)
+ s.synchronized {
+ s.attempted = true
+ s.acked = false
+ s.markDone()
+ }
}
- }
- messageStatuses.retain((i, status) => {
- status.connectionManagerId != sendingConnectionManagerId
- })
+ messageStatuses.retain((i, status) => {
+ status.connectionManagerId != sendingConnectionManagerId
+ })
+ }
}
+ } finally {
+ // So that the selection keys can be removed.
+ wakeupSelector()
}
}
- private def handleConnectionError(connection: Connection, e: Exception) {
- logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId)
+ def handleConnectionError(connection: Connection, e: Exception) {
+ logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId())
removeConnection(connection)
}
- private def changeConnectionKeyInterest(connection: Connection, ops: Int) {
- keyInterestChangeRequests += ((connection.key, ops))
+ def changeConnectionKeyInterest(connection: Connection, ops: Int) {
+ keyInterestChangeRequests += ((connection.key, ops))
+ // so that registerations happen !
+ wakeupSelector()
}
- private def receiveMessage(connection: Connection, message: Message) {
+ def receiveMessage(connection: Connection, message: Message) {
val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
- logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
+ logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
val runnable = new Runnable() {
val creationTime = System.currentTimeMillis
def run() {
@@ -247,11 +470,11 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
if (bufferMessage.hasAckId) {
val sentMessageStatus = messageStatuses.synchronized {
messageStatuses.get(bufferMessage.ackId) match {
- case Some(status) => {
- messageStatuses -= bufferMessage.ackId
+ case Some(status) => {
+ messageStatuses -= bufferMessage.ackId
status
}
- case None => {
+ case None => {
throw new Exception("Could not find reference for received ack message " + message.id)
null
}
@@ -271,7 +494,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logDebug("Not calling back as callback is null")
None
}
-
+
if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
@@ -281,7 +504,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
}
}
- sendMessage(connectionManagerId, ackMessage.getOrElse {
+ sendMessage(connectionManagerId, ackMessage.getOrElse {
Message.createBufferMessage(bufferMessage.id)
})
}
@@ -293,18 +516,22 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
- val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId,
- new SendingConnection(inetSocketAddress, selector, connectionManagerId))
- newConnection
+ val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId)
+ registerRequests.enqueue(newConnection)
+
+ newConnection
}
- val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)
- val connection = connectionsById.getOrElse(lookupKey, startNewConnection())
+ // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ...
+ // If we do re-add it, we should consistently use it everywhere I guess ?
+ val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
message.senderAddress = id.toSocketAddress()
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
- /*connection.send(message)*/
- sendMessageRequests.synchronized {
- sendMessageRequests += ((message, connection))
- }
+ connection.send(message)
+
+ wakeupSelector()
+ }
+
+ private def wakeupSelector() {
selector.wakeup()
}
@@ -337,6 +564,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logWarning("All connections not cleaned up")
}
handleMessageExecutor.shutdown()
+ handleReadWriteExecutor.shutdown()
+ handleConnectExecutor.shutdown()
logInfo("ConnectionManager stopped")
}
}
@@ -346,17 +575,17 @@ private[spark] object ConnectionManager {
def main(args: Array[String]) {
val manager = new ConnectionManager(9999)
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
println("Received [" + msg + "] from [" + id + "]")
None
})
-
+
/*testSequentialSending(manager)*/
/*System.gc()*/
/*testParallelSending(manager)*/
/*System.gc()*/
-
+
/*testParallelDecreasingSending(manager)*/
/*System.gc()*/
@@ -368,9 +597,9 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Sequential Sending")
println("--------------------------")
- val size = 10 * 1024 * 1024
+ val size = 10 * 1024 * 1024
val count = 10
-
+
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
@@ -386,7 +615,7 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Parallel Sending")
println("--------------------------")
- val size = 10 * 1024 * 1024
+ val size = 10 * 1024 * 1024
val count = 10
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
@@ -401,12 +630,12 @@ private[spark] object ConnectionManager {
if (!g.isDefined) println("Failed")
})
val finishTime = System.currentTimeMillis
-
+
val mb = size * count / 1024.0 / 1024.0
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
println("--------------------------")
- println("Started at " + startTime + ", finished at " + finishTime)
+ println("Started at " + startTime + ", finished at " + finishTime)
println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)")
println("--------------------------")
println()
@@ -416,7 +645,7 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Parallel Decreasing Sending")
println("--------------------------")
- val size = 10 * 1024 * 1024
+ val size = 10 * 1024 * 1024
val count = 10
val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte)))
buffers.foreach(_.flip)
@@ -431,7 +660,7 @@ private[spark] object ConnectionManager {
if (!g.isDefined) println("Failed")
})
val finishTime = System.currentTimeMillis
-
+
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
println("--------------------------")
@@ -445,7 +674,7 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Continuous Sending")
println("--------------------------")
- val size = 10 * 1024 * 1024
+ val size = 10 * 1024 * 1024
val count = 10
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
diff --git a/core/src/main/scala/spark/network/ConnectionManagerId.scala b/core/src/main/scala/spark/network/ConnectionManagerId.scala
new file mode 100644
index 0000000000..b554e84251
--- /dev/null
+++ b/core/src/main/scala/spark/network/ConnectionManagerId.scala
@@ -0,0 +1,21 @@
+package spark.network
+
+import java.net.InetSocketAddress
+
+import spark.Utils
+
+
+private[spark] case class ConnectionManagerId(host: String, port: Int) {
+ // DEBUG code
+ Utils.checkHost(host)
+ assert (port > 0)
+
+ def toSocketAddress() = new InetSocketAddress(host, port)
+}
+
+
+private[spark] object ConnectionManagerId {
+ def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
+ new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
+ }
+}
diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala
index 525751b5bf..d4f03610eb 100644
--- a/core/src/main/scala/spark/network/Message.scala
+++ b/core/src/main/scala/spark/network/Message.scala
@@ -1,55 +1,10 @@
package spark.network
-import spark._
-
-import scala.collection.mutable.ArrayBuffer
-
import java.nio.ByteBuffer
-import java.net.InetAddress
import java.net.InetSocketAddress
-import storage.BlockManager
-
-private[spark] class MessageChunkHeader(
- val typ: Long,
- val id: Int,
- val totalSize: Int,
- val chunkSize: Int,
- val other: Int,
- val address: InetSocketAddress) {
- lazy val buffer = {
- val ip = address.getAddress.getAddress()
- val port = address.getPort()
- ByteBuffer.
- allocate(MessageChunkHeader.HEADER_SIZE).
- putLong(typ).
- putInt(id).
- putInt(totalSize).
- putInt(chunkSize).
- putInt(other).
- putInt(ip.size).
- put(ip).
- putInt(port).
- position(MessageChunkHeader.HEADER_SIZE).
- flip.asInstanceOf[ByteBuffer]
- }
-
- override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
- " and sizes " + totalSize + " / " + chunkSize + " bytes"
-}
-private[spark] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
- val size = if (buffer == null) 0 else buffer.remaining
- lazy val buffers = {
- val ab = new ArrayBuffer[ByteBuffer]()
- ab += header.buffer
- if (buffer != null) {
- ab += buffer
- }
- ab
- }
+import scala.collection.mutable.ArrayBuffer
- override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
-}
private[spark] abstract class Message(val typ: Long, val id: Int) {
var senderAddress: InetSocketAddress = null
@@ -58,120 +13,16 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
var finishTime = -1L
def size: Int
-
+
def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]
-
+
def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]
-
+
def timeTaken(): String = (finishTime - startTime).toString + " ms"
override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
}
-private[spark] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
-extends Message(Message.BUFFER_MESSAGE, id_) {
-
- val initialSize = currentSize()
- var gotChunkForSendingOnce = false
-
- def size = initialSize
-
- def currentSize() = {
- if (buffers == null || buffers.isEmpty) {
- 0
- } else {
- buffers.map(_.remaining).reduceLeft(_ + _)
- }
- }
-
- def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
- if (maxChunkSize <= 0) {
- throw new Exception("Max chunk size is " + maxChunkSize)
- }
-
- if (size == 0 && gotChunkForSendingOnce == false) {
- val newChunk = new MessageChunk(new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
- gotChunkForSendingOnce = true
- return Some(newChunk)
- }
-
- while(!buffers.isEmpty) {
- val buffer = buffers(0)
- if (buffer.remaining == 0) {
- BlockManager.dispose(buffer)
- buffers -= buffer
- } else {
- val newBuffer = if (buffer.remaining <= maxChunkSize) {
- buffer.duplicate()
- } else {
- buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
- }
- buffer.position(buffer.position + newBuffer.remaining)
- val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
- gotChunkForSendingOnce = true
- return Some(newChunk)
- }
- }
- None
- }
-
- def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
- // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
- if (buffers.size > 1) {
- throw new Exception("Attempting to get chunk from message with multiple data buffers")
- }
- val buffer = buffers(0)
- if (buffer.remaining > 0) {
- if (buffer.remaining < chunkSize) {
- throw new Exception("Not enough space in data buffer for receiving chunk")
- }
- val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
- buffer.position(buffer.position + newBuffer.remaining)
- val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
- return Some(newChunk)
- }
- None
- }
-
- def flip() {
- buffers.foreach(_.flip)
- }
-
- def hasAckId() = (ackId != 0)
-
- def isCompletelyReceived() = !buffers(0).hasRemaining
-
- override def toString = {
- if (hasAckId) {
- "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
- } else {
- "BufferMessage(id = " + id + ", size = " + size + ")"
- }
- }
-}
-
-private[spark] object MessageChunkHeader {
- val HEADER_SIZE = 40
-
- def create(buffer: ByteBuffer): MessageChunkHeader = {
- if (buffer.remaining != HEADER_SIZE) {
- throw new IllegalArgumentException("Cannot convert buffer data to Message")
- }
- val typ = buffer.getLong()
- val id = buffer.getInt()
- val totalSize = buffer.getInt()
- val chunkSize = buffer.getInt()
- val other = buffer.getInt()
- val ipSize = buffer.getInt()
- val ipBytes = new Array[Byte](ipSize)
- buffer.get(ipBytes)
- val ip = InetAddress.getByAddress(ipBytes)
- val port = buffer.getInt()
- new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
- }
-}
private[spark] object Message {
val BUFFER_MESSAGE = 1111111111L
@@ -180,14 +31,16 @@ private[spark] object Message {
def getNewId() = synchronized {
lastId += 1
- if (lastId == 0) lastId += 1
+ if (lastId == 0) {
+ lastId += 1
+ }
lastId
}
def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
if (dataBuffers == null) {
return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
- }
+ }
if (dataBuffers.exists(_ == null)) {
throw new Exception("Attempting to create buffer message with null buffer")
}
@@ -196,7 +49,7 @@ private[spark] object Message {
def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
createBufferMessage(dataBuffers, 0)
-
+
def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
if (dataBuffer == null) {
return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
@@ -204,15 +57,18 @@ private[spark] object Message {
return createBufferMessage(Array(dataBuffer), ackId)
}
}
-
- def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
+
+ def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
createBufferMessage(dataBuffer, 0)
-
- def createBufferMessage(ackId: Int): BufferMessage = createBufferMessage(new Array[ByteBuffer](0), ackId)
+
+ def createBufferMessage(ackId: Int): BufferMessage = {
+ createBufferMessage(new Array[ByteBuffer](0), ackId)
+ }
def create(header: MessageChunkHeader): Message = {
val newMessage: Message = header.typ match {
- case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
+ case BUFFER_MESSAGE => new BufferMessage(header.id,
+ ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
}
newMessage.senderAddress = header.address
newMessage
diff --git a/core/src/main/scala/spark/network/MessageChunk.scala b/core/src/main/scala/spark/network/MessageChunk.scala
new file mode 100644
index 0000000000..aaf9204d0e
--- /dev/null
+++ b/core/src/main/scala/spark/network/MessageChunk.scala
@@ -0,0 +1,25 @@
+package spark.network
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+
+private[network]
+class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
+
+ val size = if (buffer == null) 0 else buffer.remaining
+
+ lazy val buffers = {
+ val ab = new ArrayBuffer[ByteBuffer]()
+ ab += header.buffer
+ if (buffer != null) {
+ ab += buffer
+ }
+ ab
+ }
+
+ override def toString = {
+ "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
+ }
+}
diff --git a/core/src/main/scala/spark/network/MessageChunkHeader.scala b/core/src/main/scala/spark/network/MessageChunkHeader.scala
new file mode 100644
index 0000000000..3693d509d6
--- /dev/null
+++ b/core/src/main/scala/spark/network/MessageChunkHeader.scala
@@ -0,0 +1,58 @@
+package spark.network
+
+import java.net.InetAddress
+import java.net.InetSocketAddress
+import java.nio.ByteBuffer
+
+
+private[spark] class MessageChunkHeader(
+ val typ: Long,
+ val id: Int,
+ val totalSize: Int,
+ val chunkSize: Int,
+ val other: Int,
+ val address: InetSocketAddress) {
+ lazy val buffer = {
+ // No need to change this, at 'use' time, we do a reverse lookup of the hostname.
+ // Refer to network.Connection
+ val ip = address.getAddress.getAddress()
+ val port = address.getPort()
+ ByteBuffer.
+ allocate(MessageChunkHeader.HEADER_SIZE).
+ putLong(typ).
+ putInt(id).
+ putInt(totalSize).
+ putInt(chunkSize).
+ putInt(other).
+ putInt(ip.size).
+ put(ip).
+ putInt(port).
+ position(MessageChunkHeader.HEADER_SIZE).
+ flip.asInstanceOf[ByteBuffer]
+ }
+
+ override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
+ " and sizes " + totalSize + " / " + chunkSize + " bytes"
+}
+
+
+private[spark] object MessageChunkHeader {
+ val HEADER_SIZE = 40
+
+ def create(buffer: ByteBuffer): MessageChunkHeader = {
+ if (buffer.remaining != HEADER_SIZE) {
+ throw new IllegalArgumentException("Cannot convert buffer data to Message")
+ }
+ val typ = buffer.getLong()
+ val id = buffer.getInt()
+ val totalSize = buffer.getInt()
+ val chunkSize = buffer.getInt()
+ val other = buffer.getInt()
+ val ipSize = buffer.getInt()
+ val ipBytes = new Array[Byte](ipSize)
+ buffer.get(ipBytes)
+ val ip = InetAddress.getByAddress(ipBytes)
+ val port = buffer.getInt()
+ new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
+ }
+}
diff --git a/core/src/main/scala/spark/network/netty/FileHeader.scala b/core/src/main/scala/spark/network/netty/FileHeader.scala
new file mode 100644
index 0000000000..aed4254234
--- /dev/null
+++ b/core/src/main/scala/spark/network/netty/FileHeader.scala
@@ -0,0 +1,57 @@
+package spark.network.netty
+
+import io.netty.buffer._
+
+import spark.Logging
+
+private[spark] class FileHeader (
+ val fileLen: Int,
+ val blockId: String) extends Logging {
+
+ lazy val buffer = {
+ val buf = Unpooled.buffer()
+ buf.capacity(FileHeader.HEADER_SIZE)
+ buf.writeInt(fileLen)
+ buf.writeInt(blockId.length)
+ blockId.foreach((x: Char) => buf.writeByte(x))
+ //padding the rest of header
+ if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
+ buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
+ } else {
+ throw new Exception("too long header " + buf.readableBytes)
+ logInfo("too long header")
+ }
+ buf
+ }
+
+}
+
+private[spark] object FileHeader {
+
+ val HEADER_SIZE = 40
+
+ def getFileLenOffset = 0
+ def getFileLenSize = Integer.SIZE/8
+
+ def create(buf: ByteBuf): FileHeader = {
+ val length = buf.readInt
+ val idLength = buf.readInt
+ val idBuilder = new StringBuilder(idLength)
+ for (i <- 1 to idLength) {
+ idBuilder += buf.readByte().asInstanceOf[Char]
+ }
+ val blockId = idBuilder.toString()
+ new FileHeader(length, blockId)
+ }
+
+
+ def main (args:Array[String]){
+
+ val header = new FileHeader(25,"block_0");
+ val buf = header.buffer;
+ val newheader = FileHeader.create(buf);
+ System.out.println("id="+newheader.blockId+",size="+newheader.fileLen)
+
+ }
+}
+
diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
new file mode 100644
index 0000000000..8d5194a737
--- /dev/null
+++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
@@ -0,0 +1,101 @@
+package spark.network.netty
+
+import java.util.concurrent.Executors
+
+import io.netty.buffer.ByteBuf
+import io.netty.channel.ChannelHandlerContext
+import io.netty.util.CharsetUtil
+
+import spark.Logging
+import spark.network.ConnectionManagerId
+
+import scala.collection.JavaConverters._
+
+
+private[spark] class ShuffleCopier extends Logging {
+
+ def getBlock(host: String, port: Int, blockId: String,
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+
+ val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
+ val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
+ val fc = new FileClient(handler, connectTimeout)
+
+ try {
+ fc.init()
+ fc.connect(host, port)
+ fc.sendRequest(blockId)
+ fc.waitForClose()
+ fc.close()
+ } catch {
+ // Handle any socket-related exceptions in FileClient
+ case e: Exception => {
+ logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e)
+ handler.handleError(blockId)
+ }
+ }
+ }
+
+ def getBlock(cmId: ConnectionManagerId, blockId: String,
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
+ }
+
+ def getBlocks(cmId: ConnectionManagerId,
+ blocks: Seq[(String, Long)],
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+
+ for ((blockId, size) <- blocks) {
+ getBlock(cmId, blockId, resultCollectCallback)
+ }
+ }
+}
+
+
+private[spark] object ShuffleCopier extends Logging {
+
+ private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit)
+ extends FileClientHandler with Logging {
+
+ override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
+ logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
+ resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
+ }
+
+ override def handleError(blockId: String) {
+ if (!isComplete) {
+ resultCollectCallBack(blockId, -1, null)
+ }
+ }
+ }
+
+ def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
+ if (size != -1) {
+ logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
+ }
+ }
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: ShuffleCopier <host> <port> <shuffle_block_id> <threads>")
+ System.exit(1)
+ }
+ val host = args(0)
+ val port = args(1).toInt
+ val file = args(2)
+ val threads = if (args.length > 3) args(3).toInt else 10
+
+ val copiers = Executors.newFixedThreadPool(80)
+ val tasks = (for (i <- Range(0, threads)) yield {
+ Executors.callable(new Runnable() {
+ def run() {
+ val copier = new ShuffleCopier()
+ copier.getBlock(host, port, file, echoResultCollectCallBack)
+ }
+ })
+ }).asJava
+ copiers.invokeAll(tasks)
+ copiers.shutdown
+ System.exit(0)
+ }
+}
diff --git a/core/src/main/scala/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/spark/network/netty/ShuffleSender.scala
new file mode 100644
index 0000000000..d6fa4b1e80
--- /dev/null
+++ b/core/src/main/scala/spark/network/netty/ShuffleSender.scala
@@ -0,0 +1,53 @@
+package spark.network.netty
+
+import java.io.File
+
+import spark.Logging
+
+
+private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
+
+ val server = new FileServer(pResolver, portIn)
+ server.start()
+
+ def stop() {
+ server.stop()
+ }
+
+ def port: Int = server.getPort()
+}
+
+
+/**
+ * An application for testing the shuffle sender as a standalone program.
+ */
+private[spark] object ShuffleSender {
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println(
+ "Usage: ShuffleSender <port> <subDirsPerLocalDir> <list of shuffle_block_directories>")
+ System.exit(1)
+ }
+
+ val port = args(0).toInt
+ val subDirsPerLocalDir = args(1).toInt
+ val localDirs = args.drop(2).map(new File(_))
+
+ val pResovler = new PathResolver {
+ override def getAbsolutePath(blockId: String): String = {
+ if (!blockId.startsWith("shuffle_")) {
+ throw new Exception("Block " + blockId + " is not a shuffle block")
+ }
+ // Figure out which local directory it hashes to, and which subdirectory in that
+ val hash = math.abs(blockId.hashCode)
+ val dirId = hash % localDirs.length
+ val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
+ val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
+ val file = new File(subDir, blockId)
+ return file.getAbsolutePath
+ }
+ }
+ val sender = new ShuffleSender(port, pResovler)
+ }
+}
diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index f44d37a91f..3e60860b3e 100644
--- a/core/src/main/scala/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -1,8 +1,9 @@
package spark.rdd
-import scala.collection.mutable.HashMap
import scala.reflect.ClassTag
+
import spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext}
+import spark.storage.BlockManager
private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition {
val index = idx
@@ -12,12 +13,7 @@ private[spark]
class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[String])
extends RDD[T](sc, Nil) {
- @transient lazy val locations_ = {
- val blockManager = SparkEnv.get.blockManager
- /*val locations = blockIds.map(id => blockManager.getLocations(id))*/
- val locations = blockManager.getLocations(blockIds)
- HashMap(blockIds.zip(locations):_*)
- }
+ @transient lazy val locations_ = BlockManager.blockIdsToExecutorLocations(blockIds, SparkEnv.get)
override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => {
new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition]
diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
index 700a4160c8..efd29fa561 100644
--- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
@@ -9,6 +9,7 @@ import org.apache.hadoop.util.ReflectionUtils
import org.apache.hadoop.fs.Path
import java.io.{File, IOException, EOFException}
import java.text.NumberFormat
+import spark.deploy.SparkHadoopUtil
private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
@@ -22,13 +23,20 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
@transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
override def getPartitions: Array[Partition] = {
- val dirContents = fs.listStatus(new Path(checkpointPath))
- val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
- val numPartitions = partitionFiles.size
- if (numPartitions > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
- ! partitionFiles(numPartitions-1).endsWith(CheckpointRDD.splitIdToFile(numPartitions-1)))) {
- throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
- }
+ val cpath = new Path(checkpointPath)
+ val numPartitions =
+ // listStatus can throw exception if path does not exist.
+ if (fs.exists(cpath)) {
+ val dirContents = fs.listStatus(cpath)
+ val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
+ val numPart = partitionFiles.size
+ if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
+ ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) {
+ throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
+ }
+ numPart
+ } else 0
+
Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
}
@@ -36,7 +44,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
checkpointData.get.cpFile = Some(checkpointPath)
override def getPreferredLocations(split: Partition): Seq[String] = {
- val status = fs.getFileStatus(new Path(checkpointPath))
+ val status = fs.getFileStatus(new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)))
val locations = fs.getFileBlockLocations(status, 0, status.getLen)
locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
}
@@ -59,7 +67,7 @@ private[spark] object CheckpointRDD extends Logging {
def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
val outputDir = new Path(path)
- val fs = outputDir.getFileSystem(new Configuration())
+ val fs = outputDir.getFileSystem(SparkHadoopUtil.newConfiguration())
val finalOutputName = splitIdToFile(ctx.splitId)
val finalOutputPath = new Path(outputDir, finalOutputName)
@@ -84,6 +92,7 @@ private[spark] object CheckpointRDD extends Logging {
if (!fs.rename(tempOutputPath, finalOutputPath)) {
if (!fs.exists(finalOutputPath)) {
+ logInfo("Deleting tempOutputPath " + tempOutputPath)
fs.delete(tempOutputPath, false)
throw new IOException("Checkpoint failed: failed to save output of task: "
+ ctx.attemptId + " and final output path does not exist")
@@ -96,7 +105,7 @@ private[spark] object CheckpointRDD extends Logging {
}
def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
- val fs = path.getFileSystem(new Configuration())
+ val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration())
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val fileInputStream = fs.open(path, bufferSize)
val serializer = SparkEnv.get.serializer.newInstance()
@@ -118,7 +127,7 @@ private[spark] object CheckpointRDD extends Logging {
val sc = new SparkContext(cluster, "CheckpointRDD Test")
val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
val path = new Path(hdfsPath, "temp")
- val fs = path.getFileSystem(new Configuration())
+ val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration())
sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index a6235491ca..8966f9f86e 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -6,7 +6,7 @@ import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer
-import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext}
+import spark.{Aggregator, Partition, Partitioner, RDD, SparkEnv, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
@@ -49,12 +49,17 @@ private[spark] class CoGroupAggregator
*
* @param rdds parent RDDs.
* @param part partitioner used to partition the shuffle output.
- * @param mapSideCombine flag indicating whether to merge values before shuffle step.
+ * @param mapSideCombine flag indicating whether to merge values before shuffle step. If the flag
+ * is on, Spark does an extra pass over the data on the map side to merge
+ * all values belonging to the same key together. This can reduce the amount
+ * of data shuffled if and only if the number of distinct keys is very small,
+ * and the ratio of key size to value size is also very small.
*/
class CoGroupedRDD[K](
@transient var rdds: Seq[RDD[(K, _)]],
part: Partitioner,
- val mapSideCombine: Boolean = true)
+ val mapSideCombine: Boolean = false,
+ val serializerClass: String = null)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
private val aggr = new CoGroupAggregator
@@ -68,9 +73,9 @@ class CoGroupedRDD[K](
logInfo("Adding shuffle dependency with " + rdd)
if (mapSideCombine) {
val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
- new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
+ new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part, serializerClass)
} else {
- new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part)
+ new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part, serializerClass)
}
}
}
@@ -112,6 +117,7 @@ class CoGroupedRDD[K](
}
}
+ val ser = SparkEnv.get.serializerManager.get(serializerClass)
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
// Read them from the parent
@@ -124,12 +130,12 @@ class CoGroupedRDD[K](
val fetcher = SparkEnv.get.shuffleFetcher
if (mapSideCombine) {
// With map side combine on, for each key, the shuffle fetcher returns a list of values.
- fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics).foreach {
+ fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
case (key, values) => getSeq(key)(depNum) ++= values
}
} else {
// With map side combine off, for each key the shuffle fetcher returns a single value.
- fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics).foreach {
+ fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics, ser).foreach {
case (key, value) => getSeq(key)(depNum) += value
}
}
diff --git a/core/src/main/scala/spark/rdd/EmptyRDD.scala b/core/src/main/scala/spark/rdd/EmptyRDD.scala
new file mode 100644
index 0000000000..e4dd3a7fa7
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/EmptyRDD.scala
@@ -0,0 +1,16 @@
+package spark.rdd
+
+import spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext}
+
+
+/**
+ * An RDD that is empty, i.e. has no element in it.
+ */
+class EmptyRDD[T: ClassManifest](sc: SparkContext) extends RDD[T](sc, Nil) {
+
+ override def getPartitions: Array[Partition] = Array.empty
+
+ override def compute(split: Partition, context: TaskContext): Iterator[T] = {
+ throw new UnsupportedOperationException("empty RDD")
+ }
+}
diff --git a/core/src/main/scala/spark/rdd/JdbcRDD.scala b/core/src/main/scala/spark/rdd/JdbcRDD.scala
new file mode 100644
index 0000000000..a50f407737
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/JdbcRDD.scala
@@ -0,0 +1,103 @@
+package spark.rdd
+
+import java.sql.{Connection, ResultSet}
+
+import spark.{Logging, Partition, RDD, SparkContext, TaskContext}
+import spark.util.NextIterator
+
+private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
+ override def index = idx
+}
+
+/**
+ * An RDD that executes an SQL query on a JDBC connection and reads results.
+ * For usage example, see test case JdbcRDDSuite.
+ *
+ * @param getConnection a function that returns an open Connection.
+ * The RDD takes care of closing the connection.
+ * @param sql the text of the query.
+ * The query must contain two ? placeholders for parameters used to partition the results.
+ * E.g. "select title, author from books where ? <= id and id <= ?"
+ * @param lowerBound the minimum value of the first placeholder
+ * @param upperBound the maximum value of the second placeholder
+ * The lower and upper bounds are inclusive.
+ * @param numPartitions the number of partitions.
+ * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
+ * the query would be executed twice, once with (1, 10) and once with (11, 20)
+ * @param mapRow a function from a ResultSet to a single row of the desired result type(s).
+ * This should only call getInt, getString, etc; the RDD takes care of calling next.
+ * The default maps a ResultSet to an array of Object.
+ */
+class JdbcRDD[T: ClassManifest](
+ sc: SparkContext,
+ getConnection: () => Connection,
+ sql: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int,
+ mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _)
+ extends RDD[T](sc, Nil) with Logging {
+
+ override def getPartitions: Array[Partition] = {
+ // bounds are inclusive, hence the + 1 here and - 1 on end
+ val length = 1 + upperBound - lowerBound
+ (0 until numPartitions).map(i => {
+ val start = lowerBound + ((i * length) / numPartitions).toLong
+ val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1
+ new JdbcPartition(i, start, end)
+ }).toArray
+ }
+
+ override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] {
+ context.addOnCompleteCallback{ () => closeIfNeeded() }
+ val part = thePart.asInstanceOf[JdbcPartition]
+ val conn = getConnection()
+ val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
+
+ // setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results,
+ // rather than pulling entire resultset into memory.
+ // see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html
+ if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) {
+ stmt.setFetchSize(Integer.MIN_VALUE)
+ logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ")
+ }
+
+ stmt.setLong(1, part.lower)
+ stmt.setLong(2, part.upper)
+ val rs = stmt.executeQuery()
+
+ override def getNext: T = {
+ if (rs.next()) {
+ mapRow(rs)
+ } else {
+ finished = true
+ null.asInstanceOf[T]
+ }
+ }
+
+ override def close() {
+ try {
+ if (null != rs && ! rs.isClosed()) rs.close()
+ } catch {
+ case e: Exception => logWarning("Exception closing resultset", e)
+ }
+ try {
+ if (null != stmt && ! stmt.isClosed()) stmt.close()
+ } catch {
+ case e: Exception => logWarning("Exception closing statement", e)
+ }
+ try {
+ if (null != conn && ! stmt.isClosed()) conn.close()
+ logInfo("closed connection")
+ } catch {
+ case e: Exception => logWarning("Exception closing connection", e)
+ }
+ }
+ }
+}
+
+object JdbcRDD {
+ def resultSetToObjectArray(rs: ResultSet) = {
+ Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
+ }
+}
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index bdd974590a..901d01ef30 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -57,7 +57,7 @@ class NewHadoopRDD[K, V](
override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopPartition]
val conf = confBroadcast.value.value
- val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
+ val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
if (format.isInstanceOf[Configurable]) {
diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index 34d32eb85a..349e6162c4 100644
--- a/core/src/main/scala/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -10,6 +10,7 @@ import scala.io.Source
import scala.reflect.ClassTag
import spark.{RDD, SparkEnv, Partition, TaskContext}
+import spark.broadcast.Broadcast
/**
@@ -19,14 +20,21 @@ import spark.{RDD, SparkEnv, Partition, TaskContext}
class PipedRDD[T: ClassTag](
prev: RDD[T],
command: Seq[String],
- envVars: Map[String, String])
+ envVars: Map[String, String],
+ printPipeContext: (String => Unit) => Unit,
+ printRDDElement: (T, String => Unit) => Unit)
extends RDD[String](prev) {
- def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map())
-
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
- def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command))
+ def this(
+ prev: RDD[T],
+ command: String,
+ envVars: Map[String, String] = Map(),
+ printPipeContext: (String => Unit) => Unit = null,
+ printRDDElement: (T, String => Unit) => Unit = null) =
+ this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement)
+
override def getPartitions: Array[Partition] = firstParent[T].partitions
@@ -53,8 +61,17 @@ class PipedRDD[T: ClassTag](
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
+
+ // input the pipe context firstly
+ if (printPipeContext != null) {
+ printPipeContext(out.println(_))
+ }
for (elem <- firstParent[T].iterator(split, context)) {
- out.println(elem)
+ if (printRDDElement != null) {
+ printRDDElement(elem, out.println(_))
+ } else {
+ out.println(elem)
+ }
}
out.close()
}
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 4e33b7dd5c..c7d1926b83 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -3,6 +3,7 @@ package spark.rdd
import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext}
import spark.SparkContext._
+
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
override val index = idx
override def hashCode(): Int = idx
@@ -12,13 +13,15 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
* The resulting RDD from a shuffle (e.g. repartitioning of data).
* @param prev the parent RDD.
* @param part the partitioner used to partition the RDD
+ * @param serializerClass class name of the serializer to use.
* @tparam K the key class.
* @tparam V the value class.
*/
class ShuffledRDD[K, V](
@transient prev: RDD[(K, V)],
- part: Partitioner)
- extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part))) {
+ part: Partitioner,
+ serializerClass: String = null)
+ extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part, serializerClass))) {
override val partitioner = Some(part)
@@ -28,6 +31,7 @@ class ShuffledRDD[K, V](
override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics)
+ SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics,
+ SparkEnv.get.serializerManager.get(serializerClass))
}
}
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index 5e56900b18..9274839bca 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -15,6 +15,7 @@ import spark.SparkEnv
import spark.ShuffleDependency
import spark.OneToOneDependency
+
/**
* An optimized version of cogroup for set difference/subtraction.
*
@@ -34,7 +35,9 @@ import spark.OneToOneDependency
private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
@transient var rdd1: RDD[(K, V)],
@transient var rdd2: RDD[(K, W)],
- part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) {
+ part: Partitioner,
+ val serializerClass: String = null)
+ extends RDD[(K, V)](rdd1.context, Nil) {
override def getDependencies: Seq[Dependency[_]] = {
Seq(rdd1, rdd2).map { rdd =>
@@ -43,7 +46,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
- new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part)
+ new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part, serializerClass)
}
}
}
@@ -68,6 +71,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
+ val serializer = SparkEnv.get.serializerManager.get(serializerClass)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
@@ -80,12 +84,16 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
}
}
def integrate(dep: CoGroupSplitDep, op: ((K, V)) => Unit) = dep match {
- case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
for (t <- rdd.iterator(itsSplit, context))
op(t.asInstanceOf[(K, V)])
- case ShuffleCoGroupSplitDep(shuffleId) =>
- for (t <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
+ }
+ case ShuffleCoGroupSplitDep(shuffleId) => {
+ val iter = SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index,
+ context.taskMetrics, serializer)
+ for (t <- iter)
op(t.asInstanceOf[(K, V)])
+ }
}
// the first dep is rdd1; add all values to the map
integrate(partition.deps(0), t => getSeq(t._1) += t._2)
diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala
new file mode 100644
index 0000000000..b234428ab2
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala
@@ -0,0 +1,138 @@
+package spark.rdd
+
+import spark.{Utils, OneToOneDependency, RDD, SparkContext, Partition, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
+
+private[spark] class ZippedPartitionsPartition(
+ idx: Int,
+ @transient rdds: Seq[RDD[_]])
+ extends Partition {
+
+ override val index: Int = idx
+ var partitionValues = rdds.map(rdd => rdd.partitions(idx))
+ def partitions = partitionValues
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ partitionValues = rdds.map(rdd => rdd.partitions(idx))
+ oos.defaultWriteObject()
+ }
+}
+
+abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
+ sc: SparkContext,
+ var rdds: Seq[RDD[_]])
+ extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) {
+
+ override def getPartitions: Array[Partition] = {
+ val sizes = rdds.map(x => x.partitions.size)
+ if (!sizes.forall(x => x == sizes(0))) {
+ throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
+ }
+ val array = new Array[Partition](sizes(0))
+ for (i <- 0 until sizes(0)) {
+ array(i) = new ZippedPartitionsPartition(i, rdds)
+ }
+ array
+ }
+
+ override def getPreferredLocations(s: Partition): Seq[String] = {
+ // Note that as number of rdd's increase and/or number of slaves in cluster increase, the computed preferredLocations below
+ // become diminishingly small : so we might need to look at alternate strategies to alleviate this.
+ // If there are no (or very small number of preferred locations), we will end up transferred the blocks to 'any' node in the
+ // cluster - paying with n/w and cache cost.
+ // Maybe pick a node which figures max amount of time ?
+ // Choose node which is hosting 'larger' of some subset of blocks ?
+ // Look at rack locality to ensure chosen host is atleast rack local to both hosting node ?, etc (would be good to defer this if possible)
+ val splits = s.asInstanceOf[ZippedPartitionsPartition].partitions
+ val rddSplitZip = rdds.zip(splits)
+
+ // exact match.
+ val exactMatchPreferredLocations = rddSplitZip.map(x => x._1.preferredLocations(x._2))
+ val exactMatchLocations = exactMatchPreferredLocations.reduce((x, y) => x.intersect(y))
+
+ // Remove exact match and then do host local match.
+ val exactMatchHosts = exactMatchLocations.map(Utils.parseHostPort(_)._1)
+ val matchPreferredHosts = exactMatchPreferredLocations.map(locs => locs.map(Utils.parseHostPort(_)._1))
+ .reduce((x, y) => x.intersect(y))
+ val otherNodeLocalLocations = matchPreferredHosts.filter { s => !exactMatchHosts.contains(s) }
+
+ otherNodeLocalLocations ++ exactMatchLocations
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdds = null
+ }
+}
+
+class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest](
+ sc: SparkContext,
+ f: (Iterator[A], Iterator[B]) => Iterator[V],
+ var rdd1: RDD[A],
+ var rdd2: RDD[B])
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) {
+
+ override def compute(s: Partition, context: TaskContext): Iterator[V] = {
+ val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
+ f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context))
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdd1 = null
+ rdd2 = null
+ }
+}
+
+class ZippedPartitionsRDD3
+ [A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest](
+ sc: SparkContext,
+ f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
+ var rdd1: RDD[A],
+ var rdd2: RDD[B],
+ var rdd3: RDD[C])
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) {
+
+ override def compute(s: Partition, context: TaskContext): Iterator[V] = {
+ val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
+ f(rdd1.iterator(partitions(0), context),
+ rdd2.iterator(partitions(1), context),
+ rdd3.iterator(partitions(2), context))
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdd1 = null
+ rdd2 = null
+ rdd3 = null
+ }
+}
+
+class ZippedPartitionsRDD4
+ [A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest](
+ sc: SparkContext,
+ f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
+ var rdd1: RDD[A],
+ var rdd2: RDD[B],
+ var rdd3: RDD[C],
+ var rdd4: RDD[D])
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) {
+
+ override def compute(s: Partition, context: TaskContext): Iterator[V] = {
+ val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
+ f(rdd1.iterator(partitions(0), context),
+ rdd2.iterator(partitions(1), context),
+ rdd3.iterator(partitions(2), context),
+ rdd4.iterator(partitions(3), context))
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdd1 = null
+ rdd2 = null
+ rdd3 = null
+ rdd4 = null
+ }
+}
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
index 1b438cd505..be05fb71f9 100644
--- a/core/src/main/scala/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -1,5 +1,7 @@
package spark.rdd
+import spark.{Utils, OneToOneDependency, RDD, SparkContext, Partition, TaskContext}
+
import java.io.{ObjectOutputStream, IOException}
import scala.reflect.ClassTag
@@ -50,8 +52,27 @@ class ZippedRDD[T: ClassTag, U: ClassTag](
}
override def getPreferredLocations(s: Partition): Seq[String] = {
+ // Note that as number of slaves in cluster increase, the computed preferredLocations can become small : so we might need
+ // to look at alternate strategies to alleviate this. (If there are no (or very small number of preferred locations), we
+ // will end up transferred the blocks to 'any' node in the cluster - paying with n/w and cache cost.
+ // Maybe pick one or the other ? (so that atleast one block is local ?).
+ // Choose node which is hosting 'larger' of the blocks ?
+ // Look at rack locality to ensure chosen host is atleast rack local to both hosting node ?, etc (would be good to defer this if possible)
val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions
- rdd1.preferredLocations(partition1).intersect(rdd2.preferredLocations(partition2))
+ val pref1 = rdd1.preferredLocations(partition1)
+ val pref2 = rdd2.preferredLocations(partition2)
+
+ // exact match - instance local and host local.
+ val exactMatchLocations = pref1.intersect(pref2)
+
+ // remove locations which are already handled via exactMatchLocations, and intersect where both partitions are node local.
+ val otherNodeLocalPref1 = pref1.filter(loc => ! exactMatchLocations.contains(loc)).map(loc => Utils.parseHostPort(loc)._1)
+ val otherNodeLocalPref2 = pref2.filter(loc => ! exactMatchLocations.contains(loc)).map(loc => Utils.parseHostPort(loc)._1)
+ val otherNodeLocalLocations = otherNodeLocalPref1.intersect(otherNodeLocalPref2)
+
+
+ // Can have mix of instance local (hostPort) and node local (host) locations as preference !
+ exactMatchLocations ++ otherNodeLocalLocations
}
override def clearDependencies() {
diff --git a/core/src/main/scala/spark/scheduler/ActiveJob.scala b/core/src/main/scala/spark/scheduler/ActiveJob.scala
index 5a4e9a582d..105eaecb22 100644
--- a/core/src/main/scala/spark/scheduler/ActiveJob.scala
+++ b/core/src/main/scala/spark/scheduler/ActiveJob.scala
@@ -2,6 +2,8 @@ package spark.scheduler
import spark.TaskContext
+import java.util.Properties
+
/**
* Tracks information about an active job in the DAGScheduler.
*/
@@ -11,7 +13,8 @@ private[spark] class ActiveJob(
val func: (TaskContext, Iterator[_]) => _,
val partitions: Array[Int],
val callSite: String,
- val listener: JobListener) {
+ val listener: JobListener,
+ val properties: Properties) {
val numPartitions = partitions.length
val finished = Array.fill[Boolean](numPartitions)(false)
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index b838cf84a8..1164c40c43 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -4,6 +4,7 @@ import cluster.TaskInfo
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
+import java.util.Properties
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.reflect.ClassTag
@@ -13,7 +14,7 @@ import spark.executor.TaskMetrics
import spark.partial.ApproximateActionListener
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
-import spark.storage.BlockManagerMaster
+import spark.storage.{BlockManager, BlockManagerMaster}
import spark.util.{MetadataCleaner, TimeStampedHashMap}
/**
@@ -51,6 +52,11 @@ class DAGScheduler(
eventQueue.put(ExecutorLost(execId))
}
+ // Called by TaskScheduler when a host is added
+ override def executorGained(execId: String, hostPort: String) {
+ eventQueue.put(ExecutorGained(execId, hostPort))
+ }
+
// Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
override def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
@@ -89,6 +95,8 @@ class DAGScheduler(
// stray messages to detect.
val failedGeneration = new HashMap[String, Long]
+ val idToActiveJob = new HashMap[Int, ActiveJob]
+
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
val running = new HashSet[Stage] // Stages we are running right now
val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures
@@ -113,9 +121,8 @@ class DAGScheduler(
private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
if (!cacheLocs.contains(rdd.id)) {
val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
- cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map {
- locations => locations.map(_.ip).toList
- }.toArray
+ val locs = BlockManager.blockIdsToExecutorLocations(blockIds, env, blockManagerMaster)
+ cacheLocs(rdd.id) = blockIds.map(locs.getOrElse(_, Nil))
}
cacheLocs(rdd.id)
}
@@ -222,13 +229,14 @@ class DAGScheduler(
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
- resultHandler: (Int, U) => Unit)
+ resultHandler: (Int, U) => Unit,
+ properties: Properties = null)
: (JobSubmitted, JobWaiter[U]) =
{
assert(partitions.size > 0)
val waiter = new JobWaiter(partitions.size, resultHandler)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)
+ val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
return (toSubmit, waiter)
}
@@ -238,13 +246,14 @@ class DAGScheduler(
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
- resultHandler: (Int, U) => Unit)
+ resultHandler: (Int, U) => Unit,
+ properties: Properties = null)
{
if (partitions.size == 0) {
return
}
val (toSubmit, waiter) = prepareJob(
- finalRdd, func, partitions, callSite, allowLocal, resultHandler)
+ finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)
eventQueue.put(toSubmit)
waiter.awaitResult() match {
case JobSucceeded => {}
@@ -259,13 +268,14 @@ class DAGScheduler(
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
callSite: String,
- timeout: Long)
+ timeout: Long,
+ properties: Properties = null)
: PartialResult[R] =
{
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
- eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener))
+ eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener, properties))
return listener.awaitResult() // Will throw an exception if the job fails
}
@@ -275,10 +285,10 @@ class DAGScheduler(
*/
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
- case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
+ case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) =>
val runId = nextRunId.getAndIncrement()
val finalStage = newStage(finalRDD, None, runId)
- val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
+ val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
" output partitions (allowLocal=" + allowLocal + ")")
@@ -289,15 +299,22 @@ class DAGScheduler(
// Compute very short actions like first() or take() with no parent stages locally.
runLocally(job)
} else {
+ sparkListeners.foreach(_.onJobStart(SparkListenerJobStart(job, properties)))
+ idToActiveJob(runId) = job
activeJobs += job
resultStageToJob(finalStage) = job
submitStage(finalStage)
}
+ case ExecutorGained(execId, hostPort) =>
+ handleExecutorGained(execId, hostPort)
+
case ExecutorLost(execId) =>
handleExecutorLost(execId)
case completion: CompletionEvent =>
+ sparkListeners.foreach(_.onTaskEnd(SparkListenerTaskEnd(completion.task,
+ completion.reason, completion.taskInfo, completion.taskMetrics)))
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
@@ -308,6 +325,7 @@ class DAGScheduler(
for (job <- activeJobs) {
val error = new SparkException("Job cancelled because SparkContext was shut down")
job.listener.jobFailed(error)
+ sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error))))
}
return true
}
@@ -455,11 +473,13 @@ class DAGScheduler(
}
}
if (tasks.size > 0) {
+ sparkListeners.foreach(_.onStageSubmitted(SparkListenerStageSubmitted(stage, tasks.size)))
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
myPending ++= tasks
logDebug("New pending tasks: " + myPending)
+ val properties = idToActiveJob(stage.priority).properties
taskSched.submitTasks(
- new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority))
+ new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority, properties))
if (!stage.submissionTime.isDefined) {
stage.submissionTime = Some(System.currentTimeMillis())
}
@@ -508,6 +528,7 @@ class DAGScheduler(
activeJobs -= job
resultStageToJob -= stage
markStageAsFinished(stage)
+ sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobSucceeded)))
}
job.listener.taskSucceeded(rt.outputId, event.result)
}
@@ -631,6 +652,14 @@ class DAGScheduler(
"(generation " + currentGeneration + ")")
}
}
+
+ private def handleExecutorGained(execId: String, hostPort: String) {
+ // remove from failedGeneration(execId) ?
+ if (failedGeneration.contains(execId)) {
+ logInfo("Host gained which was in lost list earlier: " + hostPort)
+ failedGeneration -= execId
+ }
+ }
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
@@ -640,7 +669,9 @@ class DAGScheduler(
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
- job.listener.jobFailed(new SparkException("Job failed: " + reason))
+ val error = new SparkException("Job failed: " + reason)
+ job.listener.jobFailed(error)
+ sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error))))
activeJobs -= job
resultStageToJob -= resultStage
}
diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
index ed0b9bf178..acad915f13 100644
--- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
@@ -1,5 +1,7 @@
package spark.scheduler
+import java.util.Properties
+
import spark.scheduler.cluster.TaskInfo
import scala.collection.mutable.Map
@@ -20,7 +22,8 @@ private[spark] case class JobSubmitted(
partitions: Array[Int],
allowLocal: Boolean,
callSite: String,
- listener: JobListener)
+ listener: JobListener,
+ properties: Properties = null)
extends DAGSchedulerEvent
private[spark] case class CompletionEvent(
@@ -32,6 +35,10 @@ private[spark] case class CompletionEvent(
taskMetrics: TaskMetrics)
extends DAGSchedulerEvent
+private[spark] case class ExecutorGained(execId: String, hostPort: String) extends DAGSchedulerEvent {
+ Utils.checkHostPort(hostPort, "Required hostport")
+}
+
private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
diff --git a/core/src/main/scala/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/spark/scheduler/InputFormatInfo.scala
new file mode 100644
index 0000000000..287f731787
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/InputFormatInfo.scala
@@ -0,0 +1,156 @@
+package spark.scheduler
+
+import spark.Logging
+import scala.collection.immutable.Set
+import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
+import org.apache.hadoop.util.ReflectionUtils
+import org.apache.hadoop.mapreduce.Job
+import org.apache.hadoop.conf.Configuration
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.JavaConversions._
+
+
+/**
+ * Parses and holds information about inputFormat (and files) specified as a parameter.
+ */
+class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_],
+ val path: String) extends Logging {
+
+ var mapreduceInputFormat: Boolean = false
+ var mapredInputFormat: Boolean = false
+
+ validate()
+
+ override def toString(): String = {
+ "InputFormatInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + ", path : " + path
+ }
+
+ override def hashCode(): Int = {
+ var hashCode = inputFormatClazz.hashCode
+ hashCode = hashCode * 31 + path.hashCode
+ hashCode
+ }
+
+ // Since we are not doing canonicalization of path, this can be wrong : like relative vs absolute path
+ // .. which is fine, this is best case effort to remove duplicates - right ?
+ override def equals(other: Any): Boolean = other match {
+ case that: InputFormatInfo => {
+ // not checking config - that should be fine, right ?
+ this.inputFormatClazz == that.inputFormatClazz &&
+ this.path == that.path
+ }
+ case _ => false
+ }
+
+ private def validate() {
+ logDebug("validate InputFormatInfo : " + inputFormatClazz + ", path " + path)
+
+ try {
+ if (classOf[org.apache.hadoop.mapreduce.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) {
+ logDebug("inputformat is from mapreduce package")
+ mapreduceInputFormat = true
+ }
+ else if (classOf[org.apache.hadoop.mapred.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) {
+ logDebug("inputformat is from mapred package")
+ mapredInputFormat = true
+ }
+ else {
+ throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz +
+ " is NOT a supported input format ? does not implement either of the supported hadoop api's")
+ }
+ }
+ catch {
+ case e: ClassNotFoundException => {
+ throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + " cannot be found ?", e)
+ }
+ }
+ }
+
+
+ // This method does not expect failures, since validate has already passed ...
+ private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = {
+ val conf = new JobConf(configuration)
+ FileInputFormat.setInputPaths(conf, path)
+
+ val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] =
+ ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[
+ org.apache.hadoop.mapreduce.InputFormat[_, _]]
+ val job = new Job(conf)
+
+ val retval = new ArrayBuffer[SplitInfo]()
+ val list = instance.getSplits(job)
+ for (split <- list) {
+ retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split)
+ }
+
+ return retval.toSet
+ }
+
+ // This method does not expect failures, since validate has already passed ...
+ private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = {
+ val jobConf = new JobConf(configuration)
+ FileInputFormat.setInputPaths(jobConf, path)
+
+ val instance: org.apache.hadoop.mapred.InputFormat[_, _] =
+ ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], jobConf).asInstanceOf[
+ org.apache.hadoop.mapred.InputFormat[_, _]]
+
+ val retval = new ArrayBuffer[SplitInfo]()
+ instance.getSplits(jobConf, jobConf.getNumMapTasks()).foreach(
+ elem => retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, elem)
+ )
+
+ return retval.toSet
+ }
+
+ private def findPreferredLocations(): Set[SplitInfo] = {
+ logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat +
+ ", inputFormatClazz : " + inputFormatClazz)
+ if (mapreduceInputFormat) {
+ return prefLocsFromMapreduceInputFormat()
+ }
+ else {
+ assert(mapredInputFormat)
+ return prefLocsFromMapredInputFormat()
+ }
+ }
+}
+
+
+
+
+object InputFormatInfo {
+ /**
+ Computes the preferred locations based on input(s) and returned a location to block map.
+ Typical use of this method for allocation would follow some algo like this
+ (which is what we currently do in YARN branch) :
+ a) For each host, count number of splits hosted on that host.
+ b) Decrement the currently allocated containers on that host.
+ c) Compute rack info for each host and update rack -> count map based on (b).
+ d) Allocate nodes based on (c)
+ e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node
+ (even if data locality on that is very high) : this is to prevent fragility of job if a single
+ (or small set of) hosts go down.
+
+ go to (a) until required nodes are allocated.
+
+ If a node 'dies', follow same procedure.
+
+ PS: I know the wording here is weird, hopefully it makes some sense !
+ */
+ def computePreferredLocations(formats: Seq[InputFormatInfo]): HashMap[String, HashSet[SplitInfo]] = {
+
+ val nodeToSplit = new HashMap[String, HashSet[SplitInfo]]
+ for (inputSplit <- formats) {
+ val splits = inputSplit.findPreferredLocations()
+
+ for (split <- splits){
+ val location = split.hostLocation
+ val set = nodeToSplit.getOrElseUpdate(location, new HashSet[SplitInfo])
+ set += split
+ }
+ }
+
+ nodeToSplit
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala
new file mode 100644
index 0000000000..178bfaba3d
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/JobLogger.scala
@@ -0,0 +1,306 @@
+package spark.scheduler
+
+import java.io.PrintWriter
+import java.io.File
+import java.io.FileNotFoundException
+import java.text.SimpleDateFormat
+import java.util.{Date, Properties}
+import java.util.concurrent.LinkedBlockingQueue
+import scala.collection.mutable.{Map, HashMap, ListBuffer}
+import scala.io.Source
+import spark._
+import spark.executor.TaskMetrics
+import spark.scheduler.cluster.TaskInfo
+
+// Used to record runtime information for each job, including RDD graph
+// tasks' start/stop shuffle information and information from outside
+
+class JobLogger(val logDirName: String) extends SparkListener with Logging {
+ private val logDir =
+ if (System.getenv("SPARK_LOG_DIR") != null)
+ System.getenv("SPARK_LOG_DIR")
+ else
+ "/tmp/spark"
+ private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
+ private val stageIDToJobID = new HashMap[Int, Int]
+ private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
+ private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
+
+ createLogDir()
+ def this() = this(String.valueOf(System.currentTimeMillis()))
+
+ def getLogDir = logDir
+ def getJobIDtoPrintWriter = jobIDToPrintWriter
+ def getStageIDToJobID = stageIDToJobID
+ def getJobIDToStages = jobIDToStages
+ def getEventQueue = eventQueue
+
+ new Thread("JobLogger") {
+ setDaemon(true)
+ override def run() {
+ while (true) {
+ val event = eventQueue.take
+ logDebug("Got event of type " + event.getClass.getName)
+ event match {
+ case SparkListenerJobStart(job, properties) =>
+ processJobStartEvent(job, properties)
+ case SparkListenerStageSubmitted(stage, taskSize) =>
+ processStageSubmittedEvent(stage, taskSize)
+ case StageCompleted(stageInfo) =>
+ processStageCompletedEvent(stageInfo)
+ case SparkListenerJobEnd(job, result) =>
+ processJobEndEvent(job, result)
+ case SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics) =>
+ processTaskEndEvent(task, reason, taskInfo, taskMetrics)
+ case _ =>
+ }
+ }
+ }
+ }.start()
+
+ // Create a folder for log files, the folder's name is the creation time of the jobLogger
+ protected def createLogDir() {
+ val dir = new File(logDir + "/" + logDirName + "/")
+ if (dir.exists()) {
+ return
+ }
+ if (dir.mkdirs() == false) {
+ logError("create log directory error:" + logDir + "/" + logDirName + "/")
+ }
+ }
+
+ // Create a log file for one job, the file name is the jobID
+ protected def createLogWriter(jobID: Int) {
+ try{
+ val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
+ jobIDToPrintWriter += (jobID -> fileWriter)
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
+ }
+
+ // Close log file, and clean the stage relationship in stageIDToJobID
+ protected def closeLogWriter(jobID: Int) =
+ jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
+ stageIDToJobID -= stage.id
+ })
+ jobIDToPrintWriter -= jobID
+ jobIDToStages -= jobID
+ }
+
+ // Write log information to log file, withTime parameter controls whether to recored
+ // time stamp for the information
+ protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
+ var writeInfo = info
+ if (withTime) {
+ val date = new Date(System.currentTimeMillis())
+ writeInfo = DATE_FORMAT.format(date) + ": " +info
+ }
+ jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
+ }
+
+ protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
+ stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
+
+ protected def buildJobDep(jobID: Int, stage: Stage) {
+ if (stage.priority == jobID) {
+ jobIDToStages.get(jobID) match {
+ case Some(stageList) => stageList += stage
+ case None => val stageList = new ListBuffer[Stage]
+ stageList += stage
+ jobIDToStages += (jobID -> stageList)
+ }
+ stageIDToJobID += (stage.id -> jobID)
+ stage.parents.foreach(buildJobDep(jobID, _))
+ }
+ }
+
+ protected def recordStageDep(jobID: Int) {
+ def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
+ var rddList = new ListBuffer[RDD[_]]
+ rddList += rdd
+ rdd.dependencies.foreach{ dep => dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ case _ => rddList ++= getRddsInStage(dep.rdd)
+ }
+ }
+ rddList
+ }
+ jobIDToStages.get(jobID).foreach {_.foreach { stage =>
+ var depRddDesc: String = ""
+ getRddsInStage(stage.rdd).foreach { rdd =>
+ depRddDesc += rdd.id + ","
+ }
+ var depStageDesc: String = ""
+ stage.parents.foreach { stage =>
+ depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
+ }
+ jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
+ depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
+ " STAGE_DEP=" + depStageDesc, false)
+ }
+ }
+ }
+
+ // Generate indents and convert to String
+ protected def indentString(indent: Int) = {
+ val sb = new StringBuilder()
+ for (i <- 1 to indent) {
+ sb.append(" ")
+ }
+ sb.toString()
+ }
+
+ protected def getRddName(rdd: RDD[_]) = {
+ var rddName = rdd.getClass.getName
+ if (rdd.name != null) {
+ rddName = rdd.name
+ }
+ rddName
+ }
+
+ protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
+ val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
+ jobLogInfo(jobID, indentString(indent) + rddInfo, false)
+ rdd.dependencies.foreach{ dep => dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
+ }
+ }
+ }
+
+ protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
+ var stageInfo: String = ""
+ if (stage.isShuffleMap) {
+ stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
+ stage.shuffleDep.get.shuffleId
+ }else{
+ stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ }
+ if (stage.priority == jobID) {
+ jobLogInfo(jobID, indentString(indent) + stageInfo, false)
+ recordRddInStageGraph(jobID, stage.rdd, indent)
+ stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
+ } else
+ jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.priority, false)
+ }
+
+ // Record task metrics into job log files
+ protected def recordTaskMetrics(stageID: Int, status: String,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
+ " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
+ " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
+ val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
+ val readMetrics =
+ taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics =
+ taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
+ stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
+ }
+
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
+ eventQueue.put(stageSubmitted)
+ }
+
+ protected def processStageSubmittedEvent(stage: Stage, taskSize: Int) {
+ stageLogInfo(stage.id, "STAGE_ID=" + stage.id + " STATUS=SUBMITTED" + " TASK_SIZE=" + taskSize)
+ }
+
+ override def onStageCompleted(stageCompleted: StageCompleted) {
+ eventQueue.put(stageCompleted)
+ }
+
+ protected def processStageCompletedEvent(stageInfo: StageInfo) {
+ stageLogInfo(stageInfo.stage.id, "STAGE_ID=" +
+ stageInfo.stage.id + " STATUS=COMPLETED")
+
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ eventQueue.put(taskEnd)
+ }
+
+ protected def processTaskEndEvent(task: Task[_], reason: TaskEndReason,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ var taskStatus = ""
+ task match {
+ case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
+ case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
+ }
+ reason match {
+ case Success => taskStatus += " STATUS=SUCCESS"
+ recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskMetrics)
+ case Resubmitted =>
+ taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId
+ stageLogInfo(task.stageId, taskStatus)
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
+ task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
+ mapId + " REDUCE_ID=" + reduceId
+ stageLogInfo(task.stageId, taskStatus)
+ case OtherFailure(message) =>
+ taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId + " INFO=" + message
+ stageLogInfo(task.stageId, taskStatus)
+ case _ =>
+ }
+ }
+
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) {
+ eventQueue.put(jobEnd)
+ }
+
+ protected def processJobEndEvent(job: ActiveJob, reason: JobResult) {
+ var info = "JOB_ID=" + job.runId
+ reason match {
+ case JobSucceeded => info += " STATUS=SUCCESS"
+ case JobFailed(exception) =>
+ info += " STATUS=FAILED REASON="
+ exception.getMessage.split("\\s+").foreach(info += _ + "_")
+ case _ =>
+ }
+ jobLogInfo(job.runId, info.substring(0, info.length - 1).toUpperCase)
+ closeLogWriter(job.runId)
+ }
+
+ protected def recordJobProperties(jobID: Int, properties: Properties) {
+ if(properties != null) {
+ val annotation = properties.getProperty("spark.job.annotation", "")
+ jobLogInfo(jobID, annotation, false)
+ }
+ }
+
+ override def onJobStart(jobStart: SparkListenerJobStart) {
+ eventQueue.put(jobStart)
+ }
+
+ protected def processJobStartEvent(job: ActiveJob, properties: Properties) {
+ createLogWriter(job.runId)
+ recordJobProperties(job.runId, properties)
+ buildJobDep(job.runId, job.finalStage)
+ recordStageDep(job.runId)
+ recordStageDepGraph(job.runId, job.finalStage)
+ jobLogInfo(job.runId, "JOB_ID=" + job.runId + " STATUS=STARTED")
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index beb21a76fe..83166bce22 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -70,6 +70,13 @@ private[spark] class ResultTask[T, U](
rdd.partitions(partition)
}
+ private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.toSet.toSeq
+
+ {
+ // DEBUG code
+ preferredLocs.foreach (hostPort => Utils.checkHost(Utils.parseHostPort(hostPort)._1, "preferredLocs : " + preferredLocs))
+ }
+
override def run(attemptId: Long): U = {
val context = new TaskContext(stageId, partition, attemptId)
metrics = Some(context.taskMetrics)
@@ -80,7 +87,7 @@ private[spark] class ResultTask[T, U](
}
}
- override def preferredLocations: Seq[String] = locs
+ override def preferredLocations: Seq[String] = preferredLocs
override def toString = "ResultTask(" + stageId + ", " + partition + ")"
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 36d087a4d0..95647389c3 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -13,9 +13,10 @@ import com.ning.compress.lzf.LZFInputStream
import com.ning.compress.lzf.LZFOutputStream
import spark._
-import executor.ShuffleWriteMetrics
+import spark.executor.ShuffleWriteMetrics
import spark.storage._
-import util.{TimeStampedHashMap, MetadataCleaner}
+import spark.util.{TimeStampedHashMap, MetadataCleaner}
+
private[spark] object ShuffleMapTask {
@@ -77,13 +78,20 @@ private[spark] class ShuffleMapTask(
var rdd: RDD[_],
var dep: ShuffleDependency[_,_],
var partition: Int,
- @transient var locs: Seq[String])
+ @transient private var locs: Seq[String])
extends Task[MapStatus](stageId)
with Externalizable
with Logging {
protected def this() = this(0, null, null, 0, null)
+ @transient private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.toSet.toSeq
+
+ {
+ // DEBUG code
+ preferredLocs.foreach (hostPort => Utils.checkHost(Utils.parseHostPort(hostPort)._1, "preferredLocs : " + preferredLocs))
+ }
+
var split = if (rdd == null) {
null
} else {
@@ -121,40 +129,58 @@ private[spark] class ShuffleMapTask(
val taskContext = new TaskContext(stageId, partition, attemptId)
metrics = Some(taskContext.taskMetrics)
+
+ val blockManager = SparkEnv.get.blockManager
+ var shuffle: ShuffleBlocks = null
+ var buckets: ShuffleWriterGroup = null
+
try {
- // Partition the map output.
- val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
+ // Obtain all the block writers for shuffle blocks.
+ val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
+ shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
+ buckets = shuffle.acquireWriters(partition)
+
+ // Write the map output to its associated buckets.
for (elem <- rdd.iterator(split, taskContext)) {
val pair = elem.asInstanceOf[(Any, Any)]
val bucketId = dep.partitioner.getPartition(pair._1)
- buckets(bucketId) += pair
+ buckets.writers(bucketId).write(pair)
}
- val compressedSizes = new Array[Byte](numOutputSplits)
-
- var totalBytes = 0l
-
- val blockManager = SparkEnv.get.blockManager
- for (i <- 0 until numOutputSplits) {
- val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
- // Get a Scala iterator from Java map
- val iter: Iterator[(Any, Any)] = buckets(i).iterator
- val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
+ // Commit the writes. Get the size of each bucket block (total block size).
+ var totalBytes = 0L
+ val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
+ writer.commit()
+ writer.close()
+ val size = writer.size()
totalBytes += size
- compressedSizes(i) = MapOutputTracker.compressSize(size)
+ MapOutputTracker.compressSize(size)
}
+
+ // Update shuffle metrics.
val shuffleMetrics = new ShuffleWriteMetrics
shuffleMetrics.shuffleBytesWritten = totalBytes
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
return new MapStatus(blockManager.blockManagerId, compressedSizes)
+ } catch { case e: Exception =>
+ // If there is an exception from running the task, revert the partial writes
+ // and throw the exception upstream to Spark.
+ if (buckets != null) {
+ buckets.writers.foreach(_.revertPartialWrites())
+ }
+ throw e
} finally {
+ // Release the writers back to the shuffle block manager.
+ if (shuffle != null && buckets != null) {
+ shuffle.releaseWriters(buckets)
+ }
// Execute the callbacks on task completion.
taskContext.executeOnCompleteCallbacks()
}
}
- override def preferredLocations: Seq[String] = locs
+ override def preferredLocations: Seq[String] = preferredLocs
override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
}
diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala
index a65140b145..bac984b5c9 100644
--- a/core/src/main/scala/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/spark/scheduler/SparkListener.scala
@@ -1,27 +1,59 @@
package spark.scheduler
+import java.util.Properties
import spark.scheduler.cluster.TaskInfo
import spark.util.Distribution
-import spark.{Utils, Logging}
+import spark.{Logging, SparkContext, TaskEndReason, Utils}
import spark.executor.TaskMetrics
-trait SparkListener {
- /**
- * called when a stage is completed, with information on the completed stage
- */
- def onStageCompleted(stageCompleted: StageCompleted)
-}
-
sealed trait SparkListenerEvents
+case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int) extends SparkListenerEvents
+
case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
+case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics) extends SparkListenerEvents
+
+case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null)
+ extends SparkListenerEvents
+
+case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult)
+ extends SparkListenerEvents
+
+trait SparkListener {
+ /**
+ * Called when a stage is completed, with information on the completed stage
+ */
+ def onStageCompleted(stageCompleted: StageCompleted) { }
+
+ /**
+ * Called when a stage is submitted
+ */
+ def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { }
+
+ /**
+ * Called when a task ends
+ */
+ def onTaskEnd(taskEnd: SparkListenerTaskEnd) { }
+
+ /**
+ * Called when a job starts
+ */
+ def onJobStart(jobStart: SparkListenerJobStart) { }
+
+ /**
+ * Called when a job ends
+ */
+ def onJobEnd(jobEnd: SparkListenerJobEnd) { }
+
+}
/**
* Simple SparkListener that logs a few summary statistics when each stage completes
*/
class StatsReportListener extends SparkListener with Logging {
- def onStageCompleted(stageCompleted: StageCompleted) {
+ override def onStageCompleted(stageCompleted: StageCompleted) {
import spark.scheduler.StatsReportListener._
implicit val sc = stageCompleted
this.logInfo("Finished stage: " + stageCompleted.stageInfo)
diff --git a/core/src/main/scala/spark/scheduler/SplitInfo.scala b/core/src/main/scala/spark/scheduler/SplitInfo.scala
new file mode 100644
index 0000000000..6abfb7a1f7
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/SplitInfo.scala
@@ -0,0 +1,61 @@
+package spark.scheduler
+
+import collection.mutable.ArrayBuffer
+
+// information about a specific split instance : handles both split instances.
+// So that we do not need to worry about the differences.
+class SplitInfo(val inputFormatClazz: Class[_], val hostLocation: String, val path: String,
+ val length: Long, val underlyingSplit: Any) {
+ override def toString(): String = {
+ "SplitInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz +
+ ", hostLocation : " + hostLocation + ", path : " + path +
+ ", length : " + length + ", underlyingSplit " + underlyingSplit
+ }
+
+ override def hashCode(): Int = {
+ var hashCode = inputFormatClazz.hashCode
+ hashCode = hashCode * 31 + hostLocation.hashCode
+ hashCode = hashCode * 31 + path.hashCode
+ // ignore overflow ? It is hashcode anyway !
+ hashCode = hashCode * 31 + (length & 0x7fffffff).toInt
+ hashCode
+ }
+
+ // This is practically useless since most of the Split impl's dont seem to implement equals :-(
+ // So unless there is identity equality between underlyingSplits, it will always fail even if it
+ // is pointing to same block.
+ override def equals(other: Any): Boolean = other match {
+ case that: SplitInfo => {
+ this.hostLocation == that.hostLocation &&
+ this.inputFormatClazz == that.inputFormatClazz &&
+ this.path == that.path &&
+ this.length == that.length &&
+ // other split specific checks (like start for FileSplit)
+ this.underlyingSplit == that.underlyingSplit
+ }
+ case _ => false
+ }
+}
+
+object SplitInfo {
+
+ def toSplitInfo(inputFormatClazz: Class[_], path: String,
+ mapredSplit: org.apache.hadoop.mapred.InputSplit): Seq[SplitInfo] = {
+ val retval = new ArrayBuffer[SplitInfo]()
+ val length = mapredSplit.getLength
+ for (host <- mapredSplit.getLocations) {
+ retval += new SplitInfo(inputFormatClazz, host, path, length, mapredSplit)
+ }
+ retval
+ }
+
+ def toSplitInfo(inputFormatClazz: Class[_], path: String,
+ mapreduceSplit: org.apache.hadoop.mapreduce.InputSplit): Seq[SplitInfo] = {
+ val retval = new ArrayBuffer[SplitInfo]()
+ val length = mapreduceSplit.getLength
+ for (host <- mapreduceSplit.getLocations) {
+ retval += new SplitInfo(inputFormatClazz, host, path, length, mapreduceSplit)
+ }
+ retval
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala
index 552061e46b..7fc9e13fd9 100644
--- a/core/src/main/scala/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/spark/scheduler/Stage.scala
@@ -26,7 +26,7 @@ private[spark] class Stage(
val parents: List[Stage],
val priority: Int)
extends Logging {
-
+
val isShuffleMap = shuffleDep != None
val numPartitions = rdd.partitions.size
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
@@ -60,7 +60,7 @@ private[spark] class Stage(
numAvailableOutputs -= 1
}
}
-
+
def removeOutputsOnExecutor(execId: String) {
var becameUnavailable = false
for (partition <- 0 until numPartitions) {
diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala
index d549b184b0..7787b54762 100644
--- a/core/src/main/scala/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/TaskScheduler.scala
@@ -10,6 +10,10 @@ package spark.scheduler
private[spark] trait TaskScheduler {
def start(): Unit
+ // Invoked after system has successfully initialized (typically in spark context).
+ // Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc.
+ def postStartHook() { }
+
// Disconnect from the cluster.
def stop(): Unit
diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
index 771518dddf..b75d3736cf 100644
--- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
+++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
@@ -14,6 +14,9 @@ private[spark] trait TaskSchedulerListener {
def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any],
taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit
+ // A node was added to the cluster.
+ def executorGained(execId: String, hostPort: String): Unit
+
// A node was lost from the cluster.
def executorLost(execId: String): Unit
diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala
index a3002ca477..e4b5fcaedb 100644
--- a/core/src/main/scala/spark/scheduler/TaskSet.scala
+++ b/core/src/main/scala/spark/scheduler/TaskSet.scala
@@ -1,11 +1,18 @@
package spark.scheduler
+import java.util.Properties
+
/**
* A set of tasks submitted together to the low-level TaskScheduler, usually representing
* missing partitions of a particular stage.
*/
-private[spark] class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) {
- val id: String = stageId + "." + attempt
+private[spark] class TaskSet(
+ val tasks: Array[Task[_]],
+ val stageId: Int,
+ val attempt: Int,
+ val priority: Int,
+ val properties: Properties) {
+ val id: String = stageId + "." + attempt
override def toString: String = "TaskSet " + id
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index 26fdef101b..3a0c29b27f 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -1,6 +1,6 @@
package spark.scheduler.cluster
-import java.io.{File, FileInputStream, FileOutputStream}
+import java.lang.{Boolean => JBoolean}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@@ -25,17 +25,45 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
// Threshold above which we warn user initial TaskSet may be starved
val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong
+ // How often to revive offers in case there are pending tasks - that is how often to try to get
+ // tasks scheduled in case there are nodes available : default 0 is to disable it - to preserve existing behavior
+ // Note that this is required due to delayed scheduling due to data locality waits, etc.
+ // TODO: rename property ?
+ val TASK_REVIVAL_INTERVAL = System.getProperty("spark.tasks.revive.interval", "0").toLong
+
+ /*
+ This property controls how aggressive we should be to modulate waiting for node local task scheduling.
+ To elaborate, currently there is a time limit (3 sec def) to ensure that spark attempts to wait for node locality of tasks before
+ scheduling on other nodes. We have modified this in yarn branch such that offers to task set happen in prioritized order :
+ node-local, rack-local and then others
+ But once all available node local (and no pref) tasks are scheduled, instead of waiting for 3 sec before
+ scheduling to other nodes (which degrades performance for time sensitive tasks and on larger clusters), we can
+ modulate that : to also allow rack local nodes or any node. The default is still set to HOST - so that previous behavior is
+ maintained. This is to allow tuning the tension between pulling rdd data off node and scheduling computation asap.
+
+ TODO: rename property ? The value is one of
+ - NODE_LOCAL (default, no change w.r.t current behavior),
+ - RACK_LOCAL and
+ - ANY
+
+ Note that this property makes more sense when used in conjugation with spark.tasks.revive.interval > 0 : else it is not very effective.
+
+ Additional Note: For non trivial clusters, there is a 4x - 5x reduction in running time (in some of our experiments) based on whether
+ it is left at default NODE_LOCAL, RACK_LOCAL (if cluster is configured to be rack aware) or ANY.
+ If cluster is rack aware, then setting it to RACK_LOCAL gives best tradeoff and a 3x - 4x performance improvement while minimizing IO impact.
+ Also, it brings down the variance in running time drastically.
+ */
+ val TASK_SCHEDULING_AGGRESSION = TaskLocality.parse(System.getProperty("spark.tasks.schedule.aggression", "NODE_LOCAL"))
val activeTaskSets = new HashMap[String, TaskSetManager]
- var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
val taskIdToTaskSetId = new HashMap[Long, String]
val taskIdToExecutorId = new HashMap[Long, String]
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
- var hasReceivedTask = false
- var hasLaunchedTask = false
- val starvationTimer = new Timer(true)
+ @volatile private var hasReceivedTask = false
+ @volatile private var hasLaunchedTask = false
+ private val starvationTimer = new Timer(true)
// Incrementing Mesos task IDs
val nextTaskId = new AtomicLong(0)
@@ -43,11 +71,16 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// Which executor IDs we have executors on
val activeExecutorIds = new HashSet[String]
+ // TODO: We might want to remove this and merge it with execId datastructures - but later.
+ // Which hosts in the cluster are alive (contains hostPort's) - used for process local and node local task locality.
+ private val hostPortsAlive = new HashSet[String]
+ private val hostToAliveHostPorts = new HashMap[String, HashSet[String]]
+
// The set of executors we have on each host; this is used to compute hostsAlive, which
// in turn is used to decide when we can attain data locality on a given host
- val executorsByHost = new HashMap[String, HashSet[String]]
+ private val executorsByHostPort = new HashMap[String, HashSet[String]]
- val executorIdToHost = new HashMap[String, String]
+ private val executorIdToHostPort = new HashMap[String, String]
// JAR server, if any JARs were added by the user to the SparkContext
var jarServer: HttpServer = null
@@ -62,24 +95,50 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
val mapOutputTracker = SparkEnv.get.mapOutputTracker
+ var schedulableBuilder: SchedulableBuilder = null
+ var rootPool: Pool = null
+
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
}
def initialize(context: SchedulerBackend) {
backend = context
+ //default scheduler is FIFO
+ val schedulingMode = System.getProperty("spark.cluster.schedulingmode", "FIFO")
+ //temporarily set rootPool name to empty
+ rootPool = new Pool("", SchedulingMode.withName(schedulingMode), 0, 0)
+ schedulableBuilder = {
+ schedulingMode match {
+ case "FIFO" =>
+ new FIFOSchedulableBuilder(rootPool)
+ case "FAIR" =>
+ new FairSchedulableBuilder(rootPool)
+ }
+ }
+ schedulableBuilder.buildPools()
+ // resolve executorId to hostPort mapping.
+ def executorToHostPort(executorId: String, defaultHostPort: String): String = {
+ executorIdToHostPort.getOrElse(executorId, defaultHostPort)
+ }
+
+ // Unfortunately, this means that SparkEnv is indirectly referencing ClusterScheduler
+ // Will that be a design violation ?
+ SparkEnv.get.executorIdToHostPort = Some(executorToHostPort)
}
+
def newTaskId(): Long = nextTaskId.getAndIncrement()
override def start() {
backend.start()
- if (System.getProperty("spark.speculation", "false") == "true") {
+ if (JBoolean.getBoolean("spark.speculation")) {
new Thread("ClusterScheduler speculation check") {
setDaemon(true)
override def run() {
+ logInfo("Starting speculative execution thread")
while (true) {
try {
Thread.sleep(SPECULATION_INTERVAL)
@@ -91,15 +150,36 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}.start()
}
+
+
+ // Change to always run with some default if TASK_REVIVAL_INTERVAL <= 0 ?
+ if (TASK_REVIVAL_INTERVAL > 0) {
+ new Thread("ClusterScheduler task offer revival check") {
+ setDaemon(true)
+
+ override def run() {
+ logInfo("Starting speculative task offer revival thread")
+ while (true) {
+ try {
+ Thread.sleep(TASK_REVIVAL_INTERVAL)
+ } catch {
+ case e: InterruptedException => {}
+ }
+
+ if (hasPendingTasks()) backend.reviveOffers()
+ }
+ }
+ }.start()
+ }
}
override def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
- val manager = new TaskSetManager(this, taskSet)
+ val manager = new ClusterTaskSetManager(this, taskSet)
activeTaskSets(taskSet.id) = manager
- activeTaskSetsQueue += manager
+ schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
if (hasReceivedTask == false) {
@@ -122,7 +202,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
def taskSetFinished(manager: TaskSetManager) {
this.synchronized {
activeTaskSets -= manager.taskSet.id
- activeTaskSetsQueue -= manager
+ manager.parent.removeSchedulable(manager)
+ logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
taskSetTaskIds.remove(manager.taskSet.id)
@@ -139,22 +220,128 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
SparkEnv.set(sc.env)
// Mark each slave as alive and remember its hostname
for (o <- offers) {
- executorIdToHost(o.executorId) = o.hostname
- if (!executorsByHost.contains(o.hostname)) {
- executorsByHost(o.hostname) = new HashSet()
+ // DEBUG Code
+ Utils.checkHostPort(o.hostPort)
+
+ executorIdToHostPort(o.executorId) = o.hostPort
+ if (! executorsByHostPort.contains(o.hostPort)) {
+ executorsByHostPort(o.hostPort) = new HashSet[String]()
}
+
+ hostPortsAlive += o.hostPort
+ hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(o.hostPort)._1, new HashSet[String]).add(o.hostPort)
+ executorGained(o.executorId, o.hostPort)
}
// Build a list of tasks to assign to each slave
val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
+ // merge availableCpus into nodeToAvailableCpus block ?
val availableCpus = offers.map(o => o.cores).toArray
+ val nodeToAvailableCpus = {
+ val map = new HashMap[String, Int]()
+ for (offer <- offers) {
+ val hostPort = offer.hostPort
+ val cores = offer.cores
+ // DEBUG code
+ Utils.checkHostPort(hostPort)
+
+ val host = Utils.parseHostPort(hostPort)._1
+
+ map.put(host, map.getOrElse(host, 0) + cores)
+ }
+
+ map
+ }
var launchedTask = false
- for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) {
+ val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
+ for (manager <- sortedTaskSetQueue)
+ {
+ logInfo("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks))
+ }
+ for (manager <- sortedTaskSetQueue) {
+
+ // Split offers based on node local, rack local and off-rack tasks.
+ val processLocalOffers = new HashMap[String, ArrayBuffer[Int]]()
+ val nodeLocalOffers = new HashMap[String, ArrayBuffer[Int]]()
+ val rackLocalOffers = new HashMap[String, ArrayBuffer[Int]]()
+ val otherOffers = new HashMap[String, ArrayBuffer[Int]]()
+
+ for (i <- 0 until offers.size) {
+ val hostPort = offers(i).hostPort
+ // DEBUG code
+ Utils.checkHostPort(hostPort)
+
+ val numProcessLocalTasks = math.max(0, math.min(manager.numPendingTasksForHostPort(hostPort), availableCpus(i)))
+ if (numProcessLocalTasks > 0){
+ val list = processLocalOffers.getOrElseUpdate(hostPort, new ArrayBuffer[Int])
+ for (j <- 0 until numProcessLocalTasks) list += i
+ }
+
+ val host = Utils.parseHostPort(hostPort)._1
+ val numNodeLocalTasks = math.max(0,
+ // Remove process local tasks (which are also host local btw !) from this
+ math.min(manager.numPendingTasksForHost(hostPort) - numProcessLocalTasks, nodeToAvailableCpus(host)))
+ if (numNodeLocalTasks > 0){
+ val list = nodeLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
+ for (j <- 0 until numNodeLocalTasks) list += i
+ }
+
+ val numRackLocalTasks = math.max(0,
+ // Remove node local tasks (which are also rack local btw !) from this
+ math.min(manager.numRackLocalPendingTasksForHost(hostPort) - numProcessLocalTasks - numNodeLocalTasks, nodeToAvailableCpus(host)))
+ if (numRackLocalTasks > 0){
+ val list = rackLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
+ for (j <- 0 until numRackLocalTasks) list += i
+ }
+ if (numNodeLocalTasks <= 0 && numRackLocalTasks <= 0){
+ // add to others list - spread even this across cluster.
+ val list = otherOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
+ list += i
+ }
+ }
+
+ val offersPriorityList = new ArrayBuffer[Int](
+ processLocalOffers.size + nodeLocalOffers.size + rackLocalOffers.size + otherOffers.size)
+
+ // First process local, then host local, then rack, then others
+
+ // numNodeLocalOffers contains count of both process local and host offers.
+ val numNodeLocalOffers = {
+ val processLocalPriorityList = ClusterScheduler.prioritizeContainers(processLocalOffers)
+ offersPriorityList ++= processLocalPriorityList
+
+ val nodeLocalPriorityList = ClusterScheduler.prioritizeContainers(nodeLocalOffers)
+ offersPriorityList ++= nodeLocalPriorityList
+
+ processLocalPriorityList.size + nodeLocalPriorityList.size
+ }
+ val numRackLocalOffers = {
+ val rackLocalPriorityList = ClusterScheduler.prioritizeContainers(rackLocalOffers)
+ offersPriorityList ++= rackLocalPriorityList
+ rackLocalPriorityList.size
+ }
+ offersPriorityList ++= ClusterScheduler.prioritizeContainers(otherOffers)
+
+ var lastLoop = false
+ val lastLoopIndex = TASK_SCHEDULING_AGGRESSION match {
+ case TaskLocality.NODE_LOCAL => numNodeLocalOffers
+ case TaskLocality.RACK_LOCAL => numRackLocalOffers + numNodeLocalOffers
+ case TaskLocality.ANY => offersPriorityList.size
+ }
+
do {
launchedTask = false
- for (i <- 0 until offers.size) {
+ var loopCount = 0
+ for (i <- offersPriorityList) {
val execId = offers(i).executorId
- val host = offers(i).hostname
- manager.slaveOffer(execId, host, availableCpus(i)) match {
+ val hostPort = offers(i).hostPort
+
+ // If last loop and within the lastLoopIndex, expand scope - else use null (which will use default/existing)
+ val overrideLocality = if (lastLoop && loopCount < lastLoopIndex) TASK_SCHEDULING_AGGRESSION else null
+
+ // If last loop, override waiting for host locality - we scheduled all local tasks already and there might be more available ...
+ loopCount += 1
+
+ manager.slaveOffer(execId, hostPort, availableCpus(i), overrideLocality) match {
case Some(task) =>
tasks(i) += task
val tid = task.taskId
@@ -162,15 +349,31 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
taskSetTaskIds(manager.taskSet.id) += tid
taskIdToExecutorId(tid) = execId
activeExecutorIds += execId
- executorsByHost(host) += execId
+ executorsByHostPort(hostPort) += execId
availableCpus(i) -= 1
launchedTask = true
case None => {}
+ }
+ }
+ // Loop once more - when lastLoop = true, then we try to schedule task on all nodes irrespective of
+ // data locality (we still go in order of priority : but that would not change anything since
+ // if data local tasks had been available, we would have scheduled them already)
+ if (lastLoop) {
+ // prevent more looping
+ launchedTask = false
+ } else if (!lastLoop && !launchedTask) {
+ // Do this only if TASK_SCHEDULING_AGGRESSION != NODE_LOCAL
+ if (TASK_SCHEDULING_AGGRESSION != TaskLocality.NODE_LOCAL) {
+ // fudge launchedTask to ensure we loop once more
+ launchedTask = true
+ // dont loop anymore
+ lastLoop = true
}
}
} while (launchedTask)
}
+
if (tasks.size > 0) {
hasLaunchedTask = true
}
@@ -223,6 +426,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
backend.reviveOffers()
}
if (taskFailed) {
+
// Also revive offers if a task had failed for some reason other than host lost
backend.reviveOffers()
}
@@ -256,29 +460,40 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (jarServer != null) {
jarServer.stop()
}
+
+ // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out.
+ // TODO: Do something better !
+ Thread.sleep(5000L)
}
override def defaultParallelism() = backend.defaultParallelism()
+
// Check for speculatable tasks in all our active jobs.
def checkSpeculatableTasks() {
var shouldRevive = false
synchronized {
- for (ts <- activeTaskSetsQueue) {
- shouldRevive |= ts.checkSpeculatableTasks()
- }
+ shouldRevive = rootPool.checkSpeculatableTasks()
}
if (shouldRevive) {
backend.reviveOffers()
}
}
+ // Check for pending tasks in all our active jobs.
+ def hasPendingTasks(): Boolean = {
+ synchronized {
+ rootPool.hasPendingTasks()
+ }
+ }
+
def executorLost(executorId: String, reason: ExecutorLossReason) {
var failedExecutor: Option[String] = None
+
synchronized {
if (activeExecutorIds.contains(executorId)) {
- val host = executorIdToHost(executorId)
- logError("Lost executor %s on %s: %s".format(executorId, host, reason))
+ val hostPort = executorIdToHostPort(executorId)
+ logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason))
removeExecutor(executorId)
failedExecutor = Some(executorId)
} else {
@@ -296,19 +511,104 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
- /** Get a list of hosts that currently have executors */
- def hostsAlive: scala.collection.Set[String] = executorsByHost.keySet
-
/** Remove an executor from all our data structures and mark it as lost */
private def removeExecutor(executorId: String) {
activeExecutorIds -= executorId
- val host = executorIdToHost(executorId)
- val execs = executorsByHost.getOrElse(host, new HashSet)
+ val hostPort = executorIdToHostPort(executorId)
+ if (hostPortsAlive.contains(hostPort)) {
+ // DEBUG Code
+ Utils.checkHostPort(hostPort)
+
+ hostPortsAlive -= hostPort
+ hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(hostPort)._1, new HashSet[String]).remove(hostPort)
+ }
+
+ val execs = executorsByHostPort.getOrElse(hostPort, new HashSet)
execs -= executorId
if (execs.isEmpty) {
- executorsByHost -= host
+ executorsByHostPort -= hostPort
}
- executorIdToHost -= executorId
- activeTaskSetsQueue.foreach(_.executorLost(executorId, host))
+ executorIdToHostPort -= executorId
+ rootPool.executorLost(executorId, hostPort)
+ }
+
+ def executorGained(execId: String, hostPort: String) {
+ listener.executorGained(execId, hostPort)
+ }
+
+ def getExecutorsAliveOnHost(host: String): Option[Set[String]] = {
+ Utils.checkHost(host)
+
+ val retval = hostToAliveHostPorts.get(host)
+ if (retval.isDefined) {
+ return Some(retval.get.toSet)
+ }
+
+ None
+ }
+
+ def isExecutorAliveOnHostPort(hostPort: String): Boolean = {
+ // Even if hostPort is a host, it does not matter - it is just a specific check.
+ // But we do have to ensure that only hostPort get into hostPortsAlive !
+ // So no check against Utils.checkHostPort
+ hostPortsAlive.contains(hostPort)
+ }
+
+ // By default, rack is unknown
+ def getRackForHost(value: String): Option[String] = None
+
+ // By default, (cached) hosts for rack is unknown
+ def getCachedHostsForRack(rack: String): Option[Set[String]] = None
+}
+
+
+object ClusterScheduler {
+
+ // Used to 'spray' available containers across the available set to ensure too many containers on same host
+ // are not used up. Used in yarn mode and in task scheduling (when there are multiple containers available
+ // to execute a task)
+ // For example: yarn can returns more containers than we would have requested under ANY, this method
+ // prioritizes how to use the allocated containers.
+ // flatten the map such that the array buffer entries are spread out across the returned value.
+ // given <host, list[container]> == <h1, [c1 .. c5]>, <h2, [c1 .. c3]>, <h3, [c1, c2]>, <h4, c1>, <h5, c1>, i
+ // the return value would be something like : h1c1, h2c1, h3c1, h4c1, h5c1, h1c2, h2c2, h3c2, h1c3, h2c3, h1c4, h1c5
+ // We then 'use' the containers in this order (consuming only the top K from this list where
+ // K = number to be user). This is to ensure that if we have multiple eligible allocations,
+ // they dont end up allocating all containers on a small number of hosts - increasing probability of
+ // multiple container failure when a host goes down.
+ // Note, there is bias for keys with higher number of entries in value to be picked first (by design)
+ // Also note that invocation of this method is expected to have containers of same 'type'
+ // (host-local, rack-local, off-rack) and not across types : so that reordering is simply better from
+ // the available list - everything else being same.
+ // That is, we we first consume data local, then rack local and finally off rack nodes. So the
+ // prioritization from this method applies to within each category
+ def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = {
+ val _keyList = new ArrayBuffer[K](map.size)
+ _keyList ++= map.keys
+
+ // order keyList based on population of value in map
+ val keyList = _keyList.sortWith(
+ (left, right) => map.get(left).getOrElse(Set()).size > map.get(right).getOrElse(Set()).size
+ )
+
+ val retval = new ArrayBuffer[T](keyList.size * 2)
+ var index = 0
+ var found = true
+
+ while (found){
+ found = false
+ for (key <- keyList) {
+ val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null)
+ assert(containerList != null)
+ // Get the index'th entry for this host - if present
+ if (index < containerList.size){
+ retval += containerList.apply(index)
+ found = true
+ }
+ }
+ index += 1
+ }
+
+ retval.toList
}
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala
new file mode 100644
index 0000000000..d72b0bfc9f
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -0,0 +1,747 @@
+package spark.scheduler.cluster
+
+import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.math.max
+import scala.math.min
+
+import spark._
+import spark.scheduler._
+import spark.TaskState.TaskState
+import java.nio.ByteBuffer
+
+private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging {
+
+ // process local is expected to be used ONLY within tasksetmanager for now.
+ val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value
+
+ type TaskLocality = Value
+
+ def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = {
+
+ // Must not be the constraint.
+ assert (constraint != TaskLocality.PROCESS_LOCAL)
+
+ constraint match {
+ case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL
+ case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL
+ // For anything else, allow
+ case _ => true
+ }
+ }
+
+ def parse(str: String): TaskLocality = {
+ // better way to do this ?
+ try {
+ val retval = TaskLocality.withName(str)
+ // Must not specify PROCESS_LOCAL !
+ assert (retval != TaskLocality.PROCESS_LOCAL)
+
+ retval
+ } catch {
+ case nEx: NoSuchElementException => {
+ logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL");
+ // default to preserve earlier behavior
+ NODE_LOCAL
+ }
+ }
+ }
+}
+
+/**
+ * Schedules the tasks within a single TaskSet in the ClusterScheduler.
+ */
+private[spark] class ClusterTaskSetManager(
+ sched: ClusterScheduler,
+ val taskSet: TaskSet)
+ extends TaskSetManager
+ with Logging {
+
+ // Maximum time to wait to run a task in a preferred location (in ms)
+ val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
+
+ // CPUs to request per task
+ val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble
+
+ // Maximum times a task is allowed to fail before failing the job
+ val MAX_TASK_FAILURES = 4
+
+ // Quantile of tasks at which to start speculation
+ val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
+ val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
+
+ // Serializer for closures and tasks.
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+
+ val tasks = taskSet.tasks
+ val numTasks = tasks.length
+ val copiesRunning = new Array[Int](numTasks)
+ val finished = new Array[Boolean](numTasks)
+ val numFailures = new Array[Int](numTasks)
+ val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
+ var tasksFinished = 0
+
+ var weight = 1
+ var minShare = 0
+ var runningTasks = 0
+ var priority = taskSet.priority
+ var stageId = taskSet.stageId
+ var name = "TaskSet_"+taskSet.stageId.toString
+ var parent:Schedulable = null
+
+ // Last time when we launched a preferred task (for delay scheduling)
+ var lastPreferredLaunchTime = System.currentTimeMillis
+
+ // List of pending tasks for each node (process local to container). These collections are actually
+ // treated as stacks, in which new tasks are added to the end of the
+ // ArrayBuffer and removed from the end. This makes it faster to detect
+ // tasks that repeatedly fail because whenever a task failed, it is put
+ // back at the head of the stack. They are also only cleaned up lazily;
+ // when a task is launched, it remains in all the pending lists except
+ // the one that it was launched from, but gets removed from them later.
+ private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]]
+
+ // List of pending tasks for each node.
+ // Essentially, similar to pendingTasksForHostPort, except at host level
+ private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
+
+ // List of pending tasks for each node based on rack locality.
+ // Essentially, similar to pendingTasksForHost, except at rack level
+ private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]]
+
+ // List containing pending tasks with no locality preferences
+ val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
+
+ // List containing all pending tasks (also used as a stack, as above)
+ val allPendingTasks = new ArrayBuffer[Int]
+
+ // Tasks that can be speculated. Since these will be a small fraction of total
+ // tasks, we'll just hold them in a HashSet.
+ val speculatableTasks = new HashSet[Int]
+
+ // Task index, start and finish time for each task attempt (indexed by task ID)
+ val taskInfos = new HashMap[Long, TaskInfo]
+
+ // Did the job fail?
+ var failed = false
+ var causeOfFailure = ""
+
+ // How frequently to reprint duplicate exceptions in full, in milliseconds
+ val EXCEPTION_PRINT_INTERVAL =
+ System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
+ // Map of recent exceptions (identified by string representation and
+ // top stack frame) to duplicate count (how many times the same
+ // exception has appeared) and time the full exception was
+ // printed. This should ideally be an LRU map that can drop old
+ // exceptions automatically.
+ val recentExceptions = HashMap[String, (Int, Long)]()
+
+ // Figure out the current map output tracker generation and set it on all tasks
+ val generation = sched.mapOutputTracker.getGeneration
+ logDebug("Generation for " + taskSet.id + ": " + generation)
+ for (t <- tasks) {
+ t.generation = generation
+ }
+
+ // Add all our tasks to the pending lists. We do this in reverse order
+ // of task index so that tasks with low indices get launched first.
+ for (i <- (0 until numTasks).reverse) {
+ addPendingTask(i)
+ }
+
+ // Note that it follows the hierarchy.
+ // if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and
+ // if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL
+ private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler,
+ taskLocality: TaskLocality.TaskLocality): HashSet[String] = {
+
+ if (TaskLocality.PROCESS_LOCAL == taskLocality) {
+ // straight forward comparison ! Special case it.
+ val retval = new HashSet[String]()
+ scheduler.synchronized {
+ for (location <- _taskPreferredLocations) {
+ if (scheduler.isExecutorAliveOnHostPort(location)) {
+ retval += location
+ }
+ }
+ }
+
+ return retval
+ }
+
+ val taskPreferredLocations =
+ if (TaskLocality.NODE_LOCAL == taskLocality) {
+ _taskPreferredLocations
+ } else {
+ assert (TaskLocality.RACK_LOCAL == taskLocality)
+ // Expand set to include all 'seen' rack local hosts.
+ // This works since container allocation/management happens within master - so any rack locality information is updated in msater.
+ // Best case effort, and maybe sort of kludge for now ... rework it later ?
+ val hosts = new HashSet[String]
+ _taskPreferredLocations.foreach(h => {
+ val rackOpt = scheduler.getRackForHost(h)
+ if (rackOpt.isDefined) {
+ val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get)
+ if (hostsOpt.isDefined) {
+ hosts ++= hostsOpt.get
+ }
+ }
+
+ // Ensure that irrespective of what scheduler says, host is always added !
+ hosts += h
+ })
+
+ hosts
+ }
+
+ val retval = new HashSet[String]
+ scheduler.synchronized {
+ for (prefLocation <- taskPreferredLocations) {
+ val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1)
+ if (aliveLocationsOpt.isDefined) {
+ retval ++= aliveLocationsOpt.get
+ }
+ }
+ }
+
+ retval
+ }
+
+ // Add a task to all the pending-task lists that it should be on.
+ private def addPendingTask(index: Int) {
+ // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate
+ // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it.
+ val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL)
+ val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
+ val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
+
+ if (rackLocalLocations.size == 0) {
+ // Current impl ensures this.
+ assert (processLocalLocations.size == 0)
+ assert (hostLocalLocations.size == 0)
+ pendingTasksWithNoPrefs += index
+ } else {
+
+ // process local locality
+ for (hostPort <- processLocalLocations) {
+ // DEBUG Code
+ Utils.checkHostPort(hostPort)
+
+ val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer())
+ hostPortList += index
+ }
+
+ // host locality (includes process local)
+ for (hostPort <- hostLocalLocations) {
+ // DEBUG Code
+ Utils.checkHostPort(hostPort)
+
+ val host = Utils.parseHostPort(hostPort)._1
+ val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
+ hostList += index
+ }
+
+ // rack locality (includes process local and host local)
+ for (rackLocalHostPort <- rackLocalLocations) {
+ // DEBUG Code
+ Utils.checkHostPort(rackLocalHostPort)
+
+ val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1
+ val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer())
+ list += index
+ }
+ }
+
+ allPendingTasks += index
+ }
+
+ // Return the pending tasks list for a given host port (process local), or an empty list if
+ // there is no map entry for that host
+ private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = {
+ // DEBUG Code
+ Utils.checkHostPort(hostPort)
+ pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer())
+ }
+
+ // Return the pending tasks list for a given host, or an empty list if
+ // there is no map entry for that host
+ private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ pendingTasksForHost.getOrElse(host, ArrayBuffer())
+ }
+
+ // Return the pending tasks (rack level) list for a given host, or an empty list if
+ // there is no map entry for that host
+ private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer())
+ }
+
+ // Number of pending tasks for a given host Port (which would be process local)
+ def numPendingTasksForHostPort(hostPort: String): Int = {
+ getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
+ }
+
+ // Number of pending tasks for a given host (which would be data local)
+ def numPendingTasksForHost(hostPort: String): Int = {
+ getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
+ }
+
+ // Number of pending rack local tasks for a given host
+ def numRackLocalPendingTasksForHost(hostPort: String): Int = {
+ getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
+ }
+
+
+ // Dequeue a pending task from the given list and return its index.
+ // Return None if the list is empty.
+ // This method also cleans up any tasks in the list that have already
+ // been launched, since we want that to happen lazily.
+ private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
+ while (!list.isEmpty) {
+ val index = list.last
+ list.trimEnd(1)
+ if (copiesRunning(index) == 0 && !finished(index)) {
+ return Some(index)
+ }
+ }
+ return None
+ }
+
+ // Return a speculative task for a given host if any are available. The task should not have an
+ // attempt running on this host, in case the host is slow. In addition, if locality is set, the
+ // task must have a preference for this host/rack/no preferred locations at all.
+ private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
+
+ assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL))
+ speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
+
+ if (speculatableTasks.size > 0) {
+ val localTask = speculatableTasks.find {
+ index =>
+ val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
+ val attemptLocs = taskAttempts(index).map(_.hostPort)
+ (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort)
+ }
+
+ if (localTask != None) {
+ speculatableTasks -= localTask.get
+ return localTask
+ }
+
+ // check for rack locality
+ if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+ val rackTask = speculatableTasks.find {
+ index =>
+ val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
+ val attemptLocs = taskAttempts(index).map(_.hostPort)
+ locations.contains(hostPort) && !attemptLocs.contains(hostPort)
+ }
+
+ if (rackTask != None) {
+ speculatableTasks -= rackTask.get
+ return rackTask
+ }
+ }
+
+ // Any task ...
+ if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+ // Check for attemptLocs also ?
+ val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort))
+ if (nonLocalTask != None) {
+ speculatableTasks -= nonLocalTask.get
+ return nonLocalTask
+ }
+ }
+ }
+ return None
+ }
+
+ // Dequeue a pending task for a given node and return its index.
+ // If localOnly is set to false, allow non-local tasks as well.
+ private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
+ val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort))
+ if (processLocalTask != None) {
+ return processLocalTask
+ }
+
+ val localTask = findTaskFromList(getPendingTasksForHost(hostPort))
+ if (localTask != None) {
+ return localTask
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+ val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort))
+ if (rackLocalTask != None) {
+ return rackLocalTask
+ }
+ }
+
+ // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner.
+ // TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down).
+ val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
+ if (noPrefTask != None) {
+ return noPrefTask
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+ val nonLocalTask = findTaskFromList(allPendingTasks)
+ if (nonLocalTask != None) {
+ return nonLocalTask
+ }
+ }
+
+ // Finally, if all else has failed, find a speculative task
+ return findSpeculativeTask(hostPort, locality)
+ }
+
+ private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = {
+ Utils.checkHostPort(hostPort)
+
+ val locs = task.preferredLocations
+
+ locs.contains(hostPort)
+ }
+
+ private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = {
+ val locs = task.preferredLocations
+
+ // If no preference, consider it as host local
+ if (locs.isEmpty) return true
+
+ val host = Utils.parseHostPort(hostPort)._1
+ locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined
+ }
+
+ // Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location).
+ // This is true if either the task has preferred locations and this host is one, or it has
+ // no preferred locations (in which we still count the launch as preferred).
+ private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = {
+
+ val locs = task.preferredLocations
+
+ val preferredRacks = new HashSet[String]()
+ for (preferredHost <- locs) {
+ val rack = sched.getRackForHost(preferredHost)
+ if (None != rack) preferredRacks += rack.get
+ }
+
+ if (preferredRacks.isEmpty) return false
+
+ val hostRack = sched.getRackForHost(hostPort)
+
+ return None != hostRack && preferredRacks.contains(hostRack.get)
+ }
+
+ // Respond to an offer of a single slave from the scheduler by finding a task
+ def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
+
+ if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
+ // If explicitly specified, use that
+ val locality = if (overrideLocality != null) overrideLocality else {
+ // expand only if we have waited for more than LOCALITY_WAIT for a host local task ...
+ val time = System.currentTimeMillis
+ if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY
+ }
+
+ findTask(hostPort, locality) match {
+ case Some(index) => {
+ // Found a task; do some bookkeeping and return a Mesos task for it
+ val task = tasks(index)
+ val taskId = sched.newTaskId()
+ // Figure out whether this should count as a preferred launch
+ val taskLocality =
+ if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else
+ if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else
+ if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else
+ TaskLocality.ANY
+ val prefStr = taskLocality.toString
+ logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
+ taskSet.id, index, taskId, execId, hostPort, prefStr))
+ // Do various bookkeeping
+ copiesRunning(index) += 1
+ val time = System.currentTimeMillis
+ val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality)
+ taskInfos(taskId) = info
+ taskAttempts(index) = info :: taskAttempts(index)
+ if (TaskLocality.NODE_LOCAL == taskLocality) {
+ lastPreferredLaunchTime = time
+ }
+ // Serialize and return the task
+ val startTime = System.currentTimeMillis
+ val serializedTask = Task.serializeWithDependencies(
+ task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ val timeTaken = System.currentTimeMillis - startTime
+ increaseRunningTasks(1)
+ logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
+ taskSet.id, index, serializedTask.limit, timeTaken))
+ val taskName = "task %s:%d".format(taskSet.id, index)
+ return Some(new TaskDescription(taskId, execId, taskName, serializedTask))
+ }
+ case _ =>
+ }
+ }
+ return None
+ }
+
+ def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ state match {
+ case TaskState.FINISHED =>
+ taskFinished(tid, state, serializedData)
+ case TaskState.LOST =>
+ taskLost(tid, state, serializedData)
+ case TaskState.FAILED =>
+ taskLost(tid, state, serializedData)
+ case TaskState.KILLED =>
+ taskLost(tid, state, serializedData)
+ case _ =>
+ }
+ }
+
+ def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ if (info.failed) {
+ // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
+ // or even from Mesos itself when acks get delayed.
+ return
+ }
+ val index = info.index
+ info.markSuccessful()
+ decreaseRunningTasks(1)
+ if (!finished(index)) {
+ tasksFinished += 1
+ logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
+ tid, info.duration, tasksFinished, numTasks))
+ // Deserialize task result and pass it to the scheduler
+ try {
+ val result = ser.deserialize[TaskResult[_]](serializedData)
+ result.metrics.resultSize = serializedData.limit()
+ sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
+ } catch {
+ case cnf: ClassNotFoundException =>
+ val loader = Thread.currentThread().getContextClassLoader
+ throw new SparkException("ClassNotFound with classloader: " + loader, cnf)
+ case ex => throw ex
+ }
+ // Mark finished and stop if we've finished all the tasks
+ finished(index) = true
+ if (tasksFinished == numTasks) {
+ sched.taskSetFinished(this)
+ }
+ } else {
+ logInfo("Ignoring task-finished event for TID " + tid +
+ " because task " + index + " is already finished")
+ }
+ }
+
+ def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ if (info.failed) {
+ // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
+ // or even from Mesos itself when acks get delayed.
+ return
+ }
+ val index = info.index
+ info.markFailed()
+ decreaseRunningTasks(1)
+ if (!finished(index)) {
+ logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
+ copiesRunning(index) -= 1
+ // Check if the problem is a map output fetch failure. In that case, this
+ // task will never succeed on any node, so tell the scheduler about it.
+ if (serializedData != null && serializedData.limit() > 0) {
+ val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
+ reason match {
+ case fetchFailed: FetchFailed =>
+ logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+ sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+ finished(index) = true
+ tasksFinished += 1
+ sched.taskSetFinished(this)
+ decreaseRunningTasks(runningTasks)
+ return
+
+ case taskResultTooBig: TaskResultTooBigFailure =>
+ logInfo("Loss was due to task %s result exceeding Akka frame size; " +
+ "aborting job".format(tid))
+ abort("Task %s result exceeded Akka frame size".format(tid))
+ return
+
+ case ef: ExceptionFailure =>
+ val key = ef.description
+ val now = System.currentTimeMillis
+ val (printFull, dupCount) = {
+ if (recentExceptions.contains(key)) {
+ val (dupCount, printTime) = recentExceptions(key)
+ if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
+ recentExceptions(key) = (0, now)
+ (true, 0)
+ } else {
+ recentExceptions(key) = (dupCount + 1, printTime)
+ (false, dupCount + 1)
+ }
+ } else {
+ recentExceptions(key) = (0, now)
+ (true, 0)
+ }
+ }
+ if (printFull) {
+ val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
+ logInfo("Loss was due to %s\n%s\n%s".format(
+ ef.className, ef.description, locs.mkString("\n")))
+ } else {
+ logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+ }
+
+ case _ => {}
+ }
+ }
+ // On non-fetch failures, re-enqueue the task as pending for a max number of retries
+ addPendingTask(index)
+ // Count failed attempts only on FAILED and LOST state (not on KILLED)
+ if (state == TaskState.FAILED || state == TaskState.LOST) {
+ numFailures(index) += 1
+ if (numFailures(index) > MAX_TASK_FAILURES) {
+ logError("Task %s:%d failed more than %d times; aborting job".format(
+ taskSet.id, index, MAX_TASK_FAILURES))
+ abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
+ }
+ }
+ } else {
+ logInfo("Ignoring task-lost event for TID " + tid +
+ " because task " + index + " is already finished")
+ }
+ }
+
+ def error(message: String) {
+ // Save the error message
+ abort("Error: " + message)
+ }
+
+ def abort(message: String) {
+ failed = true
+ causeOfFailure = message
+ // TODO: Kill running tasks if we were not terminated due to a Mesos error
+ sched.listener.taskSetFailed(taskSet, message)
+ decreaseRunningTasks(runningTasks)
+ sched.taskSetFinished(this)
+ }
+
+ override def increaseRunningTasks(taskNum: Int) {
+ runningTasks += taskNum
+ if (parent != null) {
+ parent.increaseRunningTasks(taskNum)
+ }
+ }
+
+ override def decreaseRunningTasks(taskNum: Int) {
+ runningTasks -= taskNum
+ if (parent != null) {
+ parent.decreaseRunningTasks(taskNum)
+ }
+ }
+
+ //TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed
+ override def getSchedulableByName(name: String): Schedulable = {
+ return null
+ }
+
+ override def addSchedulable(schedulable:Schedulable) {
+ //nothing
+ }
+
+ override def removeSchedulable(schedulable:Schedulable) {
+ //nothing
+ }
+
+ override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+ var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
+ sortedTaskSetQueue += this
+ return sortedTaskSetQueue
+ }
+
+ override def executorLost(execId: String, hostPort: String) {
+ logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
+
+ // If some task has preferred locations only on hostname, and there are no more executors there,
+ // put it in the no-prefs list to avoid the wait from delay scheduling
+
+ // host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to
+ // no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc.
+ // Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if
+ // there is no host local node for the task (not if there is no process local node for the task)
+ for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) {
+ // val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
+ val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
+ if (newLocs.isEmpty) {
+ pendingTasksWithNoPrefs += index
+ }
+ }
+
+ // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
+ if (tasks(0).isInstanceOf[ShuffleMapTask]) {
+ for ((tid, info) <- taskInfos if info.executorId == execId) {
+ val index = taskInfos(tid).index
+ if (finished(index)) {
+ finished(index) = false
+ copiesRunning(index) -= 1
+ tasksFinished -= 1
+ addPendingTask(index)
+ // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
+ // stage finishes when a total of tasks.size tasks finish.
+ sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
+ }
+ }
+ }
+ // Also re-enqueue any tasks that were running on the node
+ for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
+ taskLost(tid, TaskState.KILLED, null)
+ }
+ }
+
+ /**
+ * Check for tasks to be speculated and return true if there are any. This is called periodically
+ * by the ClusterScheduler.
+ *
+ * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
+ * we don't scan the whole task set. It might also help to make this sorted by launch time.
+ */
+ override def checkSpeculatableTasks(): Boolean = {
+ // Can't speculate if we only have one task, or if all tasks have finished.
+ if (numTasks == 1 || tasksFinished == numTasks) {
+ return false
+ }
+ var foundTasks = false
+ val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
+ logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
+ if (tasksFinished >= minFinishedForSpeculation) {
+ val time = System.currentTimeMillis()
+ val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
+ Arrays.sort(durations)
+ val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
+ val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
+ // TODO: Threshold should also look at standard deviation of task durations and have a lower
+ // bound based on that.
+ logDebug("Task length threshold for speculation: " + threshold)
+ for ((tid, info) <- taskInfos) {
+ val index = info.index
+ if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
+ !speculatableTasks.contains(index)) {
+ logInfo(
+ "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
+ taskSet.id, index, info.hostPort, threshold))
+ speculatableTasks += index
+ foundTasks = true
+ }
+ }
+ }
+ return foundTasks
+ }
+
+ override def hasPendingTasks(): Boolean = {
+ numTasks > 0 && tasksFinished < numTasks
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/cluster/Pool.scala b/core/src/main/scala/spark/scheduler/cluster/Pool.scala
new file mode 100644
index 0000000000..941ba7a3f1
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/Pool.scala
@@ -0,0 +1,104 @@
+package spark.scheduler.cluster
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+
+import spark.Logging
+import spark.scheduler.cluster.SchedulingMode.SchedulingMode
+
+/**
+ * An Schedulable entity that represent collection of Pools or TaskSetManagers
+ */
+
+private[spark] class Pool(
+ val poolName: String,
+ val schedulingMode: SchedulingMode,
+ initMinShare: Int,
+ initWeight: Int)
+ extends Schedulable
+ with Logging {
+
+ var schedulableQueue = new ArrayBuffer[Schedulable]
+ var schedulableNameToSchedulable = new HashMap[String, Schedulable]
+
+ var weight = initWeight
+ var minShare = initMinShare
+ var runningTasks = 0
+
+ var priority = 0
+ var stageId = 0
+ var name = poolName
+ var parent:Schedulable = null
+
+ var taskSetSchedulingAlgorithm: SchedulingAlgorithm = {
+ schedulingMode match {
+ case SchedulingMode.FAIR =>
+ new FairSchedulingAlgorithm()
+ case SchedulingMode.FIFO =>
+ new FIFOSchedulingAlgorithm()
+ }
+ }
+
+ override def addSchedulable(schedulable: Schedulable) {
+ schedulableQueue += schedulable
+ schedulableNameToSchedulable(schedulable.name) = schedulable
+ schedulable.parent= this
+ }
+
+ override def removeSchedulable(schedulable: Schedulable) {
+ schedulableQueue -= schedulable
+ schedulableNameToSchedulable -= schedulable.name
+ }
+
+ override def getSchedulableByName(schedulableName: String): Schedulable = {
+ if (schedulableNameToSchedulable.contains(schedulableName)) {
+ return schedulableNameToSchedulable(schedulableName)
+ }
+ for (schedulable <- schedulableQueue) {
+ var sched = schedulable.getSchedulableByName(schedulableName)
+ if (sched != null) {
+ return sched
+ }
+ }
+ return null
+ }
+
+ override def executorLost(executorId: String, host: String) {
+ schedulableQueue.foreach(_.executorLost(executorId, host))
+ }
+
+ override def checkSpeculatableTasks(): Boolean = {
+ var shouldRevive = false
+ for (schedulable <- schedulableQueue) {
+ shouldRevive |= schedulable.checkSpeculatableTasks()
+ }
+ return shouldRevive
+ }
+
+ override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+ var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
+ val sortedSchedulableQueue = schedulableQueue.sortWith(taskSetSchedulingAlgorithm.comparator)
+ for (schedulable <- sortedSchedulableQueue) {
+ sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue()
+ }
+ return sortedTaskSetQueue
+ }
+
+ override def increaseRunningTasks(taskNum: Int) {
+ runningTasks += taskNum
+ if (parent != null) {
+ parent.increaseRunningTasks(taskNum)
+ }
+ }
+
+ override def decreaseRunningTasks(taskNum: Int) {
+ runningTasks -= taskNum
+ if (parent != null) {
+ parent.decreaseRunningTasks(taskNum)
+ }
+ }
+
+ override def hasPendingTasks(): Boolean = {
+ schedulableQueue.exists(_.hasPendingTasks())
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala
new file mode 100644
index 0000000000..2dd9c0564f
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala
@@ -0,0 +1,27 @@
+package spark.scheduler.cluster
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * An interface for schedulable entities.
+ * there are two type of Schedulable entities(Pools and TaskSetManagers)
+ */
+private[spark] trait Schedulable {
+ var parent: Schedulable
+ def weight: Int
+ def minShare: Int
+ def runningTasks: Int
+ def priority: Int
+ def stageId: Int
+ def name: String
+
+ def increaseRunningTasks(taskNum: Int): Unit
+ def decreaseRunningTasks(taskNum: Int): Unit
+ def addSchedulable(schedulable: Schedulable): Unit
+ def removeSchedulable(schedulable: Schedulable): Unit
+ def getSchedulableByName(name: String): Schedulable
+ def executorLost(executorId: String, host: String): Unit
+ def checkSpeculatableTasks(): Boolean
+ def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager]
+ def hasPendingTasks(): Boolean
+}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala
new file mode 100644
index 0000000000..18cc15c2a5
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala
@@ -0,0 +1,115 @@
+package spark.scheduler.cluster
+
+import java.io.{File, FileInputStream, FileOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.util.control.Breaks._
+import scala.xml._
+
+import spark.Logging
+import spark.scheduler.cluster.SchedulingMode.SchedulingMode
+
+import java.util.Properties
+
+/**
+ * An interface to build Schedulable tree
+ * buildPools: build the tree nodes(pools)
+ * addTaskSetManager: build the leaf nodes(TaskSetManagers)
+ */
+private[spark] trait SchedulableBuilder {
+ def buildPools()
+ def addTaskSetManager(manager: Schedulable, properties: Properties)
+}
+
+private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) extends SchedulableBuilder with Logging {
+
+ override def buildPools() {
+ //nothing
+ }
+
+ override def addTaskSetManager(manager: Schedulable, properties: Properties) {
+ rootPool.addSchedulable(manager)
+ }
+}
+
+private[spark] class FairSchedulableBuilder(val rootPool: Pool) extends SchedulableBuilder with Logging {
+
+ val schedulerAllocFile = System.getProperty("spark.fairscheduler.allocation.file","unspecified")
+ val FAIR_SCHEDULER_PROPERTIES = "spark.scheduler.cluster.fair.pool"
+ val DEFAULT_POOL_NAME = "default"
+ val MINIMUM_SHARES_PROPERTY = "minShare"
+ val SCHEDULING_MODE_PROPERTY = "schedulingMode"
+ val WEIGHT_PROPERTY = "weight"
+ val POOL_NAME_PROPERTY = "@name"
+ val POOLS_PROPERTY = "pool"
+ val DEFAULT_SCHEDULING_MODE = SchedulingMode.FIFO
+ val DEFAULT_MINIMUM_SHARE = 2
+ val DEFAULT_WEIGHT = 1
+
+ override def buildPools() {
+ val file = new File(schedulerAllocFile)
+ if (file.exists()) {
+ val xml = XML.loadFile(file)
+ for (poolNode <- (xml \\ POOLS_PROPERTY)) {
+
+ val poolName = (poolNode \ POOL_NAME_PROPERTY).text
+ var schedulingMode = DEFAULT_SCHEDULING_MODE
+ var minShare = DEFAULT_MINIMUM_SHARE
+ var weight = DEFAULT_WEIGHT
+
+ val xmlSchedulingMode = (poolNode \ SCHEDULING_MODE_PROPERTY).text
+ if (xmlSchedulingMode != "") {
+ try {
+ schedulingMode = SchedulingMode.withName(xmlSchedulingMode)
+ } catch {
+ case e: Exception => logInfo("Error xml schedulingMode, using default schedulingMode")
+ }
+ }
+
+ val xmlMinShare = (poolNode \ MINIMUM_SHARES_PROPERTY).text
+ if (xmlMinShare != "") {
+ minShare = xmlMinShare.toInt
+ }
+
+ val xmlWeight = (poolNode \ WEIGHT_PROPERTY).text
+ if (xmlWeight != "") {
+ weight = xmlWeight.toInt
+ }
+
+ val pool = new Pool(poolName, schedulingMode, minShare, weight)
+ rootPool.addSchedulable(pool)
+ logInfo("Create new pool with name:%s,schedulingMode:%s,minShare:%d,weight:%d".format(
+ poolName, schedulingMode, minShare, weight))
+ }
+ }
+
+ //finally create "default" pool
+ if (rootPool.getSchedulableByName(DEFAULT_POOL_NAME) == null) {
+ val pool = new Pool(DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)
+ rootPool.addSchedulable(pool)
+ logInfo("Create default pool with name:%s,schedulingMode:%s,minShare:%d,weight:%d".format(
+ DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT))
+ }
+ }
+
+ override def addTaskSetManager(manager: Schedulable, properties: Properties) {
+ var poolName = DEFAULT_POOL_NAME
+ var parentPool = rootPool.getSchedulableByName(poolName)
+ if (properties != null) {
+ poolName = properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME)
+ parentPool = rootPool.getSchedulableByName(poolName)
+ if (parentPool == null) {
+ //we will create a new pool that user has configured in app instead of being defined in xml file
+ parentPool = new Pool(poolName,DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)
+ rootPool.addSchedulable(parentPool)
+ logInfo("Create pool with name:%s,schedulingMode:%s,minShare:%d,weight:%d".format(
+ poolName, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT))
+ }
+ }
+ parentPool.addSchedulable(manager)
+ logInfo("Added task set " + manager.name + " tasks to pool "+poolName)
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
index 9ac875de3a..8844057a5c 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
@@ -1,6 +1,6 @@
package spark.scheduler.cluster
-import spark.Utils
+import spark.{SparkContext, Utils}
/**
* A backend interface for cluster scheduling systems that allows plugging in different ones under
@@ -14,14 +14,7 @@ private[spark] trait SchedulerBackend {
def defaultParallelism(): Int
// Memory used by each executor (in megabytes)
- protected val executorMemory = {
- // TODO: Might need to add some extra memory for the non-heap parts of the JVM
- Option(System.getProperty("spark.executor.memory"))
- .orElse(Option(System.getenv("SPARK_MEM")))
- .map(Utils.memoryStringToMb)
- .getOrElse(512)
- }
-
+ protected val executorMemory: Int = SparkContext.executorMemoryRequested
// TODO: Probably want to add a killTask too
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala
new file mode 100644
index 0000000000..f33310a34a
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala
@@ -0,0 +1,64 @@
+package spark.scheduler.cluster
+
+/**
+ * An interface for sort algorithm
+ * FIFO: FIFO algorithm between TaskSetManagers
+ * FS: FS algorithm between Pools, and FIFO or FS within Pools
+ */
+private[spark] trait SchedulingAlgorithm {
+ def comparator(s1: Schedulable, s2: Schedulable): Boolean
+}
+
+private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm {
+ override def comparator(s1: Schedulable, s2: Schedulable): Boolean = {
+ val priority1 = s1.priority
+ val priority2 = s2.priority
+ var res = math.signum(priority1 - priority2)
+ if (res == 0) {
+ val stageId1 = s1.stageId
+ val stageId2 = s2.stageId
+ res = math.signum(stageId1 - stageId2)
+ }
+ if (res < 0) {
+ return true
+ } else {
+ return false
+ }
+ }
+}
+
+private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm {
+ override def comparator(s1: Schedulable, s2: Schedulable): Boolean = {
+ val minShare1 = s1.minShare
+ val minShare2 = s2.minShare
+ val runningTasks1 = s1.runningTasks
+ val runningTasks2 = s2.runningTasks
+ val s1Needy = runningTasks1 < minShare1
+ val s2Needy = runningTasks2 < minShare2
+ val minShareRatio1 = runningTasks1.toDouble / math.max(minShare1, 1.0).toDouble
+ val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble
+ val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble
+ val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble
+ var res:Boolean = true
+ var compare:Int = 0
+
+ if (s1Needy && !s2Needy) {
+ return true
+ } else if (!s1Needy && s2Needy) {
+ return false
+ } else if (s1Needy && s2Needy) {
+ compare = minShareRatio1.compareTo(minShareRatio2)
+ } else {
+ compare = taskToWeightRatio1.compareTo(taskToWeightRatio2)
+ }
+
+ if (compare < 0) {
+ return true
+ } else if (compare > 0) {
+ return false
+ } else {
+ return s1.name < s2.name
+ }
+ }
+}
+
diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala
new file mode 100644
index 0000000000..6e0c6793e0
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala
@@ -0,0 +1,7 @@
+package spark.scheduler.cluster
+
+object SchedulingMode extends Enumeration("FAIR","FIFO"){
+
+ type SchedulingMode = Value
+ val FAIR,FIFO = Value
+}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index bb289c9cf3..170ede0f44 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -31,7 +31,8 @@ private[spark] class SparkDeploySchedulerBackend(
val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
val sparkHome = sc.getSparkHome().getOrElse(
throw new IllegalArgumentException("must supply spark home for spark standalone"))
- val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome)
+ val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome,
+ sc.ui.appUIAddress)
client = new Client(sc.env.actorSystem, master, appDesc, this)
client.start()
@@ -57,9 +58,9 @@ private[spark] class SparkDeploySchedulerBackend(
}
}
- override def executorAdded(executorId: String, workerId: String, host: String, cores: Int, memory: Int) {
- logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format(
- executorId, host, cores, Utils.memoryMegabytesToString(memory)))
+ override def executorAdded(executorId: String, workerId: String, hostPort: String, cores: Int, memory: Int) {
+ logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format(
+ executorId, hostPort, cores, Utils.memoryMegabytesToString(memory)))
}
override def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) {
diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
index d766067824..3335294844 100644
--- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
@@ -3,6 +3,7 @@ package spark.scheduler.cluster
import spark.TaskState.TaskState
import java.nio.ByteBuffer
import spark.util.SerializableBuffer
+import spark.Utils
private[spark] sealed trait StandaloneClusterMessage extends Serializable
@@ -19,8 +20,10 @@ case class RegisterExecutorFailed(message: String) extends StandaloneClusterMess
// Executors to driver
private[spark]
-case class RegisterExecutor(executorId: String, host: String, cores: Int)
- extends StandaloneClusterMessage
+case class RegisterExecutor(executorId: String, hostPort: String, cores: Int)
+ extends StandaloneClusterMessage {
+ Utils.checkHostPort(hostPort, "Expected host port")
+}
private[spark]
case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer)
diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
index a06d853b46..16131215c8 100644
--- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
@@ -6,8 +6,9 @@ import akka.actor._
import scala.concurrent.duration._
import akka.pattern.ask
-import spark.{SparkException, Logging, TaskState}
+import spark.{Utils, SparkException, Logging, TaskState}
import scala.concurrent.Await
+
import java.util.concurrent.atomic.AtomicInteger
import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent}
@@ -24,12 +25,12 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
var totalCoreCount = new AtomicInteger(0)
class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor {
- val executorActor = new HashMap[String, ActorRef]
- val executorAddress = new HashMap[String, Address]
- val executorHost = new HashMap[String, String]
- val freeCores = new HashMap[String, Int]
- val actorToExecutorId = new HashMap[ActorRef, String]
- val addressToExecutorId = new HashMap[Address, String]
+ private val executorActor = new HashMap[String, ActorRef]
+ private val executorAddress = new HashMap[String, Address]
+ private val executorHostPort = new HashMap[String, String]
+ private val freeCores = new HashMap[String, Int]
+ private val actorToExecutorId = new HashMap[ActorRef, String]
+ private val addressToExecutorId = new HashMap[Address, String]
override def preStart() {
// Listen for remote client disconnection events, since they don't go through Akka's watch()
@@ -37,7 +38,8 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
def receive = {
- case RegisterExecutor(executorId, host, cores) =>
+ case RegisterExecutor(executorId, hostPort, cores) =>
+ Utils.checkHostPort(hostPort, "Host port expected " + hostPort)
if (executorActor.contains(executorId)) {
sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId)
} else {
@@ -45,7 +47,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
sender ! RegisteredExecutor(sparkProperties)
context.watch(sender)
executorActor(executorId) = sender
- executorHost(executorId) = host
+ executorHostPort(executorId) = hostPort
freeCores(executorId) = cores
executorAddress(executorId) = sender.path.address
actorToExecutorId(sender) = executorId
@@ -85,13 +87,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
// Make fake resource offers on all executors
def makeOffers() {
launchTasks(scheduler.resourceOffers(
- executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))}))
+ executorHostPort.toArray.map {case (id, hostPort) => new WorkerOffer(id, hostPort, freeCores(id))}))
}
// Make fake resource offers on just one executor
def makeOffers(executorId: String) {
launchTasks(scheduler.resourceOffers(
- Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId)))))
+ Seq(new WorkerOffer(executorId, executorHostPort(executorId), freeCores(executorId)))))
}
// Launch tasks returned by a set of resource offers
@@ -110,9 +112,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
actorToExecutorId -= executorActor(executorId)
addressToExecutorId -= executorAddress(executorId)
executorActor -= executorId
- executorHost -= executorId
+ executorHostPort -= executorId
freeCores -= executorId
- executorHost -= executorId
+ executorHostPort -= executorId
totalCoreCount.addAndGet(-numCores)
scheduler.executorLost(executorId, SlaveLost(reason))
}
@@ -128,7 +130,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
while (iterator.hasNext) {
val entry = iterator.next
val (key, value) = (entry.getKey.toString, entry.getValue.toString)
- if (key.startsWith("spark.")) {
+ if (key.startsWith("spark.") && !key.equals("spark.hostPort")) {
properties += ((key, value))
}
}
@@ -136,10 +138,11 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
}
+ private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+
override def stop() {
try {
if (driverActor != null) {
- val timeout = 5.seconds
val future = driverActor.ask(StopDriver)(timeout)
Await.result(future, timeout)
}
@@ -159,7 +162,6 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
// Called by subclasses when notified of a lost worker
def removeExecutor(executorId: String, reason: String) {
try {
- val timeout = 5.seconds
val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout)
Await.result(future, timeout)
} catch {
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
index dfe3c5a85b..718f26bfbd 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
@@ -1,5 +1,7 @@
package spark.scheduler.cluster
+import spark.Utils
+
/**
* Information about a running task attempt inside a TaskSet.
*/
@@ -9,8 +11,11 @@ class TaskInfo(
val index: Int,
val launchTime: Long,
val executorId: String,
- val host: String,
- val preferred: Boolean) {
+ val hostPort: String,
+ val taskLocality: TaskLocality.TaskLocality) {
+
+ Utils.checkHostPort(hostPort, "Expected hostport")
+
var finishTime: Long = 0
var failed = false
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index c9f2c48804..b4dd75d90f 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -1,430 +1,17 @@
package spark.scheduler.cluster
-import java.util.Arrays
-import java.util.{HashMap => JHashMap}
-
import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.math.max
-import scala.math.min
-
-import spark._
import spark.scheduler._
import spark.TaskState.TaskState
import java.nio.ByteBuffer
-/**
- * Schedules the tasks within a single TaskSet in the ClusterScheduler.
- */
-private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends Logging {
-
- // Maximum time to wait to run a task in a preferred location (in ms)
- val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
-
- // CPUs to request per task
- val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble
-
- // Maximum times a task is allowed to fail before failing the job
- val MAX_TASK_FAILURES = 4
-
- // Quantile of tasks at which to start speculation
- val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
- val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
-
- // Serializer for closures and tasks.
- val ser = SparkEnv.get.closureSerializer.newInstance()
-
- val priority = taskSet.priority
- val tasks = taskSet.tasks
- val numTasks = tasks.length
- val copiesRunning = new Array[Int](numTasks)
- val finished = new Array[Boolean](numTasks)
- val numFailures = new Array[Int](numTasks)
- val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
- var tasksFinished = 0
-
- // Last time when we launched a preferred task (for delay scheduling)
- var lastPreferredLaunchTime = System.currentTimeMillis
-
- // List of pending tasks for each node. These collections are actually
- // treated as stacks, in which new tasks are added to the end of the
- // ArrayBuffer and removed from the end. This makes it faster to detect
- // tasks that repeatedly fail because whenever a task failed, it is put
- // back at the head of the stack. They are also only cleaned up lazily;
- // when a task is launched, it remains in all the pending lists except
- // the one that it was launched from, but gets removed from them later.
- val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
-
- // List containing pending tasks with no locality preferences
- val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
-
- // List containing all pending tasks (also used as a stack, as above)
- val allPendingTasks = new ArrayBuffer[Int]
-
- // Tasks that can be speculated. Since these will be a small fraction of total
- // tasks, we'll just hold them in a HashSet.
- val speculatableTasks = new HashSet[Int]
-
- // Task index, start and finish time for each task attempt (indexed by task ID)
- val taskInfos = new HashMap[Long, TaskInfo]
-
- // Did the job fail?
- var failed = false
- var causeOfFailure = ""
-
- // How frequently to reprint duplicate exceptions in full, in milliseconds
- val EXCEPTION_PRINT_INTERVAL =
- System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
- // Map of recent exceptions (identified by string representation and
- // top stack frame) to duplicate count (how many times the same
- // exception has appeared) and time the full exception was
- // printed. This should ideally be an LRU map that can drop old
- // exceptions automatically.
- val recentExceptions = HashMap[String, (Int, Long)]()
-
- // Figure out the current map output tracker generation and set it on all tasks
- val generation = sched.mapOutputTracker.getGeneration
- logDebug("Generation for " + taskSet.id + ": " + generation)
- for (t <- tasks) {
- t.generation = generation
- }
-
- // Add all our tasks to the pending lists. We do this in reverse order
- // of task index so that tasks with low indices get launched first.
- for (i <- (0 until numTasks).reverse) {
- addPendingTask(i)
- }
-
- // Add a task to all the pending-task lists that it should be on.
- private def addPendingTask(index: Int) {
- val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
- if (locations.size == 0) {
- pendingTasksWithNoPrefs += index
- } else {
- for (host <- locations) {
- val list = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
- list += index
- }
- }
- allPendingTasks += index
- }
-
- // Return the pending tasks list for a given host, or an empty list if
- // there is no map entry for that host
- private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
- pendingTasksForHost.getOrElse(host, ArrayBuffer())
- }
-
- // Dequeue a pending task from the given list and return its index.
- // Return None if the list is empty.
- // This method also cleans up any tasks in the list that have already
- // been launched, since we want that to happen lazily.
- private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
- while (!list.isEmpty) {
- val index = list.last
- list.trimEnd(1)
- if (copiesRunning(index) == 0 && !finished(index)) {
- return Some(index)
- }
- }
- return None
- }
-
- // Return a speculative task for a given host if any are available. The task should not have an
- // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
- // task must have a preference for this host (or no preferred locations at all).
- private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
- val hostsAlive = sched.hostsAlive
- speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
- val localTask = speculatableTasks.find {
- index =>
- val locations = tasks(index).preferredLocations.toSet & hostsAlive
- val attemptLocs = taskAttempts(index).map(_.host)
- (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host)
- }
- if (localTask != None) {
- speculatableTasks -= localTask.get
- return localTask
- }
- if (!localOnly && speculatableTasks.size > 0) {
- val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.host).contains(host))
- if (nonLocalTask != None) {
- speculatableTasks -= nonLocalTask.get
- return nonLocalTask
- }
- }
- return None
- }
-
- // Dequeue a pending task for a given node and return its index.
- // If localOnly is set to false, allow non-local tasks as well.
- private def findTask(host: String, localOnly: Boolean): Option[Int] = {
- val localTask = findTaskFromList(getPendingTasksForHost(host))
- if (localTask != None) {
- return localTask
- }
- val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
- if (noPrefTask != None) {
- return noPrefTask
- }
- if (!localOnly) {
- val nonLocalTask = findTaskFromList(allPendingTasks)
- if (nonLocalTask != None) {
- return nonLocalTask
- }
- }
- // Finally, if all else has failed, find a speculative task
- return findSpeculativeTask(host, localOnly)
- }
-
- // Does a host count as a preferred location for a task? This is true if
- // either the task has preferred locations and this host is one, or it has
- // no preferred locations (in which we still count the launch as preferred).
- private def isPreferredLocation(task: Task[_], host: String): Boolean = {
- val locs = task.preferredLocations
- return (locs.contains(host) || locs.isEmpty)
- }
-
- // Respond to an offer of a single slave from the scheduler by finding a task
- def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = {
- if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
- val time = System.currentTimeMillis
- val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
-
- findTask(host, localOnly) match {
- case Some(index) => {
- // Found a task; do some bookkeeping and return a Mesos task for it
- val task = tasks(index)
- val taskId = sched.newTaskId()
- // Figure out whether this should count as a preferred launch
- val preferred = isPreferredLocation(task, host)
- val prefStr = if (preferred) {
- "preferred"
- } else {
- "non-preferred, not one of " + task.preferredLocations.mkString(", ")
- }
- logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
- taskSet.id, index, taskId, execId, host, prefStr))
- // Do various bookkeeping
- copiesRunning(index) += 1
- val info = new TaskInfo(taskId, index, time, execId, host, preferred)
- taskInfos(taskId) = info
- taskAttempts(index) = info :: taskAttempts(index)
- if (preferred) {
- lastPreferredLaunchTime = time
- }
- // Serialize and return the task
- val startTime = System.currentTimeMillis
- val serializedTask = Task.serializeWithDependencies(
- task, sched.sc.addedFiles, sched.sc.addedJars, ser)
- val timeTaken = System.currentTimeMillis - startTime
- logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
- taskSet.id, index, serializedTask.limit, timeTaken))
- val taskName = "task %s:%d".format(taskSet.id, index)
- return Some(new TaskDescription(taskId, execId, taskName, serializedTask))
- }
- case _ =>
- }
- }
- return None
- }
-
- def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- state match {
- case TaskState.FINISHED =>
- taskFinished(tid, state, serializedData)
- case TaskState.LOST =>
- taskLost(tid, state, serializedData)
- case TaskState.FAILED =>
- taskLost(tid, state, serializedData)
- case TaskState.KILLED =>
- taskLost(tid, state, serializedData)
- case _ =>
- }
- }
-
- def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- val info = taskInfos(tid)
- if (info.failed) {
- // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
- // or even from Mesos itself when acks get delayed.
- return
- }
- val index = info.index
- info.markSuccessful()
- if (!finished(index)) {
- tasksFinished += 1
- logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
- tid, info.duration, tasksFinished, numTasks))
- // Deserialize task result and pass it to the scheduler
- val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
- result.metrics.resultSize = serializedData.limit()
- sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
- // Mark finished and stop if we've finished all the tasks
- finished(index) = true
- if (tasksFinished == numTasks) {
- sched.taskSetFinished(this)
- }
- } else {
- logInfo("Ignoring task-finished event for TID " + tid +
- " because task " + index + " is already finished")
- }
- }
-
- def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- val info = taskInfos(tid)
- if (info.failed) {
- // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
- // or even from Mesos itself when acks get delayed.
- return
- }
- val index = info.index
- info.markFailed()
- if (!finished(index)) {
- logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
- copiesRunning(index) -= 1
- // Check if the problem is a map output fetch failure. In that case, this
- // task will never succeed on any node, so tell the scheduler about it.
- if (serializedData != null && serializedData.limit() > 0) {
- val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
- reason match {
- case fetchFailed: FetchFailed =>
- logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
- sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
- finished(index) = true
- tasksFinished += 1
- sched.taskSetFinished(this)
- return
-
- case ef: ExceptionFailure =>
- val key = ef.exception.toString
- val now = System.currentTimeMillis
- val (printFull, dupCount) = {
- if (recentExceptions.contains(key)) {
- val (dupCount, printTime) = recentExceptions(key)
- if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
- recentExceptions(key) = (0, now)
- (true, 0)
- } else {
- recentExceptions(key) = (dupCount + 1, printTime)
- (false, dupCount + 1)
- }
- } else {
- recentExceptions(key) = (0, now)
- (true, 0)
- }
- }
- if (printFull) {
- val locs = ef.exception.getStackTrace.map(loc => "\tat %s".format(loc.toString))
- logInfo("Loss was due to %s\n%s".format(ef.exception.toString, locs.mkString("\n")))
- } else {
- logInfo("Loss was due to %s [duplicate %d]".format(ef.exception.toString, dupCount))
- }
-
- case _ => {}
- }
- }
- // On non-fetch failures, re-enqueue the task as pending for a max number of retries
- addPendingTask(index)
- // Count failed attempts only on FAILED and LOST state (not on KILLED)
- if (state == TaskState.FAILED || state == TaskState.LOST) {
- numFailures(index) += 1
- if (numFailures(index) > MAX_TASK_FAILURES) {
- logError("Task %s:%d failed more than %d times; aborting job".format(
- taskSet.id, index, MAX_TASK_FAILURES))
- abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
- }
- }
- } else {
- logInfo("Ignoring task-lost event for TID " + tid +
- " because task " + index + " is already finished")
- }
- }
-
- def error(message: String) {
- // Save the error message
- abort("Error: " + message)
- }
-
- def abort(message: String) {
- failed = true
- causeOfFailure = message
- // TODO: Kill running tasks if we were not terminated due to a Mesos error
- sched.listener.taskSetFailed(taskSet, message)
- sched.taskSetFinished(this)
- }
-
- def executorLost(execId: String, hostname: String) {
- logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
- val newHostsAlive = sched.hostsAlive
- // If some task has preferred locations only on hostname, and there are no more executors there,
- // put it in the no-prefs list to avoid the wait from delay scheduling
- if (!newHostsAlive.contains(hostname)) {
- for (index <- getPendingTasksForHost(hostname)) {
- val newLocs = tasks(index).preferredLocations.toSet & newHostsAlive
- if (newLocs.isEmpty) {
- pendingTasksWithNoPrefs += index
- }
- }
- }
- // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
- if (tasks(0).isInstanceOf[ShuffleMapTask]) {
- for ((tid, info) <- taskInfos if info.executorId == execId) {
- val index = taskInfos(tid).index
- if (finished(index)) {
- finished(index) = false
- copiesRunning(index) -= 1
- tasksFinished -= 1
- addPendingTask(index)
- // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
- // stage finishes when a total of tasks.size tasks finish.
- sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
- }
- }
- }
- // Also re-enqueue any tasks that were running on the node
- for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
- taskLost(tid, TaskState.KILLED, null)
- }
- }
-
- /**
- * Check for tasks to be speculated and return true if there are any. This is called periodically
- * by the ClusterScheduler.
- *
- * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
- * we don't scan the whole task set. It might also help to make this sorted by launch time.
- */
- def checkSpeculatableTasks(): Boolean = {
- // Can't speculate if we only have one task, or if all tasks have finished.
- if (numTasks == 1 || tasksFinished == numTasks) {
- return false
- }
- var foundTasks = false
- val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
- logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
- if (tasksFinished >= minFinishedForSpeculation) {
- val time = System.currentTimeMillis()
- val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
- Arrays.sort(durations)
- val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
- val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
- // TODO: Threshold should also look at standard deviation of task durations and have a lower
- // bound based on that.
- logDebug("Task length threshold for speculation: " + threshold)
- for ((tid, info) <- taskInfos) {
- val index = info.index
- if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
- !speculatableTasks.contains(index)) {
- logInfo(
- "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
- taskSet.id, index, info.host, threshold))
- speculatableTasks += index
- foundTasks = true
- }
- }
- }
- return foundTasks
- }
+private[spark] trait TaskSetManager extends Schedulable {
+ def taskSet: TaskSet
+ def slaveOffer(execId: String, hostPort: String, availableCpus: Double,
+ overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription]
+ def numPendingTasksForHostPort(hostPort: String): Int
+ def numRackLocalPendingTasksForHost(hostPort :String): Int
+ def numPendingTasksForHost(hostPort: String): Int
+ def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer)
+ def error(message: String)
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
index 3c3afcbb14..c47824315c 100644
--- a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
@@ -4,5 +4,5 @@ package spark.scheduler.cluster
* Represents free resources available on an executor.
*/
private[spark]
-class WorkerOffer(val executorId: String, val hostname: String, val cores: Int) {
+class WorkerOffer(val executorId: String, val hostPort: String, val cores: Int) {
}
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 9e1bde3fbe..93d4318b29 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -2,19 +2,50 @@ package spark.scheduler.local
import java.io.File
import java.util.concurrent.atomic.AtomicInteger
+import java.nio.ByteBuffer
+import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
import spark._
+import spark.TaskState.TaskState
import spark.executor.ExecutorURLClassLoader
import spark.scheduler._
-import spark.scheduler.cluster.TaskInfo
+import spark.scheduler.cluster._
+import akka.actor._
/**
- * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
+ * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
* the scheduler also allows each task to fail up to maxFailures times, which is useful for
* testing fault recovery.
*/
-private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext)
+
+private[spark] case class LocalReviveOffers()
+private[spark] case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
+
+private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging {
+ def receive = {
+ case LocalReviveOffers =>
+ launchTask(localScheduler.resourceOffer(freeCores))
+ case LocalStatusUpdate(taskId, state, serializeData) =>
+ freeCores += 1
+ localScheduler.statusUpdate(taskId, state, serializeData)
+ launchTask(localScheduler.resourceOffer(freeCores))
+ }
+
+ def launchTask(tasks : Seq[TaskDescription]) {
+ for (task <- tasks) {
+ freeCores -= 1
+ localScheduler.threadPool.submit(new Runnable {
+ def run() {
+ localScheduler.runTask(task.taskId,task.serializedTask)
+ }
+ })
+ }
+ }
+}
+
+private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
extends TaskScheduler
with Logging {
@@ -30,87 +61,127 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader)
- // TODO: Need to take into account stage priority in scheduling
+ var schedulableBuilder: SchedulableBuilder = null
+ var rootPool: Pool = null
+ val activeTaskSets = new HashMap[String, TaskSetManager]
+ val taskIdToTaskSetId = new HashMap[Long, String]
+ val taskSetTaskIds = new HashMap[String, HashSet[Long]]
+
+ var localActor: ActorRef = null
+
+ override def start() {
+ //default scheduler is FIFO
+ val schedulingMode = System.getProperty("spark.cluster.schedulingmode", "FIFO")
+ //temporarily set rootPool name to empty
+ rootPool = new Pool("", SchedulingMode.withName(schedulingMode), 0, 0)
+ schedulableBuilder = {
+ schedulingMode match {
+ case "FIFO" =>
+ new FIFOSchedulableBuilder(rootPool)
+ case "FAIR" =>
+ new FairSchedulableBuilder(rootPool)
+ }
+ }
+ schedulableBuilder.buildPools()
- override def start() { }
+ localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
+ }
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
}
override def submitTasks(taskSet: TaskSet) {
- val tasks = taskSet.tasks
- val failCount = new Array[Int](tasks.size)
-
- def submitTask(task: Task[_], idInJob: Int) {
- val myAttemptId = attemptId.getAndIncrement()
- threadPool.submit(new Runnable {
- def run() {
- runTask(task, idInJob, myAttemptId)
- }
- })
+ synchronized {
+ var manager = new LocalTaskSetManager(this, taskSet)
+ schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
+ activeTaskSets(taskSet.id) = manager
+ taskSetTaskIds(taskSet.id) = new HashSet[Long]()
+ localActor ! LocalReviveOffers
}
+ }
+
+ def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
+ synchronized {
+ var freeCpuCores = freeCores
+ val tasks = new ArrayBuffer[TaskDescription](freeCores)
+ val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
+ for (manager <- sortedTaskSetQueue) {
+ logDebug("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks))
+ }
- def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
- logInfo("Running " + task)
- val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local", true)
- // Set the Spark execution environment for the worker thread
- SparkEnv.set(env)
- try {
- Accumulators.clear()
- Thread.currentThread().setContextClassLoader(classLoader)
-
- // Serialize and deserialize the task so that accumulators are changed to thread-local ones;
- // this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
- val ser = SparkEnv.get.closureSerializer.newInstance()
- val bytes = Task.serializeWithDependencies(task, sc.addedFiles, sc.addedJars, ser)
- logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes")
- val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
- updateDependencies(taskFiles, taskJars) // Download any files added with addFile
- val deserStart = System.currentTimeMillis()
- val deserializedTask = ser.deserialize[Task[_]](
- taskBytes, Thread.currentThread.getContextClassLoader)
- val deserTime = System.currentTimeMillis() - deserStart
-
- // Run it
- val result: Any = deserializedTask.run(attemptId)
-
- // Serialize and deserialize the result to emulate what the Mesos
- // executor does. This is useful to catch serialization errors early
- // on in development (so when users move their local Spark programs
- // to the cluster, they don't get surprised by serialization errors).
- val serResult = ser.serialize(result)
- deserializedTask.metrics.get.resultSize = serResult.limit()
- val resultToReturn = ser.deserialize[Any](serResult)
- val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
- ser.serialize(Accumulators.values))
- logInfo("Finished " + task)
- info.markSuccessful()
- deserializedTask.metrics.get.executorRunTime = info.duration.toInt //close enough
- deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
-
- // If the threadpool has not already been shutdown, notify DAGScheduler
- if (!Thread.currentThread().isInterrupted)
- listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, deserializedTask.metrics.getOrElse(null))
- } catch {
- case t: Throwable => {
- logError("Exception in task " + idInJob, t)
- failCount.synchronized {
- failCount(idInJob) += 1
- if (failCount(idInJob) <= maxFailures) {
- submitTask(task, idInJob)
- } else {
- // TODO: Do something nicer here to return all the way to the user
- if (!Thread.currentThread().isInterrupted)
- listener.taskEnded(task, new ExceptionFailure(t), null, null, info, null)
+ var launchTask = false
+ for (manager <- sortedTaskSetQueue) {
+ do {
+ launchTask = false
+ manager.slaveOffer(null,null,freeCpuCores) match {
+ case Some(task) =>
+ tasks += task
+ taskIdToTaskSetId(task.taskId) = manager.taskSet.id
+ taskSetTaskIds(manager.taskSet.id) += task.taskId
+ freeCpuCores -= 1
+ launchTask = true
+ case None => {}
}
- }
- }
+ } while(launchTask)
}
+ return tasks
}
+ }
- for ((task, i) <- tasks.zipWithIndex) {
- submitTask(task, i)
+ def taskSetFinished(manager: TaskSetManager) {
+ synchronized {
+ activeTaskSets -= manager.taskSet.id
+ manager.parent.removeSchedulable(manager)
+ logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
+ taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
+ taskSetTaskIds -= manager.taskSet.id
+ }
+ }
+
+ def runTask(taskId: Long, bytes: ByteBuffer) {
+ logInfo("Running " + taskId)
+ val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
+ // Set the Spark execution environment for the worker thread
+ SparkEnv.set(env)
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ try {
+ Accumulators.clear()
+ Thread.currentThread().setContextClassLoader(classLoader)
+
+ // Serialize and deserialize the task so that accumulators are changed to thread-local ones;
+ // this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
+ val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
+ updateDependencies(taskFiles, taskJars) // Download any files added with addFile
+ val deserStart = System.currentTimeMillis()
+ val deserializedTask = ser.deserialize[Task[_]](
+ taskBytes, Thread.currentThread.getContextClassLoader)
+ val deserTime = System.currentTimeMillis() - deserStart
+
+ // Run it
+ val result: Any = deserializedTask.run(taskId)
+
+ // Serialize and deserialize the result to emulate what the Mesos
+ // executor does. This is useful to catch serialization errors early
+ // on in development (so when users move their local Spark programs
+ // to the cluster, they don't get surprised by serialization errors).
+ val serResult = ser.serialize(result)
+ deserializedTask.metrics.get.resultSize = serResult.limit()
+ val resultToReturn = ser.deserialize[Any](serResult)
+ val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
+ ser.serialize(Accumulators.values))
+ logInfo("Finished " + taskId)
+ deserializedTask.metrics.get.executorRunTime = deserTime.toInt//info.duration.toInt //close enough
+ deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
+
+ val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null))
+ val serializedResult = ser.serialize(taskResult)
+ localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
+ } catch {
+ case t: Throwable => {
+ val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace)
+ localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure))
+ }
}
}
@@ -126,6 +197,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentFiles(name) = timestamp
}
+
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
@@ -141,7 +213,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
}
}
- override def stop() {
+ def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) {
+ synchronized {
+ val taskSetId = taskIdToTaskSetId(taskId)
+ val taskSetManager = activeTaskSets(taskSetId)
+ taskSetTaskIds(taskSetId) -= taskId
+ taskSetManager.statusUpdate(taskId, state, serializedData)
+ }
+ }
+
+ override def stop() {
threadPool.shutdownNow()
}
diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala
new file mode 100644
index 0000000000..70b69bb26f
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala
@@ -0,0 +1,172 @@
+package spark.scheduler.local
+
+import java.io.File
+import java.util.concurrent.atomic.AtomicInteger
+import java.nio.ByteBuffer
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+
+import spark._
+import spark.TaskState.TaskState
+import spark.scheduler._
+import spark.scheduler.cluster._
+
+private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging {
+ var parent: Schedulable = null
+ var weight: Int = 1
+ var minShare: Int = 0
+ var runningTasks: Int = 0
+ var priority: Int = taskSet.priority
+ var stageId: Int = taskSet.stageId
+ var name: String = "TaskSet_"+taskSet.stageId.toString
+
+
+ var failCount = new Array[Int](taskSet.tasks.size)
+ val taskInfos = new HashMap[Long, TaskInfo]
+ val numTasks = taskSet.tasks.size
+ var numFinished = 0
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ val copiesRunning = new Array[Int](numTasks)
+ val finished = new Array[Boolean](numTasks)
+ val numFailures = new Array[Int](numTasks)
+ val MAX_TASK_FAILURES = sched.maxFailures
+
+ def increaseRunningTasks(taskNum: Int): Unit = {
+ runningTasks += taskNum
+ if (parent != null) {
+ parent.increaseRunningTasks(taskNum)
+ }
+ }
+
+ def decreaseRunningTasks(taskNum: Int): Unit = {
+ runningTasks -= taskNum
+ if (parent != null) {
+ parent.decreaseRunningTasks(taskNum)
+ }
+ }
+
+ def addSchedulable(schedulable: Schedulable): Unit = {
+ //nothing
+ }
+
+ def removeSchedulable(schedulable: Schedulable): Unit = {
+ //nothing
+ }
+
+ def getSchedulableByName(name: String): Schedulable = {
+ return null
+ }
+
+ def executorLost(executorId: String, host: String): Unit = {
+ //nothing
+ }
+
+ def checkSpeculatableTasks(): Boolean = {
+ return true
+ }
+
+ def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+ var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
+ sortedTaskSetQueue += this
+ return sortedTaskSetQueue
+ }
+
+ def hasPendingTasks(): Boolean = {
+ return true
+ }
+
+ def findTask(): Option[Int] = {
+ for (i <- 0 to numTasks-1) {
+ if (copiesRunning(i) == 0 && !finished(i)) {
+ return Some(i)
+ }
+ }
+ return None
+ }
+
+ def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
+ SparkEnv.set(sched.env)
+ logDebug("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks))
+ if (availableCpus > 0 && numFinished < numTasks) {
+ findTask() match {
+ case Some(index) =>
+ val taskId = sched.attemptId.getAndIncrement()
+ val task = taskSet.tasks(index)
+ val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
+ taskInfos(taskId) = info
+ val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
+ val taskName = "task %s:%d".format(taskSet.id, index)
+ copiesRunning(index) += 1
+ increaseRunningTasks(1)
+ return Some(new TaskDescription(taskId, null, taskName, bytes))
+ case None => {}
+ }
+ }
+ return None
+ }
+
+ def numPendingTasksForHostPort(hostPort: String): Int = {
+ return 0
+ }
+
+ def numRackLocalPendingTasksForHost(hostPort :String): Int = {
+ return 0
+ }
+
+ def numPendingTasksForHost(hostPort: String): Int = {
+ return 0
+ }
+
+ def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ state match {
+ case TaskState.FINISHED =>
+ taskEnded(tid, state, serializedData)
+ case TaskState.FAILED =>
+ taskFailed(tid, state, serializedData)
+ case _ => {}
+ }
+ }
+
+ def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ val index = info.index
+ val task = taskSet.tasks(index)
+ info.markSuccessful()
+ val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
+ result.metrics.resultSize = serializedData.limit()
+ sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
+ numFinished += 1
+ decreaseRunningTasks(1)
+ finished(index) = true
+ if (numFinished == numTasks) {
+ sched.taskSetFinished(this)
+ }
+ }
+
+ def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ val index = info.index
+ val task = taskSet.tasks(index)
+ info.markFailed()
+ decreaseRunningTasks(1)
+ val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader)
+ if (!finished(index)) {
+ copiesRunning(index) -= 1
+ numFailures(index) += 1
+ val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
+ logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n")))
+ if (numFailures(index) > MAX_TASK_FAILURES) {
+ val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description)
+ decreaseRunningTasks(runningTasks)
+ sched.listener.taskSetFailed(taskSet, errorMessage)
+ // need to delete failed Taskset from schedule queue
+ sched.taskSetFinished(this)
+ }
+ }
+ }
+
+ def error(message: String) {
+ }
+}
diff --git a/core/src/main/scala/spark/serializer/Serializer.scala b/core/src/main/scala/spark/serializer/Serializer.scala
index aca86ab6f0..2ad73b711d 100644
--- a/core/src/main/scala/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/spark/serializer/Serializer.scala
@@ -1,10 +1,13 @@
package spark.serializer
-import java.nio.ByteBuffer
import java.io.{EOFException, InputStream, OutputStream}
+import java.nio.ByteBuffer
+
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+
import spark.util.ByteBufferInputStream
+
/**
* A serializer. Because some serialization libraries are not thread safe, this class is used to
* create [[spark.serializer.SerializerInstance]] objects that do the actual serialization and are
@@ -14,6 +17,7 @@ trait Serializer {
def newInstance(): SerializerInstance
}
+
/**
* An instance of a serializer, for use by one thread at a time.
*/
@@ -45,6 +49,7 @@ trait SerializerInstance {
}
}
+
/**
* A stream for writing serialized objects.
*/
@@ -61,6 +66,7 @@ trait SerializationStream {
}
}
+
/**
* A stream for reading serialized objects.
*/
diff --git a/core/src/main/scala/spark/serializer/SerializerManager.scala b/core/src/main/scala/spark/serializer/SerializerManager.scala
new file mode 100644
index 0000000000..60b2aac797
--- /dev/null
+++ b/core/src/main/scala/spark/serializer/SerializerManager.scala
@@ -0,0 +1,45 @@
+package spark.serializer
+
+import java.util.concurrent.ConcurrentHashMap
+
+
+/**
+ * A service that returns a serializer object given the serializer's class name. If a previous
+ * instance of the serializer object has been created, the get method returns that instead of
+ * creating a new one.
+ */
+private[spark] class SerializerManager {
+
+ private val serializers = new ConcurrentHashMap[String, Serializer]
+ private var _default: Serializer = _
+
+ def default = _default
+
+ def setDefault(clsName: String): Serializer = {
+ _default = get(clsName)
+ _default
+ }
+
+ def get(clsName: String): Serializer = {
+ if (clsName == null) {
+ default
+ } else {
+ var serializer = serializers.get(clsName)
+ if (serializer != null) {
+ // If the serializer has been created previously, reuse that.
+ serializer
+ } else this.synchronized {
+ // Otherwise, create a new one. But make sure no other thread has attempted
+ // to create another new one at the same time.
+ serializer = serializers.get(clsName)
+ if (serializer == null) {
+ val clsLoader = Thread.currentThread.getContextClassLoader
+ serializer =
+ Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer]
+ serializers.put(clsName, serializer)
+ }
+ serializer
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockException.scala b/core/src/main/scala/spark/storage/BlockException.scala
new file mode 100644
index 0000000000..f275d476df
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockException.scala
@@ -0,0 +1,5 @@
+package spark.storage
+
+private[spark]
+case class BlockException(blockId: String, message: String) extends Exception(message)
+
diff --git a/core/src/main/scala/spark/storage/BlockFetchTracker.scala b/core/src/main/scala/spark/storage/BlockFetchTracker.scala
index 993aece1f7..0718156b1b 100644
--- a/core/src/main/scala/spark/storage/BlockFetchTracker.scala
+++ b/core/src/main/scala/spark/storage/BlockFetchTracker.scala
@@ -1,10 +1,10 @@
package spark.storage
private[spark] trait BlockFetchTracker {
- def totalBlocks : Int
- def numLocalBlocks: Int
- def numRemoteBlocks: Int
- def remoteFetchTime : Long
- def fetchWaitTime: Long
- def remoteBytesRead : Long
+ def totalBlocks : Int
+ def numLocalBlocks: Int
+ def numRemoteBlocks: Int
+ def remoteFetchTime : Long
+ def fetchWaitTime: Long
+ def remoteBytesRead : Long
}
diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
new file mode 100644
index 0000000000..bec876213e
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
@@ -0,0 +1,330 @@
+package spark.storage
+
+import java.nio.ByteBuffer
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashSet
+import scala.collection.mutable.Queue
+
+import io.netty.buffer.ByteBuf
+
+import spark.Logging
+import spark.Utils
+import spark.SparkException
+import spark.network.BufferMessage
+import spark.network.ConnectionManagerId
+import spark.network.netty.ShuffleCopier
+import spark.serializer.Serializer
+
+
+/**
+ * A block fetcher iterator interface. There are two implementations:
+ *
+ * BasicBlockFetcherIterator: uses a custom-built NIO communication layer.
+ * NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer.
+ *
+ * Eventually we would like the two to converge and use a single NIO-based communication layer,
+ * but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores),
+ * NIO would perform poorly and thus the need for the Netty OIO one.
+ */
+
+private[storage]
+trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])]
+ with Logging with BlockFetchTracker {
+ def initialize()
+}
+
+
+private[storage]
+object BlockFetcherIterator {
+
+ // A request to fetch one or more blocks, complete with their sizes
+ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
+ val size = blocks.map(_._2).sum
+ }
+
+ // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
+ // the block (since we want all deserializaton to happen in the calling thread); can also
+ // represent a fetch failure if size == -1.
+ class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
+ def failed: Boolean = size == -1
+ }
+
+ class BasicBlockFetcherIterator(
+ private val blockManager: BlockManager,
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ serializer: Serializer)
+ extends BlockFetcherIterator {
+
+ import blockManager._
+
+ private var _remoteBytesRead = 0l
+ private var _remoteFetchTime = 0l
+ private var _fetchWaitTime = 0l
+
+ if (blocksByAddress == null) {
+ throw new IllegalArgumentException("BlocksByAddress is null")
+ }
+
+ // Total number blocks fetched (local + remote). Also number of FetchResults expected
+ protected var _numBlocksToFetch = 0
+
+ protected var startTime = System.currentTimeMillis
+
+ // This represents the number of local blocks, also counting zero-sized blocks
+ private var numLocal = 0
+ // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
+ protected val localBlocksToFetch = new ArrayBuffer[String]()
+
+ // This represents the number of remote blocks, also counting zero-sized blocks
+ private var numRemote = 0
+ // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
+ protected val remoteBlocksToFetch = new HashSet[String]()
+
+ // A queue to hold our results.
+ protected val results = new LinkedBlockingQueue[FetchResult]
+
+ // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+ // the number of bytes in flight is limited to maxBytesInFlight
+ private val fetchRequests = new Queue[FetchRequest]
+
+ // Current bytes in flight from our requests
+ private var bytesInFlight = 0L
+
+ protected def sendRequest(req: FetchRequest) {
+ logDebug("Sending request for %d blocks (%s) from %s".format(
+ req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort))
+ val cmId = new ConnectionManagerId(req.address.host, req.address.port)
+ val blockMessageArray = new BlockMessageArray(req.blocks.map {
+ case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
+ })
+ bytesInFlight += req.size
+ val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
+ val fetchStart = System.currentTimeMillis()
+ val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
+ future.onSuccess {
+ case Some(message) => {
+ val fetchDone = System.currentTimeMillis()
+ _remoteFetchTime += fetchDone - fetchStart
+ val bufferMessage = message.asInstanceOf[BufferMessage]
+ val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
+ for (blockMessage <- blockMessageArray) {
+ if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
+ throw new SparkException(
+ "Unexpected message " + blockMessage.getType + " received from " + cmId)
+ }
+ val blockId = blockMessage.getId
+ results.put(new FetchResult(blockId, sizeMap(blockId),
+ () => dataDeserialize(blockId, blockMessage.getData, serializer)))
+ _remoteBytesRead += req.size
+ logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+ }
+ }
+ case None => {
+ logError("Could not get block(s) from " + cmId)
+ for ((blockId, size) <- req.blocks) {
+ results.put(new FetchResult(blockId, -1, null))
+ }
+ }
+ }
+ }
+
+ protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
+ // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+ // at most maxBytesInFlight in order to limit the amount of data in flight.
+ val remoteRequests = new ArrayBuffer[FetchRequest]
+ for ((address, blockInfos) <- blocksByAddress) {
+ if (address == blockManagerId) {
+ numLocal = blockInfos.size
+ // Filter out zero-sized blocks
+ localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
+ _numBlocksToFetch += localBlocksToFetch.size
+ } else {
+ numRemote += blockInfos.size
+ // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
+ // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
+ // nodes, rather than blocking on reading output from one node.
+ val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
+ logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
+ val iterator = blockInfos.iterator
+ var curRequestSize = 0L
+ var curBlocks = new ArrayBuffer[(String, Long)]
+ while (iterator.hasNext) {
+ val (blockId, size) = iterator.next()
+ // Skip empty blocks
+ if (size > 0) {
+ curBlocks += ((blockId, size))
+ remoteBlocksToFetch += blockId
+ _numBlocksToFetch += 1
+ curRequestSize += size
+ } else if (size < 0) {
+ throw new BlockException(blockId, "Negative block size " + size)
+ }
+ if (curRequestSize >= minRequestSize) {
+ // Add this FetchRequest
+ remoteRequests += new FetchRequest(address, curBlocks)
+ curRequestSize = 0
+ curBlocks = new ArrayBuffer[(String, Long)]
+ }
+ }
+ // Add in the final request
+ if (!curBlocks.isEmpty) {
+ remoteRequests += new FetchRequest(address, curBlocks)
+ }
+ }
+ }
+ logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " +
+ totalBlocks + " blocks")
+ remoteRequests
+ }
+
+ protected def getLocalBlocks() {
+ // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
+ // these all at once because they will just memory-map some files, so they won't consume
+ // any memory that might exceed our maxBytesInFlight
+ for (id <- localBlocksToFetch) {
+ getLocalFromDisk(id, serializer) match {
+ case Some(iter) => {
+ // Pass 0 as size since it's not in flight
+ results.put(new FetchResult(id, 0, () => iter))
+ logDebug("Got local block " + id)
+ }
+ case None => {
+ throw new BlockException(id, "Could not get block " + id + " from local machine")
+ }
+ }
+ }
+ }
+
+ override def initialize() {
+ // Split local and remote blocks.
+ val remoteRequests = splitLocalRemoteBlocks()
+ // Add the remote requests into our queue in a random order
+ fetchRequests ++= Utils.randomize(remoteRequests)
+
+ // Send out initial requests for blocks, up to our maxBytesInFlight
+ while (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+
+ val numGets = remoteRequests.size - fetchRequests.size
+ logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
+
+ // Get Local Blocks
+ startTime = System.currentTimeMillis
+ getLocalBlocks()
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+ }
+
+ //an iterator that will read fetched blocks off the queue as they arrive.
+ @volatile protected var resultsGotten = 0
+
+ override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
+
+ override def next(): (String, Option[Iterator[Any]]) = {
+ resultsGotten += 1
+ val startFetchWait = System.currentTimeMillis()
+ val result = results.take()
+ val stopFetchWait = System.currentTimeMillis()
+ _fetchWaitTime += (stopFetchWait - startFetchWait)
+ if (! result.failed) bytesInFlight -= result.size
+ while (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+ (result.blockId, if (result.failed) None else Some(result.deserialize()))
+ }
+
+ // Implementing BlockFetchTracker trait.
+ override def totalBlocks: Int = numLocal + numRemote
+ override def numLocalBlocks: Int = numLocal
+ override def numRemoteBlocks: Int = numRemote
+ override def remoteFetchTime: Long = _remoteFetchTime
+ override def fetchWaitTime: Long = _fetchWaitTime
+ override def remoteBytesRead: Long = _remoteBytesRead
+ }
+ // End of BasicBlockFetcherIterator
+
+ class NettyBlockFetcherIterator(
+ blockManager: BlockManager,
+ blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ serializer: Serializer)
+ extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) {
+
+ import blockManager._
+
+ val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest]
+
+ private def startCopiers(numCopiers: Int): List[_ <: Thread] = {
+ (for ( i <- Range(0,numCopiers) ) yield {
+ val copier = new Thread {
+ override def run(){
+ try {
+ while(!isInterrupted && !fetchRequestsSync.isEmpty) {
+ sendRequest(fetchRequestsSync.take())
+ }
+ } catch {
+ case x: InterruptedException => logInfo("Copier Interrupted")
+ //case _ => throw new SparkException("Exception Throw in Shuffle Copier")
+ }
+ }
+ }
+ copier.start
+ copier
+ }).toList
+ }
+
+ // keep this to interrupt the threads when necessary
+ private def stopCopiers() {
+ for (copier <- copiers) {
+ copier.interrupt()
+ }
+ }
+
+ override protected def sendRequest(req: FetchRequest) {
+
+ def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) {
+ val fetchResult = new FetchResult(blockId, blockSize,
+ () => dataDeserialize(blockId, blockData.nioBuffer, serializer))
+ results.put(fetchResult)
+ }
+
+ logDebug("Sending request for %d blocks (%s) from %s".format(
+ req.blocks.size, Utils.memoryBytesToString(req.size), req.address.host))
+ val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort)
+ val cpier = new ShuffleCopier
+ cpier.getBlocks(cmId, req.blocks, putResult)
+ logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host )
+ }
+
+ private var copiers: List[_ <: Thread] = null
+
+ override def initialize() {
+ // Split Local Remote Blocks and set numBlocksToFetch
+ val remoteRequests = splitLocalRemoteBlocks()
+ // Add the remote requests into our queue in a random order
+ for (request <- Utils.randomize(remoteRequests)) {
+ fetchRequestsSync.put(request)
+ }
+
+ copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt)
+ logInfo("Started " + fetchRequestsSync.size + " remote gets in " +
+ Utils.getUsedTimeMs(startTime))
+
+ // Get Local Blocks
+ startTime = System.currentTimeMillis
+ getLocalBlocks()
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+ }
+
+ override def next(): (String, Option[Iterator[Any]]) = {
+ resultsGotten += 1
+ val result = results.take()
+ // If all the results has been retrieved, copiers will exit automatically
+ (result.blockId, if (result.failed) None else Some(result.deserialize()))
+ }
+ }
+ // End of NettyBlockFetcherIterator
+}
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index d3f6cd78dc..4bb4927b4a 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -2,10 +2,8 @@ package spark.storage
import java.io.{InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
-import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue}
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
-import scala.collection.JavaConversions._
+import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet}
import akka.actor.{ActorSystem, Cancellable, Props}
import scala.concurrent.{Await, Future}
@@ -16,7 +14,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-import spark.{Logging, SizeEstimator, SparkEnv, SparkException, Utils}
+import spark.{Logging, SparkEnv, SparkException, Utils}
import spark.network._
import spark.serializer.Serializer
import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap}
@@ -24,30 +22,35 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam
import sun.nio.ch.DirectBuffer
-private[spark]
-case class BlockException(blockId: String, message: String, ex: Exception = null)
-extends Exception(message)
-
-private[spark]
-class BlockManager(
+private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
val master: BlockManagerMaster,
- val serializer: Serializer,
+ val defaultSerializer: Serializer,
maxMemory: Long)
extends Logging {
- class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
- var pending: Boolean = true
- var size: Long = -1L
- var failed: Boolean = false
+ private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
+ @volatile var pending: Boolean = true
+ @volatile var size: Long = -1L
+ @volatile var initThread: Thread = null
+ @volatile var failed = false
+
+ setInitThread()
+
+ private def setInitThread() {
+ // Set current thread as init thread - waitForReady will not block this thread
+ // (in case there is non trivial initialization which ends up calling waitForReady as part of
+ // initialization itself)
+ this.initThread = Thread.currentThread()
+ }
/**
* Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
* Return true if the block is available, false otherwise.
*/
def waitForReady(): Boolean = {
- if (pending) {
+ if (initThread != Thread.currentThread() && pending) {
synchronized {
while (pending) this.wait()
}
@@ -57,35 +60,51 @@ class BlockManager(
/** Mark this BlockInfo as ready (i.e. block is finished writing) */
def markReady(sizeInBytes: Long) {
+ assert (pending)
+ size = sizeInBytes
+ initThread = null
+ failed = false
+ initThread = null
+ pending = false
synchronized {
- pending = false
- failed = false
- size = sizeInBytes
this.notifyAll()
}
}
/** Mark this BlockInfo as ready but failed */
def markFailure() {
+ assert (pending)
+ size = 0
+ initThread = null
+ failed = true
+ initThread = null
+ pending = false
synchronized {
- failed = true
- pending = false
this.notifyAll()
}
}
}
+ val shuffleBlockManager = new ShuffleBlockManager(this)
+
private val blockInfo = new TimeStampedHashMap[String, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
- private[storage] val diskStore: BlockStore =
+ private[storage] val diskStore: DiskStore =
new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
+ // If we use Netty for shuffle, start a new Netty-based shuffle sender service.
+ private val nettyPort: Int = {
+ val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
+ val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt
+ if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0
+ }
+
val connectionManager = new ConnectionManager(0)
implicit val futureExecContext = connectionManager.futureExecContext
val blockManagerId = BlockManagerId(
- executorId, connectionManager.id.host, connectionManager.id.port)
+ executorId, connectionManager.id.host, connectionManager.id.port, nettyPort)
// Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
// for receiving shuffle outputs)
@@ -101,7 +120,7 @@ class BlockManager(
val heartBeatFrequency = BlockManager.getHeartBeatFrequencyFromSystemProperties
- val host = System.getProperty("spark.hostname", Utils.localHostName())
+ val hostPort = Utils.localHostPort()
val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
@@ -212,9 +231,12 @@ class BlockManager(
* Tell the master about the current storage status of a block. This will send a block update
* message reflecting the current status, *not* the desired storage level in its block info.
* For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk.
+ *
+ * droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid).
+ * This ensures that update in master will compensate for the increase in memory on slave.
*/
- def reportBlockStatus(blockId: String, info: BlockInfo) {
- val needReregister = !tryToReportBlockStatus(blockId, info)
+ def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) {
+ val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize)
if (needReregister) {
logInfo("Got told to reregister updating block " + blockId)
// Reregistering will report our new block for free.
@@ -228,7 +250,7 @@ class BlockManager(
* which will be true if the block was successfully recorded and false if
* the slave needs to re-register.
*/
- private def tryToReportBlockStatus(blockId: String, info: BlockInfo): Boolean = {
+ private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = {
val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized {
info.level match {
case null =>
@@ -237,7 +259,7 @@ class BlockManager(
val inMem = level.useMemory && memoryStore.contains(blockId)
val onDisk = level.useDisk && diskStore.contains(blockId)
val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication)
- val memSize = if (inMem) memoryStore.getSize(blockId) else 0L
+ val memSize = if (inMem) memoryStore.getSize(blockId) else droppedMemorySize
val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L
(storageLevel, memSize, diskSize, info.tellMaster)
}
@@ -250,26 +272,24 @@ class BlockManager(
}
}
-
/**
- * Get locations of the block.
+ * Get locations of an array of blocks.
*/
- def getLocations(blockId: String): Seq[String] = {
+ def getLocationBlockIds(blockIds: Array[String]): Array[Seq[BlockManagerId]] = {
val startTimeMs = System.currentTimeMillis
- var managers = master.getLocations(blockId)
- val locations = managers.map(_.ip)
- logDebug("Got block locations in " + Utils.getUsedTimeMs(startTimeMs))
- return locations
+ val locations = master.getLocations(blockIds).toArray
+ logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
+ locations
}
/**
- * Get locations of an array of blocks.
+ * A short-circuited method to get blocks directly from disk. This is used for getting
+ * shuffle blocks. It is safe to do so without a lock on block info since disk store
+ * never deletes (recent) items.
*/
- def getLocations(blockIds: Array[String]): Array[Seq[String]] = {
- val startTimeMs = System.currentTimeMillis
- val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray
- logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
- return locations
+ def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ diskStore.getValues(blockId, serializer).orElse(
+ sys.error("Block " + blockId + " not found on disk, though it should be"))
}
/**
@@ -277,18 +297,6 @@ class BlockManager(
*/
def getLocal(blockId: String): Option[Iterator[Any]] = {
logDebug("Getting local block " + blockId)
-
- // As an optimization for map output fetches, if the block is for a shuffle, return it
- // without acquiring a lock; the disk store never deletes (recent) items so this should work
- if (blockId.startsWith("shuffle_")) {
- return diskStore.getValues(blockId) match {
- case Some(iterator) =>
- Some(iterator)
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- }
-
val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
@@ -339,6 +347,8 @@ class BlockManager(
case Some(bytes) =>
// Put a copy of the block back in memory before returning it. Note that we can't
// put the ByteBuffer returned by the disk store as that's a memory-mapped file.
+ // The use of rewind assumes this.
+ assert (0 == bytes.position())
val copyForMemory = ByteBuffer.allocate(bytes.limit)
copyForMemory.put(bytes)
memoryStore.putBytes(blockId, copyForMemory, level)
@@ -372,7 +382,7 @@ class BlockManager(
// As an optimization for map output fetches, if the block is for a shuffle, return it
// without acquiring a lock; the disk store never deletes (recent) items so this should work
- if (blockId.startsWith("shuffle_")) {
+ if (ShuffleBlockManager.isShuffle(blockId)) {
return diskStore.getBytes(blockId) match {
case Some(bytes) =>
Some(bytes)
@@ -411,6 +421,7 @@ class BlockManager(
// Read it as a byte buffer into memory first, then return it
diskStore.getBytes(blockId) match {
case Some(bytes) =>
+ assert (0 == bytes.position())
if (level.useMemory) {
if (level.deserialized) {
memoryStore.putBytes(blockId, bytes, level)
@@ -450,7 +461,7 @@ class BlockManager(
for (loc <- locations) {
logDebug("Getting remote block " + blockId + " from " + loc)
val data = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port))
+ GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
if (data != null) {
return Some(dataDeserialize(blockId, data))
}
@@ -473,9 +484,19 @@ class BlockManager(
* fashion as they're received. Expects a size in bytes to be provided for each block fetched,
* so that we can control the maxMegabytesInFlight for the fetch.
*/
- def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])])
+ def getMultiple(
+ blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer)
: BlockFetcherIterator = {
- return new BlockFetcherIterator(this, blocksByAddress)
+
+ val iter =
+ if (System.getProperty("spark.shuffle.use.netty", "false").toBoolean) {
+ new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer)
+ } else {
+ new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer)
+ }
+
+ iter.initialize()
+ iter
}
def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
@@ -486,6 +507,22 @@ class BlockManager(
}
/**
+ * A short circuited method to get a block writer that can write data directly to disk.
+ * This is currently used for writing shuffle files out. Callers should handle error
+ * cases.
+ */
+ def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ : BlockObjectWriter = {
+ val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
+ writer.registerCloseEventHandler(() => {
+ val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
+ blockInfo.put(blockId, myInfo)
+ myInfo.markReady(writer.size())
+ })
+ writer
+ }
+
+ /**
* Put a new block of values to the block manager. Returns its (estimated) size in bytes.
*/
def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel,
@@ -501,17 +538,26 @@ class BlockManager(
throw new IllegalArgumentException("Storage level is null or invalid")
}
- val oldBlock = blockInfo.get(blockId).orNull
- if (oldBlock != null && oldBlock.waitForReady()) {
- logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
- return oldBlock.size
- }
-
// Remember the block's storage level so that we can correctly drop it to disk if it needs
// to be dropped right after it got put into memory. Note, however, that other threads will
// not be able to get() this block until we call markReady on its BlockInfo.
- val myInfo = new BlockInfo(level, tellMaster)
- blockInfo.put(blockId, myInfo)
+ val myInfo = {
+ val tinfo = new BlockInfo(level, tellMaster)
+ // Do atomically !
+ val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
+
+ if (oldBlockOpt.isDefined) {
+ if (oldBlockOpt.get.waitForReady()) {
+ logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
+ return oldBlockOpt.get.size
+ }
+
+ // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
+ oldBlockOpt.get
+ } else {
+ tinfo
+ }
+ }
val startTimeMs = System.currentTimeMillis
@@ -531,6 +577,7 @@ class BlockManager(
logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
+ var marked = false
try {
if (level.useMemory) {
// Save it just to memory first, even if it also has useDisk set to true; we will later
@@ -555,26 +602,25 @@ class BlockManager(
// Now that the block is in either the memory or disk store, let other threads read it,
// and tell the master about it.
+ marked = true
myInfo.markReady(size)
if (tellMaster) {
reportBlockStatus(blockId, myInfo)
}
- } catch {
+ } finally {
// If we failed at putting the block to memory/disk, notify other possible readers
// that it has failed, and then remove it from the block info map.
- case e: Exception => {
+ if (! marked) {
// Note that the remove must happen before markFailure otherwise another thread
// could've inserted a new BlockInfo before we remove it.
blockInfo.remove(blockId)
myInfo.markFailure()
- logWarning("Putting block " + blockId + " failed", e)
- throw e
+ logWarning("Putting block " + blockId + " failed")
}
}
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
-
// Replicate block if required
if (level.replication > 1) {
val remoteStartTime = System.currentTimeMillis
@@ -611,16 +657,26 @@ class BlockManager(
throw new IllegalArgumentException("Storage level is null or invalid")
}
- if (blockInfo.contains(blockId)) {
- logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
- return
- }
-
// Remember the block's storage level so that we can correctly drop it to disk if it needs
// to be dropped right after it got put into memory. Note, however, that other threads will
// not be able to get() this block until we call markReady on its BlockInfo.
- val myInfo = new BlockInfo(level, tellMaster)
- blockInfo.put(blockId, myInfo)
+ val myInfo = {
+ val tinfo = new BlockInfo(level, tellMaster)
+ // Do atomically !
+ val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
+
+ if (oldBlockOpt.isDefined) {
+ if (oldBlockOpt.get.waitForReady()) {
+ logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
+ return
+ }
+
+ // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
+ oldBlockOpt.get
+ } else {
+ tinfo
+ }
+ }
val startTimeMs = System.currentTimeMillis
@@ -639,6 +695,7 @@ class BlockManager(
logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
+ var marked = false
try {
if (level.useMemory) {
// Store it only in memory at first, even if useDisk is also set to true
@@ -649,22 +706,24 @@ class BlockManager(
diskStore.putBytes(blockId, bytes, level)
}
+ // assert (0 == bytes.position(), "" + bytes)
+
// Now that the block is in either the memory or disk store, let other threads read it,
// and tell the master about it.
+ marked = true
myInfo.markReady(bytes.limit)
if (tellMaster) {
reportBlockStatus(blockId, myInfo)
}
- } catch {
+ } finally {
// If we failed at putting the block to memory/disk, notify other possible readers
// that it has failed, and then remove it from the block info map.
- case e: Exception => {
+ if (! marked) {
// Note that the remove must happen before markFailure otherwise another thread
// could've inserted a new BlockInfo before we remove it.
blockInfo.remove(blockId)
myInfo.markFailure()
- logWarning("Putting block " + blockId + " failed", e)
- throw e
+ logWarning("Putting block " + blockId + " failed")
}
}
}
@@ -698,7 +757,7 @@ class BlockManager(
logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is "
+ data.limit() + " Bytes. To node: " + peer)
if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel),
- new ConnectionManagerId(peer.ip, peer.port))) {
+ new ConnectionManagerId(peer.host, peer.port))) {
logError("Failed to call syncPutBlock to " + peer)
}
logDebug("Replicated BlockId " + blockId + " once used " +
@@ -730,6 +789,14 @@ class BlockManager(
val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
+ // required ? As of now, this will be invoked only for blocks which are ready
+ // But in case this changes in future, adding for consistency sake.
+ if (! info.waitForReady() ) {
+ // If we get here, the block write failed.
+ logWarning("Block " + blockId + " was marked as failure. Nothing to drop")
+ return
+ }
+
val level = info.level
if (level.useDisk && !diskStore.contains(blockId)) {
logInfo("Writing block " + blockId + " to disk")
@@ -740,12 +807,13 @@ class BlockManager(
diskStore.putBytes(blockId, bytes, level)
}
}
+ val droppedMemorySize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L
val blockWasRemoved = memoryStore.remove(blockId)
if (!blockWasRemoved) {
logWarning("Block " + blockId + " could not be dropped from memory as it does not exist")
}
if (info.tellMaster) {
- reportBlockStatus(blockId, info)
+ reportBlockStatus(blockId, info, droppedMemorySize)
}
if (!level.useDisk) {
// The block is completely gone from this node; forget it so we can put() it again later.
@@ -758,9 +826,23 @@ class BlockManager(
}
/**
+ * Remove all blocks belonging to the given RDD.
+ * @return The number of blocks removed.
+ */
+ def removeRdd(rddId: Int): Int = {
+ // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps
+ // from RDD.id to blocks.
+ logInfo("Removing RDD " + rddId)
+ val rddPrefix = "rdd_" + rddId + "_"
+ val blocksToRemove = blockInfo.filter(_._1.startsWith(rddPrefix)).map(_._1)
+ blocksToRemove.foreach(blockId => removeBlock(blockId, false))
+ blocksToRemove.size
+ }
+
+ /**
* Remove a block from both memory and disk.
*/
- def removeBlock(blockId: String) {
+ def removeBlock(blockId: String, tellMaster: Boolean = true) {
logInfo("Removing block " + blockId)
val info = blockInfo.get(blockId).orNull
if (info != null) info.synchronized {
@@ -772,7 +854,7 @@ class BlockManager(
"the disk or memory store")
}
blockInfo.remove(blockId)
- if (info.tellMaster) {
+ if (tellMaster && info.tellMaster) {
reportBlockStatus(blockId, info)
}
} else {
@@ -805,7 +887,7 @@ class BlockManager(
}
def shouldCompress(blockId: String): Boolean = {
- if (blockId.startsWith("shuffle_")) {
+ if (ShuffleBlockManager.isShuffle(blockId)) {
compressShuffle
} else if (blockId.startsWith("broadcast_")) {
compressBroadcast
@@ -820,7 +902,11 @@ class BlockManager(
* Wrap an output stream for compression if block compression is enabled for its block type
*/
def wrapForCompression(blockId: String, s: OutputStream): OutputStream = {
- if (shouldCompress(blockId)) new LZFOutputStream(s) else s
+ if (shouldCompress(blockId)) {
+ (new LZFOutputStream(s)).setFinishBlockOnFlush(true)
+ } else {
+ s
+ }
}
/**
@@ -830,7 +916,10 @@ class BlockManager(
if (shouldCompress(blockId)) new LZFInputStream(s) else s
}
- def dataSerialize(blockId: String, values: Iterator[Any]): ByteBuffer = {
+ def dataSerialize(
+ blockId: String,
+ values: Iterator[Any],
+ serializer: Serializer = defaultSerializer): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
val ser = serializer.newInstance()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
@@ -842,7 +931,10 @@ class BlockManager(
* Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of
* the iterator is reached.
*/
- def dataDeserialize(blockId: String, bytes: ByteBuffer): Iterator[Any] = {
+ def dataDeserialize(
+ blockId: String,
+ bytes: ByteBuffer,
+ serializer: Serializer = defaultSerializer): Iterator[Any] = {
bytes.rewind()
val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true))
serializer.newInstance().deserializeStream(stream).asIterator
@@ -862,8 +954,8 @@ class BlockManager(
}
}
-private[spark]
-object BlockManager extends Logging {
+
+private[spark] object BlockManager extends Logging {
val ID_GENERATOR = new IdGenerator
@@ -873,7 +965,8 @@ object BlockManager extends Logging {
}
def getHeartBeatFrequencyFromSystemProperties: Long =
- System.getProperty("spark.storage.blockManagerHeartBeatMs", "10000").toLong
+
+ System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong / 4
def getDisableHeartBeatsForTesting: Boolean =
System.getProperty("spark.test.disableBlockManagerHeartBeat", "false").toBoolean
@@ -892,177 +985,43 @@ object BlockManager extends Logging {
}
}
}
-}
-
-class BlockFetcherIterator(
- private val blockManager: BlockManager,
- val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]
-) extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker {
- import blockManager._
-
- private var _remoteBytesRead = 0l
- private var _remoteFetchTime = 0l
- private var _fetchWaitTime = 0l
-
- if (blocksByAddress == null) {
- throw new IllegalArgumentException("BlocksByAddress is null")
- }
- val totalBlocks = blocksByAddress.map(_._2.size).sum
- logDebug("Getting " + totalBlocks + " blocks")
- var startTime = System.currentTimeMillis
- val localBlockIds = new ArrayBuffer[String]()
- val remoteBlockIds = new HashSet[String]()
-
- // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
- // the block (since we want all deserializaton to happen in the calling thread); can also
- // represent a fetch failure if size == -1.
- class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
- def failed: Boolean = size == -1
- }
-
- // A queue to hold our results.
- val results = new LinkedBlockingQueue[FetchResult]
-
- // A request to fetch one or more blocks, complete with their sizes
- class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
- val size = blocks.map(_._2).sum
- }
+ def blockIdsToExecutorLocations(blockIds: Array[String], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null): HashMap[String, List[String]] = {
+ // env == null and blockManagerMaster != null is used in tests
+ assert (env != null || blockManagerMaster != null)
+ val locationBlockIds: Seq[Seq[BlockManagerId]] =
+ if (env != null) {
+ env.blockManager.getLocationBlockIds(blockIds)
+ } else {
+ blockManagerMaster.getLocations(blockIds)
+ }
- // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
- // the number of bytes in flight is limited to maxBytesInFlight
- val fetchRequests = new Queue[FetchRequest]
+ // Convert from block master locations to executor locations (we need that for task scheduling)
+ val executorLocations = new HashMap[String, List[String]]()
+ for (i <- 0 until blockIds.length) {
+ val blockId = blockIds(i)
+ val blockLocations = locationBlockIds(i)
- // Current bytes in flight from our requests
- var bytesInFlight = 0L
+ val executors = new HashSet[String]()
- def sendRequest(req: FetchRequest) {
- logDebug("Sending request for %d blocks (%s) from %s".format(
- req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
- val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
- val blockMessageArray = new BlockMessageArray(req.blocks.map {
- case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
- })
- bytesInFlight += req.size
- val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
- val fetchStart = System.currentTimeMillis()
- val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
- future.onSuccess {
- case Some(message) => {
- val fetchDone = System.currentTimeMillis()
- _remoteFetchTime += fetchDone - fetchStart
- val bufferMessage = message.asInstanceOf[BufferMessage]
- val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
- for (blockMessage <- blockMessageArray) {
- if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
- throw new SparkException(
- "Unexpected message " + blockMessage.getType + " received from " + cmId)
- }
- val blockId = blockMessage.getId
- results.put(new FetchResult(
- blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData)))
- _remoteBytesRead += req.size
- logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+ if (env != null) {
+ for (bkLocation <- blockLocations) {
+ val executorHostPort = env.resolveExecutorIdToHostPort(bkLocation.executorId, bkLocation.host)
+ executors += executorHostPort
+ // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort)
}
- }
- case None => {
- logError("Could not get block(s) from " + cmId)
- for ((blockId, size) <- req.blocks) {
- results.put(new FetchResult(blockId, -1, null))
+ } else {
+ // Typically while testing, etc - revert to simply using host.
+ for (bkLocation <- blockLocations) {
+ executors += bkLocation.host
+ // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort)
}
}
- }
- }
- // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
- // at most maxBytesInFlight in order to limit the amount of data in flight.
- val remoteRequests = new ArrayBuffer[FetchRequest]
- for ((address, blockInfos) <- blocksByAddress) {
- if (address == blockManagerId) {
- localBlockIds ++= blockInfos.map(_._1)
- } else {
- remoteBlockIds ++= blockInfos.map(_._1)
- // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
- // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
- // nodes, rather than blocking on reading output from one node.
- val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
- logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
- val iterator = blockInfos.iterator
- var curRequestSize = 0L
- var curBlocks = new ArrayBuffer[(String, Long)]
- while (iterator.hasNext) {
- val (blockId, size) = iterator.next()
- curBlocks += ((blockId, size))
- curRequestSize += size
- if (curRequestSize >= minRequestSize) {
- // Add this FetchRequest
- remoteRequests += new FetchRequest(address, curBlocks)
- curRequestSize = 0
- curBlocks = new ArrayBuffer[(String, Long)]
- }
- }
- // Add in the final request
- if (!curBlocks.isEmpty) {
- remoteRequests += new FetchRequest(address, curBlocks)
- }
+ executorLocations.put(blockId, executors.toSeq.toList)
}
- }
- // Add the remote requests into our queue in a random order
- fetchRequests ++= Utils.randomize(remoteRequests)
- // Send out initial requests for blocks, up to our maxBytesInFlight
- while (!fetchRequests.isEmpty &&
- (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
- sendRequest(fetchRequests.dequeue())
+ executorLocations
}
- val numGets = remoteBlockIds.size - fetchRequests.size
- logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
-
- // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
- // these all at once because they will just memory-map some files, so they won't consume
- // any memory that might exceed our maxBytesInFlight
- startTime = System.currentTimeMillis
- for (id <- localBlockIds) {
- getLocal(id) match {
- case Some(iter) => {
- results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
- logDebug("Got local block " + id)
- }
- case None => {
- throw new BlockException(id, "Could not get block " + id + " from local machine")
- }
- }
- }
- logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
-
- //an iterator that will read fetched blocks off the queue as they arrive.
- var resultsGotten = 0
-
- def hasNext: Boolean = resultsGotten < totalBlocks
-
- def next(): (String, Option[Iterator[Any]]) = {
- resultsGotten += 1
- val startFetchWait = System.currentTimeMillis()
- val result = results.take()
- val stopFetchWait = System.currentTimeMillis()
- _fetchWaitTime += (stopFetchWait - startFetchWait)
- bytesInFlight -= result.size
- while (!fetchRequests.isEmpty &&
- (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
- sendRequest(fetchRequests.dequeue())
- }
- (result.blockId, if (result.failed) None else Some(result.deserialize()))
- }
-
-
- //methods to profile the block fetching
- def numLocalBlocks = localBlockIds.size
- def numRemoteBlocks = remoteBlockIds.size
-
- def remoteFetchTime = _remoteFetchTime
- def fetchWaitTime = _fetchWaitTime
-
- def remoteBytesRead = _remoteBytesRead
-
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala
index f2f1e77d41..1e557d6148 100644
--- a/core/src/main/scala/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerId.scala
@@ -2,51 +2,70 @@ package spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.ConcurrentHashMap
+import spark.Utils
/**
* This class represent an unique identifier for a BlockManager.
* The first 2 constructors of this class is made private to ensure that
- * BlockManagerId objects can be created only using the factory method in
- * [[spark.storage.BlockManager$]]. This allows de-duplication of ID objects.
+ * BlockManagerId objects can be created only using the apply method in
+ * the companion object. This allows de-duplication of ID objects.
* Also, constructor parameters are private to ensure that parameters cannot
* be modified from outside this class.
*/
private[spark] class BlockManagerId private (
private var executorId_ : String,
- private var ip_ : String,
- private var port_ : Int
+ private var host_ : String,
+ private var port_ : Int,
+ private var nettyPort_ : Int
) extends Externalizable {
- private def this() = this(null, null, 0) // For deserialization only
+ private def this() = this(null, null, 0, 0) // For deserialization only
def executorId: String = executorId_
- def ip: String = ip_
+ if (null != host_){
+ Utils.checkHost(host_, "Expected hostname")
+ assert (port_ > 0)
+ }
+
+ def hostPort: String = {
+ // DEBUG code
+ Utils.checkHost(host)
+ assert (port > 0)
+
+ host + ":" + port
+ }
+
+ def host: String = host_
def port: Int = port_
+ def nettyPort: Int = nettyPort_
+
override def writeExternal(out: ObjectOutput) {
out.writeUTF(executorId_)
- out.writeUTF(ip_)
+ out.writeUTF(host_)
out.writeInt(port_)
+ out.writeInt(nettyPort_)
}
override def readExternal(in: ObjectInput) {
executorId_ = in.readUTF()
- ip_ = in.readUTF()
+ host_ = in.readUTF()
port_ = in.readInt()
+ nettyPort_ = in.readInt()
}
@throws(classOf[IOException])
private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this)
- override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, ip, port)
+ override def toString = "BlockManagerId(%s, %s, %d, %d)".format(executorId, host, port, nettyPort)
- override def hashCode: Int = (executorId.hashCode * 41 + ip.hashCode) * 41 + port
+ override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port + nettyPort
override def equals(that: Any) = that match {
case id: BlockManagerId =>
- executorId == id.executorId && port == id.port && ip == id.ip
+ executorId == id.executorId && port == id.port && host == id.host && nettyPort == id.nettyPort
case _ =>
false
}
@@ -55,8 +74,17 @@ private[spark] class BlockManagerId private (
private[spark] object BlockManagerId {
- def apply(execId: String, ip: String, port: Int) =
- getCachedBlockManagerId(new BlockManagerId(execId, ip, port))
+ /**
+ * Returns a [[spark.storage.BlockManagerId]] for the given configuraiton.
+ *
+ * @param execId ID of the executor.
+ * @param host Host name of the block manager.
+ * @param port Port of the block manager.
+ * @param nettyPort Optional port for the Netty-based shuffle sender.
+ * @return A new [[spark.storage.BlockManagerId]].
+ */
+ def apply(execId: String, host: String, port: Int, nettyPort: Int) =
+ getCachedBlockManagerId(new BlockManagerId(execId, host, port, nettyPort))
def apply(in: ObjectInput) = {
val obj = new BlockManagerId()
@@ -67,11 +95,7 @@ private[spark] object BlockManagerId {
val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()
def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
- if (blockManagerIdCache.containsKey(id)) {
- blockManagerIdCache.get(id)
- } else {
- blockManagerIdCache.put(id, id)
- id
- }
+ blockManagerIdCache.putIfAbsent(id, id)
+ blockManagerIdCache.get(id)
}
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
index 4e55936d28..6a9278292e 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -9,10 +9,13 @@ import scala.util.Random
import akka.actor.{Actor, ActorRef, ActorSystem, Props}
import scala.concurrent.Await
+import scala.concurrent.Future
+import scala.concurrent.ExecutionContext.Implicits.global
+
import akka.pattern.ask
import scala.concurrent.duration._
-import spark.{Logging, SparkException, Utils}
+import spark.{Logging, SparkException}
private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging {
@@ -21,7 +24,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster"
- val timeout = 10.seconds
+ val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
/** Remove a dead executor from the driver actor. This is only called on the driver side. */
def removeExecutor(execId: String) {
@@ -87,6 +90,19 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
}
/**
+ * Remove all blocks belonging to the given RDD.
+ */
+ def removeRdd(rddId: Int, blocking: Boolean) {
+ val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId))
+ future onFailure {
+ case e: Throwable => logError("Failed to remove RDD " + rddId, e)
+ }
+ if (blocking) {
+ Await.result(future, timeout)
+ }
+ }
+
+ /**
* Return the memory status for each block manager, in the form of a map from
* the block manager's id to two long values. The first value is the maximum
* amount of memory allocated for the block manager, while the second is the
@@ -97,7 +113,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
}
def getStorageStatus: Array[StorageStatus] = {
- askDriverWithReply[ArrayBuffer[StorageStatus]](GetStorageStatus).toArray
+ askDriverWithReply[Array[StorageStatus]](GetStorageStatus)
}
/** Stop the driver actor, called only on the Spark driver node */
@@ -134,7 +150,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
val future = driverActor.ask(message)(timeout)
val result = Await.result(future, timeout)
if (result == null) {
- throw new Exception("BlockManagerMaster returned null")
+ throw new SparkException("BlockManagerMaster returned null")
}
return result.asInstanceOf[T]
} catch {
diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
index 2d39e2c15c..6b5e38124b 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
@@ -2,14 +2,16 @@ package spark.storage
import java.util.{HashMap => JHashMap}
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.mutable
import scala.collection.JavaConversions._
-import scala.util.Random
import akka.actor.{Actor, ActorRef, Cancellable}
+import akka.pattern.ask
+
import scala.concurrent.duration._
+import scala.concurrent.Future
-import spark.{Logging, Utils}
+import spark.{Logging, Utils, SparkException}
/**
* BlockManagerMasterActor is an actor on the master node to track statuses of
@@ -20,13 +22,16 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
// Mapping from block manager id to the block manager's information.
private val blockManagerInfo =
- new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo]
+ new mutable.HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo]
// Mapping from executor ID to block manager ID.
- private val blockManagerIdByExecutor = new HashMap[String, BlockManagerId]
+ private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId]
// Mapping from block id to the set of block managers that have the block.
- private val blockLocations = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]]
+ private val blockLocations = new JHashMap[String, mutable.HashSet[BlockManagerId]]
+
+ val akkaTimeout = Duration.create(
+ System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
initLogging()
@@ -34,7 +39,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
"" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong
val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs",
- "5000").toLong
+ "60000").toLong
var timeoutCheckingTask: Cancellable = null
@@ -50,28 +55,34 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
def receive = {
case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) =>
register(blockManagerId, maxMemSize, slaveActor)
+ sender ! true
case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
+ // TODO: Ideally we want to handle all the message replies in receive instead of in the
+ // individual private methods.
updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size)
case GetLocations(blockId) =>
- getLocations(blockId)
+ sender ! getLocations(blockId)
case GetLocationsMultipleBlockIds(blockIds) =>
- getLocationsMultipleBlockIds(blockIds)
+ sender ! getLocationsMultipleBlockIds(blockIds)
case GetPeers(blockManagerId, size) =>
- getPeersDeterministic(blockManagerId, size)
- /*getPeers(blockManagerId, size)*/
+ sender ! getPeers(blockManagerId, size)
case GetMemoryStatus =>
- getMemoryStatus
+ sender ! memoryStatus
case GetStorageStatus =>
- getStorageStatus
+ sender ! storageStatus
+
+ case RemoveRdd(rddId) =>
+ sender ! removeRdd(rddId)
case RemoveBlock(blockId) =>
- removeBlock(blockId)
+ removeBlockFromWorkers(blockId)
+ sender ! true
case RemoveExecutor(execId) =>
removeExecutor(execId)
@@ -81,7 +92,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
logInfo("Stopping BlockManagerMaster")
sender ! true
if (timeoutCheckingTask != null) {
- timeoutCheckingTask.cancel
+ timeoutCheckingTask.cancel()
}
context.stop(self)
@@ -89,13 +100,36 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
expireDeadHosts()
case HeartBeat(blockManagerId) =>
- heartBeat(blockManagerId)
+ sender ! heartBeat(blockManagerId)
case other =>
- logInfo("Got unknown message: " + other)
+ logWarning("Got unknown message: " + other)
+ }
+
+ private def removeRdd(rddId: Int): Future[Seq[Int]] = {
+ // First remove the metadata for the given RDD, and then asynchronously remove the blocks
+ // from the slaves.
+
+ val prefix = "rdd_" + rddId + "_"
+ // Find all blocks for the given RDD, remove the block from both blockLocations and
+ // the blockManagerInfo that is tracking the blocks.
+ val blocks = blockLocations.keySet().filter(_.startsWith(prefix))
+ blocks.foreach { blockId =>
+ val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId)
+ bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId)))
+ blockLocations.remove(blockId)
+ }
+
+ // Ask the slaves to remove the RDD, and put the result in a sequence of Futures.
+ // The dispatcher is used as an implicit argument into the Future sequence construction.
+ import context.dispatcher
+ val removeMsg = RemoveRdd(rddId)
+ Future.sequence(blockManagerInfo.values.map { bm =>
+ bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
+ }.toSeq)
}
- def removeBlockManager(blockManagerId: BlockManagerId) {
+ private def removeBlockManager(blockManagerId: BlockManagerId) {
val info = blockManagerInfo(blockManagerId)
// Remove the block manager from blockManagerIdByExecutor.
@@ -106,7 +140,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
val iterator = info.blocks.keySet.iterator
while (iterator.hasNext) {
val blockId = iterator.next
- val locations = blockLocations.get(blockId)._2
+ val locations = blockLocations.get(blockId)
locations -= blockManagerId
if (locations.size == 0) {
blockLocations.remove(locations)
@@ -114,11 +148,11 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
}
- def expireDeadHosts() {
+ private def expireDeadHosts() {
logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.")
val now = System.currentTimeMillis()
val minSeenTime = now - slaveTimeout
- val toRemove = new HashSet[BlockManagerId]
+ val toRemove = new mutable.HashSet[BlockManagerId]
for (info <- blockManagerInfo.values) {
if (info.lastSeenMs < minSeenTime) {
logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: " +
@@ -129,31 +163,26 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
toRemove.foreach(removeBlockManager)
}
- def removeExecutor(execId: String) {
+ private def removeExecutor(execId: String) {
logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.")
blockManagerIdByExecutor.get(execId).foreach(removeBlockManager)
- sender ! true
}
- def heartBeat(blockManagerId: BlockManagerId) {
+ private def heartBeat(blockManagerId: BlockManagerId): Boolean = {
if (!blockManagerInfo.contains(blockManagerId)) {
- if (blockManagerId.executorId == "<driver>" && !isLocal) {
- sender ! true
- } else {
- sender ! false
- }
+ blockManagerId.executorId == "<driver>" && !isLocal
} else {
blockManagerInfo(blockManagerId).updateLastSeenMs()
- sender ! true
+ true
}
}
// Remove a block from the slaves that have it. This can only be used to remove
// blocks that the master knows about.
- private def removeBlock(blockId: String) {
- val block = blockLocations.get(blockId)
- if (block != null) {
- block._2.foreach { blockManagerId: BlockManagerId =>
+ private def removeBlockFromWorkers(blockId: String) {
+ val locations = blockLocations.get(blockId)
+ if (locations != null) {
+ locations.foreach { blockManagerId: BlockManagerId =>
val blockManager = blockManagerInfo.get(blockManagerId)
if (blockManager.isDefined) {
// Remove the block from the slave's BlockManager.
@@ -163,23 +192,20 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
}
}
- sender ! true
}
// Return a map from the block manager id to max memory and remaining memory.
- private def getMemoryStatus() {
- val res = blockManagerInfo.map { case(blockManagerId, info) =>
+ private def memoryStatus: Map[BlockManagerId, (Long, Long)] = {
+ blockManagerInfo.map { case(blockManagerId, info) =>
(blockManagerId, (info.maxMem, info.remainingMem))
}.toMap
- sender ! res
}
- private def getStorageStatus() {
- val res = blockManagerInfo.map { case(blockManagerId, info) =>
+ private def storageStatus: Array[StorageStatus] = {
+ blockManagerInfo.map { case(blockManagerId, info) =>
import collection.JavaConverters._
StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap)
- }
- sender ! res
+ }.toArray
}
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
@@ -188,7 +214,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
} else if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
case Some(manager) =>
- // A block manager of the same host name already exists
+ // A block manager of the same executor already exists.
+ // This should never happen. Let's just quit.
logError("Got two different block manager registrations on " + id.executorId)
System.exit(1)
case None =>
@@ -197,7 +224,6 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
blockManagerInfo(id) = new BlockManagerMasterActor.BlockManagerInfo(
id, System.currentTimeMillis(), maxMemSize, slaveActor)
}
- sender ! true
}
private def updateBlockInfo(
@@ -226,12 +252,12 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize)
- var locations: HashSet[BlockManagerId] = null
+ var locations: mutable.HashSet[BlockManagerId] = null
if (blockLocations.containsKey(blockId)) {
- locations = blockLocations.get(blockId)._2
+ locations = blockLocations.get(blockId)
} else {
- locations = new HashSet[BlockManagerId]
- blockLocations.put(blockId, (storageLevel.replication, locations))
+ locations = new mutable.HashSet[BlockManagerId]
+ blockLocations.put(blockId, locations)
}
if (storageLevel.isValid) {
@@ -247,70 +273,24 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
sender ! true
}
- private def getLocations(blockId: String) {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " " + blockId + " "
- if (blockLocations.containsKey(blockId)) {
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- res.appendAll(blockLocations.get(blockId)._2)
- sender ! res.toSeq
- } else {
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- sender ! res
- }
+ private def getLocations(blockId: String): Seq[BlockManagerId] = {
+ if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty
}
- private def getLocationsMultipleBlockIds(blockIds: Array[String]) {
- def getLocations(blockId: String): Seq[BlockManagerId] = {
- val tmp = blockId
- if (blockLocations.containsKey(blockId)) {
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- res.appendAll(blockLocations.get(blockId)._2)
- return res.toSeq
- } else {
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- return res.toSeq
- }
- }
-
- var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]]
- for (blockId <- blockIds) {
- res.append(getLocations(blockId))
- }
- sender ! res.toSeq
+ private def getLocationsMultipleBlockIds(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+ blockIds.map(blockId => getLocations(blockId))
}
- private def getPeers(blockManagerId: BlockManagerId, size: Int) {
- var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- res.appendAll(peers)
- res -= blockManagerId
- val rand = new Random(System.currentTimeMillis())
- while (res.length > size) {
- res.remove(rand.nextInt(res.length))
- }
- sender ! res.toSeq
- }
-
- private def getPeersDeterministic(blockManagerId: BlockManagerId, size: Int) {
- var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ private def getPeers(blockManagerId: BlockManagerId, size: Int): Seq[BlockManagerId] = {
+ val peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
val selfIndex = peers.indexOf(blockManagerId)
if (selfIndex == -1) {
- throw new Exception("Self index for " + blockManagerId + " not found")
+ throw new SparkException("Self index for " + blockManagerId + " not found")
}
// Note that this logic will select the same node multiple times if there aren't enough peers
- var index = selfIndex
- while (res.size < size) {
- index += 1
- if (index == selfIndex) {
- throw new Exception("More peer expected than available")
- }
- res += peers(index % peers.size)
- }
- sender ! res.toSeq
+ Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq
}
}
@@ -333,8 +313,8 @@ object BlockManagerMasterActor {
// Mapping from block id to its status.
private val _blocks = new JHashMap[String, BlockStatus]
- logInfo("Registering block manager %s:%d with %s RAM".format(
- blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem)))
+ logInfo("Registering block manager %s with %s RAM".format(
+ blockManagerId.hostPort, Utils.memoryBytesToString(maxMem)))
def updateLastSeenMs() {
_lastSeenMs = System.currentTimeMillis()
@@ -359,13 +339,13 @@ object BlockManagerMasterActor {
_blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize))
if (storageLevel.useMemory) {
_remainingMem -= memSize
- logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
+ logInfo("Added %s in memory on %s (size: %s, free: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize),
Utils.memoryBytesToString(_remainingMem)))
}
if (storageLevel.useDisk) {
- logInfo("Added %s on disk on %s:%d (size: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
+ logInfo("Added %s on disk on %s (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize)))
}
} else if (_blocks.containsKey(blockId)) {
// If isValid is not true, drop the block.
@@ -373,17 +353,24 @@ object BlockManagerMasterActor {
_blocks.remove(blockId)
if (blockStatus.storageLevel.useMemory) {
_remainingMem += blockStatus.memSize
- logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
+ logInfo("Removed %s on %s in memory (size: %s, free: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize),
Utils.memoryBytesToString(_remainingMem)))
}
if (blockStatus.storageLevel.useDisk) {
- logInfo("Removed %s on %s:%d on disk (size: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
+ logInfo("Removed %s on %s on disk (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize)))
}
}
}
+ def removeBlock(blockId: String) {
+ if (_blocks.containsKey(blockId)) {
+ _remainingMem += _blocks.get(blockId).memSize
+ _blocks.remove(blockId)
+ }
+ }
+
def remainingMem: Long = _remainingMem
def lastSeenMs: Long = _lastSeenMs
diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala
index cff48d9909..0010726c8d 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala
@@ -16,6 +16,9 @@ sealed trait ToBlockManagerSlave
private[spark]
case class RemoveBlock(blockId: String) extends ToBlockManagerSlave
+// Remove all blocks belonging to a specific RDD.
+private[spark] case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave
+
//////////////////////////////////////////////////////////////////////////////////
// Messages from slaves to the master.
diff --git a/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala
index f570cdc52d..b264d1deb5 100644
--- a/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala
@@ -11,6 +11,12 @@ import spark.{Logging, SparkException, Utils}
*/
class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor {
override def receive = {
- case RemoveBlock(blockId) => blockManager.removeBlock(blockId)
+
+ case RemoveBlock(blockId) =>
+ blockManager.removeBlock(blockId)
+
+ case RemoveRdd(rddId) =>
+ val numBlocksRemoved = blockManager.removeRdd(rddId)
+ sender ! numBlocksRemoved
}
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala
index a3397a0fb4..631455abcd 100644
--- a/core/src/main/scala/spark/storage/BlockManagerUI.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala
@@ -1,10 +1,12 @@
package spark.storage
import akka.actor.{ActorRef, ActorSystem}
+
import akka.util.Timeout
import scala.concurrent.duration._
import spray.httpx.TwirlSupport._
import spray.routing.Directives
+
import spark.{Logging, SparkContext}
import spark.util.AkkaUtils
import spark.Utils
@@ -20,20 +22,21 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef,
implicit val implicitActorSystem = actorSystem
val STATIC_RESOURCE_DIR = "spark/deploy/static"
- implicit val timeout = Timeout(10 seconds)
+ implicit val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ val host = Utils.localHostName()
+ val port = if (System.getProperty("spark.ui.port") != null) {
+ System.getProperty("spark.ui.port").toInt
+ } else {
+ // TODO: Unfortunately, it's not possible to pass port 0 to spray and figure out which
+ // random port it bound to, so we have to try to find a local one by creating a socket.
+ Utils.findFreePort()
+ }
/** Start a HTTP server to run the Web interface */
def start() {
try {
- val port = if (System.getProperty("spark.ui.port") != null) {
- System.getProperty("spark.ui.port").toInt
- } else {
- // TODO: Unfortunately, it's not possible to pass port 0 to spray and figure out which
- // random port it bound to, so we have to try to find a local one by creating a socket.
- Utils.findFreePort()
- }
- AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, handler)
- logInfo("Started BlockManager web UI at http://%s:%d".format(Utils.localHostName(), port))
+ AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, handler, "BlockManagerHTTPServer")
+ logInfo("Started BlockManager web UI at http://%s:%d".format(host, port))
} catch {
case e: Exception =>
logError("Failed to create BlockManager WebUI", e)
@@ -74,4 +77,6 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef,
}
}
}
+
+ private[spark] def appUIAddress = "http://" + host + ":" + port
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
index d2985559c1..3057ade233 100644
--- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
@@ -2,13 +2,7 @@ package spark.storage
import java.nio.ByteBuffer
-import scala.actors._
-import scala.actors.Actor._
-import scala.actors.remote._
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
-import scala.util.Random
-
-import spark.{Logging, Utils, SparkEnv}
+import spark.{Logging, Utils}
import spark.network._
/**
@@ -19,7 +13,7 @@ import spark.network._
*/
private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
initLogging()
-
+
blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive)
def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = {
@@ -51,7 +45,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
logDebug("Received [" + pB + "]")
putBlock(pB.id, pB.data, pB.level)
return None
- }
+ }
case BlockMessage.TYPE_GET_BLOCK => {
val gB = new GetBlock(blockMessage.getId)
logDebug("Received [" + gB + "]")
@@ -88,30 +82,26 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
private[spark] object BlockManagerWorker extends Logging {
private var blockManagerWorker: BlockManagerWorker = null
- private val DATA_TRANSFER_TIME_OUT_MS: Long = 500
- private val REQUEST_RETRY_INTERVAL_MS: Long = 1000
-
+
initLogging()
-
+
def startBlockManagerWorker(manager: BlockManager) {
blockManagerWorker = new BlockManagerWorker(manager)
}
-
+
def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = {
val blockManager = blockManagerWorker.blockManager
- val connectionManager = blockManager.connectionManager
- val serializer = blockManager.serializer
+ val connectionManager = blockManager.connectionManager
val blockMessage = BlockMessage.fromPutBlock(msg)
val blockMessageArray = new BlockMessageArray(blockMessage)
val resultMessage = connectionManager.sendMessageReliablySync(
toConnManagerId, blockMessageArray.toBufferMessage)
return (resultMessage != None)
}
-
+
def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
val blockManager = blockManagerWorker.blockManager
- val connectionManager = blockManager.connectionManager
- val serializer = blockManager.serializer
+ val connectionManager = blockManager.connectionManager
val blockMessage = BlockMessage.fromGetBlock(msg)
val blockMessageArray = new BlockMessageArray(blockMessage)
val responseMessage = connectionManager.sendMessageReliablySync(
diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala
index a25decb123..ee0c5ff9a2 100644
--- a/core/src/main/scala/spark/storage/BlockMessageArray.scala
+++ b/core/src/main/scala/spark/storage/BlockMessageArray.scala
@@ -115,6 +115,7 @@ private[spark] object BlockMessageArray {
val newBuffer = ByteBuffer.allocate(totalSize)
newBuffer.clear()
bufferMessage.buffers.foreach(buffer => {
+ assert (0 == buffer.position())
newBuffer.put(buffer)
buffer.rewind()
})
diff --git a/core/src/main/scala/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/spark/storage/BlockObjectWriter.scala
new file mode 100644
index 0000000000..42e2b07d5c
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockObjectWriter.scala
@@ -0,0 +1,50 @@
+package spark.storage
+
+import java.nio.ByteBuffer
+
+
+/**
+ * An interface for writing JVM objects to some underlying storage. This interface allows
+ * appending data to an existing block, and can guarantee atomicity in the case of faults
+ * as it allows the caller to revert partial writes.
+ *
+ * This interface does not support concurrent writes.
+ */
+abstract class BlockObjectWriter(val blockId: String) {
+
+ var closeEventHandler: () => Unit = _
+
+ def open(): BlockObjectWriter
+
+ def close() {
+ closeEventHandler()
+ }
+
+ def isOpen: Boolean
+
+ def registerCloseEventHandler(handler: () => Unit) {
+ closeEventHandler = handler
+ }
+
+ /**
+ * Flush the partial writes and commit them as a single atomic block. Return the
+ * number of bytes written for this commit.
+ */
+ def commit(): Long
+
+ /**
+ * Reverts writes that haven't been flushed yet. Callers should invoke this function
+ * when there are runtime exceptions.
+ */
+ def revertPartialWrites()
+
+ /**
+ * Writes an object.
+ */
+ def write(value: Any)
+
+ /**
+ * Size of the valid writes, in bytes.
+ */
+ def size(): Long
+}
diff --git a/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala b/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala
deleted file mode 100644
index f6c28dce52..0000000000
--- a/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala
+++ /dev/null
@@ -1,12 +0,0 @@
-package spark.storage
-
-private[spark] trait DelegateBlockFetchTracker extends BlockFetchTracker {
- var delegate : BlockFetchTracker = _
- def setDelegate(d: BlockFetchTracker) {delegate = d}
- def totalBlocks = delegate.totalBlocks
- def numLocalBlocks = delegate.numLocalBlocks
- def numRemoteBlocks = delegate.numRemoteBlocks
- def remoteFetchTime = delegate.remoteFetchTime
- def fetchWaitTime = delegate.fetchWaitTime
- def remoteBytesRead = delegate.remoteBytesRead
-}
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index ddbf8821ad..da859eebcb 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -1,41 +1,126 @@
package spark.storage
+import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
import java.nio.ByteBuffer
-import java.io.{File, FileOutputStream, RandomAccessFile}
+import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
import java.util.{Random, Date}
import java.text.SimpleDateFormat
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
import scala.collection.mutable.ArrayBuffer
-import spark.executor.ExecutorExitCode
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import spark.Utils
+import spark.executor.ExecutorExitCode
+import spark.serializer.{Serializer, SerializationStream}
+import spark.Logging
+import spark.network.netty.ShuffleSender
+import spark.network.netty.PathResolver
+
/**
* Stores BlockManager blocks on disk.
*/
private class DiskStore(blockManager: BlockManager, rootDirs: String)
- extends BlockStore(blockManager) {
+ extends BlockStore(blockManager) with Logging {
+
+ class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ extends BlockObjectWriter(blockId) {
+
+ private val f: File = createFile(blockId /*, allowAppendExisting */)
+
+ // The file channel, used for repositioning / truncating the file.
+ private var channel: FileChannel = null
+ private var bs: OutputStream = null
+ private var objOut: SerializationStream = null
+ private var lastValidPosition = 0L
+ private var initialized = false
+
+ override def open(): DiskBlockObjectWriter = {
+ val fos = new FileOutputStream(f, true)
+ channel = fos.getChannel()
+ bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize))
+ objOut = serializer.newInstance().serializeStream(bs)
+ initialized = true
+ this
+ }
+
+ override def close() {
+ if (initialized) {
+ objOut.close()
+ bs.close()
+ channel = null
+ bs = null
+ objOut = null
+ }
+ // Invoke the close callback handler.
+ super.close()
+ }
- val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
+ override def isOpen: Boolean = objOut != null
+ // Flush the partial writes, and set valid length to be the length of the entire file.
+ // Return the number of bytes written for this commit.
+ override def commit(): Long = {
+ if (initialized) {
+ // NOTE: Flush the serializer first and then the compressed/buffered output stream
+ objOut.flush()
+ bs.flush()
+ val prevPos = lastValidPosition
+ lastValidPosition = channel.position()
+ lastValidPosition - prevPos
+ } else {
+ // lastValidPosition is zero if stream is uninitialized
+ lastValidPosition
+ }
+ }
+
+ override def revertPartialWrites() {
+ if (initialized) {
+ // Discard current writes. We do this by flushing the outstanding writes and
+ // truncate the file to the last valid position.
+ objOut.flush()
+ bs.flush()
+ channel.truncate(lastValidPosition)
+ }
+ }
+
+ override def write(value: Any) {
+ if (!initialized) {
+ open()
+ }
+ objOut.writeObject(value)
+ }
+
+ override def size(): Long = lastValidPosition
+ }
+
+ private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
+ private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
+
+ private var shuffleSender : ShuffleSender = null
// Create one local directory for each path mentioned in spark.local.dir; then, inside this
// directory, create multiple subdirectories that we will hash files into, in order to avoid
// having really large inodes at the top level.
- val localDirs = createLocalDirs()
- val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
+ private val localDirs: Array[File] = createLocalDirs()
+ private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
addShutdownHook()
+ def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ : BlockObjectWriter = {
+ new DiskBlockObjectWriter(blockId, serializer, bufferSize)
+ }
+
override def getSize(blockId: String): Long = {
getFile(blockId).length()
}
- override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
+ // So that we do not modify the input offsets !
+ // duplicate does not copy buffer, so inexpensive
+ val bytes = _bytes.duplicate()
logDebug("Attempting to put block " + blockId)
val startTime = System.currentTimeMillis
val file = createFile(blockId)
@@ -49,6 +134,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
blockId, Utils.memoryBytesToString(bytes.limit), (finishTime - startTime)))
}
+ private def getFileBytes(file: File): ByteBuffer = {
+ val length = file.length()
+ val channel = new RandomAccessFile(file, "r").getChannel()
+ val buffer = try {
+ channel.map(MapMode.READ_ONLY, 0, length)
+ } finally {
+ channel.close()
+ }
+
+ buffer
+ }
+
override def putValues(
blockId: String,
values: ArrayBuffer[Any],
@@ -61,18 +158,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
val file = createFile(blockId)
val fileOut = blockManager.wrapForCompression(blockId,
new FastBufferedOutputStream(new FileOutputStream(file)))
- val objOut = blockManager.serializer.newInstance().serializeStream(fileOut)
+ val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut)
objOut.writeAll(values.iterator)
objOut.close()
val length = file.length()
+
+ val timeTaken = System.currentTimeMillis - startTime
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.memoryBytesToString(length), (System.currentTimeMillis - startTime)))
+ blockId, Utils.memoryBytesToString(length), timeTaken))
if (returnValues) {
// Return a byte buffer for the contents of the file
- val channel = new RandomAccessFile(file, "r").getChannel()
- val buffer = channel.map(MapMode.READ_ONLY, 0, length)
- channel.close()
+ val buffer = getFileBytes(file)
PutResult(length, Right(buffer))
} else {
PutResult(length, null)
@@ -81,10 +178,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
override def getBytes(blockId: String): Option[ByteBuffer] = {
val file = getFile(blockId)
- val length = file.length().toInt
- val channel = new RandomAccessFile(file, "r").getChannel()
- val bytes = channel.map(MapMode.READ_ONLY, 0, length)
- channel.close()
+ val bytes = getFileBytes(file)
Some(bytes)
}
@@ -92,11 +186,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
}
+ /**
+ * A version of getValues that allows a custom serializer. This is used as part of the
+ * shuffle short-circuit code.
+ */
+ def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer))
+ }
+
override def remove(blockId: String): Boolean = {
val file = getFile(blockId)
if (file.exists()) {
file.delete()
- true
} else {
false
}
@@ -106,10 +207,13 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
getFile(blockId).exists()
}
- private def createFile(blockId: String): File = {
+ private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = {
val file = getFile(blockId)
- if (file.exists()) {
- throw new Exception("File for block " + blockId + " already exists on disk: " + file)
+ if (!allowAppendExisting && file.exists()) {
+ // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
+ // was rescheduled on the same machine as the old task.
+ logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
+ file.delete()
}
file
}
@@ -144,8 +248,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private def createLocalDirs(): Array[File] = {
logDebug("Creating local directories at root dirs '" + rootDirs + "'")
val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
- rootDirs.split(",").map(rootDir => {
- var foundLocalDir: Boolean = false
+ rootDirs.split(",").map { rootDir =>
+ var foundLocalDir = false
var localDir: File = null
var localDirId: String = null
var tries = 0
@@ -156,12 +260,11 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
localDir = new File(rootDir, "spark-local-" + localDirId)
if (!localDir.exists) {
- localDir.mkdirs()
- foundLocalDir = true
+ foundLocalDir = localDir.mkdirs()
}
} catch {
case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir failed", e)
+ logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
}
}
if (!foundLocalDir) {
@@ -171,19 +274,40 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
logInfo("Created local directory at " + localDir)
localDir
- })
+ }
}
private def addShutdownHook() {
+ localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
override def run() {
logDebug("Shutdown hook called")
- try {
- localDirs.foreach(localDir => Utils.deleteRecursively(localDir))
- } catch {
- case t: Throwable => logError("Exception while deleting local spark dirs", t)
+ localDirs.foreach { localDir =>
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ } catch {
+ case t: Throwable =>
+ logError("Exception while deleting local spark dir: " + localDir, t)
+ }
+ }
+ if (shuffleSender != null) {
+ shuffleSender.stop
}
}
})
}
+
+ private[storage] def startShuffleBlockSender(port: Int): Int = {
+ val pResolver = new PathResolver {
+ override def getAbsolutePath(blockId: String): String = {
+ if (!blockId.startsWith("shuffle_")) {
+ return null
+ }
+ DiskStore.this.getFile(blockId).getAbsolutePath()
+ }
+ }
+ shuffleSender = new ShuffleSender(port, pResolver)
+ logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port)
+ shuffleSender.port
+ }
}
diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala
index 949588476c..eba5ee507f 100644
--- a/core/src/main/scala/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/spark/storage/MemoryStore.scala
@@ -31,7 +31,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
+ // Work on a duplicate - since the original input might be used elsewhere.
+ val bytes = _bytes.duplicate()
bytes.rewind()
if (level.deserialized) {
val values = blockManager.dataDeserialize(blockId, bytes)
diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
new file mode 100644
index 0000000000..44638e0c2d
--- /dev/null
+++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
@@ -0,0 +1,50 @@
+package spark.storage
+
+import spark.serializer.Serializer
+
+
+private[spark]
+class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter])
+
+
+private[spark]
+trait ShuffleBlocks {
+ def acquireWriters(mapId: Int): ShuffleWriterGroup
+ def releaseWriters(group: ShuffleWriterGroup)
+}
+
+
+private[spark]
+class ShuffleBlockManager(blockManager: BlockManager) {
+
+ def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = {
+ new ShuffleBlocks {
+ // Get a group of writers for a map task.
+ override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
+ val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+ val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
+ blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
+ }
+ new ShuffleWriterGroup(mapId, writers)
+ }
+
+ override def releaseWriters(group: ShuffleWriterGroup) = {
+ // Nothing really to release here.
+ }
+ }
+ }
+}
+
+
+private[spark]
+object ShuffleBlockManager {
+
+ // Returns the block id for a given shuffle block.
+ def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = {
+ "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId
+ }
+
+ // Returns true if the block is a shuffle block.
+ def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_")
+}
diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala
index 3b5a77ab22..cc0c354e7e 100644
--- a/core/src/main/scala/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/spark/storage/StorageLevel.scala
@@ -123,11 +123,7 @@ object StorageLevel {
val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]()
private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = {
- if (storageLevelCache.containsKey(level)) {
- storageLevelCache.get(level)
- } else {
- storageLevelCache.put(level, level)
- level
- }
+ storageLevelCache.putIfAbsent(level, level)
+ storageLevelCache.get(level)
}
}
diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala
index dec47a9d41..950c0cdf35 100644
--- a/core/src/main/scala/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/spark/storage/StorageUtils.scala
@@ -4,9 +4,9 @@ import spark.{Utils, SparkContext}
import BlockManagerMasterActor.BlockStatus
private[spark]
-case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
+case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
blocks: Map[String, BlockStatus]) {
-
+
def memUsed(blockPrefix: String = "") = {
blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize).
reduceOption(_+_).getOrElse(0l)
@@ -22,53 +22,62 @@ case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
}
case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
- numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) {
+ numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long)
+ extends Ordered[RDDInfo] {
override def toString = {
import Utils.memoryBytesToString
"RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id,
storageLevel.toString, numCachedPartitions, numPartitions, memoryBytesToString(memSize), memoryBytesToString(diskSize))
}
+
+ override def compare(that: RDDInfo) = {
+ this.id - that.id
+ }
}
/* Helper methods for storage-related objects */
private[spark]
object StorageUtils {
- /* Given the current storage status of the BlockManager, returns information for each RDD */
- def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus],
+ /* Given the current storage status of the BlockManager, returns information for each RDD */
+ def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus],
sc: SparkContext) : Array[RDDInfo] = {
- rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc)
+ rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc)
}
- /* Given a list of BlockStatus objets, returns information for each RDD */
- def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
+ /* Given a list of BlockStatus objets, returns information for each RDD */
+ def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
sc: SparkContext) : Array[RDDInfo] = {
// Group by rddId, ignore the partition name
- val groupedRddBlocks = infos.groupBy { case(k, v) =>
+ val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) =>
k.substring(0,k.lastIndexOf('_'))
}.mapValues(_.values.toArray)
// For each RDD, generate an RDDInfo object
- groupedRddBlocks.map { case(rddKey, rddBlocks) =>
-
+ val rddInfos = groupedRddBlocks.map { case (rddKey, rddBlocks) =>
// Add up memory and disk sizes
val memSize = rddBlocks.map(_.memSize).reduce(_ + _)
val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _)
// Find the id of the RDD, e.g. rdd_1 => 1
val rddId = rddKey.split("_").last.toInt
- // Get the friendly name for the rdd, if available.
- val rdd = sc.persistentRdds(rddId)
- val rddName = Option(rdd.name).getOrElse(rddKey)
- val rddStorageLevel = rdd.getStorageLevel
- RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.partitions.size, memSize, diskSize)
- }.toArray
+ // Get the friendly name and storage level for the RDD, if available
+ sc.persistentRdds.get(rddId).map { r =>
+ val rddName = Option(r.name).getOrElse(rddKey)
+ val rddStorageLevel = r.getStorageLevel
+ RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize)
+ }
+ }.flatten.toArray
+
+ scala.util.Sorting.quickSort(rddInfos)
+
+ rddInfos
}
- /* Removes all BlockStatus object that are not part of a block prefix */
- def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus],
+ /* Removes all BlockStatus object that are not part of a block prefix */
+ def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus],
prefix: String) : Array[StorageStatus] = {
storageStatusList.map { status =>
diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala
index e16915c8e9..ea39888c21 100644
--- a/core/src/main/scala/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/spark/util/AkkaUtils.scala
@@ -5,13 +5,15 @@ import com.typesafe.config.ConfigFactory
import scala.concurrent.duration._
import akka.pattern.ask
import akka.remote.RemoteActorRefProvider
+
import spray.routing.Route
import spray.io.IOExtension
import spray.routing.HttpServiceActor
import spray.can.server.{HttpServer, ServerSettings}
import spray.io.SingletonHandler
import scala.concurrent.Await
-import spark.SparkException
+import spark.{Utils, SparkException}
+
import java.util.concurrent.TimeoutException
/**
@@ -29,9 +31,14 @@ private[spark] object AkkaUtils {
def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = {
val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt
val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt
- val akkaTimeout = System.getProperty("spark.akka.timeout", "20").toInt
+
+ val akkaTimeout = System.getProperty("spark.akka.timeout", "60").toInt
+
val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt
- val lifecycleEvents = System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean
+ val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off"
+ // 10 seconds is the default akka timeout, but in a cluster, we need higher by default.
+ val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt
+
val akkaConf = ConfigFactory.parseString("""
akka.daemonic = on
akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
@@ -45,10 +52,11 @@ private[spark] object AkkaUtils {
akka.remote.netty.execution-pool-size = %d
akka.actor.default-dispatcher.throughput = %d
akka.remote.log-remote-lifecycle-events = %s
+ akka.remote.netty.write-timeout = %ds
""".format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize,
- if (lifecycleEvents) "on" else "off"))
+ lifecycleEvents, akkaWriteTimeout))
- val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader)
+ val actorSystem = ActorSystem(name, akkaConf)
// Figure out the port number we bound to, in case port was passed as 0. This is a bit of a
// hack because Akka doesn't let you figure out the port through the public API yet.
@@ -60,12 +68,13 @@ private[spark] object AkkaUtils {
/**
* Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to
* handle requests. Returns the bound port or throws a SparkException on failure.
+ * TODO: Not changing ip to host here - is it required ?
*/
- def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route) = {
+ def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, name: String = "HttpServer") = {
val ioWorker = IOExtension(actorSystem).ioBridge()
val httpService = actorSystem.actorOf(Props(HttpServiceActor(route)))
val server = actorSystem.actorOf(
- Props(new HttpServer(ioWorker, SingletonHandler(httpService), ServerSettings())), name = "HttpServer")
+ Props(new HttpServer(ioWorker, SingletonHandler(httpService), ServerSettings())), name = name)
actorSystem.registerOnTermination { actorSystem.stop(ioWorker) }
val timeout = 3.seconds
val future = server.ask(HttpServer.Bind(ip, port))(timeout)
diff --git a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala
new file mode 100644
index 0000000000..4bc5db8bb7
--- /dev/null
+++ b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala
@@ -0,0 +1,45 @@
+package spark.util
+
+import java.io.Serializable
+import java.util.{PriorityQueue => JPriorityQueue}
+import scala.collection.generic.Growable
+import scala.collection.JavaConverters._
+
+/**
+ * Bounded priority queue. This class wraps the original PriorityQueue
+ * class and modifies it such that only the top K elements are retained.
+ * The top K elements are defined by an implicit Ordering[A].
+ */
+class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A])
+ extends Iterable[A] with Growable[A] with Serializable {
+
+ private val underlying = new JPriorityQueue[A](maxSize, ord)
+
+ override def iterator: Iterator[A] = underlying.iterator.asScala
+
+ override def ++=(xs: TraversableOnce[A]): this.type = {
+ xs.foreach { this += _ }
+ this
+ }
+
+ override def +=(elem: A): this.type = {
+ if (size < maxSize) underlying.offer(elem)
+ else maybeReplaceLowest(elem)
+ this
+ }
+
+ override def +=(elem1: A, elem2: A, elems: A*): this.type = {
+ this += elem1 += elem2 ++= elems
+ }
+
+ override def clear() { underlying.clear() }
+
+ private def maybeReplaceLowest(a: A): Boolean = {
+ val head = underlying.peek()
+ if (head != null && ord.gt(a, head)) {
+ underlying.poll()
+ underlying.offer(a)
+ } else false
+ }
+}
+
diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala
index 5f80180339..2b980340b7 100644
--- a/core/src/main/scala/spark/util/StatCounter.scala
+++ b/core/src/main/scala/spark/util/StatCounter.scala
@@ -37,17 +37,23 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
if (other == this) {
merge(other.copy()) // Avoid overwriting fields in a weird order
} else {
- val delta = other.mu - mu
- if (other.n * 10 < n) {
- mu = mu + (delta * other.n) / (n + other.n)
- } else if (n * 10 < other.n) {
- mu = other.mu - (delta * n) / (n + other.n)
- } else {
- mu = (mu * n + other.mu * other.n) / (n + other.n)
+ if (n == 0) {
+ mu = other.mu
+ m2 = other.m2
+ n = other.n
+ } else if (other.n != 0) {
+ val delta = other.mu - mu
+ if (other.n * 10 < n) {
+ mu = mu + (delta * other.n) / (n + other.n)
+ } else if (n * 10 < other.n) {
+ mu = other.mu - (delta * n) / (n + other.n)
+ } else {
+ mu = (mu * n + other.mu * other.n) / (n + other.n)
+ }
+ m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
+ n += other.n
}
- m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
- n += other.n
- this
+ this
}
}
diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala
index 4afba0a4c3..e95ca1fc8e 100644
--- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala
@@ -3,6 +3,7 @@ package spark.util
import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConversions
import scala.collection.mutable.Map
+import spark.scheduler.MapStatus
/**
* This is a custom implementation of scala.collection.mutable.Map which stores the insertion
@@ -42,6 +43,13 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging {
this
}
+ // Should we return previous value directly or as Option ?
+ def putIfAbsent(key: A, value: B): Option[B] = {
+ val prev = internalMap.putIfAbsent(key, (value, currentTime))
+ if (prev != null) Some(prev._1) else None
+ }
+
+
override def -= (key: A): this.type = {
internalMap.remove(key)
this
diff --git a/core/src/main/scala/spark/util/TimedIterator.scala b/core/src/main/scala/spark/util/TimedIterator.scala
deleted file mode 100644
index 539b01f4ce..0000000000
--- a/core/src/main/scala/spark/util/TimedIterator.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-package spark.util
-
-/**
- * A utility for tracking the total time an iterator takes to iterate through its elements.
- *
- * In general, this should only be used if you expect it to take a considerable amount of time
- * (eg. milliseconds) to get each element -- otherwise, the timing won't be very accurate,
- * and you are probably just adding more overhead
- */
-class TimedIterator[+A](val sub: Iterator[A]) extends Iterator[A] {
- private var netMillis = 0l
- private var nElems = 0
- def hasNext = {
- val start = System.currentTimeMillis()
- val r = sub.hasNext
- val end = System.currentTimeMillis()
- netMillis += (end - start)
- r
- }
- def next = {
- val start = System.currentTimeMillis()
- val r = sub.next
- val end = System.currentTimeMillis()
- netMillis += (end - start)
- nElems += 1
- r
- }
-
- def getNetMillis = netMillis
- def getAverageTimePerItem = netMillis / nElems.toDouble
-
-}
diff --git a/core/src/main/twirl/spark/deploy/master/app_details.scala.html b/core/src/main/twirl/spark/deploy/master/app_details.scala.html
index 301a7e2124..5e5e5de551 100644
--- a/core/src/main/twirl/spark/deploy/master/app_details.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/app_details.scala.html
@@ -9,19 +9,17 @@
<li><strong>ID:</strong> @app.id</li>
<li><strong>Description:</strong> @app.desc.name</li>
<li><strong>User:</strong> @app.desc.user</li>
- <li><strong>Cores:</strong>
- @app.desc.cores
- (@app.coresGranted Granted
- @if(app.desc.cores == Integer.MAX_VALUE) {
-
+ <li><strong>Cores:</strong>
+ @if(app.desc.maxCores == Integer.MAX_VALUE) {
+ Unlimited (@app.coresGranted granted)
} else {
- , @app.coresLeft
+ @app.desc.maxCores (@app.coresGranted granted, @app.coresLeft left)
}
- )
</li>
<li><strong>Memory per Slave:</strong> @app.desc.memoryPerSlave</li>
<li><strong>Submit Date:</strong> @app.submitDate</li>
<li><strong>State:</strong> @app.state</li>
+ <li><strong><a href="@app.appUiUrl">Application Detail UI</a></strong></li>
</ul>
</div>
</div>
diff --git a/core/src/main/twirl/spark/deploy/master/executor_row.scala.html b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html
index d2d80fad48..21e72c7aab 100644
--- a/core/src/main/twirl/spark/deploy/master/executor_row.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html
@@ -3,7 +3,7 @@
<tr>
<td>@executor.id</td>
<td>
- <a href="@executor.worker.webUiAddress">@executor.worker.id</href>
+ <a href="@executor.worker.webUiAddress">@executor.worker.id</a>
</td>
<td>@executor.cores</td>
<td>@executor.memory</td>
diff --git a/core/src/main/twirl/spark/deploy/master/index.scala.html b/core/src/main/twirl/spark/deploy/master/index.scala.html
index ac51a39a51..b9b9f08810 100644
--- a/core/src/main/twirl/spark/deploy/master/index.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/index.scala.html
@@ -2,7 +2,7 @@
@import spark.deploy.master._
@import spark.Utils
-@spark.common.html.layout(title = "Spark Master on " + state.host) {
+@spark.common.html.layout(title = "Spark Master on " + state.host + ":" + state.port) {
<!-- Cluster Details -->
<div class="row">
diff --git a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html
index be69e9bf02..46277ca421 100644
--- a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html
@@ -4,7 +4,7 @@
<tr>
<td>
- <a href="@worker.webUiAddress">@worker.id</href>
+ <a href="@worker.webUiAddress">@worker.id</a>
</td>
<td>@{worker.host}:@{worker.port}</td>
<td>@worker.state</td>
diff --git a/core/src/main/twirl/spark/deploy/worker/index.scala.html b/core/src/main/twirl/spark/deploy/worker/index.scala.html
index c39f769a73..0e66af9284 100644
--- a/core/src/main/twirl/spark/deploy/worker/index.scala.html
+++ b/core/src/main/twirl/spark/deploy/worker/index.scala.html
@@ -1,7 +1,7 @@
@(worker: spark.deploy.WorkerState)
@import spark.Utils
-@spark.common.html.layout(title = "Spark Worker on " + worker.host) {
+@spark.common.html.layout(title = "Spark Worker on " + worker.host + ":" + worker.port) {
<!-- Worker Details -->
<div class="row">
diff --git a/core/src/main/twirl/spark/storage/worker_table.scala.html b/core/src/main/twirl/spark/storage/worker_table.scala.html
index d54b8de4cc..cd72a688c1 100644
--- a/core/src/main/twirl/spark/storage/worker_table.scala.html
+++ b/core/src/main/twirl/spark/storage/worker_table.scala.html
@@ -12,7 +12,7 @@
<tbody>
@for(status <- workersStatusList) {
<tr>
- <td>@(status.blockManagerId.ip + ":" + status.blockManagerId.port)</td>
+ <td>@(status.blockManagerId.host + ":" + status.blockManagerId.port)</td>
<td>
@(Utils.memoryBytesToString(status.memUsed(prefix)))
(@(Utils.memoryBytesToString(status.memRemaining)) Total Available)
diff --git a/core/src/test/resources/fairscheduler.xml b/core/src/test/resources/fairscheduler.xml
new file mode 100644
index 0000000000..5a688b0ebb
--- /dev/null
+++ b/core/src/test/resources/fairscheduler.xml
@@ -0,0 +1,14 @@
+<allocations>
+<pool name="1">
+ <minShare>2</minShare>
+ <weight>1</weight>
+ <schedulingMode>FIFO</schedulingMode>
+</pool>
+<pool name="2">
+ <minShare>3</minShare>
+ <weight>1</weight>
+ <schedulingMode>FIFO</schedulingMode>
+</pool>
+<pool name="3">
+</pool>
+</allocations>
diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala
index 8836c68ae6..6785787b7e 100644
--- a/core/src/test/scala/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/spark/CheckpointSuite.scala
@@ -28,6 +28,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
}
+ test("basic checkpointing") {
+ val parCollection = sc.makeRDD(1 to 4)
+ val flatMappedRDD = parCollection.flatMap(x => 1 to x)
+ flatMappedRDD.checkpoint()
+ assert(flatMappedRDD.dependencies.head.rdd == parCollection)
+ val result = flatMappedRDD.collect()
+ assert(flatMappedRDD.dependencies.head.rdd != parCollection)
+ assert(flatMappedRDD.collect() === result)
+ }
+
test("RDDs with one-to-one dependencies") {
testCheckpointing(_.map(x => x.toString))
testCheckpointing(_.flatMap(x => 1 to x))
diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala
index 46b74fe5ee..0866fb47b3 100644
--- a/core/src/test/scala/spark/DistributedSuite.scala
+++ b/core/src/test/scala/spark/DistributedSuite.scala
@@ -3,8 +3,10 @@ package spark
import network.ConnectionManagerId
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
+import org.scalatest.concurrent.Timeouts._
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.prop.Checkers
+import org.scalatest.time.{Span, Millis}
import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
import org.scalacheck.Prop._
@@ -16,7 +18,13 @@ import scala.collection.mutable.ArrayBuffer
import SparkContext._
import storage.{GetBlock, BlockManagerWorker, StorageLevel}
-class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter with LocalSparkContext {
+
+class NotSerializableClass
+class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {}
+
+
+class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
+ with LocalSparkContext {
val clusterUrl = "local-cluster[2,1,512]"
@@ -25,6 +33,24 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
System.clearProperty("spark.storage.memoryFraction")
}
+ test("task throws not serializable exception") {
+ // Ensures that executors do not crash when an exn is not serializable. If executors crash,
+ // this test will hang. Correct behavior is that executors don't crash but fail tasks
+ // and the scheduler throws a SparkException.
+
+ // numSlaves must be less than numPartitions
+ val numSlaves = 3
+ val numPartitions = 10
+
+ sc = new SparkContext("local-cluster[%s,1,512]".format(numSlaves), "test")
+ val data = sc.parallelize(1 to 100, numPartitions).
+ map(x => throw new NotSerializableExn(new NotSerializableClass))
+ intercept[SparkException] {
+ data.count()
+ }
+ resetSparkContext()
+ }
+
test("local-cluster format") {
sc = new SparkContext("local-cluster[2,1,512]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
@@ -153,7 +179,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
val blockManager = SparkEnv.get.blockManager
blockManager.master.getLocations(blockId).foreach(id => {
val bytes = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(id.ip, id.port))
+ GetBlock(blockId), ConnectionManagerId(id.host, id.port))
val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList
assert(deserialized === (1 to 100).toList)
})
@@ -196,7 +222,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
sc = new SparkContext(clusterUrl, "test")
val data = sc.parallelize(Seq(true, true), 2)
assert(data.count === 2) // force executors to start
- val masterId = SparkEnv.get.blockManager.blockManagerId
assert(data.map(markNodeIfIdentity).collect.size === 2)
assert(data.map(failOnMarkedIdentity).collect.size === 2)
}
@@ -252,6 +277,42 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
assert(data2.count === 2)
}
}
+
+ test("unpersist RDDs") {
+ DistributedSuite.amMaster = true
+ sc = new SparkContext("local-cluster[3,1,512]", "test")
+ val data = sc.parallelize(Seq(true, false, false, false), 4)
+ data.persist(StorageLevel.MEMORY_ONLY_2)
+ data.count
+ assert(sc.persistentRdds.isEmpty === false)
+ data.unpersist()
+ assert(sc.persistentRdds.isEmpty === true)
+
+ failAfter(Span(3000, Millis)) {
+ try {
+ while (! sc.getRDDStorageInfo.isEmpty) {
+ Thread.sleep(200)
+ }
+ } catch {
+ case _ => { Thread.sleep(10) }
+ // Do nothing. We might see exceptions because block manager
+ // is racing this thread to remove entries from the driver.
+ }
+ }
+ }
+
+ test("job should fail if TaskResult exceeds Akka frame size") {
+ // We must use local-cluster mode since results are returned differently
+ // when running under LocalScheduler:
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+ val akkaFrameSize =
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
+ val rdd = sc.parallelize(Seq(1)).map{x => new Array[Byte](akkaFrameSize)}
+ val exception = intercept[SparkException] {
+ rdd.reduce((x, y) => x)
+ }
+ exception.getMessage should endWith("result exceeded Akka frame size")
+ }
}
object DistributedSuite {
diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala
index 91b48c7456..e61ff7793d 100644
--- a/core/src/test/scala/spark/FileSuite.scala
+++ b/core/src/test/scala/spark/FileSuite.scala
@@ -7,6 +7,8 @@ import scala.io.Source
import com.google.common.io.Files
import org.scalatest.FunSuite
import org.apache.hadoop.io._
+import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodec, GzipCodec}
+
import SparkContext._
@@ -26,6 +28,28 @@ class FileSuite extends FunSuite with LocalSparkContext {
assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4"))
}
+ test("text files (compressed)") {
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val normalDir = new File(tempDir, "output_normal").getAbsolutePath
+ val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath
+ val codec = new DefaultCodec()
+
+ val data = sc.parallelize("a" * 10000, 1)
+ data.saveAsTextFile(normalDir)
+ data.saveAsTextFile(compressedOutputDir, classOf[DefaultCodec])
+
+ val normalFile = new File(normalDir, "part-00000")
+ val normalContent = sc.textFile(normalDir).collect
+ assert(normalContent === Array.fill(10000)("a"))
+
+ val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension)
+ val compressedContent = sc.textFile(compressedOutputDir).collect
+ assert(compressedContent === Array.fill(10000)("a"))
+
+ assert(compressedFile.length < normalFile.length)
+ }
+
test("SequenceFiles") {
sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
@@ -37,6 +61,28 @@ class FileSuite extends FunSuite with LocalSparkContext {
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
}
+ test("SequenceFile (compressed)") {
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val normalDir = new File(tempDir, "output_normal").getAbsolutePath
+ val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath
+ val codec = new DefaultCodec()
+
+ val data = sc.parallelize(Seq.fill(100)("abc"), 1).map(x => (x, x))
+ data.saveAsSequenceFile(normalDir)
+ data.saveAsSequenceFile(compressedOutputDir, Some(classOf[DefaultCodec]))
+
+ val normalFile = new File(normalDir, "part-00000")
+ val normalContent = sc.sequenceFile[String, String](normalDir).collect
+ assert(normalContent === Array.fill(100)("abc", "abc"))
+
+ val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension)
+ val compressedContent = sc.sequenceFile[String, String](compressedOutputDir).collect
+ assert(compressedContent === Array.fill(100)("abc", "abc"))
+
+ assert(compressedFile.length < normalFile.length)
+ }
+
test("SequenceFile with writable key") {
sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index d3dcd3bbeb..d306124fca 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -8,6 +8,7 @@ import java.util.*;
import scala.Tuple2;
import com.google.common.base.Charsets;
+import org.apache.hadoop.io.compress.DefaultCodec;
import com.google.common.io.Files;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
@@ -474,6 +475,19 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void textFilesCompressed() throws IOException {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
+ rdd.saveAsTextFile(outputDir, DefaultCodec.class);
+
+ // Try reading it in as a text file RDD
+ List<String> expected = Arrays.asList("1", "2", "3", "4");
+ JavaRDD<String> readRDD = sc.textFile(outputDir);
+ Assert.assertEquals(expected, readRDD.collect());
+ }
+
+ @Test
public void sequenceFile() {
File tempDir = Files.createTempDir();
String outputDir = new File(tempDir, "output").getAbsolutePath();
@@ -620,6 +634,37 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void hadoopFileCompressed() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output_compressed").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+
+ rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
+ @Override
+ public Tuple2<IntWritable, Text> call(Tuple2<Integer, String> pair) {
+ return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
+ }
+ }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class,
+ DefaultCodec.class);
+
+ JavaPairRDD<IntWritable, Text> output = sc.hadoopFile(outputDir,
+ SequenceFileInputFormat.class, IntWritable.class, Text.class);
+
+ Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>,
+ String>() {
+ @Override
+ public String call(Tuple2<IntWritable, Text> x) {
+ return x.toString();
+ }
+ }).collect().toString());
+ }
+
+ @Test
public void zip() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() {
@@ -633,6 +678,32 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void zipPartitions() {
+ JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2);
+ JavaRDD<String> rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2);
+ FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer> sizesFn =
+ new FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer>() {
+ @Override
+ public Iterable<Integer> call(Iterator<Integer> i, Iterator<String> s) {
+ int sizeI = 0;
+ int sizeS = 0;
+ while (i.hasNext()) {
+ sizeI += 1;
+ i.next();
+ }
+ while (s.hasNext()) {
+ sizeS += 1;
+ s.next();
+ }
+ return Arrays.asList(sizeI, sizeS);
+ }
+ };
+
+ JavaRDD<Integer> sizes = rdd1.zipPartitions(sizesFn, rdd2);
+ Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString());
+ }
+
+ @Test
public void accumulators() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala
index ff00dd05dd..76d5258b02 100644
--- a/core/src/test/scala/spark/LocalSparkContext.scala
+++ b/core/src/test/scala/spark/LocalSparkContext.scala
@@ -27,6 +27,7 @@ object LocalSparkContext {
sc.stop()
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
}
/** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */
@@ -38,4 +39,4 @@ object LocalSparkContext {
}
}
-} \ No newline at end of file
+}
diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
index 3abc584b6a..6e585e1c3a 100644
--- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
@@ -8,7 +8,7 @@ import spark.storage.BlockManagerId
import spark.util.AkkaUtils
class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
-
+
test("compressSize") {
assert(MapOutputTracker.compressSize(0L) === 0)
assert(MapOutputTracker.compressSize(1L) === 1)
@@ -45,13 +45,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
Array(compressedSize1000, compressedSize10000)))
- tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
Array(compressedSize10000, compressedSize1000)))
val statuses = tracker.getServerStatuses(10, 0)
- assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000),
- (BlockManagerId("b", "hostB", 1000), size10000)))
+ assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000, 0), size1000),
+ (BlockManagerId("b", "hostB", 1000, 0), size10000)))
tracker.stop()
}
@@ -64,14 +64,14 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
Array(compressedSize1000, compressedSize1000, compressedSize1000)))
- tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
// As if we had two simulatenous fetch failures
- tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
- tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
// The remaining reduce task might try to grab the output despite the shuffle failure;
// this should cause it to fail, and the scheduler will ignore the failure due to the
@@ -80,16 +80,20 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
}
test("remote fetch") {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0)
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0)
+ System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+
val masterTracker = new MapOutputTracker()
masterTracker.trackerActor = actorSystem.actorOf(
Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker")
-
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", "localhost", 0)
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0)
val slaveTracker = new MapOutputTracker()
slaveTracker.trackerActor = slaveSystem.actorFor(
"akka://spark@localhost:" + boundPort + "/user/MapOutputTracker")
-
+
masterTracker.registerShuffle(10, 1)
masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration)
@@ -98,13 +102,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
masterTracker.registerMapOutput(10, 0, new MapStatus(
- BlockManagerId("a", "hostA", 1000), Array(compressedSize1000)))
+ BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration)
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000), size1000)))
+ Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
- masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
+ masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration)
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
diff --git a/core/src/test/scala/spark/PairRDDFunctionsSuite.scala b/core/src/test/scala/spark/PairRDDFunctionsSuite.scala
new file mode 100644
index 0000000000..682d2745bf
--- /dev/null
+++ b/core/src/test/scala/spark/PairRDDFunctionsSuite.scala
@@ -0,0 +1,287 @@
+package spark
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashSet
+
+import org.scalatest.FunSuite
+import org.scalatest.prop.Checkers
+import org.scalacheck.Arbitrary._
+import org.scalacheck.Gen
+import org.scalacheck.Prop._
+
+import com.google.common.io.Files
+
+import spark.rdd.ShuffledRDD
+import spark.SparkContext._
+
+class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
+ test("groupByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with duplicates") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with negative key hash codes") {
+ val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesForMinus1 = groups.find(_._1 == -1).get._2
+ assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with many output partitions") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
+ val groups = pairs.groupByKey(10).collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("reduceByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("reduceByKey with collectAsMap") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_).collectAsMap()
+ assert(sums.size === 2)
+ assert(sums(1) === 7)
+ assert(sums(2) === 1)
+ }
+
+ test("reduceByKey with many output partitons") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_, 10).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("reduceByKey with partitioner") {
+ val p = new Partitioner() {
+ def numPartitions = 2
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
+ val sums = pairs.reduceByKey(_+_)
+ assert(sums.collect().toSet === Set((1, 4), (0, 1)))
+ assert(sums.partitioner === Some(p))
+ // count the dependencies to make sure there is only 1 ShuffledRDD
+ val deps = new HashSet[RDD[_]]()
+ def visit(r: RDD[_]) {
+ for (dep <- r.dependencies) {
+ deps += dep.rdd
+ visit(dep.rdd)
+ }
+ }
+ visit(sums)
+ assert(deps.size === 2) // ShuffledRDD, ParallelCollection
+ }
+
+ test("join") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (2, 'x')),
+ (2, (1, 'y')),
+ (2, (1, 'z'))
+ ))
+ }
+
+ test("join all-to-all") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 6)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (1, 'y')),
+ (1, (2, 'x')),
+ (1, (2, 'y')),
+ (1, (3, 'x')),
+ (1, (3, 'y'))
+ ))
+ }
+
+ test("leftOuterJoin") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.leftOuterJoin(rdd2).collect()
+ assert(joined.size === 5)
+ assert(joined.toSet === Set(
+ (1, (1, Some('x'))),
+ (1, (2, Some('x'))),
+ (2, (1, Some('y'))),
+ (2, (1, Some('z'))),
+ (3, (1, None))
+ ))
+ }
+
+ test("rightOuterJoin") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.rightOuterJoin(rdd2).collect()
+ assert(joined.size === 5)
+ assert(joined.toSet === Set(
+ (1, (Some(1), 'x')),
+ (1, (Some(2), 'x')),
+ (2, (Some(1), 'y')),
+ (2, (Some(1), 'z')),
+ (4, (None, 'w'))
+ ))
+ }
+
+ test("join with no matches") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 0)
+ }
+
+ test("join with many output partitions") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.join(rdd2, 10).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (2, 'x')),
+ (2, (1, 'y')),
+ (2, (1, 'z'))
+ ))
+ }
+
+ test("groupWith") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.groupWith(rdd2).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))),
+ (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))),
+ (3, (ArrayBuffer(1), ArrayBuffer())),
+ (4, (ArrayBuffer(), ArrayBuffer('w')))
+ ))
+ }
+
+ test("zero-partition RDD") {
+ val emptyDir = Files.createTempDir()
+ val file = sc.textFile(emptyDir.getAbsolutePath)
+ assert(file.partitions.size == 0)
+ assert(file.collect().toList === Nil)
+ // Test that a shuffle on the file works, because this used to be a bug
+ assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
+ }
+
+ test("keys and values") {
+ val rdd = sc.parallelize(Array((1, "a"), (2, "b")))
+ assert(rdd.keys.collect().toList === List(1, 2))
+ assert(rdd.values.collect().toList === List("a", "b"))
+ }
+
+ test("default partitioner uses partition size") {
+ // specify 2000 partitions
+ val a = sc.makeRDD(Array(1, 2, 3, 4), 2000)
+ // do a map, which loses the partitioner
+ val b = a.map(a => (a, (a * 2).toString))
+ // then a group by, and see we didn't revert to 2 partitions
+ val c = b.groupByKey()
+ assert(c.partitions.size === 2000)
+ }
+
+ test("default partitioner uses largest partitioner") {
+ val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2)
+ val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000)
+ val c = a.join(b)
+ assert(c.partitions.size === 2000)
+ }
+
+ test("subtract") {
+ val a = sc.parallelize(Array(1, 2, 3), 2)
+ val b = sc.parallelize(Array(2, 3, 4), 4)
+ val c = a.subtract(b)
+ assert(c.collect().toSet === Set(1))
+ assert(c.partitions.size === a.partitions.size)
+ }
+
+ test("subtract with narrow dependency") {
+ // use a deterministic partitioner
+ val p = new Partitioner() {
+ def numPartitions = 5
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ // partitionBy so we have a narrow dependency
+ val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
+ // more partitions/no partitioner so a shuffle dependency
+ val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
+ val c = a.subtract(b)
+ assert(c.collect().toSet === Set((1, "a"), (3, "c")))
+ // Ideally we could keep the original partitioner...
+ assert(c.partitioner === None)
+ }
+
+ test("subtractByKey") {
+ val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
+ val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4)
+ val c = a.subtractByKey(b)
+ assert(c.collect().toSet === Set((1, "a"), (1, "a")))
+ assert(c.partitions.size === a.partitions.size)
+ }
+
+ test("subtractByKey with narrow dependency") {
+ // use a deterministic partitioner
+ val p = new Partitioner() {
+ def numPartitions = 5
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ // partitionBy so we have a narrow dependency
+ val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
+ // more partitions/no partitioner so a shuffle dependency
+ val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
+ val c = a.subtractByKey(b)
+ assert(c.collect().toSet === Set((1, "a"), (1, "a")))
+ assert(c.partitioner.get === p)
+ }
+
+ test("foldByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.foldByKey(0)(_+_).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("foldByKey with mutable result type") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache()
+ // Fold the values using in-place mutation
+ val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect()
+ assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1))))
+ // Check that the mutable objects in the original RDD were not changed
+ assert(bufs.collect().toSet === Set(
+ (1, ArrayBuffer(1)),
+ (1, ArrayBuffer(2)),
+ (1, ArrayBuffer(3)),
+ (1, ArrayBuffer(1)),
+ (2, ArrayBuffer(1))))
+ }
+}
diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala
index 60db759c25..99e433e3bd 100644
--- a/core/src/test/scala/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/spark/PartitioningSuite.scala
@@ -1,13 +1,13 @@
package spark
import org.scalatest.FunSuite
-
import scala.collection.mutable.ArrayBuffer
-
import SparkContext._
+import spark.util.StatCounter
+import scala.math.abs
+
+class PartitioningSuite extends FunSuite with SharedSparkContext {
-class PartitioningSuite extends FunSuite with LocalSparkContext {
-
test("HashPartitioner equality") {
val p2 = new HashPartitioner(2)
val p4 = new HashPartitioner(4)
@@ -21,8 +21,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
}
test("RangePartitioner equality") {
- sc = new SparkContext("local", "test")
-
// Make an RDD where all the elements are the same so that the partition range bounds
// are deterministically all the same.
val rdd = sc.parallelize(Seq(1, 1, 1, 1)).map(x => (x, x))
@@ -50,7 +48,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
}
test("HashPartitioner not equal to RangePartitioner") {
- sc = new SparkContext("local", "test")
val rdd = sc.parallelize(1 to 10).map(x => (x, x))
val rangeP2 = new RangePartitioner(2, rdd)
val hashP2 = new HashPartitioner(2)
@@ -61,8 +58,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
}
test("partitioner preservation") {
- sc = new SparkContext("local", "test")
-
val rdd = sc.parallelize(1 to 10, 4).map(x => (x, x))
val grouped2 = rdd.groupByKey(2)
@@ -101,7 +96,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
}
test("partitioning Java arrays should fail") {
- sc = new SparkContext("local", "test")
val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x))
val arrPairs: RDD[(Array[Int], Int)] =
sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x))
@@ -120,4 +114,20 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array"))
}
+
+ test("zero-length partitions should be correctly handled") {
+ // Create RDD with some consecutive empty partitions (including the "first" one)
+ val rdd: RDD[Double] = sc
+ .parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8)
+ .filter(_ >= 0.0)
+
+ // Run the partitions, including the consecutive empty ones, through StatCounter
+ val stats: StatCounter = rdd.stats();
+ assert(abs(6.0 - stats.sum) < 0.01);
+ assert(abs(6.0/2 - rdd.mean) < 0.01);
+ assert(abs(1.0 - rdd.variance) < 0.01);
+ assert(abs(1.0 - rdd.stdev) < 0.01);
+
+ // Add other tests here for classes that should be able to handle empty partitions correctly
+ }
}
diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala
index a6344edf8f..1c9ca50811 100644
--- a/core/src/test/scala/spark/PipedRDDSuite.scala
+++ b/core/src/test/scala/spark/PipedRDDSuite.scala
@@ -3,10 +3,9 @@ package spark
import org.scalatest.FunSuite
import SparkContext._
-class PipedRDDSuite extends FunSuite with LocalSparkContext {
-
+class PipedRDDSuite extends FunSuite with SharedSparkContext {
+
test("basic pipe") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("cat"))
@@ -19,8 +18,45 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext {
assert(c(3) === "4")
}
+ test("advanced pipe") {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val bl = sc.broadcast(List("0"))
+
+ val piped = nums.pipe(Seq("cat"),
+ Map[String, String](),
+ (f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
+ (i:Int, f: String=> Unit) => f(i + "_"))
+
+ val c = piped.collect()
+
+ assert(c.size === 8)
+ assert(c(0) === "0")
+ assert(c(1) === "\u0001")
+ assert(c(2) === "1_")
+ assert(c(3) === "2_")
+ assert(c(4) === "0")
+ assert(c(5) === "\u0001")
+ assert(c(6) === "3_")
+ assert(c(7) === "4_")
+
+ val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2)
+ val d = nums1.groupBy(str=>str.split("\t")(0)).
+ pipe(Seq("cat"),
+ Map[String, String](),
+ (f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
+ (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect()
+ assert(d.size === 8)
+ assert(d(0) === "0")
+ assert(d(1) === "\u0001")
+ assert(d(2) === "b\t2_")
+ assert(d(3) === "b\t4_")
+ assert(d(4) === "0")
+ assert(d(5) === "\u0001")
+ assert(d(6) === "a\t1_")
+ assert(d(7) === "a\t3_")
+ }
+
test("pipe with env variable") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
val c = piped.collect()
@@ -30,7 +66,6 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext {
}
test("pipe with non-zero exit status") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe("cat nonexistent_file")
intercept[SparkException] {
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index 7fbdd44340..d8db69b1c9 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -2,13 +2,14 @@ package spark
import scala.collection.mutable.HashMap
import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts._
+import org.scalatest.time.{Span, Millis}
import spark.SparkContext._
-import spark.rdd.{CoalescedRDD, CoGroupedRDD, PartitionPruningRDD, ShuffledRDD}
+import spark.rdd.{CoalescedRDD, CoGroupedRDD, EmptyRDD, PartitionPruningRDD, ShuffledRDD}
-class RDDSuite extends FunSuite with LocalSparkContext {
+class RDDSuite extends FunSuite with SharedSparkContext {
test("basic operations") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
@@ -44,7 +45,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("SparkContext.union") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(sc.union(nums).collect().toList === List(1, 2, 3, 4))
assert(sc.union(nums, nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
@@ -53,7 +53,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("aggregate") {
- sc = new SparkContext("local", "test")
val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3)))
type StringMap = HashMap[String, Int]
val emptyMap = new StringMap {
@@ -73,27 +72,7 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
}
- test("basic checkpointing") {
- import java.io.File
- val checkpointDir = File.createTempFile("temp", "")
- checkpointDir.delete()
-
- sc = new SparkContext("local", "test")
- sc.setCheckpointDir(checkpointDir.toString)
- val parCollection = sc.makeRDD(1 to 4)
- val flatMappedRDD = parCollection.flatMap(x => 1 to x)
- flatMappedRDD.checkpoint()
- assert(flatMappedRDD.dependencies.head.rdd == parCollection)
- val result = flatMappedRDD.collect()
- Thread.sleep(1000)
- assert(flatMappedRDD.dependencies.head.rdd != parCollection)
- assert(flatMappedRDD.collect() === result)
-
- checkpointDir.deleteOnExit()
- }
-
test("basic caching") {
- sc = new SparkContext("local", "test")
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
assert(rdd.collect().toList === List(1, 2, 3, 4))
assert(rdd.collect().toList === List(1, 2, 3, 4))
@@ -101,7 +80,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("caching with failures") {
- sc = new SparkContext("local", "test")
val onlySplit = new Partition { override def index: Int = 0 }
var shouldFail = true
val rdd = new RDD[Int](sc, Nil) {
@@ -123,38 +101,26 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(rdd.collect().toList === List(1, 2, 3, 4))
}
- test("cogrouped RDDs") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.makeRDD(Array((1, "one"), (1, "another one"), (2, "two"), (3, "three")), 2)
- val rdd2 = sc.makeRDD(Array((1, "one1"), (1, "another one1"), (2, "two1")), 2)
-
- // Use cogroup function
- val cogrouped = rdd1.cogroup(rdd2).collectAsMap()
- assert(cogrouped(1) === (Seq("one", "another one"), Seq("one1", "another one1")))
- assert(cogrouped(2) === (Seq("two"), Seq("two1")))
- assert(cogrouped(3) === (Seq("three"), Seq()))
-
- // Construct CoGroupedRDD directly, with map side combine enabled
- val cogrouped1 = new CoGroupedRDD[Int](
- Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]),
- new HashPartitioner(3),
- true).collectAsMap()
- assert(cogrouped1(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1")))
- assert(cogrouped1(2).toSeq === Seq(Seq("two"), Seq("two1")))
- assert(cogrouped1(3).toSeq === Seq(Seq("three"), Seq()))
+ test("empty RDD") {
+ val empty = new EmptyRDD[Int](sc)
+ assert(empty.count === 0)
+ assert(empty.collect().size === 0)
- // Construct CoGroupedRDD directly, with map side combine disabled
- val cogrouped2 = new CoGroupedRDD[Int](
- Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]),
- new HashPartitioner(3),
- false).collectAsMap()
- assert(cogrouped2(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1")))
- assert(cogrouped2(2).toSeq === Seq(Seq("two"), Seq("two1")))
- assert(cogrouped2(3).toSeq === Seq(Seq("three"), Seq()))
+ val thrown = intercept[UnsupportedOperationException]{
+ empty.reduce(_+_)
+ }
+ assert(thrown.getMessage.contains("empty"))
+
+ val emptyKv = new EmptyRDD[(Int, Int)](sc)
+ val rdd = sc.parallelize(1 to 2, 2).map(x => (x, x))
+ assert(rdd.join(emptyKv).collect().size === 0)
+ assert(rdd.rightOuterJoin(emptyKv).collect().size === 0)
+ assert(rdd.leftOuterJoin(emptyKv).collect().size === 2)
+ assert(rdd.cogroup(emptyKv).collect().size === 2)
+ assert(rdd.union(emptyKv).collect().size === 2)
}
- test("coalesced RDDs") {
- sc = new SparkContext("local", "test")
+ test("cogrouped RDDs") {
val data = sc.parallelize(1 to 10, 10)
val coalesced1 = data.coalesce(2)
@@ -192,7 +158,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("zipped RDDs") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val zipped = nums.zip(nums.map(_ + 1.0))
assert(zipped.glom().map(_.toList).collect().toList ===
@@ -204,7 +169,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("partition pruning") {
- sc = new SparkContext("local", "test")
val data = sc.parallelize(1 to 10, 10)
// Note that split number starts from 0, so > 8 means only 10th partition left.
val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8)
@@ -216,7 +180,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
test("mapWith") {
import java.util.Random
- sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.mapWith(
(index: Int) => new Random(index + 42))
@@ -235,7 +198,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
test("flatMapWith") {
import java.util.Random
- sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.flatMapWith(
(index: Int) => new Random(index + 42))
@@ -257,7 +219,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
test("filterWith") {
import java.util.Random
- sc = new SparkContext("local", "test")
val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
val sample = ints.filterWith(
(index: Int) => new Random(index + 42))
@@ -273,4 +234,21 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(sample.size === checkSample.size)
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
}
+
+ test("top with predefined ordering") {
+ val nums = Array.range(1, 100000)
+ val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)
+ val topK = ints.top(5)
+ assert(topK.size === 5)
+ assert(topK.sorted === nums.sorted.takeRight(5))
+ }
+
+ test("top with custom ordering") {
+ val words = Vector("a", "b", "c", "d")
+ implicit val ord = implicitly[Ordering[String]].reverse
+ val rdd = sc.makeRDD(words, 2)
+ val topK = rdd.top(2)
+ assert(topK.size === 2)
+ assert(topK.sorted === Array("b", "a"))
+ }
}
diff --git a/core/src/test/scala/spark/SharedSparkContext.scala b/core/src/test/scala/spark/SharedSparkContext.scala
new file mode 100644
index 0000000000..1da79f9824
--- /dev/null
+++ b/core/src/test/scala/spark/SharedSparkContext.scala
@@ -0,0 +1,25 @@
+package spark
+
+import org.scalatest.Suite
+import org.scalatest.BeforeAndAfterAll
+
+/** Shares a local `SparkContext` between all tests in a suite and closes it at the end */
+trait SharedSparkContext extends BeforeAndAfterAll { self: Suite =>
+
+ @transient private var _sc: SparkContext = _
+
+ def sc: SparkContext = _sc
+
+ override def beforeAll() {
+ _sc = new SparkContext("local", "test")
+ super.beforeAll()
+ }
+
+ override def afterAll() {
+ if (_sc != null) {
+ LocalSparkContext.stop(_sc)
+ _sc = null
+ }
+ super.afterAll()
+ }
+}
diff --git a/core/src/test/scala/spark/ShuffleNettySuite.scala b/core/src/test/scala/spark/ShuffleNettySuite.scala
new file mode 100644
index 0000000000..bfaffa953e
--- /dev/null
+++ b/core/src/test/scala/spark/ShuffleNettySuite.scala
@@ -0,0 +1,17 @@
+package spark
+
+import org.scalatest.BeforeAndAfterAll
+
+
+class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll {
+
+ // This test suite should run all tests in ShuffleSuite with Netty shuffle mode.
+
+ override def beforeAll(configMap: Map[String, Any]) {
+ System.setProperty("spark.shuffle.use.netty", "true")
+ }
+
+ override def afterAll(configMap: Map[String, Any]) {
+ System.setProperty("spark.shuffle.use.netty", "false")
+ }
+}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 2b2a90defa..950218fa28 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -16,54 +16,9 @@ import spark.rdd.ShuffledRDD
import spark.SparkContext._
class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
-
- test("groupByKey") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
- val groups = pairs.groupByKey().collect()
- assert(groups.size === 2)
- val valuesFor1 = groups.find(_._1 == 1).get._2
- assert(valuesFor1.toList.sorted === List(1, 2, 3))
- val valuesFor2 = groups.find(_._1 == 2).get._2
- assert(valuesFor2.toList.sorted === List(1))
- }
-
- test("groupByKey with duplicates") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
- val groups = pairs.groupByKey().collect()
- assert(groups.size === 2)
- val valuesFor1 = groups.find(_._1 == 1).get._2
- assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
- val valuesFor2 = groups.find(_._1 == 2).get._2
- assert(valuesFor2.toList.sorted === List(1))
- }
-
- test("groupByKey with negative key hash codes") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1)))
- val groups = pairs.groupByKey().collect()
- assert(groups.size === 2)
- val valuesForMinus1 = groups.find(_._1 == -1).get._2
- assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
- val valuesFor2 = groups.find(_._1 == 2).get._2
- assert(valuesFor2.toList.sorted === List(1))
- }
-
- test("groupByKey with many output partitions") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
- val groups = pairs.groupByKey(10).collect()
- assert(groups.size === 2)
- val valuesFor1 = groups.find(_._1 == 1).get._2
- assert(valuesFor1.toList.sorted === List(1, 2, 3))
- val valuesFor2 = groups.find(_._1 == 2).get._2
- assert(valuesFor2.toList.sorted === List(1))
- }
-
test("groupByKey with compression") {
try {
- System.setProperty("spark.blockManager.compress", "true")
+ System.setProperty("spark.shuffle.compress", "true")
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4)
val groups = pairs.groupByKey(4).collect()
@@ -77,239 +32,100 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
}
}
- test("reduceByKey") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
- val sums = pairs.reduceByKey(_+_).collect()
- assert(sums.toSet === Set((1, 7), (2, 1)))
- }
-
- test("reduceByKey with collectAsMap") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
- val sums = pairs.reduceByKey(_+_).collectAsMap()
- assert(sums.size === 2)
- assert(sums(1) === 7)
- assert(sums(2) === 1)
- }
+ test("shuffle non-zero block size") {
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ val NUM_BLOCKS = 3
- test("reduceByKey with many output partitons") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
- val sums = pairs.reduceByKey(_+_, 10).collect()
- assert(sums.toSet === Set((1, 7), (2, 1)))
- }
-
- test("reduceByKey with partitioner") {
- sc = new SparkContext("local", "test")
- val p = new Partitioner() {
- def numPartitions = 2
- def getPartition(key: Any) = key.asInstanceOf[Int]
+ val a = sc.parallelize(1 to 10, 2)
+ val b = a.map { x =>
+ (x, new ShuffleSuite.NonJavaSerializableClass(x * 2))
}
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
- val sums = pairs.reduceByKey(_+_)
- assert(sums.collect().toSet === Set((1, 4), (0, 1)))
- assert(sums.partitioner === Some(p))
- // count the dependencies to make sure there is only 1 ShuffledRDD
- val deps = new HashSet[RDD[_]]()
- def visit(r: RDD[_]) {
- for (dep <- r.dependencies) {
- deps += dep.rdd
- visit(dep.rdd)
- }
+ // If the Kryo serializer is not used correctly, the shuffle would fail because the
+ // default Java serializer cannot handle the non serializable class.
+ val c = new ShuffledRDD(b, new HashPartitioner(NUM_BLOCKS),
+ classOf[spark.KryoSerializer].getName)
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+
+ assert(c.count === 10)
+
+ // All blocks must have non-zero size
+ (0 until NUM_BLOCKS).foreach { id =>
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
+ assert(statuses.forall(s => s._2 > 0))
}
- visit(sums)
- assert(deps.size === 2) // ShuffledRDD, ParallelCollection
}
- test("join") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
- val joined = rdd1.join(rdd2).collect()
- assert(joined.size === 4)
- assert(joined.toSet === Set(
- (1, (1, 'x')),
- (1, (2, 'x')),
- (2, (1, 'y')),
- (2, (1, 'z'))
- ))
+ test("shuffle serializer") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ val a = sc.parallelize(1 to 10, 2)
+ val b = a.map { x =>
+ (x, new ShuffleSuite.NonJavaSerializableClass(x * 2))
+ }
+ // If the Kryo serializer is not used correctly, the shuffle would fail because the
+ // default Java serializer cannot handle the non serializable class.
+ val c = new ShuffledRDD(b, new HashPartitioner(3), classOf[spark.KryoSerializer].getName)
+ assert(c.count === 10)
}
- test("join all-to-all") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
- val joined = rdd1.join(rdd2).collect()
- assert(joined.size === 6)
- assert(joined.toSet === Set(
- (1, (1, 'x')),
- (1, (1, 'y')),
- (1, (2, 'x')),
- (1, (2, 'y')),
- (1, (3, 'x')),
- (1, (3, 'y'))
- ))
- }
+ test("zero sized blocks") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
- test("leftOuterJoin") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
- val joined = rdd1.leftOuterJoin(rdd2).collect()
- assert(joined.size === 5)
- assert(joined.toSet === Set(
- (1, (1, Some('x'))),
- (1, (2, Some('x'))),
- (2, (1, Some('y'))),
- (2, (1, Some('z'))),
- (3, (1, None))
- ))
- }
+ // 10 partitions from 4 keys
+ val NUM_BLOCKS = 10
+ val a = sc.parallelize(1 to 4, NUM_BLOCKS)
+ val b = a.map(x => (x, x*2))
- test("rightOuterJoin") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
- val joined = rdd1.rightOuterJoin(rdd2).collect()
- assert(joined.size === 5)
- assert(joined.toSet === Set(
- (1, (Some(1), 'x')),
- (1, (Some(2), 'x')),
- (2, (Some(1), 'y')),
- (2, (Some(1), 'z')),
- (4, (None, 'w'))
- ))
- }
+ // NOTE: The default Java serializer doesn't create zero-sized blocks.
+ // So, use Kryo
+ val c = new ShuffledRDD(b, new HashPartitioner(10), classOf[spark.KryoSerializer].getName)
- test("join with no matches") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
- val joined = rdd1.join(rdd2).collect()
- assert(joined.size === 0)
- }
-
- test("join with many output partitions") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
- val joined = rdd1.join(rdd2, 10).collect()
- assert(joined.size === 4)
- assert(joined.toSet === Set(
- (1, (1, 'x')),
- (1, (2, 'x')),
- (2, (1, 'y')),
- (2, (1, 'z'))
- ))
- }
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ assert(c.count === 4)
- test("groupWith") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
- val joined = rdd1.groupWith(rdd2).collect()
- assert(joined.size === 4)
- assert(joined.toSet === Set(
- (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))),
- (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))),
- (3, (ArrayBuffer(1), ArrayBuffer())),
- (4, (ArrayBuffer(), ArrayBuffer('w')))
- ))
- }
+ val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
+ statuses.map(x => x._2)
+ }
+ val nonEmptyBlocks = blockSizes.filter(x => x > 0)
- test("zero-partition RDD") {
- sc = new SparkContext("local", "test")
- val emptyDir = Files.createTempDir()
- val file = sc.textFile(emptyDir.getAbsolutePath)
- assert(file.partitions.size == 0)
- assert(file.collect().toList === Nil)
- // Test that a shuffle on the file works, because this used to be a bug
- assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
+ // We should have at most 4 non-zero sized partitions
+ assert(nonEmptyBlocks.size <= 4)
}
- test("keys and values") {
- sc = new SparkContext("local", "test")
- val rdd = sc.parallelize(Array((1, "a"), (2, "b")))
- assert(rdd.keys.collect().toList === List(1, 2))
- assert(rdd.values.collect().toList === List("a", "b"))
- }
+ test("zero sized blocks without kryo") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
- test("default partitioner uses partition size") {
- sc = new SparkContext("local", "test")
- // specify 2000 partitions
- val a = sc.makeRDD(Array(1, 2, 3, 4), 2000)
- // do a map, which loses the partitioner
- val b = a.map(a => (a, (a * 2).toString))
- // then a group by, and see we didn't revert to 2 partitions
- val c = b.groupByKey()
- assert(c.partitions.size === 2000)
- }
+ // 10 partitions from 4 keys
+ val NUM_BLOCKS = 10
+ val a = sc.parallelize(1 to 4, NUM_BLOCKS)
+ val b = a.map(x => (x, x*2))
- test("default partitioner uses largest partitioner") {
- sc = new SparkContext("local", "test")
- val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2)
- val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000)
- val c = a.join(b)
- assert(c.partitions.size === 2000)
- }
+ // NOTE: The default Java serializer should create zero-sized blocks
+ val c = new ShuffledRDD(b, new HashPartitioner(10))
- test("subtract") {
- sc = new SparkContext("local", "test")
- val a = sc.parallelize(Array(1, 2, 3), 2)
- val b = sc.parallelize(Array(2, 3, 4), 4)
- val c = a.subtract(b)
- assert(c.collect().toSet === Set(1))
- assert(c.partitions.size === a.partitions.size)
- }
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ assert(c.count === 4)
- test("subtract with narrow dependency") {
- sc = new SparkContext("local", "test")
- // use a deterministic partitioner
- val p = new Partitioner() {
- def numPartitions = 5
- def getPartition(key: Any) = key.asInstanceOf[Int]
+ val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
+ statuses.map(x => x._2)
}
- // partitionBy so we have a narrow dependency
- val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
- // more partitions/no partitioner so a shuffle dependency
- val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
- val c = a.subtract(b)
- assert(c.collect().toSet === Set((1, "a"), (3, "c")))
- // Ideally we could keep the original partitioner...
- assert(c.partitioner === None)
- }
-
- test("subtractByKey") {
- sc = new SparkContext("local", "test")
- val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
- val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4)
- val c = a.subtractByKey(b)
- assert(c.collect().toSet === Set((1, "a"), (1, "a")))
- assert(c.partitions.size === a.partitions.size)
- }
+ val nonEmptyBlocks = blockSizes.filter(x => x > 0)
- test("subtractByKey with narrow dependency") {
- sc = new SparkContext("local", "test")
- // use a deterministic partitioner
- val p = new Partitioner() {
- def numPartitions = 5
- def getPartition(key: Any) = key.asInstanceOf[Int]
- }
- // partitionBy so we have a narrow dependency
- val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
- // more partitions/no partitioner so a shuffle dependency
- val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
- val c = a.subtractByKey(b)
- assert(c.collect().toSet === Set((1, "a"), (1, "a")))
- assert(c.partitioner.get === p)
+ // We should have at most 4 non-zero sized partitions
+ assert(nonEmptyBlocks.size <= 4)
}
-
}
object ShuffleSuite {
+
def mergeCombineException(x: Int, y: Int): Int = {
throw new SparkException("Exception for map-side combine.")
x + y
}
+
+ class NonJavaSerializableClass(val value: Int)
}
diff --git a/core/src/test/scala/spark/SizeEstimatorSuite.scala b/core/src/test/scala/spark/SizeEstimatorSuite.scala
index 9f3aa6628d..c385965c35 100644
--- a/core/src/test/scala/spark/SizeEstimatorSuite.scala
+++ b/core/src/test/scala/spark/SizeEstimatorSuite.scala
@@ -78,7 +78,6 @@ class SizeEstimatorSuite
// Arrays containing nulls should just have one pointer per element
expectResult(56)(SizeEstimator.estimate(new Array[String](10)))
expectResult(56)(SizeEstimator.estimate(new Array[AnyRef](10)))
-
// For object arrays with non-null elements, each object should take one pointer plus
// however many bytes that class takes. (Note that Array.fill calls the code in its
// second parameter separately for each object, so we get distinct objects.)
@@ -115,7 +114,6 @@ class SizeEstimatorSuite
expectResult(48)(SizeEstimator.estimate(DummyString("a")))
expectResult(48)(SizeEstimator.estimate(DummyString("ab")))
expectResult(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
-
resetOrClear("os.arch", arch)
}
diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala
index 495f957e53..f7bf207c68 100644
--- a/core/src/test/scala/spark/SortingSuite.scala
+++ b/core/src/test/scala/spark/SortingSuite.scala
@@ -5,16 +5,14 @@ import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.ShouldMatchers
import SparkContext._
-class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers with Logging {
-
+class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging {
+
test("sortByKey") {
- sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2)
- assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
+ assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
}
test("large array") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@@ -24,7 +22,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("large array with one split") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@@ -32,9 +29,8 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
assert(sorted.partitions.size === 1)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
-
+
test("large array with many partitions") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@@ -42,9 +38,8 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
assert(sorted.partitions.size === 20)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
-
+
test("sort descending") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@@ -52,15 +47,13 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("sort descending with one split") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 1)
assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
}
-
+
test("sort descending with many partitions") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@@ -68,7 +61,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("more partitions than elements") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 30)
@@ -76,14 +68,12 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("empty RDD") {
- sc = new SparkContext("local", "test")
val pairArr = new Array[(Int, Int)](0)
val pairs = sc.parallelize(pairArr, 2)
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
}
test("partition balancing") {
- sc = new SparkContext("local", "test")
val pairArr = (1 to 1000).map(x => (x, x)).toArray
val sorted = sc.parallelize(pairArr, 4).sortByKey()
assert(sorted.collect() === pairArr.sortBy(_._1))
@@ -99,7 +89,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("partition balancing for descending sort") {
- sc = new SparkContext("local", "test")
val pairArr = (1 to 1000).map(x => (x, x)).toArray
val sorted = sc.parallelize(pairArr, 4).sortByKey(false)
assert(sorted.collect() === pairArr.sortBy(_._1).reverse)
diff --git a/core/src/test/scala/spark/UnpersistSuite.scala b/core/src/test/scala/spark/UnpersistSuite.scala
new file mode 100644
index 0000000000..94776e7572
--- /dev/null
+++ b/core/src/test/scala/spark/UnpersistSuite.scala
@@ -0,0 +1,30 @@
+package spark
+
+import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts._
+import org.scalatest.time.{Span, Millis}
+import spark.SparkContext._
+
+class UnpersistSuite extends FunSuite with LocalSparkContext {
+ test("unpersist RDD") {
+ sc = new SparkContext("local", "test")
+ val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
+ rdd.count
+ assert(sc.persistentRdds.isEmpty === false)
+ rdd.unpersist()
+ assert(sc.persistentRdds.isEmpty === true)
+
+ failAfter(Span(3000, Millis)) {
+ try {
+ while (! sc.getRDDStorageInfo.isEmpty) {
+ Thread.sleep(200)
+ }
+ } catch {
+ case _ => { Thread.sleep(10) }
+ // Do nothing. We might see exceptions because block manager
+ // is racing this thread to remove entries from the driver.
+ }
+ }
+ assert(sc.getRDDStorageInfo.isEmpty === true)
+ }
+}
diff --git a/core/src/test/scala/spark/UtilsSuite.scala b/core/src/test/scala/spark/UtilsSuite.scala
index ed4701574f..4a113e16bf 100644
--- a/core/src/test/scala/spark/UtilsSuite.scala
+++ b/core/src/test/scala/spark/UtilsSuite.scala
@@ -27,24 +27,49 @@ class UtilsSuite extends FunSuite {
assert(os.toByteArray.toList.equals(bytes.toList))
}
- test("memoryStringToMb"){
- assert(Utils.memoryStringToMb("1") == 0)
- assert(Utils.memoryStringToMb("1048575") == 0)
- assert(Utils.memoryStringToMb("3145728") == 3)
+ test("memoryStringToMb") {
+ assert(Utils.memoryStringToMb("1") === 0)
+ assert(Utils.memoryStringToMb("1048575") === 0)
+ assert(Utils.memoryStringToMb("3145728") === 3)
- assert(Utils.memoryStringToMb("1024k") == 1)
- assert(Utils.memoryStringToMb("5000k") == 4)
- assert(Utils.memoryStringToMb("4024k") == Utils.memoryStringToMb("4024K"))
+ assert(Utils.memoryStringToMb("1024k") === 1)
+ assert(Utils.memoryStringToMb("5000k") === 4)
+ assert(Utils.memoryStringToMb("4024k") === Utils.memoryStringToMb("4024K"))
- assert(Utils.memoryStringToMb("1024m") == 1024)
- assert(Utils.memoryStringToMb("5000m") == 5000)
- assert(Utils.memoryStringToMb("4024m") == Utils.memoryStringToMb("4024M"))
+ assert(Utils.memoryStringToMb("1024m") === 1024)
+ assert(Utils.memoryStringToMb("5000m") === 5000)
+ assert(Utils.memoryStringToMb("4024m") === Utils.memoryStringToMb("4024M"))
- assert(Utils.memoryStringToMb("2g") == 2048)
- assert(Utils.memoryStringToMb("3g") == Utils.memoryStringToMb("3G"))
+ assert(Utils.memoryStringToMb("2g") === 2048)
+ assert(Utils.memoryStringToMb("3g") === Utils.memoryStringToMb("3G"))
- assert(Utils.memoryStringToMb("2t") == 2097152)
- assert(Utils.memoryStringToMb("3t") == Utils.memoryStringToMb("3T"))
+ assert(Utils.memoryStringToMb("2t") === 2097152)
+ assert(Utils.memoryStringToMb("3t") === Utils.memoryStringToMb("3T"))
+ }
+
+ test("splitCommandString") {
+ assert(Utils.splitCommandString("") === Seq())
+ assert(Utils.splitCommandString("a") === Seq("a"))
+ assert(Utils.splitCommandString("aaa") === Seq("aaa"))
+ assert(Utils.splitCommandString("a b c") === Seq("a", "b", "c"))
+ assert(Utils.splitCommandString(" a b\t c ") === Seq("a", "b", "c"))
+ assert(Utils.splitCommandString("a 'b c'") === Seq("a", "b c"))
+ assert(Utils.splitCommandString("a 'b c' d") === Seq("a", "b c", "d"))
+ assert(Utils.splitCommandString("'b c'") === Seq("b c"))
+ assert(Utils.splitCommandString("a \"b c\"") === Seq("a", "b c"))
+ assert(Utils.splitCommandString("a \"b c\" d") === Seq("a", "b c", "d"))
+ assert(Utils.splitCommandString("\"b c\"") === Seq("b c"))
+ assert(Utils.splitCommandString("a 'b\" c' \"d' e\"") === Seq("a", "b\" c", "d' e"))
+ assert(Utils.splitCommandString("a\t'b\nc'\nd") === Seq("a", "b\nc", "d"))
+ assert(Utils.splitCommandString("a \"b\\\\c\"") === Seq("a", "b\\c"))
+ assert(Utils.splitCommandString("a \"b\\\"c\"") === Seq("a", "b\"c"))
+ assert(Utils.splitCommandString("a 'b\\\"c'") === Seq("a", "b\\\"c"))
+ assert(Utils.splitCommandString("'a'b") === Seq("ab"))
+ assert(Utils.splitCommandString("'a''b'") === Seq("ab"))
+ assert(Utils.splitCommandString("\"a\"b") === Seq("ab"))
+ assert(Utils.splitCommandString("\"a\"\"b\"") === Seq("ab"))
+ assert(Utils.splitCommandString("''") === Seq(""))
+ assert(Utils.splitCommandString("\"\"") === Seq(""))
}
}
diff --git a/core/src/test/scala/spark/ZippedPartitionsSuite.scala b/core/src/test/scala/spark/ZippedPartitionsSuite.scala
new file mode 100644
index 0000000000..96cb295f45
--- /dev/null
+++ b/core/src/test/scala/spark/ZippedPartitionsSuite.scala
@@ -0,0 +1,33 @@
+package spark
+
+import scala.collection.immutable.NumericRange
+
+import org.scalatest.FunSuite
+import org.scalatest.prop.Checkers
+import org.scalacheck.Arbitrary._
+import org.scalacheck.Gen
+import org.scalacheck.Prop._
+
+import SparkContext._
+
+
+object ZippedPartitionsSuite {
+ def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = {
+ Iterator(i.toArray.size, s.toArray.size, d.toArray.size)
+ }
+}
+
+class ZippedPartitionsSuite extends FunSuite with SharedSparkContext {
+ test("print sizes") {
+ val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2)
+ val data3 = sc.makeRDD(Array(1.0, 2.0), 2)
+
+ val zippedRDD = data1.zipPartitions(ZippedPartitionsSuite.procZippedData, data2, data3)
+
+ val obtainedSizes = zippedRDD.collect()
+ val expectedSizes = Array(2, 3, 1, 2, 3, 1)
+ assert(obtainedSizes.size == 6)
+ assert(obtainedSizes.zip(expectedSizes).forall(x => x._1 == x._2))
+ }
+}
diff --git a/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala
new file mode 100644
index 0000000000..6afb0fa9bc
--- /dev/null
+++ b/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala
@@ -0,0 +1,56 @@
+package spark
+
+import org.scalatest.{ BeforeAndAfter, FunSuite }
+import spark.SparkContext._
+import spark.rdd.JdbcRDD
+import java.sql._
+
+class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
+
+ before {
+ Class.forName("org.apache.derby.jdbc.EmbeddedDriver")
+ val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true")
+ try {
+ val create = conn.createStatement
+ create.execute("""
+ CREATE TABLE FOO(
+ ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1),
+ DATA INTEGER
+ )""")
+ create.close
+ val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)")
+ (1 to 100).foreach { i =>
+ insert.setInt(1, i * 2)
+ insert.executeUpdate
+ }
+ insert.close
+ } catch {
+ case e: SQLException if e.getSQLState == "X0Y32" =>
+ // table exists
+ } finally {
+ conn.close
+ }
+ }
+
+ test("basic functionality") {
+ sc = new SparkContext("local", "test")
+ val rdd = new JdbcRDD(
+ sc,
+ () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") },
+ "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?",
+ 1, 100, 3,
+ (r: ResultSet) => { r.getInt(1) } ).cache
+
+ assert(rdd.count === 100)
+ assert(rdd.reduce(_+_) === 10100)
+ }
+
+ after {
+ try {
+ DriverManager.getConnection("jdbc:derby:;shutdown=true")
+ } catch {
+ case se: SQLException if se.getSQLState == "XJ015" =>
+ // normal shutdown
+ }
+ }
+}
diff --git a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala
new file mode 100644
index 0000000000..8e1ad27e14
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala
@@ -0,0 +1,250 @@
+package spark.scheduler
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+
+import spark._
+import spark.scheduler._
+import spark.scheduler.cluster._
+import scala.collection.mutable.ArrayBuffer
+
+import java.util.Properties
+
+class DummyTaskSetManager(
+ initPriority: Int,
+ initStageId: Int,
+ initNumTasks: Int,
+ clusterScheduler: ClusterScheduler,
+ taskSet: TaskSet)
+ extends ClusterTaskSetManager(clusterScheduler,taskSet) {
+
+ parent = null
+ weight = 1
+ minShare = 2
+ runningTasks = 0
+ priority = initPriority
+ stageId = initStageId
+ name = "TaskSet_"+stageId
+ override val numTasks = initNumTasks
+ tasksFinished = 0
+
+ override def increaseRunningTasks(taskNum: Int) {
+ runningTasks += taskNum
+ if (parent != null) {
+ parent.increaseRunningTasks(taskNum)
+ }
+ }
+
+ override def decreaseRunningTasks(taskNum: Int) {
+ runningTasks -= taskNum
+ if (parent != null) {
+ parent.decreaseRunningTasks(taskNum)
+ }
+ }
+
+ override def addSchedulable(schedulable: Schedulable) {
+ }
+
+ override def removeSchedulable(schedulable: Schedulable) {
+ }
+
+ override def getSchedulableByName(name: String): Schedulable = {
+ return null
+ }
+
+ override def executorLost(executorId: String, host: String): Unit = {
+ }
+
+ override def slaveOffer(execId: String, host: String, avaiableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
+ if (tasksFinished + runningTasks < numTasks) {
+ increaseRunningTasks(1)
+ return Some(new TaskDescription(0, execId, "task 0:0", null))
+ }
+ return None
+ }
+
+ override def checkSpeculatableTasks(): Boolean = {
+ return true
+ }
+
+ def taskFinished() {
+ decreaseRunningTasks(1)
+ tasksFinished +=1
+ if (tasksFinished == numTasks) {
+ parent.removeSchedulable(this)
+ }
+ }
+
+ def abort() {
+ decreaseRunningTasks(runningTasks)
+ parent.removeSchedulable(this)
+ }
+}
+
+class DummyTask(stageId: Int) extends Task[Int](stageId)
+{
+ def run(attemptId: Long): Int = {
+ return 0
+ }
+}
+
+class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging {
+
+ def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): DummyTaskSetManager = {
+ new DummyTaskSetManager(priority, stage, numTasks, cs , taskSet)
+ }
+
+ def resourceOffer(rootPool: Pool): Int = {
+ val taskSetQueue = rootPool.getSortedTaskSetQueue()
+ /* Just for Test*/
+ for (manager <- taskSetQueue) {
+ logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks))
+ }
+ for (taskSet <- taskSetQueue) {
+ taskSet.slaveOffer("execId_1", "hostname_1", 1) match {
+ case Some(task) =>
+ return taskSet.stageId
+ case None => {}
+ }
+ }
+ -1
+ }
+
+ def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) {
+ assert(resourceOffer(rootPool) === expectedTaskSetId)
+ }
+
+ test("FIFO Scheduler Test") {
+ sc = new SparkContext("local", "ClusterSchedulerSuite")
+ val clusterScheduler = new ClusterScheduler(sc)
+ var tasks = ArrayBuffer[Task[_]]()
+ val task = new DummyTask(0)
+ tasks += task
+ val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
+
+ val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0)
+ val schedulableBuilder = new FIFOSchedulableBuilder(rootPool)
+ schedulableBuilder.buildPools()
+
+ val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, clusterScheduler, taskSet)
+ val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, clusterScheduler, taskSet)
+ val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, clusterScheduler, taskSet)
+ schedulableBuilder.addTaskSetManager(taskSetManager0, null)
+ schedulableBuilder.addTaskSetManager(taskSetManager1, null)
+ schedulableBuilder.addTaskSetManager(taskSetManager2, null)
+
+ checkTaskSetId(rootPool, 0)
+ resourceOffer(rootPool)
+ checkTaskSetId(rootPool, 1)
+ resourceOffer(rootPool)
+ taskSetManager1.abort()
+ checkTaskSetId(rootPool, 2)
+ }
+
+ test("Fair Scheduler Test") {
+ sc = new SparkContext("local", "ClusterSchedulerSuite")
+ val clusterScheduler = new ClusterScheduler(sc)
+ var tasks = ArrayBuffer[Task[_]]()
+ val task = new DummyTask(0)
+ tasks += task
+ val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
+
+ val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
+ System.setProperty("spark.fairscheduler.allocation.file", xmlPath)
+ val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
+ val schedulableBuilder = new FairSchedulableBuilder(rootPool)
+ schedulableBuilder.buildPools()
+
+ assert(rootPool.getSchedulableByName("default") != null)
+ assert(rootPool.getSchedulableByName("1") != null)
+ assert(rootPool.getSchedulableByName("2") != null)
+ assert(rootPool.getSchedulableByName("3") != null)
+ assert(rootPool.getSchedulableByName("1").minShare === 2)
+ assert(rootPool.getSchedulableByName("1").weight === 1)
+ assert(rootPool.getSchedulableByName("2").minShare === 3)
+ assert(rootPool.getSchedulableByName("2").weight === 1)
+ assert(rootPool.getSchedulableByName("3").minShare === 2)
+ assert(rootPool.getSchedulableByName("3").weight === 1)
+
+ val properties1 = new Properties()
+ properties1.setProperty("spark.scheduler.cluster.fair.pool","1")
+ val properties2 = new Properties()
+ properties2.setProperty("spark.scheduler.cluster.fair.pool","2")
+
+ val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, clusterScheduler, taskSet)
+ val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, clusterScheduler, taskSet)
+ val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, clusterScheduler, taskSet)
+ schedulableBuilder.addTaskSetManager(taskSetManager10, properties1)
+ schedulableBuilder.addTaskSetManager(taskSetManager11, properties1)
+ schedulableBuilder.addTaskSetManager(taskSetManager12, properties1)
+
+ val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, clusterScheduler, taskSet)
+ val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, clusterScheduler, taskSet)
+ schedulableBuilder.addTaskSetManager(taskSetManager23, properties2)
+ schedulableBuilder.addTaskSetManager(taskSetManager24, properties2)
+
+ checkTaskSetId(rootPool, 0)
+ checkTaskSetId(rootPool, 3)
+ checkTaskSetId(rootPool, 3)
+ checkTaskSetId(rootPool, 1)
+ checkTaskSetId(rootPool, 4)
+ checkTaskSetId(rootPool, 2)
+ checkTaskSetId(rootPool, 2)
+ checkTaskSetId(rootPool, 4)
+
+ taskSetManager12.taskFinished()
+ assert(rootPool.getSchedulableByName("1").runningTasks === 3)
+ taskSetManager24.abort()
+ assert(rootPool.getSchedulableByName("2").runningTasks === 2)
+ }
+
+ test("Nested Pool Test") {
+ sc = new SparkContext("local", "ClusterSchedulerSuite")
+ val clusterScheduler = new ClusterScheduler(sc)
+ var tasks = ArrayBuffer[Task[_]]()
+ val task = new DummyTask(0)
+ tasks += task
+ val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
+
+ val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
+ val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1)
+ val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1)
+ rootPool.addSchedulable(pool0)
+ rootPool.addSchedulable(pool1)
+
+ val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2)
+ val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1)
+ pool0.addSchedulable(pool00)
+ pool0.addSchedulable(pool01)
+
+ val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2)
+ val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1)
+ pool1.addSchedulable(pool10)
+ pool1.addSchedulable(pool11)
+
+ val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, clusterScheduler, taskSet)
+ val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, clusterScheduler, taskSet)
+ pool00.addSchedulable(taskSetManager000)
+ pool00.addSchedulable(taskSetManager001)
+
+ val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, clusterScheduler, taskSet)
+ val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, clusterScheduler, taskSet)
+ pool01.addSchedulable(taskSetManager010)
+ pool01.addSchedulable(taskSetManager011)
+
+ val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, clusterScheduler, taskSet)
+ val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, clusterScheduler, taskSet)
+ pool10.addSchedulable(taskSetManager100)
+ pool10.addSchedulable(taskSetManager101)
+
+ val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, clusterScheduler, taskSet)
+ val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, clusterScheduler, taskSet)
+ pool11.addSchedulable(taskSetManager110)
+ pool11.addSchedulable(taskSetManager111)
+
+ checkTaskSetId(rootPool, 0)
+ checkTaskSetId(rootPool, 4)
+ checkTaskSetId(rootPool, 6)
+ checkTaskSetId(rootPool, 2)
+ }
+}
diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
index 6da58a0f6e..30e6fef950 100644
--- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -44,7 +44,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
override def submitTasks(taskSet: TaskSet) = {
// normally done by TaskSetManager
taskSet.tasks.foreach(_.generation = mapOutputTracker.getGeneration)
- taskSets += taskSet
+ taskSets += taskSet
}
override def setListener(listener: TaskSchedulerListener) = {}
override def defaultParallelism() = 2
@@ -164,7 +164,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
}
}
}
-
+
/** Sends the rdd to the scheduler for scheduling. */
private def submit(
rdd: RDD[_],
@@ -174,7 +174,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
listener: JobListener = listener) {
runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener))
}
-
+
/** Sends TaskSetFailed to the scheduler. */
private def failed(taskSet: TaskSet, message: String) {
runEvent(TaskSetFailed(taskSet, message))
@@ -209,11 +209,11 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener))
assert(results === Map(0 -> 42))
}
-
+
test("run trivial job w/ dependency") {
val baseRdd = makeRdd(1, Nil)
val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
- submit(finalRdd, Array(0))
+ submit(finalRdd, Array(0))
complete(taskSets(0), Seq((Success, 42)))
assert(results === Map(0 -> 42))
}
@@ -250,7 +250,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
complete(taskSets(1), Seq((Success, 42)))
assert(results === Map(0 -> 42))
}
-
+
test("run trivial shuffle with fetch failure") {
val shuffleMapRdd = makeRdd(2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
@@ -271,7 +271,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
// have the 2nd attempt pass
complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
// we can see both result blocks now
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.ip) === Array("hostA", "hostB"))
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB"))
complete(taskSets(3), Seq((Success, 43)))
assert(results === Map(0 -> 42, 1 -> 43))
}
@@ -385,12 +385,12 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
assert(results === Map(0 -> 42))
}
- /** Assert that the supplied TaskSet has exactly the given preferredLocations. */
+ /** Assert that the supplied TaskSet has exactly the given preferredLocations. Note, converts taskSet's locations to host only. */
private def assertLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
assert(locations.size === taskSet.tasks.size)
for ((expectLocs, taskLocs) <-
taskSet.tasks.map(_.preferredLocations).zip(locations)) {
- assert(expectLocs === taskLocs)
+ assert(expectLocs.map(loc => spark.Utils.parseHostPort(loc)._1) === taskLocs)
}
}
@@ -398,6 +398,6 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
private def makeBlockManagerId(host: String): BlockManagerId =
- BlockManagerId("exec-" + host, host, 12345)
+ BlockManagerId("exec-" + host, host, 12345, 0)
}
diff --git a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala
new file mode 100644
index 0000000000..699901f1a1
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala
@@ -0,0 +1,104 @@
+package spark.scheduler
+
+import java.util.Properties
+import java.util.concurrent.LinkedBlockingQueue
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import scala.collection.mutable
+import spark._
+import spark.SparkContext._
+
+
+class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+
+ test("inner method") {
+ sc = new SparkContext("local", "joblogger")
+ val joblogger = new JobLogger {
+ def createLogWriterTest(jobID: Int) = createLogWriter(jobID)
+ def closeLogWriterTest(jobID: Int) = closeLogWriter(jobID)
+ def getRddNameTest(rdd: RDD[_]) = getRddName(rdd)
+ def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage)
+ }
+ type MyRDD = RDD[(Int, Int)]
+ def makeRdd(
+ numPartitions: Int,
+ dependencies: List[Dependency[_]]
+ ): MyRDD = {
+ val maxPartition = numPartitions - 1
+ return new MyRDD(sc, dependencies) {
+ override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+ throw new RuntimeException("should not be reached")
+ override def getPartitions = (0 to maxPartition).map(i => new Partition {
+ override def index = i
+ }).toArray
+ }
+ }
+ val jobID = 5
+ val parentRdd = makeRdd(4, Nil)
+ val shuffleDep = new ShuffleDependency(parentRdd, null)
+ val rootRdd = makeRdd(4, List(shuffleDep))
+ val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID)
+ val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID)
+
+ joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4))
+ joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
+ parentRdd.setName("MyRDD")
+ joblogger.getRddNameTest(parentRdd) should be ("MyRDD")
+ joblogger.createLogWriterTest(jobID)
+ joblogger.getJobIDtoPrintWriter.size should be (1)
+ joblogger.buildJobDepTest(jobID, rootStage)
+ joblogger.getJobIDToStages.get(jobID).get.size should be (2)
+ joblogger.getStageIDToJobID.get(0) should be (Some(jobID))
+ joblogger.getStageIDToJobID.get(1) should be (Some(jobID))
+ joblogger.closeLogWriterTest(jobID)
+ joblogger.getStageIDToJobID.size should be (0)
+ joblogger.getJobIDToStages.size should be (0)
+ joblogger.getJobIDtoPrintWriter.size should be (0)
+ }
+
+ test("inner variables") {
+ sc = new SparkContext("local[4]", "joblogger")
+ val joblogger = new JobLogger {
+ override protected def closeLogWriter(jobID: Int) =
+ getJobIDtoPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ }
+ }
+ sc.addSparkListener(joblogger)
+ val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
+ rdd.reduceByKey(_+_).collect()
+
+ joblogger.getLogDir should be ("/tmp/spark")
+ joblogger.getJobIDtoPrintWriter.size should be (1)
+ joblogger.getStageIDToJobID.size should be (2)
+ joblogger.getStageIDToJobID.get(0) should be (Some(0))
+ joblogger.getStageIDToJobID.get(1) should be (Some(0))
+ joblogger.getJobIDToStages.size should be (1)
+ }
+
+
+ test("interface functions") {
+ sc = new SparkContext("local[4]", "joblogger")
+ val joblogger = new JobLogger {
+ var onTaskEndCount = 0
+ var onJobEndCount = 0
+ var onJobStartCount = 0
+ var onStageCompletedCount = 0
+ var onStageSubmittedCount = 0
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1
+ override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1
+ override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1
+ }
+ sc.addSparkListener(joblogger)
+ val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
+ rdd.reduceByKey(_+_).collect()
+
+ joblogger.onJobStartCount should be (1)
+ joblogger.onJobEndCount should be (1)
+ joblogger.onTaskEndCount should be (8)
+ joblogger.onStageSubmittedCount should be (2)
+ joblogger.onStageCompletedCount should be (2)
+ }
+}
diff --git a/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala
new file mode 100644
index 0000000000..8bd813fd14
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala
@@ -0,0 +1,206 @@
+package spark.scheduler
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+
+import spark._
+import spark.scheduler._
+import spark.scheduler.cluster._
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ConcurrentMap, HashMap}
+import java.util.concurrent.Semaphore
+import java.util.concurrent.CountDownLatch
+import java.util.Properties
+
+class Lock() {
+ var finished = false
+ def jobWait() = {
+ synchronized {
+ while(!finished) {
+ this.wait()
+ }
+ }
+ }
+
+ def jobFinished() = {
+ synchronized {
+ finished = true
+ this.notifyAll()
+ }
+ }
+}
+
+object TaskThreadInfo {
+ val threadToLock = HashMap[Int, Lock]()
+ val threadToRunning = HashMap[Int, Boolean]()
+ val threadToStarted = HashMap[Int, CountDownLatch]()
+}
+
+/*
+ * 1. each thread contains one job.
+ * 2. each job contains one stage.
+ * 3. each stage only contains one task.
+ * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure
+ * it will get cpu core resource, and will wait to finished after user manually
+ * release "Lock" and then cluster will contain another free cpu cores.
+ * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue,
+ * thus it will be scheduled later when cluster has free cpu cores.
+ */
+class LocalSchedulerSuite extends FunSuite with LocalSparkContext {
+
+ def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
+
+ TaskThreadInfo.threadToRunning(threadIndex) = false
+ val nums = sc.parallelize(threadIndex to threadIndex, 1)
+ TaskThreadInfo.threadToLock(threadIndex) = new Lock()
+ TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1)
+ new Thread {
+ if (poolName != null) {
+ sc.addLocalProperties("spark.scheduler.cluster.fair.pool",poolName)
+ }
+ override def run() {
+ val ans = nums.map(number => {
+ TaskThreadInfo.threadToRunning(number) = true
+ TaskThreadInfo.threadToStarted(number).countDown()
+ TaskThreadInfo.threadToLock(number).jobWait()
+ TaskThreadInfo.threadToRunning(number) = false
+ number
+ }).collect()
+ assert(ans.toList === List(threadIndex))
+ sem.release()
+ }
+ }.start()
+ }
+
+ test("Local FIFO scheduler end-to-end test") {
+ System.setProperty("spark.cluster.schedulingmode", "FIFO")
+ sc = new SparkContext("local[4]", "test")
+ val sem = new Semaphore(0)
+
+ createThread(1,null,sc,sem)
+ TaskThreadInfo.threadToStarted(1).await()
+ createThread(2,null,sc,sem)
+ TaskThreadInfo.threadToStarted(2).await()
+ createThread(3,null,sc,sem)
+ TaskThreadInfo.threadToStarted(3).await()
+ createThread(4,null,sc,sem)
+ TaskThreadInfo.threadToStarted(4).await()
+ // thread 5 and 6 (stage pending)must meet following two points
+ // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager
+ // queue before executing TaskThreadInfo.threadToLock(1).jobFinished()
+ // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6
+ // So I just use "sleep" 1s here for each thread.
+ // TODO: any better solution?
+ createThread(5,null,sc,sem)
+ Thread.sleep(1000)
+ createThread(6,null,sc,sem)
+ Thread.sleep(1000)
+
+ assert(TaskThreadInfo.threadToRunning(1) === true)
+ assert(TaskThreadInfo.threadToRunning(2) === true)
+ assert(TaskThreadInfo.threadToRunning(3) === true)
+ assert(TaskThreadInfo.threadToRunning(4) === true)
+ assert(TaskThreadInfo.threadToRunning(5) === false)
+ assert(TaskThreadInfo.threadToRunning(6) === false)
+
+ TaskThreadInfo.threadToLock(1).jobFinished()
+ TaskThreadInfo.threadToStarted(5).await()
+
+ assert(TaskThreadInfo.threadToRunning(1) === false)
+ assert(TaskThreadInfo.threadToRunning(2) === true)
+ assert(TaskThreadInfo.threadToRunning(3) === true)
+ assert(TaskThreadInfo.threadToRunning(4) === true)
+ assert(TaskThreadInfo.threadToRunning(5) === true)
+ assert(TaskThreadInfo.threadToRunning(6) === false)
+
+ TaskThreadInfo.threadToLock(3).jobFinished()
+ TaskThreadInfo.threadToStarted(6).await()
+
+ assert(TaskThreadInfo.threadToRunning(1) === false)
+ assert(TaskThreadInfo.threadToRunning(2) === true)
+ assert(TaskThreadInfo.threadToRunning(3) === false)
+ assert(TaskThreadInfo.threadToRunning(4) === true)
+ assert(TaskThreadInfo.threadToRunning(5) === true)
+ assert(TaskThreadInfo.threadToRunning(6) === true)
+
+ TaskThreadInfo.threadToLock(2).jobFinished()
+ TaskThreadInfo.threadToLock(4).jobFinished()
+ TaskThreadInfo.threadToLock(5).jobFinished()
+ TaskThreadInfo.threadToLock(6).jobFinished()
+ sem.acquire(6)
+ }
+
+ test("Local fair scheduler end-to-end test") {
+ sc = new SparkContext("local[8]", "LocalSchedulerSuite")
+ val sem = new Semaphore(0)
+ System.setProperty("spark.cluster.schedulingmode", "FAIR")
+ val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
+ System.setProperty("spark.fairscheduler.allocation.file", xmlPath)
+
+ createThread(10,"1",sc,sem)
+ TaskThreadInfo.threadToStarted(10).await()
+ createThread(20,"2",sc,sem)
+ TaskThreadInfo.threadToStarted(20).await()
+ createThread(30,"3",sc,sem)
+ TaskThreadInfo.threadToStarted(30).await()
+
+ assert(TaskThreadInfo.threadToRunning(10) === true)
+ assert(TaskThreadInfo.threadToRunning(20) === true)
+ assert(TaskThreadInfo.threadToRunning(30) === true)
+
+ createThread(11,"1",sc,sem)
+ TaskThreadInfo.threadToStarted(11).await()
+ createThread(21,"2",sc,sem)
+ TaskThreadInfo.threadToStarted(21).await()
+ createThread(31,"3",sc,sem)
+ TaskThreadInfo.threadToStarted(31).await()
+
+ assert(TaskThreadInfo.threadToRunning(11) === true)
+ assert(TaskThreadInfo.threadToRunning(21) === true)
+ assert(TaskThreadInfo.threadToRunning(31) === true)
+
+ createThread(12,"1",sc,sem)
+ TaskThreadInfo.threadToStarted(12).await()
+ createThread(22,"2",sc,sem)
+ TaskThreadInfo.threadToStarted(22).await()
+ createThread(32,"3",sc,sem)
+
+ assert(TaskThreadInfo.threadToRunning(12) === true)
+ assert(TaskThreadInfo.threadToRunning(22) === true)
+ assert(TaskThreadInfo.threadToRunning(32) === false)
+
+ TaskThreadInfo.threadToLock(10).jobFinished()
+ TaskThreadInfo.threadToStarted(32).await()
+
+ assert(TaskThreadInfo.threadToRunning(32) === true)
+
+ //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager
+ // queue so that cluster will assign free cpu core to stage 23 after stage 11 finished.
+ //2. priority of 23 and 33 will be meaningless as using fair scheduler here.
+ createThread(23,"2",sc,sem)
+ createThread(33,"3",sc,sem)
+ Thread.sleep(1000)
+
+ TaskThreadInfo.threadToLock(11).jobFinished()
+ TaskThreadInfo.threadToStarted(23).await()
+
+ assert(TaskThreadInfo.threadToRunning(23) === true)
+ assert(TaskThreadInfo.threadToRunning(33) === false)
+
+ TaskThreadInfo.threadToLock(12).jobFinished()
+ TaskThreadInfo.threadToStarted(33).await()
+
+ assert(TaskThreadInfo.threadToRunning(33) === true)
+
+ TaskThreadInfo.threadToLock(20).jobFinished()
+ TaskThreadInfo.threadToLock(21).jobFinished()
+ TaskThreadInfo.threadToLock(22).jobFinished()
+ TaskThreadInfo.threadToLock(23).jobFinished()
+ TaskThreadInfo.threadToLock(30).jobFinished()
+ TaskThreadInfo.threadToLock(31).jobFinished()
+ TaskThreadInfo.threadToLock(32).jobFinished()
+ TaskThreadInfo.threadToLock(33).jobFinished()
+
+ sem.acquire(11)
+ }
+}
diff --git a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
index 2f5af10e69..48aa67c543 100644
--- a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
@@ -57,7 +57,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
taskMetrics.shuffleReadMetrics should be ('defined)
val sm = taskMetrics.shuffleReadMetrics.get
sm.totalBlocksFetched should be > (0)
- sm.shuffleReadMillis should be > (0l)
sm.localBlocksFetched should be > (0)
sm.remoteBlocksFetched should be (0)
sm.remoteBytesRead should be (0l)
@@ -78,7 +77,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
class SaveStageInfo extends SparkListener {
val stageInfos = mutable.Buffer[StageInfo]()
- def onStageCompleted(stage: StageCompleted) {
+ override def onStageCompleted(stage: StageCompleted) {
stageInfos += stage.stageInfo
}
}
diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
index b8c0f6fb76..b9d5f9668e 100644
--- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
@@ -15,8 +15,10 @@ import org.scalatest.time.SpanSugar._
import spark.JavaSerializer
import spark.KryoSerializer
import spark.SizeEstimator
+import spark.util.AkkaUtils
import spark.util.ByteBufferInputStream
+
class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester {
var store: BlockManager = null
var store2: BlockManager = null
@@ -31,7 +33,11 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
val serializer = new KryoSerializer
before {
- actorSystem = ActorSystem("test")
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0)
+ this.actorSystem = actorSystem
+ System.setProperty("spark.driver.port", boundPort.toString)
+ System.setProperty("spark.hostPort", "localhost:" + boundPort)
+
master = new BlockManagerMaster(
actorSystem.actorOf(Props(new spark.storage.BlockManagerMasterActor(true))))
@@ -41,9 +47,14 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
oldHeartBeat = System.setProperty("spark.storage.disableBlockManagerHeartBeat", "true")
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
+ // Set some value ...
+ System.setProperty("spark.hostPort", spark.Utils.localHostName() + ":" + 1111)
}
after {
+ System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
+
if (store != null) {
store.stop()
store = null
@@ -88,9 +99,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("BlockManagerId object caching") {
- val id1 = BlockManagerId("e1", "XXX", 1)
- val id2 = BlockManagerId("e1", "XXX", 1) // this should return the same object as id1
- val id3 = BlockManagerId("e1", "XXX", 2) // this should return a different object
+ val id1 = BlockManagerId("e1", "XXX", 1, 0)
+ val id2 = BlockManagerId("e1", "XXX", 1, 0) // this should return the same object as id1
+ val id3 = BlockManagerId("e1", "XXX", 2, 0) // this should return a different object
assert(id2 === id1, "id2 is not same as id1")
assert(id2.eq(id1), "id2 is not the same object as id1")
assert(id3 != id1, "id3 is same as id1")
@@ -113,7 +124,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
// Putting a1, a2 and a3 in memory and telling master only about a1 and a2
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY)
- store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, false)
+ store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, tellMaster = false)
// Checking whether blocks are in memory
assert(store.getSingle("a1") != None, "a1 was not in store")
@@ -159,7 +170,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
// Putting a1, a2 and a3 in memory and telling master only about a1 and a2
store.putSingle("a1-to-remove", a1, StorageLevel.MEMORY_ONLY)
store.putSingle("a2-to-remove", a2, StorageLevel.MEMORY_ONLY)
- store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, false)
+ store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, tellMaster = false)
// Checking whether blocks are in memory and memory size
val memStatus = master.getMemoryStatus.head._2
@@ -198,6 +209,39 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
}
+ test("removing rdd") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+ // Putting a1, a2 and a3 in memory.
+ store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY)
+ master.removeRdd(0, blocking = false)
+
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("rdd_0_0") should be (None)
+ master.getLocations("rdd_0_0") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("rdd_0_1") should be (None)
+ master.getLocations("rdd_0_1") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("nonrddblock") should not be (None)
+ master.getLocations("nonrddblock") should have size (1)
+ }
+
+ store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY)
+ master.removeRdd(0, blocking = true)
+ store.getSingle("rdd_0_0") should be (None)
+ master.getLocations("rdd_0_0") should have size 0
+ store.getSingle("rdd_0_1") should be (None)
+ master.getLocations("rdd_0_1") should have size 0
+ }
+
test("reregistration on heart beat") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
@@ -226,7 +270,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
master.removeExecutor(store.blockManagerId.executorId)
assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
- store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY)
store.waitForAsyncReregister()
assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master")
@@ -244,7 +288,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
master.removeExecutor(store.blockManagerId.executorId)
val t1 = new Thread {
override def run() {
- store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, true)
+ store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
}
}
val t2 = new Thread {
@@ -454,9 +498,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
- store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, true)
- store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY, true)
- store.put("list3", list3.iterator, StorageLevel.MEMORY_ONLY, true)
+ store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
+ store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
+ store.put("list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
assert(store.get("list2") != None, "list2 was not in store")
assert(store.get("list2").get.size == 2)
assert(store.get("list3") != None, "list3 was not in store")
@@ -465,7 +509,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
assert(store.get("list2") != None, "list2 was not in store")
assert(store.get("list2").get.size == 2)
// At this point list2 was gotten last, so LRU will getSingle rid of list3
- store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, true)
+ store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
assert(store.get("list1") != None, "list1 was not in store")
assert(store.get("list1").get.size == 2)
assert(store.get("list2") != None, "list2 was not in store")
@@ -480,9 +524,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
val list3 = List(new Array[Byte](200), new Array[Byte](200))
val list4 = List(new Array[Byte](200), new Array[Byte](200))
// First store list1 and list2, both in memory, and list3, on disk only
- store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, true)
- store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, true)
- store.put("list3", list3.iterator, StorageLevel.DISK_ONLY, true)
+ store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true)
+ store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true)
+ store.put("list3", list3.iterator, StorageLevel.DISK_ONLY, tellMaster = true)
// At this point LRU should not kick in because list3 is only on disk
assert(store.get("list1") != None, "list2 was not in store")
assert(store.get("list1").get.size === 2)
@@ -497,7 +541,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
assert(store.get("list3") != None, "list1 was not in store")
assert(store.get("list3").get.size === 2)
// Now let's add in list4, which uses both disk and memory; list1 should drop out
- store.put("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, true)
+ store.put("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)
assert(store.get("list1") === None, "list1 was in store")
assert(store.get("list2") != None, "list3 was not in store")
assert(store.get("list2").get.size === 2)
diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb
index d77e53963c..c10ae595de 100644
--- a/docs/_plugins/copy_api_dirs.rb
+++ b/docs/_plugins/copy_api_dirs.rb
@@ -18,7 +18,7 @@ if ENV['SKIP_API'] != '1'
# Copy over the scaladoc from each project into the docs directory.
# This directory will be copied over to _site when `jekyll` command is run.
projects.each do |project_name|
- source = "../" + project_name + "/target/scala-2.9.2/api"
+ source = "../" + project_name + "/target/scala-2.9.3/api"
dest = "api/" + project_name
puts "echo making directory " + dest
diff --git a/docs/configuration.md b/docs/configuration.md
index 17fdbf04d1..3266db7af1 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -22,26 +22,30 @@ the copy executable.
Inside `spark-env.sh`, you *must* set at least the following two variables:
-* `SCALA_HOME`, to point to your Scala installation.
+* `SCALA_HOME`, to point to your Scala installation, or `SCALA_LIBRARY_PATH` to point to the directory for Scala
+ library JARs (if you install Scala as a Debian or RPM package, there is no `SCALA_HOME`, but these libraries
+ are in a separate path, typically /usr/share/java; look for `scala-library.jar`).
* `MESOS_NATIVE_LIBRARY`, if you are [running on a Mesos cluster](running-on-mesos.html).
-In addition, there are four other variables that control execution. These can be set *either in `spark-env.sh`
-or in each job's driver program*, because they will automatically be propagated to workers from the driver.
-For a multi-user environment, we recommend setting the in the driver program instead of `spark-env.sh`, so
-that different user jobs can use different amounts of memory, JVM options, etc.
+In addition, there are four other variables that control execution. These should be set *in the environment that
+launches the job's driver program* instead of `spark-env.sh`, because they will be automatically propagated to
+workers. Setting these per-job instead of in `spark-env.sh` ensures that different jobs can have different settings
+for these variables.
-* `SPARK_MEM`, to set the amount of memory used per node (this should be in the same format as the
- JVM's -Xmx option, e.g. `300m` or `1g`)
* `SPARK_JAVA_OPTS`, to add JVM options. This includes any system properties that you'd like to pass with `-D`.
* `SPARK_CLASSPATH`, to add elements to Spark's classpath.
* `SPARK_LIBRARY_PATH`, to add search directories for native libraries.
+* `SPARK_MEM`, to set the amount of memory used per node. This should be in the same format as the
+ JVM's -Xmx option, e.g. `300m` or `1g`. Note that this option will soon be deprecated in favor of
+ the `spark.executor.memory` system property, so we recommend using that in new code.
-Note that if you do set these in `spark-env.sh`, they will override the values set by user programs, which
-is undesirable; you can choose to have `spark-env.sh` set them only if the user program hasn't, as follows:
+Beware that if you do set these variables in `spark-env.sh`, they will override the values set by user programs,
+which is undesirable; if you prefer, you can choose to have `spark-env.sh` set them only if the user program
+hasn't, as follows:
{% highlight bash %}
-if [ -z "$SPARK_MEM" ] ; then
- SPARK_MEM="1g"
+if [ -z "$SPARK_JAVA_OPTS" ] ; then
+ SPARK_JAVA_OPTS="-verbose:gc"
fi
{% endhighlight %}
@@ -55,11 +59,18 @@ val sc = new SparkContext(...)
{% endhighlight %}
Most of the configurable system properties control internal settings that have reasonable default values. However,
-there are at least four properties that you will commonly want to control:
+there are at least five properties that you will commonly want to control:
<table class="table">
<tr><th>Property Name</th><th>Default</th><th>Meaning</th></tr>
<tr>
+ <td>spark.executor.memory</td>
+ <td>512m</td>
+ <td>
+ Amount of memory to use per executor process, in the same format as JVM memory strings (e.g. `512m`, `2g`).
+ </td>
+</tr>
+<tr>
<td>spark.serializer</td>
<td>spark.JavaSerializer</td>
<td>
@@ -260,6 +271,13 @@ Apart from these, the following properties are also available, and may be useful
applications). Note that any RDD that persists in memory for more than this duration will be cleared as well.
</td>
</tr>
+<tr>
+ <td>spark.streaming.blockInterval</td>
+ <td>200</td>
+ <td>
+ Duration (milliseconds) of how long to batch new objects coming from network receivers.
+ </td>
+</tr>
</table>
diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md
index dc57035eba..eab8a0ff20 100644
--- a/docs/ec2-scripts.md
+++ b/docs/ec2-scripts.md
@@ -106,9 +106,8 @@ permissions on your private key file, you can run `launch` with the
# Configuration
You can edit `/root/spark/conf/spark-env.sh` on each machine to set Spark configuration options, such
-as JVM options and, most crucially, the amount of memory to use per machine (`SPARK_MEM`).
-This file needs to be copied to **every machine** to reflect the change. The easiest way to do this
-is to use a script we provide called `copy-dir`. First edit your `spark-env.sh` file on the master,
+as JVM options. This file needs to be copied to **every machine** to reflect the change. The easiest way to
+do this is to use a script we provide called `copy-dir`. First edit your `spark-env.sh` file on the master,
then run `~/spark-ec2/copy-dir /root/spark/conf` to RSYNC it to all the workers.
The [configuration guide](configuration.html) describes the available configuration options.
diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md
index 3a7a8db4a6..e8aaac74d0 100644
--- a/docs/python-programming-guide.md
+++ b/docs/python-programming-guide.md
@@ -17,24 +17,23 @@ There are a few key differences between the Python and Scala APIs:
* Python is dynamically typed, so RDDs can hold objects of different types.
* PySpark does not currently support the following Spark features:
- Special functions on RDDs of doubles, such as `mean` and `stdev`
- - `lookup`
+ - `lookup`, `sample` and `sort`
- `persist` at storage levels other than `MEMORY_ONLY`
- - `sample`
- - `sort`
+ - Execution on Windows -- this is slated for a future release
In PySpark, RDDs support the same methods as their Scala counterparts but take Python functions and return Python collection types.
Short functions can be passed to RDD methods using Python's [`lambda`](http://www.diveintopython.net/power_of_introspection/lambda_functions.html) syntax:
{% highlight python %}
logData = sc.textFile(logFile).cache()
-errors = logData.filter(lambda s: 'ERROR' in s.split())
+errors = logData.filter(lambda line: "ERROR" in line)
{% endhighlight %}
You can also pass functions that are defined using the `def` keyword; this is useful for more complicated functions that cannot be expressed using `lambda`:
{% highlight python %}
def is_error(line):
- return 'ERROR' in line.split()
+ return "ERROR" in line
errors = logData.filter(is_error)
{% endhighlight %}
@@ -43,8 +42,7 @@ Functions can access objects in enclosing scopes, although modifications to thos
{% highlight python %}
error_keywords = ["Exception", "Error"]
def is_error(line):
- words = line.split()
- return any(keyword in words for keyword in error_keywords)
+ return any(keyword in line for keyword in error_keywords)
errors = logData.filter(is_error)
{% endhighlight %}
diff --git a/docs/quick-start.md b/docs/quick-start.md
index 2d961b29cb..335643536a 100644
--- a/docs/quick-start.md
+++ b/docs/quick-start.md
@@ -113,8 +113,8 @@ import SparkContext._
object SimpleJob {
def main(args: Array[String]) {
- val logFile = "/var/log/syslog" // Should be some file on your system
- val sc = new SparkContext("local", "Simple Job", "$YOUR_SPARK_HOME",
+ val logFile = "$YOUR_SPARK_HOME/README.md" // Should be some file on your system
+ val sc = new SparkContext("local", "Simple Job", "YOUR_SPARK_HOME",
List("target/scala-{{site.SCALA_VERSION}}/simple-project_{{site.SCALA_VERSION}}-1.0.jar"))
val logData = sc.textFile(logFile, 2).cache()
val numAs = logData.filter(line => line.contains("a")).count()
@@ -124,7 +124,7 @@ object SimpleJob {
}
{% endhighlight %}
-This job simply counts the number of lines containing 'a' and the number containing 'b' in a system log file. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext, we initialize a SparkContext as part of the job. We pass the SparkContext constructor four arguments, the type of scheduler we want to use (in this case, a local scheduler), a name for the job, the directory where Spark is installed, and a name for the jar file containing the job's sources. The final two arguments are needed in a distributed setting, where Spark is running across several nodes, so we include them for completeness. Spark will automatically ship the jar files you list to slave nodes.
+This job simply counts the number of lines containing 'a' and the number containing 'b' in the Spark README. Note that you'll need to replace $YOUR_SPARK_HOME with the location where Spark is installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext, we initialize a SparkContext as part of the job. We pass the SparkContext constructor four arguments, the type of scheduler we want to use (in this case, a local scheduler), a name for the job, the directory where Spark is installed, and a name for the jar file containing the job's sources. The final two arguments are needed in a distributed setting, where Spark is running across several nodes, so we include them for completeness. Spark will automatically ship the jar files you list to slave nodes.
This file depends on the Spark API, so we'll also include an sbt configuration file, `simple.sbt` which explains that Spark is a dependency. This file also adds two repositories which host Spark dependencies:
@@ -156,7 +156,7 @@ $ find .
$ sbt package
$ sbt run
...
-Lines with a: 8422, Lines with b: 1836
+Lines with a: 46, Lines with b: 23
{% endhighlight %}
This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS.
@@ -173,7 +173,7 @@ import spark.api.java.function.Function;
public class SimpleJob {
public static void main(String[] args) {
- String logFile = "/var/log/syslog"; // Should be some file on your system
+ String logFile = "$YOUR_SPARK_HOME/README.md"; // Should be some file on your system
JavaSparkContext sc = new JavaSparkContext("local", "Simple Job",
"$YOUR_SPARK_HOME", new String[]{"target/simple-project-1.0.jar"});
JavaRDD<String> logData = sc.textFile(logFile).cache();
@@ -191,7 +191,7 @@ public class SimpleJob {
}
{% endhighlight %}
-This job simply counts the number of lines containing 'a' and the number containing 'b' in a system log file. Note that like in the Scala example, we initialize a SparkContext, though we use the special `JavaSparkContext` class to get a Java-friendly one. We also create RDDs (represented by `JavaRDD`) and run transformations on them. Finally, we pass functions to Spark by creating classes that extend `spark.api.java.function.Function`. The [Java programming guide](java-programming-guide.html) describes these differences in more detail.
+This job simply counts the number of lines containing 'a' and the number containing 'b' in a system log file. Note that you'll need to replace $YOUR_SPARK_HOME with the location where Spark is installed. As with the Scala example, we initialize a SparkContext, though we use the special `JavaSparkContext` class to get a Java-friendly one. We also create RDDs (represented by `JavaRDD`) and run transformations on them. Finally, we pass functions to Spark by creating classes that extend `spark.api.java.function.Function`. The [Java programming guide](java-programming-guide.html) describes these differences in more detail.
To build the job, we also write a Maven `pom.xml` file that lists Spark as a dependency. Note that Spark artifacts are tagged with a Scala version.
@@ -239,7 +239,7 @@ Now, we can execute the job using Maven:
$ mvn package
$ mvn exec:java -Dexec.mainClass="SimpleJob"
...
-Lines with a: 8422, Lines with b: 1836
+Lines with a: 46, Lines with b: 23
{% endhighlight %}
This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS.
@@ -253,7 +253,7 @@ As an example, we'll create a simple Spark job, `SimpleJob.py`:
"""SimpleJob.py"""
from pyspark import SparkContext
-logFile = "/var/log/syslog" # Should be some file on your system
+logFile = "$YOUR_SPARK_HOME/README.md" # Should be some file on your system
sc = SparkContext("local", "Simple job")
logData = sc.textFile(logFile).cache()
@@ -265,7 +265,8 @@ print "Lines with a: %i, lines with b: %i" % (numAs, numBs)
This job simply counts the number of lines containing 'a' and the number containing 'b' in a system log file.
-Like in the Scala and Java examples, we use a SparkContext to create RDDs.
+Note that you'll need to replace $YOUR_SPARK_HOME with the location where Spark is installed.
+As with the Scala and Java examples, we use a SparkContext to create RDDs.
We can pass Python functions to Spark, which are automatically serialized along with any variables that they reference.
For jobs that use custom classes or third-party libraries, we can add those code dependencies to SparkContext to ensure that they will be available on remote machines; this is described in more detail in the [Python programming guide](python-programming-guide.html).
`SimpleJob` is simple enough that we do not need to specify any code dependencies.
@@ -276,7 +277,7 @@ We can run this job using the `pyspark` script:
$ cd $SPARK_HOME
$ ./pyspark SimpleJob.py
...
-Lines with a: 8422, Lines with b: 1836
+Lines with a: 46, Lines with b: 23
{% endhighlight python %}
This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS.
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index c2957e6cb4..66fb8d73e8 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -5,24 +5,54 @@ title: Launching Spark on YARN
Experimental support for running over a [YARN (Hadoop
NextGen)](http://hadoop.apache.org/docs/r2.0.2-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html)
-cluster was added to Spark in version 0.6.0. Because YARN depends on version
-2.0 of the Hadoop libraries, this currently requires checking out a separate
-branch of Spark, called `yarn`, which you can do as follows:
+cluster was added to Spark in version 0.6.0. This was merged into master as part of 0.7 effort.
+To build spark core with YARN support, please use the hadoop2-yarn profile.
+Ex: mvn -Phadoop2-yarn clean install
- git clone git://github.com/mesos/spark
- cd spark
- git checkout -b yarn --track origin/yarn
+# Building spark core consolidated jar.
+
+We need a consolidated spark core jar (which bundles all the required dependencies) to run Spark jobs on a yarn cluster.
+This can be built either through sbt or via maven.
+
+- Building spark assembled jar via sbt.
+ It is a manual process of enabling it in project/SparkBuild.scala.
+Please comment out the
+ HADOOP_VERSION, HADOOP_MAJOR_VERSION and HADOOP_YARN
+variables before the line 'For Hadoop 2 YARN support'
+Next, uncomment the subsequent 3 variable declaration lines (for these three variables) which enable hadoop yarn support.
+
+Assembly of the jar Ex:
+
+ ./sbt/sbt clean assembly
+
+The assembled jar would typically be something like :
+`./core/target/spark-core-assembly-0.8.0-SNAPSHOT.jar`
+
+
+- Building spark assembled jar via Maven.
+ Use the hadoop2-yarn profile and execute the package target.
+
+Something like this. Ex:
+
+ mvn -Phadoop2-yarn clean package -DskipTests=true
+
+
+This will build the shaded (consolidated) jar. Typically something like :
+`./repl-bin/target/spark-repl-bin-<VERSION>-shaded-hadoop2-yarn.jar`
# Preparations
-- In order to distribute Spark within the cluster, it must be packaged into a single JAR file. This can be done by running `sbt/sbt assembly`
+- Building spark core assembled jar (see above).
- Your application code must be packaged into a separate JAR file.
If you want to test out the YARN deployment mode, you can use the current Spark examples. A `spark-examples_{{site.SCALA_VERSION}}-{{site.SPARK_VERSION}}` file can be generated by running `sbt/sbt package`. NOTE: since the documentation you're reading is for Spark version {{site.SPARK_VERSION}}, we are assuming here that you have downloaded Spark {{site.SPARK_VERSION}} or checked it out of source control. If you are using a different version of Spark, the version numbers in the jar generated by the sbt package command will obviously be different.
# Launching Spark on YARN
+Ensure that HADOOP_CONF_DIR or YARN_CONF_DIR points to the directory which contains the (client side) configuration files for the hadoop cluster.
+This would be used to connect to the cluster, write to the dfs and submit jobs to the resource manager.
+
The command to launch the YARN Client is as follows:
SPARK_JAR=<SPARK_YAR_FILE> ./run spark.deploy.yarn.Client \
@@ -30,22 +60,28 @@ The command to launch the YARN Client is as follows:
--class <APP_MAIN_CLASS> \
--args <APP_MAIN_ARGUMENTS> \
--num-workers <NUMBER_OF_WORKER_MACHINES> \
+ --master-memory <MEMORY_FOR_MASTER> \
--worker-memory <MEMORY_PER_WORKER> \
- --worker-cores <CORES_PER_WORKER>
+ --worker-cores <CORES_PER_WORKER> \
+ --user <hadoop_user> \
+ --queue <queue_name>
For example:
SPARK_JAR=./core/target/spark-core-assembly-{{site.SPARK_VERSION}}.jar ./run spark.deploy.yarn.Client \
--jar examples/target/scala-{{site.SCALA_VERSION}}/spark-examples_{{site.SCALA_VERSION}}-{{site.SPARK_VERSION}}.jar \
--class spark.examples.SparkPi \
- --args standalone \
+ --args yarn-standalone \
--num-workers 3 \
+ --master-memory 4g \
--worker-memory 2g \
- --worker-cores 2
+ --worker-cores 1
The above starts a YARN Client programs which periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running.
# Important Notes
-- When your application instantiates a Spark context it must use a special "standalone" master url. This starts the scheduler without forcing it to connect to a cluster. A good way to handle this is to pass "standalone" as an argument to your program, as shown in the example above.
-- YARN does not support requesting container resources based on the number of cores. Thus the numbers of cores given via command line arguments cannot be guaranteed.
+- When your application instantiates a Spark context it must use a special "yarn-standalone" master url. This starts the scheduler without forcing it to connect to a cluster. A good way to handle this is to pass "yarn-standalone" as an argument to your program, as shown in the example above.
+- We do not requesting container resources based on the number of cores. Thus the numbers of cores given via command line arguments cannot be guaranteed.
+- Currently, we have not yet integrated with hadoop security. If --user is present, the hadoop_user specified will be used to run the tasks on the cluster. If unspecified, current user will be used (which should be valid in cluster).
+ Once hadoop security support is added, and if hadoop cluster is enabled with security, additional restrictions would apply via delegation tokens passed.
diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md
index 2315aadbdf..e9cf9ef36f 100644
--- a/docs/scala-programming-guide.md
+++ b/docs/scala-programming-guide.md
@@ -43,12 +43,18 @@ new SparkContext(master, appName, [sparkHome], [jars])
The `master` parameter is a string specifying a [Spark or Mesos cluster URL](#master-urls) to connect to, or a special "local" string to run in local mode, as described below. `appName` is a name for your application, which will be shown in the cluster web UI. Finally, the last two parameters are needed to deploy your code to a cluster if running in distributed mode, as described later.
-In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `MASTER` environment variable. For example, to run on four cores, use
+In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `MASTER` environment variable, and you can add JARs to the classpath with the `ADD_JARS` variable. For example, to run `spark-shell` on four cores, use
{% highlight bash %}
$ MASTER=local[4] ./spark-shell
{% endhighlight %}
+Or, to also add `code.jar` to its classpath, use:
+
+{% highlight bash %}
+$ MASTER=local[4] ADD_JARS=code.jar ./spark-shell
+{% endhighlight %}
+
### Master URLs
The master URL passed to Spark can be in one of the following formats:
@@ -67,6 +73,8 @@ The master URL passed to Spark can be in one of the following formats:
</td></tr>
</table>
+If no master URL is specified, the spark shell defaults to "local".
+
For running on YARN, Spark launches an instance of the standalone deploy cluster within YARN; see [running on YARN](running-on-yarn.html) for details.
### Deploying Code on a Cluster
@@ -76,7 +84,7 @@ If you want to run your job on a cluster, you will need to specify the two optio
* `sparkHome`: The path at which Spark is installed on your worker machines (it should be the same on all of them).
* `jars`: A list of JAR files on the local machine containing your job's code and any dependencies, which Spark will deploy to all the worker nodes. You'll need to package your job into a set of JARs using your build system. For example, if you're using SBT, the [sbt-assembly](https://github.com/sbt/sbt-assembly) plugin is a good way to make a single JAR with your code and dependencies.
-If you run `spark-shell` on a cluster, any classes you define in the shell will automatically be distributed.
+If you run `spark-shell` on a cluster, you can add JARs to it by specifying the `ADD_JARS` environment variable before you launch it. This variable should contain a comma-separated list of JARs. For example, `ADD_JARS=a.jar,b.jar ./spark-shell` will launch a shell with `a.jar` and `b.jar` on its classpath. In addition, any new classes you define in the shell will automatically be distributed.
# Resilient Distributed Datasets (RDDs)
diff --git a/docs/tuning.md b/docs/tuning.md
index 32c7ab86e9..5ffca54481 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -157,9 +157,9 @@ their work directories), *not* on your driver program.
**Cache Size Tuning**
-One important configuration parameter for GC is the amount of memory that should be used for
-caching RDDs. By default, Spark uses 66% of the configured memory (`SPARK_MEM`) to cache RDDs. This means that
- 33% of memory is available for any objects created during task execution.
+One important configuration parameter for GC is the amount of memory that should be used for caching RDDs.
+By default, Spark uses 66% of the configured executor memory (`spark.executor.memory` or `SPARK_MEM`) to
+cache RDDs. This means that 33% of memory is available for any objects created during task execution.
In case your tasks slow down and you find that your JVM is garbage-collecting frequently or running out of
memory, lowering this value will help reduce the memory consumption. To change this to say 50%, you can call
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 9f2daad2b6..7affe6fffc 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -103,7 +103,7 @@ def parse_args():
parser.print_help()
sys.exit(1)
(action, cluster_name) = args
- if opts.identity_file == None and action in ['launch', 'login']:
+ if opts.identity_file == None and action in ['launch', 'login', 'start']:
print >> stderr, ("ERROR: The -i or --identity-file argument is " +
"required for " + action)
sys.exit(1)
diff --git a/examples/pom.xml b/examples/pom.xml
index f521e85027..064f1179e5 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -34,6 +34,41 @@
<artifactId>scalacheck_${scala.version}</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.cassandra</groupId>
+ <artifactId>cassandra-all</artifactId>
+ <version>1.2.5</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.googlecode.concurrentlinkedhashmap</groupId>
+ <artifactId>concurrentlinkedhashmap-lru</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.ning</groupId>
+ <artifactId>compress-lzf</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>io.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>jline</groupId>
+ <artifactId>jline</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.apache.cassandra.deps</groupId>
+ <artifactId>avro</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
@@ -67,6 +102,11 @@
<artifactId>hadoop-core</artifactId>
<scope>provided</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.hbase</groupId>
+ <artifactId>hbase</artifactId>
+ <version>0.94.6</version>
+ </dependency>
</dependencies>
<build>
<plugins>
@@ -105,6 +145,11 @@
<artifactId>hadoop-client</artifactId>
<scope>provided</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.hbase</groupId>
+ <artifactId>hbase</artifactId>
+ <version>0.94.6</version>
+ </dependency>
</dependencies>
<build>
<plugins>
@@ -118,5 +163,53 @@
</plugins>
</build>
</profile>
+ <profile>
+ <id>hadoop2-yarn</id>
+ <dependencies>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-core</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-streaming</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-common</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hbase</groupId>
+ <artifactId>hbase</artifactId>
+ <version>0.94.6</version>
+ </dependency>
+ </dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <configuration>
+ <classifier>hadoop2-yarn</classifier>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
</profiles>
</project>
diff --git a/examples/src/main/scala/spark/examples/CassandraTest.scala b/examples/src/main/scala/spark/examples/CassandraTest.scala
new file mode 100644
index 0000000000..0fe1833e83
--- /dev/null
+++ b/examples/src/main/scala/spark/examples/CassandraTest.scala
@@ -0,0 +1,196 @@
+package spark.examples
+
+import org.apache.hadoop.mapreduce.Job
+import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat
+import org.apache.cassandra.hadoop.ConfigHelper
+import org.apache.cassandra.hadoop.ColumnFamilyInputFormat
+import org.apache.cassandra.thrift._
+import spark.SparkContext
+import spark.SparkContext._
+import java.nio.ByteBuffer
+import java.util.SortedMap
+import org.apache.cassandra.db.IColumn
+import org.apache.cassandra.utils.ByteBufferUtil
+import scala.collection.JavaConversions._
+
+
+/*
+ * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra
+ * support for Hadoop.
+ *
+ * To run this example, run this file with the following command params -
+ * <spark_master> <cassandra_node> <cassandra_port>
+ *
+ * So if you want to run this on localhost this will be,
+ * local[3] localhost 9160
+ *
+ * The example makes some assumptions:
+ * 1. You have already created a keyspace called casDemo and it has a column family named Words
+ * 2. There are column family has a column named "para" which has test content.
+ *
+ * You can create the content by running the following script at the bottom of this file with
+ * cassandra-cli.
+ *
+ */
+object CassandraTest {
+
+ def main(args: Array[String]) {
+
+ // Get a SparkContext
+ val sc = new SparkContext(args(0), "casDemo")
+
+ // Build the job configuration with ConfigHelper provided by Cassandra
+ val job = new Job()
+ job.setInputFormatClass(classOf[ColumnFamilyInputFormat])
+
+ val host: String = args(1)
+ val port: String = args(2)
+
+ ConfigHelper.setInputInitialAddress(job.getConfiguration(), host)
+ ConfigHelper.setInputRpcPort(job.getConfiguration(), port)
+ ConfigHelper.setOutputInitialAddress(job.getConfiguration(), host)
+ ConfigHelper.setOutputRpcPort(job.getConfiguration(), port)
+ ConfigHelper.setInputColumnFamily(job.getConfiguration(), "casDemo", "Words")
+ ConfigHelper.setOutputColumnFamily(job.getConfiguration(), "casDemo", "WordCount")
+
+ val predicate = new SlicePredicate()
+ val sliceRange = new SliceRange()
+ sliceRange.setStart(Array.empty[Byte])
+ sliceRange.setFinish(Array.empty[Byte])
+ predicate.setSlice_range(sliceRange)
+ ConfigHelper.setInputSlicePredicate(job.getConfiguration(), predicate)
+
+ ConfigHelper.setInputPartitioner(job.getConfiguration(), "Murmur3Partitioner")
+ ConfigHelper.setOutputPartitioner(job.getConfiguration(), "Murmur3Partitioner")
+
+ // Make a new Hadoop RDD
+ val casRdd = sc.newAPIHadoopRDD(
+ job.getConfiguration(),
+ classOf[ColumnFamilyInputFormat],
+ classOf[ByteBuffer],
+ classOf[SortedMap[ByteBuffer, IColumn]])
+
+ // Let us first get all the paragraphs from the retrieved rows
+ val paraRdd = casRdd.map {
+ case (key, value) => {
+ ByteBufferUtil.string(value.get(ByteBufferUtil.bytes("para")).value())
+ }
+ }
+
+ // Lets get the word count in paras
+ val counts = paraRdd.flatMap(p => p.split(" ")).map(word => (word, 1)).reduceByKey(_ + _)
+
+ counts.collect().foreach {
+ case (word, count) => println(word + ":" + count)
+ }
+
+ counts.map {
+ case (word, count) => {
+ val colWord = new org.apache.cassandra.thrift.Column()
+ colWord.setName(ByteBufferUtil.bytes("word"))
+ colWord.setValue(ByteBufferUtil.bytes(word))
+ colWord.setTimestamp(System.currentTimeMillis)
+
+ val colCount = new org.apache.cassandra.thrift.Column()
+ colCount.setName(ByteBufferUtil.bytes("wcount"))
+ colCount.setValue(ByteBufferUtil.bytes(count.toLong))
+ colCount.setTimestamp(System.currentTimeMillis)
+
+ val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis)
+
+ val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil
+ mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn())
+ mutations.get(0).column_or_supercolumn.setColumn(colWord)
+ mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn())
+ mutations.get(1).column_or_supercolumn.setColumn(colCount)
+ (outputkey, mutations)
+ }
+ }.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]],
+ classOf[ColumnFamilyOutputFormat], job.getConfiguration)
+ }
+}
+
+/*
+create keyspace casDemo;
+use casDemo;
+
+create column family WordCount with comparator = UTF8Type;
+update column family WordCount with column_metadata =
+ [{column_name: word, validation_class: UTF8Type},
+ {column_name: wcount, validation_class: LongType}];
+
+create column family Words with comparator = UTF8Type;
+update column family Words with column_metadata =
+ [{column_name: book, validation_class: UTF8Type},
+ {column_name: para, validation_class: UTF8Type}];
+
+assume Words keys as utf8;
+
+set Words['3musk001']['book'] = 'The Three Musketeers';
+set Words['3musk001']['para'] = 'On the first Monday of the month of April, 1625, the market
+ town of Meung, in which the author of ROMANCE OF THE ROSE was born, appeared to
+ be in as perfect a state of revolution as if the Huguenots had just made
+ a second La Rochelle of it. Many citizens, seeing the women flying
+ toward the High Street, leaving their children crying at the open doors,
+ hastened to don the cuirass, and supporting their somewhat uncertain
+ courage with a musket or a partisan, directed their steps toward the
+ hostelry of the Jolly Miller, before which was gathered, increasing
+ every minute, a compact group, vociferous and full of curiosity.';
+
+set Words['3musk002']['book'] = 'The Three Musketeers';
+set Words['3musk002']['para'] = 'In those times panics were common, and few days passed without
+ some city or other registering in its archives an event of this kind. There were
+ nobles, who made war against each other; there was the king, who made
+ war against the cardinal; there was Spain, which made war against the
+ king. Then, in addition to these concealed or public, secret or open
+ wars, there were robbers, mendicants, Huguenots, wolves, and scoundrels,
+ who made war upon everybody. The citizens always took up arms readily
+ against thieves, wolves or scoundrels, often against nobles or
+ Huguenots, sometimes against the king, but never against cardinal or
+ Spain. It resulted, then, from this habit that on the said first Monday
+ of April, 1625, the citizens, on hearing the clamor, and seeing neither
+ the red-and-yellow standard nor the livery of the Duc de Richelieu,
+ rushed toward the hostel of the Jolly Miller. When arrived there, the
+ cause of the hubbub was apparent to all';
+
+set Words['3musk003']['book'] = 'The Three Musketeers';
+set Words['3musk003']['para'] = 'You ought, I say, then, to husband the means you have, however
+ large the sum may be; but you ought also to endeavor to perfect yourself in
+ the exercises becoming a gentleman. I will write a letter today to the
+ Director of the Royal Academy, and tomorrow he will admit you without
+ any expense to yourself. Do not refuse this little service. Our
+ best-born and richest gentlemen sometimes solicit it without being able
+ to obtain it. You will learn horsemanship, swordsmanship in all its
+ branches, and dancing. You will make some desirable acquaintances; and
+ from time to time you can call upon me, just to tell me how you are
+ getting on, and to say whether I can be of further service to you.';
+
+
+set Words['thelostworld001']['book'] = 'The Lost World';
+set Words['thelostworld001']['para'] = 'She sat with that proud, delicate profile of hers outlined
+ against the red curtain. How beautiful she was! And yet how aloof! We had been
+ friends, quite good friends; but never could I get beyond the same
+ comradeship which I might have established with one of my
+ fellow-reporters upon the Gazette,--perfectly frank, perfectly kindly,
+ and perfectly unsexual. My instincts are all against a woman being too
+ frank and at her ease with me. It is no compliment to a man. Where
+ the real sex feeling begins, timidity and distrust are its companions,
+ heritage from old wicked days when love and violence went often hand in
+ hand. The bent head, the averted eye, the faltering voice, the wincing
+ figure--these, and not the unshrinking gaze and frank reply, are the
+ true signals of passion. Even in my short life I had learned as much
+ as that--or had inherited it in that race memory which we call instinct.';
+
+set Words['thelostworld002']['book'] = 'The Lost World';
+set Words['thelostworld002']['para'] = 'I always liked McArdle, the crabbed, old, round-backed,
+ red-headed news editor, and I rather hoped that he liked me. Of course, Beaumont was
+ the real boss; but he lived in the rarefied atmosphere of some Olympian
+ height from which he could distinguish nothing smaller than an
+ international crisis or a split in the Cabinet. Sometimes we saw him
+ passing in lonely majesty to his inner sanctum, with his eyes staring
+ vaguely and his mind hovering over the Balkans or the Persian Gulf. He
+ was above and beyond us. But McArdle was his first lieutenant, and it
+ was he that we knew. The old man nodded as I entered the room, and he
+ pushed his spectacles far up on his bald forehead.';
+
+*/
diff --git a/examples/src/main/scala/spark/examples/HBaseTest.scala b/examples/src/main/scala/spark/examples/HBaseTest.scala
new file mode 100644
index 0000000000..6e910154d4
--- /dev/null
+++ b/examples/src/main/scala/spark/examples/HBaseTest.scala
@@ -0,0 +1,35 @@
+package spark.examples
+
+import spark._
+import spark.rdd.NewHadoopRDD
+import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor}
+import org.apache.hadoop.hbase.client.HBaseAdmin
+import org.apache.hadoop.hbase.mapreduce.TableInputFormat
+
+object HBaseTest {
+ def main(args: Array[String]) {
+ val sc = new SparkContext(args(0), "HBaseTest",
+ System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
+
+ val conf = HBaseConfiguration.create()
+
+ // Other options for configuring scan behavior are available. More information available at
+ // http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html
+ conf.set(TableInputFormat.INPUT_TABLE, args(1))
+
+ // Initialize hBase table if necessary
+ val admin = new HBaseAdmin(conf)
+ if(!admin.isTableAvailable(args(1))) {
+ val tableDesc = new HTableDescriptor(args(1))
+ admin.createTable(tableDesc)
+ }
+
+ val hBaseRDD = sc.newAPIHadoopRDD(conf, classOf[TableInputFormat],
+ classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable],
+ classOf[org.apache.hadoop.hbase.client.Result])
+
+ hBaseRDD.count()
+
+ System.exit(0)
+ }
+} \ No newline at end of file
diff --git a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/spark/examples/SparkHdfsLR.scala
index 0f42f405a0..3d080a0257 100644
--- a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala
+++ b/examples/src/main/scala/spark/examples/SparkHdfsLR.scala
@@ -4,6 +4,8 @@ import java.util.Random
import scala.math.exp
import spark.util.Vector
import spark._
+import spark.deploy.SparkHadoopUtil
+import spark.scheduler.InputFormatInfo
/**
* Logistic regression based classification.
@@ -32,9 +34,13 @@ object SparkHdfsLR {
System.err.println("Usage: SparkHdfsLR <master> <file> <iters>")
System.exit(1)
}
+ val inputPath = args(1)
+ val conf = SparkHadoopUtil.newConfiguration()
val sc = new SparkContext(args(0), "SparkHdfsLR",
- System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
- val lines = sc.textFile(args(1))
+ System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")), Map(),
+ InputFormatInfo.computePreferredLocations(
+ Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath))))
+ val lines = sc.textFile(inputPath)
val points = lines.map(parsePoint _).cache()
val ITERATIONS = args(2).toInt
diff --git a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
index c3a9e491ba..9202e65e09 100644
--- a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
+++ b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
@@ -37,7 +37,7 @@ object KafkaWordCount {
ssc.checkpoint("checkpoint")
val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap
- val lines = ssc.kafkaStream[String](zkQuorum, group, topicpMap)
+ val lines = ssc.kafkaStream(zkQuorum, group, topicpMap)
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
wordCounts.print()
diff --git a/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala b/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala
new file mode 100644
index 0000000000..51c3c9f9b4
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/StatefulNetworkWordCount.scala
@@ -0,0 +1,50 @@
+package spark.streaming.examples
+
+import spark.streaming._
+import spark.streaming.StreamingContext._
+
+/**
+ * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every second.
+ * Usage: StatefulNetworkWordCount <master> <hostname> <port>
+ * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
+ * <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data.
+ *
+ * To run this on your local machine, you need to first run a Netcat server
+ * `$ nc -lk 9999`
+ * and then run the example
+ * `$ ./run spark.streaming.examples.StatefulNetworkWordCount local[2] localhost 9999`
+ */
+object StatefulNetworkWordCount {
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: StatefulNetworkWordCount <master> <hostname> <port>\n" +
+ "In local mode, <master> should be 'local[n]' with n > 1")
+ System.exit(1)
+ }
+
+ val updateFunc = (values: Seq[Int], state: Option[Int]) => {
+ val currentCount = values.foldLeft(0)(_ + _)
+
+ val previousCount = state.getOrElse(0)
+
+ Some(currentCount + previousCount)
+ }
+
+ // Create the context with a 1 second batch size
+ val ssc = new StreamingContext(args(0), "NetworkWordCumulativeCountUpdateStateByKey", Seconds(1),
+ System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
+ ssc.checkpoint(".")
+
+ // Create a NetworkInputDStream on target ip:port and count the
+ // words in input stream of \n delimited test (eg. generated by 'nc')
+ val lines = ssc.socketTextStream(args(1), args(2).toInt)
+ val words = lines.flatMap(_.split(" "))
+ val wordDstream = words.map(x => (x, 1))
+
+ // Update the cumulative count using updateStateByKey
+ // This will give a Dstream made of state (which is the cumulative count of the words)
+ val stateDstream = wordDstream.updateStateByKey[Int](updateFunc)
+ stateDstream.print()
+ ssc.start()
+ }
+}
diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala
index a9642100e3..528778ed72 100644
--- a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala
+++ b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala
@@ -26,8 +26,8 @@ import spark.SparkContext._
*/
object TwitterAlgebirdCMS {
def main(args: Array[String]) {
- if (args.length < 3) {
- System.err.println("Usage: TwitterAlgebirdCMS <master> <twitter_username> <twitter_password>" +
+ if (args.length < 1) {
+ System.err.println("Usage: TwitterAlgebirdCMS <master>" +
" [filter1] [filter2] ... [filter n]")
System.exit(1)
}
@@ -40,12 +40,11 @@ object TwitterAlgebirdCMS {
// K highest frequency elements to take
val TOPK = 10
- val Array(master, username, password) = args.slice(0, 3)
- val filters = args.slice(3, args.length)
+ val (master, filters) = (args.head, args.tail)
val ssc = new StreamingContext(master, "TwitterAlgebirdCMS", Seconds(10),
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
- val stream = ssc.twitterStream(username, password, filters, StorageLevel.MEMORY_ONLY_SER)
+ val stream = ssc.twitterStream(None, filters, StorageLevel.MEMORY_ONLY_SER)
val users = stream.map(status => status.getUser.getId)
diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala
index f3288bfb85..896e9fd8af 100644
--- a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala
+++ b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala
@@ -21,20 +21,19 @@ import spark.streaming.dstream.TwitterInputDStream
*/
object TwitterAlgebirdHLL {
def main(args: Array[String]) {
- if (args.length < 3) {
- System.err.println("Usage: TwitterAlgebirdHLL <master> <twitter_username> <twitter_password>" +
+ if (args.length < 1) {
+ System.err.println("Usage: TwitterAlgebirdHLL <master>" +
" [filter1] [filter2] ... [filter n]")
System.exit(1)
}
/** Bit size parameter for HyperLogLog, trades off accuracy vs size */
val BIT_SIZE = 12
- val Array(master, username, password) = args.slice(0, 3)
- val filters = args.slice(3, args.length)
+ val (master, filters) = (args.head, args.tail)
val ssc = new StreamingContext(master, "TwitterAlgebirdHLL", Seconds(5),
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
- val stream = ssc.twitterStream(username, password, filters, StorageLevel.MEMORY_ONLY_SER)
+ val stream = ssc.twitterStream(None, filters, StorageLevel.MEMORY_ONLY_SER)
val users = stream.map(status => status.getUser.getId)
diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala
index 9d4494c6f2..65f0b6d352 100644
--- a/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala
+++ b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala
@@ -12,18 +12,17 @@ import spark.SparkContext._
*/
object TwitterPopularTags {
def main(args: Array[String]) {
- if (args.length < 3) {
- System.err.println("Usage: TwitterPopularTags <master> <twitter_username> <twitter_password>" +
+ if (args.length < 1) {
+ System.err.println("Usage: TwitterPopularTags <master>" +
" [filter1] [filter2] ... [filter n]")
System.exit(1)
}
- val Array(master, username, password) = args.slice(0, 3)
- val filters = args.slice(3, args.length)
+ val (master, filters) = (args.head, args.tail)
val ssc = new StreamingContext(master, "TwitterPopularTags", Seconds(2),
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
- val stream = ssc.twitterStream(username, password, filters)
+ val stream = ssc.twitterStream(None, filters)
val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#")))
diff --git a/pom.xml b/pom.xml
index 72fcc00d43..5d624d1446 100644
--- a/pom.xml
+++ b/pom.xml
@@ -56,9 +56,12 @@
<akka.version>2.1.4</akka.version>
<spray.version>1.1-M7</spray.version>
<spray.json.version>1.2.3</spray.json.version>
- <slf4j.version>1.6.1</slf4j.version>
+ <slf4j.version>1.7.2</slf4j.version>
<cdh.version>4.1.2</cdh.version>
<log4j.version>1.2.17</log4j.version>
+
+ <PermGen>64m</PermGen>
+ <MaxPermGen>512m</MaxPermGen>
</properties>
<repositories>
@@ -106,17 +109,6 @@
<enabled>false</enabled>
</snapshots>
</repository>
- <repository>
- <id>twitter4j-repo</id>
- <name>Twitter4J Repository</name>
- <url>http://twitter4j.org/maven2/</url>
- <releases>
- <enabled>true</enabled>
- </releases>
- <snapshots>
- <enabled>false</enabled>
- </snapshots>
- </repository>
</repositories>
<pluginRepositories>
<pluginRepository>
@@ -187,9 +179,9 @@
<version>0.8.4</version>
</dependency>
<dependency>
- <groupId>asm</groupId>
- <artifactId>asm-all</artifactId>
- <version>3.3.1</version>
+ <groupId>org.ow2.asm</groupId>
+ <artifactId>asm</artifactId>
+ <version>4.0</version>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
@@ -253,7 +245,7 @@
</dependency>
<dependency>
<groupId>com.github.scala-incubator.io</groupId>
- <artifactId>scala-io-file_${scala.version}</artifactId>
+ <artifactId>scala-io-file_2.9.2</artifactId>
<version>0.4.1</version>
</dependency>
<dependency>
@@ -261,6 +253,17 @@
<artifactId>mesos</artifactId>
<version>${mesos.version}</version>
</dependency>
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-all</artifactId>
+ <version>4.0.0.Beta2</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.derby</groupId>
+ <artifactId>derby</artifactId>
+ <version>10.4.2.0</version>
+ <scope>test</scope>
+ </dependency>
<dependency>
<groupId>org.scala-lang</groupId>
@@ -386,6 +389,8 @@
<jvmArgs>
<jvmArg>-Xms64m</jvmArg>
<jvmArg>-Xmx1024m</jvmArg>
+ <jvmArg>-XX:PermSize=${PermGen}</jvmArg>
+ <jvmArg>-XX:MaxPermSize=${MaxPermGen}</jvmArg>
</jvmArgs>
<javacArgs>
<javacArg>-source</javacArg>
@@ -422,8 +427,9 @@
<configuration>
<reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory>
<junitxml>.</junitxml>
- <filereports>WDF TestSuite.txt</filereports>
+ <filereports>${project.build.directory}/SparkTestSuite.txt</filereports>
<argLine>-Xms64m -Xmx1024m</argLine>
+ <stderr/>
</configuration>
<executions>
<execution>
@@ -517,7 +523,6 @@
<profiles>
<profile>
<id>hadoop1</id>
-
<properties>
<hadoop.major.version>1</hadoop.major.version>
</properties>
@@ -559,6 +564,73 @@
<groupId>org.apache.avro</groupId>
<artifactId>avro-ipc</artifactId>
<version>1.7.1.cloudera.2</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ </dependencies>
+ </dependencyManagement>
+ </profile>
+
+ <profile>
+ <id>hadoop2-yarn</id>
+ <properties>
+ <hadoop.major.version>2</hadoop.major.version>
+ <!-- 0.23.* is same as 2.0.* - except hardened to run production jobs -->
+ <!-- <yarn.version>0.23.7</yarn.version> -->
+ <yarn.version>2.0.2-alpha</yarn.version>
+ </properties>
+
+ <repositories>
+ <repository>
+ <id>maven-root</id>
+ <name>Maven root repository</name>
+ <url>http://repo1.maven.org/maven2/</url>
+ <releases>
+ <enabled>true</enabled>
+ </releases>
+ <snapshots>
+ <enabled>false</enabled>
+ </snapshots>
+ </repository>
+ </repositories>
+
+ <dependencyManagement>
+ <dependencies>
+ <!-- TODO: check versions, bringover from yarn branch ! -->
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <version>${yarn.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ <version>${yarn.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-common</artifactId>
+ <version>${yarn.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-client</artifactId>
+ <version>${yarn.version}</version>
+ </dependency>
+ <!-- Specify Avro version because Kafka also has it as a dependency -->
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro</artifactId>
+ <version>1.7.1.cloudera.2</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro-ipc</artifactId>
+ <version>1.7.1.cloudera.2</version>
</dependency>
</dependencies>
</dependencyManagement>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 71b3008c8d..987b78d317 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -1,3 +1,4 @@
+
import sbt._
import sbt.Classpaths.publishTask
import Keys._
@@ -12,16 +13,23 @@ object SparkBuild extends Build {
// "1.0.4" for Apache releases, or "0.20.2-cdh3u5" for Cloudera Hadoop.
val HADOOP_VERSION = "1.0.4"
val HADOOP_MAJOR_VERSION = "1"
+ val HADOOP_YARN = false
// For Hadoop 2 versions such as "2.0.0-mr1-cdh4.1.1", set the HADOOP_MAJOR_VERSION to "2"
//val HADOOP_VERSION = "2.0.0-mr1-cdh4.1.1"
//val HADOOP_MAJOR_VERSION = "2"
+ //val HADOOP_YARN = false
+
+ // For Hadoop 2 YARN support
+ //val HADOOP_VERSION = "2.0.2-alpha"
+ //val HADOOP_MAJOR_VERSION = "2"
+ //val HADOOP_YARN = true
lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel, streaming)
lazy val core = Project("core", file("core"), settings = coreSettings)
- lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core) dependsOn (streaming)
+ lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core)
lazy val examples = Project("examples", file("examples"), settings = examplesSettings) dependsOn (core) dependsOn (streaming)
@@ -46,7 +54,7 @@ object SparkBuild extends Build {
// Fork new JVMs for tests and set Java options for those
fork := true,
- javaOptions += "-Xmx1g",
+ javaOptions += "-Xmx2500m",
// Only allow one test at a time, even across projects, since they run in the same JVM
concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
@@ -96,8 +104,8 @@ object SparkBuild extends Build {
*/
libraryDependencies ++= Seq(
- "io.netty" % "netty" % "3.6.6.Final",
- "org.eclipse.jetty" % "jetty-server" % "7.5.3.v20111011",
+ "io.netty" % "netty-all" % "4.0.0.Beta2",
+ "org.eclipse.jetty" % "jetty-server" % "7.6.8.v20121106",
"org.scalatest" %% "scalatest" % "1.9.1" % "test",
"org.scalacheck" %% "scalacheck" % "1.10.0" % "test",
"com.novocode" % "junit-interface" % "0.8" % "test",
@@ -116,21 +124,20 @@ object SparkBuild extends Build {
publishMavenStyle in MavenCompile := true,
publishLocal in MavenCompile <<= publishTask(publishLocalConfiguration in MavenCompile, deliverLocal),
publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn
- )
+ ) ++ net.virtualvoid.sbt.graph.Plugin.graphSettings
- val slf4jVersion = "1.6.1"
+ val slf4jVersion = "1.7.2"
val excludeJackson = ExclusionRule(organization = "org.codehaus.jackson")
val excludeNetty = ExclusionRule(organization = "org.jboss.netty")
+ val excludeAsm = ExclusionRule(organization = "asm")
def coreSettings = sharedSettings ++ Seq(
name := "spark-core",
resolvers ++= Seq(
- "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/",
"JBoss Repository" at "http://repository.jboss.org/nexus/content/repositories/releases/",
"Spray Repository" at "http://repo.spray.cc/",
- "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/",
- "Twitter4J Repository" at "http://twitter4j.org/maven2/"
+ "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/"
),
libraryDependencies ++= Seq(
@@ -139,28 +146,50 @@ object SparkBuild extends Build {
"org.slf4j" % "slf4j-api" % slf4jVersion,
"org.slf4j" % "slf4j-log4j12" % slf4jVersion,
"com.ning" % "compress-lzf" % "0.8.4",
+ "commons-daemon" % "commons-daemon" % "1.0.10",
"org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION excludeAll(excludeNetty, excludeJackson),
- "asm" % "asm-all" % "3.3.1",
+ "org.ow2.asm" % "asm" % "4.0",
"com.google.protobuf" % "protobuf-java" % "2.4.1",
- "de.javakaffee" % "kryo-serializers" % "0.20",
- "com.typesafe.akka" %% "akka-remote" % "2.1.4" excludeAll(excludeNetty),
- "com.typesafe.akka" %% "akka-slf4j" % "2.1.4" excludeAll(excludeNetty),
+ "de.javakaffee" % "kryo-serializers" % "0.22",
+ "com.typesafe.akka" %% "akka-remote" % "2.1.4" excludeAll(excludeNetty),
+ "com.typesafe.akka" %% "akka-slf4j" % "2.1.4" excludeAll(excludeNetty),
"it.unimi.dsi" % "fastutil" % "6.4.4",
- "io.spray" % "spray-can" % "1.1-M7" excludeAll(excludeNetty),
- "io.spray" % "spray-io" % "1.1-M7" excludeAll(excludeNetty),
- "io.spray" % "spray-routing" % "1.1-M7" excludeAll(excludeNetty),
- "io.spray" %% "spray-json" % "1.2.3" excludeAll(excludeNetty),
+ "io.spray" % "spray-can" % "1.1-M7" excludeAll(excludeNetty),
+ "io.spray" % "spray-io" % "1.1-M7" excludeAll(excludeNetty),
+ "io.spray" % "spray-routing" % "1.1-M7" excludeAll(excludeNetty),
+ "io.spray" %% "spray-json" % "1.2.3" excludeAll(excludeNetty),
"colt" % "colt" % "1.2.0",
"org.apache.mesos" % "mesos" % "0.9.0-incubating",
+ "org.apache.derby" % "derby" % "10.4.2.0" % "test",
"org.scala-lang" % "scala-actors" % "2.10.1",
"org.scala-lang" % "jline" % "2.10.1",
"org.scala-lang" % "scala-reflect" % "2.10.1"
- ) ++ (if (HADOOP_MAJOR_VERSION == "2")
- Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION excludeAll(excludeNetty, excludeJackson))
- else
- None
- ).toSeq,
- unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") }
+ ) ++ (
+ if (HADOOP_MAJOR_VERSION == "2") {
+ if (HADOOP_YARN) {
+ Seq(
+ // Exclude rule required for all ?
+ "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty),
+ "org.apache.hadoop" % "hadoop-yarn-api" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty),
+ "org.apache.hadoop" % "hadoop-yarn-common" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty),
+ "org.apache.hadoop" % "hadoop-yarn-client" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty)
+ )
+ } else {
+ Seq(
+ "org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty),
+ "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty)
+ )
+ }
+ } else {
+ Seq("org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty) )
+ }),
+ unmanagedSourceDirectories in Compile <+= baseDirectory{ _ /
+ ( if (HADOOP_YARN && HADOOP_MAJOR_VERSION == "2") {
+ "src/hadoop2-yarn/scala"
+ } else {
+ "src/hadoop" + HADOOP_MAJOR_VERSION + "/scala"
+ } )
+ }
) ++ assemblySettings ++ extraAssemblySettings ++ Twirl.settings
def rootSettings = sharedSettings ++ Seq(
@@ -169,30 +198,45 @@ object SparkBuild extends Build {
def replSettings = sharedSettings ++ Seq(
name := "spark-repl",
- // libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-compiler" % _)
libraryDependencies ++= Seq("org.scala-lang" % "scala-compiler" % "2.10.1")
)
def examplesSettings = sharedSettings ++ Seq(
name := "spark-examples",
- libraryDependencies ++= Seq("com.twitter" %% "algebird-core" % "0.1.11")
+ libraryDependencies ++= Seq(
+ "com.twitter" %% "algebird-core" % "0.1.11",
+ "org.apache.hbase" % "hbase" % "0.94.6" excludeAll(excludeNetty, excludeAsm),
+
+ "org.apache.cassandra" % "cassandra-all" % "1.2.5"
+ exclude("com.google.guava", "guava")
+ exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru")
+ exclude("com.ning","compress-lzf")
+ exclude("io.netty", "netty")
+ exclude("jline","jline")
+ exclude("log4j","log4j")
+ exclude("org.apache.cassandra.deps", "avro")
+ )
)
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")
def streamingSettings = sharedSettings ++ Seq(
name := "spark-streaming",
+ resolvers ++= Seq(
+ "Akka Repository" at "http://repo.akka.io/releases/"
+ ),
libraryDependencies ++= Seq(
"org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty),
"com.github.sgroschupf" % "zkclient" % "0.1",
- "org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty),
- "com.typesafe.akka" %% "akka-zeromq" % "2.1.4" excludeAll(excludeNetty)
+ "org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty),
+ "com.typesafe.akka" %% "akka-zeromq" % "2.1.4" excludeAll(excludeNetty)
)
) ++ assemblySettings ++ extraAssemblySettings
def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq(
mergeStrategy in assembly := {
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
+ case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard
case "reference.conf" => MergeStrategy.concat
case _ => MergeStrategy.first
}
diff --git a/project/plugins.sbt b/project/plugins.sbt
index d4f2442872..f806e66481 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -16,3 +16,5 @@ addSbtPlugin("io.spray" %% "sbt-twirl" % "0.6.1")
//resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns)
//addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6")
+
+addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.3")
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
new file mode 100644
index 0000000000..78c9457b84
--- /dev/null
+++ b/python/pyspark/daemon.py
@@ -0,0 +1,164 @@
+import os
+import signal
+import socket
+import sys
+import traceback
+import multiprocessing
+from ctypes import c_bool
+from errno import EINTR, ECHILD
+from socket import AF_INET, SOCK_STREAM, SOMAXCONN
+from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
+from pyspark.worker import main as worker_main
+from pyspark.serializers import write_int
+
+try:
+ POOLSIZE = multiprocessing.cpu_count()
+except NotImplementedError:
+ POOLSIZE = 4
+
+exit_flag = multiprocessing.Value(c_bool, False)
+
+
+def should_exit():
+ global exit_flag
+ return exit_flag.value
+
+
+def compute_real_exit_code(exit_code):
+ # SystemExit's code can be integer or string, but os._exit only accepts integers
+ import numbers
+ if isinstance(exit_code, numbers.Integral):
+ return exit_code
+ else:
+ return 1
+
+
+def worker(listen_sock):
+ # Redirect stdout to stderr
+ os.dup2(2, 1)
+ sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1
+
+ # Manager sends SIGHUP to request termination of workers in the pool
+ def handle_sighup(*args):
+ assert should_exit()
+ signal.signal(SIGHUP, handle_sighup)
+
+ # Cleanup zombie children
+ def handle_sigchld(*args):
+ pid = status = None
+ try:
+ while (pid, status) != (0, 0):
+ pid, status = os.waitpid(0, os.WNOHANG)
+ except EnvironmentError as err:
+ if err.errno == EINTR:
+ # retry
+ handle_sigchld()
+ elif err.errno != ECHILD:
+ raise
+ signal.signal(SIGCHLD, handle_sigchld)
+
+ # Handle clients
+ while not should_exit():
+ # Wait until a client arrives or we have to exit
+ sock = None
+ while not should_exit() and sock is None:
+ try:
+ sock, addr = listen_sock.accept()
+ except EnvironmentError as err:
+ if err.errno != EINTR:
+ raise
+
+ if sock is not None:
+ # Fork a child to handle the client.
+ # The client is handled in the child so that the manager
+ # never receives SIGCHLD unless a worker crashes.
+ if os.fork() == 0:
+ # Leave the worker pool
+ signal.signal(SIGHUP, SIG_DFL)
+ listen_sock.close()
+ # Read the socket using fdopen instead of socket.makefile() because the latter
+ # seems to be very slow; note that we need to dup() the file descriptor because
+ # otherwise writes also cause a seek that makes us miss data on the read side.
+ infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
+ outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
+ exit_code = 0
+ try:
+ worker_main(infile, outfile)
+ except SystemExit as exc:
+ exit_code = exc.code
+ finally:
+ outfile.flush()
+ sock.close()
+ os._exit(compute_real_exit_code(exit_code))
+ else:
+ sock.close()
+
+
+def launch_worker(listen_sock):
+ if os.fork() == 0:
+ try:
+ worker(listen_sock)
+ except Exception as err:
+ traceback.print_exc()
+ os._exit(1)
+ else:
+ assert should_exit()
+ os._exit(0)
+
+
+def manager():
+ # Create a new process group to corral our children
+ os.setpgid(0, 0)
+
+ # Create a listening socket on the AF_INET loopback interface
+ listen_sock = socket.socket(AF_INET, SOCK_STREAM)
+ listen_sock.bind(('127.0.0.1', 0))
+ listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN))
+ listen_host, listen_port = listen_sock.getsockname()
+ write_int(listen_port, sys.stdout)
+
+ # Launch initial worker pool
+ for idx in range(POOLSIZE):
+ launch_worker(listen_sock)
+ listen_sock.close()
+
+ def shutdown():
+ global exit_flag
+ exit_flag.value = True
+
+ # Gracefully exit on SIGTERM, don't die on SIGHUP
+ signal.signal(SIGTERM, lambda signum, frame: shutdown())
+ signal.signal(SIGHUP, SIG_IGN)
+
+ # Cleanup zombie children
+ def handle_sigchld(*args):
+ try:
+ pid, status = os.waitpid(0, os.WNOHANG)
+ if status != 0 and not should_exit():
+ raise RuntimeError("worker crashed: %s, %s" % (pid, status))
+ except EnvironmentError as err:
+ if err.errno not in (ECHILD, EINTR):
+ raise
+ signal.signal(SIGCHLD, handle_sigchld)
+
+ # Initialization complete
+ sys.stdout.close()
+ try:
+ while not should_exit():
+ try:
+ # Spark tells us to exit by closing stdin
+ if os.read(0, 512) == '':
+ shutdown()
+ except EnvironmentError as err:
+ if err.errno != EINTR:
+ shutdown()
+ raise
+ finally:
+ signal.signal(SIGTERM, SIG_DFL)
+ exit_flag.value = True
+ # Send SIGHUP to notify workers of shutdown
+ os.kill(0, SIGHUP)
+
+
+if __name__ == '__main__':
+ manager()
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 115cf28cc2..5a95144983 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -46,6 +46,10 @@ def read_long(stream):
return struct.unpack("!q", length)[0]
+def write_long(value, stream):
+ stream.write(struct.pack("!q", value))
+
+
def read_int(stream):
length = stream.read(4)
if length == "":
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 6a1962d267..1e34d47365 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -12,6 +12,7 @@ import unittest
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.java_gateway import SPARK_HOME
+from pyspark.serializers import read_int
class PySparkTestCase(unittest.TestCase):
@@ -117,5 +118,47 @@ class TestIO(PySparkTestCase):
self.sc.parallelize([1]).foreach(func)
+class TestDaemon(unittest.TestCase):
+ def connect(self, port):
+ from socket import socket, AF_INET, SOCK_STREAM
+ sock = socket(AF_INET, SOCK_STREAM)
+ sock.connect(('127.0.0.1', port))
+ # send a split index of -1 to shutdown the worker
+ sock.send("\xFF\xFF\xFF\xFF")
+ sock.close()
+ return True
+
+ def do_termination_test(self, terminator):
+ from subprocess import Popen, PIPE
+ from errno import ECONNREFUSED
+
+ # start daemon
+ daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py")
+ daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE)
+
+ # read the port number
+ port = read_int(daemon.stdout)
+
+ # daemon should accept connections
+ self.assertTrue(self.connect(port))
+
+ # request shutdown
+ terminator(daemon)
+ time.sleep(1)
+
+ # daemon should no longer accept connections
+ with self.assertRaises(EnvironmentError) as trap:
+ self.connect(port)
+ self.assertEqual(trap.exception.errno, ECONNREFUSED)
+
+ def test_termination_stdin(self):
+ """Ensure that daemon and workers terminate when stdin is closed."""
+ self.do_termination_test(lambda daemon: daemon.stdin.close())
+
+ def test_termination_sigterm(self):
+ """Ensure that daemon and workers terminate on SIGTERM."""
+ from signal import SIGTERM
+ self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
+
if __name__ == "__main__":
unittest.main()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 812e7a9da5..379bbfd4c2 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -3,6 +3,7 @@ Worker that receives input from Piped RDD.
"""
import os
import sys
+import time
import traceback
from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
@@ -12,48 +13,60 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, read_with_length, write_int, \
- read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
+ read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
-# Redirect stdout to stderr so that users must return values from functions.
-old_stdout = os.fdopen(os.dup(1), 'w')
-os.dup2(2, 1)
+def load_obj(infile):
+ return load_pickle(standard_b64decode(infile.readline().strip()))
-def load_obj():
- return load_pickle(standard_b64decode(sys.stdin.readline().strip()))
+def report_times(outfile, boot, init, finish):
+ write_int(-3, outfile)
+ write_long(1000 * boot, outfile)
+ write_long(1000 * init, outfile)
+ write_long(1000 * finish, outfile)
-def main():
- split_index = read_int(sys.stdin)
- spark_files_dir = load_pickle(read_with_length(sys.stdin))
+def main(infile, outfile):
+ boot_time = time.time()
+ split_index = read_int(infile)
+ if split_index == -1: # for unit tests
+ return
+ spark_files_dir = load_pickle(read_with_length(infile))
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
sys.path.append(spark_files_dir)
- num_broadcast_variables = read_int(sys.stdin)
+ num_broadcast_variables = read_int(infile)
for _ in range(num_broadcast_variables):
- bid = read_long(sys.stdin)
- value = read_with_length(sys.stdin)
+ bid = read_long(infile)
+ value = read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
- func = load_obj()
- bypassSerializer = load_obj()
+ func = load_obj(infile)
+ bypassSerializer = load_obj(infile)
if bypassSerializer:
dumps = lambda x: x
else:
dumps = dump_pickle
- iterator = read_from_pickle_file(sys.stdin)
+ init_time = time.time()
+ iterator = read_from_pickle_file(infile)
try:
for obj in func(split_index, iterator):
- write_with_length(dumps(obj), old_stdout)
+ write_with_length(dumps(obj), outfile)
except Exception as e:
- write_int(-2, old_stdout)
- write_with_length(traceback.format_exc(), old_stdout)
+ write_int(-2, outfile)
+ write_with_length(traceback.format_exc(), outfile)
sys.exit(-1)
+ finish_time = time.time()
+ report_times(outfile, boot_time, init_time, finish_time)
# Mark the beginning of the accumulators section of the output
- write_int(-1, old_stdout)
+ write_int(-1, outfile)
for aid, accum in _accumulatorRegistry.items():
- write_with_length(dump_pickle((aid, accum._value)), old_stdout)
+ write_with_length(dump_pickle((aid, accum._value)), outfile)
+ write_int(-1, outfile)
if __name__ == '__main__':
- main()
+ # Redirect stdout to stderr so that users must return values from functions.
+ old_stdout = os.fdopen(os.dup(1), 'w')
+ os.dup2(2, 1)
+ main(sys.stdin, old_stdout)
diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml
index fe526a7616..7a7280313e 100644
--- a/repl-bin/pom.xml
+++ b/repl-bin/pom.xml
@@ -154,6 +154,61 @@
</dependencies>
</profile>
<profile>
+ <id>hadoop2-yarn</id>
+ <properties>
+ <classifier>hadoop2-yarn</classifier>
+ </properties>
+ <dependencies>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-core</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-bagel</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-examples</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-repl</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-common</artifactId>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-client</artifactId>
+ <scope>runtime</scope>
+ </dependency>
+ </dependencies>
+ </profile>
+ <profile>
<id>deb</id>
<build>
<plugins>
diff --git a/repl/pom.xml b/repl/pom.xml
index 891b05fb8f..d618f176f7 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -99,13 +99,6 @@
<scope>runtime</scope>
</dependency>
<dependency>
- <groupId>org.spark-project</groupId>
- <artifactId>spark-streaming</artifactId>
- <version>${project.version}</version>
- <classifier>hadoop1</classifier>
- <scope>runtime</scope>
- </dependency>
- <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
<scope>provided</scope>
@@ -150,20 +143,84 @@
<scope>runtime</scope>
</dependency>
<dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-core</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro-ipc</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <configuration>
+ <classifier>hadoop2</classifier>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ <profile>
+ <id>hadoop2-yarn</id>
+ <properties>
+ <classifier>hadoop2-yarn</classifier>
+ </properties>
+ <dependencies>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-core</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-bagel</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-examples</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
<groupId>org.spark-project</groupId>
<artifactId>spark-streaming</artifactId>
<version>${project.version}</version>
- <classifier>hadoop2</classifier>
+ <classifier>hadoop2-yarn</classifier>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
- <artifactId>hadoop-core</artifactId>
+ <artifactId>hadoop-client</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
- <artifactId>hadoop-client</artifactId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-common</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
@@ -183,7 +240,7 @@
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
- <classifier>hadoop2</classifier>
+ <classifier>hadoop2-yarn</classifier>
</configuration>
</plugin>
</plugins>
diff --git a/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala
index 13d81ec1cf..0e9aa863b5 100644
--- a/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala
+++ b/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala
@@ -8,7 +8,6 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.objectweb.asm._
-import org.objectweb.asm.commons.EmptyVisitor
import org.objectweb.asm.Opcodes._
@@ -83,7 +82,7 @@ extends ClassLoader(parent) {
}
class ConstructorCleaner(className: String, cv: ClassVisitor)
-extends ClassAdapter(cv) {
+extends ClassVisitor(ASM4, cv) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
val mv = cv.visitMethod(access, name, desc, sig, exceptions)
diff --git a/repl/src/test/scala/spark/repl/ReplSuite.scala b/repl/src/test/scala/spark/repl/ReplSuite.scala
index 4dfd3127bf..8df0c3a7f4 100644
--- a/repl/src/test/scala/spark/repl/ReplSuite.scala
+++ b/repl/src/test/scala/spark/repl/ReplSuite.scala
@@ -19,7 +19,7 @@ class ReplSuite extends FunSuite with ReplSuiteMixin {
assertContains("res1: Int = 55", output)
}
- test("external vars") {
+ test ("external vars") {
val output = runInterpreter("local", """
var v = 7
sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_)
@@ -105,6 +105,27 @@ class ReplSuite extends FunSuite with ReplSuiteMixin {
assertContains("res2: Long = 3", output)
}
+ test ("local-cluster mode") {
+ val output = runInterpreter("local-cluster[1,1,512]", """
+ var v = 7
+ def getV() = v
+ sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ v = 10
+ sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ var array = new Array[Int](5)
+ val broadcastArray = sc.broadcast(array)
+ sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ array(0) = 5
+ sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 70", output)
+ assertContains("res1: Int = 100", output)
+ assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ }
+
if (System.getenv("MESOS_NATIVE_LIBRARY") != null) {
test("running on Mesos") {
val output = runInterpreter("localquiet", """
diff --git a/run b/run
index 96c7f8a095..91665f8934 100755
--- a/run
+++ b/run
@@ -23,30 +23,39 @@ fi
# values for that; it doesn't need a lot
if [ "$1" = "spark.deploy.master.Master" -o "$1" = "spark.deploy.worker.Worker" ]; then
SPARK_MEM=${SPARK_DAEMON_MEMORY:-512m}
- SPARK_DAEMON_JAVA_OPTS+=" -Dspark.akka.logLifecycleEvents=true"
- SPARK_JAVA_OPTS=$SPARK_DAEMON_JAVA_OPTS # Empty by default
+ SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true"
+ # Do not overwrite SPARK_JAVA_OPTS environment variable in this script
+ OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS" # Empty by default
+else
+ OUR_JAVA_OPTS="$SPARK_JAVA_OPTS"
fi
# Add java opts for master, worker, executor. The opts maybe null
case "$1" in
'spark.deploy.master.Master')
- SPARK_JAVA_OPTS+=" $SPARK_MASTER_OPTS"
+ OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_MASTER_OPTS"
;;
'spark.deploy.worker.Worker')
- SPARK_JAVA_OPTS+=" $SPARK_WORKER_OPTS"
+ OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_WORKER_OPTS"
;;
'spark.executor.StandaloneExecutorBackend')
- SPARK_JAVA_OPTS+=" $SPARK_EXECUTOR_OPTS"
+ OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
;;
'spark.executor.MesosExecutorBackend')
- SPARK_JAVA_OPTS+=" $SPARK_EXECUTOR_OPTS"
+ OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
;;
'spark.repl.Main')
- SPARK_JAVA_OPTS+=" $SPARK_REPL_OPTS"
+ OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_REPL_OPTS"
;;
esac
+# Figure out whether to run our class with java or with the scala launcher.
+# In most cases, we'd prefer to execute our process with java because scala
+# creates a shell script as the parent of its Java process, which makes it
+# hard to kill the child with stuff like Process.destroy(). However, for
+# the Spark shell, the wrapper is necessary to properly reset the terminal
+# when we exit, so we allow it to set a variable to launch with scala.
if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then
if [ "$SCALA_HOME" ]; then
RUNNER="${SCALA_HOME}/bin/scala"
@@ -59,14 +68,15 @@ if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then
fi
fi
else
- if [ `command -v java` ]; then
- RUNNER="java"
+ if [ -n "${JAVA_HOME}" ]; then
+ RUNNER="${JAVA_HOME}/bin/java"
else
- if [ -z "$JAVA_HOME" ]; then
+ if [ `command -v java` ]; then
+ RUNNER="java"
+ else
echo "JAVA_HOME is not set" >&2
exit 1
fi
- RUNNER="${JAVA_HOME}/bin/java"
fi
if [ -z "$SCALA_LIBRARY_PATH" ]; then
if [ -z "$SCALA_HOME" ]; then
@@ -85,57 +95,36 @@ fi
export SPARK_MEM
# Set JAVA_OPTS to be able to load native libraries and to set heap size
-JAVA_OPTS="$SPARK_JAVA_OPTS"
-JAVA_OPTS+=" -Djava.library.path=$SPARK_LIBRARY_PATH"
-JAVA_OPTS+=" -Xms$SPARK_MEM -Xmx$SPARK_MEM"
+JAVA_OPTS="$OUR_JAVA_OPTS"
+JAVA_OPTS="$JAVA_OPTS -Djava.library.path=$SPARK_LIBRARY_PATH"
+JAVA_OPTS="$JAVA_OPTS -Xms$SPARK_MEM -Xmx$SPARK_MEM"
# Load extra JAVA_OPTS from conf/java-opts, if it exists
if [ -e $FWDIR/conf/java-opts ] ; then
- JAVA_OPTS+=" `cat $FWDIR/conf/java-opts`"
+ JAVA_OPTS="$JAVA_OPTS `cat $FWDIR/conf/java-opts`"
fi
export JAVA_OPTS
+# Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala!
CORE_DIR="$FWDIR/core"
-REPL_DIR="$FWDIR/repl"
EXAMPLES_DIR="$FWDIR/examples"
-BAGEL_DIR="$FWDIR/bagel"
-STREAMING_DIR="$FWDIR/streaming"
-PYSPARK_DIR="$FWDIR/python"
+REPL_DIR="$FWDIR/repl"
# Exit if the user hasn't compiled Spark
-if [ ! -e "$REPL_DIR/target" ]; then
- echo "Failed to find Spark classes in $REPL_DIR/target" >&2
+if [ ! -e "$CORE_DIR/target" ]; then
+ echo "Failed to find Spark classes in $CORE_DIR/target" >&2
echo "You need to compile Spark before running this program" >&2
exit 1
fi
-# Build up classpath
-CLASSPATH="$SPARK_CLASSPATH"
-CLASSPATH+=":$FWDIR/conf"
-CLASSPATH+=":$CORE_DIR/target/scala-$SCALA_VERSION/classes"
-if [ -n "$SPARK_TESTING" ] ; then
- CLASSPATH+=":$CORE_DIR/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH+=":$STREAMING_DIR/target/scala-$SCALA_VERSION/test-classes"
-fi
-CLASSPATH+=":$CORE_DIR/src/main/resources"
-CLASSPATH+=":$REPL_DIR/target/scala-$SCALA_VERSION/classes"
-CLASSPATH+=":$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes"
-CLASSPATH+=":$STREAMING_DIR/target/scala-$SCALA_VERSION/classes"
-CLASSPATH+=":$STREAMING_DIR/lib/org/apache/kafka/kafka/0.7.2-spark/*" # <-- our in-project Kafka Jar
-if [ -e "$FWDIR/lib_managed" ]; then
- CLASSPATH+=":$FWDIR/lib_managed/jars/*"
- CLASSPATH+=":$FWDIR/lib_managed/bundles/*"
-fi
-CLASSPATH+=":$REPL_DIR/lib/*"
-if [ -e repl-bin/target ]; then
- for jar in `find "repl-bin/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do
- CLASSPATH+=":$jar"
- done
+if [[ "$@" = *repl* && ! -e "$REPL_DIR/target" ]]; then
+ echo "Failed to find Spark classes in $REPL_DIR/target" >&2
+ echo "You need to compile Spark repl module before running this program" >&2
+ exit 1
fi
-CLASSPATH+=":$BAGEL_DIR/target/scala-$SCALA_VERSION/classes"
-for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do
- CLASSPATH+=":$jar"
-done
-export CLASSPATH # Needed for spark-shell
+
+# Compute classpath using external script
+CLASSPATH=`$FWDIR/bin/compute-classpath.sh`
+export CLASSPATH
# Figure out the JAR file that our examples were packaged into. This includes a bit of a hack
# to avoid the -sources and -doc packages that are built by publish-local.
@@ -143,23 +132,14 @@ if [ -e "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar ];
# Use the JAR from the SBT build
export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar`
fi
-if [ -e "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar ]; then
+if [ -e "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar ]; then
# Use the JAR from the Maven build
- export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar`
+ export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar`
fi
-# Figure out whether to run our class with java or with the scala launcher.
-# In most cases, we'd prefer to execute our process with java because scala
-# creates a shell script as the parent of its Java process, which makes it
-# hard to kill the child with stuff like Process.destroy(). However, for
-# the Spark shell, the wrapper is necessary to properly reset the terminal
-# when we exit, so we allow it to set a variable to launch with scala.
if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then
EXTRA_ARGS="" # Java options will be passed to scala as JAVA_OPTS
else
- CLASSPATH+=":$SCALA_LIBRARY_PATH/scala-library.jar"
- CLASSPATH+=":$SCALA_LIBRARY_PATH/scala-compiler.jar"
- CLASSPATH+=":$SCALA_LIBRARY_PATH/jline.jar"
# The JVM doesn't read JAVA_OPTS by default so we need to pass it in
EXTRA_ARGS="$JAVA_OPTS"
fi
diff --git a/run2.cmd b/run2.cmd
index 27d0968571..6d43db0325 100644
--- a/run2.cmd
+++ b/run2.cmd
@@ -23,7 +23,9 @@ if "%1"=="spark.deploy.worker.Worker" set RUNNING_DAEMON=1
if "x%SPARK_DAEMON_MEMORY%" == "x" set SPARK_DAEMON_MEMORY=512m
set SPARK_DAEMON_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% -Dspark.akka.logLifecycleEvents=true
if "%RUNNING_DAEMON%"=="1" set SPARK_MEM=%SPARK_DAEMON_MEMORY%
-if "%RUNNING_DAEMON%"=="1" set SPARK_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS%
+rem Do not overwrite SPARK_JAVA_OPTS environment variable in this script
+if "%RUNNING_DAEMON%"=="0" set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS%
+if "%RUNNING_DAEMON%"=="1" set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS%
rem Check that SCALA_HOME has been specified
if not "x%SCALA_HOME%"=="x" goto scala_exists
@@ -31,37 +33,22 @@ if not "x%SCALA_HOME%"=="x" goto scala_exists
goto exit
:scala_exists
-rem If the user specifies a Mesos JAR, put it before our included one on the classpath
-set MESOS_CLASSPATH=
-if not "x%MESOS_JAR%"=="x" set MESOS_CLASSPATH=%MESOS_JAR%
-
rem Figure out how much memory to use per executor and set it as an environment
rem variable so that our process sees it and can report it to Mesos
if "x%SPARK_MEM%"=="x" set SPARK_MEM=512m
rem Set JAVA_OPTS to be able to load native libraries and to set heap size
-set JAVA_OPTS=%SPARK_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%SPARK_MEM% -Xmx%SPARK_MEM%
-rem Load extra JAVA_OPTS from conf/java-opts, if it exists
-if exist "%FWDIR%conf\java-opts.cmd" call "%FWDIR%conf\java-opts.cmd"
+set JAVA_OPTS=%OUR_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%SPARK_MEM% -Xmx%SPARK_MEM%
+rem Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala!
set CORE_DIR=%FWDIR%core
-set REPL_DIR=%FWDIR%repl
set EXAMPLES_DIR=%FWDIR%examples
-set BAGEL_DIR=%FWDIR%bagel
-set STREAMING_DIR=%FWDIR%streaming
-set PYSPARK_DIR=%FWDIR%python
+set REPL_DIR=%FWDIR%repl
-rem Build up classpath
-set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes
-set CLASSPATH=%CLASSPATH%;%CORE_DIR%\target\scala-%SCALA_VERSION%\test-classes;%CORE_DIR%\src\main\resources
-set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\classes;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\test-classes
-set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\lib\org\apache\kafka\kafka\0.7.2-spark\*
-set CLASSPATH=%CLASSPATH%;%REPL_DIR%\target\scala-%SCALA_VERSION%\classes;%EXAMPLES_DIR%\target\scala-%SCALA_VERSION%\classes
-set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\jars\*
-set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\bundles\*
-set CLASSPATH=%CLASSPATH%;%FWDIR%repl\lib\*
-set CLASSPATH=%CLASSPATH%;%FWDIR%python\lib\*
-set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes
+rem Compute classpath using external script
+set DONT_PRINT_CLASSPATH=1
+call "%FWDIR%bin\compute-classpath.cmd"
+set DONT_PRINT_CLASSPATH=0
rem Figure out the JAR file that our examples were packaged into.
rem First search in the build path from SBT:
diff --git a/streaming/pom.xml b/streaming/pom.xml
index fe869ba66e..1589e3202f 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -41,6 +41,12 @@
<groupId>org.apache.flume</groupId>
<artifactId>flume-ng-sdk</artifactId>
<version>1.2.0</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>com.github.sgroschupf</groupId>
@@ -149,5 +155,42 @@
</plugins>
</build>
</profile>
+ <profile>
+ <id>hadoop2-yarn</id>
+ <dependencies>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-core</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-common</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <configuration>
+ <classifier>hadoop2-yarn</classifier>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
</profiles>
</project>
diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
index e303e33e5e..66e67cbfa1 100644
--- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
@@ -38,11 +38,20 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
private[streaming]
class CheckpointWriter(checkpointDir: String) extends Logging {
val file = new Path(checkpointDir, "graph")
+ // The file to which we actually write - and then "move" to file.
+ private val writeFile = new Path(file.getParent, file.getName + ".next")
+ private val bakFile = new Path(file.getParent, file.getName + ".bk")
+
+ private var stopped = false
+
val conf = new Configuration()
var fs = file.getFileSystem(conf)
val maxAttempts = 3
val executor = Executors.newFixedThreadPool(1)
+ // Removed code which validates whether there is only one CheckpointWriter per path 'file' since
+ // I did not notice any errors - reintroduce it ?
+
class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable {
def run() {
var attempts = 0
@@ -51,15 +60,17 @@ class CheckpointWriter(checkpointDir: String) extends Logging {
attempts += 1
try {
logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'")
- if (fs.exists(file)) {
- val bkFile = new Path(file.getParent, file.getName + ".bk")
- FileUtil.copy(fs, file, fs, bkFile, true, true, conf)
- logDebug("Moved existing checkpoint file to " + bkFile)
- }
- val fos = fs.create(file)
+ // This is inherently thread unsafe .. so alleviating it by writing to '.new' and then doing moves : which should be pretty fast.
+ val fos = fs.create(writeFile)
fos.write(bytes)
fos.close()
- fos.close()
+ if (fs.exists(file) && fs.rename(file, bakFile)) {
+ logDebug("Moved existing checkpoint file to " + bakFile)
+ }
+ // paranoia
+ fs.delete(file, false)
+ fs.rename(writeFile, file)
+
val finishTime = System.currentTimeMillis();
logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file +
"', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds")
@@ -84,7 +95,15 @@ class CheckpointWriter(checkpointDir: String) extends Logging {
}
def stop() {
+ synchronized {
+ if (stopped) return ;
+ stopped = true
+ }
executor.shutdown()
+ val startTime = System.currentTimeMillis()
+ val terminated = executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS)
+ val endTime = System.currentTimeMillis()
+ logInfo("CheckpointWriter executor terminated ? " + terminated + ", waited for " + (endTime - startTime) + " ms.")
}
}
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index 6ad43dd9b5..4eb5e163a2 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -442,7 +442,12 @@ abstract class DStream[T: ClassTag] (
* Return a new DStream in which each RDD has a single element generated by counting each RDD
* of this DStream.
*/
- def count(): DStream[Long] = this.map(_ => 1L).reduce(_ + _)
+ def count(): DStream[Long] = {
+ this.map(_ => (null, 1L))
+ .transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1)))
+ .reduceByKey(_ + _)
+ .map(_._2)
+ }
/**
* Return a new DStream in which each RDD contains the counts of each distinct value in
@@ -458,7 +463,7 @@ abstract class DStream[T: ClassTag] (
* this DStream will be registered as an output stream and therefore materialized.
*/
def foreach(foreachFunc: RDD[T] => Unit) {
- foreach((r: RDD[T], t: Time) => foreachFunc(r))
+ this.foreach((r: RDD[T], t: Time) => foreachFunc(r))
}
/**
diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
index adb7f3a24d..3b331956f5 100644
--- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
@@ -54,8 +54,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
throw new Exception("Batch duration already set as " + batchDuration +
". cannot set it again.")
}
+ batchDuration = duration
}
- batchDuration = duration
}
def remember(duration: Duration) {
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
index 7646e15521..e06f1bbfd6 100644
--- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -28,6 +28,8 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
import org.apache.hadoop.fs.Path
import twitter4j.Status
+import twitter4j.auth.Authorization
+
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -187,10 +189,11 @@ class StreamingContext private (
* should be same.
*/
def actorStream[T: ClassTag](
- props: Props,
- name: String,
- storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2,
- supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy): DStream[T] = {
+ props: Props,
+ name: String,
+ storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2,
+ supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy
+ ): DStream[T] = {
networkStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy))
}
@@ -198,9 +201,10 @@ class StreamingContext private (
* Create an input stream that receives messages pushed by a zeromq publisher.
* @param publisherUrl Url of remote zeromq publisher
* @param subscribe topic to subscribe to
- * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence
- * of byte thus it needs the converter(which might be deserializer of bytes)
- * to translate from sequence of sequence of bytes, where sequence refer to a frame
+ * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic
+ * and each frame has sequence of byte thus it needs the converter
+ * (which might be deserializer of bytes) to translate from sequence
+ * of sequence of bytes, where sequence refer to a frame
* and sub sequence refer to its payload.
* @param storageLevel RDD storage level. Defaults to memory-only.
*/
@@ -216,24 +220,39 @@ class StreamingContext private (
}
/**
- * Create an input stream that pulls messages form a Kafka Broker.
+ * Create an input stream that pulls messages from a Kafka Broker.
* @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
* @param groupId The group id for this consumer.
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
- * in its own thread.
- * @param initialOffsets Optional initial offsets for each of the partitions to consume.
- * By default the value is pulled from zookeper.
+ * in its own thread.
* @param storageLevel Storage level to use for storing the received objects
* (default: StorageLevel.MEMORY_AND_DISK_SER_2)
*/
- def kafkaStream[T: ClassTag](
+ def kafkaStream(
zkQuorum: String,
groupId: String,
topics: Map[String, Int],
- initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](),
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2
+ ): DStream[String] = {
+ val kafkaParams = Map[String, String](
+ "zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000")
+ kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, storageLevel)
+ }
+
+ /**
+ * Create an input stream that pulls messages from a Kafka Broker.
+ * @param kafkaParams Map of kafka configuration paramaters.
+ * See: http://kafka.apache.org/configuration.html
+ * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
+ * in its own thread.
+ * @param storageLevel Storage level to use for storing the received objects
+ */
+ def kafkaStream[T: ClassTag, D <: kafka.serializer.Decoder[_]: Manifest](
+ kafkaParams: Map[String, String],
+ topics: Map[String, Int],
+ storageLevel: StorageLevel
): DStream[T] = {
- val inputStream = new KafkaInputDStream[T](this, zkQuorum, groupId, topics, initialOffsets, storageLevel)
+ val inputStream = new KafkaInputDStream[T, D](this, kafkaParams, topics, storageLevel)
registerInputStream(inputStream)
inputStream
}
@@ -363,18 +382,18 @@ class StreamingContext private (
/**
* Create a input stream that returns tweets received from Twitter.
- * @param username Twitter username
- * @param password Twitter password
+ * @param twitterAuth Twitter4J authentication, or None to use Twitter4J's default OAuth
+ * authorization; this uses the system properties twitter4j.oauth.consumerKey,
+ * .consumerSecret, .accessToken and .accessTokenSecret.
* @param filters Set of filter strings to get only those tweets that match them
* @param storageLevel Storage level to use for storing the received objects
*/
def twitterStream(
- username: String,
- password: String,
+ twitterAuth: Option[Authorization] = None,
filters: Seq[String] = Nil,
storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2
): DStream[Status] = {
- val inputStream = new TwitterInputDStream(this, username, password, filters, storageLevel)
+ val inputStream = new TwitterInputDStream(this, twitterAuth, filters, storageLevel)
registerInputStream(inputStream)
inputStream
}
@@ -398,7 +417,8 @@ class StreamingContext private (
* it will process either one or all of the RDDs returned by the queue.
* @param queue Queue of RDDs
* @param oneAtATime Whether only one RDD should be consumed from the queue in every interval
- * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. Set as null if no RDD should be returned when empty
+ * @param defaultRDD Default RDD is returned by the DStream when the queue is empty.
+ * Set as null if no RDD should be returned when empty
* @tparam T Type of objects in the RDD
*/
def queueStream[T: ClassTag](
diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
index 00e5aa0603..7da732dd88 100644
--- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
@@ -4,24 +4,21 @@ import spark.streaming._
import receivers.{ActorReceiver, ReceiverSupervisorStrategy}
import spark.streaming.dstream._
import spark.storage.StorageLevel
-
import spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
import spark.api.java.{JavaSparkContext, JavaRDD}
-
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
-
import twitter4j.Status
-
import akka.actor.Props
import akka.actor.SupervisorStrategy
import akka.zeromq.Subscribe
-
import scala.collection.JavaConversions._
+
import scala.reflect.ClassTag
import java.lang.{Long => JLong, Integer => JInt}
import java.io.InputStream
import java.util.{Map => JMap}
+import twitter4j.auth.Authorization
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -122,14 +119,16 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
*/
- def kafkaStream[T](
+ def kafkaStream(
zkQuorum: String,
groupId: String,
topics: JMap[String, JInt])
- : JavaDStream[T] = {
- implicit val cmt: ClassTag[T] =
- implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
- ssc.kafkaStream[T](zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*))
+ : JavaDStream[String] = {
+ implicit val cmt: ClassTag[String] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]]
+ ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*),
+ StorageLevel.MEMORY_ONLY_SER_2)
+
}
/**
@@ -137,49 +136,45 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
* @param groupId The group id for this consumer.
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
- * in its own thread.
- * @param initialOffsets Optional initial offsets for each of the partitions to consume.
- * By default the value is pulled from zookeper.
+ * in its own thread.
+ * @param storageLevel RDD storage level. Defaults to memory-only
+ *
*/
- def kafkaStream[T](
+ def kafkaStream(
zkQuorum: String,
groupId: String,
topics: JMap[String, JInt],
- initialOffsets: JMap[KafkaPartitionKey, JLong])
- : JavaDStream[T] = {
- implicit val cmt: ClassTag[T] =
- implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
- ssc.kafkaStream[T](
- zkQuorum,
- groupId,
- Map(topics.mapValues(_.intValue()).toSeq: _*),
- Map(initialOffsets.mapValues(_.longValue()).toSeq: _*))
+ storageLevel: StorageLevel)
+ : JavaDStream[String] = {
+ implicit val cmt: ClassTag[String] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]]
+ ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*),
+ storageLevel)
}
/**
* Create an input stream that pulls messages form a Kafka Broker.
- * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
- * @param groupId The group id for this consumer.
+ * @param typeClass Type of RDD
+ * @param decoderClass Type of kafka decoder
+ * @param kafkaParams Map of kafka configuration paramaters.
+ * See: http://kafka.apache.org/configuration.html
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
- * @param initialOffsets Optional initial offsets for each of the partitions to consume.
- * By default the value is pulled from zookeper.
* @param storageLevel RDD storage level. Defaults to memory-only
*/
- def kafkaStream[T](
- zkQuorum: String,
- groupId: String,
+ def kafkaStream[T, D <: kafka.serializer.Decoder[_]](
+ typeClass: Class[T],
+ decoderClass: Class[D],
+ kafkaParams: JMap[String, String],
topics: JMap[String, JInt],
- initialOffsets: JMap[KafkaPartitionKey, JLong],
storageLevel: StorageLevel)
: JavaDStream[T] = {
implicit val cmt: ClassTag[T] =
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
- ssc.kafkaStream[T](
- zkQuorum,
- groupId,
+ implicit val cmd: Manifest[D] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[D]]
+ ssc.kafkaStream[T, D](
+ kafkaParams.toMap,
Map(topics.mapValues(_.intValue()).toSeq: _*),
- Map(initialOffsets.mapValues(_.longValue()).toSeq: _*),
storageLevel)
}
@@ -316,44 +311,73 @@ class JavaStreamingContext(val ssc: StreamingContext) {
/**
* Create a input stream that returns tweets received from Twitter.
- * @param username Twitter username
- * @param password Twitter password
+ * @param twitterAuth Twitter4J Authorization object
+ * @param filters Set of filter strings to get only those tweets that match them
+ * @param storageLevel Storage level to use for storing the received objects
+ */
+ def twitterStream(
+ twitterAuth: Authorization,
+ filters: Array[String],
+ storageLevel: StorageLevel
+ ): JavaDStream[Status] = {
+ ssc.twitterStream(Some(twitterAuth), filters, storageLevel)
+ }
+
+ /**
+ * Create a input stream that returns tweets received from Twitter using Twitter4J's default
+ * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey,
+ * .consumerSecret, .accessToken and .accessTokenSecret to be set.
* @param filters Set of filter strings to get only those tweets that match them
* @param storageLevel Storage level to use for storing the received objects
*/
def twitterStream(
- username: String,
- password: String,
filters: Array[String],
storageLevel: StorageLevel
): JavaDStream[Status] = {
- ssc.twitterStream(username, password, filters, storageLevel)
+ ssc.twitterStream(None, filters, storageLevel)
}
/**
* Create a input stream that returns tweets received from Twitter.
- * @param username Twitter username
- * @param password Twitter password
+ * @param twitterAuth Twitter4J Authorization
* @param filters Set of filter strings to get only those tweets that match them
*/
def twitterStream(
- username: String,
- password: String,
+ twitterAuth: Authorization,
filters: Array[String]
): JavaDStream[Status] = {
- ssc.twitterStream(username, password, filters)
+ ssc.twitterStream(Some(twitterAuth), filters)
+ }
+
+ /**
+ * Create a input stream that returns tweets received from Twitter using Twitter4J's default
+ * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey,
+ * .consumerSecret, .accessToken and .accessTokenSecret to be set.
+ * @param filters Set of filter strings to get only those tweets that match them
+ */
+ def twitterStream(
+ filters: Array[String]
+ ): JavaDStream[Status] = {
+ ssc.twitterStream(None, filters)
}
/**
* Create a input stream that returns tweets received from Twitter.
- * @param username Twitter username
- * @param password Twitter password
+ * @param twitterAuth Twitter4J Authorization
*/
def twitterStream(
- username: String,
- password: String
+ twitterAuth: Authorization
): JavaDStream[Status] = {
- ssc.twitterStream(username, password)
+ ssc.twitterStream(Some(twitterAuth))
+ }
+
+ /**
+ * Create a input stream that returns tweets received from Twitter using Twitter4J's default
+ * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey,
+ * .consumerSecret, .accessToken and .accessTokenSecret to be set.
+ */
+ def twitterStream(): JavaDStream[Status] = {
+ ssc.twitterStream()
}
/**
diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
index e093edb05b..e0f6351ef7 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
@@ -9,58 +9,51 @@ import java.util.concurrent.Executors
import kafka.consumer._
import kafka.message.{Message, MessageSet, MessageAndMetadata}
-import kafka.serializer.StringDecoder
+import kafka.serializer.Decoder
import kafka.utils.{Utils, ZKGroupTopicDirs}
import kafka.utils.ZkUtils._
+import kafka.utils.ZKStringSerializer
+import org.I0Itec.zkclient._
import scala.collection.Map
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
-// Key for a specific Kafka Partition: (broker, topic, group, part)
-case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int)
-
/**
* Input stream that pulls messages from a Kafka Broker.
*
- * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
- * @param groupId The group id for this consumer.
+ * @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
- * @param initialOffsets Optional initial offsets for each of the partitions to consume.
- * By default the value is pulled from zookeper.
* @param storageLevel RDD storage level.
*/
private[streaming]
-class KafkaInputDStream[T: ClassTag](
+class KafkaInputDStream[T: ClassTag, D <: Decoder[_]: Manifest](
@transient ssc_ : StreamingContext,
- zkQuorum: String,
- groupId: String,
+ kafkaParams: Map[String, String],
topics: Map[String, Int],
- initialOffsets: Map[KafkaPartitionKey, Long],
storageLevel: StorageLevel
) extends NetworkInputDStream[T](ssc_ ) with Logging {
def getReceiver(): NetworkReceiver[T] = {
- new KafkaReceiver(zkQuorum, groupId, topics, initialOffsets, storageLevel)
+ new KafkaReceiver[T, D](kafkaParams, topics, storageLevel)
.asInstanceOf[NetworkReceiver[T]]
}
}
private[streaming]
-class KafkaReceiver(zkQuorum: String, groupId: String,
- topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long],
- storageLevel: StorageLevel) extends NetworkReceiver[Any] {
-
- // Timeout for establishing a connection to Zookeper in ms.
- val ZK_TIMEOUT = 10000
+class KafkaReceiver[T: ClassTag, D <: Decoder[_]: Manifest](
+ kafkaParams: Map[String, String],
+ topics: Map[String, Int],
+ storageLevel: StorageLevel
+ ) extends NetworkReceiver[Any] {
// Handles pushing data into the BlockManager
lazy protected val blockGenerator = new BlockGenerator(storageLevel)
// Connection to Kafka
- var consumerConnector : ZookeeperConsumerConnector = null
+ var consumerConnector : ConsumerConnector = null
def onStop() {
blockGenerator.stop()
@@ -73,54 +66,59 @@ class KafkaReceiver(zkQuorum: String, groupId: String,
// In case we are using multiple Threads to handle Kafka Messages
val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _))
- logInfo("Starting Kafka Consumer Stream with group: " + groupId)
- logInfo("Initial offsets: " + initialOffsets.toString)
+ logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("groupid"))
- // Zookeper connection properties
+ // Kafka connection properties
val props = new Properties()
- props.put("zk.connect", zkQuorum)
- props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString)
- props.put("groupid", groupId)
+ kafkaParams.foreach(param => props.put(param._1, param._2))
// Create the connection to the cluster
- logInfo("Connecting to Zookeper: " + zkQuorum)
+ logInfo("Connecting to Zookeper: " + kafkaParams("zk.connect"))
val consumerConfig = new ConsumerConfig(props)
- consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector]
- logInfo("Connected to " + zkQuorum)
+ consumerConnector = Consumer.create(consumerConfig)
+ logInfo("Connected to " + kafkaParams("zk.connect"))
- // If specified, set the topic offset
- setOffsets(initialOffsets)
+ // When autooffset.reset is defined, it is our responsibility to try and whack the
+ // consumer group zk node.
+ if (kafkaParams.contains("autooffset.reset")) {
+ tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid"))
+ }
// Create Threads for each Topic/Message Stream we are listening
- val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder())
+ val decoder = manifest[D].erasure.newInstance.asInstanceOf[Decoder[T]]
+ val topicMessageStreams = consumerConnector.createMessageStreams(topics, decoder)
// Start the messages handler for each partition
topicMessageStreams.values.foreach { streams =>
streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) }
}
-
- }
-
- // Overwrites the offets in Zookeper.
- private def setOffsets(offsets: Map[KafkaPartitionKey, Long]) {
- offsets.foreach { case(key, offset) =>
- val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic)
- val partitionName = key.brokerId + "-" + key.partId
- updatePersistentPath(consumerConnector.zkClient,
- topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString)
- }
}
// Handles Kafka Messages
- private class MessageHandler(stream: KafkaStream[String]) extends Runnable {
+ private class MessageHandler[T: ClassManifest](stream: KafkaStream[T]) extends Runnable {
def run() {
logInfo("Starting MessageHandler.")
- stream.takeWhile { msgAndMetadata =>
+ for (msgAndMetadata <- stream) {
blockGenerator += msgAndMetadata.message
- // Keep on handling messages
-
- true
}
}
}
+
+ // It is our responsibility to delete the consumer group when specifying autooffset.reset. This is because
+ // Kafka 0.7.2 only honors this param when the group is not in zookeeper.
+ //
+ // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied from Kafkas'
+ // ConsoleConsumer. See code related to 'autooffset.reset' when it is set to 'smallest'/'largest':
+ // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala
+ private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) {
+ try {
+ val dir = "/consumers/" + groupId
+ logInfo("Cleaning up temporary zookeeper data under " + dir + ".")
+ val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer)
+ zk.deleteRecursive(dir)
+ zk.close()
+ } catch {
+ case _ => // swallow
+ }
+ }
}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
index 52b9968f6e..3b9cd0e89b 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
@@ -201,7 +201,7 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging
case class Block(id: String, iterator: Iterator[T], metadata: Any = null)
val clock = new SystemClock()
- val blockInterval = 200L
+ val blockInterval = System.getProperty("spark.streaming.blockInterval", "200").toLong
val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer)
val blockStorageLevel = storageLevel
val blocksForPushing = new ArrayBlockingQueue[Block](1000)
diff --git a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala
index c697498862..ff7a58be45 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala
@@ -3,34 +3,45 @@ package spark.streaming.dstream
import spark._
import spark.streaming._
import storage.StorageLevel
-
import twitter4j._
-import twitter4j.auth.BasicAuthorization
+import twitter4j.auth.Authorization
+import java.util.prefs.Preferences
+import twitter4j.conf.ConfigurationBuilder
+import twitter4j.conf.PropertyConfiguration
+import twitter4j.auth.OAuthAuthorization
+import twitter4j.auth.AccessToken
/* A stream of Twitter statuses, potentially filtered by one or more keywords.
*
-* @constructor create a new Twitter stream using the supplied username and password to authenticate.
+* @constructor create a new Twitter stream using the supplied Twitter4J authentication credentials.
* An optional set of string filters can be used to restrict the set of tweets. The Twitter API is
* such that this may return a sampled subset of all tweets during each interval.
+*
+* If no Authorization object is provided, initializes OAuth authorization using the system
+* properties twitter4j.oauth.consumerKey, .consumerSecret, .accessToken and .accessTokenSecret.
*/
private[streaming]
class TwitterInputDStream(
@transient ssc_ : StreamingContext,
- username: String,
- password: String,
+ twitterAuth: Option[Authorization],
filters: Seq[String],
storageLevel: StorageLevel
) extends NetworkInputDStream[Status](ssc_) {
+
+ private def createOAuthAuthorization(): Authorization = {
+ new OAuthAuthorization(new ConfigurationBuilder().build())
+ }
+ private val authorization = twitterAuth.getOrElse(createOAuthAuthorization())
+
override def getReceiver(): NetworkReceiver[Status] = {
- new TwitterReceiver(username, password, filters, storageLevel)
+ new TwitterReceiver(authorization, filters, storageLevel)
}
}
private[streaming]
class TwitterReceiver(
- username: String,
- password: String,
+ twitterAuth: Authorization,
filters: Seq[String],
storageLevel: StorageLevel
) extends NetworkReceiver[Status] {
@@ -40,8 +51,7 @@ class TwitterReceiver(
protected override def onStart() {
blockGenerator.start()
- twitterStream = new TwitterStreamFactory()
- .getInstance(new BasicAuthorization(username, password))
+ twitterStream = new TwitterStreamFactory().getInstance(twitterAuth)
twitterStream.addListener(new StatusListener {
def onStatus(status: Status) = {
blockGenerator += status
diff --git a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala
index 3db1eaa834..b3daa5a91b 100644
--- a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala
+++ b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala
@@ -160,6 +160,7 @@ object MasterFailureTest extends Logging {
// Setup the streaming computation with the given operation
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
var ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map())
ssc.checkpoint(checkpointDir.toString)
val inputStream = ssc.textFileStream(testDir.toString)
@@ -206,6 +207,7 @@ object MasterFailureTest extends Logging {
// (iii) Its not timed out yet
System.clearProperty("spark.streaming.clock")
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
ssc.start()
val startTime = System.currentTimeMillis()
while (!killed && !isLastOutputGenerated && !isTimedOut) {
@@ -358,13 +360,16 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long)
// Write the data to a local file and then move it to the target test directory
val localFile = new File(localTestDir, (i+1).toString)
val hadoopFile = new Path(testDir, (i+1).toString)
+ val tempHadoopFile = new Path(testDir, ".tmp_" + (i+1).toString)
FileUtils.writeStringToFile(localFile, input(i).toString + "\n")
var tries = 0
var done = false
while (!done && tries < maxTries) {
tries += 1
try {
- fs.copyFromLocalFile(new Path(localFile.toString), hadoopFile)
+ // fs.copyFromLocalFile(new Path(localFile.toString), hadoopFile)
+ fs.copyFromLocalFile(new Path(localFile.toString), tempHadoopFile)
+ fs.rename(tempHadoopFile, hadoopFile)
done = true
} catch {
case ioe: IOException => {
diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala
index d8b987ec86..bd0b0e74c1 100644
--- a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala
+++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala
@@ -5,7 +5,7 @@ import spark.util.{RateLimitedOutputStream, IntParam}
import java.net.ServerSocket
import spark.{Logging, KryoSerializer}
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-import io.Source
+import scala.io.Source
import java.io.IOException
/**
diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java
index 3bed500f73..4cf10582a9 100644
--- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java
@@ -4,6 +4,7 @@ import com.google.common.base.Optional;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.Files;
+import kafka.serializer.StringDecoder;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.junit.After;
import org.junit.Assert;
@@ -23,7 +24,6 @@ import spark.streaming.api.java.JavaPairDStream;
import spark.streaming.api.java.JavaStreamingContext;
import spark.streaming.JavaTestUtils;
import spark.streaming.JavaCheckpointTestUtils;
-import spark.streaming.dstream.KafkaPartitionKey;
import spark.streaming.InputStreamsSuite;
import java.io.*;
@@ -1203,10 +1203,14 @@ public class JavaAPISuite implements Serializable {
@Test
public void testKafkaStream() {
HashMap<String, Integer> topics = Maps.newHashMap();
- HashMap<KafkaPartitionKey, Long> offsets = Maps.newHashMap();
JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics);
- JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, offsets);
- JavaDStream test3 = ssc.kafkaStream("localhost:12345", "group", topics, offsets,
+ JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics,
+ StorageLevel.MEMORY_AND_DISK());
+
+ HashMap<String, String> kafkaParams = Maps.newHashMap();
+ kafkaParams.put("zk.connect","localhost:12345");
+ kafkaParams.put("groupid","consumer-group");
+ JavaDStream test3 = ssc.kafkaStream(String.class, StringDecoder.class, kafkaParams, topics,
StorageLevel.MEMORY_AND_DISK());
}
@@ -1263,7 +1267,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void testTwitterStream() {
String[] filters = new String[] { "good", "bad", "ugly" };
- JavaDStream test = ssc.twitterStream("username", "password", filters, StorageLevel.MEMORY_ONLY());
+ JavaDStream test = ssc.twitterStream(filters, StorageLevel.MEMORY_ONLY());
}
@Test
diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
index cf2ed8b1d4..565089a853 100644
--- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
@@ -15,6 +15,7 @@ class BasicOperationsSuite extends TestSuiteBase {
after {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
}
test("map") {
@@ -92,9 +93,9 @@ class BasicOperationsSuite extends TestSuiteBase {
test("count") {
testOperation(
- Seq(1 to 1, 1 to 2, 1 to 3, 1 to 4),
+ Seq(Seq(), 1 to 1, 1 to 2, 1 to 3, 1 to 4),
(s: DStream[Int]) => s.count(),
- Seq(Seq(1L), Seq(2L), Seq(3L), Seq(4L))
+ Seq(Seq(0L), Seq(1L), Seq(2L), Seq(3L), Seq(4L))
)
}
diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
index 143a26d911..141842c380 100644
--- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
@@ -35,6 +35,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
}
var ssc: StreamingContext = null
@@ -329,6 +330,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
)
ssc = new StreamingContext(checkpointDir)
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
ssc.start()
val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches)
// the first element will be re-processed data of the last batch before restart
diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
index 67dca2ac31..b024fc9dcc 100644
--- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
@@ -41,6 +41,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
after {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
}
@@ -242,6 +243,17 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
assert(output(i) === expectedOutput(i))
}
}
+
+ test("kafka input stream") {
+ val ssc = new StreamingContext(master, framework, batchDuration)
+ val topics = Map("my-topic" -> 1)
+ val test1 = ssc.kafkaStream("localhost:12345", "group", topics)
+ val test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK)
+
+ // Test specifying decoder
+ val kafkaParams = Map("zk.connect"->"localhost:12345","groupid"->"consumer-group")
+ val test3 = ssc.kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK)
+ }
}
diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
index 1b66f3bda2..80d827706f 100644
--- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
@@ -16,6 +16,7 @@ class WindowOperationsSuite extends TestSuiteBase {
after {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
}
val largerSlideInput = Seq(