aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoseph E. Gonzalez <joseph.e.gonzalez@gmail.com>2013-10-14 22:56:42 -0700
committerJoseph E. Gonzalez <joseph.e.gonzalez@gmail.com>2013-10-14 22:56:42 -0700
commitef7c369092b33a9fcf4be00d98521b92321d63c3 (patch)
tree9fcb38b7f5fc667b5cd28939b3c26f6fa0b254bf
parent67bb39c54b25e1b13edad01c25f7183e95f8a400 (diff)
parent3b11f43e36e2aca2346db7542c52fcbbeee70da2 (diff)
downloadspark-ef7c369092b33a9fcf4be00d98521b92321d63c3.tar.gz
spark-ef7c369092b33a9fcf4be00d98521b92321d63c3.tar.bz2
spark-ef7c369092b33a9fcf4be00d98521b92321d63c3.zip
merged with upstream changes
-rw-r--r--README.md1
-rw-r--r--assembly/pom.xml20
-rw-r--r--bagel/pom.xml10
-rwxr-xr-xbin/stop-slaves.sh2
-rw-r--r--core/pom.xml21
-rw-r--r--core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java3
-rw-r--r--core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java6
-rw-r--r--core/src/main/scala/org/apache/spark/Aggregator.scala49
-rw-r--r--core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/CacheManager.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala81
-rw-r--r--core/src/main/scala/org/apache/spark/TaskEndReason.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala420
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/Client.scala84
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala53
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala90
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala45
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/Master.scala228
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala46
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala53
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala203
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala42
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala136
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala85
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala175
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala120
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala38
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Pool.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/Pool.scala)10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/Schedulable.scala)8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulableBuilder.scala)101
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingAlgorithm.scala)2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingMode.scala)2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/TaskDescription.scala)2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/TaskInfo.scala)2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/TaskLocality.scala)2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/TaskSetManager.scala)5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala57
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala171
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala124
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala)20
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackend.scala)22
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala30
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockException.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockId.scala96
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala155
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockMessage.scala38
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockStore.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskStore.scala37
-rw-r--r--core/src/main/scala/org/apache/spark/storage/MemoryStore.scala48
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StorageUtils.scala47
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala230
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala36
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala13
-rw-r--r--core/src/test/scala/org/apache/spark/CacheManagerSuite.scala12
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/DistributedSuite.scala29
-rw-r--r--core/src/test/scala/org/apache/spark/ThreadingSuite.scala45
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala39
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala19
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala30
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala15
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala58
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala114
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala21
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala114
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala102
-rw-r--r--core/src/test/scala/org/apache/spark/ui/UISuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala154
-rw-r--r--core/src/test/scala/org/apache/spark/util/UtilsSuite.scala11
-rw-r--r--docker/README.md5
-rwxr-xr-xdocker/build22
-rw-r--r--docker/spark-test/README.md10
-rw-r--r--docker/spark-test/base/Dockerfile38
-rwxr-xr-xdocker/spark-test/build22
-rw-r--r--docker/spark-test/master/Dockerfile21
-rwxr-xr-xdocker/spark-test/master/default_cmd22
-rw-r--r--docker/spark-test/worker/Dockerfile22
-rwxr-xr-xdocker/spark-test/worker/default_cmd22
-rw-r--r--docs/_config.yml4
-rwxr-xr-xdocs/_layouts/global.html4
-rw-r--r--docs/mllib-guide.md24
-rw-r--r--docs/python-programming-guide.md2
-rw-r--r--docs/running-on-yarn.md10
-rw-r--r--docs/spark-standalone.md75
-rw-r--r--docs/streaming-programming-guide.md5
-rw-r--r--docs/tuning.md2
-rw-r--r--ec2/README2
-rwxr-xr-xec2/spark_ec2.py118
-rw-r--r--examples/pom.xml32
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala2
-rwxr-xr-xmake-distribution.sh2
-rw-r--r--mllib/pom.xml10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala199
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java85
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala75
-rw-r--r--pom.xml21
-rw-r--r--project/SparkBuild.scala21
-rw-r--r--python/pyspark/rdd.py70
-rw-r--r--python/pyspark/serializers.py4
-rw-r--r--python/pyspark/shell.py2
-rw-r--r--repl-bin/pom.xml12
-rw-r--r--repl/pom.xml22
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala2
-rwxr-xr-xspark-class12
-rw-r--r--streaming/pom.xml11
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala11
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala14
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala4
-rw-r--r--tools/pom.xml10
-rw-r--r--yarn/pom.xml8
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala55
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala167
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala29
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala63
183 files changed, 4890 insertions, 1255 deletions
diff --git a/README.md b/README.md
index c522c5d8a0..139bdc070c 100644
--- a/README.md
+++ b/README.md
@@ -114,4 +114,3 @@ submitting any copyrighted material via pull request, email, or other means
you agree to license the material under the project's open source license and
warrant that you have the legal authority to do so.
-
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 808a829e19..09df8c1fd7 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -21,12 +21,12 @@
<parent>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.8.0-SNAPSHOT</version>
+ <version>0.9.0-incubating-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-assembly</artifactId>
+ <artifactId>spark-assembly_2.9.3</artifactId>
<name>Spark Project Assembly</name>
<url>http://spark.incubator.apache.org/</url>
@@ -41,27 +41,27 @@
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core</artifactId>
+ <artifactId>spark-core_2.9.3</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-bagel</artifactId>
+ <artifactId>spark-bagel_2.9.3</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-mllib</artifactId>
+ <artifactId>spark-mllib_2.9.3</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-repl</artifactId>
+ <artifactId>spark-repl_2.9.3</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-streaming</artifactId>
+ <artifactId>spark-streaming_2.9.3</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
@@ -104,13 +104,13 @@
</goals>
<configuration>
<transformers>
- <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
+ <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer" />
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
<resource>META-INF/services/org.apache.hadoop.fs.FileSystem</resource>
</transformer>
</transformers>
<transformers>
- <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
+ <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer" />
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
<resource>reference.conf</resource>
</transformer>
@@ -128,7 +128,7 @@
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-yarn</artifactId>
+ <artifactId>spark-yarn_2.9.3</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>
diff --git a/bagel/pom.xml b/bagel/pom.xml
index 51173c32b2..0e552c880f 100644
--- a/bagel/pom.xml
+++ b/bagel/pom.xml
@@ -21,12 +21,12 @@
<parent>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.8.0-SNAPSHOT</version>
+ <version>0.9.0-incubating-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-bagel</artifactId>
+ <artifactId>spark-bagel_2.9.3</artifactId>
<packaging>jar</packaging>
<name>Spark Project Bagel</name>
<url>http://spark.incubator.apache.org/</url>
@@ -34,7 +34,7 @@
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core</artifactId>
+ <artifactId>spark-core_2.9.3</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
@@ -43,12 +43,12 @@
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
- <artifactId>scalatest_${scala.version}</artifactId>
+ <artifactId>scalatest_2.9.3</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
- <artifactId>scalacheck_${scala.version}</artifactId>
+ <artifactId>scalacheck_2.9.3</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
diff --git a/bin/stop-slaves.sh b/bin/stop-slaves.sh
index 03e416a132..fcb8555d4e 100755
--- a/bin/stop-slaves.sh
+++ b/bin/stop-slaves.sh
@@ -17,8 +17,6 @@
# limitations under the License.
#
-# Starts the master on the machine this script is executed on.
-
bin=`dirname "$0"`
bin=`cd "$bin"; pwd`
diff --git a/core/pom.xml b/core/pom.xml
index 14cd520aaf..8621d257e5 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -21,12 +21,12 @@
<parent>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.8.0-SNAPSHOT</version>
+ <version>0.9.0-incubating-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core</artifactId>
+ <artifactId>spark-core_2.9.3</artifactId>
<packaging>jar</packaging>
<name>Spark Project Core</name>
<url>http://spark.incubator.apache.org/</url>
@@ -39,7 +39,6 @@
<dependency>
<groupId>net.java.dev.jets3t</groupId>
<artifactId>jets3t</artifactId>
- <version>0.7.1</version>
</dependency>
<dependency>
<groupId>org.apache.avro</groupId>
@@ -50,6 +49,10 @@
<artifactId>avro-ipc</artifactId>
</dependency>
<dependency>
+ <groupId>org.apache.zookeeper</groupId>
+ <artifactId>zookeeper</artifactId>
+ </dependency>
+ <dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
</dependency>
@@ -162,12 +165,12 @@
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
- <artifactId>scalatest_${scala.version}</artifactId>
+ <artifactId>scalatest_2.9.3</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
- <artifactId>scalacheck_${scala.version}</artifactId>
+ <artifactId>scalacheck_2.9.3</artifactId>
<scope>test</scope>
</dependency>
<dependency>
@@ -202,14 +205,14 @@
<configuration>
<exportAntProperties>true</exportAntProperties>
<tasks>
- <property name="spark.classpath" refid="maven.test.classpath"/>
- <property environment="env"/>
+ <property name="spark.classpath" refid="maven.test.classpath" />
+ <property environment="env" />
<fail message="Please set the SCALA_HOME (or SCALA_LIBRARY_PATH if scala is on the path) environment variables and retry.">
<condition>
<not>
<or>
- <isset property="env.SCALA_HOME"/>
- <isset property="env.SCALA_LIBRARY_PATH"/>
+ <isset property="env.SCALA_HOME" />
+ <isset property="env.SCALA_LIBRARY_PATH" />
</or>
</not>
</condition>
diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java
index c4aa2669e0..8a09210245 100644
--- a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java
+++ b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java
@@ -21,6 +21,7 @@ import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundByteHandlerAdapter;
+import org.apache.spark.storage.BlockId;
abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
@@ -33,7 +34,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
}
public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
- public abstract void handleError(String blockId);
+ public abstract void handleError(BlockId blockId);
@Override
public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {
diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
index d3d57a0255..cfd8132891 100644
--- a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
+++ b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
@@ -24,6 +24,7 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.channel.DefaultFileRegion;
+import org.apache.spark.storage.BlockId;
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
@@ -34,8 +35,9 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
}
@Override
- public void messageReceived(ChannelHandlerContext ctx, String blockId) {
- String path = pResolver.getAbsolutePath(blockId);
+ public void messageReceived(ChannelHandlerContext ctx, String blockIdString) {
+ BlockId blockId = BlockId.apply(blockIdString);
+ String path = pResolver.getAbsolutePath(blockId.name());
// if getFilePath returns null, close the channel
if (path == null) {
//ctx.close();
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index 3ef402926e..1a2ec55876 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -17,43 +17,42 @@
package org.apache.spark
-import java.util.{HashMap => JHashMap}
+import org.apache.spark.util.AppendOnlyMap
-import scala.collection.JavaConversions._
-
-/** A set of functions used to aggregate data.
- *
- * @param createCombiner function to create the initial value of the aggregation.
- * @param mergeValue function to merge a new value into the aggregation result.
- * @param mergeCombiners function to merge outputs from multiple mergeValue function.
- */
+/**
+ * A set of functions used to aggregate data.
+ *
+ * @param createCombiner function to create the initial value of the aggregation.
+ * @param mergeValue function to merge a new value into the aggregation result.
+ * @param mergeCombiners function to merge outputs from multiple mergeValue function.
+ */
case class Aggregator[K, V, C] (
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) {
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = {
- val combiners = new JHashMap[K, C]
- for (kv <- iter) {
- val oldC = combiners.get(kv._1)
- if (oldC == null) {
- combiners.put(kv._1, createCombiner(kv._2))
- } else {
- combiners.put(kv._1, mergeValue(oldC, kv._2))
- }
+ val combiners = new AppendOnlyMap[K, C]
+ var kv: Product2[K, V] = null
+ val update = (hadValue: Boolean, oldValue: C) => {
+ if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
+ }
+ while (iter.hasNext) {
+ kv = iter.next()
+ combiners.changeValue(kv._1, update)
}
combiners.iterator
}
def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
- val combiners = new JHashMap[K, C]
- iter.foreach { case(k, c) =>
- val oldC = combiners.get(k)
- if (oldC == null) {
- combiners.put(k, c)
- } else {
- combiners.put(k, mergeCombiners(oldC, c))
- }
+ val combiners = new AppendOnlyMap[K, C]
+ var kc: (K, C) = null
+ val update = (hadValue: Boolean, oldValue: C) => {
+ if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2
+ }
+ while (iter.hasNext) {
+ kc = iter.next()
+ combiners.changeValue(kc._1, update)
}
combiners.iterator
}
diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
index 908ff56a6b..f8af6b0fbe 100644
--- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.HashMap
import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
import org.apache.spark.serializer.Serializer
-import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator
@@ -45,12 +45,12 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
}
- val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map {
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) =>
- (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2)))
+ (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}
- def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[T] = {
+ def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
@@ -58,9 +58,8 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
block.asInstanceOf[Iterator[T]]
}
case None => {
- val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
blockId match {
- case regex(shufId, mapId, _) =>
+ case ShuffleBlockId(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
case _ =>
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index 68b99ca125..221bb68c61 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -18,7 +18,7 @@
package org.apache.spark
import scala.collection.mutable.{ArrayBuffer, HashSet}
-import org.apache.spark.storage.{BlockManager, StorageLevel}
+import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, RDDBlockId}
import org.apache.spark.rdd.RDD
@@ -26,28 +26,29 @@ import org.apache.spark.rdd.RDD
sure a node doesn't load two copies of an RDD at once.
*/
private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
- private val loading = new HashSet[String]
+
+ /** Keys of RDD splits that are being computed/loaded. */
+ private val loading = new HashSet[RDDBlockId]()
/** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = {
- val key = "rdd_%d_%d".format(rdd.id, split.index)
- logInfo("Cache key is " + key)
+ val key = RDDBlockId(rdd.id, split.index)
+ logDebug("Looking for partition " + key)
blockManager.get(key) match {
- case Some(cachedValues) =>
- // Partition is in cache, so just return its values
- logInfo("Found partition in cache!")
- return cachedValues.asInstanceOf[Iterator[T]]
+ case Some(values) =>
+ // Partition is already materialized, so just return its values
+ return values.asInstanceOf[Iterator[T]]
case None =>
// Mark the split as loading (unless someone else marks it first)
loading.synchronized {
if (loading.contains(key)) {
- logInfo("Loading contains " + key + ", waiting...")
+ logInfo("Another thread is loading %s, waiting for it to finish...".format(key))
while (loading.contains(key)) {
try {loading.wait()} catch {case _ : Throwable =>}
}
- logInfo("Loading no longer contains " + key + ", so returning cached result")
+ logInfo("Finished waiting for %s".format(key))
// See whether someone else has successfully loaded it. The main way this would fail
// is for the RDD-level cache eviction policy if someone else has loaded the same RDD
// partition but we didn't want to make space for it. However, that case is unlikely
@@ -57,7 +58,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
case Some(values) =>
return values.asInstanceOf[Iterator[T]]
case None =>
- logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
+ logInfo("Whoever was loading %s failed; we'll try it ourselves".format(key))
loading.add(key)
}
} else {
@@ -66,13 +67,13 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
}
try {
// If we got here, we have to load the split
- logInfo("Computing partition " + split)
+ logInfo("Partition %s not found, computing it".format(key))
val computedValues = rdd.computeOrReadCheckpoint(split, context)
// Persist the result, so long as the task is not running locally
if (context.runningLocally) { return computedValues }
val elements = new ArrayBuffer[Any]
elements ++= computedValues
- blockManager.put(key, elements, storageLevel, true)
+ blockManager.put(key, elements, storageLevel, tellMaster = true)
return elements.iterator.asInstanceOf[Iterator[T]]
} finally {
loading.synchronized {
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index ae7cf2a893..1e3f1ebfaf 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -32,7 +32,7 @@ import akka.util.Duration
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.util.{MetadataCleanerType, Utils, MetadataCleaner, TimeStampedHashMap}
private[spark] sealed trait MapOutputTrackerMessage
@@ -71,7 +71,7 @@ private[spark] class MapOutputTracker extends Logging {
var cacheEpoch = epoch
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
- val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index e22b56427e..f3723a4f9d 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -51,17 +51,22 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor
import org.apache.mesos.MesosNativeLibrary
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.LocalSparkCluster
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend,
- ClusterScheduler, Schedulable, SchedulingMode}
+ ClusterScheduler}
import org.apache.spark.scheduler.local.LocalScheduler
-import org.apache.spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
+import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import org.apache.spark.storage.{StorageUtils, BlockManagerSource}
import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.{ClosureCleaner, Utils, MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.util._
+import org.apache.spark.scheduler.StageInfo
+import org.apache.spark.storage.RDDInfo
+import org.apache.spark.storage.StorageStatus
+import scala.Some
import org.apache.spark.scheduler.StageInfo
import org.apache.spark.storage.RDDInfo
import org.apache.spark.storage.StorageStatus
@@ -83,9 +88,11 @@ class SparkContext(
val sparkHome: String = null,
val jars: Seq[String] = Nil,
val environment: Map[String, String] = Map(),
- // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) too.
- // This is typically generated from InputFormatInfo.computePreferredLocations .. host, set of data-local splits on host
- val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map())
+ // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc)
+ // too. This is typically generated from InputFormatInfo.computePreferredLocations .. host, set
+ // of data-local splits on host
+ val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] =
+ scala.collection.immutable.Map())
extends Logging {
// Ensure logging is initialized before we spawn any threads
@@ -116,7 +123,7 @@ class SparkContext(
// Keeps track of all persisted RDDs
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
- private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup)
+ private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup)
// Initalize the Spark UI
private[spark] val ui = new SparkUI(this)
@@ -145,7 +152,7 @@ class SparkContext(
}
// Create and start the scheduler
- private var taskScheduler: TaskScheduler = {
+ private[spark] var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
@@ -153,7 +160,7 @@ class SparkContext(
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
- val SPARK_REGEX = """(spark://.*)""".r
+ val SPARK_REGEX = """spark://(.*)""".r
//Regular expression for connection to Mesos cluster
val MESOS_REGEX = """(mesos://.*)""".r
@@ -169,7 +176,8 @@ class SparkContext(
case SPARK_REGEX(sparkUrl) =>
val scheduler = new ClusterScheduler(this)
- val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
+ val masterUrls = sparkUrl.split(",").map("spark://" + _)
+ val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
scheduler.initialize(backend)
scheduler
@@ -185,8 +193,8 @@ class SparkContext(
val scheduler = new ClusterScheduler(this)
val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
- val sparkUrl = localCluster.start()
- val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
+ val masterUrls = localCluster.start()
+ val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
scheduler.initialize(backend)
backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
localCluster.stop()
@@ -228,7 +236,7 @@ class SparkContext(
}
taskScheduler.start()
- @volatile private var dagScheduler = new DAGScheduler(taskScheduler)
+ @volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler)
dagScheduler.start()
ui.start()
@@ -238,7 +246,8 @@ class SparkContext(
val env = SparkEnv.get
val conf = env.hadoop.newConfiguration()
// Explicitly check for S3 environment variables
- if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
+ if (System.getenv("AWS_ACCESS_KEY_ID") != null &&
+ System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
@@ -256,7 +265,9 @@ class SparkContext(
private[spark] var checkpointDir: Option[String] = None
// Thread Local variable that can be used by users to pass information down the stack
- private val localProperties = new ThreadLocal[Properties]
+ private val localProperties = new InheritableThreadLocal[Properties] {
+ override protected def childValue(parent: Properties): Properties = new Properties(parent)
+ }
def initLocalProperties() {
localProperties.set(new Properties())
@@ -273,6 +284,9 @@ class SparkContext(
}
}
+ def getLocalProperty(key: String): String =
+ Option(localProperties.get).map(_.getProperty(key)).getOrElse(null)
+
/** Set a human readable description of the current job. */
def setJobDescription(value: String) {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
@@ -321,7 +335,7 @@ class SparkContext(
}
/**
- * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf giving its InputFormat and any
+ * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and any
* other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
* etc).
*/
@@ -332,6 +346,8 @@ class SparkContext(
valueClass: Class[V],
minSplits: Int = defaultMinSplits
): RDD[(K, V)] = {
+ // Add necessary security credentials to the JobConf before broadcasting it.
+ SparkEnv.get.hadoop.addCredentials(conf)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
}
@@ -342,10 +358,18 @@ class SparkContext(
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int = defaultMinSplits
- ) : RDD[(K, V)] = {
- val conf = new JobConf(hadoopConfiguration)
- FileInputFormat.setInputPaths(conf, path)
- new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
+ ): RDD[(K, V)] = {
+ // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
+ val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
+ val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
+ new HadoopRDD(
+ this,
+ confBroadcast,
+ Some(setInputPathsFunc),
+ inputFormatClass,
+ keyClass,
+ valueClass,
+ minSplits)
}
/**
@@ -643,10 +667,21 @@ class SparkContext(
key = uri.getScheme match {
case null | "file" =>
if (env.hadoop.isYarnMode()) {
- logWarning("local jar specified as parameter to addJar under Yarn mode")
- return
+ // In order for this to work on yarn the user must specify the --addjars option to
+ // the client to upload the file into the distributed cache to make it show up in the
+ // current working directory.
+ val fileName = new Path(uri.getPath).getName()
+ try {
+ env.httpFileServer.addJar(new File(fileName))
+ } catch {
+ case e: Exception => {
+ logError("Error adding jar (" + e + "), was the --addJars option used?")
+ throw e
+ }
+ }
+ } else {
+ env.httpFileServer.addJar(new File(uri.getPath))
}
- env.httpFileServer.addJar(new File(uri.getPath))
case _ =>
path
}
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 03bf268863..8466c2a004 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -46,6 +46,10 @@ private[spark] case class ExceptionFailure(
metrics: Option[TaskMetrics])
extends TaskEndReason
-private[spark] case class OtherFailure(message: String) extends TaskEndReason
+/**
+ * The task finished successfully, but the result was lost from the executor's block manager before
+ * it was fetched.
+ */
+private[spark] case object TaskResultLost extends TaskEndReason
-private[spark] case class TaskResultTooBigFailure() extends TaskEndReason
+private[spark] case class OtherFailure(message: String) extends TaskEndReason
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 7e6e691f11..7a3568c5ef 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -68,6 +68,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
new JavaRDD(rdd.map(f)(f.returnType()))(f.returnType())
/**
+ * Return a new RDD by applying a function to each partition of this RDD, while tracking the index
+ * of the original partition.
+ */
+ def mapPartitionsWithIndex[R: ClassManifest](
+ f: JFunction2[Int, java.util.Iterator[T], java.util.Iterator[R]],
+ preservesPartitioning: Boolean = false): JavaRDD[R] =
+ new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))),
+ preservesPartitioning))
+
+ /**
* Return a new RDD by applying a function to all elements of this RDD.
*/
def map[R](f: DoubleFunction[T]): JavaDoubleRDD =
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala
index b090c6edf3..2be4e323be 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala
@@ -17,12 +17,13 @@
package org.apache.spark.api.python
-import org.apache.spark.Partitioner
import java.util.Arrays
+
+import org.apache.spark.Partitioner
import org.apache.spark.util.Utils
/**
- * A [[org.apache.spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
+ * A [[org.apache.spark.Partitioner]] that performs handling of long-valued keys, for use by the Python API.
*
* Stores the unique id() of the Python-side partitioning function so that it is incorporated into
* equality comparisons. Correctness requires that the id is a unique identifier for the
@@ -30,6 +31,7 @@ import org.apache.spark.util.Utils
* function). This can be ensured by using the Python id() function and maintaining a reference
* to the Python partitioning function so that its id() is not reused.
*/
+
private[spark] class PythonPartitioner(
override val numPartitions: Int,
val pyPartitionFunctionId: Long)
@@ -37,7 +39,9 @@ private[spark] class PythonPartitioner(
override def getPartition(key: Any): Int = key match {
case null => 0
- case key: Array[Byte] => Utils.nonNegativeMod(Arrays.hashCode(key), numPartitions)
+ // we don't trust the Python partition function to return valid partition ID's so
+ // let's do a modulo numPartitions in any case
+ case key: Long => Utils.nonNegativeMod(key.toInt, numPartitions)
case _ => Utils.nonNegativeMod(key.hashCode(), numPartitions)
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index ccd3833964..1f8ad688a6 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -187,14 +187,14 @@ private class PythonException(msg: String) extends Exception(msg)
* This is used by PySpark's shuffle operations.
*/
private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
- RDD[(Array[Byte], Array[Byte])](prev) {
+ RDD[(Long, Array[Byte])](prev) {
override def getPartitions = prev.partitions
override def compute(split: Partition, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
- case Seq(a, b) => (a, b)
+ case Seq(a, b) => (Utils.deserializeLongValue(a), b)
case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
}
- val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
+ val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
}
private[spark] object PythonRDD {
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala
index 93e7815ab5..b6c484bfe1 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala
@@ -26,7 +26,7 @@ import scala.collection.mutable.{ListBuffer, Map, Set}
import scala.math
import org.apache.spark._
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.Utils
private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
@@ -36,7 +36,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
def value = value_
- def blockId: String = "broadcast_" + id
+ def blockId = BroadcastBlockId(id)
MultiTracker.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 9db26ae6de..609464e38d 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -25,16 +25,15 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import org.apache.spark.{HttpServer, Logging, SparkEnv}
import org.apache.spark.io.CompressionCodec
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashSet}
-
+import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
- def blockId: String = "broadcast_" + id
+ def blockId = BroadcastBlockId(id)
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
@@ -82,7 +81,7 @@ private object HttpBroadcast extends Logging {
private var server: HttpServer = null
private val files = new TimeStampedHashSet[String]
- private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup)
+ private val cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup)
private lazy val compressionCodec = CompressionCodec.createCodec()
@@ -121,7 +120,7 @@ private object HttpBroadcast extends Logging {
}
def write(id: Long, value: Any) {
- val file = new File(broadcastDir, "broadcast-" + id)
+ val file = new File(broadcastDir, BroadcastBlockId(id).name)
val out: OutputStream = {
if (compress) {
compressionCodec.compressedOutputStream(new FileOutputStream(file))
@@ -137,7 +136,7 @@ private object HttpBroadcast extends Logging {
}
def read[T](id: Long): T = {
- val url = serverUri + "/broadcast-" + id
+ val url = serverUri + "/" + BroadcastBlockId(id).name
val in = {
if (compress) {
compressionCodec.compressedInputStream(new URL(url).openStream())
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala
index 80c97ca073..e6674d49a7 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala
@@ -19,13 +19,11 @@ package org.apache.spark.broadcast
import java.io._
import java.net._
-import java.util.{Comparator, Random, UUID}
-import scala.collection.mutable.{ListBuffer, Map, Set}
-import scala.math
+import scala.collection.mutable.{ListBuffer, Set}
import org.apache.spark._
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.Utils
private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
@@ -33,7 +31,7 @@ extends Broadcast[T](id) with Logging with Serializable {
def value = value_
- def blockId = "broadcast_" + id
+ def blockId = BroadcastBlockId(id)
MultiTracker.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index 1cfff5e565..275331724a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -21,12 +21,14 @@ import scala.collection.immutable.List
import org.apache.spark.deploy.ExecutorState.ExecutorState
import org.apache.spark.deploy.master.{WorkerInfo, ApplicationInfo}
+import org.apache.spark.deploy.master.RecoveryState.MasterState
import org.apache.spark.deploy.worker.ExecutorRunner
import org.apache.spark.util.Utils
private[deploy] sealed trait DeployMessage extends Serializable
+/** Contains messages sent between Scheduler actor nodes. */
private[deploy] object DeployMessages {
// Worker to Master
@@ -52,17 +54,20 @@ private[deploy] object DeployMessages {
exitStatus: Option[Int])
extends DeployMessage
+ case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription])
+
case class Heartbeat(workerId: String) extends DeployMessage
// Master to Worker
- case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage
+ case class RegisteredWorker(masterUrl: String, masterWebUiUrl: String) extends DeployMessage
case class RegisterWorkerFailed(message: String) extends DeployMessage
- case class KillExecutor(appId: String, execId: Int) extends DeployMessage
+ case class KillExecutor(masterUrl: String, appId: String, execId: Int) extends DeployMessage
case class LaunchExecutor(
+ masterUrl: String,
appId: String,
execId: Int,
appDesc: ApplicationDescription,
@@ -76,9 +81,11 @@ private[deploy] object DeployMessages {
case class RegisterApplication(appDescription: ApplicationDescription)
extends DeployMessage
+ case class MasterChangeAcknowledged(appId: String)
+
// Master to Client
- case class RegisteredApplication(appId: String) extends DeployMessage
+ case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage
// TODO(matei): replace hostPort with host
case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) {
@@ -94,6 +101,10 @@ private[deploy] object DeployMessages {
case object StopClient
+ // Master to Worker & Client
+
+ case class MasterChanged(masterUrl: String, masterWebUiUrl: String)
+
// MasterWebUI To Master
case object RequestMasterState
@@ -101,7 +112,8 @@ private[deploy] object DeployMessages {
// Master to MasterWebUI
case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo],
- activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) {
+ activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo],
+ status: MasterState) {
Utils.checkHost(host, "Required hostname")
assert (port > 0)
@@ -123,12 +135,7 @@ private[deploy] object DeployMessages {
assert (port > 0)
}
- // Actor System to Master
-
- case object CheckForWorkerTimeOut
-
- case object RequestWebUIPort
-
- case class WebUIPortResponse(webUIBoundPort: Int)
+ // Actor System to Worker
+ case object SendHeartbeat
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala
new file mode 100644
index 0000000000..2abf0b69dd
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+/**
+ * Used to send state on-the-wire about Executors from Worker to Master.
+ * This state is sufficient for the Master to reconstruct its internal data structures during
+ * failover.
+ */
+private[spark] class ExecutorDescription(
+ val appId: String,
+ val execId: Int,
+ val cores: Int,
+ val state: ExecutorState.Value)
+ extends Serializable {
+
+ override def toString: String =
+ "ExecutorState(appId=%s, execId=%d, cores=%d, state=%s)".format(appId, execId, cores, state)
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
new file mode 100644
index 0000000000..668032a3a2
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
@@ -0,0 +1,420 @@
+/*
+ *
+ * * Licensed to the Apache Software Foundation (ASF) under one or more
+ * * contributor license agreements. See the NOTICE file distributed with
+ * * this work for additional information regarding copyright ownership.
+ * * The ASF licenses this file to You under the Apache License, Version 2.0
+ * * (the "License"); you may not use this file except in compliance with
+ * * the License. You may obtain a copy of the License at
+ * *
+ * * http://www.apache.org/licenses/LICENSE-2.0
+ * *
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS,
+ * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * * See the License for the specific language governing permissions and
+ * * limitations under the License.
+ *
+ */
+
+package org.apache.spark.deploy
+
+import java.io._
+import java.net.URL
+import java.util.concurrent.TimeoutException
+
+import scala.concurrent.{Await, future, promise}
+import scala.concurrent.duration._
+import scala.concurrent.ExecutionContext.Implicits.global
+import scala.collection.mutable.ListBuffer
+import scala.sys.process._
+
+import net.liftweb.json.JsonParser
+
+import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.deploy.master.RecoveryState
+
+/**
+ * This suite tests the fault tolerance of the Spark standalone scheduler, mainly the Master.
+ * In order to mimic a real distributed cluster more closely, Docker is used.
+ * Execute using
+ * ./spark-class org.apache.spark.deploy.FaultToleranceTest
+ *
+ * Make sure that that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS:
+ * - spark.deploy.recoveryMode=ZOOKEEPER
+ * - spark.deploy.zookeeper.url=172.17.42.1:2181
+ * Note that 172.17.42.1 is the default docker ip for the host and 2181 is the default ZK port.
+ *
+ * Unfortunately, due to the Docker dependency this suite cannot be run automatically without a
+ * working installation of Docker. In addition to having Docker, the following are assumed:
+ * - Docker can run without sudo (see http://docs.docker.io/en/latest/use/basics/)
+ * - The docker images tagged spark-test-master and spark-test-worker are built from the
+ * docker/ directory. Run 'docker/spark-test/build' to generate these.
+ */
+private[spark] object FaultToleranceTest extends App with Logging {
+ val masters = ListBuffer[TestMasterInfo]()
+ val workers = ListBuffer[TestWorkerInfo]()
+ var sc: SparkContext = _
+
+ var numPassed = 0
+ var numFailed = 0
+
+ val sparkHome = System.getenv("SPARK_HOME")
+ assertTrue(sparkHome != null, "Run with a valid SPARK_HOME")
+
+ val containerSparkHome = "/opt/spark"
+ val dockerMountDir = "%s:%s".format(sparkHome, containerSparkHome)
+
+ System.setProperty("spark.driver.host", "172.17.42.1") // default docker host ip
+
+ def afterEach() {
+ if (sc != null) {
+ sc.stop()
+ sc = null
+ }
+ terminateCluster()
+ }
+
+ test("sanity-basic") {
+ addMasters(1)
+ addWorkers(1)
+ createClient()
+ assertValidClusterState()
+ }
+
+ test("sanity-many-masters") {
+ addMasters(3)
+ addWorkers(3)
+ createClient()
+ assertValidClusterState()
+ }
+
+ test("single-master-halt") {
+ addMasters(3)
+ addWorkers(2)
+ createClient()
+ assertValidClusterState()
+
+ killLeader()
+ delay(30 seconds)
+ assertValidClusterState()
+ createClient()
+ assertValidClusterState()
+ }
+
+ test("single-master-restart") {
+ addMasters(1)
+ addWorkers(2)
+ createClient()
+ assertValidClusterState()
+
+ killLeader()
+ addMasters(1)
+ delay(30 seconds)
+ assertValidClusterState()
+
+ killLeader()
+ addMasters(1)
+ delay(30 seconds)
+ assertValidClusterState()
+ }
+
+ test("cluster-failure") {
+ addMasters(2)
+ addWorkers(2)
+ createClient()
+ assertValidClusterState()
+
+ terminateCluster()
+ addMasters(2)
+ addWorkers(2)
+ assertValidClusterState()
+ }
+
+ test("all-but-standby-failure") {
+ addMasters(2)
+ addWorkers(2)
+ createClient()
+ assertValidClusterState()
+
+ killLeader()
+ workers.foreach(_.kill())
+ workers.clear()
+ delay(30 seconds)
+ addWorkers(2)
+ assertValidClusterState()
+ }
+
+ test("rolling-outage") {
+ addMasters(1)
+ delay()
+ addMasters(1)
+ delay()
+ addMasters(1)
+ addWorkers(2)
+ createClient()
+ assertValidClusterState()
+ assertTrue(getLeader == masters.head)
+
+ (1 to 3).foreach { _ =>
+ killLeader()
+ delay(30 seconds)
+ assertValidClusterState()
+ assertTrue(getLeader == masters.head)
+ addMasters(1)
+ }
+ }
+
+ def test(name: String)(fn: => Unit) {
+ try {
+ fn
+ numPassed += 1
+ logInfo("Passed: " + name)
+ } catch {
+ case e: Exception =>
+ numFailed += 1
+ logError("FAILED: " + name, e)
+ }
+ afterEach()
+ }
+
+ def addMasters(num: Int) {
+ (1 to num).foreach { _ => masters += SparkDocker.startMaster(dockerMountDir) }
+ }
+
+ def addWorkers(num: Int) {
+ val masterUrls = getMasterUrls(masters)
+ (1 to num).foreach { _ => workers += SparkDocker.startWorker(dockerMountDir, masterUrls) }
+ }
+
+ /** Creates a SparkContext, which constructs a Client to interact with our cluster. */
+ def createClient() = {
+ if (sc != null) { sc.stop() }
+ // Counter-hack: Because of a hack in SparkEnv#createFromSystemProperties() that changes this
+ // property, we need to reset it.
+ System.setProperty("spark.driver.port", "0")
+ sc = new SparkContext(getMasterUrls(masters), "fault-tolerance", containerSparkHome)
+ }
+
+ def getMasterUrls(masters: Seq[TestMasterInfo]): String = {
+ "spark://" + masters.map(master => master.ip + ":7077").mkString(",")
+ }
+
+ def getLeader: TestMasterInfo = {
+ val leaders = masters.filter(_.state == RecoveryState.ALIVE)
+ assertTrue(leaders.size == 1)
+ leaders(0)
+ }
+
+ def killLeader(): Unit = {
+ masters.foreach(_.readState())
+ val leader = getLeader
+ masters -= leader
+ leader.kill()
+ }
+
+ def delay(secs: Duration = 5.seconds) = Thread.sleep(secs.toMillis)
+
+ def terminateCluster() {
+ masters.foreach(_.kill())
+ workers.foreach(_.kill())
+ masters.clear()
+ workers.clear()
+ }
+
+ /** This includes Client retry logic, so it may take a while if the cluster is recovering. */
+ def assertUsable() = {
+ val f = future {
+ try {
+ val res = sc.parallelize(0 until 10).collect()
+ assertTrue(res.toList == (0 until 10))
+ true
+ } catch {
+ case e: Exception =>
+ logError("assertUsable() had exception", e)
+ e.printStackTrace()
+ false
+ }
+ }
+
+ // Avoid waiting indefinitely (e.g., we could register but get no executors).
+ assertTrue(Await.result(f, 120 seconds))
+ }
+
+ /**
+ * Asserts that the cluster is usable and that the expected masters and workers
+ * are all alive in a proper configuration (e.g., only one leader).
+ */
+ def assertValidClusterState() = {
+ assertUsable()
+ var numAlive = 0
+ var numStandby = 0
+ var numLiveApps = 0
+ var liveWorkerIPs: Seq[String] = List()
+
+ def stateValid(): Boolean = {
+ (workers.map(_.ip) -- liveWorkerIPs).isEmpty &&
+ numAlive == 1 && numStandby == masters.size - 1 && numLiveApps >= 1
+ }
+
+ val f = future {
+ try {
+ while (!stateValid()) {
+ Thread.sleep(1000)
+
+ numAlive = 0
+ numStandby = 0
+ numLiveApps = 0
+
+ masters.foreach(_.readState())
+
+ for (master <- masters) {
+ master.state match {
+ case RecoveryState.ALIVE =>
+ numAlive += 1
+ liveWorkerIPs = master.liveWorkerIPs
+ case RecoveryState.STANDBY =>
+ numStandby += 1
+ case _ => // ignore
+ }
+
+ numLiveApps += master.numLiveApps
+ }
+ }
+ true
+ } catch {
+ case e: Exception =>
+ logError("assertValidClusterState() had exception", e)
+ false
+ }
+ }
+
+ try {
+ assertTrue(Await.result(f, 120 seconds))
+ } catch {
+ case e: TimeoutException =>
+ logError("Master states: " + masters.map(_.state))
+ logError("Num apps: " + numLiveApps)
+ logError("IPs expected: " + workers.map(_.ip) + " / found: " + liveWorkerIPs)
+ throw new RuntimeException("Failed to get into acceptable cluster state after 2 min.", e)
+ }
+ }
+
+ def assertTrue(bool: Boolean, message: String = "") {
+ if (!bool) {
+ throw new IllegalStateException("Assertion failed: " + message)
+ }
+ }
+
+ logInfo("Ran %s tests, %s passed and %s failed".format(numPassed+numFailed, numPassed, numFailed))
+}
+
+private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val logFile: File)
+ extends Logging {
+
+ implicit val formats = net.liftweb.json.DefaultFormats
+ var state: RecoveryState.Value = _
+ var liveWorkerIPs: List[String] = _
+ var numLiveApps = 0
+
+ logDebug("Created master: " + this)
+
+ def readState() {
+ try {
+ val masterStream = new InputStreamReader(new URL("http://%s:8080/json".format(ip)).openStream)
+ val json = JsonParser.parse(masterStream, closeAutomatically = true)
+
+ val workers = json \ "workers"
+ val liveWorkers = workers.children.filter(w => (w \ "state").extract[String] == "ALIVE")
+ liveWorkerIPs = liveWorkers.map(w => (w \ "host").extract[String])
+
+ numLiveApps = (json \ "activeapps").children.size
+
+ val status = json \\ "status"
+ val stateString = status.extract[String]
+ state = RecoveryState.values.filter(state => state.toString == stateString).head
+ } catch {
+ case e: Exception =>
+ // ignore, no state update
+ logWarning("Exception", e)
+ }
+ }
+
+ def kill() { Docker.kill(dockerId) }
+
+ override def toString: String =
+ "[ip=%s, id=%s, logFile=%s, state=%s]".
+ format(ip, dockerId.id, logFile.getAbsolutePath, state)
+}
+
+private[spark] class TestWorkerInfo(val ip: String, val dockerId: DockerId, val logFile: File)
+ extends Logging {
+
+ implicit val formats = net.liftweb.json.DefaultFormats
+
+ logDebug("Created worker: " + this)
+
+ def kill() { Docker.kill(dockerId) }
+
+ override def toString: String =
+ "[ip=%s, id=%s, logFile=%s]".format(ip, dockerId, logFile.getAbsolutePath)
+}
+
+private[spark] object SparkDocker {
+ def startMaster(mountDir: String): TestMasterInfo = {
+ val cmd = Docker.makeRunCmd("spark-test-master", mountDir = mountDir)
+ val (ip, id, outFile) = startNode(cmd)
+ new TestMasterInfo(ip, id, outFile)
+ }
+
+ def startWorker(mountDir: String, masters: String): TestWorkerInfo = {
+ val cmd = Docker.makeRunCmd("spark-test-worker", args = masters, mountDir = mountDir)
+ val (ip, id, outFile) = startNode(cmd)
+ new TestWorkerInfo(ip, id, outFile)
+ }
+
+ private def startNode(dockerCmd: ProcessBuilder) : (String, DockerId, File) = {
+ val ipPromise = promise[String]()
+ val outFile = File.createTempFile("fault-tolerance-test", "")
+ outFile.deleteOnExit()
+ val outStream: FileWriter = new FileWriter(outFile)
+ def findIpAndLog(line: String): Unit = {
+ if (line.startsWith("CONTAINER_IP=")) {
+ val ip = line.split("=")(1)
+ ipPromise.success(ip)
+ }
+
+ outStream.write(line + "\n")
+ outStream.flush()
+ }
+
+ dockerCmd.run(ProcessLogger(findIpAndLog _))
+ val ip = Await.result(ipPromise.future, 30 seconds)
+ val dockerId = Docker.getLastProcessId
+ (ip, dockerId, outFile)
+ }
+}
+
+private[spark] class DockerId(val id: String) {
+ override def toString = id
+}
+
+private[spark] object Docker extends Logging {
+ def makeRunCmd(imageTag: String, args: String = "", mountDir: String = ""): ProcessBuilder = {
+ val mountCmd = if (mountDir != "") { " -v " + mountDir } else ""
+
+ val cmd = "docker run %s %s %s".format(mountCmd, imageTag, args)
+ logDebug("Run command: " + cmd)
+ cmd
+ }
+
+ def kill(dockerId: DockerId) : Unit = {
+ "docker kill %s".format(dockerId.id).!
+ }
+
+ def getLastProcessId: DockerId = {
+ var id: String = null
+ "docker ps -l -q".!(ProcessLogger(line => id = line))
+ new DockerId(id)
+ }
+} \ No newline at end of file
diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
index 87a703427c..e607b8c6f4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
@@ -41,6 +41,7 @@ private[spark] object JsonProtocol {
("starttime" -> obj.startTime) ~
("id" -> obj.id) ~
("name" -> obj.desc.name) ~
+ ("appuiurl" -> obj.appUiUrl) ~
("cores" -> obj.desc.maxCores) ~
("user" -> obj.desc.user) ~
("memoryperslave" -> obj.desc.memoryPerSlave) ~
@@ -64,14 +65,15 @@ private[spark] object JsonProtocol {
}
def writeMasterState(obj: MasterStateResponse) = {
- ("url" -> ("spark://" + obj.uri)) ~
+ ("url" -> obj.uri) ~
("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~
("cores" -> obj.workers.map(_.cores).sum) ~
("coresused" -> obj.workers.map(_.coresUsed).sum) ~
("memory" -> obj.workers.map(_.memory).sum) ~
("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~
("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~
- ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo))
+ ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~
+ ("status" -> obj.status.toString)
}
def writeWorkerState(obj: WorkerStateResponse) = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
index 10161c8204..308a2bfa22 100644
--- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
@@ -39,22 +39,23 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
private val masterActorSystems = ArrayBuffer[ActorSystem]()
private val workerActorSystems = ArrayBuffer[ActorSystem]()
- def start(): String = {
+ def start(): Array[String] = {
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0)
masterActorSystems += masterSystem
val masterUrl = "spark://" + localHostname + ":" + masterPort
+ val masters = Array(masterUrl)
/* Start the Workers */
for (workerNum <- 1 to numWorkers) {
val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker,
- memoryPerWorker, masterUrl, null, Some(workerNum))
+ memoryPerWorker, masters, null, Some(workerNum))
workerActorSystems += workerSystem
}
- return masterUrl
+ return masters
}
def stop() {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 0a5f4c368f..993ba6bd3d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -16,6 +16,9 @@
*/
package org.apache.spark.deploy
+
+import com.google.common.collect.MapMaker
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
@@ -24,11 +27,16 @@ import org.apache.hadoop.mapred.JobConf
* Contains util methods to interact with Hadoop from spark.
*/
class SparkHadoopUtil {
+ // A general, soft-reference map for metadata needed during HadoopRDD split computation
+ // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
+ private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
- // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
+ // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop
+ // subsystems
def newConfiguration(): Configuration = new Configuration()
- // add any user credentials to the job conf which are necessary for running on a secure Hadoop cluster
+ // Add any user credentials to the job conf which are necessary for running on a secure Hadoop
+ // cluster
def addCredentials(conf: JobConf) {}
def isYarnMode(): Boolean = { false }
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
index a342dd724a..77422f61ec 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
@@ -23,6 +23,7 @@ import akka.actor._
import akka.actor.Terminated
import akka.pattern.ask
import akka.util.Duration
+import akka.util.duration._
import akka.remote.RemoteClientDisconnected
import akka.remote.RemoteClientLifeCycleEvent
import akka.remote.RemoteClientShutdown
@@ -37,41 +38,81 @@ import org.apache.spark.deploy.master.Master
/**
* The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description,
* and a listener for cluster events, and calls back the listener when various events occur.
+ *
+ * @param masterUrls Each url should look like spark://host:port.
*/
private[spark] class Client(
actorSystem: ActorSystem,
- masterUrl: String,
+ masterUrls: Array[String],
appDescription: ApplicationDescription,
listener: ClientListener)
extends Logging {
+ val REGISTRATION_TIMEOUT = 20.seconds
+ val REGISTRATION_RETRIES = 3
+
var actor: ActorRef = null
var appId: String = null
+ var registered = false
+ var activeMasterUrl: String = null
class ClientActor extends Actor with Logging {
var master: ActorRef = null
var masterAddress: Address = null
var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
+ var alreadyDead = false // To avoid calling listener.dead() multiple times
override def preStart() {
- logInfo("Connecting to master " + masterUrl)
try {
- master = context.actorFor(Master.toAkkaUrl(masterUrl))
- masterAddress = master.path.address
- master ! RegisterApplication(appDescription)
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
+ registerWithMaster()
} catch {
case e: Exception =>
- logError("Failed to connect to master", e)
+ logWarning("Failed to connect to master", e)
markDisconnected()
context.stop(self)
}
}
+ def tryRegisterAllMasters() {
+ for (masterUrl <- masterUrls) {
+ logInfo("Connecting to master " + masterUrl + "...")
+ val actor = context.actorFor(Master.toAkkaUrl(masterUrl))
+ actor ! RegisterApplication(appDescription)
+ }
+ }
+
+ def registerWithMaster() {
+ tryRegisterAllMasters()
+
+ var retries = 0
+ lazy val retryTimer: Cancellable =
+ context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
+ retries += 1
+ if (registered) {
+ retryTimer.cancel()
+ } else if (retries >= REGISTRATION_RETRIES) {
+ logError("All masters are unresponsive! Giving up.")
+ markDead()
+ } else {
+ tryRegisterAllMasters()
+ }
+ }
+ retryTimer // start timer
+ }
+
+ def changeMaster(url: String) {
+ activeMasterUrl = url
+ master = context.actorFor(Master.toAkkaUrl(url))
+ masterAddress = master.path.address
+ context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.watch(master) // Doesn't work with remote actors, but useful for testing
+ }
+
override def receive = {
- case RegisteredApplication(appId_) =>
+ case RegisteredApplication(appId_, masterUrl) =>
appId = appId_
+ registered = true
+ changeMaster(masterUrl)
listener.connected(appId)
case ApplicationRemoved(message) =>
@@ -92,23 +133,27 @@ private[spark] class Client(
listener.executorRemoved(fullId, message.getOrElse(""), exitStatus)
}
+ case MasterChanged(masterUrl, masterWebUiUrl) =>
+ logInfo("Master has changed, new master is at " + masterUrl)
+ context.unwatch(master)
+ changeMaster(masterUrl)
+ alreadyDisconnected = false
+ sender ! MasterChangeAcknowledged(appId)
+
case Terminated(actor_) if actor_ == master =>
- logError("Connection to master failed; stopping client")
+ logWarning("Connection to master failed; waiting for master to reconnect...")
markDisconnected()
- context.stop(self)
case RemoteClientDisconnected(transport, address) if address == masterAddress =>
- logError("Connection to master failed; stopping client")
+ logWarning("Connection to master failed; waiting for master to reconnect...")
markDisconnected()
- context.stop(self)
case RemoteClientShutdown(transport, address) if address == masterAddress =>
- logError("Connection to master failed; stopping client")
+ logWarning("Connection to master failed; waiting for master to reconnect...")
markDisconnected()
- context.stop(self)
case StopClient =>
- markDisconnected()
+ markDead()
sender ! true
context.stop(self)
}
@@ -122,6 +167,13 @@ private[spark] class Client(
alreadyDisconnected = true
}
}
+
+ def markDead() {
+ if (!alreadyDead) {
+ listener.dead()
+ alreadyDead = true
+ }
+ }
}
def start() {
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala
index 4605368c11..be7a11bd15 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala
@@ -27,8 +27,12 @@ package org.apache.spark.deploy.client
private[spark] trait ClientListener {
def connected(appId: String): Unit
+ /** Disconnection may be a temporary state, as we fail over to a new Master. */
def disconnected(): Unit
+ /** Dead means that we couldn't find any Masters to connect to, and have given up. */
+ def dead(): Unit
+
def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit
def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
index d5e9a0e095..5b62d3ba6c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
@@ -33,6 +33,11 @@ private[spark] object TestClient {
System.exit(0)
}
+ def dead() {
+ logInfo("Could not connect to master")
+ System.exit(0)
+ }
+
def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {}
def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {}
@@ -44,7 +49,7 @@ private[spark] object TestClient {
val desc = new ApplicationDescription(
"TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored")
val listener = new TestListener
- val client = new Client(actorSystem, url, desc, listener)
+ val client = new Client(actorSystem, Array(url), desc, listener)
client.start()
actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index bd5327627a..5150b7c7de 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -29,23 +29,46 @@ private[spark] class ApplicationInfo(
val submitDate: Date,
val driver: ActorRef,
val appUiUrl: String)
-{
- var state = ApplicationState.WAITING
- var executors = new mutable.HashMap[Int, ExecutorInfo]
- var coresGranted = 0
- var endTime = -1L
- val appSource = new ApplicationSource(this)
-
- private var nextExecutorId = 0
-
- def newExecutorId(): Int = {
- val id = nextExecutorId
- nextExecutorId += 1
- id
+ extends Serializable {
+
+ @transient var state: ApplicationState.Value = _
+ @transient var executors: mutable.HashMap[Int, ExecutorInfo] = _
+ @transient var coresGranted: Int = _
+ @transient var endTime: Long = _
+ @transient var appSource: ApplicationSource = _
+
+ @transient private var nextExecutorId: Int = _
+
+ init()
+
+ private def readObject(in: java.io.ObjectInputStream) : Unit = {
+ in.defaultReadObject()
+ init()
+ }
+
+ private def init() {
+ state = ApplicationState.WAITING
+ executors = new mutable.HashMap[Int, ExecutorInfo]
+ coresGranted = 0
+ endTime = -1L
+ appSource = new ApplicationSource(this)
+ nextExecutorId = 0
+ }
+
+ private def newExecutorId(useID: Option[Int] = None): Int = {
+ useID match {
+ case Some(id) =>
+ nextExecutorId = math.max(nextExecutorId, id + 1)
+ id
+ case None =>
+ val id = nextExecutorId
+ nextExecutorId += 1
+ id
+ }
}
- def addExecutor(worker: WorkerInfo, cores: Int): ExecutorInfo = {
- val exec = new ExecutorInfo(newExecutorId(), this, worker, cores, desc.memoryPerSlave)
+ def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorInfo = {
+ val exec = new ExecutorInfo(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave)
executors(exec.id) = exec
coresGranted += cores
exec
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala
index 5a24042e14..c87b66f047 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala
@@ -34,7 +34,7 @@ class ApplicationSource(val application: ApplicationInfo) extends Source {
override def getValue: Long = application.duration
})
- metricRegistry.register(MetricRegistry.name("cores", "number"), new Gauge[Int] {
+ metricRegistry.register(MetricRegistry.name("cores"), new Gauge[Int] {
override def getValue: Int = application.coresGranted
})
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
index 7e804223cf..fedf879eff 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
@@ -18,11 +18,11 @@
package org.apache.spark.deploy.master
private[spark] object ApplicationState
- extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") {
+ extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED", "UNKNOWN") {
type ApplicationState = Value
- val WAITING, RUNNING, FINISHED, FAILED = Value
+ val WAITING, RUNNING, FINISHED, FAILED, UNKNOWN = Value
val MAX_NUM_RETRY = 10
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala
index cf384a985e..76db61dd61 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy.master
-import org.apache.spark.deploy.ExecutorState
+import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
private[spark] class ExecutorInfo(
val id: Int,
@@ -28,5 +28,10 @@ private[spark] class ExecutorInfo(
var state = ExecutorState.LAUNCHING
+ /** Copy all state (non-val) variables from the given on-the-wire ExecutorDescription. */
+ def copyState(execDesc: ExecutorDescription) {
+ state = execDesc.state
+ }
+
def fullId: String = application.id + "/" + id
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
new file mode 100644
index 0000000000..c0849ef324
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import java.io._
+
+import scala.Serializable
+
+import akka.serialization.Serialization
+import org.apache.spark.Logging
+
+/**
+ * Stores data in a single on-disk directory with one file per application and worker.
+ * Files are deleted when applications and workers are removed.
+ *
+ * @param dir Directory to store files. Created if non-existent (but not recursively).
+ * @param serialization Used to serialize our objects.
+ */
+private[spark] class FileSystemPersistenceEngine(
+ val dir: String,
+ val serialization: Serialization)
+ extends PersistenceEngine with Logging {
+
+ new File(dir).mkdir()
+
+ override def addApplication(app: ApplicationInfo) {
+ val appFile = new File(dir + File.separator + "app_" + app.id)
+ serializeIntoFile(appFile, app)
+ }
+
+ override def removeApplication(app: ApplicationInfo) {
+ new File(dir + File.separator + "app_" + app.id).delete()
+ }
+
+ override def addWorker(worker: WorkerInfo) {
+ val workerFile = new File(dir + File.separator + "worker_" + worker.id)
+ serializeIntoFile(workerFile, worker)
+ }
+
+ override def removeWorker(worker: WorkerInfo) {
+ new File(dir + File.separator + "worker_" + worker.id).delete()
+ }
+
+ override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = {
+ val sortedFiles = new File(dir).listFiles().sortBy(_.getName)
+ val appFiles = sortedFiles.filter(_.getName.startsWith("app_"))
+ val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
+ val workerFiles = sortedFiles.filter(_.getName.startsWith("worker_"))
+ val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
+ (apps, workers)
+ }
+
+ private def serializeIntoFile(file: File, value: Serializable) {
+ val created = file.createNewFile()
+ if (!created) { throw new IllegalStateException("Could not create file: " + file) }
+
+ val serializer = serialization.findSerializerFor(value)
+ val serialized = serializer.toBinary(value)
+
+ val out = new FileOutputStream(file)
+ out.write(serialized)
+ out.close()
+ }
+
+ def deserializeFromFile[T <: Serializable](file: File)(implicit m: Manifest[T]): T = {
+ val fileData = new Array[Byte](file.length().asInstanceOf[Int])
+ val dis = new DataInputStream(new FileInputStream(file))
+ dis.readFully(fileData)
+ dis.close()
+
+ val clazz = m.erasure.asInstanceOf[Class[T]]
+ val serializer = serialization.serializerFor(clazz)
+ serializer.fromBinary(fileData).asInstanceOf[T]
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
new file mode 100644
index 0000000000..f25a1ad3bf
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import akka.actor.{Actor, ActorRef}
+
+import org.apache.spark.deploy.master.MasterMessages.ElectedLeader
+
+/**
+ * A LeaderElectionAgent keeps track of whether the current Master is the leader, meaning it
+ * is the only Master serving requests.
+ * In addition to the API provided, the LeaderElectionAgent will use of the following messages
+ * to inform the Master of leader changes:
+ * [[org.apache.spark.deploy.master.MasterMessages.ElectedLeader ElectedLeader]]
+ * [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]]
+ */
+private[spark] trait LeaderElectionAgent extends Actor {
+ val masterActor: ActorRef
+}
+
+/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */
+private[spark] class MonarchyLeaderAgent(val masterActor: ActorRef) extends LeaderElectionAgent {
+ override def preStart() {
+ masterActor ! ElectedLeader
+ }
+
+ override def receive = {
+ case _ =>
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index bde59905bc..cd916672ac 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -27,24 +27,26 @@ import akka.actor.Terminated
import akka.dispatch.Await
import akka.pattern.ask
import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown}
+import akka.serialization.SerializationExtension
import akka.util.duration._
-import akka.util.Timeout
+import akka.util.{Duration, Timeout}
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
+import org.apache.spark.deploy.master.MasterMessages._
import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.util.{Utils, AkkaUtils}
-import akka.util.{Duration, Timeout}
-
+import org.apache.spark.util.{AkkaUtils, Utils}
private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
val RETAINED_APPLICATIONS = System.getProperty("spark.deploy.retainedApplications", "200").toInt
val REAPER_ITERATIONS = System.getProperty("spark.dead.worker.persistence", "15").toInt
-
+ val RECOVERY_DIR = System.getProperty("spark.deploy.recoveryDirectory", "")
+ val RECOVERY_MODE = System.getProperty("spark.deploy.recoveryMode", "NONE")
+
var nextAppNumber = 0
val workers = new HashSet[WorkerInfo]
val idToWorker = new HashMap[String, WorkerInfo]
@@ -74,51 +76,116 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
if (envVar != null) envVar else host
}
+ val masterUrl = "spark://" + host + ":" + port
+ var masterWebUiUrl: String = _
+
+ var state = RecoveryState.STANDBY
+
+ var persistenceEngine: PersistenceEngine = _
+
+ var leaderElectionAgent: ActorRef = _
+
// As a temporary workaround before better ways of configuring memory, we allow users to set
// a flag that will perform round-robin scheduling across the nodes (spreading out each app
// among all the nodes) instead of trying to consolidate each app onto a small # of nodes.
val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean
override def preStart() {
- logInfo("Starting Spark master at spark://" + host + ":" + port)
+ logInfo("Starting Spark master at " + masterUrl)
// Listen for remote client disconnection events, since they don't go through Akka's watch()
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
webUi.start()
+ masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort.get
context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut)
masterMetricsSystem.registerSource(masterSource)
masterMetricsSystem.start()
applicationMetricsSystem.start()
+
+ persistenceEngine = RECOVERY_MODE match {
+ case "ZOOKEEPER" =>
+ logInfo("Persisting recovery state to ZooKeeper")
+ new ZooKeeperPersistenceEngine(SerializationExtension(context.system))
+ case "FILESYSTEM" =>
+ logInfo("Persisting recovery state to directory: " + RECOVERY_DIR)
+ new FileSystemPersistenceEngine(RECOVERY_DIR, SerializationExtension(context.system))
+ case _ =>
+ new BlackHolePersistenceEngine()
+ }
+
+ leaderElectionAgent = context.actorOf(Props(
+ RECOVERY_MODE match {
+ case "ZOOKEEPER" =>
+ new ZooKeeperLeaderElectionAgent(self, masterUrl)
+ case _ =>
+ new MonarchyLeaderAgent(self)
+ }))
+ }
+
+ override def preRestart(reason: Throwable, message: Option[Any]) {
+ super.preRestart(reason, message) // calls postStop()!
+ logError("Master actor restarted due to exception", reason)
}
override def postStop() {
webUi.stop()
masterMetricsSystem.stop()
applicationMetricsSystem.stop()
+ persistenceEngine.close()
+ context.stop(leaderElectionAgent)
}
override def receive = {
- case RegisterWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress) => {
+ case ElectedLeader => {
+ val (storedApps, storedWorkers) = persistenceEngine.readPersistedData()
+ state = if (storedApps.isEmpty && storedWorkers.isEmpty)
+ RecoveryState.ALIVE
+ else
+ RecoveryState.RECOVERING
+
+ logInfo("I have been elected leader! New state: " + state)
+
+ if (state == RecoveryState.RECOVERING) {
+ beginRecovery(storedApps, storedWorkers)
+ context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis) { completeRecovery() }
+ }
+ }
+
+ case RevokedLeadership => {
+ logError("Leadership has been revoked -- master shutting down.")
+ System.exit(0)
+ }
+
+ case RegisterWorker(id, host, workerPort, cores, memory, webUiPort, publicAddress) => {
logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
host, workerPort, cores, Utils.megabytesToString(memory)))
- if (idToWorker.contains(id)) {
+ if (state == RecoveryState.STANDBY) {
+ // ignore, don't send response
+ } else if (idToWorker.contains(id)) {
sender ! RegisterWorkerFailed("Duplicate worker ID")
} else {
- addWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress)
+ val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
+ registerWorker(worker)
context.watch(sender) // This doesn't work with remote actors but helps for testing
- sender ! RegisteredWorker("http://" + masterPublicAddress + ":" + webUi.boundPort.get)
+ persistenceEngine.addWorker(worker)
+ sender ! RegisteredWorker(masterUrl, masterWebUiUrl)
schedule()
}
}
case RegisterApplication(description) => {
- logInfo("Registering app " + description.name)
- val app = addApplication(description, sender)
- logInfo("Registered app " + description.name + " with ID " + app.id)
- waitingApps += app
- context.watch(sender) // This doesn't work with remote actors but helps for testing
- sender ! RegisteredApplication(app.id)
- schedule()
+ if (state == RecoveryState.STANDBY) {
+ // ignore, don't send response
+ } else {
+ logInfo("Registering app " + description.name)
+ val app = createApplication(description, sender)
+ registerApplication(app)
+ logInfo("Registered app " + description.name + " with ID " + app.id)
+ context.watch(sender) // This doesn't work with remote actors but helps for testing
+ persistenceEngine.addApplication(app)
+ sender ! RegisteredApplication(app.id, masterUrl)
+ schedule()
+ }
}
case ExecutorStateChanged(appId, execId, state, message, exitStatus) => {
@@ -158,27 +225,63 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}
}
+ case MasterChangeAcknowledged(appId) => {
+ idToApp.get(appId) match {
+ case Some(app) =>
+ logInfo("Application has been re-registered: " + appId)
+ app.state = ApplicationState.WAITING
+ case None =>
+ logWarning("Master change ack from unknown app: " + appId)
+ }
+
+ if (canCompleteRecovery) { completeRecovery() }
+ }
+
+ case WorkerSchedulerStateResponse(workerId, executors) => {
+ idToWorker.get(workerId) match {
+ case Some(worker) =>
+ logInfo("Worker has been re-registered: " + workerId)
+ worker.state = WorkerState.ALIVE
+
+ val validExecutors = executors.filter(exec => idToApp.get(exec.appId).isDefined)
+ for (exec <- validExecutors) {
+ val app = idToApp.get(exec.appId).get
+ val execInfo = app.addExecutor(worker, exec.cores, Some(exec.execId))
+ worker.addExecutor(execInfo)
+ execInfo.copyState(exec)
+ }
+ case None =>
+ logWarning("Scheduler state from unknown worker: " + workerId)
+ }
+
+ if (canCompleteRecovery) { completeRecovery() }
+ }
+
case Terminated(actor) => {
// The disconnected actor could've been either a worker or an app; remove whichever of
// those we have an entry for in the corresponding actor hashmap
actorToWorker.get(actor).foreach(removeWorker)
actorToApp.get(actor).foreach(finishApplication)
+ if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
}
case RemoteClientDisconnected(transport, address) => {
// The disconnected client could've been either a worker or an app; remove whichever it was
addressToWorker.get(address).foreach(removeWorker)
addressToApp.get(address).foreach(finishApplication)
+ if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
}
case RemoteClientShutdown(transport, address) => {
// The disconnected client could've been either a worker or an app; remove whichever it was
addressToWorker.get(address).foreach(removeWorker)
addressToApp.get(address).foreach(finishApplication)
+ if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
}
case RequestMasterState => {
- sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray)
+ sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray,
+ state)
}
case CheckForWorkerTimeOut => {
@@ -190,6 +293,50 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}
}
+ def canCompleteRecovery =
+ workers.count(_.state == WorkerState.UNKNOWN) == 0 &&
+ apps.count(_.state == ApplicationState.UNKNOWN) == 0
+
+ def beginRecovery(storedApps: Seq[ApplicationInfo], storedWorkers: Seq[WorkerInfo]) {
+ for (app <- storedApps) {
+ logInfo("Trying to recover app: " + app.id)
+ try {
+ registerApplication(app)
+ app.state = ApplicationState.UNKNOWN
+ app.driver ! MasterChanged(masterUrl, masterWebUiUrl)
+ } catch {
+ case e: Exception => logInfo("App " + app.id + " had exception on reconnect")
+ }
+ }
+
+ for (worker <- storedWorkers) {
+ logInfo("Trying to recover worker: " + worker.id)
+ try {
+ registerWorker(worker)
+ worker.state = WorkerState.UNKNOWN
+ worker.actor ! MasterChanged(masterUrl, masterWebUiUrl)
+ } catch {
+ case e: Exception => logInfo("Worker " + worker.id + " had exception on reconnect")
+ }
+ }
+ }
+
+ def completeRecovery() {
+ // Ensure "only-once" recovery semantics using a short synchronization period.
+ synchronized {
+ if (state != RecoveryState.RECOVERING) { return }
+ state = RecoveryState.COMPLETING_RECOVERY
+ }
+
+ // Kill off any workers and apps that didn't respond to us.
+ workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker)
+ apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication)
+
+ state = RecoveryState.ALIVE
+ schedule()
+ logInfo("Recovery complete - resuming operations!")
+ }
+
/**
* Can an app use the given worker? True if the worker has enough memory and we haven't already
* launched an executor for the app on it (right now the standalone backend doesn't like having
@@ -204,6 +351,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
* every time a new app joins or resource availability changes.
*/
def schedule() {
+ if (state != RecoveryState.ALIVE) { return }
// Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app
// in the queue, then the second app, etc.
if (spreadOutApps) {
@@ -251,14 +399,13 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) {
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
- worker.actor ! LaunchExecutor(
+ worker.actor ! LaunchExecutor(masterUrl,
exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome)
exec.application.driver ! ExecutorAdded(
exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)
}
- def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
- publicAddress: String): WorkerInfo = {
+ def registerWorker(worker: WorkerInfo): Unit = {
// There may be one or more refs to dead workers on this same node (w/ different ID's),
// remove them.
workers.filter { w =>
@@ -266,12 +413,17 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}.foreach { w =>
workers -= w
}
- val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
+
+ val workerAddress = worker.actor.path.address
+ if (addressToWorker.contains(workerAddress)) {
+ logInfo("Attempted to re-register worker at same address: " + workerAddress)
+ return
+ }
+
workers += worker
idToWorker(worker.id) = worker
- actorToWorker(sender) = worker
- addressToWorker(sender.path.address) = worker
- worker
+ actorToWorker(worker.actor) = worker
+ addressToWorker(workerAddress) = worker
}
def removeWorker(worker: WorkerInfo) {
@@ -286,25 +438,36 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
exec.id, ExecutorState.LOST, Some("worker lost"), None)
exec.application.removeExecutor(exec)
}
+ persistenceEngine.removeWorker(worker)
}
- def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
+ def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
val now = System.currentTimeMillis()
val date = new Date(now)
- val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl)
+ new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl)
+ }
+
+ def registerApplication(app: ApplicationInfo): Unit = {
+ val appAddress = app.driver.path.address
+ if (addressToWorker.contains(appAddress)) {
+ logInfo("Attempted to re-register application at same address: " + appAddress)
+ return
+ }
+
applicationMetricsSystem.registerSource(app.appSource)
apps += app
idToApp(app.id) = app
- actorToApp(driver) = app
- addressToApp(driver.path.address) = app
+ actorToApp(app.driver) = app
+ addressToApp(appAddress) = app
if (firstApp == None) {
firstApp = Some(app)
}
+ // TODO: What is firstApp?? Can we remove it?
val workersAlive = workers.filter(_.state == WorkerState.ALIVE).toArray
- if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= desc.memoryPerSlave)) {
+ if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= app.desc.memoryPerSlave)) {
logWarning("Could not find any workers with enough memory for " + firstApp.get.id)
}
- app
+ waitingApps += app
}
def finishApplication(app: ApplicationInfo) {
@@ -329,13 +492,14 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
waitingApps -= app
for (exec <- app.executors.values) {
exec.worker.removeExecutor(exec)
- exec.worker.actor ! KillExecutor(exec.application.id, exec.id)
+ exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id)
exec.state = ExecutorState.KILLED
}
app.markFinished(state)
if (state != ApplicationState.FINISHED) {
app.driver ! ApplicationRemoved(state.toString)
}
+ persistenceEngine.removeApplication(app)
schedule()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
new file mode 100644
index 0000000000..74a9f8cd82
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+sealed trait MasterMessages extends Serializable
+
+/** Contains messages seen only by the Master and its associated entities. */
+private[master] object MasterMessages {
+
+ // LeaderElectionAgent to Master
+
+ case object ElectedLeader
+
+ case object RevokedLeadership
+
+ // Actor System to LeaderElectionAgent
+
+ case object CheckLeader
+
+ // Actor System to Master
+
+ case object CheckForWorkerTimeOut
+
+ case class BeginRecovery(storedApps: Seq[ApplicationInfo], storedWorkers: Seq[WorkerInfo])
+
+ case object CompleteRecovery
+
+ case object RequestWebUIPort
+
+ case class WebUIPortResponse(webUIBoundPort: Int)
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala
index 23d1cb77da..36c1b87b7f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala
@@ -26,17 +26,17 @@ private[spark] class MasterSource(val master: Master) extends Source {
val sourceName = "master"
// Gauge for worker numbers in cluster
- metricRegistry.register(MetricRegistry.name("workers","number"), new Gauge[Int] {
+ metricRegistry.register(MetricRegistry.name("workers"), new Gauge[Int] {
override def getValue: Int = master.workers.size
})
// Gauge for application numbers in cluster
- metricRegistry.register(MetricRegistry.name("apps", "number"), new Gauge[Int] {
+ metricRegistry.register(MetricRegistry.name("apps"), new Gauge[Int] {
override def getValue: Int = master.apps.size
})
// Gauge for waiting application numbers in cluster
- metricRegistry.register(MetricRegistry.name("waitingApps", "number"), new Gauge[Int] {
+ metricRegistry.register(MetricRegistry.name("waitingApps"), new Gauge[Int] {
override def getValue: Int = master.waitingApps.size
})
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
new file mode 100644
index 0000000000..94b986caf2
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+/**
+ * Allows Master to persist any state that is necessary in order to recover from a failure.
+ * The following semantics are required:
+ * - addApplication and addWorker are called before completing registration of a new app/worker.
+ * - removeApplication and removeWorker are called at any time.
+ * Given these two requirements, we will have all apps and workers persisted, but
+ * we might not have yet deleted apps or workers that finished (so their liveness must be verified
+ * during recovery).
+ */
+private[spark] trait PersistenceEngine {
+ def addApplication(app: ApplicationInfo)
+
+ def removeApplication(app: ApplicationInfo)
+
+ def addWorker(worker: WorkerInfo)
+
+ def removeWorker(worker: WorkerInfo)
+
+ /**
+ * Returns the persisted data sorted by their respective ids (which implies that they're
+ * sorted by time of creation).
+ */
+ def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo])
+
+ def close() {}
+}
+
+private[spark] class BlackHolePersistenceEngine extends PersistenceEngine {
+ override def addApplication(app: ApplicationInfo) {}
+ override def removeApplication(app: ApplicationInfo) {}
+ override def addWorker(worker: WorkerInfo) {}
+ override def removeWorker(worker: WorkerInfo) {}
+ override def readPersistedData() = (Nil, Nil)
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala
new file mode 100644
index 0000000000..b91be821f0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+private[spark] object RecoveryState
+ extends Enumeration("STANDBY", "ALIVE", "RECOVERING", "COMPLETING_RECOVERY") {
+
+ type MasterState = Value
+
+ val STANDBY, ALIVE, RECOVERING, COMPLETING_RECOVERY = Value
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala
new file mode 100644
index 0000000000..81e15c534f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import scala.collection.JavaConversions._
+import scala.concurrent.ops._
+
+import org.apache.spark.Logging
+import org.apache.zookeeper._
+import org.apache.zookeeper.data.Stat
+import org.apache.zookeeper.Watcher.Event.KeeperState
+
+/**
+ * Provides a Scala-side interface to the standard ZooKeeper client, with the addition of retry
+ * logic. If the ZooKeeper session expires or otherwise dies, a new ZooKeeper session will be
+ * created. If ZooKeeper remains down after several retries, the given
+ * [[org.apache.spark.deploy.master.SparkZooKeeperWatcher SparkZooKeeperWatcher]] will be
+ * informed via zkDown().
+ *
+ * Additionally, all commands sent to ZooKeeper will be retried until they either fail too many
+ * times or a semantic exception is thrown (e.g.., "node already exists").
+ */
+private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher) extends Logging {
+ val ZK_URL = System.getProperty("spark.deploy.zookeeper.url", "")
+
+ val ZK_ACL = ZooDefs.Ids.OPEN_ACL_UNSAFE
+ val ZK_TIMEOUT_MILLIS = 30000
+ val RETRY_WAIT_MILLIS = 5000
+ val ZK_CHECK_PERIOD_MILLIS = 10000
+ val MAX_RECONNECT_ATTEMPTS = 3
+
+ private var zk: ZooKeeper = _
+
+ private val watcher = new ZooKeeperWatcher()
+ private var reconnectAttempts = 0
+ private var closed = false
+
+ /** Connect to ZooKeeper to start the session. Must be called before anything else. */
+ def connect() {
+ connectToZooKeeper()
+
+ new Thread() {
+ override def run() = sessionMonitorThread()
+ }.start()
+ }
+
+ def sessionMonitorThread(): Unit = {
+ while (!closed) {
+ Thread.sleep(ZK_CHECK_PERIOD_MILLIS)
+ if (zk.getState != ZooKeeper.States.CONNECTED) {
+ reconnectAttempts += 1
+ val attemptsLeft = MAX_RECONNECT_ATTEMPTS - reconnectAttempts
+ if (attemptsLeft <= 0) {
+ logError("Could not connect to ZooKeeper: system failure")
+ zkWatcher.zkDown()
+ close()
+ } else {
+ logWarning("ZooKeeper connection failed, retrying " + attemptsLeft + " more times...")
+ connectToZooKeeper()
+ }
+ }
+ }
+ }
+
+ def close() {
+ if (!closed && zk != null) { zk.close() }
+ closed = true
+ }
+
+ private def connectToZooKeeper() {
+ if (zk != null) zk.close()
+ zk = new ZooKeeper(ZK_URL, ZK_TIMEOUT_MILLIS, watcher)
+ }
+
+ /**
+ * Attempts to maintain a live ZooKeeper exception despite (very) transient failures.
+ * Mainly useful for handling the natural ZooKeeper session expiration.
+ */
+ private class ZooKeeperWatcher extends Watcher {
+ def process(event: WatchedEvent) {
+ if (closed) { return }
+
+ event.getState match {
+ case KeeperState.SyncConnected =>
+ reconnectAttempts = 0
+ zkWatcher.zkSessionCreated()
+ case KeeperState.Expired =>
+ connectToZooKeeper()
+ case KeeperState.Disconnected =>
+ logWarning("ZooKeeper disconnected, will retry...")
+ }
+ }
+ }
+
+ def create(path: String, bytes: Array[Byte], createMode: CreateMode): String = {
+ retry {
+ zk.create(path, bytes, ZK_ACL, createMode)
+ }
+ }
+
+ def exists(path: String, watcher: Watcher = null): Stat = {
+ retry {
+ zk.exists(path, watcher)
+ }
+ }
+
+ def getChildren(path: String, watcher: Watcher = null): List[String] = {
+ retry {
+ zk.getChildren(path, watcher).toList
+ }
+ }
+
+ def getData(path: String): Array[Byte] = {
+ retry {
+ zk.getData(path, false, null)
+ }
+ }
+
+ def delete(path: String, version: Int = -1): Unit = {
+ retry {
+ zk.delete(path, version)
+ }
+ }
+
+ /**
+ * Creates the given directory (non-recursively) if it doesn't exist.
+ * All znodes are created in PERSISTENT mode with no data.
+ */
+ def mkdir(path: String) {
+ if (exists(path) == null) {
+ try {
+ create(path, "".getBytes, CreateMode.PERSISTENT)
+ } catch {
+ case e: Exception =>
+ // If the exception caused the directory not to be created, bubble it up,
+ // otherwise ignore it.
+ if (exists(path) == null) { throw e }
+ }
+ }
+ }
+
+ /**
+ * Recursively creates all directories up to the given one.
+ * All znodes are created in PERSISTENT mode with no data.
+ */
+ def mkdirRecursive(path: String) {
+ var fullDir = ""
+ for (dentry <- path.split("/").tail) {
+ fullDir += "/" + dentry
+ mkdir(fullDir)
+ }
+ }
+
+ /**
+ * Retries the given function up to 3 times. The assumption is that failure is transient,
+ * UNLESS it is a semantic exception (i.e., trying to get data from a node that doesn't exist),
+ * in which case the exception will be thrown without retries.
+ *
+ * @param fn Block to execute, possibly multiple times.
+ */
+ def retry[T](fn: => T, n: Int = MAX_RECONNECT_ATTEMPTS): T = {
+ try {
+ fn
+ } catch {
+ case e: KeeperException.NoNodeException => throw e
+ case e: KeeperException.NodeExistsException => throw e
+ case e if n > 0 =>
+ logError("ZooKeeper exception, " + n + " more retries...", e)
+ Thread.sleep(RETRY_WAIT_MILLIS)
+ retry(fn, n-1)
+ }
+ }
+}
+
+trait SparkZooKeeperWatcher {
+ /**
+ * Called whenever a ZK session is created --
+ * this will occur when we create our first session as well as each time
+ * the session expires or errors out.
+ */
+ def zkSessionCreated()
+
+ /**
+ * Called if ZK appears to be completely down (i.e., not just a transient error).
+ * We will no longer attempt to reconnect to ZK, and the SparkZooKeeperSession is considered dead.
+ */
+ def zkDown()
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
index 6219f11f2a..e05f587b58 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
@@ -22,28 +22,44 @@ import scala.collection.mutable
import org.apache.spark.util.Utils
private[spark] class WorkerInfo(
- val id: String,
- val host: String,
- val port: Int,
- val cores: Int,
- val memory: Int,
- val actor: ActorRef,
- val webUiPort: Int,
- val publicAddress: String) {
+ val id: String,
+ val host: String,
+ val port: Int,
+ val cores: Int,
+ val memory: Int,
+ val actor: ActorRef,
+ val webUiPort: Int,
+ val publicAddress: String)
+ extends Serializable {
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
- var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info
- var state: WorkerState.Value = WorkerState.ALIVE
- var coresUsed = 0
- var memoryUsed = 0
+ @transient var executors: mutable.HashMap[String, ExecutorInfo] = _ // fullId => info
+ @transient var state: WorkerState.Value = _
+ @transient var coresUsed: Int = _
+ @transient var memoryUsed: Int = _
- var lastHeartbeat = System.currentTimeMillis()
+ @transient var lastHeartbeat: Long = _
+
+ init()
def coresFree: Int = cores - coresUsed
def memoryFree: Int = memory - memoryUsed
+ private def readObject(in: java.io.ObjectInputStream) : Unit = {
+ in.defaultReadObject()
+ init()
+ }
+
+ private def init() {
+ executors = new mutable.HashMap
+ state = WorkerState.ALIVE
+ coresUsed = 0
+ memoryUsed = 0
+ lastHeartbeat = System.currentTimeMillis()
+ }
+
def hostPort: String = {
assert (port > 0)
host + ":" + port
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
index b5ee6dca79..c8d34f25e2 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
@@ -17,8 +17,10 @@
package org.apache.spark.deploy.master
-private[spark] object WorkerState extends Enumeration("ALIVE", "DEAD", "DECOMMISSIONED") {
+private[spark] object WorkerState
+ extends Enumeration("ALIVE", "DEAD", "DECOMMISSIONED", "UNKNOWN") {
+
type WorkerState = Value
- val ALIVE, DEAD, DECOMMISSIONED = Value
+ val ALIVE, DEAD, DECOMMISSIONED, UNKNOWN = Value
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
new file mode 100644
index 0000000000..7809013e83
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import akka.actor.ActorRef
+import org.apache.zookeeper._
+import org.apache.zookeeper.Watcher.Event.EventType
+
+import org.apache.spark.deploy.master.MasterMessages._
+import org.apache.spark.Logging
+
+private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, masterUrl: String)
+ extends LeaderElectionAgent with SparkZooKeeperWatcher with Logging {
+
+ val WORKING_DIR = System.getProperty("spark.deploy.zookeeper.dir", "/spark") + "/leader_election"
+
+ private val watcher = new ZooKeeperWatcher()
+ private val zk = new SparkZooKeeperSession(this)
+ private var status = LeadershipStatus.NOT_LEADER
+ private var myLeaderFile: String = _
+ private var leaderUrl: String = _
+
+ override def preStart() {
+ logInfo("Starting ZooKeeper LeaderElection agent")
+ zk.connect()
+ }
+
+ override def zkSessionCreated() {
+ synchronized {
+ zk.mkdirRecursive(WORKING_DIR)
+ myLeaderFile =
+ zk.create(WORKING_DIR + "/master_", masterUrl.getBytes, CreateMode.EPHEMERAL_SEQUENTIAL)
+ self ! CheckLeader
+ }
+ }
+
+ override def preRestart(reason: scala.Throwable, message: scala.Option[scala.Any]) {
+ logError("LeaderElectionAgent failed, waiting " + zk.ZK_TIMEOUT_MILLIS + "...", reason)
+ Thread.sleep(zk.ZK_TIMEOUT_MILLIS)
+ super.preRestart(reason, message)
+ }
+
+ override def zkDown() {
+ logError("ZooKeeper down! LeaderElectionAgent shutting down Master.")
+ System.exit(1)
+ }
+
+ override def postStop() {
+ zk.close()
+ }
+
+ override def receive = {
+ case CheckLeader => checkLeader()
+ }
+
+ private class ZooKeeperWatcher extends Watcher {
+ def process(event: WatchedEvent) {
+ if (event.getType == EventType.NodeDeleted) {
+ logInfo("Leader file disappeared, a master is down!")
+ self ! CheckLeader
+ }
+ }
+ }
+
+ /** Uses ZK leader election. Navigates several ZK potholes along the way. */
+ def checkLeader() {
+ val masters = zk.getChildren(WORKING_DIR).toList
+ val leader = masters.sorted.head
+ val leaderFile = WORKING_DIR + "/" + leader
+
+ // Setup a watch for the current leader.
+ zk.exists(leaderFile, watcher)
+
+ try {
+ leaderUrl = new String(zk.getData(leaderFile))
+ } catch {
+ // A NoNodeException may be thrown if old leader died since the start of this method call.
+ // This is fine -- just check again, since we're guaranteed to see the new values.
+ case e: KeeperException.NoNodeException =>
+ logInfo("Leader disappeared while reading it -- finding next leader")
+ checkLeader()
+ return
+ }
+
+ // Synchronization used to ensure no interleaving between the creation of a new session and the
+ // checking of a leader, which could cause us to delete our real leader file erroneously.
+ synchronized {
+ val isLeader = myLeaderFile == leaderFile
+ if (!isLeader && leaderUrl == masterUrl) {
+ // We found a different master file pointing to this process.
+ // This can happen in the following two cases:
+ // (1) The master process was restarted on the same node.
+ // (2) The ZK server died between creating the node and returning the name of the node.
+ // For this case, we will end up creating a second file, and MUST explicitly delete the
+ // first one, since our ZK session is still open.
+ // Note that this deletion will cause a NodeDeleted event to be fired so we check again for
+ // leader changes.
+ assert(leaderFile < myLeaderFile)
+ logWarning("Cleaning up old ZK master election file that points to this master.")
+ zk.delete(leaderFile)
+ } else {
+ updateLeadershipStatus(isLeader)
+ }
+ }
+ }
+
+ def updateLeadershipStatus(isLeader: Boolean) {
+ if (isLeader && status == LeadershipStatus.NOT_LEADER) {
+ status = LeadershipStatus.LEADER
+ masterActor ! ElectedLeader
+ } else if (!isLeader && status == LeadershipStatus.LEADER) {
+ status = LeadershipStatus.NOT_LEADER
+ masterActor ! RevokedLeadership
+ }
+ }
+
+ private object LeadershipStatus extends Enumeration {
+ type LeadershipStatus = Value
+ val LEADER, NOT_LEADER = Value
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
new file mode 100644
index 0000000000..a0233a7271
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import org.apache.spark.Logging
+import org.apache.zookeeper._
+
+import akka.serialization.Serialization
+
+class ZooKeeperPersistenceEngine(serialization: Serialization)
+ extends PersistenceEngine
+ with SparkZooKeeperWatcher
+ with Logging
+{
+ val WORKING_DIR = System.getProperty("spark.deploy.zookeeper.dir", "/spark") + "/master_status"
+
+ val zk = new SparkZooKeeperSession(this)
+
+ zk.connect()
+
+ override def zkSessionCreated() {
+ zk.mkdirRecursive(WORKING_DIR)
+ }
+
+ override def zkDown() {
+ logError("PersistenceEngine disconnected from ZooKeeper -- ZK looks down.")
+ }
+
+ override def addApplication(app: ApplicationInfo) {
+ serializeIntoFile(WORKING_DIR + "/app_" + app.id, app)
+ }
+
+ override def removeApplication(app: ApplicationInfo) {
+ zk.delete(WORKING_DIR + "/app_" + app.id)
+ }
+
+ override def addWorker(worker: WorkerInfo) {
+ serializeIntoFile(WORKING_DIR + "/worker_" + worker.id, worker)
+ }
+
+ override def removeWorker(worker: WorkerInfo) {
+ zk.delete(WORKING_DIR + "/worker_" + worker.id)
+ }
+
+ override def close() {
+ zk.close()
+ }
+
+ override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = {
+ val sortedFiles = zk.getChildren(WORKING_DIR).toList.sorted
+ val appFiles = sortedFiles.filter(_.startsWith("app_"))
+ val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
+ val workerFiles = sortedFiles.filter(_.startsWith("worker_"))
+ val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
+ (apps, workers)
+ }
+
+ private def serializeIntoFile(path: String, value: Serializable) {
+ val serializer = serialization.findSerializerFor(value)
+ val serialized = serializer.toBinary(value)
+ zk.create(path, serialized, CreateMode.PERSISTENT)
+ }
+
+ def deserializeFromFile[T <: Serializable](filename: String)(implicit m: Manifest[T]): T = {
+ val fileData = zk.getData("/spark/master_status/" + filename)
+ val clazz = m.erasure.asInstanceOf[Class[T]]
+ val serializer = serialization.serializerFor(clazz)
+ serializer.fromBinary(fileData).asInstanceOf[T]
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index e3dc30eefc..8fabc95665 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -43,7 +43,8 @@ private[spark] class ExecutorRunner(
val workerId: String,
val host: String,
val sparkHome: File,
- val workDir: File)
+ val workDir: File,
+ var state: ExecutorState.Value)
extends Logging {
val fullId = appId + "/" + execId
@@ -83,7 +84,8 @@ private[spark] class ExecutorRunner(
process.destroy()
process.waitFor()
}
- worker ! ExecutorStateChanged(appId, execId, ExecutorState.KILLED, None, None)
+ state = ExecutorState.KILLED
+ worker ! ExecutorStateChanged(appId, execId, state, None, None)
Runtime.getRuntime.removeShutdownHook(shutdownHook)
}
}
@@ -180,9 +182,9 @@ private[spark] class ExecutorRunner(
// long-lived processes only. However, in the future, we might restart the executor a few
// times on the same machine.
val exitCode = process.waitFor()
+ state = ExecutorState.FAILED
val message = "Command exited with code " + exitCode
- worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message),
- Some(exitCode))
+ worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))
} catch {
case interrupted: InterruptedException =>
logInfo("Runner thread for executor " + fullId + " interrupted")
@@ -192,8 +194,9 @@ private[spark] class ExecutorRunner(
if (process != null) {
process.destroy()
}
+ state = ExecutorState.FAILED
val message = e.getClass + ": " + e.getMessage
- worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), None)
+ worker ! ExecutorStateChanged(appId, execId, state, Some(message), None)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 09530beb3b..216d9d44ac 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -23,26 +23,28 @@ import java.io.File
import scala.collection.mutable.HashMap
-import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
+import akka.actor._
import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
import akka.util.duration._
-import org.apache.spark.{Logging}
-import org.apache.spark.deploy.ExecutorState
+import org.apache.spark.Logging
+import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.Master
import org.apache.spark.deploy.worker.ui.WorkerWebUI
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{Utils, AkkaUtils}
-
+/**
+ * @param masterUrls Each url should look like spark://host:port.
+ */
private[spark] class Worker(
host: String,
port: Int,
webUiPort: Int,
cores: Int,
memory: Int,
- masterUrl: String,
+ masterUrls: Array[String],
workDirPath: String = null)
extends Actor with Logging {
@@ -54,8 +56,18 @@ private[spark] class Worker(
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
val HEARTBEAT_MILLIS = System.getProperty("spark.worker.timeout", "60").toLong * 1000 / 4
+ val REGISTRATION_TIMEOUT = 20.seconds
+ val REGISTRATION_RETRIES = 3
+
+ // Index into masterUrls that we're currently trying to register with.
+ var masterIndex = 0
+
+ val masterLock: Object = new Object()
var master: ActorRef = null
- var masterWebUiUrl : String = ""
+ var activeMasterUrl: String = ""
+ var activeMasterWebUiUrl : String = ""
+ @volatile var registered = false
+ @volatile var connected = false
val workerId = generateWorkerId()
var sparkHome: File = null
var workDir: File = null
@@ -95,6 +107,7 @@ private[spark] class Worker(
}
override def preStart() {
+ assert(!registered)
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
host, port, cores, Utils.megabytesToString(memory)))
sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
@@ -103,44 +116,98 @@ private[spark] class Worker(
webUi = new WorkerWebUI(this, workDir, Some(webUiPort))
webUi.start()
- connectToMaster()
+ registerWithMaster()
metricsSystem.registerSource(workerSource)
metricsSystem.start()
}
- def connectToMaster() {
- logInfo("Connecting to master " + masterUrl)
- master = context.actorFor(Master.toAkkaUrl(masterUrl))
- master ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get, publicAddress)
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
+ def changeMaster(url: String, uiUrl: String) {
+ masterLock.synchronized {
+ activeMasterUrl = url
+ activeMasterWebUiUrl = uiUrl
+ master = context.actorFor(Master.toAkkaUrl(activeMasterUrl))
+ context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.watch(master) // Doesn't work with remote actors, but useful for testing
+ connected = true
+ }
+ }
+
+ def tryRegisterAllMasters() {
+ for (masterUrl <- masterUrls) {
+ logInfo("Connecting to master " + masterUrl + "...")
+ val actor = context.actorFor(Master.toAkkaUrl(masterUrl))
+ actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get,
+ publicAddress)
+ }
+ }
+
+ def registerWithMaster() {
+ tryRegisterAllMasters()
+
+ var retries = 0
+ lazy val retryTimer: Cancellable =
+ context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
+ retries += 1
+ if (registered) {
+ retryTimer.cancel()
+ } else if (retries >= REGISTRATION_RETRIES) {
+ logError("All masters are unresponsive! Giving up.")
+ System.exit(1)
+ } else {
+ tryRegisterAllMasters()
+ }
+ }
+ retryTimer // start timer
}
override def receive = {
- case RegisteredWorker(url) =>
- masterWebUiUrl = url
- logInfo("Successfully registered with master")
- context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis) {
- master ! Heartbeat(workerId)
+ case RegisteredWorker(masterUrl, masterWebUiUrl) =>
+ logInfo("Successfully registered with master " + masterUrl)
+ registered = true
+ changeMaster(masterUrl, masterWebUiUrl)
+ context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat)
+
+ case SendHeartbeat =>
+ masterLock.synchronized {
+ if (connected) { master ! Heartbeat(workerId) }
}
+ case MasterChanged(masterUrl, masterWebUiUrl) =>
+ logInfo("Master has changed, new master is at " + masterUrl)
+ context.unwatch(master)
+ changeMaster(masterUrl, masterWebUiUrl)
+
+ val execs = executors.values.
+ map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state))
+ sender ! WorkerSchedulerStateResponse(workerId, execs.toList)
+
case RegisterWorkerFailed(message) =>
- logError("Worker registration failed: " + message)
- System.exit(1)
-
- case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
- logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
- val manager = new ExecutorRunner(
- appId, execId, appDesc, cores_, memory_, self, workerId, host, new File(execSparkHome_), workDir)
- executors(appId + "/" + execId) = manager
- manager.start()
- coresUsed += cores_
- memoryUsed += memory_
- master ! ExecutorStateChanged(appId, execId, ExecutorState.RUNNING, None, None)
+ if (!registered) {
+ logError("Worker registration failed: " + message)
+ System.exit(1)
+ }
+
+ case LaunchExecutor(masterUrl, appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
+ if (masterUrl != activeMasterUrl) {
+ logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor.")
+ } else {
+ logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
+ val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_,
+ self, workerId, host, new File(execSparkHome_), workDir, ExecutorState.RUNNING)
+ executors(appId + "/" + execId) = manager
+ manager.start()
+ coresUsed += cores_
+ memoryUsed += memory_
+ masterLock.synchronized {
+ master ! ExecutorStateChanged(appId, execId, manager.state, None, None)
+ }
+ }
case ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
- master ! ExecutorStateChanged(appId, execId, state, message, exitStatus)
+ masterLock.synchronized {
+ master ! ExecutorStateChanged(appId, execId, state, message, exitStatus)
+ }
val fullId = appId + "/" + execId
if (ExecutorState.isFinished(state)) {
val executor = executors(fullId)
@@ -153,32 +220,39 @@ private[spark] class Worker(
memoryUsed -= executor.memory
}
- case KillExecutor(appId, execId) =>
- val fullId = appId + "/" + execId
- executors.get(fullId) match {
- case Some(executor) =>
- logInfo("Asked to kill executor " + fullId)
- executor.kill()
- case None =>
- logInfo("Asked to kill unknown executor " + fullId)
+ case KillExecutor(masterUrl, appId, execId) =>
+ if (masterUrl != activeMasterUrl) {
+ logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor " + execId)
+ } else {
+ val fullId = appId + "/" + execId
+ executors.get(fullId) match {
+ case Some(executor) =>
+ logInfo("Asked to kill executor " + fullId)
+ executor.kill()
+ case None =>
+ logInfo("Asked to kill unknown executor " + fullId)
+ }
}
- case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
+ case Terminated(actor_) if actor_ == master =>
+ masterDisconnected()
+
+ case RemoteClientDisconnected(transport, address) if address == master.path.address =>
+ masterDisconnected()
+
+ case RemoteClientShutdown(transport, address) if address == master.path.address =>
masterDisconnected()
case RequestWorkerState => {
sender ! WorkerStateResponse(host, port, workerId, executors.values.toList,
- finishedExecutors.values.toList, masterUrl, cores, memory,
- coresUsed, memoryUsed, masterWebUiUrl)
+ finishedExecutors.values.toList, activeMasterUrl, cores, memory,
+ coresUsed, memoryUsed, activeMasterWebUiUrl)
}
}
def masterDisconnected() {
- // TODO: It would be nice to try to reconnect to the master, but just shut down for now.
- // (Note that if reconnecting we would also need to assign IDs differently.)
- logError("Connection to master failed! Shutting down.")
- executors.values.foreach(_.kill())
- System.exit(1)
+ logError("Connection to master failed! Waiting for master to reconnect...")
+ connected = false
}
def generateWorkerId(): String = {
@@ -196,17 +270,18 @@ private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
- args.memory, args.master, args.workDir)
+ args.memory, args.masters, args.workDir)
actorSystem.awaitTermination()
}
def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int,
- masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
+ masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None)
+ : (ActorSystem, Int) = {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory,
- masterUrl, workDir)), name = "Worker")
+ masterUrls, workDir)), name = "Worker")
(actorSystem, boundPort)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
index 0ae89a864f..3ed528e6b3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
@@ -29,7 +29,7 @@ private[spark] class WorkerArguments(args: Array[String]) {
var webUiPort = 8081
var cores = inferDefaultCores()
var memory = inferDefaultMemory()
- var master: String = null
+ var masters: Array[String] = null
var workDir: String = null
// Check for settings in environment variables
@@ -86,14 +86,14 @@ private[spark] class WorkerArguments(args: Array[String]) {
printUsageAndExit(0)
case value :: tail =>
- if (master != null) { // Two positional arguments were given
+ if (masters != null) { // Two positional arguments were given
printUsageAndExit(1)
}
- master = value
+ masters = value.stripPrefix("spark://").split(",").map("spark://" + _)
parse(tail)
case Nil =>
- if (master == null) { // No positional argument was given
+ if (masters == null) { // No positional argument was given
printUsageAndExit(1)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala
index df269fd047..b7ddd8c816 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala
@@ -25,27 +25,27 @@ private[spark] class WorkerSource(val worker: Worker) extends Source {
val sourceName = "worker"
val metricRegistry = new MetricRegistry()
- metricRegistry.register(MetricRegistry.name("executors", "number"), new Gauge[Int] {
+ metricRegistry.register(MetricRegistry.name("executors"), new Gauge[Int] {
override def getValue: Int = worker.executors.size
})
// Gauge for cores used of this worker
- metricRegistry.register(MetricRegistry.name("coresUsed", "number"), new Gauge[Int] {
+ metricRegistry.register(MetricRegistry.name("coresUsed"), new Gauge[Int] {
override def getValue: Int = worker.coresUsed
})
// Gauge for memory used of this worker
- metricRegistry.register(MetricRegistry.name("memUsed", "MBytes"), new Gauge[Int] {
+ metricRegistry.register(MetricRegistry.name("memUsed_MB"), new Gauge[Int] {
override def getValue: Int = worker.memoryUsed
})
// Gauge for cores free of this worker
- metricRegistry.register(MetricRegistry.name("coresFree", "number"), new Gauge[Int] {
+ metricRegistry.register(MetricRegistry.name("coresFree"), new Gauge[Int] {
override def getValue: Int = worker.coresFree
})
// Gauge for memory free of this worker
- metricRegistry.register(MetricRegistry.name("memFree", "MBytes"), new Gauge[Int] {
+ metricRegistry.register(MetricRegistry.name("memFree_MB"), new Gauge[Int] {
override def getValue: Int = worker.memoryFree
})
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index 95d6007f3b..800f1cafcc 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -105,7 +105,7 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
val logText = <node>{Utils.offsetBytes(path, startByte, endByte)}</node>
- val linkToMaster = <p><a href={worker.masterWebUiUrl}>Back to Master</a></p>
+ val linkToMaster = <p><a href={worker.activeMasterWebUiUrl}>Back to Master</a></p>
val range = <span>Bytes {startByte.toString} - {endByte.toString} of {logLength}</span>
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index ceae3b8289..eff0c0f274 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -17,7 +17,7 @@
package org.apache.spark.executor
-import java.io.{File}
+import java.io.File
import java.lang.management.ManagementFactory
import java.nio.ByteBuffer
import java.util.concurrent._
@@ -27,11 +27,11 @@ import scala.collection.mutable.HashMap
import org.apache.spark.scheduler._
import org.apache.spark._
+import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util.Utils
-
/**
- * The Mesos executor for Spark.
+ * Spark executor used with Mesos and the standalone scheduler.
*/
private[spark] class Executor(
executorId: String,
@@ -167,12 +167,20 @@ private[spark] class Executor(
// we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could
// just change the relevants bytes in the byte buffer
val accumUpdates = Accumulators.values
- val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null))
- val serializedResult = ser.serialize(result)
- logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit)
- if (serializedResult.limit >= (akkaFrameSize - 1024)) {
- context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(TaskResultTooBigFailure()))
- return
+ val directResult = new DirectTaskResult(value, accumUpdates, task.metrics.getOrElse(null))
+ val serializedDirectResult = ser.serialize(directResult)
+ logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit)
+ val serializedResult = {
+ if (serializedDirectResult.limit >= akkaFrameSize - 1024) {
+ logInfo("Storing result for " + taskId + " in local BlockManager")
+ val blockId = TaskResultBlockId(taskId)
+ env.blockManager.putBytes(
+ blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
+ ser.serialize(new IndirectTaskResult[Any](blockId))
+ } else {
+ logInfo("Sending result for " + taskId + " directly to driver")
+ serializedDirectResult
+ }
}
context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
logInfo("Finished task ID " + taskId)
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
index 18c9dc1c0a..34ed9c8f73 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
@@ -43,31 +43,31 @@ class ExecutorSource(val executor: Executor, executorId: String) extends Source
val sourceName = "executor.%s".format(executorId)
// Gauge for executor thread pool's actively executing task counts
- metricRegistry.register(MetricRegistry.name("threadpool", "activeTask", "count"), new Gauge[Int] {
+ metricRegistry.register(MetricRegistry.name("threadpool", "activeTasks"), new Gauge[Int] {
override def getValue: Int = executor.threadPool.getActiveCount()
})
// Gauge for executor thread pool's approximate total number of tasks that have been completed
- metricRegistry.register(MetricRegistry.name("threadpool", "completeTask", "count"), new Gauge[Long] {
+ metricRegistry.register(MetricRegistry.name("threadpool", "completeTasks"), new Gauge[Long] {
override def getValue: Long = executor.threadPool.getCompletedTaskCount()
})
// Gauge for executor thread pool's current number of threads
- metricRegistry.register(MetricRegistry.name("threadpool", "currentPool", "size"), new Gauge[Int] {
+ metricRegistry.register(MetricRegistry.name("threadpool", "currentPool_size"), new Gauge[Int] {
override def getValue: Int = executor.threadPool.getPoolSize()
})
// Gauge got executor thread pool's largest number of threads that have ever simultaneously been in th pool
- metricRegistry.register(MetricRegistry.name("threadpool", "maxPool", "size"), new Gauge[Int] {
+ metricRegistry.register(MetricRegistry.name("threadpool", "maxPool_size"), new Gauge[Int] {
override def getValue: Int = executor.threadPool.getMaximumPoolSize()
})
// Gauge for file system stats of this executor
for (scheme <- Array("hdfs", "file")) {
- registerFileSystemStat(scheme, "bytesRead", _.getBytesRead(), 0L)
- registerFileSystemStat(scheme, "bytesWritten", _.getBytesWritten(), 0L)
- registerFileSystemStat(scheme, "readOps", _.getReadOps(), 0)
- registerFileSystemStat(scheme, "largeReadOps", _.getLargeReadOps(), 0)
- registerFileSystemStat(scheme, "writeOps", _.getWriteOps(), 0)
+ registerFileSystemStat(scheme, "read_bytes", _.getBytesRead(), 0L)
+ registerFileSystemStat(scheme, "write_bytes", _.getBytesWritten(), 0L)
+ registerFileSystemStat(scheme, "read_ops", _.getReadOps(), 0)
+ registerFileSystemStat(scheme, "largeRead_ops", _.getLargeReadOps(), 0)
+ registerFileSystemStat(scheme, "write_ops", _.getWriteOps(), 0)
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
index 3c29700920..1b9fa1e53a 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
@@ -20,17 +20,18 @@ package org.apache.spark.network.netty
import io.netty.buffer._
import org.apache.spark.Logging
+import org.apache.spark.storage.{TestBlockId, BlockId}
private[spark] class FileHeader (
val fileLen: Int,
- val blockId: String) extends Logging {
+ val blockId: BlockId) extends Logging {
lazy val buffer = {
val buf = Unpooled.buffer()
buf.capacity(FileHeader.HEADER_SIZE)
buf.writeInt(fileLen)
- buf.writeInt(blockId.length)
- blockId.foreach((x: Char) => buf.writeByte(x))
+ buf.writeInt(blockId.name.length)
+ blockId.name.foreach((x: Char) => buf.writeByte(x))
//padding the rest of header
if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
@@ -57,18 +58,15 @@ private[spark] object FileHeader {
for (i <- 1 to idLength) {
idBuilder += buf.readByte().asInstanceOf[Char]
}
- val blockId = idBuilder.toString()
+ val blockId = BlockId(idBuilder.toString())
new FileHeader(length, blockId)
}
-
- def main (args:Array[String]){
-
- val header = new FileHeader(25,"block_0");
- val buf = header.buffer;
- val newheader = FileHeader.create(buf);
- System.out.println("id="+newheader.blockId+",size="+newheader.fileLen)
-
+ def main (args:Array[String]) {
+ val header = new FileHeader(25, TestBlockId("my_block"))
+ val buf = header.buffer
+ val newHeader = FileHeader.create(buf)
+ System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen)
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
index 9493ccffd9..481ff8c3e0 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
@@ -27,12 +27,13 @@ import org.apache.spark.Logging
import org.apache.spark.network.ConnectionManagerId
import scala.collection.JavaConverters._
+import org.apache.spark.storage.BlockId
private[spark] class ShuffleCopier extends Logging {
- def getBlock(host: String, port: Int, blockId: String,
- resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ def getBlock(host: String, port: Int, blockId: BlockId,
+ resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
@@ -41,7 +42,7 @@ private[spark] class ShuffleCopier extends Logging {
try {
fc.init()
fc.connect(host, port)
- fc.sendRequest(blockId)
+ fc.sendRequest(blockId.name)
fc.waitForClose()
fc.close()
} catch {
@@ -53,14 +54,14 @@ private[spark] class ShuffleCopier extends Logging {
}
}
- def getBlock(cmId: ConnectionManagerId, blockId: String,
- resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ def getBlock(cmId: ConnectionManagerId, blockId: BlockId,
+ resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
}
def getBlocks(cmId: ConnectionManagerId,
- blocks: Seq[(String, Long)],
- resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ blocks: Seq[(BlockId, Long)],
+ resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
for ((blockId, size) <- blocks) {
getBlock(cmId, blockId, resultCollectCallback)
@@ -71,7 +72,7 @@ private[spark] class ShuffleCopier extends Logging {
private[spark] object ShuffleCopier extends Logging {
- private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit)
+ private class ShuffleClientHandler(resultCollectCallBack: (BlockId, Long, ByteBuf) => Unit)
extends FileClientHandler with Logging {
override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
@@ -79,14 +80,14 @@ private[spark] object ShuffleCopier extends Logging {
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
}
- override def handleError(blockId: String) {
+ override def handleError(blockId: BlockId) {
if (!isComplete) {
resultCollectCallBack(blockId, -1, null)
}
}
}
- def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
+ def echoResultCollectCallBack(blockId: BlockId, size: Long, content: ByteBuf) {
if (size != -1) {
logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
}
@@ -99,7 +100,7 @@ private[spark] object ShuffleCopier extends Logging {
}
val host = args(0)
val port = args(1).toInt
- val file = args(2)
+ val blockId = BlockId(args(2))
val threads = if (args.length > 3) args(3).toInt else 10
val copiers = Executors.newFixedThreadPool(80)
@@ -107,12 +108,12 @@ private[spark] object ShuffleCopier extends Logging {
Executors.callable(new Runnable() {
def run() {
val copier = new ShuffleCopier()
- copier.getBlock(host, port, file, echoResultCollectCallBack)
+ copier.getBlock(host, port, blockId, echoResultCollectCallBack)
}
})
}).asJava
copiers.invokeAll(tasks)
- copiers.shutdown
+ copiers.shutdown()
System.exit(0)
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
index 8afcbe190a..1586dff254 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
@@ -21,6 +21,7 @@ import java.io.File
import org.apache.spark.Logging
import org.apache.spark.util.Utils
+import org.apache.spark.storage.BlockId
private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
@@ -53,8 +54,9 @@ private[spark] object ShuffleSender {
val localDirs = args.drop(2).map(new File(_))
val pResovler = new PathResolver {
- override def getAbsolutePath(blockId: String): String = {
- if (!blockId.startsWith("shuffle_")) {
+ override def getAbsolutePath(blockIdString: String): String = {
+ val blockId = BlockId(blockIdString)
+ if (!blockId.isShuffle) {
throw new Exception("Block " + blockId + " is not a shuffle block")
}
// Figure out which local directory it hashes to, and which subdirectory in that
@@ -62,7 +64,7 @@ private[spark] object ShuffleSender {
val dirId = hash % localDirs.length
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
- val file = new File(subDir, blockId)
+ val file = new File(subDir, blockId.name)
return file.getAbsolutePath
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index bca6956a18..44ea573a7c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -18,14 +18,14 @@
package org.apache.spark.rdd
import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext}
-import org.apache.spark.storage.BlockManager
+import org.apache.spark.storage.{BlockId, BlockManager}
-private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition {
+private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends Partition {
val index = idx
}
private[spark]
-class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
+class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[BlockId])
extends RDD[T](sc, Nil) {
@transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 0187256a8e..d237797aa6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -18,13 +18,12 @@
package org.apache.spark.rdd
import java.io.{ObjectOutputStream, IOException}
-import java.util.{HashMap => JHashMap}
-import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
+import org.apache.spark.util.AppendOnlyMap
private[spark] sealed trait CoGroupSplitDep extends Serializable
@@ -105,17 +104,14 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
val split = s.asInstanceOf[CoGroupPartition]
val numRdds = split.deps.size
// e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
- val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
+ val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]]
- def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
- val seq = map.get(k)
- if (seq != null) {
- seq
- } else {
- val seq = Array.fill(numRdds)(new ArrayBuffer[Any])
- map.put(k, seq)
- seq
- }
+ val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => {
+ if (hadVal) oldVal else Array.fill(numRdds)(new ArrayBuffer[Any])
+ }
+
+ val getSeq = (k: K) => {
+ map.changeValue(k, update)
}
val ser = SparkEnv.get.serializerManager.get(serializerClass)
@@ -134,7 +130,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
}
}
}
- JavaConversions.mapAsScalaMap(map).iterator
+ map.iterator
}
override def clearDependencies() {
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 2cb6734e41..2d394abfd9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -19,6 +19,7 @@ package org.apache.spark.rdd
import java.io.EOFException
+import org.apache.hadoop.mapred.FileInputFormat
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.InputSplit
import org.apache.hadoop.mapred.JobConf
@@ -26,7 +27,9 @@ import org.apache.hadoop.mapred.RecordReader
import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.util.ReflectionUtils
-import org.apache.spark.{Logging, Partition, SerializableWritable, SparkContext, SparkEnv, TaskContext}
+import org.apache.spark.{Logging, Partition, SerializableWritable, SparkContext, SparkEnv,
+ TaskContext}
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.util.NextIterator
import org.apache.hadoop.conf.{Configuration, Configurable}
@@ -45,29 +48,95 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp
}
/**
- * An RDD that reads a Hadoop dataset as specified by a JobConf (e.g. files in HDFS, the local file
- * system, or S3, tables in HBase, etc).
+ * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS,
+ * sources in HBase, or S3).
+ *
+ * @param sc The SparkContext to associate the RDD with.
+ * @param broadCastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed
+ * variabe references an instance of JobConf, then that JobConf will be used for the Hadoop job.
+ * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration.
+ * @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that HadoopRDD
+ * creates.
+ * @param inputFormatClass Storage format of the data to be read.
+ * @param keyClass Class of the key associated with the inputFormatClass.
+ * @param valueClass Class of the value associated with the inputFormatClass.
+ * @param minSplits Minimum number of Hadoop Splits (HadoopRDD partitions) to generate.
*/
class HadoopRDD[K, V](
sc: SparkContext,
- @transient conf: JobConf,
+ broadcastedConf: Broadcast[SerializableWritable[Configuration]],
+ initLocalJobConfFuncOpt: Option[JobConf => Unit],
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int)
extends RDD[(K, V)](sc, Nil) with Logging {
- // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it
- private val confBroadcast = sc.broadcast(new SerializableWritable(conf))
+ def this(
+ sc: SparkContext,
+ conf: JobConf,
+ inputFormatClass: Class[_ <: InputFormat[K, V]],
+ keyClass: Class[K],
+ valueClass: Class[V],
+ minSplits: Int) = {
+ this(
+ sc,
+ sc.broadcast(new SerializableWritable(conf))
+ .asInstanceOf[Broadcast[SerializableWritable[Configuration]]],
+ None /* initLocalJobConfFuncOpt */,
+ inputFormatClass,
+ keyClass,
+ valueClass,
+ minSplits)
+ }
+
+ protected val jobConfCacheKey = "rdd_%d_job_conf".format(id)
+
+ protected val inputFormatCacheKey = "rdd_%d_input_format".format(id)
+
+ // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads.
+ protected def getJobConf(): JobConf = {
+ val conf: Configuration = broadcastedConf.value.value
+ if (conf.isInstanceOf[JobConf]) {
+ // A user-broadcasted JobConf was provided to the HadoopRDD, so always use it.
+ return conf.asInstanceOf[JobConf]
+ } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) {
+ // getJobConf() has been called previously, so there is already a local cache of the JobConf
+ // needed by this RDD.
+ return HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf]
+ } else {
+ // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the
+ // local process. The local cache is accessed through HadoopRDD.putCachedMetadata().
+ // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects.
+ val newJobConf = new JobConf(broadcastedConf.value.value)
+ initLocalJobConfFuncOpt.map(f => f(newJobConf))
+ HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
+ return newJobConf
+ }
+ }
+
+ protected def getInputFormat(conf: JobConf): InputFormat[K, V] = {
+ if (HadoopRDD.containsCachedMetadata(inputFormatCacheKey)) {
+ return HadoopRDD.getCachedMetadata(inputFormatCacheKey).asInstanceOf[InputFormat[K, V]]
+ }
+ // Once an InputFormat for this RDD is created, cache it so that only one reflection call is
+ // done in each local process.
+ val newInputFormat = ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf)
+ .asInstanceOf[InputFormat[K, V]]
+ if (newInputFormat.isInstanceOf[Configurable]) {
+ newInputFormat.asInstanceOf[Configurable].setConf(conf)
+ }
+ HadoopRDD.putCachedMetadata(inputFormatCacheKey, newInputFormat)
+ return newInputFormat
+ }
override def getPartitions: Array[Partition] = {
- val env = SparkEnv.get
- env.hadoop.addCredentials(conf)
- val inputFormat = createInputFormat(conf)
+ val jobConf = getJobConf()
+ val inputFormat = getInputFormat(jobConf)
if (inputFormat.isInstanceOf[Configurable]) {
- inputFormat.asInstanceOf[Configurable].setConf(conf)
+ inputFormat.asInstanceOf[Configurable].setConf(jobConf)
}
- val inputSplits = inputFormat.getSplits(conf, minSplits)
+ val inputSplits = inputFormat.getSplits(jobConf, minSplits)
val array = new Array[Partition](inputSplits.size)
for (i <- 0 until inputSplits.size) {
array(i) = new HadoopPartition(id, i, inputSplits(i))
@@ -75,22 +144,14 @@ class HadoopRDD[K, V](
array
}
- def createInputFormat(conf: JobConf): InputFormat[K, V] = {
- ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf)
- .asInstanceOf[InputFormat[K, V]]
- }
-
override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] {
val split = theSplit.asInstanceOf[HadoopPartition]
logInfo("Input split: " + split.inputSplit)
var reader: RecordReader[K, V] = null
- val conf = confBroadcast.value.value
- val fmt = createInputFormat(conf)
- if (fmt.isInstanceOf[Configurable]) {
- fmt.asInstanceOf[Configurable].setConf(conf)
- }
- reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)
+ val jobConf = getJobConf()
+ val inputFormat = getInputFormat(jobConf)
+ reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback{ () => closeIfNeeded() }
@@ -127,5 +188,18 @@ class HadoopRDD[K, V](
// Do nothing. Hadoop RDD should not be checkpointed.
}
- def getConf: Configuration = confBroadcast.value.value
+ def getConf: Configuration = getJobConf()
+}
+
+private[spark] object HadoopRDD {
+ /**
+ * The three methods below are helpers for accessing the local map, a property of the SparkEnv of
+ * the local process.
+ */
+ def getCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.get(key)
+
+ def containsCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.containsKey(key)
+
+ def putCachedMetadata(key: String, value: Any) =
+ SparkEnv.get.hadoop.hadoopJobMetadata.put(key, value)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index e9ef52bf3b..6776220835 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -753,24 +753,42 @@ abstract class RDD[T: ClassManifest](
}
/**
- * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
- * it will be slow if a lot of partitions are required. In that case, use collect() to get the
- * whole RDD instead.
+ * Take the first num elements of the RDD. It works by first scanning one partition, and use the
+ * results from that partition to estimate the number of additional partitions needed to satisfy
+ * the limit.
*/
def take(num: Int): Array[T] = {
if (num == 0) {
return new Array[T](0)
}
+
val buf = new ArrayBuffer[T]
- var p = 0
- while (buf.size < num && p < partitions.size) {
+ val totalParts = this.partitions.length
+ var partsScanned = 0
+ while (buf.size < num && partsScanned < totalParts) {
+ // The number of partitions to try in this iteration. It is ok for this number to be
+ // greater than totalParts because we actually cap it at totalParts in runJob.
+ var numPartsToTry = 1
+ if (partsScanned > 0) {
+ // If we didn't find any rows after the first iteration, just try all partitions next.
+ // Otherwise, interpolate the number of partitions we need to try, but overestimate it
+ // by 50%.
+ if (buf.size == 0) {
+ numPartsToTry = totalParts - 1
+ } else {
+ numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt
+ }
+ }
+ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
+
val left = num - buf.size
- val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, Array(p), true)
- buf ++= res(0)
- if (buf.size == num)
- return buf.toArray
- p += 1
+ val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+ val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p, allowLocal = true)
+
+ res.foreach(buf ++= _.take(num - buf.size))
+ partsScanned += numPartsToTry
}
+
return buf.toArray
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 3e3f04f087..5c51852985 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -28,9 +28,8 @@ import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
-import org.apache.spark.scheduler.cluster.TaskInfo
-import org.apache.spark.storage.{BlockManager, BlockManagerMaster}
-import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
/**
* The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of
@@ -115,7 +114,7 @@ class DAGScheduler(
private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
- private val listenerBus = new SparkListenerBus()
+ private[spark] val listenerBus = new SparkListenerBus()
// Contains the locations that each RDD's partitions are cached on
private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]]
@@ -139,7 +138,7 @@ class DAGScheduler(
val activeJobs = new HashSet[ActiveJob]
val resultStageToJob = new HashMap[Stage, ActiveJob]
- val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup)
// Start a thread to run the DAGScheduler event loop
def start() {
@@ -157,7 +156,7 @@ class DAGScheduler(
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
if (!cacheLocs.contains(rdd.id)) {
- val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
+ val blockIds = rdd.partitions.indices.map(index=> RDDBlockId(rdd.id, index)).toArray[BlockId]
val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
cacheLocs(rdd.id) = blockIds.map { id =>
locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId))
@@ -553,7 +552,7 @@ class DAGScheduler(
SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head)
} catch {
case e: NotSerializableException =>
- abortStage(stage, e.toString)
+ abortStage(stage, "Task not serializable: " + e.toString)
running -= stage
return
}
@@ -705,6 +704,9 @@ class DAGScheduler(
case ExceptionFailure(className, description, stackTrace, metrics) =>
// Do nothing here, left up to the TaskScheduler to decide how to handle user failures
+ case TaskResultLost =>
+ // Do nothing here; the TaskScheduler handles these failures and resubmits the task.
+
case other =>
// Unrecognized failure - abort all jobs depending on this stage
abortStage(stageIdToStage(task.stageId), task + " failed: " + other)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 0d99670648..10ff1b4376 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -19,7 +19,6 @@ package org.apache.spark.scheduler
import java.util.Properties
-import org.apache.spark.scheduler.cluster.TaskInfo
import scala.collection.mutable.Map
import org.apache.spark._
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
index 446d490cc9..151514896f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
@@ -27,23 +27,23 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: Spar
val metricRegistry = new MetricRegistry()
val sourceName = "%s.DAGScheduler".format(sc.appName)
- metricRegistry.register(MetricRegistry.name("stage", "failedStages", "number"), new Gauge[Int] {
+ metricRegistry.register(MetricRegistry.name("stage", "failedStages"), new Gauge[Int] {
override def getValue: Int = dagScheduler.failed.size
})
- metricRegistry.register(MetricRegistry.name("stage", "runningStages", "number"), new Gauge[Int] {
+ metricRegistry.register(MetricRegistry.name("stage", "runningStages"), new Gauge[Int] {
override def getValue: Int = dagScheduler.running.size
})
- metricRegistry.register(MetricRegistry.name("stage", "waitingStages", "number"), new Gauge[Int] {
+ metricRegistry.register(MetricRegistry.name("stage", "waitingStages"), new Gauge[Int] {
override def getValue: Int = dagScheduler.waiting.size
})
- metricRegistry.register(MetricRegistry.name("job", "allJobs", "number"), new Gauge[Int] {
+ metricRegistry.register(MetricRegistry.name("job", "allJobs"), new Gauge[Int] {
override def getValue: Int = dagScheduler.nextJobId.get()
})
- metricRegistry.register(MetricRegistry.name("job", "activeJobs", "number"), new Gauge[Int] {
+ metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] {
override def getValue: Int = dagScheduler.activeJobs.size
})
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index c8b78bf00a..3628b1b078 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -30,7 +30,6 @@ import scala.io.Source
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.scheduler.cluster.TaskInfo
// Used to record runtime information for each job, including RDD graph
// tasks' start/stop shuffle information and information from outside
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
index 35b32600da..9eb8d48501 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
@@ -15,13 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import org.apache.spark.Logging
-import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
+import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
/**
* An Schedulable entity that represent collection of Pools or TaskSetManagers
@@ -45,7 +45,7 @@ private[spark] class Pool(
var priority = 0
var stageId = 0
var name = poolName
- var parent:Schedulable = null
+ var parent: Pool = null
var taskSetSchedulingAlgorithm: SchedulingAlgorithm = {
schedulingMode match {
@@ -101,14 +101,14 @@ private[spark] class Pool(
return sortedTaskSetQueue
}
- override def increaseRunningTasks(taskNum: Int) {
+ def increaseRunningTasks(taskNum: Int) {
runningTasks += taskNum
if (parent != null) {
parent.increaseRunningTasks(taskNum)
}
}
- override def decreaseRunningTasks(taskNum: Int) {
+ def decreaseRunningTasks(taskNum: Int) {
runningTasks -= taskNum
if (parent != null) {
parent.decreaseRunningTasks(taskNum)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 07e8317e3a..6dd422bbf6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -23,7 +23,7 @@ import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.RDDCheckpointData
-import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap}
private[spark] object ResultTask {
@@ -32,7 +32,7 @@ private[spark] object ResultTask {
// expensive on the master node if it needs to launch thousands of tasks.
val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
- val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues)
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues)
def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = {
synchronized {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
index f4726450ec..1c7ea2dccc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/Schedulable.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
@@ -15,9 +15,9 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
-import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
+import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import scala.collection.mutable.ArrayBuffer
/**
@@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer
* there are two type of Schedulable entities(Pools and TaskSetManagers)
*/
private[spark] trait Schedulable {
- var parent: Schedulable
+ var parent: Pool
// child queues
def schedulableQueue: ArrayBuffer[Schedulable]
def schedulingMode: SchedulingMode
@@ -36,8 +36,6 @@ private[spark] trait Schedulable {
def stageId: Int
def name: String
- def increaseRunningTasks(taskNum: Int): Unit
- def decreaseRunningTasks(taskNum: Int): Unit
def addSchedulable(schedulable: Schedulable): Unit
def removeSchedulable(schedulable: Schedulable): Unit
def getSchedulableByName(name: String): Schedulable
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
index f80823317b..4e25086ec9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulableBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
@@ -15,16 +15,14 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
-import java.io.{File, FileInputStream, FileOutputStream, FileNotFoundException}
-import java.util.Properties
-
-import scala.xml.XML
+import java.io.{FileInputStream, InputStream}
+import java.util.{NoSuchElementException, Properties}
import org.apache.spark.Logging
-import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
+import scala.xml.XML
/**
* An interface to build Schedulable tree
@@ -51,7 +49,8 @@ private[spark] class FIFOSchedulableBuilder(val rootPool: Pool)
private[spark] class FairSchedulableBuilder(val rootPool: Pool)
extends SchedulableBuilder with Logging {
- val schedulerAllocFile = System.getProperty("spark.scheduler.allocation.file")
+ val schedulerAllocFile = Option(System.getProperty("spark.scheduler.allocation.file"))
+ val DEFAULT_SCHEDULER_FILE = "fairscheduler.xml"
val FAIR_SCHEDULER_PROPERTIES = "spark.scheduler.pool"
val DEFAULT_POOL_NAME = "default"
val MINIMUM_SHARES_PROPERTY = "minShare"
@@ -64,48 +63,26 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool)
val DEFAULT_WEIGHT = 1
override def buildPools() {
- if (schedulerAllocFile != null) {
- val file = new File(schedulerAllocFile)
- if (file.exists()) {
- val xml = XML.loadFile(file)
- for (poolNode <- (xml \\ POOLS_PROPERTY)) {
-
- val poolName = (poolNode \ POOL_NAME_PROPERTY).text
- var schedulingMode = DEFAULT_SCHEDULING_MODE
- var minShare = DEFAULT_MINIMUM_SHARE
- var weight = DEFAULT_WEIGHT
-
- val xmlSchedulingMode = (poolNode \ SCHEDULING_MODE_PROPERTY).text
- if (xmlSchedulingMode != "") {
- try {
- schedulingMode = SchedulingMode.withName(xmlSchedulingMode)
- } catch {
- case e: Exception => logInfo("Error xml schedulingMode, using default schedulingMode")
- }
- }
-
- val xmlMinShare = (poolNode \ MINIMUM_SHARES_PROPERTY).text
- if (xmlMinShare != "") {
- minShare = xmlMinShare.toInt
- }
-
- val xmlWeight = (poolNode \ WEIGHT_PROPERTY).text
- if (xmlWeight != "") {
- weight = xmlWeight.toInt
- }
-
- val pool = new Pool(poolName, schedulingMode, minShare, weight)
- rootPool.addSchedulable(pool)
- logInfo("Created pool %s, schedulingMode: %s, minShare: %d, weight: %d".format(
- poolName, schedulingMode, minShare, weight))
+ var is: Option[InputStream] = None
+ try {
+ is = Option {
+ schedulerAllocFile.map { f =>
+ new FileInputStream(f)
+ }.getOrElse {
+ getClass.getClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE)
}
- } else {
- throw new java.io.FileNotFoundException(
- "Fair scheduler allocation file not found: " + schedulerAllocFile)
}
+
+ is.foreach { i => buildFairSchedulerPool(i) }
+ } finally {
+ is.foreach(_.close())
}
// finally create "default" pool
+ buildDefaultPool()
+ }
+
+ private def buildDefaultPool() {
if (rootPool.getSchedulableByName(DEFAULT_POOL_NAME) == null) {
val pool = new Pool(DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE,
DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)
@@ -115,6 +92,42 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool)
}
}
+ private def buildFairSchedulerPool(is: InputStream) {
+ val xml = XML.load(is)
+ for (poolNode <- (xml \\ POOLS_PROPERTY)) {
+
+ val poolName = (poolNode \ POOL_NAME_PROPERTY).text
+ var schedulingMode = DEFAULT_SCHEDULING_MODE
+ var minShare = DEFAULT_MINIMUM_SHARE
+ var weight = DEFAULT_WEIGHT
+
+ val xmlSchedulingMode = (poolNode \ SCHEDULING_MODE_PROPERTY).text
+ if (xmlSchedulingMode != "") {
+ try {
+ schedulingMode = SchedulingMode.withName(xmlSchedulingMode)
+ } catch {
+ case e: NoSuchElementException =>
+ logWarning("Error xml schedulingMode, using default schedulingMode")
+ }
+ }
+
+ val xmlMinShare = (poolNode \ MINIMUM_SHARES_PROPERTY).text
+ if (xmlMinShare != "") {
+ minShare = xmlMinShare.toInt
+ }
+
+ val xmlWeight = (poolNode \ WEIGHT_PROPERTY).text
+ if (xmlWeight != "") {
+ weight = xmlWeight.toInt
+ }
+
+ val pool = new Pool(poolName, schedulingMode, minShare, weight)
+ rootPool.addSchedulable(pool)
+ logInfo("Created pool %s, schedulingMode: %s, minShare: %d, weight: %d".format(
+ poolName, schedulingMode, minShare, weight))
+ }
+ }
+
override def addTaskSetManager(manager: Schedulable, properties: Properties) {
var poolName = DEFAULT_POOL_NAME
var parentPool = rootPool.getSchedulableByName(poolName)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala
index cbeed4731a..3418640b8c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingAlgorithm.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
/**
* An interface for sort algorithm
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingMode.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala
index 34811389a0..0a786deb16 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingMode.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
/**
* "FAIR" and "FIFO" determines which policy is used
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index d23df0dd2b..3b9d5679fb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -25,7 +25,7 @@ import scala.collection.mutable.HashMap
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.storage._
-import org.apache.spark.util.{TimeStampedHashMap, MetadataCleaner}
+import org.apache.spark.util.{MetadataCleanerType, TimeStampedHashMap, MetadataCleaner}
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.RDDCheckpointData
@@ -37,7 +37,7 @@ private[spark] object ShuffleMapTask {
// expensive on the master node if it needs to launch thousands of tasks.
val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
- val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.clearOldValues)
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_MAP_TASK, serializedInfoCache.clearOldValues)
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
synchronized {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index c3cf4b8907..62b521ad45 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -18,7 +18,6 @@
package org.apache.spark.scheduler
import java.util.Properties
-import org.apache.spark.scheduler.cluster.TaskInfo
import org.apache.spark.util.{Utils, Distribution}
import org.apache.spark.{Logging, SparkContext, TaskEndReason}
import org.apache.spark.executor.TaskMetrics
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index a65e1ecd6d..4d3e4a17ba 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -70,5 +70,23 @@ private[spark] class SparkListenerBus() extends Logging {
queueFullErrorMessageLogged = true
}
}
+
+ /**
+ * Waits until there are no more events in the queue, or until the specified time has elapsed.
+ * Used for testing only. Returns true if the queue has emptied and false is the specified time
+ * elapsed before the queue emptied.
+ */
+ def waitUntilEmpty(timeoutMillis: Int): Boolean = {
+ val finishTime = System.currentTimeMillis + timeoutMillis
+ while (!eventQueue.isEmpty()) {
+ if (System.currentTimeMillis > finishTime) {
+ return false
+ }
+ /* Sleep rather than using wait/notify, because this is used only for testing and wait/notify
+ * add overhead in the general case. */
+ Thread.sleep(10)
+ }
+ return true
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index 72cb1c9ce8..b6f11969e5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -17,8 +17,8 @@
package org.apache.spark.scheduler
-import org.apache.spark.scheduler.cluster.TaskInfo
import scala.collection._
+
import org.apache.spark.executor.TaskMetrics
case class StageInfo(
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
index 309ac2f6c9..5190d234d4 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskDescription.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import java.nio.ByteBuffer
import org.apache.spark.util.SerializableBuffer
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 9685fb1a67..7c2a422aff 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import org.apache.spark.util.Utils
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskLocality.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala
index 5d4130e14a..47b0f387aa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskLocality.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
private[spark] object TaskLocality
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index 5c7e5bb977..7e468d0d67 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -24,14 +24,20 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.{SparkEnv}
import java.nio.ByteBuffer
import org.apache.spark.util.Utils
+import org.apache.spark.storage.BlockId
// Task result. Also contains updates to accumulator variables.
-// TODO: Use of distributed cache to return result is a hack to get around
-// what seems to be a bug with messages over 60KB in libprocess; fix it
+private[spark] sealed trait TaskResult[T]
+
+/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
+private[spark]
+case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable
+
+/** A TaskResult that contains the task's return value and accumulator updates. */
private[spark]
-class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics)
- extends Externalizable
-{
+class DirectTaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics)
+ extends TaskResult[T] with Externalizable {
+
def this() = this(null.asInstanceOf[T], null, null)
override def writeExternal(out: ObjectOutput) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 63be8ba3f5..7c2a9f03d7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -17,10 +17,11 @@
package org.apache.spark.scheduler
-import org.apache.spark.scheduler.cluster.Pool
-import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
+import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
+
/**
* Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler.
+ * Each TaskScheduler schedulers task for a single SparkContext.
* These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
* and are responsible for sending the tasks to the cluster, running them, retrying if there
* are failures, and mitigating stragglers. They return events to the DAGScheduler through
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala
index 83be051c1a..593fa9fb93 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala
@@ -17,7 +17,6 @@
package org.apache.spark.scheduler
-import org.apache.spark.scheduler.cluster.TaskInfo
import scala.collection.mutable.Map
import org.apache.spark.TaskEndReason
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 648a3ef922..90f6bcefac 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -15,12 +15,11 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import java.nio.ByteBuffer
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler.TaskSet
/**
* Tracks and schedules the tasks within a single TaskSet. This class keeps track of the status of
@@ -45,7 +44,5 @@ private[spark] trait TaskSetManager extends Schedulable {
maxLocality: TaskLocality.TaskLocality)
: Option[TaskDescription]
- def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer)
-
def error(message: String)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
index 919acce828..1a844b7e7e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
@@ -18,6 +18,9 @@
package org.apache.spark.scheduler.cluster
import java.lang.{Boolean => JBoolean}
+import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicLong
+import java.util.{TimerTask, Timer}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@@ -26,10 +29,7 @@ import scala.collection.mutable.HashSet
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
-import java.nio.ByteBuffer
-import java.util.concurrent.atomic.AtomicLong
-import java.util.{TimerTask, Timer}
+import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
/**
* The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call
@@ -55,7 +55,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// Threshold above which we warn user initial TaskSet may be starved
val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong
- val activeTaskSets = new HashMap[String, TaskSetManager]
+ // ClusterTaskSetManagers are not thread safe, so any access to one should be synchronized
+ // on this class.
+ val activeTaskSets = new HashMap[String, ClusterTaskSetManager]
val taskIdToTaskSetId = new HashMap[Long, String]
val taskIdToExecutorId = new HashMap[Long, String]
@@ -65,7 +67,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
@volatile private var hasLaunchedTask = false
private val starvationTimer = new Timer(true)
- // Incrementing Mesos task IDs
+ // Incrementing task IDs
val nextTaskId = new AtomicLong(0)
// Which executor IDs we have executors on
@@ -96,6 +98,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
val schedulingMode: SchedulingMode = SchedulingMode.withName(
System.getProperty("spark.scheduler.mode", "FIFO"))
+ // This is a var so that we can reset it for testing purposes.
+ private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
+
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
}
@@ -234,7 +239,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- var taskSetToUpdate: Option[TaskSetManager] = None
var failedExecutor: Option[String] = None
var taskFailed = false
synchronized {
@@ -249,9 +253,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
taskIdToTaskSetId.get(tid) match {
case Some(taskSetId) =>
- if (activeTaskSets.contains(taskSetId)) {
- taskSetToUpdate = Some(activeTaskSets(taskSetId))
- }
if (TaskState.isFinished(state)) {
taskIdToTaskSetId.remove(tid)
if (taskSetTaskIds.contains(taskSetId)) {
@@ -262,6 +263,15 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (state == TaskState.FAILED) {
taskFailed = true
}
+ activeTaskSets.get(taskSetId).foreach { taskSet =>
+ if (state == TaskState.FINISHED) {
+ taskSet.removeRunningTask(tid)
+ taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
+ } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
+ taskSet.removeRunningTask(tid)
+ taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
+ }
+ }
case None =>
logInfo("Ignoring update from TID " + tid + " because its task set is gone")
}
@@ -269,10 +279,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
case e: Exception => logError("Exception in statusUpdate", e)
}
}
- // Update the task set and DAGScheduler without holding a lock on this, since that can deadlock
- if (taskSetToUpdate != None) {
- taskSetToUpdate.get.statusUpdate(tid, state, serializedData)
- }
+ // Update the DAGScheduler without holding a lock on this, since that can deadlock
if (failedExecutor != None) {
listener.executorLost(failedExecutor.get)
backend.reviveOffers()
@@ -283,6 +290,25 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
+ def handleSuccessfulTask(
+ taskSetManager: ClusterTaskSetManager,
+ tid: Long,
+ taskResult: DirectTaskResult[_]) = synchronized {
+ taskSetManager.handleSuccessfulTask(tid, taskResult)
+ }
+
+ def handleFailedTask(
+ taskSetManager: ClusterTaskSetManager,
+ tid: Long,
+ taskState: TaskState,
+ reason: Option[TaskEndReason]) = synchronized {
+ taskSetManager.handleFailedTask(tid, taskState, reason)
+ if (taskState == TaskState.FINISHED) {
+ // The task finished successfully but the result was lost, so we should revive offers.
+ backend.reviveOffers()
+ }
+ }
+
def error(message: String) {
synchronized {
if (activeTaskSets.size > 0) {
@@ -311,6 +337,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (jarServer != null) {
jarServer.stop()
}
+ if (taskResultGetter != null) {
+ taskResultGetter.stop()
+ }
// sleeping for an arbitrary 5 seconds : to ensure that messages are sent out.
// TODO: Do something better !
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
index 0ac3d7bcfd..936167c13f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -25,15 +25,12 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.math.max
import scala.math.min
+import scala.Some
-import org.apache.spark.{FetchFailed, Logging, Resubmitted, SparkEnv, Success, TaskEndReason, TaskState}
-import org.apache.spark.{ExceptionFailure, SparkException, TaskResultTooBigFailure}
+import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
+ SparkException, Success, TaskEndReason, TaskResultLost, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler._
-import scala.Some
-import org.apache.spark.FetchFailed
-import org.apache.spark.ExceptionFailure
-import org.apache.spark.TaskResultTooBigFailure
import org.apache.spark.util.{SystemClock, Clock}
@@ -71,18 +68,20 @@ private[spark] class ClusterTaskSetManager(
val tasks = taskSet.tasks
val numTasks = tasks.length
val copiesRunning = new Array[Int](numTasks)
- val finished = new Array[Boolean](numTasks)
+ val successful = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
- var tasksFinished = 0
+ var tasksSuccessful = 0
var weight = 1
var minShare = 0
- var runningTasks = 0
var priority = taskSet.priority
var stageId = taskSet.stageId
var name = "TaskSet_"+taskSet.stageId.toString
- var parent: Schedulable = null
+ var parent: Pool = null
+
+ var runningTasks = 0
+ private val runningTasksSet = new HashSet[Long]
// Set of pending tasks for each executor. These collections are actually
// treated as stacks, in which new tasks are added to the end of the
@@ -223,7 +222,7 @@ private[spark] class ClusterTaskSetManager(
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
- if (copiesRunning(index) == 0 && !finished(index)) {
+ if (copiesRunning(index) == 0 && !successful(index)) {
return Some(index)
}
}
@@ -243,7 +242,7 @@ private[spark] class ClusterTaskSetManager(
private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
: Option[(Int, TaskLocality.Value)] =
{
- speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
+ speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
if (!speculatableTasks.isEmpty) {
// Check for process-local or preference-less tasks; note that tasks can be process-local
@@ -344,7 +343,7 @@ private[spark] class ClusterTaskSetManager(
maxLocality: TaskLocality.TaskLocality)
: Option[TaskDescription] =
{
- if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
+ if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) {
val curTime = clock.getTime()
var allowedLocality = getAllowedLocalityLevel(curTime)
@@ -375,7 +374,7 @@ private[spark] class ClusterTaskSetManager(
val serializedTask = Task.serializeWithDependencies(
task, sched.sc.addedFiles, sched.sc.addedJars, ser)
val timeTaken = clock.getTime() - startTime
- increaseRunningTasks(1)
+ addRunningTask(taskId)
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %s:%d".format(taskSet.id, index)
@@ -417,94 +416,61 @@ private[spark] class ClusterTaskSetManager(
index
}
- /** Called by cluster scheduler when one of our tasks changes state */
- override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- SparkEnv.set(env)
- state match {
- case TaskState.FINISHED =>
- taskFinished(tid, state, serializedData)
- case TaskState.LOST =>
- taskLost(tid, state, serializedData)
- case TaskState.FAILED =>
- taskLost(tid, state, serializedData)
- case TaskState.KILLED =>
- taskLost(tid, state, serializedData)
- case _ =>
- }
- }
-
- def taskStarted(task: Task[_], info: TaskInfo) {
+ private def taskStarted(task: Task[_], info: TaskInfo) {
sched.listener.taskStarted(task, info)
}
- def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ /**
+ * Marks the task as successful and notifies the listener that a task has ended.
+ */
+ def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
val info = taskInfos(tid)
- if (info.failed) {
- // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
- // or even from Mesos itself when acks get delayed.
- return
- }
val index = info.index
info.markSuccessful()
- decreaseRunningTasks(1)
- if (!finished(index)) {
- tasksFinished += 1
+ removeRunningTask(tid)
+ if (!successful(index)) {
logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
- tid, info.duration, info.host, tasksFinished, numTasks))
- // Deserialize task result and pass it to the scheduler
- try {
- val result = ser.deserialize[TaskResult[_]](serializedData)
- result.metrics.resultSize = serializedData.limit()
- sched.listener.taskEnded(
- tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
- } catch {
- case cnf: ClassNotFoundException =>
- val loader = Thread.currentThread().getContextClassLoader
- throw new SparkException("ClassNotFound with classloader: " + loader, cnf)
- case ex => throw ex
- }
- // Mark finished and stop if we've finished all the tasks
- finished(index) = true
- if (tasksFinished == numTasks) {
+ tid, info.duration, info.host, tasksSuccessful, numTasks))
+ sched.listener.taskEnded(
+ tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
+
+ // Mark successful and stop if all the tasks have succeeded.
+ tasksSuccessful += 1
+ successful(index) = true
+ if (tasksSuccessful == numTasks) {
sched.taskSetFinished(this)
}
} else {
- logInfo("Ignoring task-finished event for TID " + tid +
- " because task " + index + " is already finished")
+ logInfo("Ignorning task-finished event for TID " + tid + " because task " +
+ index + " has already completed successfully")
}
}
- def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ /**
+ * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the listener.
+ */
+ def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
val info = taskInfos(tid)
if (info.failed) {
- // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
- // or even from Mesos itself when acks get delayed.
return
}
+ removeRunningTask(tid)
val index = info.index
info.markFailed()
- decreaseRunningTasks(1)
- if (!finished(index)) {
+ if (!successful(index)) {
logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
copiesRunning(index) -= 1
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
- if (serializedData != null && serializedData.limit() > 0) {
- val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
- reason match {
+ reason.foreach {
+ _ match {
case fetchFailed: FetchFailed =>
logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
- finished(index) = true
- tasksFinished += 1
+ successful(index) = true
+ tasksSuccessful += 1
sched.taskSetFinished(this)
- decreaseRunningTasks(runningTasks)
- return
-
- case taskResultTooBig: TaskResultTooBigFailure =>
- logInfo("Loss was due to task %s result exceeding Akka frame size; aborting job".format(
- tid))
- abort("Task %s result exceeded Akka frame size".format(tid))
+ removeAllRunningTasks()
return
case ef: ExceptionFailure =>
@@ -534,13 +500,16 @@ private[spark] class ClusterTaskSetManager(
logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
}
+ case TaskResultLost =>
+ logInfo("Lost result for TID %s on host %s".format(tid, info.host))
+ sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
+
case _ => {}
}
}
// On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index)
- // Count failed attempts only on FAILED and LOST state (not on KILLED)
- if (state == TaskState.FAILED || state == TaskState.LOST) {
+ if (state != TaskState.KILLED) {
numFailures(index) += 1
if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format(
@@ -564,22 +533,36 @@ private[spark] class ClusterTaskSetManager(
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
sched.listener.taskSetFailed(taskSet, message)
- decreaseRunningTasks(runningTasks)
+ removeAllRunningTasks()
sched.taskSetFinished(this)
}
- override def increaseRunningTasks(taskNum: Int) {
- runningTasks += taskNum
- if (parent != null) {
- parent.increaseRunningTasks(taskNum)
+ /** If the given task ID is not in the set of running tasks, adds it.
+ *
+ * Used to keep track of the number of running tasks, for enforcing scheduling policies.
+ */
+ def addRunningTask(tid: Long) {
+ if (runningTasksSet.add(tid) && parent != null) {
+ parent.increaseRunningTasks(1)
+ }
+ runningTasks = runningTasksSet.size
+ }
+
+ /** If the given task ID is in the set of running tasks, removes it. */
+ def removeRunningTask(tid: Long) {
+ if (runningTasksSet.remove(tid) && parent != null) {
+ parent.decreaseRunningTasks(1)
}
+ runningTasks = runningTasksSet.size
}
- override def decreaseRunningTasks(taskNum: Int) {
- runningTasks -= taskNum
+ private def removeAllRunningTasks() {
+ val numRunningTasks = runningTasksSet.size
+ runningTasksSet.clear()
if (parent != null) {
- parent.decreaseRunningTasks(taskNum)
+ parent.decreaseRunningTasks(numRunningTasks)
}
+ runningTasks = 0
}
override def getSchedulableByName(name: String): Schedulable = {
@@ -615,10 +598,10 @@ private[spark] class ClusterTaskSetManager(
if (tasks(0).isInstanceOf[ShuffleMapTask]) {
for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index
- if (finished(index)) {
- finished(index) = false
+ if (successful(index)) {
+ successful(index) = false
copiesRunning(index) -= 1
- tasksFinished -= 1
+ tasksSuccessful -= 1
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
@@ -628,7 +611,7 @@ private[spark] class ClusterTaskSetManager(
}
// Also re-enqueue any tasks that were running on the node
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
- taskLost(tid, TaskState.KILLED, null)
+ handleFailedTask(tid, TaskState.KILLED, None)
}
}
@@ -641,24 +624,24 @@ private[spark] class ClusterTaskSetManager(
*/
override def checkSpeculatableTasks(): Boolean = {
// Can't speculate if we only have one task, or if all tasks have finished.
- if (numTasks == 1 || tasksFinished == numTasks) {
+ if (numTasks == 1 || tasksSuccessful == numTasks) {
return false
}
var foundTasks = false
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
- if (tasksFinished >= minFinishedForSpeculation) {
+ if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
val time = clock.getTime()
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
Arrays.sort(durations)
- val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
+ val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1))
val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
// TODO: Threshold should also look at standard deviation of task durations and have a lower
// bound based on that.
logDebug("Task length threshold for speculation: " + threshold)
for ((tid, info) <- taskInfos) {
val index = info.index
- if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
+ if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
!speculatableTasks.contains(index)) {
logInfo(
"Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
@@ -672,7 +655,7 @@ private[spark] class ClusterTaskSetManager(
}
override def hasPendingTasks(): Boolean = {
- numTasks > 0 && tasksFinished < numTasks
+ numTasks > 0 && tasksSuccessful < numTasks
}
private def getLocalityWait(level: TaskLocality.TaskLocality): Long = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index 9c49768c0c..cb88159b8d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -26,7 +26,7 @@ import org.apache.spark.util.Utils
private[spark] class SparkDeploySchedulerBackend(
scheduler: ClusterScheduler,
sc: SparkContext,
- master: String,
+ masters: Array[String],
appName: String)
extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem)
with ClientListener
@@ -52,7 +52,7 @@ private[spark] class SparkDeploySchedulerBackend(
val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome,
"http://" + sc.ui.appUIAddress)
- client = new Client(sc.env.actorSystem, master, appDesc, this)
+ client = new Client(sc.env.actorSystem, masters, appDesc, this)
client.start()
}
@@ -71,8 +71,14 @@ private[spark] class SparkDeploySchedulerBackend(
override def disconnected() {
if (!stopping) {
- logError("Disconnected from Spark cluster!")
- scheduler.error("Disconnected from Spark cluster")
+ logWarning("Disconnected from Spark cluster! Waiting for reconnection...")
+ }
+ }
+
+ override def dead() {
+ if (!stopping) {
+ logError("Spark cluster looks dead, giving up.")
+ scheduler.error("Spark cluster looks down")
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala
index 9c36d221f6..c0b836bf1a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala
@@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster
import java.nio.ByteBuffer
import org.apache.spark.TaskState.TaskState
+import org.apache.spark.scheduler.TaskDescription
import org.apache.spark.util.{Utils, SerializableBuffer}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
index b4ea0be415..f3aeea43d5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
@@ -29,6 +29,7 @@ import akka.util.Duration
import akka.util.duration._
import org.apache.spark.{SparkException, Logging, TaskState}
+import org.apache.spark.scheduler.TaskDescription
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._
import org.apache.spark.util.Utils
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
new file mode 100644
index 0000000000..feec8ecfe4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import java.nio.ByteBuffer
+import java.util.concurrent.{LinkedBlockingDeque, ThreadFactory, ThreadPoolExecutor, TimeUnit}
+
+import org.apache.spark._
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
+import org.apache.spark.serializer.SerializerInstance
+
+/**
+ * Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
+ */
+private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
+ extends Logging {
+ private val MIN_THREADS = System.getProperty("spark.resultGetter.minThreads", "4").toInt
+ private val MAX_THREADS = System.getProperty("spark.resultGetter.maxThreads", "4").toInt
+ private val getTaskResultExecutor = new ThreadPoolExecutor(
+ MIN_THREADS,
+ MAX_THREADS,
+ 0L,
+ TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable],
+ new ResultResolverThreadFactory)
+
+ class ResultResolverThreadFactory extends ThreadFactory {
+ private var counter = 0
+ private var PREFIX = "Result resolver thread"
+
+ override def newThread(r: Runnable): Thread = {
+ val thread = new Thread(r, "%s-%s".format(PREFIX, counter))
+ counter += 1
+ thread.setDaemon(true)
+ return thread
+ }
+ }
+
+ protected val serializer = new ThreadLocal[SerializerInstance] {
+ override def initialValue(): SerializerInstance = {
+ return sparkEnv.closureSerializer.newInstance()
+ }
+ }
+
+ def enqueueSuccessfulTask(
+ taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) {
+ getTaskResultExecutor.execute(new Runnable {
+ override def run() {
+ try {
+ val result = serializer.get().deserialize[TaskResult[_]](serializedData) match {
+ case directResult: DirectTaskResult[_] => directResult
+ case IndirectTaskResult(blockId) =>
+ logDebug("Fetching indirect task result for TID %s".format(tid))
+ val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
+ if (!serializedTaskResult.isDefined) {
+ /* We won't be able to get the task result if the machine that ran the task failed
+ * between when the task ended and when we tried to fetch the result, or if the
+ * block manager had to flush the result. */
+ scheduler.handleFailedTask(
+ taskSetManager, tid, TaskState.FINISHED, Some(TaskResultLost))
+ return
+ }
+ val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
+ serializedTaskResult.get)
+ sparkEnv.blockManager.master.removeBlock(blockId)
+ deserializedResult
+ }
+ result.metrics.resultSize = serializedData.limit()
+ scheduler.handleSuccessfulTask(taskSetManager, tid, result)
+ } catch {
+ case cnf: ClassNotFoundException =>
+ val loader = Thread.currentThread.getContextClassLoader
+ taskSetManager.abort("ClassNotFound with classloader: " + loader)
+ case ex =>
+ taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex))
+ }
+ }
+ })
+ }
+
+ def enqueueFailedTask(taskSetManager: ClusterTaskSetManager, tid: Long, taskState: TaskState,
+ serializedData: ByteBuffer) {
+ var reason: Option[TaskEndReason] = None
+ getTaskResultExecutor.execute(new Runnable {
+ override def run() {
+ try {
+ if (serializedData != null && serializedData.limit() > 0) {
+ reason = Some(serializer.get().deserialize[TaskEndReason](
+ serializedData, getClass.getClassLoader))
+ }
+ } catch {
+ case cnd: ClassNotFoundException =>
+ // Log an error but keep going here -- the task failed, so not catastropic if we can't
+ // deserialize the reason.
+ val loader = Thread.currentThread.getContextClassLoader
+ logError(
+ "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
+ case ex => {}
+ }
+ scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
+ }
+ })
+ }
+
+ def stop() {
+ getTaskResultExecutor.shutdownNow()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 3dbe61d706..8f2eef9a53 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -15,22 +15,22 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.mesos
+package org.apache.spark.scheduler.cluster.mesos
-import com.google.protobuf.ByteString
+import java.io.File
+import java.util.{ArrayList => JArrayList, List => JList}
+import java.util.Collections
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.JavaConversions._
+import com.google.protobuf.ByteString
import org.apache.mesos.{Scheduler => MScheduler}
import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
-import org.apache.spark.{SparkException, Logging, SparkContext}
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
-import scala.collection.JavaConversions._
-import java.io.File
-import org.apache.spark.scheduler.cluster._
-import java.util.{ArrayList => JArrayList, List => JList}
-import java.util.Collections
-import org.apache.spark.TaskState
+import org.apache.spark.{SparkException, Logging, SparkContext, TaskState}
+import org.apache.spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend}
/**
* A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
diff --git a/core/src/main/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index 541f86e338..50cbc2ca92 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -15,22 +15,24 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.mesos
+package org.apache.spark.scheduler.cluster.mesos
-import com.google.protobuf.ByteString
+import java.io.File
+import java.util.{ArrayList => JArrayList, List => JList}
+import java.util.Collections
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.JavaConversions._
+import com.google.protobuf.ByteString
import org.apache.mesos.{Scheduler => MScheduler}
import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
-import org.apache.spark.{SparkException, Logging, SparkContext}
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
-import scala.collection.JavaConversions._
-import java.io.File
-import org.apache.spark.scheduler.cluster._
-import java.util.{ArrayList => JArrayList, List => JList}
-import java.util.Collections
-import org.apache.spark.TaskState
+import org.apache.spark.{Logging, SparkException, SparkContext, TaskState}
+import org.apache.spark.scheduler.TaskDescription
+import org.apache.spark.scheduler.cluster.{ClusterScheduler, ExecutorExited, ExecutorLossReason}
+import org.apache.spark.scheduler.cluster.{SchedulerBackend, SlaveLost, WorkerOffer}
import org.apache.spark.util.Utils
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
index 8cb4d1396f..4d1bb1c639 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
@@ -31,8 +31,7 @@ import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.ExecutorURLClassLoader
import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster._
-import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
+import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import akka.actor._
import org.apache.spark.util.Utils
@@ -92,7 +91,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
var rootPool: Pool = null
val schedulingMode: SchedulingMode = SchedulingMode.withName(
System.getProperty("spark.scheduler.mode", "FIFO"))
- val activeTaskSets = new HashMap[String, TaskSetManager]
+ val activeTaskSets = new HashMap[String, LocalTaskSetManager]
val taskIdToTaskSetId = new HashMap[Long, String]
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
@@ -211,7 +210,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
deserializedTask.metrics.get.executorRunTime = serviceTime.toInt
deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime
deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
- val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null))
+ val taskResult = new DirectTaskResult(
+ result, accumUpdates, deserializedTask.metrics.getOrElse(null))
val serializedResult = ser.serialize(taskResult)
localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
index e52cb998bd..c2e2399ccb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
@@ -21,16 +21,16 @@ import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
-import org.apache.spark.{ExceptionFailure, Logging, SparkEnv, Success, TaskState}
+import org.apache.spark.{ExceptionFailure, Logging, SparkEnv, SparkException, Success, TaskState}
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler.{Task, TaskResult, TaskSet}
-import org.apache.spark.scheduler.cluster.{Schedulable, TaskDescription, TaskInfo, TaskLocality, TaskSetManager}
+import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Pool, Schedulable, Task,
+ TaskDescription, TaskInfo, TaskLocality, TaskResult, TaskSet, TaskSetManager}
private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet)
extends TaskSetManager with Logging {
- var parent: Schedulable = null
+ var parent: Pool = null
var weight: Int = 1
var minShare: Int = 0
var runningTasks: Int = 0
@@ -49,14 +49,14 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
val numFailures = new Array[Int](numTasks)
val MAX_TASK_FAILURES = sched.maxFailures
- override def increaseRunningTasks(taskNum: Int): Unit = {
+ def increaseRunningTasks(taskNum: Int): Unit = {
runningTasks += taskNum
if (parent != null) {
parent.increaseRunningTasks(taskNum)
}
}
- override def decreaseRunningTasks(taskNum: Int): Unit = {
+ def decreaseRunningTasks(taskNum: Int): Unit = {
runningTasks -= taskNum
if (parent != null) {
parent.decreaseRunningTasks(taskNum)
@@ -132,7 +132,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
return None
}
- override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
SparkEnv.set(env)
state match {
case TaskState.FINISHED =>
@@ -152,7 +152,12 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
val index = info.index
val task = taskSet.tasks(index)
info.markSuccessful()
- val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
+ val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) match {
+ case directResult: DirectTaskResult[_] => directResult
+ case IndirectTaskResult(blockId) => {
+ throw new SparkException("Expect only DirectTaskResults when using LocalScheduler")
+ }
+ }
result.metrics.resultSize = serializedData.limit()
sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
numFinished += 1
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 3feafde8b6..263ff59ba6 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -23,12 +23,11 @@ import java.io.{EOFException, InputStream, OutputStream}
import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer}
import com.esotericsoftware.kryo.{KryoException, Kryo}
import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
-import com.twitter.chill.ScalaKryoInstantiator
+import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar}
import org.apache.spark.{SerializableWritable, Logging}
-import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock, StorageLevel}
-
import org.apache.spark.broadcast.HttpBroadcast
+import org.apache.spark.storage.{GetBlock,GotBlock, PutBlock, StorageLevel, TestBlockId}
/**
* A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]].
@@ -38,20 +37,23 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging
def newKryoOutput() = new KryoOutput(bufferSize)
- def newKryoInput() = new KryoInput(bufferSize)
-
def newKryo(): Kryo = {
- val instantiator = new ScalaKryoInstantiator
+ val instantiator = new EmptyScalaKryoInstantiator
val kryo = instantiator.newKryo()
val classLoader = Thread.currentThread.getContextClassLoader
+ val blockId = TestBlockId("1")
// Register some commonly used classes
val toRegister: Seq[AnyRef] = Seq(
ByteBuffer.allocate(1),
StorageLevel.MEMORY_ONLY,
- PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
- GotBlock("1", ByteBuffer.allocate(1)),
- GetBlock("1")
+ PutBlock(blockId, ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
+ GotBlock(blockId, ByteBuffer.allocate(1)),
+ GetBlock(blockId),
+ 1 to 10,
+ 1 until 10,
+ 1L to 10L,
+ 1L until 10L
)
for (obj <- toRegister) kryo.register(obj.getClass)
@@ -67,6 +69,10 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging
reg.registerClasses(kryo)
}
+ // Register Chill's classes; we do this after our ranges and the user's own classes to let
+ // our code override the generic serialziers in Chill for things like Seq
+ new AllScalaRegistrar().apply(kryo)
+
kryo.setClassLoader(classLoader)
// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops
@@ -114,8 +120,10 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser
private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance with Logging {
val kryo = ks.newKryo()
- val output = ks.newKryoOutput()
- val input = ks.newKryoInput()
+
+ // Make these lazy vals to avoid creating a buffer unless we use them
+ lazy val output = ks.newKryoOutput()
+ lazy val input = new KryoInput()
def serialize[T](t: T): ByteBuffer = {
output.clear()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockException.scala b/core/src/main/scala/org/apache/spark/storage/BlockException.scala
index 290dbce4f5..0d0a2dadc7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockException.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockException.scala
@@ -18,5 +18,5 @@
package org.apache.spark.storage
private[spark]
-case class BlockException(blockId: String, message: String) extends Exception(message)
+case class BlockException(blockId: BlockId, message: String) extends Exception(message)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
index 3aeda3879d..e51c5b30a3 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
@@ -47,7 +47,7 @@ import org.apache.spark.util.Utils
*/
private[storage]
-trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])]
+trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])]
with Logging with BlockFetchTracker {
def initialize()
}
@@ -57,20 +57,20 @@ private[storage]
object BlockFetcherIterator {
// A request to fetch one or more blocks, complete with their sizes
- class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
+ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) {
val size = blocks.map(_._2).sum
}
// A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
// the block (since we want all deserializaton to happen in the calling thread); can also
// represent a fetch failure if size == -1.
- class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
+ class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) {
def failed: Boolean = size == -1
}
class BasicBlockFetcherIterator(
private val blockManager: BlockManager,
- val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer)
extends BlockFetcherIterator {
@@ -92,12 +92,12 @@ object BlockFetcherIterator {
// This represents the number of local blocks, also counting zero-sized blocks
private var numLocal = 0
// BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
- protected val localBlocksToFetch = new ArrayBuffer[String]()
+ protected val localBlocksToFetch = new ArrayBuffer[BlockId]()
// This represents the number of remote blocks, also counting zero-sized blocks
private var numRemote = 0
// BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
- protected val remoteBlocksToFetch = new HashSet[String]()
+ protected val remoteBlocksToFetch = new HashSet[BlockId]()
// A queue to hold our results.
protected val results = new LinkedBlockingQueue[FetchResult]
@@ -167,7 +167,7 @@ object BlockFetcherIterator {
logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
val iterator = blockInfos.iterator
var curRequestSize = 0L
- var curBlocks = new ArrayBuffer[(String, Long)]
+ var curBlocks = new ArrayBuffer[(BlockId, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
// Skip empty blocks
@@ -183,7 +183,7 @@ object BlockFetcherIterator {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
curRequestSize = 0
- curBlocks = new ArrayBuffer[(String, Long)]
+ curBlocks = new ArrayBuffer[(BlockId, Long)]
}
}
// Add in the final request
@@ -241,7 +241,7 @@ object BlockFetcherIterator {
override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
- override def next(): (String, Option[Iterator[Any]]) = {
+ override def next(): (BlockId, Option[Iterator[Any]]) = {
resultsGotten += 1
val startFetchWait = System.currentTimeMillis()
val result = results.take()
@@ -267,7 +267,7 @@ object BlockFetcherIterator {
class NettyBlockFetcherIterator(
blockManager: BlockManager,
- blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer)
extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) {
@@ -303,7 +303,7 @@ object BlockFetcherIterator {
override protected def sendRequest(req: FetchRequest) {
- def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) {
+ def putResult(blockId: BlockId, blockSize: Long, blockData: ByteBuf) {
val fetchResult = new FetchResult(blockId, blockSize,
() => dataDeserialize(blockId, blockData.nioBuffer, serializer))
results.put(fetchResult)
@@ -337,7 +337,7 @@ object BlockFetcherIterator {
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
}
- override def next(): (String, Option[Iterator[Any]]) = {
+ override def next(): (BlockId, Option[Iterator[Any]]) = {
resultsGotten += 1
val result = results.take()
// If all the results has been retrieved, copiers will exit automatically
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
new file mode 100644
index 0000000000..c7efc67a4a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+/**
+ * Identifies a particular Block of data, usually associated with a single file.
+ * A Block can be uniquely identified by its filename, but each type of Block has a different
+ * set of keys which produce its unique name.
+ *
+ * If your BlockId should be serializable, be sure to add it to the BlockId.fromString() method.
+ */
+private[spark] sealed abstract class BlockId {
+ /** A globally unique identifier for this Block. Can be used for ser/de. */
+ def name: String
+
+ // convenience methods
+ def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
+ def isRDD = isInstanceOf[RDDBlockId]
+ def isShuffle = isInstanceOf[ShuffleBlockId]
+ def isBroadcast = isInstanceOf[BroadcastBlockId]
+
+ override def toString = name
+ override def hashCode = name.hashCode
+ override def equals(other: Any): Boolean = other match {
+ case o: BlockId => getClass == o.getClass && name.equals(o.name)
+ case _ => false
+ }
+}
+
+private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
+ def name = "rdd_" + rddId + "_" + splitIndex
+}
+
+private[spark]
+case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
+ def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
+}
+
+private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
+ def name = "broadcast_" + broadcastId
+}
+
+private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
+ def name = "taskresult_" + taskId
+}
+
+private[spark] case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId {
+ def name = "input-" + streamId + "-" + uniqueId
+}
+
+// Intended only for testing purposes
+private[spark] case class TestBlockId(id: String) extends BlockId {
+ def name = "test_" + id
+}
+
+private[spark] object BlockId {
+ val RDD = "rdd_([0-9]+)_([0-9]+)".r
+ val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
+ val BROADCAST = "broadcast_([0-9]+)".r
+ val TASKRESULT = "taskresult_([0-9]+)".r
+ val STREAM = "input-([0-9]+)-([0-9]+)".r
+ val TEST = "test_(.*)".r
+
+ /** Converts a BlockId "name" String back into a BlockId. */
+ def apply(id: String) = id match {
+ case RDD(rddId, splitIndex) =>
+ RDDBlockId(rddId.toInt, splitIndex.toInt)
+ case SHUFFLE(shuffleId, mapId, reduceId) =>
+ ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
+ case BROADCAST(broadcastId) =>
+ BroadcastBlockId(broadcastId.toLong)
+ case TASKRESULT(taskId) =>
+ TaskResultBlockId(taskId.toLong)
+ case STREAM(streamId, uniqueId) =>
+ StreamBlockId(streamId.toInt, uniqueId.toLong)
+ case TEST(value) =>
+ TestBlockId(value)
+ case _ =>
+ throw new IllegalStateException("Unrecognized BlockId: " + id)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 60fdc5f2ee..801f88a3db 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -37,7 +37,6 @@ import org.apache.spark.util._
import sun.nio.ch.DirectBuffer
-
private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
@@ -103,7 +102,7 @@ private[spark] class BlockManager(
val shuffleBlockManager = new ShuffleBlockManager(this)
- private val blockInfo = new TimeStampedHashMap[String, BlockInfo]
+ private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
private[storage] val diskStore: DiskStore =
@@ -154,7 +153,8 @@ private[spark] class BlockManager(
var heartBeatTask: Cancellable = null
- val metadataCleaner = new MetadataCleaner("BlockManager", this.dropOldBlocks)
+ private val metadataCleaner = new MetadataCleaner(MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks)
+ private val broadcastCleaner = new MetadataCleaner(MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks)
initialize()
// The compression codec to use. Note that the "lazy" val is necessary because we want to delay
@@ -248,7 +248,7 @@ private[spark] class BlockManager(
/**
* Get storage level of local block. If no info exists for the block, then returns null.
*/
- def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
+ def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
/**
* Tell the master about the current storage status of a block. This will send a block update
@@ -258,7 +258,7 @@ private[spark] class BlockManager(
* droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid).
* This ensures that update in master will compensate for the increase in memory on slave.
*/
- def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) {
+ def reportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L) {
val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize)
if (needReregister) {
logInfo("Got told to reregister updating block " + blockId)
@@ -273,7 +273,7 @@ private[spark] class BlockManager(
* which will be true if the block was successfully recorded and false if
* the slave needs to re-register.
*/
- private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = {
+ private def tryToReportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = {
val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized {
info.level match {
case null =>
@@ -298,7 +298,7 @@ private[spark] class BlockManager(
/**
* Get locations of an array of blocks.
*/
- def getLocationBlockIds(blockIds: Array[String]): Array[Seq[BlockManagerId]] = {
+ def getLocationBlockIds(blockIds: Array[BlockId]): Array[Seq[BlockManagerId]] = {
val startTimeMs = System.currentTimeMillis
val locations = master.getLocations(blockIds).toArray
logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
@@ -310,7 +310,7 @@ private[spark] class BlockManager(
* shuffle blocks. It is safe to do so without a lock on block info since disk store
* never deletes (recent) items.
*/
- def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ def getLocalFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
diskStore.getValues(blockId, serializer).orElse(
sys.error("Block " + blockId + " not found on disk, though it should be"))
}
@@ -318,7 +318,7 @@ private[spark] class BlockManager(
/**
* Get block from local block manager.
*/
- def getLocal(blockId: String): Option[Iterator[Any]] = {
+ def getLocal(blockId: BlockId): Option[Iterator[Any]] = {
logDebug("Getting local block " + blockId)
val info = blockInfo.get(blockId).orNull
if (info != null) {
@@ -399,13 +399,13 @@ private[spark] class BlockManager(
/**
* Get block from the local block manager as serialized bytes.
*/
- def getLocalBytes(blockId: String): Option[ByteBuffer] = {
+ def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = {
// TODO: This whole thing is very similar to getLocal; we need to refactor it somehow
logDebug("Getting local block " + blockId + " as bytes")
// As an optimization for map output fetches, if the block is for a shuffle, return it
// without acquiring a lock; the disk store never deletes (recent) items so this should work
- if (ShuffleBlockManager.isShuffle(blockId)) {
+ if (blockId.isShuffle) {
return diskStore.getBytes(blockId) match {
case Some(bytes) =>
Some(bytes)
@@ -472,7 +472,7 @@ private[spark] class BlockManager(
/**
* Get block from remote block managers.
*/
- def getRemote(blockId: String): Option[Iterator[Any]] = {
+ def getRemote(blockId: BlockId): Option[Iterator[Any]] = {
if (blockId == null) {
throw new IllegalArgumentException("Block Id is null")
}
@@ -484,7 +484,7 @@ private[spark] class BlockManager(
for (loc <- locations) {
logDebug("Getting remote block " + blockId + " from " + loc)
val data = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
+ GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
if (data != null) {
return Some(dataDeserialize(blockId, data))
}
@@ -495,10 +495,45 @@ private[spark] class BlockManager(
}
/**
+ * Get block from remote block managers as serialized bytes.
+ */
+ def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = {
+ // TODO: As with getLocalBytes, this is very similar to getRemote and perhaps should be
+ // refactored.
+ if (blockId == null) {
+ throw new IllegalArgumentException("Block Id is null")
+ }
+ logDebug("Getting remote block " + blockId + " as bytes")
+
+ val locations = master.getLocations(blockId)
+ for (loc <- locations) {
+ logDebug("Getting remote block " + blockId + " from " + loc)
+ val data = BlockManagerWorker.syncGetBlock(
+ GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
+ if (data != null) {
+ return Some(data)
+ }
+ logDebug("The value of block " + blockId + " is null")
+ }
+ logDebug("Block " + blockId + " not found")
+ return None
+ }
+
+ /**
* Get a block from the block manager (either local or remote).
*/
- def get(blockId: String): Option[Iterator[Any]] = {
- getLocal(blockId).orElse(getRemote(blockId))
+ def get(blockId: BlockId): Option[Iterator[Any]] = {
+ val local = getLocal(blockId)
+ if (local.isDefined) {
+ logInfo("Found block %s locally".format(blockId))
+ return local
+ }
+ val remote = getRemote(blockId)
+ if (remote.isDefined) {
+ logInfo("Found block %s remotely".format(blockId))
+ return remote
+ }
+ None
}
/**
@@ -508,7 +543,7 @@ private[spark] class BlockManager(
* so that we can control the maxMegabytesInFlight for the fetch.
*/
def getMultiple(
- blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer)
+ blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer)
: BlockFetcherIterator = {
val iter =
@@ -522,7 +557,7 @@ private[spark] class BlockManager(
iter
}
- def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
+ def put(blockId: BlockId, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
: Long = {
val elements = new ArrayBuffer[Any]
elements ++= values
@@ -534,7 +569,7 @@ private[spark] class BlockManager(
* This is currently used for writing shuffle files out. Callers should handle error
* cases.
*/
- def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ def getDiskBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
writer.registerCloseEventHandler(() => {
@@ -548,7 +583,7 @@ private[spark] class BlockManager(
/**
* Put a new block of values to the block manager. Returns its (estimated) size in bytes.
*/
- def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel,
+ def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
tellMaster: Boolean = true) : Long = {
if (blockId == null) {
@@ -668,7 +703,7 @@ private[spark] class BlockManager(
* Put a new block of serialized bytes to the block manager.
*/
def putBytes(
- blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) {
+ blockId: BlockId, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) {
if (blockId == null) {
throw new IllegalArgumentException("Block Id is null")
@@ -769,7 +804,7 @@ private[spark] class BlockManager(
* Replicate block to another node.
*/
var cachedPeers: Seq[BlockManagerId] = null
- private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) {
+ private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel) {
val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
if (cachedPeers == null) {
cachedPeers = master.getPeers(blockManagerId, level.replication - 1)
@@ -792,14 +827,14 @@ private[spark] class BlockManager(
/**
* Read a block consisting of a single object.
*/
- def getSingle(blockId: String): Option[Any] = {
+ def getSingle(blockId: BlockId): Option[Any] = {
get(blockId).map(_.next())
}
/**
* Write a block consisting of a single object.
*/
- def putSingle(blockId: String, value: Any, level: StorageLevel, tellMaster: Boolean = true) {
+ def putSingle(blockId: BlockId, value: Any, level: StorageLevel, tellMaster: Boolean = true) {
put(blockId, Iterator(value), level, tellMaster)
}
@@ -807,7 +842,7 @@ private[spark] class BlockManager(
* Drop a block from memory, possibly putting it on disk if applicable. Called when the memory
* store reaches its limit and needs to free up space.
*/
- def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) {
+ def dropFromMemory(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer]) {
logInfo("Dropping block " + blockId + " from memory")
val info = blockInfo.get(blockId).orNull
if (info != null) {
@@ -856,16 +891,15 @@ private[spark] class BlockManager(
// TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps
// from RDD.id to blocks.
logInfo("Removing RDD " + rddId)
- val rddPrefix = "rdd_" + rddId + "_"
- val blocksToRemove = blockInfo.filter(_._1.startsWith(rddPrefix)).map(_._1)
- blocksToRemove.foreach(blockId => removeBlock(blockId, false))
+ val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
+ blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false))
blocksToRemove.size
}
/**
* Remove a block from both memory and disk.
*/
- def removeBlock(blockId: String, tellMaster: Boolean = true) {
+ def removeBlock(blockId: BlockId, tellMaster: Boolean = true) {
logInfo("Removing block " + blockId)
val info = blockInfo.get(blockId).orNull
if (info != null) info.synchronized {
@@ -886,13 +920,13 @@ private[spark] class BlockManager(
}
}
- def dropOldBlocks(cleanupTime: Long) {
- logInfo("Dropping blocks older than " + cleanupTime)
+ private def dropOldNonBroadcastBlocks(cleanupTime: Long) {
+ logInfo("Dropping non broadcast blocks older than " + cleanupTime)
val iterator = blockInfo.internalMap.entrySet().iterator()
while (iterator.hasNext) {
val entry = iterator.next()
val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
- if (time < cleanupTime) {
+ if (time < cleanupTime && !id.isBroadcast) {
info.synchronized {
val level = info.level
if (level.useMemory) {
@@ -909,34 +943,52 @@ private[spark] class BlockManager(
}
}
- def shouldCompress(blockId: String): Boolean = {
- if (ShuffleBlockManager.isShuffle(blockId)) {
- compressShuffle
- } else if (blockId.startsWith("broadcast_")) {
- compressBroadcast
- } else if (blockId.startsWith("rdd_")) {
- compressRdds
- } else {
- false // Won't happen in a real cluster, but it can in tests
+ private def dropOldBroadcastBlocks(cleanupTime: Long) {
+ logInfo("Dropping broadcast blocks older than " + cleanupTime)
+ val iterator = blockInfo.internalMap.entrySet().iterator()
+ while (iterator.hasNext) {
+ val entry = iterator.next()
+ val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
+ if (time < cleanupTime && id.isBroadcast) {
+ info.synchronized {
+ val level = info.level
+ if (level.useMemory) {
+ memoryStore.remove(id)
+ }
+ if (level.useDisk) {
+ diskStore.remove(id)
+ }
+ iterator.remove()
+ logInfo("Dropped block " + id)
+ }
+ reportBlockStatus(id, info)
+ }
}
}
+ def shouldCompress(blockId: BlockId): Boolean = blockId match {
+ case ShuffleBlockId(_, _, _) => compressShuffle
+ case BroadcastBlockId(_) => compressBroadcast
+ case RDDBlockId(_, _) => compressRdds
+ case _ => false
+ }
+
/**
* Wrap an output stream for compression if block compression is enabled for its block type
*/
- def wrapForCompression(blockId: String, s: OutputStream): OutputStream = {
+ def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
}
/**
* Wrap an input stream for compression if block compression is enabled for its block type
*/
- def wrapForCompression(blockId: String, s: InputStream): InputStream = {
+ def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}
def dataSerialize(
- blockId: String,
+ blockId: BlockId,
values: Iterator[Any],
serializer: Serializer = defaultSerializer): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
@@ -951,7 +1003,7 @@ private[spark] class BlockManager(
* the iterator is reached.
*/
def dataDeserialize(
- blockId: String,
+ blockId: BlockId,
bytes: ByteBuffer,
serializer: Serializer = defaultSerializer): Iterator[Any] = {
bytes.rewind()
@@ -969,6 +1021,7 @@ private[spark] class BlockManager(
memoryStore.clear()
diskStore.clear()
metadataCleaner.cancel()
+ broadcastCleaner.cancel()
logInfo("BlockManager stopped")
}
}
@@ -1005,10 +1058,10 @@ private[spark] object BlockManager extends Logging {
}
def blockIdsToBlockManagers(
- blockIds: Array[String],
+ blockIds: Array[BlockId],
env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null)
- : Map[String, Seq[BlockManagerId]] =
+ : Map[BlockId, Seq[BlockManagerId]] =
{
// env == null and blockManagerMaster != null is used in tests
assert (env != null || blockManagerMaster != null)
@@ -1018,7 +1071,7 @@ private[spark] object BlockManager extends Logging {
blockManagerMaster.getLocations(blockIds)
}
- val blockManagers = new HashMap[String, Seq[BlockManagerId]]
+ val blockManagers = new HashMap[BlockId, Seq[BlockManagerId]]
for (i <- 0 until blockIds.length) {
blockManagers(blockIds(i)) = blockLocations(i)
}
@@ -1026,19 +1079,19 @@ private[spark] object BlockManager extends Logging {
}
def blockIdsToExecutorIds(
- blockIds: Array[String],
+ blockIds: Array[BlockId],
env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null)
- : Map[String, Seq[String]] =
+ : Map[BlockId, Seq[String]] =
{
blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId))
}
def blockIdsToHosts(
- blockIds: Array[String],
+ blockIds: Array[BlockId],
env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null)
- : Map[String, Seq[String]] =
+ : Map[BlockId, Seq[String]] =
{
blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host))
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index cf463d6ffc..94038649b3 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -60,7 +60,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
def updateBlockInfo(
blockManagerId: BlockManagerId,
- blockId: String,
+ blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long): Boolean = {
@@ -71,12 +71,12 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
}
/** Get locations of the blockId from the driver */
- def getLocations(blockId: String): Seq[BlockManagerId] = {
+ def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId))
}
/** Get locations of multiple blockIds from the driver */
- def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+ def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
}
@@ -94,7 +94,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
* Remove a block from the slaves that have it. This can only be used to remove
* blocks that the driver knows about.
*/
- def removeBlock(blockId: String) {
+ def removeBlock(blockId: BlockId) {
askDriverWithReply(RemoveBlock(blockId))
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index c7b23ab094..633230c0a8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -48,7 +48,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId]
// Mapping from block id to the set of block managers that have the block.
- private val blockLocations = new JHashMap[String, mutable.HashSet[BlockManagerId]]
+ private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]]
val akkaTimeout = Duration.create(
System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
@@ -129,10 +129,9 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
// First remove the metadata for the given RDD, and then asynchronously remove the blocks
// from the slaves.
- val prefix = "rdd_" + rddId + "_"
// Find all blocks for the given RDD, remove the block from both blockLocations and
// the blockManagerInfo that is tracking the blocks.
- val blocks = blockLocations.keySet().filter(_.startsWith(prefix))
+ val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
blocks.foreach { blockId =>
val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId)
bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId)))
@@ -198,7 +197,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
// Remove a block from the slaves that have it. This can only be used to remove
// blocks that the master knows about.
- private def removeBlockFromWorkers(blockId: String) {
+ private def removeBlockFromWorkers(blockId: BlockId) {
val locations = blockLocations.get(blockId)
if (locations != null) {
locations.foreach { blockManagerId: BlockManagerId =>
@@ -247,7 +246,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
private def updateBlockInfo(
blockManagerId: BlockManagerId,
- blockId: String,
+ blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long) {
@@ -292,11 +291,11 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
sender ! true
}
- private def getLocations(blockId: String): Seq[BlockManagerId] = {
+ private def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty
}
- private def getLocationsMultipleBlockIds(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+ private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
blockIds.map(blockId => getLocations(blockId))
}
@@ -330,7 +329,7 @@ object BlockManagerMasterActor {
private var _remainingMem: Long = maxMem
// Mapping from block id to its status.
- private val _blocks = new JHashMap[String, BlockStatus]
+ private val _blocks = new JHashMap[BlockId, BlockStatus]
logInfo("Registering block manager %s with %s RAM".format(
blockManagerId.hostPort, Utils.bytesToString(maxMem)))
@@ -339,7 +338,7 @@ object BlockManagerMasterActor {
_lastSeenMs = System.currentTimeMillis()
}
- def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long,
+ def updateBlockInfo(blockId: BlockId, storageLevel: StorageLevel, memSize: Long,
diskSize: Long) {
updateLastSeenMs()
@@ -383,7 +382,7 @@ object BlockManagerMasterActor {
}
}
- def removeBlock(blockId: String) {
+ def removeBlock(blockId: BlockId) {
if (_blocks.containsKey(blockId)) {
_remainingMem += _blocks.get(blockId).memSize
_blocks.remove(blockId)
@@ -394,7 +393,7 @@ object BlockManagerMasterActor {
def lastSeenMs: Long = _lastSeenMs
- def blocks: JHashMap[String, BlockStatus] = _blocks
+ def blocks: JHashMap[BlockId, BlockStatus] = _blocks
override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 24333a179c..45f51da288 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -30,7 +30,7 @@ private[storage] object BlockManagerMessages {
// Remove a block from the slaves that have it. This can only be used to remove
// blocks that the master knows about.
- case class RemoveBlock(blockId: String) extends ToBlockManagerSlave
+ case class RemoveBlock(blockId: BlockId) extends ToBlockManagerSlave
// Remove all blocks belonging to a specific RDD.
case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave
@@ -51,7 +51,7 @@ private[storage] object BlockManagerMessages {
class UpdateBlockInfo(
var blockManagerId: BlockManagerId,
- var blockId: String,
+ var blockId: BlockId,
var storageLevel: StorageLevel,
var memSize: Long,
var diskSize: Long)
@@ -62,7 +62,7 @@ private[storage] object BlockManagerMessages {
override def writeExternal(out: ObjectOutput) {
blockManagerId.writeExternal(out)
- out.writeUTF(blockId)
+ out.writeUTF(blockId.name)
storageLevel.writeExternal(out)
out.writeLong(memSize)
out.writeLong(diskSize)
@@ -70,7 +70,7 @@ private[storage] object BlockManagerMessages {
override def readExternal(in: ObjectInput) {
blockManagerId = BlockManagerId(in)
- blockId = in.readUTF()
+ blockId = BlockId(in.readUTF())
storageLevel = StorageLevel(in)
memSize = in.readLong()
diskSize = in.readLong()
@@ -79,7 +79,7 @@ private[storage] object BlockManagerMessages {
object UpdateBlockInfo {
def apply(blockManagerId: BlockManagerId,
- blockId: String,
+ blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long): UpdateBlockInfo = {
@@ -87,14 +87,14 @@ private[storage] object BlockManagerMessages {
}
// For pattern-matching
- def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = {
+ def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, BlockId, StorageLevel, Long, Long)] = {
Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize))
}
}
- case class GetLocations(blockId: String) extends ToBlockManagerMaster
+ case class GetLocations(blockId: BlockId) extends ToBlockManagerMaster
- case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster
+ case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster
case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala
index acc3951088..365866d1e3 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala
@@ -28,7 +28,7 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar
val metricRegistry = new MetricRegistry()
val sourceName = "%s.BlockManager".format(sc.appName)
- metricRegistry.register(MetricRegistry.name("memory", "maxMem", "MBytes"), new Gauge[Long] {
+ metricRegistry.register(MetricRegistry.name("memory", "maxMem_MB"), new Gauge[Long] {
override def getValue: Long = {
val storageStatusList = blockManager.master.getStorageStatus
val maxMem = storageStatusList.map(_.maxMem).reduce(_ + _)
@@ -36,7 +36,7 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar
}
})
- metricRegistry.register(MetricRegistry.name("memory", "remainingMem", "MBytes"), new Gauge[Long] {
+ metricRegistry.register(MetricRegistry.name("memory", "remainingMem_MB"), new Gauge[Long] {
override def getValue: Long = {
val storageStatusList = blockManager.master.getStorageStatus
val remainingMem = storageStatusList.map(_.memRemaining).reduce(_ + _)
@@ -44,7 +44,7 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar
}
})
- metricRegistry.register(MetricRegistry.name("memory", "memUsed", "MBytes"), new Gauge[Long] {
+ metricRegistry.register(MetricRegistry.name("memory", "memUsed_MB"), new Gauge[Long] {
override def getValue: Long = {
val storageStatusList = blockManager.master.getStorageStatus
val maxMem = storageStatusList.map(_.maxMem).reduce(_ + _)
@@ -53,7 +53,7 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar
}
})
- metricRegistry.register(MetricRegistry.name("disk", "diskSpaceUsed", "MBytes"), new Gauge[Long] {
+ metricRegistry.register(MetricRegistry.name("disk", "diskSpaceUsed_MB"), new Gauge[Long] {
override def getValue: Long = {
val storageStatusList = blockManager.master.getStorageStatus
val diskSpaceUsed = storageStatusList
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
index 678c38203c..0c66addf9d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
@@ -77,7 +77,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
}
}
- private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) {
+ private def putBlock(id: BlockId, bytes: ByteBuffer, level: StorageLevel) {
val startTimeMs = System.currentTimeMillis()
logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes)
blockManager.putBytes(id, bytes, level)
@@ -85,7 +85,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
+ " with data size: " + bytes.limit)
}
- private def getBlock(id: String): ByteBuffer = {
+ private def getBlock(id: BlockId): ByteBuffer = {
val startTimeMs = System.currentTimeMillis()
logDebug("GetBlock " + id + " started from " + startTimeMs)
val buffer = blockManager.getLocalBytes(id) match {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
index d8fa6a91d1..80dcb5a207 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
@@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.network._
-private[spark] case class GetBlock(id: String)
-private[spark] case class GotBlock(id: String, data: ByteBuffer)
-private[spark] case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel)
+private[spark] case class GetBlock(id: BlockId)
+private[spark] case class GotBlock(id: BlockId, data: ByteBuffer)
+private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel)
private[spark] class BlockMessage() {
// Un-initialized: typ = 0
@@ -34,7 +34,7 @@ private[spark] class BlockMessage() {
// GotBlock: typ = 2
// PutBlock: typ = 3
private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED
- private var id: String = null
+ private var id: BlockId = null
private var data: ByteBuffer = null
private var level: StorageLevel = null
@@ -74,7 +74,7 @@ private[spark] class BlockMessage() {
for (i <- 1 to idLength) {
idBuilder += buffer.getChar()
}
- id = idBuilder.toString()
+ id = BlockId(idBuilder.toString)
if (typ == BlockMessage.TYPE_PUT_BLOCK) {
@@ -109,28 +109,17 @@ private[spark] class BlockMessage() {
set(buffer)
}
- def getType: Int = {
- return typ
- }
-
- def getId: String = {
- return id
- }
-
- def getData: ByteBuffer = {
- return data
- }
-
- def getLevel: StorageLevel = {
- return level
- }
+ def getType: Int = typ
+ def getId: BlockId = id
+ def getData: ByteBuffer = data
+ def getLevel: StorageLevel = level
def toBufferMessage: BufferMessage = {
val startTime = System.currentTimeMillis
val buffers = new ArrayBuffer[ByteBuffer]()
- var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2)
- buffer.putInt(typ).putInt(id.length())
- id.foreach((x: Char) => buffer.putChar(x))
+ var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2)
+ buffer.putInt(typ).putInt(id.name.length)
+ id.name.foreach((x: Char) => buffer.putChar(x))
buffer.flip()
buffers += buffer
@@ -212,7 +201,8 @@ private[spark] object BlockMessage {
def main(args: Array[String]) {
val B = new BlockMessage()
- B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2))
+ val blockId = TestBlockId("ABC")
+ B.set(new PutBlock(blockId, ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2))
val bMsg = B.toBufferMessage
val C = new BlockMessage()
C.set(bMsg)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
index 0aaf846b5b..6ce9127c74 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
@@ -111,14 +111,15 @@ private[spark] object BlockMessageArray {
}
def main(args: Array[String]) {
- val blockMessages =
+ val blockMessages =
(0 until 10).map { i =>
if (i % 2 == 0) {
val buffer = ByteBuffer.allocate(100)
buffer.clear
- BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY_SER))
+ BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer,
+ StorageLevel.MEMORY_ONLY_SER))
} else {
- BlockMessage.fromGetBlock(GetBlock(i.toString))
+ BlockMessage.fromGetBlock(GetBlock(TestBlockId(i.toString)))
}
}
val blockMessageArray = new BlockMessageArray(blockMessages)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 39f103297f..2a67800c45 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -25,7 +25,7 @@ package org.apache.spark.storage
*
* This interface does not support concurrent writes.
*/
-abstract class BlockObjectWriter(val blockId: String) {
+abstract class BlockObjectWriter(val blockId: BlockId) {
var closeEventHandler: () => Unit = _
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
index fa834371f4..ea42656240 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
@@ -27,7 +27,7 @@ import org.apache.spark.Logging
*/
private[spark]
abstract class BlockStore(val blockManager: BlockManager) extends Logging {
- def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel)
+ def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel)
/**
* Put in a block and, possibly, also return its content as either bytes or another Iterator.
@@ -36,26 +36,26 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging {
* @return a PutResult that contains the size of the data, as well as the values put if
* returnValues is true (if not, the result's data field can be null)
*/
- def putValues(blockId: String, values: ArrayBuffer[Any], level: StorageLevel,
+ def putValues(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
returnValues: Boolean) : PutResult
/**
* Return the size of a block in bytes.
*/
- def getSize(blockId: String): Long
+ def getSize(blockId: BlockId): Long
- def getBytes(blockId: String): Option[ByteBuffer]
+ def getBytes(blockId: BlockId): Option[ByteBuffer]
- def getValues(blockId: String): Option[Iterator[Any]]
+ def getValues(blockId: BlockId): Option[Iterator[Any]]
/**
* Remove a block, if it exists.
* @param blockId the block to remove.
* @return True if the block was found and removed, False otherwise.
*/
- def remove(blockId: String): Boolean
+ def remove(blockId: BlockId): Boolean
- def contains(blockId: String): Boolean
+ def contains(blockId: BlockId): Boolean
def clear() { }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index 63447baf8c..b7ca61e938 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -42,7 +42,7 @@ import org.apache.spark.util.Utils
private class DiskStore(blockManager: BlockManager, rootDirs: String)
extends BlockStore(blockManager) with Logging {
- class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ class DiskBlockObjectWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
extends BlockObjectWriter(blockId) {
private val f: File = createFile(blockId /*, allowAppendExisting */)
@@ -124,16 +124,16 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
addShutdownHook()
- def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ def getBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
new DiskBlockObjectWriter(blockId, serializer, bufferSize)
}
- override def getSize(blockId: String): Long = {
+ override def getSize(blockId: BlockId): Long = {
getFile(blockId).length()
}
- override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
// So that we do not modify the input offsets !
// duplicate does not copy buffer, so inexpensive
val bytes = _bytes.duplicate()
@@ -163,7 +163,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
override def putValues(
- blockId: String,
+ blockId: BlockId,
values: ArrayBuffer[Any],
level: StorageLevel,
returnValues: Boolean)
@@ -192,13 +192,13 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
}
- override def getBytes(blockId: String): Option[ByteBuffer] = {
+ override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
val file = getFile(blockId)
val bytes = getFileBytes(file)
Some(bytes)
}
- override def getValues(blockId: String): Option[Iterator[Any]] = {
+ override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
}
@@ -206,11 +206,11 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
* A version of getValues that allows a custom serializer. This is used as part of the
* shuffle short-circuit code.
*/
- def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer))
}
- override def remove(blockId: String): Boolean = {
+ override def remove(blockId: BlockId): Boolean = {
val file = getFile(blockId)
if (file.exists()) {
file.delete()
@@ -219,11 +219,11 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
}
- override def contains(blockId: String): Boolean = {
+ override def contains(blockId: BlockId): Boolean = {
getFile(blockId).exists()
}
- private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = {
+ private def createFile(blockId: BlockId, allowAppendExisting: Boolean = false): File = {
val file = getFile(blockId)
if (!allowAppendExisting && file.exists()) {
// NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
@@ -234,7 +234,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
file
}
- private def getFile(blockId: String): File = {
+ private def getFile(blockId: BlockId): File = {
logDebug("Getting file for block " + blockId)
// Figure out which local directory it hashes to, and which subdirectory in that
@@ -258,7 +258,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
}
- new File(subDir, blockId)
+ new File(subDir, blockId.name)
}
private def createLocalDirs(): Array[File] = {
@@ -307,7 +307,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
}
if (shuffleSender != null) {
- shuffleSender.stop
+ shuffleSender.stop()
}
}
})
@@ -315,11 +315,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private[storage] def startShuffleBlockSender(port: Int): Int = {
val pResolver = new PathResolver {
- override def getAbsolutePath(blockId: String): String = {
- if (!blockId.startsWith("shuffle_")) {
- return null
- }
- DiskStore.this.getFile(blockId).getAbsolutePath()
+ override def getAbsolutePath(blockIdString: String): String = {
+ val blockId = BlockId(blockIdString)
+ if (!blockId.isShuffle) null
+ else DiskStore.this.getFile(blockId).getAbsolutePath
}
}
shuffleSender = new ShuffleSender(port, pResolver)
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index 3b3b2342fa..05f676c6e2 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -30,10 +30,10 @@ import org.apache.spark.util.{SizeEstimator, Utils}
private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
extends BlockStore(blockManager) {
- case class Entry(value: Any, size: Long, deserialized: Boolean, var dropPending: Boolean = false)
+ case class Entry(value: Any, size: Long, deserialized: Boolean)
- private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true)
- private var currentMemory = 0L
+ private val entries = new LinkedHashMap[BlockId, Entry](32, 0.75f, true)
+ @volatile private var currentMemory = 0L
// Object used to ensure that only one thread is putting blocks and if necessary, dropping
// blocks from the memory store.
private val putLock = new Object()
@@ -42,13 +42,13 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
def freeMemory: Long = maxMemory - currentMemory
- override def getSize(blockId: String): Long = {
+ override def getSize(blockId: BlockId): Long = {
entries.synchronized {
entries.get(blockId).size
}
}
- override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
// Work on a duplicate - since the original input might be used elsewhere.
val bytes = _bytes.duplicate()
bytes.rewind()
@@ -64,7 +64,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
override def putValues(
- blockId: String,
+ blockId: BlockId,
values: ArrayBuffer[Any],
level: StorageLevel,
returnValues: Boolean)
@@ -81,7 +81,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def getBytes(blockId: String): Option[ByteBuffer] = {
+ override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
val entry = entries.synchronized {
entries.get(blockId)
}
@@ -94,7 +94,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def getValues(blockId: String): Option[Iterator[Any]] = {
+ override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
val entry = entries.synchronized {
entries.get(blockId)
}
@@ -108,11 +108,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def remove(blockId: String): Boolean = {
+ override def remove(blockId: BlockId): Boolean = {
entries.synchronized {
- val entry = entries.get(blockId)
+ val entry = entries.remove(blockId)
if (entry != null) {
- entries.remove(blockId)
currentMemory -= entry.size
logInfo("Block %s of size %d dropped from memory (free %d)".format(
blockId, entry.size, freeMemory))
@@ -126,19 +125,16 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
override def clear() {
entries.synchronized {
entries.clear()
+ currentMemory = 0
}
logInfo("MemoryStore cleared")
}
/**
- * Return the RDD ID that a given block ID is from, or null if it is not an RDD block.
+ * Return the RDD ID that a given block ID is from, or None if it is not an RDD block.
*/
- private def getRddId(blockId: String): String = {
- if (blockId.startsWith("rdd_")) {
- blockId.split('_')(1)
- } else {
- null
- }
+ private def getRddId(blockId: BlockId): Option[Int] = {
+ blockId.asRDDId.map(_.rddId)
}
/**
@@ -151,7 +147,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* blocks to free memory for one block, another thread may use up the freed space for
* another block.
*/
- private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = {
+ private def tryToPut(blockId: BlockId, value: Any, size: Long, deserialized: Boolean): Boolean = {
// TODO: Its possible to optimize the locking by locking entries only when selecting blocks
// to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been
// released, it must be ensured that those to-be-dropped blocks are not double counted for
@@ -160,8 +156,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
putLock.synchronized {
if (ensureFreeSpace(blockId, size)) {
val entry = new Entry(value, size, deserialized)
- entries.synchronized { entries.put(blockId, entry) }
- currentMemory += size
+ entries.synchronized {
+ entries.put(blockId, entry)
+ currentMemory += size
+ }
if (deserialized) {
logInfo("Block %s stored as values to memory (estimated size %s, free %s)".format(
blockId, Utils.bytesToString(size), Utils.bytesToString(freeMemory)))
@@ -193,7 +191,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* Assumes that a lock is held by the caller to ensure only one thread is dropping blocks.
* Otherwise, the freed space may fill up before the caller puts in their new value.
*/
- private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = {
+ private def ensureFreeSpace(blockIdToAdd: BlockId, space: Long): Boolean = {
logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format(
space, currentMemory, maxMemory))
@@ -205,7 +203,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
if (maxMemory - currentMemory < space) {
val rddToAdd = getRddId(blockIdToAdd)
- val selectedBlocks = new ArrayBuffer[String]()
+ val selectedBlocks = new ArrayBuffer[BlockId]()
var selectedMemory = 0L
// This is synchronized to ensure that the set of entries is not changed
@@ -216,7 +214,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) {
val pair = iterator.next()
val blockId = pair.getKey
- if (rddToAdd != null && rddToAdd == getRddId(blockId)) {
+ if (rddToAdd != None && rddToAdd == getRddId(blockId)) {
logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " +
"block from the same RDD")
return false
@@ -250,7 +248,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
return true
}
- override def contains(blockId: String): Boolean = {
+ override def contains(blockId: BlockId): Boolean = {
entries.synchronized { entries.containsKey(blockId) }
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index 9da11efb57..f39fcd87fb 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -30,7 +30,6 @@ trait ShuffleBlocks {
def releaseWriters(group: ShuffleWriterGroup)
}
-
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) {
@@ -40,7 +39,7 @@ class ShuffleBlockManager(blockManager: BlockManager) {
override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
- val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
+ val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
}
new ShuffleWriterGroup(mapId, writers)
@@ -52,16 +51,3 @@ class ShuffleBlockManager(blockManager: BlockManager) {
}
}
}
-
-
-private[spark]
-object ShuffleBlockManager {
-
- // Returns the block id for a given shuffle block.
- def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = {
- "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId
- }
-
- // Returns true if the block is a shuffle block.
- def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_")
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
index 2bb7715696..1720007e4e 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
@@ -23,20 +23,24 @@ import org.apache.spark.util.Utils
private[spark]
case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
- blocks: Map[String, BlockStatus]) {
+ blocks: Map[BlockId, BlockStatus]) {
- def memUsed(blockPrefix: String = "") = {
- blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize).
- reduceOption(_+_).getOrElse(0l)
- }
+ def memUsed() = blocks.values.map(_.memSize).reduceOption(_+_).getOrElse(0L)
- def diskUsed(blockPrefix: String = "") = {
- blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize).
- reduceOption(_+_).getOrElse(0l)
- }
+ def memUsedByRDD(rddId: Int) =
+ rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_+_).getOrElse(0L)
+
+ def diskUsed() = blocks.values.map(_.diskSize).reduceOption(_+_).getOrElse(0L)
+
+ def diskUsedByRDD(rddId: Int) =
+ rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_+_).getOrElse(0L)
def memRemaining : Long = maxMem - memUsed()
+ def rddBlocks = blocks.flatMap {
+ case (rdd: RDDBlockId, status) => Some(rdd, status)
+ case _ => None
+ }
}
case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
@@ -60,7 +64,7 @@ object StorageUtils {
/* Returns RDD-level information, compiled from a list of StorageStatus objects */
def rddInfoFromStorageStatus(storageStatusList: Seq[StorageStatus],
sc: SparkContext) : Array[RDDInfo] = {
- rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc)
+ rddInfoFromBlockStatusList(storageStatusList.flatMap(_.rddBlocks).toMap[RDDBlockId, BlockStatus], sc)
}
/* Returns a map of blocks to their locations, compiled from a list of StorageStatus objects */
@@ -71,26 +75,21 @@ object StorageUtils {
}
/* Given a list of BlockStatus objets, returns information for each RDD */
- def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
+ def rddInfoFromBlockStatusList(infos: Map[RDDBlockId, BlockStatus],
sc: SparkContext) : Array[RDDInfo] = {
// Group by rddId, ignore the partition name
- val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) =>
- k.substring(0,k.lastIndexOf('_'))
- }.mapValues(_.values.toArray)
+ val groupedRddBlocks = infos.groupBy { case(k, v) => k.rddId }.mapValues(_.values.toArray)
// For each RDD, generate an RDDInfo object
- val rddInfos = groupedRddBlocks.map { case (rddKey, rddBlocks) =>
+ val rddInfos = groupedRddBlocks.map { case (rddId, rddBlocks) =>
// Add up memory and disk sizes
val memSize = rddBlocks.map(_.memSize).reduce(_ + _)
val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _)
- // Find the id of the RDD, e.g. rdd_1 => 1
- val rddId = rddKey.split("_").last.toInt
-
// Get the friendly name and storage level for the RDD, if available
sc.persistentRdds.get(rddId).map { r =>
- val rddName = Option(r.name).getOrElse(rddKey)
+ val rddName = Option(r.name).getOrElse(rddId.toString)
val rddStorageLevel = r.getStorageLevel
RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize)
}
@@ -101,16 +100,14 @@ object StorageUtils {
rddInfos
}
- /* Removes all BlockStatus object that are not part of a block prefix */
- def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus],
- prefix: String) : Array[StorageStatus] = {
+ /* Filters storage status by a given RDD id. */
+ def filterStorageStatusByRDD(storageStatusList: Array[StorageStatus], rddId: Int)
+ : Array[StorageStatus] = {
storageStatusList.map { status =>
- val newBlocks = status.blocks.filterKeys(_.startsWith(prefix))
+ val newBlocks = status.rddBlocks.filterKeys(_.rddId == rddId).toMap[BlockId, BlockStatus]
//val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _)
StorageStatus(status.blockManagerId, status.maxMem, newBlocks)
}
-
}
-
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
index f2ae8dd97d..860e680576 100644
--- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
@@ -36,11 +36,11 @@ private[spark] object ThreadingTest {
val numBlocksPerProducer = 20000
private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread {
- val queue = new ArrayBlockingQueue[(String, Seq[Int])](100)
+ val queue = new ArrayBlockingQueue[(BlockId, Seq[Int])](100)
override def run() {
for (i <- 1 to numBlocksPerProducer) {
- val blockId = "b-" + id + "-" + i
+ val blockId = TestBlockId("b-" + id + "-" + i)
val blockSize = Random.nextInt(1000)
val block = (1 to blockSize).map(_ => Random.nextInt())
val level = randomLevel()
@@ -64,7 +64,7 @@ private[spark] object ThreadingTest {
private[spark] class ConsumerThread(
manager: BlockManager,
- queue: ArrayBlockingQueue[(String, Seq[Int])]
+ queue: ArrayBlockingQueue[(BlockId, Seq[Int])]
) extends Thread {
var numBlockConsumed = 0
diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
index 3ec9760ed0..fcd1b518d0 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
@@ -21,7 +21,7 @@ import scala.util.Random
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
-import org.apache.spark.scheduler.cluster.SchedulingMode
+import org.apache.spark.scheduler.SchedulingMode
/**
@@ -35,7 +35,7 @@ private[spark] object UIWorkloadGenerator {
def main(args: Array[String]) {
if (args.length < 2) {
- println("usage: ./spark-class spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]")
+ println("usage: ./spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]")
System.exit(1)
}
val master = args(0)
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
index d1868dcf78..42e9be6e19 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
@@ -26,8 +26,8 @@ import org.eclipse.jetty.server.Handler
import org.apache.spark.{ExceptionFailure, Logging, SparkContext}
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.scheduler.cluster.TaskInfo
import org.apache.spark.scheduler.{SparkListenerTaskStart, SparkListenerTaskEnd, SparkListener}
+import org.apache.spark.scheduler.TaskInfo
import org.apache.spark.ui.JettyUtils._
import org.apache.spark.ui.Page.Executors
import org.apache.spark.ui.UIUtils
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
index 3b428effaf..b39c0e9769 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
@@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.{NodeSeq, Node}
-import org.apache.spark.scheduler.cluster.SchedulingMode
+import org.apache.spark.scheduler.SchedulingMode
import org.apache.spark.ui.Page._
import org.apache.spark.ui.UIUtils._
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 5d46f38a2a..eb3b4e8522 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -21,10 +21,8 @@ import scala.Seq
import scala.collection.mutable.{ListBuffer, HashMap, HashSet}
import org.apache.spark.{ExceptionFailure, SparkContext, Success}
-import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster.TaskInfo
import org.apache.spark.executor.TaskMetrics
-import collection.mutable
+import org.apache.spark.scheduler._
/**
* Tracks task-level information to be displayed in the UI.
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
index 6aecef5120..e7eab374ad 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
@@ -32,8 +32,8 @@ import org.apache.spark.ui.JettyUtils._
import org.apache.spark.{ExceptionFailure, SparkContext, Success}
import org.apache.spark.scheduler._
import collection.mutable
-import org.apache.spark.scheduler.cluster.SchedulingMode
-import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
+import org.apache.spark.scheduler.SchedulingMode
+import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.util.Utils
/** Web UI showing progress status of all jobs in the given SparkContext. */
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
index b3d3666944..06810d8dbc 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
@@ -21,8 +21,7 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.xml.Node
-import org.apache.spark.scheduler.Stage
-import org.apache.spark.scheduler.cluster.Schedulable
+import org.apache.spark.scheduler.{Schedulable, Stage}
import org.apache.spark.ui.UIUtils
/** Table showing list of pools */
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index a9969ab1c0..163a3746ea 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -23,12 +23,12 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.Node
+import org.apache.spark.{ExceptionFailure}
+import org.apache.spark.executor.TaskMetrics
import org.apache.spark.ui.UIUtils._
import org.apache.spark.ui.Page._
import org.apache.spark.util.{Utils, Distribution}
-import org.apache.spark.{ExceptionFailure}
-import org.apache.spark.scheduler.cluster.TaskInfo
-import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.scheduler.TaskInfo
/** Page showing statistics and task list for a given stage */
private[spark] class StagePage(parent: JobProgressUI) {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 32776eaa25..07db8622da 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -22,8 +22,7 @@ import java.util.Date
import scala.xml.Node
import scala.collection.mutable.HashSet
-import org.apache.spark.scheduler.cluster.{SchedulingMode, TaskInfo}
-import org.apache.spark.scheduler.Stage
+import org.apache.spark.scheduler.{SchedulingMode, Stage, TaskInfo}
import org.apache.spark.ui.UIUtils
import org.apache.spark.util.Utils
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
index 43c1257677..b83cd54f3c 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
@@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.Node
-import org.apache.spark.storage.{StorageStatus, StorageUtils}
+import org.apache.spark.storage.{BlockId, StorageStatus, StorageUtils}
import org.apache.spark.storage.BlockManagerMasterActor.BlockStatus
import org.apache.spark.ui.UIUtils._
import org.apache.spark.ui.Page._
@@ -33,21 +33,20 @@ private[spark] class RDDPage(parent: BlockManagerUI) {
val sc = parent.sc
def render(request: HttpServletRequest): Seq[Node] = {
- val id = request.getParameter("id")
- val prefix = "rdd_" + id.toString
+ val id = request.getParameter("id").toInt
val storageStatusList = sc.getExecutorStorageStatus
- val filteredStorageStatusList = StorageUtils.
- filterStorageStatusByPrefix(storageStatusList, prefix)
+ val filteredStorageStatusList = StorageUtils.filterStorageStatusByRDD(storageStatusList, id)
val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
val workerHeaders = Seq("Host", "Memory Usage", "Disk Usage")
- val workers = filteredStorageStatusList.map((prefix, _))
+ val workers = filteredStorageStatusList.map((id, _))
val workerTable = listingTable(workerHeaders, workerRow, workers)
val blockHeaders = Seq("Block Name", "Storage Level", "Size in Memory", "Size on Disk",
"Executors")
- val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1)
+ val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.
+ sortWith(_._1.name < _._1.name)
val blockLocations = StorageUtils.blockLocationsFromStorageStatus(filteredStorageStatusList)
val blocks = blockStatuses.map {
case(id, status) => (id, status, blockLocations.get(id).getOrElse(Seq("UNKNOWN")))
@@ -99,7 +98,7 @@ private[spark] class RDDPage(parent: BlockManagerUI) {
headerSparkPage(content, parent.sc, "RDD Storage Info for " + rddInfo.name, Storage)
}
- def blockRow(row: (String, BlockStatus, Seq[String])): Seq[Node] = {
+ def blockRow(row: (BlockId, BlockStatus, Seq[String])): Seq[Node] = {
val (id, block, locations) = row
<tr>
<td>{id}</td>
@@ -118,15 +117,15 @@ private[spark] class RDDPage(parent: BlockManagerUI) {
</tr>
}
- def workerRow(worker: (String, StorageStatus)): Seq[Node] = {
- val (prefix, status) = worker
+ def workerRow(worker: (Int, StorageStatus)): Seq[Node] = {
+ val (rddId, status) = worker
<tr>
<td>{status.blockManagerId.host + ":" + status.blockManagerId.port}</td>
<td>
- {Utils.bytesToString(status.memUsed(prefix))}
+ {Utils.bytesToString(status.memUsedByRDD(rddId))}
({Utils.bytesToString(status.memRemaining)} Remaining)
</td>
- <td>{Utils.bytesToString(status.diskUsed(prefix))}</td>
+ <td>{Utils.bytesToString(status.diskUsedByRDD(rddId))}</td>
</tr>
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
new file mode 100644
index 0000000000..f60deafc6f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+/**
+ * A simple open hash table optimized for the append-only use case, where keys
+ * are never removed, but the value for each key may be changed.
+ *
+ * This implementation uses quadratic probing with a power-of-2 hash table
+ * size, which is guaranteed to explore all spaces for each key (see
+ * http://en.wikipedia.org/wiki/Quadratic_probing).
+ *
+ * TODO: Cache the hash values of each key? java.util.HashMap does that.
+ */
+private[spark]
+class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable {
+ require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
+ require(initialCapacity >= 1, "Invalid initial capacity")
+
+ private var capacity = nextPowerOf2(initialCapacity)
+ private var mask = capacity - 1
+ private var curSize = 0
+
+ // Holds keys and values in the same array for memory locality; specifically, the order of
+ // elements is key0, value0, key1, value1, key2, value2, etc.
+ private var data = new Array[AnyRef](2 * capacity)
+
+ // Treat the null key differently so we can use nulls in "data" to represent empty items.
+ private var haveNullValue = false
+ private var nullValue: V = null.asInstanceOf[V]
+
+ private val LOAD_FACTOR = 0.7
+
+ /** Get the value for a given key */
+ def apply(key: K): V = {
+ val k = key.asInstanceOf[AnyRef]
+ if (k.eq(null)) {
+ return nullValue
+ }
+ var pos = rehash(k.hashCode) & mask
+ var i = 1
+ while (true) {
+ val curKey = data(2 * pos)
+ if (k.eq(curKey) || k == curKey) {
+ return data(2 * pos + 1).asInstanceOf[V]
+ } else if (curKey.eq(null)) {
+ return null.asInstanceOf[V]
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
+ }
+ return null.asInstanceOf[V]
+ }
+
+ /** Set the value for a key */
+ def update(key: K, value: V): Unit = {
+ val k = key.asInstanceOf[AnyRef]
+ if (k.eq(null)) {
+ if (!haveNullValue) {
+ incrementSize()
+ }
+ nullValue = value
+ haveNullValue = true
+ return
+ }
+ val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef])
+ if (isNewEntry) {
+ incrementSize()
+ }
+ }
+
+ /**
+ * Set the value for key to updateFunc(hadValue, oldValue), where oldValue will be the old value
+ * for key, if any, or null otherwise. Returns the newly updated value.
+ */
+ def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
+ val k = key.asInstanceOf[AnyRef]
+ if (k.eq(null)) {
+ if (!haveNullValue) {
+ incrementSize()
+ }
+ nullValue = updateFunc(haveNullValue, nullValue)
+ haveNullValue = true
+ return nullValue
+ }
+ var pos = rehash(k.hashCode) & mask
+ var i = 1
+ while (true) {
+ val curKey = data(2 * pos)
+ if (k.eq(curKey) || k == curKey) {
+ val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
+ data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
+ return newValue
+ } else if (curKey.eq(null)) {
+ val newValue = updateFunc(false, null.asInstanceOf[V])
+ data(2 * pos) = k
+ data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
+ incrementSize()
+ return newValue
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
+ }
+ null.asInstanceOf[V] // Never reached but needed to keep compiler happy
+ }
+
+ /** Iterator method from Iterable */
+ override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
+ var pos = -1
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def nextValue(): (K, V) = {
+ if (pos == -1) { // Treat position -1 as looking at the null value
+ if (haveNullValue) {
+ return (null.asInstanceOf[K], nullValue)
+ }
+ pos += 1
+ }
+ while (pos < capacity) {
+ if (!data(2 * pos).eq(null)) {
+ return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
+ }
+ pos += 1
+ }
+ null
+ }
+
+ override def hasNext: Boolean = nextValue() != null
+
+ override def next(): (K, V) = {
+ val value = nextValue()
+ if (value == null) {
+ throw new NoSuchElementException("End of iterator")
+ }
+ pos += 1
+ value
+ }
+ }
+
+ override def size: Int = curSize
+
+ /** Increase table size by 1, rehashing if necessary */
+ private def incrementSize() {
+ curSize += 1
+ if (curSize > LOAD_FACTOR * capacity) {
+ growTable()
+ }
+ }
+
+ /**
+ * Re-hash a value to deal better with hash functions that don't differ
+ * in the lower bits, similar to java.util.HashMap
+ */
+ private def rehash(h: Int): Int = {
+ val r = h ^ (h >>> 20) ^ (h >>> 12)
+ r ^ (r >>> 7) ^ (r >>> 4)
+ }
+
+ /**
+ * Put an entry into a table represented by data, returning true if
+ * this increases the size of the table or false otherwise. Assumes
+ * that "data" has at least one empty slot.
+ */
+ private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = {
+ val mask = (data.length / 2) - 1
+ var pos = rehash(key.hashCode) & mask
+ var i = 1
+ while (true) {
+ val curKey = data(2 * pos)
+ if (curKey.eq(null)) {
+ data(2 * pos) = key
+ data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+ return true
+ } else if (curKey.eq(key) || curKey == key) {
+ data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+ return false
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
+ }
+ return false // Never reached but needed to keep compiler happy
+ }
+
+ /** Double the table's size and re-hash everything */
+ private def growTable() {
+ val newCapacity = capacity * 2
+ if (newCapacity >= (1 << 30)) {
+ // We can't make the table this big because we want an array of 2x
+ // that size for our data, but array sizes are at most Int.MaxValue
+ throw new Exception("Can't make capacity bigger than 2^29 elements")
+ }
+ val newData = new Array[AnyRef](2 * newCapacity)
+ var pos = 0
+ while (pos < capacity) {
+ if (!data(2 * pos).eq(null)) {
+ putInto(newData, data(2 * pos), data(2 * pos + 1))
+ }
+ pos += 1
+ }
+ data = newData
+ capacity = newCapacity
+ mask = newCapacity - 1
+ }
+
+ private def nextPowerOf2(n: Int): Int = {
+ val highBit = Integer.highestOneBit(n)
+ if (highBit == n) n else highBit << 1
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index a430a75451..0ce1394c77 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -17,7 +17,6 @@
package org.apache.spark.util
-import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors}
import java.util.{TimerTask, Timer}
import org.apache.spark.Logging
@@ -25,11 +24,14 @@ import org.apache.spark.Logging
/**
* Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries)
*/
-class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging {
+class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, cleanupFunc: (Long) => Unit) extends Logging {
+ val name = cleanerType.toString
+
private val delaySeconds = MetadataCleaner.getDelaySeconds
private val periodSeconds = math.max(10, delaySeconds / 10)
private val timer = new Timer(name + " cleanup timer", true)
+
private val task = new TimerTask {
override def run() {
try {
@@ -53,9 +55,37 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging
}
}
+object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext", "HttpBroadcast", "DagScheduler", "ResultTask",
+ "ShuffleMapTask", "BlockManager", "BroadcastVars") {
+
+ val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, SHUFFLE_MAP_TASK, BLOCK_MANAGER, BROADCAST_VARS = Value
+
+ type MetadataCleanerType = Value
+
+ def systemProperty(which: MetadataCleanerType.MetadataCleanerType) = "spark.cleaner.ttl." + which.toString
+}
object MetadataCleaner {
+
+ // using only sys props for now : so that workers can also get to it while preserving earlier behavior.
def getDelaySeconds = System.getProperty("spark.cleaner.ttl", "-1").toInt
- def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.ttl", delay.toString) }
+
+ def getDelaySeconds(cleanerType: MetadataCleanerType.MetadataCleanerType): Int = {
+ System.getProperty(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds.toString).toInt
+ }
+
+ def setDelaySeconds(cleanerType: MetadataCleanerType.MetadataCleanerType, delay: Int) {
+ System.setProperty(MetadataCleanerType.systemProperty(cleanerType), delay.toString)
+ }
+
+ def setDelaySeconds(delay: Int, resetAll: Boolean = true) {
+ // override for all ?
+ System.setProperty("spark.cleaner.ttl", delay.toString)
+ if (resetAll) {
+ for (cleanerType <- MetadataCleanerType.values) {
+ System.clearProperty(MetadataCleanerType.systemProperty(cleanerType))
+ }
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 886f071503..f384875cc9 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -70,6 +70,19 @@ private[spark] object Utils extends Logging {
return ois.readObject.asInstanceOf[T]
}
+ /** Deserialize a Long value (used for {@link org.apache.spark.api.python.PythonPartitioner}) */
+ def deserializeLongValue(bytes: Array[Byte]) : Long = {
+ // Note: we assume that we are given a Long value encoded in network (big-endian) byte order
+ var result = bytes(7) & 0xFFL
+ result = result + ((bytes(6) & 0xFFL) << 8)
+ result = result + ((bytes(5) & 0xFFL) << 16)
+ result = result + ((bytes(4) & 0xFFL) << 24)
+ result = result + ((bytes(3) & 0xFFL) << 32)
+ result = result + ((bytes(2) & 0xFFL) << 40)
+ result = result + ((bytes(1) & 0xFFL) << 48)
+ result + ((bytes(0) & 0xFFL) << 56)
+ }
+
/** Serialize via nested stream using specific serializer */
def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)(f: SerializationStream => Unit) = {
val osWrapper = ser.serializeStream(new OutputStream {
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 3a7171c488..ced036c58d 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.mock.EasyMockSugar
import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.{BlockManager, StorageLevel}
+import org.apache.spark.storage.{BlockManager, RDDBlockId, StorageLevel}
// TODO: Test the CacheManager's thread-safety aspects
class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar {
@@ -52,9 +52,9 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
test("get uncached rdd") {
expecting {
- blockManager.get("rdd_0_0").andReturn(None)
- blockManager.put("rdd_0_0", ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY, true).
- andReturn(0)
+ blockManager.get(RDDBlockId(0, 0)).andReturn(None)
+ blockManager.put(RDDBlockId(0, 0), ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY,
+ true).andReturn(0)
}
whenExecuting(blockManager) {
@@ -66,7 +66,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
test("get cached rdd") {
expecting {
- blockManager.get("rdd_0_0").andReturn(Some(ArrayBuffer(5, 6, 7).iterator))
+ blockManager.get(RDDBlockId(0, 0)).andReturn(Some(ArrayBuffer(5, 6, 7).iterator))
}
whenExecuting(blockManager) {
@@ -79,7 +79,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
test("get uncached local rdd") {
expecting {
// Local computation should not persist the resulting value, so don't expect a put().
- blockManager.get("rdd_0_0").andReturn(None)
+ blockManager.get(RDDBlockId(0, 0)).andReturn(None)
}
whenExecuting(blockManager) {
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index d9103aebb7..7ca5f16202 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -21,7 +21,7 @@ import org.scalatest.FunSuite
import java.io.File
import org.apache.spark.rdd._
import org.apache.spark.SparkContext._
-import storage.StorageLevel
+import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
import org.apache.spark.util.Utils
class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
@@ -83,7 +83,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
test("BlockRDD") {
- val blockId = "id"
+ val blockId = TestBlockId("id")
val blockManager = SparkEnv.get.blockManager
blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
val blockRDD = new BlockRDD[String](sc, Array(blockId))
@@ -191,7 +191,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
test("CheckpointRDD with zero partitions") {
- val rdd = new BlockRDD[Int](sc, Array[String]())
+ val rdd = new BlockRDD[Int](sc, Array[BlockId]())
assert(rdd.partitions.size === 0)
assert(rdd.isCheckpointed === false)
rdd.checkpoint()
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 7a856d4081..480bac84f3 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -18,24 +18,14 @@
package org.apache.spark
import network.ConnectionManagerId
-import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Timeouts._
+import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
-import org.scalatest.prop.Checkers
import org.scalatest.time.{Span, Millis}
-import org.scalacheck.Arbitrary._
-import org.scalacheck.Gen
-import org.scalacheck.Prop._
-import org.eclipse.jetty.server.{Server, Request, Handler}
-
-import com.google.common.io.Files
-
-import scala.collection.mutable.ArrayBuffer
import SparkContext._
-import storage.{GetBlock, BlockManagerWorker, StorageLevel}
-import ui.JettyUtils
+import org.apache.spark.storage.{BlockManagerWorker, GetBlock, RDDBlockId, StorageLevel}
class NotSerializableClass
@@ -193,7 +183,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
// Get all the locations of the first partition and try to fetch the partitions
// from those locations.
- val blockIds = data.partitions.indices.map(index => "rdd_%d_%d".format(data.id, index)).toArray
+ val blockIds = data.partitions.indices.map(index => RDDBlockId(data.id, index)).toArray
val blockId = blockIds(0)
val blockManager = SparkEnv.get.blockManager
blockManager.master.getLocations(blockId).foreach(id => {
@@ -319,19 +309,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
}
}
}
-
- test("job should fail if TaskResult exceeds Akka frame size") {
- // We must use local-cluster mode since results are returned differently
- // when running under LocalScheduler:
- sc = new SparkContext("local-cluster[1,1,512]", "test")
- val akkaFrameSize =
- sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
- val rdd = sc.parallelize(Seq(1)).map{x => new Array[Byte](akkaFrameSize)}
- val exception = intercept[SparkException] {
- rdd.reduce((x, y) => x)
- }
- exception.getMessage should endWith("result exceeded Akka frame size")
- }
}
object DistributedSuite {
diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala
index 69383ddfb8..75d6493e33 100644
--- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala
@@ -40,7 +40,7 @@ object ThreadingSuiteState {
}
class ThreadingSuite extends FunSuite with LocalSparkContext {
-
+
test("accessing SparkContext form a different thread") {
sc = new SparkContext("local", "test")
val nums = sc.parallelize(1 to 10, 2)
@@ -149,4 +149,47 @@ class ThreadingSuite extends FunSuite with LocalSparkContext {
fail("One or more threads didn't see runningThreads = 4")
}
}
+
+ test("set local properties in different thread") {
+ sc = new SparkContext("local", "test")
+ val sem = new Semaphore(0)
+
+ val threads = (1 to 5).map { i =>
+ new Thread() {
+ override def run() {
+ sc.setLocalProperty("test", i.toString)
+ assert(sc.getLocalProperty("test") === i.toString)
+ sem.release()
+ }
+ }
+ }
+
+ threads.foreach(_.start())
+
+ sem.acquire(5)
+ assert(sc.getLocalProperty("test") === null)
+ }
+
+ test("set and get local properties in parent-children thread") {
+ sc = new SparkContext("local", "test")
+ sc.setLocalProperty("test", "parent")
+ val sem = new Semaphore(0)
+
+ val threads = (1 to 5).map { i =>
+ new Thread() {
+ override def run() {
+ assert(sc.getLocalProperty("test") === "parent")
+ sc.setLocalProperty("test", i.toString)
+ assert(sc.getLocalProperty("test") === i.toString)
+ sem.release()
+ }
+ }
+ }
+
+ threads.foreach(_.start())
+
+ sem.acquire(5)
+ assert(sc.getLocalProperty("test") === "parent")
+ assert(sc.getLocalProperty("Foo") === null)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
index 05f8545c7b..0b38e239f9 100644
--- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
@@ -25,7 +25,7 @@ import net.liftweb.json.JsonAST.JValue
import org.scalatest.FunSuite
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse}
-import org.apache.spark.deploy.master.{ApplicationInfo, WorkerInfo}
+import org.apache.spark.deploy.master.{ApplicationInfo, RecoveryState, WorkerInfo}
import org.apache.spark.deploy.worker.ExecutorRunner
class JsonProtocolSuite extends FunSuite {
@@ -53,7 +53,8 @@ class JsonProtocolSuite extends FunSuite {
val workers = Array[WorkerInfo](createWorkerInfo(), createWorkerInfo())
val activeApps = Array[ApplicationInfo](createAppInfo())
val completedApps = Array[ApplicationInfo]()
- val stateResponse = new MasterStateResponse("host", 8080, workers, activeApps, completedApps)
+ val stateResponse = new MasterStateResponse("host", 8080, workers, activeApps, completedApps,
+ RecoveryState.ALIVE)
val output = JsonProtocol.writeMasterState(stateResponse)
assertValidJson(output)
}
@@ -79,7 +80,7 @@ class JsonProtocolSuite extends FunSuite {
}
def createExecutorRunner() : ExecutorRunner = {
new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host",
- new File("sparkHome"), new File("workDir"))
+ new File("sparkHome"), new File("workDir"), ExecutorState.RUNNING)
}
def assertValidJson(json: JValue) {
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index c1df5e151e..6d1bc5e296 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -25,7 +25,6 @@ import org.apache.spark.SparkContext._
import org.apache.spark.rdd._
import scala.collection.parallel.mutable
import org.apache.spark._
-import org.apache.spark.rdd.CoalescedRDDPartition
class RDDSuite extends FunSuite with SharedSparkContext {
@@ -321,6 +320,44 @@ class RDDSuite extends FunSuite with SharedSparkContext {
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
}
+ test("take") {
+ var nums = sc.makeRDD(Range(1, 1000), 1)
+ assert(nums.take(0).size === 0)
+ assert(nums.take(1) === Array(1))
+ assert(nums.take(3) === Array(1, 2, 3))
+ assert(nums.take(500) === (1 to 500).toArray)
+ assert(nums.take(501) === (1 to 501).toArray)
+ assert(nums.take(999) === (1 to 999).toArray)
+ assert(nums.take(1000) === (1 to 999).toArray)
+
+ nums = sc.makeRDD(Range(1, 1000), 2)
+ assert(nums.take(0).size === 0)
+ assert(nums.take(1) === Array(1))
+ assert(nums.take(3) === Array(1, 2, 3))
+ assert(nums.take(500) === (1 to 500).toArray)
+ assert(nums.take(501) === (1 to 501).toArray)
+ assert(nums.take(999) === (1 to 999).toArray)
+ assert(nums.take(1000) === (1 to 999).toArray)
+
+ nums = sc.makeRDD(Range(1, 1000), 100)
+ assert(nums.take(0).size === 0)
+ assert(nums.take(1) === Array(1))
+ assert(nums.take(3) === Array(1, 2, 3))
+ assert(nums.take(500) === (1 to 500).toArray)
+ assert(nums.take(501) === (1 to 501).toArray)
+ assert(nums.take(999) === (1 to 999).toArray)
+ assert(nums.take(1000) === (1 to 999).toArray)
+
+ nums = sc.makeRDD(Range(1, 1000), 1000)
+ assert(nums.take(0).size === 0)
+ assert(nums.take(1) === Array(1))
+ assert(nums.take(3) === Array(1, 2, 3))
+ assert(nums.take(500) === (1 to 500).toArray)
+ assert(nums.take(501) === (1 to 501).toArray)
+ assert(nums.take(999) === (1 to 999).toArray)
+ assert(nums.take(1000) === (1 to 999).toArray)
+ }
+
test("top with predefined ordering") {
val nums = Array.range(1, 100000)
val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 94f66c94c6..3952ee9264 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -30,11 +30,9 @@ import org.apache.spark.Partition
import org.apache.spark.TaskContext
import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency}
import org.apache.spark.{FetchFailed, Success, TaskEndReason}
-import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster}
+import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
-import org.apache.spark.scheduler.cluster.Pool
-import org.apache.spark.scheduler.cluster.SchedulingMode
-import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
+import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
/**
* Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
@@ -77,15 +75,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
// stub out BlockManagerMaster.getLocations to use our cacheLocations
val blockManagerMaster = new BlockManagerMaster(null) {
- override def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
- blockIds.map { name =>
- val pieces = name.split("_")
- if (pieces(0) == "rdd") {
- val key = pieces(1).toInt -> pieces(2).toInt
- cacheLocations.getOrElse(key, Seq())
- } else {
- Seq()
- }
+ override def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
+ blockIds.map {
+ _.asRDDId.map(id => (id.rddId -> id.splitIndex)).flatMap(key => cacheLocations.get(key)).
+ getOrElse(Seq())
}.toSeq
}
override def removeExecutor(execId: String) {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index aac7c207cb..a549417a47 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -23,10 +23,6 @@ import scala.collection.mutable
import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.SparkContext._
-/**
- *
- */
-
class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
test("local metrics") {
@@ -42,7 +38,9 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)}
- d.count
+ d.count()
+ val WAIT_TIMEOUT_MILLIS = 10000
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be (1)
val d2 = d.map{i => w(i) -> i * 2}.setName("shuffle input 1")
@@ -52,19 +50,27 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => w(k) -> (v1.size, v2.size)}
d4.setName("A Cogroup")
- d4.collectAsMap
+ d4.collectAsMap()
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be (4)
- listener.stageInfos.foreach {stageInfo =>
- //small test, so some tasks might take less than 1 millisecond, but average should be greater than 1 ms
+ listener.stageInfos.foreach { stageInfo =>
+ /* small test, so some tasks might take less than 1 millisecond, but average should be greater
+ * than 0 ms. */
checkNonZeroAvg(stageInfo.taskInfos.map{_._1.duration}, stageInfo + " duration")
- checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorRunTime.toLong}, stageInfo + " executorRunTime")
- checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong}, stageInfo + " executorDeserializeTime")
+ checkNonZeroAvg(
+ stageInfo.taskInfos.map{_._2.executorRunTime.toLong},
+ stageInfo + " executorRunTime")
+ checkNonZeroAvg(
+ stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong},
+ stageInfo + " executorDeserializeTime")
if (stageInfo.stage.rdd.name == d4.name) {
- checkNonZeroAvg(stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime}, stageInfo + " fetchWaitTime")
+ checkNonZeroAvg(
+ stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime},
+ stageInfo + " fetchWaitTime")
}
- stageInfo.taskInfos.foreach{case (taskInfo, taskMetrics) =>
+ stageInfo.taskInfos.foreach { case (taskInfo, taskMetrics) =>
taskMetrics.resultSize should be > (0l)
if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) {
taskMetrics.shuffleWriteMetrics should be ('defined)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
index 1b50ce06b3..95d3553d91 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
@@ -43,16 +43,16 @@ class FakeTaskSetManager(
stageId = initStageId
name = "TaskSet_"+stageId
override val numTasks = initNumTasks
- tasksFinished = 0
+ tasksSuccessful = 0
- override def increaseRunningTasks(taskNum: Int) {
+ def increaseRunningTasks(taskNum: Int) {
runningTasks += taskNum
if (parent != null) {
parent.increaseRunningTasks(taskNum)
}
}
- override def decreaseRunningTasks(taskNum: Int) {
+ def decreaseRunningTasks(taskNum: Int) {
runningTasks -= taskNum
if (parent != null) {
parent.decreaseRunningTasks(taskNum)
@@ -79,7 +79,7 @@ class FakeTaskSetManager(
maxLocality: TaskLocality.TaskLocality)
: Option[TaskDescription] =
{
- if (tasksFinished + runningTasks < numTasks) {
+ if (tasksSuccessful + runningTasks < numTasks) {
increaseRunningTasks(1)
return Some(new TaskDescription(0, execId, "task 0:0", 0, null))
}
@@ -92,8 +92,8 @@ class FakeTaskSetManager(
def taskFinished() {
decreaseRunningTasks(1)
- tasksFinished +=1
- if (tasksFinished == numTasks) {
+ tasksSuccessful +=1
+ if (tasksSuccessful == numTasks) {
parent.removeSchedulable(this)
}
}
@@ -114,7 +114,8 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
val taskSetQueue = rootPool.getSortedTaskSetQueue()
/* Just for Test*/
for (manager <- taskSetQueue) {
- logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks))
+ logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(
+ manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks))
}
for (taskSet <- taskSetQueue) {
taskSet.resourceOffer("execId_1", "hostname_1", 1, TaskLocality.ANY) match {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
index ff70a2cdf0..80d0c5a5e9 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
@@ -40,6 +40,7 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /*
val startedTasks = new ArrayBuffer[Long]
val endedTasks = new mutable.HashMap[Long, TaskEndReason]
val finishedManagers = new ArrayBuffer[TaskSetManager]
+ val taskSetsFailed = new ArrayBuffer[String]
val executors = new mutable.HashMap[String, String] ++ liveExecutors
@@ -63,7 +64,9 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /*
def executorLost(execId: String) {}
- def taskSetFailed(taskSet: TaskSet, reason: String) {}
+ def taskSetFailed(taskSet: TaskSet, reason: String) {
+ taskSetsFailed += taskSet.id
+ }
}
def removeExecutor(execId: String): Unit = executors -= execId
@@ -101,7 +104,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
assert(manager.resourceOffer("exec1", "host1", 2, PROCESS_LOCAL) === None)
// Tell it the task has finished
- manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0))
+ manager.handleSuccessfulTask(0, createTaskResult(0))
assert(sched.endedTasks(0) === Success)
assert(sched.finishedManagers.contains(manager))
}
@@ -125,14 +128,14 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None)
// Finish the first two tasks
- manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0))
- manager.statusUpdate(1, TaskState.FINISHED, createTaskResult(1))
+ manager.handleSuccessfulTask(0, createTaskResult(0))
+ manager.handleSuccessfulTask(1, createTaskResult(1))
assert(sched.endedTasks(0) === Success)
assert(sched.endedTasks(1) === Success)
assert(!sched.finishedManagers.contains(manager))
// Finish the last task
- manager.statusUpdate(2, TaskState.FINISHED, createTaskResult(2))
+ manager.handleSuccessfulTask(2, createTaskResult(2))
assert(sched.endedTasks(2) === Success)
assert(sched.finishedManagers.contains(manager))
}
@@ -253,6 +256,47 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None)
}
+ test("task result lost") {
+ sc = new SparkContext("local", "test")
+ val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
+ val taskSet = createTaskSet(1)
+ val clock = new FakeClock
+ val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
+
+ // Tell it the task has finished but the result was lost.
+ manager.handleFailedTask(0, TaskState.FINISHED, Some(TaskResultLost))
+ assert(sched.endedTasks(0) === TaskResultLost)
+
+ // Re-offer the host -- now we should get task 0 again.
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
+ }
+
+ test("repeated failures lead to task set abortion") {
+ sc = new SparkContext("local", "test")
+ val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
+ val taskSet = createTaskSet(1)
+ val clock = new FakeClock
+ val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+
+ // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted
+ // after the last failure.
+ (0 until manager.MAX_TASK_FAILURES).foreach { index =>
+ val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY)
+ assert(offerResult != None,
+ "Expect resource offer on iteration %s to return a task".format(index))
+ assert(offerResult.get.index === 0)
+ manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, Some(TaskResultLost))
+ if (index < manager.MAX_TASK_FAILURES) {
+ assert(!sched.taskSetsFailed.contains(taskSet.id))
+ } else {
+ assert(sched.taskSetsFailed.contains(taskSet.id))
+ }
+ }
+ }
+
+
/**
* Utility method to create a TaskSet, potentially setting a particular sequence of preferred
* locations for each task (given as varargs) if this sequence is not empty.
@@ -267,7 +311,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
new TaskSet(tasks, 0, 0, 0, null)
}
- def createTaskResult(id: Int): ByteBuffer = {
- ByteBuffer.wrap(Utils.serialize(new TaskResult[Int](id, mutable.Map.empty, new TaskMetrics)))
+ def createTaskResult(id: Int): DirectTaskResult[Int] = {
+ new DirectTaskResult[Int](id, mutable.Map.empty, new TaskMetrics)
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
new file mode 100644
index 0000000000..ee150a3107
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import java.nio.ByteBuffer
+
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv}
+import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
+import org.apache.spark.storage.TaskResultBlockId
+
+/**
+ * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter.
+ *
+ * Used to test the case where a BlockManager evicts the task result (or dies) before the
+ * TaskResult is retrieved.
+ */
+class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
+ extends TaskResultGetter(sparkEnv, scheduler) {
+ var removedResult = false
+
+ override def enqueueSuccessfulTask(
+ taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) {
+ if (!removedResult) {
+ // Only remove the result once, since we'd like to test the case where the task eventually
+ // succeeds.
+ serializer.get().deserialize[TaskResult[_]](serializedData) match {
+ case IndirectTaskResult(blockId) =>
+ sparkEnv.blockManager.master.removeBlock(blockId)
+ case directResult: DirectTaskResult[_] =>
+ taskSetManager.abort("Internal error: expect only indirect results")
+ }
+ serializedData.rewind()
+ removedResult = true
+ }
+ super.enqueueSuccessfulTask(taskSetManager, tid, serializedData)
+ }
+}
+
+/**
+ * Tests related to handling task results (both direct and indirect).
+ */
+class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndAfterAll
+ with LocalSparkContext {
+
+ override def beforeAll {
+ // Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small
+ // as we can make it) so the tests don't take too long.
+ System.setProperty("spark.akka.frameSize", "1")
+ }
+
+ before {
+ // Use local-cluster mode because results are returned differently when running with the
+ // LocalScheduler.
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+ }
+
+ override def afterAll {
+ System.clearProperty("spark.akka.frameSize")
+ }
+
+ test("handling results smaller than Akka frame size") {
+ val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
+ assert(result === 2)
+ }
+
+ test("handling results larger than Akka frame size") {
+ val akkaFrameSize =
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
+ val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
+ assert(result === 1.to(akkaFrameSize).toArray)
+
+ val RESULT_BLOCK_ID = TaskResultBlockId(0)
+ assert(sc.env.blockManager.master.getLocations(RESULT_BLOCK_ID).size === 0,
+ "Expect result to be removed from the block manager.")
+ }
+
+ test("task retried if result missing from block manager") {
+ // If this test hangs, it's probably because no resource offers were made after the task
+ // failed.
+ val scheduler: ClusterScheduler = sc.taskScheduler match {
+ case clusterScheduler: ClusterScheduler =>
+ clusterScheduler
+ case _ =>
+ assert(false, "Expect local cluster to use ClusterScheduler")
+ throw new ClassCastException
+ }
+ scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler)
+ val akkaFrameSize =
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
+ val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
+ assert(result === 1.to(akkaFrameSize).toArray)
+
+ // Make sure two tasks were run (one failed one, and a second retried one).
+ assert(scheduler.nextTaskId.get() === 2)
+ }
+}
+
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index 0164dda0ba..c016c51171 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -103,6 +103,27 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
check(List(mutable.HashMap("one" -> 1, "two" -> 2),mutable.HashMap(1->"one",2->"two",3->"three")))
}
+ test("ranges") {
+ val ser = (new KryoSerializer).newInstance()
+ def check[T](t: T) {
+ assert(ser.deserialize[T](ser.serialize(t)) === t)
+ // Check that very long ranges don't get written one element at a time
+ assert(ser.serialize(t).limit < 100)
+ }
+ check(1 to 1000000)
+ check(1 to 1000000 by 2)
+ check(1 until 1000000)
+ check(1 until 1000000 by 2)
+ check(1L to 1000000L)
+ check(1L to 1000000L by 2L)
+ check(1L until 1000000L)
+ check(1L until 1000000L by 2L)
+ check(1.0 to 1000000.0 by 1.0)
+ check(1.0 to 1000000.0 by 2.0)
+ check(1.0 until 1000000.0 by 1.0)
+ check(1.0 until 1000000.0 by 2.0)
+ }
+
test("custom registrator") {
System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
new file mode 100644
index 0000000000..cb76275e39
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.scalatest.FunSuite
+
+class BlockIdSuite extends FunSuite {
+ def assertSame(id1: BlockId, id2: BlockId) {
+ assert(id1.name === id2.name)
+ assert(id1.hashCode === id2.hashCode)
+ assert(id1 === id2)
+ }
+
+ def assertDifferent(id1: BlockId, id2: BlockId) {
+ assert(id1.name != id2.name)
+ assert(id1.hashCode != id2.hashCode)
+ assert(id1 != id2)
+ }
+
+ test("test-bad-deserialization") {
+ try {
+ // Try to deserialize an invalid block id.
+ BlockId("myblock")
+ fail()
+ } catch {
+ case e: IllegalStateException => // OK
+ case _ => fail()
+ }
+ }
+
+ test("rdd") {
+ val id = RDDBlockId(1, 2)
+ assertSame(id, RDDBlockId(1, 2))
+ assertDifferent(id, RDDBlockId(1, 1))
+ assert(id.name === "rdd_1_2")
+ assert(id.asRDDId.get.rddId === 1)
+ assert(id.asRDDId.get.splitIndex === 2)
+ assert(id.isRDD)
+ assertSame(id, BlockId(id.toString))
+ }
+
+ test("shuffle") {
+ val id = ShuffleBlockId(1, 2, 3)
+ assertSame(id, ShuffleBlockId(1, 2, 3))
+ assertDifferent(id, ShuffleBlockId(3, 2, 3))
+ assert(id.name === "shuffle_1_2_3")
+ assert(id.asRDDId === None)
+ assert(id.shuffleId === 1)
+ assert(id.mapId === 2)
+ assert(id.reduceId === 3)
+ assert(id.isShuffle)
+ assertSame(id, BlockId(id.toString))
+ }
+
+ test("broadcast") {
+ val id = BroadcastBlockId(42)
+ assertSame(id, BroadcastBlockId(42))
+ assertDifferent(id, BroadcastBlockId(123))
+ assert(id.name === "broadcast_42")
+ assert(id.asRDDId === None)
+ assert(id.broadcastId === 42)
+ assert(id.isBroadcast)
+ assertSame(id, BlockId(id.toString))
+ }
+
+ test("taskresult") {
+ val id = TaskResultBlockId(60)
+ assertSame(id, TaskResultBlockId(60))
+ assertDifferent(id, TaskResultBlockId(61))
+ assert(id.name === "taskresult_60")
+ assert(id.asRDDId === None)
+ assert(id.taskId === 60)
+ assert(!id.isRDD)
+ assertSame(id, BlockId(id.toString))
+ }
+
+ test("stream") {
+ val id = StreamBlockId(1, 100)
+ assertSame(id, StreamBlockId(1, 100))
+ assertDifferent(id, StreamBlockId(2, 101))
+ assert(id.name === "input-1-100")
+ assert(id.asRDDId === None)
+ assert(id.streamId === 1)
+ assert(id.uniqueId === 100)
+ assert(!id.isBroadcast)
+ assertSame(id, BlockId(id.toString))
+ }
+
+ test("test") {
+ val id = TestBlockId("abc")
+ assertSame(id, TestBlockId("abc"))
+ assertDifferent(id, TestBlockId("ab"))
+ assert(id.name === "test_abc")
+ assert(id.asRDDId === None)
+ assert(id.id === "abc")
+ assert(!id.isShuffle)
+ assertSame(id, BlockId(id.toString))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 038a9acb85..484a654108 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -32,7 +32,6 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.util.{SizeEstimator, Utils, AkkaUtils, ByteBufferInputStream}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
-
class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester {
var store: BlockManager = null
var store2: BlockManager = null
@@ -46,6 +45,10 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
System.setProperty("spark.kryoserializer.buffer.mb", "1")
val serializer = new KryoSerializer
+ // Implicitly convert strings to BlockIds for test clarity.
+ implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value)
+ def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId)
+
before {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0)
this.actorSystem = actorSystem
@@ -229,31 +232,31 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
// Putting a1, a2 and a3 in memory.
- store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY)
store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY)
master.removeRdd(0, blocking = false)
eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
- store.getSingle("rdd_0_0") should be (None)
- master.getLocations("rdd_0_0") should have size 0
+ store.getSingle(rdd(0, 0)) should be (None)
+ master.getLocations(rdd(0, 0)) should have size 0
}
eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
- store.getSingle("rdd_0_1") should be (None)
- master.getLocations("rdd_0_1") should have size 0
+ store.getSingle(rdd(0, 1)) should be (None)
+ master.getLocations(rdd(0, 1)) should have size 0
}
eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
store.getSingle("nonrddblock") should not be (None)
master.getLocations("nonrddblock") should have size (1)
}
- store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY)
master.removeRdd(0, blocking = true)
- store.getSingle("rdd_0_0") should be (None)
- master.getLocations("rdd_0_0") should have size 0
- store.getSingle("rdd_0_1") should be (None)
- master.getLocations("rdd_0_1") should have size 0
+ store.getSingle(rdd(0, 0)) should be (None)
+ master.getLocations(rdd(0, 0)) should have size 0
+ store.getSingle(rdd(0, 1)) should be (None)
+ master.getLocations(rdd(0, 1)) should have size 0
}
test("reregistration on heart beat") {
@@ -372,41 +375,41 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
- store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 1), a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 2), a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 3), a3, StorageLevel.MEMORY_ONLY)
// Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2
// from the same RDD
- assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
- assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store")
- assert(store.getSingle("rdd_0_1") != None, "rdd_0_1 was not in store")
+ assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store")
+ assert(store.getSingle(rdd(0, 2)) != None, "rdd_0_2 was not in store")
+ assert(store.getSingle(rdd(0, 1)) != None, "rdd_0_1 was not in store")
// Check that rdd_0_3 doesn't replace them even after further accesses
- assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
- assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
- assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
+ assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store")
+ assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store")
+ assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store")
}
test("in-memory LRU for partitions of multiple RDDs") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
- store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
// At this point rdd_1_1 should've replaced rdd_0_1
- assert(store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was not in store")
- assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store")
- assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store")
+ assert(store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was not in store")
+ assert(!store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was in store")
+ assert(store.memoryStore.contains(rdd(0, 2)), "rdd_0_2 was not in store")
// Do a get() on rdd_0_2 so that it is the most recently used item
- assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store")
+ assert(store.getSingle(rdd(0, 2)) != None, "rdd_0_2 was not in store")
// Put in more partitions from RDD 0; they should replace rdd_1_1
- store.putSingle("rdd_0_3", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_4", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 3), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 4), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
// Now rdd_1_1 should be dropped to add rdd_0_3, but then rdd_0_2 should *not* be dropped
// when we try to add rdd_0_4.
- assert(!store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was in store")
- assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store")
- assert(!store.memoryStore.contains("rdd_0_4"), "rdd_0_4 was in store")
- assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store")
- assert(store.memoryStore.contains("rdd_0_3"), "rdd_0_3 was not in store")
+ assert(!store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was in store")
+ assert(!store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was in store")
+ assert(!store.memoryStore.contains(rdd(0, 4)), "rdd_0_4 was in store")
+ assert(store.memoryStore.contains(rdd(0, 2)), "rdd_0_2 was not in store")
+ assert(store.memoryStore.contains(rdd(0, 3)), "rdd_0_3 was not in store")
}
test("on-disk storage") {
@@ -590,43 +593,46 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
try {
System.setProperty("spark.shuffle.compress", "true")
store = new BlockManager("exec1", actorSystem, master, serializer, 2000)
- store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed")
+ store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100,
+ "shuffle_0_0_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.shuffle.compress", "false")
store = new BlockManager("exec2", actorSystem, master, serializer, 2000)
- store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed")
+ store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000,
+ "shuffle_0_0_0 was compressed")
store.stop()
store = null
System.setProperty("spark.broadcast.compress", "true")
store = new BlockManager("exec3", actorSystem, master, serializer, 2000)
- store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed")
+ store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100,
+ "broadcast_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.broadcast.compress", "false")
store = new BlockManager("exec4", actorSystem, master, serializer, 2000)
- store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed")
+ store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed")
store.stop()
store = null
System.setProperty("spark.rdd.compress", "true")
store = new BlockManager("exec5", actorSystem, master, serializer, 2000)
- store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed")
+ store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.rdd.compress", "false")
store = new BlockManager("exec6", actorSystem, master, serializer, 2000)
- store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed")
+ store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed")
store.stop()
store = null
diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
index 07c9f2382b..8f0ec6683b 100644
--- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
@@ -26,7 +26,12 @@ class UISuite extends FunSuite {
test("jetty port increases under contention") {
val startPort = 4040
val server = new Server(startPort)
- server.start()
+
+ Try { server.start() } match {
+ case Success(s) =>
+ case Failure(e) =>
+ // Either case server port is busy hence setup for test complete
+ }
val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("localhost", startPort, Seq())
val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("localhost", startPort, Seq())
diff --git a/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala
new file mode 100644
index 0000000000..7177919a58
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.collection.mutable.HashSet
+
+import org.scalatest.FunSuite
+
+class AppendOnlyMapSuite extends FunSuite {
+ test("initialization") {
+ val goodMap1 = new AppendOnlyMap[Int, Int](1)
+ assert(goodMap1.size === 0)
+ val goodMap2 = new AppendOnlyMap[Int, Int](255)
+ assert(goodMap2.size === 0)
+ val goodMap3 = new AppendOnlyMap[Int, Int](256)
+ assert(goodMap3.size === 0)
+ intercept[IllegalArgumentException] {
+ new AppendOnlyMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29
+ }
+ intercept[IllegalArgumentException] {
+ new AppendOnlyMap[Int, Int](-1)
+ }
+ intercept[IllegalArgumentException] {
+ new AppendOnlyMap[Int, Int](0)
+ }
+ }
+
+ test("object keys and values") {
+ val map = new AppendOnlyMap[String, String]()
+ for (i <- 1 to 100) {
+ map("" + i) = "" + i
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ assert(map("" + i) === "" + i)
+ }
+ assert(map("0") === null)
+ assert(map("101") === null)
+ assert(map(null) === null)
+ val set = new HashSet[(String, String)]
+ for ((k, v) <- map) { // Test the foreach method
+ set += ((k, v))
+ }
+ assert(set === (1 to 100).map(_.toString).map(x => (x, x)).toSet)
+ }
+
+ test("primitive keys and values") {
+ val map = new AppendOnlyMap[Int, Int]()
+ for (i <- 1 to 100) {
+ map(i) = i
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ assert(map(i) === i)
+ }
+ assert(map(0) === null)
+ assert(map(101) === null)
+ val set = new HashSet[(Int, Int)]
+ for ((k, v) <- map) { // Test the foreach method
+ set += ((k, v))
+ }
+ assert(set === (1 to 100).map(x => (x, x)).toSet)
+ }
+
+ test("null keys") {
+ val map = new AppendOnlyMap[String, String]()
+ for (i <- 1 to 100) {
+ map("" + i) = "" + i
+ }
+ assert(map.size === 100)
+ assert(map(null) === null)
+ map(null) = "hello"
+ assert(map.size === 101)
+ assert(map(null) === "hello")
+ }
+
+ test("null values") {
+ val map = new AppendOnlyMap[String, String]()
+ for (i <- 1 to 100) {
+ map("" + i) = null
+ }
+ assert(map.size === 100)
+ assert(map("1") === null)
+ assert(map(null) === null)
+ assert(map.size === 100)
+ map(null) = null
+ assert(map.size === 101)
+ assert(map(null) === null)
+ }
+
+ test("changeValue") {
+ val map = new AppendOnlyMap[String, String]()
+ for (i <- 1 to 100) {
+ map("" + i) = "" + i
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ val res = map.changeValue("" + i, (hadValue, oldValue) => {
+ assert(hadValue === true)
+ assert(oldValue === "" + i)
+ oldValue + "!"
+ })
+ assert(res === i + "!")
+ }
+ // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a
+ // bug where changeValue would return the wrong result when the map grew on that insert
+ for (i <- 101 to 400) {
+ val res = map.changeValue("" + i, (hadValue, oldValue) => {
+ assert(hadValue === false)
+ i + "!"
+ })
+ assert(res === i + "!")
+ }
+ assert(map.size === 400)
+ assert(map(null) === null)
+ map.changeValue(null, (hadValue, oldValue) => {
+ assert(hadValue === false)
+ "null!"
+ })
+ assert(map.size === 401)
+ map.changeValue(null, (hadValue, oldValue) => {
+ assert(hadValue === true)
+ assert(oldValue === "null!")
+ "null!!"
+ })
+ assert(map.size === 401)
+ }
+
+ test("inserting in capacity-1 map") {
+ val map = new AppendOnlyMap[String, String](1)
+ for (i <- 1 to 100) {
+ map("" + i) = "" + i
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ assert(map("" + i) === "" + i)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index e2859caf58..4684c8c972 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.util
import com.google.common.base.Charsets
import com.google.common.io.Files
import java.io.{ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream, File}
+import java.nio.{ByteBuffer, ByteOrder}
import org.scalatest.FunSuite
import org.apache.commons.io.FileUtils
import scala.util.Random
@@ -135,5 +136,15 @@ class UtilsSuite extends FunSuite {
FileUtils.deleteDirectory(tmpDir2)
}
+
+ test("deserialize long value") {
+ val testval : Long = 9730889947L
+ val bbuf = ByteBuffer.allocate(8)
+ assert(bbuf.hasArray)
+ bbuf.order(ByteOrder.BIG_ENDIAN)
+ bbuf.putLong(testval)
+ assert(bbuf.array.length === 8)
+ assert(Utils.deserializeLongValue(bbuf.array) === testval)
+ }
}
diff --git a/docker/README.md b/docker/README.md
new file mode 100644
index 0000000000..bf59e77d11
--- /dev/null
+++ b/docker/README.md
@@ -0,0 +1,5 @@
+Spark docker files
+===========
+
+Drawn from Matt Massie's docker files (https://github.com/massie/dockerfiles),
+as well as some updates from Andre Schumacher (https://github.com/AndreSchumacher/docker). \ No newline at end of file
diff --git a/docker/build b/docker/build
new file mode 100755
index 0000000000..253a2fc8dd
--- /dev/null
+++ b/docker/build
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+docker images > /dev/null || { echo Please install docker in non-sudo mode. ; exit; }
+
+./spark-test/build \ No newline at end of file
diff --git a/docker/spark-test/README.md b/docker/spark-test/README.md
new file mode 100644
index 0000000000..addea277aa
--- /dev/null
+++ b/docker/spark-test/README.md
@@ -0,0 +1,10 @@
+Spark Docker files usable for testing and development purposes.
+
+These images are intended to be run like so:
+docker run -v $SPARK_HOME:/opt/spark spark-test-master
+docker run -v $SPARK_HOME:/opt/spark spark-test-worker <master_ip>
+
+Using this configuration, the containers will have their Spark directories
+mounted to your actual SPARK_HOME, allowing you to modify and recompile
+your Spark source and have them immediately usable in the docker images
+(without rebuilding them).
diff --git a/docker/spark-test/base/Dockerfile b/docker/spark-test/base/Dockerfile
new file mode 100644
index 0000000000..60962776dd
--- /dev/null
+++ b/docker/spark-test/base/Dockerfile
@@ -0,0 +1,38 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+FROM ubuntu:precise
+
+RUN echo "deb http://archive.ubuntu.com/ubuntu precise main universe" > /etc/apt/sources.list
+
+# Upgrade package index
+RUN apt-get update
+
+# install a few other useful packages plus Open Jdk 7
+RUN apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server
+
+ENV SCALA_VERSION 2.9.3
+ENV SPARK_VERSION 0.8.1
+ENV CDH_VERSION cdh4
+ENV SCALA_HOME /opt/scala-$SCALA_VERSION
+ENV SPARK_HOME /opt/spark
+ENV PATH $SPARK_HOME:$SCALA_HOME/bin:$PATH
+
+# Install Scala
+ADD http://www.scala-lang.org/files/archive/scala-$SCALA_VERSION.tgz /
+RUN (cd / && gunzip < scala-$SCALA_VERSION.tgz)|(cd /opt && tar -xvf -)
+RUN rm /scala-$SCALA_VERSION.tgz
diff --git a/docker/spark-test/build b/docker/spark-test/build
new file mode 100755
index 0000000000..6f9e197433
--- /dev/null
+++ b/docker/spark-test/build
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+docker build -t spark-test-base spark-test/base/
+docker build -t spark-test-master spark-test/master/
+docker build -t spark-test-worker spark-test/worker/
diff --git a/docker/spark-test/master/Dockerfile b/docker/spark-test/master/Dockerfile
new file mode 100644
index 0000000000..f729534ab6
--- /dev/null
+++ b/docker/spark-test/master/Dockerfile
@@ -0,0 +1,21 @@
+# Spark Master
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+FROM spark-test-base
+ADD default_cmd /root/
+CMD ["/root/default_cmd"]
diff --git a/docker/spark-test/master/default_cmd b/docker/spark-test/master/default_cmd
new file mode 100755
index 0000000000..a5b1303c2e
--- /dev/null
+++ b/docker/spark-test/master/default_cmd
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+IP=$(ip -o -4 addr list eth0 | perl -n -e 'if (m{inet\s([\d\.]+)\/\d+\s}xms) { print $1 }')
+echo "CONTAINER_IP=$IP"
+/opt/spark/spark-class org.apache.spark.deploy.master.Master -i $IP
diff --git a/docker/spark-test/worker/Dockerfile b/docker/spark-test/worker/Dockerfile
new file mode 100644
index 0000000000..890febe7b6
--- /dev/null
+++ b/docker/spark-test/worker/Dockerfile
@@ -0,0 +1,22 @@
+# Spark Worker
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+FROM spark-test-base
+ENV SPARK_WORKER_PORT 8888
+ADD default_cmd /root/
+ENTRYPOINT ["/root/default_cmd"]
diff --git a/docker/spark-test/worker/default_cmd b/docker/spark-test/worker/default_cmd
new file mode 100755
index 0000000000..ab6336f70c
--- /dev/null
+++ b/docker/spark-test/worker/default_cmd
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+IP=$(ip -o -4 addr list eth0 | perl -n -e 'if (m{inet\s([\d\.]+)\/\d+\s}xms) { print $1 }')
+echo "CONTAINER_IP=$IP"
+/opt/spark/spark-class org.apache.spark.deploy.worker.Worker $1
diff --git a/docs/_config.yml b/docs/_config.yml
index b061764b36..48ecb8d0c9 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -3,8 +3,8 @@ markdown: kramdown
# These allow the documentation to be updated with nerw releases
# of Spark, Scala, and Mesos.
-SPARK_VERSION: 0.8.0-SNAPSHOT
-SPARK_VERSION_SHORT: 0.8.0
+SPARK_VERSION: 0.9.0-incubating-SNAPSHOT
+SPARK_VERSION_SHORT: 0.9.0-SNAPSHOT
SCALA_VERSION: 2.9.3
MESOS_VERSION: 0.13.0
SPARK_ISSUE_TRACKER_URL: https://spark-project.atlassian.net
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index 238ad26de0..0c1d657cde 100755
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -6,7 +6,7 @@
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge,chrome=1">
- <title>{{ page.title }} - Spark {{site.SPARK_VERSION}} Documentation</title>
+ <title>{{ page.title }} - Spark {{site.SPARK_VERSION_SHORT}} Documentation</title>
<meta name="description" content="">
<link rel="stylesheet" href="css/bootstrap.min.css">
@@ -109,7 +109,7 @@
</ul>
</li>
</ul>
- <!--<p class="navbar-text pull-right"><span class="version-text">v{{site.SPARK_VERSION}}</span></p>-->
+ <!--<p class="navbar-text pull-right"><span class="version-text">v{{site.SPARK_VERSION_SHORT}}</span></p>-->
</div>
</div>
</div>
diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md
index f991d86c8d..c1ff9c417c 100644
--- a/docs/mllib-guide.md
+++ b/docs/mllib-guide.md
@@ -144,10 +144,9 @@ Available algorithms for clustering:
# Collaborative Filtering
-[Collaborative
-filtering](http://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering)
+[Collaborative filtering](http://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering)
is commonly used for recommender systems. These techniques aim to fill in the
-missing entries of a user-product association matrix. MLlib currently supports
+missing entries of a user-item association matrix. MLlib currently supports
model-based collaborative filtering, in which users and products are described
by a small set of latent factors that can be used to predict missing entries.
In particular, we implement the [alternating least squares
@@ -158,7 +157,24 @@ following parameters:
* *numBlocks* is the number of blacks used to parallelize computation (set to -1 to auto-configure).
* *rank* is the number of latent factors in our model.
* *iterations* is the number of iterations to run.
-* *lambda* specifies the regularization parameter in ALS.
+* *lambda* specifies the regularization parameter in ALS.
+* *implicitPrefs* specifies whether to use the *explicit feedback* ALS variant or one adapted for *implicit feedback* data
+* *alpha* is a parameter applicable to the implicit feedback variant of ALS that governs the *baseline* confidence in preference observations
+
+## Explicit vs Implicit Feedback
+
+The standard approach to matrix factorization based collaborative filtering treats
+the entries in the user-item matrix as *explicit* preferences given by the user to the item.
+
+It is common in many real-world use cases to only have access to *implicit feedback*
+(e.g. views, clicks, purchases, likes, shares etc.). The approach used in MLlib to deal with
+such data is taken from
+[Collaborative Filtering for Implicit Feedback Datasets](http://research.yahoo.com/pub/2433).
+Essentially instead of trying to model the matrix of ratings directly, this approach treats the data as
+a combination of binary preferences and *confidence values*. The ratings are then related
+to the level of confidence in observed user preferences, rather than explicit ratings given to items.
+The model then tries to find latent factors that can be used to predict the expected preference of a user
+for an item.
Available algorithms for collaborative filtering:
diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md
index f67a1cc49c..6c2336ad0c 100644
--- a/docs/python-programming-guide.md
+++ b/docs/python-programming-guide.md
@@ -16,7 +16,7 @@ This guide will show how to use the Spark features described there in Python.
There are a few key differences between the Python and Scala APIs:
* Python is dynamically typed, so RDDs can hold objects of multiple types.
-* PySpark does not yet support a few API calls, such as `lookup`, `sort`, and non-text input files, though these will be added in future releases.
+* PySpark does not yet support a few API calls, such as `lookup` and non-text input files, though these will be added in future releases.
In PySpark, RDDs support the same methods as their Scala counterparts but take Python functions and return Python collection types.
Short functions can be passed to RDD methods using Python's [`lambda`](http://www.diveintopython.net/power_of_introspection/lambda_functions.html) syntax:
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index c611db0af4..2898af0bed 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -34,6 +34,8 @@ Environment variables:
System Properties:
* 'spark.yarn.applicationMaster.waitTries', property to set the number of times the ApplicationMaster waits for the the spark master and then also the number of tries it waits for the Spark Context to be intialized. Default is 10.
+* 'spark.yarn.submit.file.replication', the HDFS replication level for the files uploaded into HDFS for the application. These include things like the spark jar, the app jar, and any distributed cache files/archives.
+* 'spark.yarn.preserve.staging.files', set to true to preserve the staged files(spark jar, app jar, distributed cache files) at the end of the job rather then delete them.
# Launching Spark on YARN
@@ -50,7 +52,11 @@ The command to launch the YARN Client is as follows:
--master-memory <MEMORY_FOR_MASTER> \
--worker-memory <MEMORY_PER_WORKER> \
--worker-cores <CORES_PER_WORKER> \
- --queue <queue_name>
+ --name <application_name> \
+ --queue <queue_name> \
+ --addJars <any_local_files_used_in_SparkContext.addJar> \
+ --files <files_for_distributed_cache> \
+ --archives <archives_for_distributed_cache>
For example:
@@ -83,3 +89,5 @@ The above starts a YARN Client programs which periodically polls the Application
- When your application instantiates a Spark context it must use a special "yarn-standalone" master url. This starts the scheduler without forcing it to connect to a cluster. A good way to handle this is to pass "yarn-standalone" as an argument to your program, as shown in the example above.
- We do not requesting container resources based on the number of cores. Thus the numbers of cores given via command line arguments cannot be guaranteed.
- The local directories used for spark will be the local directories configured for YARN (Hadoop Yarn config yarn.nodemanager.local-dirs). If the user specifies spark.local.dir, it will be ignored.
+- The --files and --archives options support specifying file names with the # similar to Hadoop. For example you can specify: --files localtest.txt#appSees.txt and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name appSees.txt and your application should use the name as appSees.txt to reference it when running on YARN.
+- The --addJars option allows the SparkContext.addJar function to work if you are using it with local files. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files.
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index 81cdbefd0c..17066ef0dd 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -3,6 +3,9 @@ layout: global
title: Spark Standalone Mode
---
+* This will become a table of contents (this text will be scraped).
+{:toc}
+
In addition to running on the Mesos or YARN cluster managers, Spark also provides a simple standalone deploy mode. You can launch a standalone cluster either manually, by starting a master and workers by hand, or use our provided [launch scripts](#cluster-launch-scripts). It is also possible to run these daemons on a single machine for testing.
# Installing Spark Standalone to a Cluster
@@ -169,3 +172,75 @@ In addition, detailed log output for each job is also written to the work direct
You can run Spark alongside your existing Hadoop cluster by just launching it as a separate service on the same machines. To access Hadoop data from Spark, just use a hdfs:// URL (typically `hdfs://<namenode>:9000/path`, but you can find the right URL on your Hadoop Namenode's web UI). Alternatively, you can set up a separate cluster for Spark, and still have it access HDFS over the network; this will be slower than disk-local access, but may not be a concern if you are still running in the same local area network (e.g. you place a few Spark machines on each rack that you have Hadoop on).
+
+# High Availability
+
+By default, standalone scheduling clusters are resilient to Worker failures (insofar as Spark itself is resilient to losing work by moving it to other workers). However, the scheduler uses a Master to make scheduling decisions, and this (by default) creates a single point of failure: if the Master crashes, no new applications can be created. In order to circumvent this, we have two high availability schemes, detailed below.
+
+## Standby Masters with ZooKeeper
+
+**Overview**
+
+Utilizing ZooKeeper to provide leader election and some state storage, you can launch multiple Masters in your cluster connected to the same ZooKeeper instance. One will be elected "leader" and the others will remain in standby mode. If the current leader dies, another Master will be elected, recover the old Master's state, and then resume scheduling. The entire recovery process (from the time the the first leader goes down) should take between 1 and 2 minutes. Note that this delay only affects scheduling _new_ applications -- applications that were already running during Master failover are unaffected.
+
+Learn more about getting started with ZooKeeper [here](http://zookeeper.apache.org/doc/trunk/zookeeperStarted.html).
+
+**Configuration**
+
+In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env using this configuration:
+
+<table class="table">
+ <tr><th style="width:21%">System property</th><th>Meaning</th></tr>
+ <tr>
+ <td><code>spark.deploy.recoveryMode</code></td>
+ <td>Set to ZOOKEEPER to enable standby Master recovery mode (default: NONE).</td>
+ </tr>
+ <tr>
+ <td><code>spark.deploy.zookeeper.url</code></td>
+ <td>The ZooKeeper cluster url (e.g., 192.168.1.100:2181,192.168.1.101:2181).</td>
+ </tr>
+ <tr>
+ <td><code>spark.deploy.zookeeper.dir</code></td>
+ <td>The directory in ZooKeeper to store recovery state (default: /spark).</td>
+ </tr>
+</table>
+
+Possible gotcha: If you have multiple Masters in your cluster but fail to correctly configure the Masters to use ZooKeeper, the Masters will fail to discover each other and think they're all leaders. This will not lead to a healthy cluster state (as all Masters will schedule independently).
+
+**Details**
+
+After you have a ZooKeeper cluster set up, enabling high availability is straightforward. Simply start multiple Master processes on different nodes with the same ZooKeeper configuration (ZooKeeper URL and directory). Masters can be added and removed at any time.
+
+In order to schedule new applications or add Workers to the cluster, they need to know the IP address of the current leader. This can be accomplished by simply passing in a list of Masters where you used to pass in a single one. For example, you might start your SparkContext pointing to ``spark://host1:port1,host2:port2``. This would cause your SparkContext to try registering with both Masters -- if ``host1`` goes down, this configuration would still be correct as we'd find the new leader, ``host2``.
+
+There's an important distinction to be made between "registering with a Master" and normal operation. When starting up, an application or Worker needs to be able to find and register with the current lead Master. Once it successfully registers, though, it is "in the system" (i.e., stored in ZooKeeper). If failover occurs, the new leader will contact all previously registered applications and Workers to inform them of the change in leadership, so they need not even have known of the existence of the new Master at startup.
+
+Due to this property, new Masters can be created at any time, and the only thing you need to worry about is that _new_ applications and Workers can find it to register with in case it becomes the leader. Once registered, you're taken care of.
+
+## Single-Node Recovery with Local File System
+
+**Overview**
+
+ZooKeeper is the best way to go for production-level high availability, but if you just want to be able to restart the Master if it goes down, FILESYSTEM mode can take care of it. When applications and Workers register, they have enough state written to the provided directory so that they can be recovered upon a restart of the Master process.
+
+**Configuration**
+
+In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env using this configuration:
+
+<table class="table">
+ <tr><th style="width:21%">System property</th><th>Meaning</th></tr>
+ <tr>
+ <td><code>spark.deploy.recoveryMode</code></td>
+ <td>Set to FILESYSTEM to enable single-node recovery mode (default: NONE).</td>
+ </tr>
+ <tr>
+ <td><code>spark.deploy.recoveryDirectory</code></td>
+ <td>The directory in which Spark will store recovery state, accessible from the Master's perspective.</td>
+ </tr>
+</table>
+
+**Details**
+
+* This solution can be used in tandem with a process monitor/manager like [monit](http://mmonit.com/monit/), or just to enable manual recovery via restart.
+* While filesystem recovery seems straightforwardly better than not doing any recovery at all, this mode may be suboptimal for certain development or experimental purposes. In particular, killing a master via stop-master.sh does not clean up its recovery state, so whenever you start a new Master, it will enter recovery mode. This could increase the startup time by up to 1 minute if it needs to wait for all previously-registered Workers/clients to timeout.
+* While it's not officially supported, you could mount an NFS directory as the recovery directory. If the original Master node dies completely, you could then start a Master on a different node, which would correctly recover all previously registered Workers/applications (equivalent to ZooKeeper recovery). Future applications will have to be able to find the new Master, however, in order to register.
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index c7df172024..835b257238 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -122,12 +122,12 @@ Spark Streaming features windowed computations, which allow you to apply transfo
<table class="table">
<tr><th style="width:30%">Transformation</th><th>Meaning</th></tr>
<tr>
- <td> <b>window</b>(<i>windowDuration</i>, </i>slideDuration</i>) </td>
+ <td> <b>window</b>(<i>windowDuration</i>, <i>slideDuration</i>) </td>
<td> Return a new DStream which is computed based on windowed batches of the source DStream. <i>windowDuration</i> is the width of the window and <i>slideTime</i> is the frequency during which the window is calculated. Both times must be multiples of the batch interval.
</td>
</tr>
<tr>
- <td> <b>countByWindow</b>(<i>windowDuration</i>, </i>slideDuration</i>) </td>
+ <td> <b>countByWindow</b>(<i>windowDuration</i>, <i>slideDuration</i>) </td>
<td> Return a sliding count of elements in the stream. <i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>.
</td>
</tr>
@@ -161,7 +161,6 @@ Spark Streaming features windowed computations, which allow you to apply transfo
<i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>.
</td>
</tr>
-
</table>
A complete list of DStream operations is available in the API documentation of [DStream](api/streaming/index.html#org.apache.spark.streaming.DStream) and [PairDStreamFunctions](api/streaming/index.html#org.apache.spark.streaming.PairDStreamFunctions).
diff --git a/docs/tuning.md b/docs/tuning.md
index 28d88a2659..f491ae9b95 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -175,7 +175,7 @@ To further tune garbage collection, we first need to understand some basic infor
* Java Heap space is divided in to two regions Young and Old. The Young generation is meant to hold short-lived objects
while the Old generation is intended for objects with longer lifetimes.
-* The Young generation is further divided into three regions [Eden, Survivor1, Survivor2].
+* The Young generation is further divided into three regions \[Eden, Survivor1, Survivor2\].
* A simplified description of the garbage collection procedure: When Eden is full, a minor GC is run on Eden and objects
that are alive from Eden and Survivor1 are copied to Survivor2. The Survivor regions are swapped. If an object is old
diff --git a/ec2/README b/ec2/README
index 0add81312c..433da37b4c 100644
--- a/ec2/README
+++ b/ec2/README
@@ -1,4 +1,4 @@
This folder contains a script, spark-ec2, for launching Spark clusters on
Amazon EC2. Usage instructions are available online at:
-http://spark-project.org/docs/latest/ec2-scripts.html
+http://spark.incubator.apache.org/docs/latest/ec2-scripts.html
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 419d0fe13f..65868b76b9 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -23,6 +23,7 @@ from __future__ import with_statement
import logging
import os
+import pipes
import random
import shutil
import subprocess
@@ -36,6 +37,9 @@ import boto
from boto.ec2.blockdevicemapping import BlockDeviceMapping, EBSBlockDeviceType
from boto import ec2
+class UsageError(Exception):
+ pass
+
# A URL prefix from which to fetch AMI information
AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list"
@@ -66,7 +70,7 @@ def parse_args():
"slaves across multiple (an additional $0.01/Gb for bandwidth" +
"between zones applies)")
parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use")
- parser.add_option("-v", "--spark-version", default="0.7.3",
+ parser.add_option("-v", "--spark-version", default="0.8.0",
help="Version of Spark to use: 'X.Y.Z' or a specific git hash")
parser.add_option("--spark-git-repo",
default="https://github.com/mesos/spark",
@@ -103,11 +107,7 @@ def parse_args():
parser.print_help()
sys.exit(1)
(action, cluster_name) = args
- if opts.identity_file == None and action in ['launch', 'login', 'start']:
- print >> stderr, ("ERROR: The -i or --identity-file argument is " +
- "required for " + action)
- sys.exit(1)
-
+
# Boto config check
# http://boto.cloudhackers.com/en/latest/boto_config_tut.html
home_dir = os.getenv('HOME')
@@ -155,7 +155,7 @@ def is_active(instance):
# Return correct versions of Spark and Shark, given the supplied Spark version
def get_spark_shark_version(opts):
- spark_shark_map = {"0.7.3": "0.7.0"}
+ spark_shark_map = {"0.7.3": "0.7.1", "0.8.0": "0.8.0"}
version = opts.spark_version.replace("v", "")
if version not in spark_shark_map:
print >> stderr, "Don't know about Spark version: %s" % version
@@ -364,12 +364,12 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
slave_nodes = []
for res in reservations:
active = [i for i in res.instances if is_active(i)]
- if len(active) > 0:
- group_names = [g.name for g in res.groups]
+ for inst in active:
+ group_names = [g.name for g in inst.groups]
if group_names == [cluster_name + "-master"]:
- master_nodes += res.instances
+ master_nodes.append(inst)
elif group_names == [cluster_name + "-slaves"]:
- slave_nodes += res.instances
+ slave_nodes.append(inst)
if any((master_nodes, slave_nodes)):
print ("Found %d master(s), %d slaves" %
(len(master_nodes), len(slave_nodes)))
@@ -390,10 +390,18 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
master = master_nodes[0].public_dns_name
if deploy_ssh_key:
- print "Copying SSH key %s to master..." % opts.identity_file
- ssh(master, opts, 'mkdir -p ~/.ssh')
- scp(master, opts, opts.identity_file, '~/.ssh/id_rsa')
- ssh(master, opts, 'chmod 600 ~/.ssh/id_rsa')
+ print "Generating cluster's SSH key on master..."
+ key_setup = """
+ [ -f ~/.ssh/id_rsa ] ||
+ (ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa &&
+ cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys)
+ """
+ ssh(master, opts, key_setup)
+ dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh'])
+ print "Transferring cluster's SSH key to slaves..."
+ for slave in slave_nodes:
+ print slave.public_dns_name
+ ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar)
modules = ['spark', 'shark', 'ephemeral-hdfs', 'persistent-hdfs',
'mapreduce', 'spark-standalone']
@@ -535,18 +543,33 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
dest.write(text)
dest.close()
# rsync the whole directory over to the master machine
- command = (("rsync -rv -e 'ssh -o StrictHostKeyChecking=no -i %s' " +
- "'%s/' '%s@%s:/'") % (opts.identity_file, tmp_dir, opts.user, active_master))
- subprocess.check_call(command, shell=True)
+ command = [
+ 'rsync', '-rv',
+ '-e', stringify_command(ssh_command(opts)),
+ "%s/" % tmp_dir,
+ "%s@%s:/" % (opts.user, active_master)
+ ]
+ subprocess.check_call(command)
# Remove the temp directory we created above
shutil.rmtree(tmp_dir)
-# Copy a file to a given host through scp, throwing an exception if scp fails
-def scp(host, opts, local_file, dest_file):
- subprocess.check_call(
- "scp -q -o StrictHostKeyChecking=no -i %s '%s' '%s@%s:%s'" %
- (opts.identity_file, local_file, opts.user, host, dest_file), shell=True)
+def stringify_command(parts):
+ if isinstance(parts, str):
+ return parts
+ else:
+ return ' '.join(map(pipes.quote, parts))
+
+
+def ssh_args(opts):
+ parts = ['-o', 'StrictHostKeyChecking=no']
+ if opts.identity_file is not None:
+ parts += ['-i', opts.identity_file]
+ return parts
+
+
+def ssh_command(opts):
+ return ['ssh'] + ssh_args(opts)
# Run a command on a host through ssh, retrying up to two times
@@ -556,18 +579,42 @@ def ssh(host, opts, command):
while True:
try:
return subprocess.check_call(
- "ssh -t -o StrictHostKeyChecking=no -i %s %s@%s '%s'" %
- (opts.identity_file, opts.user, host, command), shell=True)
+ ssh_command(opts) + ['-t', '%s@%s' % (opts.user, host), stringify_command(command)])
except subprocess.CalledProcessError as e:
if (tries > 2):
- raise e
- print "Couldn't connect to host {0}, waiting 30 seconds".format(e)
+ # If this was an ssh failure, provide the user with hints.
+ if e.returncode == 255:
+ raise UsageError("Failed to SSH to remote host {0}.\nPlease check that you have provided the correct --identity-file and --key-pair parameters and try again.".format(host))
+ else:
+ raise e
+ print >> stderr, "Error executing remote command, retrying after 30 seconds: {0}".format(e)
time.sleep(30)
tries = tries + 1
+def ssh_read(host, opts, command):
+ return subprocess.check_output(
+ ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)])
+def ssh_write(host, opts, command, input):
+ tries = 0
+ while True:
+ proc = subprocess.Popen(
+ ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)],
+ stdin=subprocess.PIPE)
+ proc.stdin.write(input)
+ proc.stdin.close()
+ status = proc.wait()
+ if status == 0:
+ break
+ elif (tries > 2):
+ raise RuntimeError("ssh_write failed with error %s" % proc.returncode)
+ else:
+ print >> stderr, "Error {0} while executing remote command, retrying after 30 seconds".format(status)
+ time.sleep(30)
+ tries = tries + 1
+
# Gets a list of zones to launch instances in
def get_zones(conn, opts):
@@ -586,7 +633,7 @@ def get_partition(total, num_partitions, current_partitions):
return num_slaves_this_zone
-def main():
+def real_main():
(opts, action, cluster_name) = parse_args()
try:
conn = ec2.connect_to_region(opts.region)
@@ -669,11 +716,11 @@ def main():
conn, opts, cluster_name)
master = master_nodes[0].public_dns_name
print "Logging into master " + master + "..."
- proxy_opt = ""
+ proxy_opt = []
if opts.proxy_port != None:
- proxy_opt = "-D " + opts.proxy_port
- subprocess.check_call("ssh -o StrictHostKeyChecking=no -i %s %s %s@%s" %
- (opts.identity_file, proxy_opt, opts.user, master), shell=True)
+ proxy_opt = ['-D', opts.proxy_port]
+ subprocess.check_call(
+ ssh_command(opts) + proxy_opt + ['-t', "%s@%s" % (opts.user, master)])
elif action == "get-master":
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
@@ -715,6 +762,13 @@ def main():
sys.exit(1)
+def main():
+ try:
+ real_main()
+ except UsageError, e:
+ print >> stderr, "\nError:\n", e
+
+
if __name__ == "__main__":
logging.basicConfig()
main()
diff --git a/examples/pom.xml b/examples/pom.xml
index e48f5b50ab..b8c020a321 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -21,38 +21,46 @@
<parent>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.8.0-SNAPSHOT</version>
+ <version>0.9.0-incubating-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-examples</artifactId>
+ <artifactId>spark-examples_2.9.3</artifactId>
<packaging>jar</packaging>
<name>Spark Project Examples</name>
<url>http://spark.incubator.apache.org/</url>
+ <repositories>
+ <!-- A repository in the local filesystem for the Kafka JAR, which we modified for Scala 2.9 -->
+ <repository>
+ <id>lib</id>
+ <url>file://${project.basedir}/lib</url>
+ </repository>
+ </repositories>
+
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core</artifactId>
+ <artifactId>spark-core_2.9.3</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-streaming</artifactId>
+ <artifactId>spark-streaming_2.9.3</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-mllib</artifactId>
+ <artifactId>spark-mllib_2.9.3</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-bagel</artifactId>
+ <artifactId>spark-bagel_2.9.3</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
@@ -72,6 +80,12 @@
</exclusions>
</dependency>
<dependency>
+ <groupId>org.apache.kafka</groupId>
+ <artifactId>kafka</artifactId>
+ <version>0.7.2-spark</version> <!-- Comes from our in-project repository -->
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
</dependency>
@@ -82,12 +96,12 @@
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
- <artifactId>scalatest_${scala.version}</artifactId>
+ <artifactId>scalatest_2.9.3</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
- <artifactId>scalacheck_${scala.version}</artifactId>
+ <artifactId>scalacheck_2.9.3</artifactId>
<scope>test</scope>
</dependency>
<dependency>
@@ -161,7 +175,7 @@
</goals>
<configuration>
<transformers>
- <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
+ <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer" />
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
<resource>reference.conf</resource>
</transformer>
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
index f7bf75b4e5..bc2db39c12 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
@@ -21,8 +21,6 @@ import java.util.Random
import org.apache.spark.SparkContext
import org.apache.spark.util.Vector
import org.apache.spark.SparkContext._
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
/**
* K-means clustering.
diff --git a/make-distribution.sh b/make-distribution.sh
index bffb19843c..32bbdb90a5 100755
--- a/make-distribution.sh
+++ b/make-distribution.sh
@@ -95,7 +95,7 @@ cp $FWDIR/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/jars/"
# Copy other things
mkdir "$DISTDIR"/conf
-cp "$FWDIR/conf/*.template" "$DISTDIR"/conf
+cp "$FWDIR"/conf/*.template "$DISTDIR"/conf
cp -r "$FWDIR/bin" "$DISTDIR"
cp -r "$FWDIR/python" "$DISTDIR"
cp "$FWDIR/spark-class" "$DISTDIR"
diff --git a/mllib/pom.xml b/mllib/pom.xml
index 966caf6835..f472082ad1 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -21,12 +21,12 @@
<parent>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.8.0-SNAPSHOT</version>
+ <version>0.9.0-incubating-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-mllib</artifactId>
+ <artifactId>spark-mllib_2.9.3</artifactId>
<packaging>jar</packaging>
<name>Spark Project ML Library</name>
<url>http://spark.incubator.apache.org/</url>
@@ -34,7 +34,7 @@
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core</artifactId>
+ <artifactId>spark-core_2.9.3</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
@@ -48,12 +48,12 @@
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
- <artifactId>scalatest_${scala.version}</artifactId>
+ <artifactId>scalatest_2.9.3</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
- <artifactId>scalacheck_${scala.version}</artifactId>
+ <artifactId>scalacheck_2.9.3</artifactId>
<scope>test</scope>
</dependency>
<dependency>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index be002d02bc..36853acab5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -21,7 +21,8 @@ import scala.collection.mutable.{ArrayBuffer, BitSet}
import scala.util.Random
import scala.util.Sorting
-import org.apache.spark.{HashPartitioner, Partitioner, SparkContext}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.{Logging, HashPartitioner, Partitioner, SparkContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.KryoRegistrator
@@ -61,6 +62,12 @@ case class Rating(val user: Int, val product: Int, val rating: Double)
/**
* Alternating Least Squares matrix factorization.
*
+ * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices,
+ * `X` and `Y`, i.e. `Xt * Y = R`. Typically these approximations are called 'factor' matrices.
+ * The general approach is iterative. During each iteration, one of the factor matrices is held
+ * constant, while the other is solved for using least squares. The newly-solved factor matrix is
+ * then held constant while solving for the other factor matrix.
+ *
* This is a blocked implementation of the ALS factorization algorithm that groups the two sets
* of factors (referred to as "users" and "products") into blocks and reduces communication by only
* sending one copy of each user vector to each product block on each iteration, and only for the
@@ -70,11 +77,21 @@ case class Rating(val user: Int, val product: Int, val rating: Double)
* vectors it receives from each user block it will depend on). This allows us to send only an
* array of feature vectors between each user block and product block, and have the product block
* find the users' ratings and update the products based on these messages.
+ *
+ * For implicit preference data, the algorithm used is based on
+ * "Collaborative Filtering for Implicit Feedback Datasets", available at
+ * [[http://research.yahoo.com/pub/2433]], adapted for the blocked approach used here.
+ *
+ * Essentially instead of finding the low-rank approximations to the rating matrix `R`,
+ * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if r > 0
+ * and 0 if r = 0. The ratings then act as 'confidence' values related to strength of indicated user
+ * preferences rather than explicit ratings given to items.
*/
-class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var lambda: Double)
- extends Serializable
+class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var lambda: Double,
+ var implicitPrefs: Boolean, var alpha: Double)
+ extends Serializable with Logging
{
- def this() = this(-1, 10, 10, 0.01)
+ def this() = this(-1, 10, 10, 0.01, false, 1.0)
/**
* Set the number of blocks to parallelize the computation into; pass -1 for an auto-configured
@@ -103,6 +120,16 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
this
}
+ def setImplicitPrefs(implicitPrefs: Boolean): ALS = {
+ this.implicitPrefs = implicitPrefs
+ this
+ }
+
+ def setAlpha(alpha: Double): ALS = {
+ this.alpha = alpha
+ this
+ }
+
/**
* Run ALS with the configured parameters on an input RDD of (user, product, rating) triples.
* Returns a MatrixFactorizationModel with feature vectors for each user and product.
@@ -147,19 +174,24 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
}
}
- for (iter <- 0 until iterations) {
+ for (iter <- 1 to iterations) {
// perform ALS update
- products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda)
- users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda)
+ logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations))
+ // YtY / XtX is an Option[DoubleMatrix] and is only required for the implicit feedback model
+ val YtY = computeYtY(users)
+ val YtYb = ratings.context.broadcast(YtY)
+ products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda,
+ alpha, YtYb)
+ logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations))
+ val XtX = computeYtY(products)
+ val XtXb = ratings.context.broadcast(XtX)
+ users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda,
+ alpha, XtXb)
}
// Flatten and cache the two final RDDs to un-block them
- val usersOut = users.join(userOutLinks).flatMap { case (b, (factors, outLinkBlock)) =>
- for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i))
- }
- val productsOut = products.join(productOutLinks).flatMap { case (b, (factors, outLinkBlock)) =>
- for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i))
- }
+ val usersOut = unblockFactors(users, userOutLinks)
+ val productsOut = unblockFactors(products, productOutLinks)
usersOut.persist()
productsOut.persist()
@@ -168,6 +200,40 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
}
/**
+ * Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors
+ * for each user (or product), in a distributed fashion. Here `reduceByKeyLocally` is used as
+ * the driver program requires `YtY` to broadcast it to the slaves
+ * @param factors the (block-distributed) user or product factor vectors
+ * @return Option[YtY] - whose value is only used in the implicit preference model
+ */
+ def computeYtY(factors: RDD[(Int, Array[Array[Double]])]) = {
+ if (implicitPrefs) {
+ Option(
+ factors.flatMapValues{ case factorArray =>
+ factorArray.map{ vector =>
+ val x = new DoubleMatrix(vector)
+ x.mmul(x.transpose())
+ }
+ }.reduceByKeyLocally((a, b) => a.addi(b))
+ .values
+ .reduce((a, b) => a.addi(b))
+ )
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs
+ */
+ def unblockFactors(blockedFactors: RDD[(Int, Array[Array[Double]])],
+ outLinks: RDD[(Int, OutLinkBlock)]) = {
+ blockedFactors.join(outLinks).flatMap{ case (b, (factors, outLinkBlock)) =>
+ for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i))
+ }
+ }
+
+ /**
* Make the out-links table for a block of the users (or products) dataset given the list of
* (user, product, rating) values for the users in that block (or the opposite for products).
*/
@@ -251,7 +317,9 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
userInLinks: RDD[(Int, InLinkBlock)],
partitioner: Partitioner,
rank: Int,
- lambda: Double)
+ lambda: Double,
+ alpha: Double,
+ YtY: Broadcast[Option[DoubleMatrix]])
: RDD[(Int, Array[Array[Double]])] =
{
val numBlocks = products.partitions.size
@@ -265,7 +333,9 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) }
}.groupByKey(partitioner)
.join(userInLinks)
- .mapValues{ case (messages, inLinkBlock) => updateBlock(messages, inLinkBlock, rank, lambda) }
+ .mapValues{ case (messages, inLinkBlock) =>
+ updateBlock(messages, inLinkBlock, rank, lambda, alpha, YtY)
+ }
}
/**
@@ -273,7 +343,7 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
* it received from each product and its InLinkBlock.
*/
def updateBlock(messages: Seq[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock,
- rank: Int, lambda: Double)
+ rank: Int, lambda: Double, alpha: Double, YtY: Broadcast[Option[DoubleMatrix]])
: Array[Array[Double]] =
{
// Sort the incoming block factor messages by block ID and make them an array
@@ -298,8 +368,14 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
fillXtX(x, tempXtX)
val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p)
for (i <- 0 until us.length) {
- userXtX(us(i)).addi(tempXtX)
- SimpleBlas.axpy(rs(i), x, userXy(us(i)))
+ implicitPrefs match {
+ case false =>
+ userXtX(us(i)).addi(tempXtX)
+ SimpleBlas.axpy(rs(i), x, userXy(us(i)))
+ case true =>
+ userXtX(us(i)).addi(tempXtX.mul(alpha * rs(i)))
+ SimpleBlas.axpy(1 + alpha * rs(i), x, userXy(us(i)))
+ }
}
}
}
@@ -311,7 +387,10 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
// Add regularization
(0 until rank).foreach(i => fullXtX.data(i*rank + i) += lambda)
// Solve the resulting matrix, which is symmetric and positive-definite
- Solve.solvePositive(fullXtX, userXy(index)).data
+ implicitPrefs match {
+ case false => Solve.solvePositive(fullXtX, userXy(index)).data
+ case true => Solve.solvePositive(fullXtX.add(YtY.value.get), userXy(index)).data
+ }
}
}
@@ -381,7 +460,7 @@ object ALS {
blocks: Int)
: MatrixFactorizationModel =
{
- new ALS(blocks, rank, iterations, lambda).run(ratings)
+ new ALS(blocks, rank, iterations, lambda, false, 1.0).run(ratings)
}
/**
@@ -419,6 +498,68 @@ object ALS {
train(ratings, rank, iterations, 0.01, -1)
}
+ /**
+ * Train a matrix factorization model given an RDD of 'implicit preferences' given by users
+ * to some products, in the form of (userID, productID, preference) pairs. We approximate the
+ * ratings matrix as the product of two lower-rank matrices of a given rank (number of features).
+ * To solve for these features, we run a given number of iterations of ALS. This is done using
+ * a level of parallelism given by `blocks`.
+ *
+ * @param ratings RDD of (userID, productID, rating) pairs
+ * @param rank number of features to use
+ * @param iterations number of iterations of ALS (recommended: 10-20)
+ * @param lambda regularization factor (recommended: 0.01)
+ * @param blocks level of parallelism to split computation into
+ * @param alpha confidence parameter (only applies when immplicitPrefs = true)
+ */
+ def trainImplicit(
+ ratings: RDD[Rating],
+ rank: Int,
+ iterations: Int,
+ lambda: Double,
+ blocks: Int,
+ alpha: Double)
+ : MatrixFactorizationModel =
+ {
+ new ALS(blocks, rank, iterations, lambda, true, alpha).run(ratings)
+ }
+
+ /**
+ * Train a matrix factorization model given an RDD of 'implicit preferences' given by users to
+ * some products, in the form of (userID, productID, preference) pairs. We approximate the
+ * ratings matrix as the product of two lower-rank matrices of a given rank (number of features).
+ * To solve for these features, we run a given number of iterations of ALS. The level of
+ * parallelism is determined automatically based on the number of partitions in `ratings`.
+ *
+ * @param ratings RDD of (userID, productID, rating) pairs
+ * @param rank number of features to use
+ * @param iterations number of iterations of ALS (recommended: 10-20)
+ * @param lambda regularization factor (recommended: 0.01)
+ */
+ def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double)
+ : MatrixFactorizationModel =
+ {
+ trainImplicit(ratings, rank, iterations, lambda, -1, alpha)
+ }
+
+ /**
+ * Train a matrix factorization model given an RDD of 'implicit preferences' ratings given by
+ * users to some products, in the form of (userID, productID, rating) pairs. We approximate the
+ * ratings matrix as the product of two lower-rank matrices of a given rank (number of features).
+ * To solve for these features, we run a given number of iterations of ALS. The level of
+ * parallelism is determined automatically based on the number of partitions in `ratings`.
+ * Model parameters `alpha` and `lambda` are set to reasonable default values
+ *
+ * @param ratings RDD of (userID, productID, rating) pairs
+ * @param rank number of features to use
+ * @param iterations number of iterations of ALS (recommended: 10-20)
+ */
+ def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int)
+ : MatrixFactorizationModel =
+ {
+ trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0)
+ }
+
private class ALSRegistrator extends KryoRegistrator {
override def registerClasses(kryo: Kryo) {
kryo.register(classOf[Rating])
@@ -426,29 +567,37 @@ object ALS {
}
def main(args: Array[String]) {
- if (args.length != 5 && args.length != 6) {
- println("Usage: ALS <master> <ratings_file> <rank> <iterations> <output_dir> [<blocks>]")
+ if (args.length < 5 || args.length > 9) {
+ println("Usage: ALS <master> <ratings_file> <rank> <iterations> <output_dir> " +
+ "[<lambda>] [<implicitPrefs>] [<alpha>] [<blocks>]")
System.exit(1)
}
val (master, ratingsFile, rank, iters, outputDir) =
(args(0), args(1), args(2).toInt, args(3).toInt, args(4))
- val blocks = if (args.length == 6) args(5).toInt else -1
+ val lambda = if (args.length >= 6) args(5).toDouble else 0.01
+ val implicitPrefs = if (args.length >= 7) args(6).toBoolean else false
+ val alpha = if (args.length >= 8) args(7).toDouble else 1
+ val blocks = if (args.length == 9) args(8).toInt else -1
+
System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
System.setProperty("spark.kryo.registrator", classOf[ALSRegistrator].getName)
System.setProperty("spark.kryo.referenceTracking", "false")
System.setProperty("spark.kryoserializer.buffer.mb", "8")
System.setProperty("spark.locality.wait", "10000")
+
val sc = new SparkContext(master, "ALS")
val ratings = sc.textFile(ratingsFile).map { line =>
val fields = line.split(',')
Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble)
}
- val model = ALS.train(ratings, rank, iters, 0.01, blocks)
+ val model = new ALS(rank = rank, iterations = iters, lambda = lambda,
+ numBlocks = blocks, implicitPrefs = implicitPrefs, alpha = alpha).run(ratings)
+
model.userFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") }
.saveAsTextFile(outputDir + "/userFeatures")
model.productFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") }
.saveAsTextFile(outputDir + "/productFeatures")
println("Final user/product features written to " + outputDir)
- System.exit(0)
+ sc.stop()
}
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
index 3323f6cee2..eafee060cd 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.recommendation;
import java.io.Serializable;
import java.util.List;
+import java.lang.Math;
import scala.Tuple2;
@@ -48,7 +49,7 @@ public class JavaALSSuite implements Serializable {
}
void validatePrediction(MatrixFactorizationModel model, int users, int products, int features,
- DoubleMatrix trueRatings, double matchThreshold) {
+ DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) {
DoubleMatrix predictedU = new DoubleMatrix(users, features);
List<scala.Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect();
for (int i = 0; i < features; ++i) {
@@ -68,12 +69,32 @@ public class JavaALSSuite implements Serializable {
DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose());
- for (int u = 0; u < users; ++u) {
- for (int p = 0; p < products; ++p) {
- double prediction = predictedRatings.get(u, p);
- double correct = trueRatings.get(u, p);
- Assert.assertTrue(Math.abs(prediction - correct) < matchThreshold);
+ if (!implicitPrefs) {
+ for (int u = 0; u < users; ++u) {
+ for (int p = 0; p < products; ++p) {
+ double prediction = predictedRatings.get(u, p);
+ double correct = trueRatings.get(u, p);
+ Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
+ prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold);
+ }
}
+ } else {
+ // For implicit prefs we use the confidence-weighted RMSE to test (ref Mahout's implicit ALS tests)
+ double sqErr = 0.0;
+ double denom = 0.0;
+ for (int u = 0; u < users; ++u) {
+ for (int p = 0; p < products; ++p) {
+ double prediction = predictedRatings.get(u, p);
+ double truePref = truePrefs.get(u, p);
+ double confidence = 1.0 + /* alpha = */ 1.0 * trueRatings.get(u, p);
+ double err = confidence * (truePref - prediction) * (truePref - prediction);
+ sqErr += err;
+ denom += 1.0;
+ }
+ }
+ double rmse = Math.sqrt(sqErr / denom);
+ Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f",
+ rmse, matchThreshold), Math.abs(rmse) < matchThreshold);
}
}
@@ -81,30 +102,62 @@ public class JavaALSSuite implements Serializable {
public void runALSUsingStaticMethods() {
int features = 1;
int iterations = 15;
- int users = 10;
- int products = 10;
- scala.Tuple2<List<Rating>, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
- users, products, features, 0.7);
+ int users = 50;
+ int products = 100;
+ scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
- validatePrediction(model, users, products, features, testData._2(), 0.3);
+ validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3());
}
@Test
public void runALSUsingConstructor() {
int features = 2;
int iterations = 15;
- int users = 20;
- int products = 30;
- scala.Tuple2<List<Rating>, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
- users, products, features, 0.7);
+ int users = 100;
+ int products = 200;
+ scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, false);
JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
.run(data.rdd());
- validatePrediction(model, users, products, features, testData._2(), 0.3);
+ validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3());
+ }
+
+ @Test
+ public void runImplicitALSUsingStaticMethods() {
+ int features = 1;
+ int iterations = 15;
+ int users = 80;
+ int products = 160;
+ scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, true);
+
+ JavaRDD<Rating> data = sc.parallelize(testData._1());
+ MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
+ validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
+ }
+
+ @Test
+ public void runImplicitALSUsingConstructor() {
+ int features = 2;
+ int iterations = 15;
+ int users = 100;
+ int products = 200;
+ scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+ users, products, features, 0.7, true);
+
+ JavaRDD<Rating> data = sc.parallelize(testData._1());
+
+ MatrixFactorizationModel model = new ALS().setRank(features)
+ .setIterations(iterations)
+ .setImplicitPrefs(true)
+ .run(data.rdd());
+ validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index 347ef238f4..fafc5ec5f2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -34,16 +34,19 @@ object ALSSuite {
users: Int,
products: Int,
features: Int,
- samplingRate: Double): (java.util.List[Rating], DoubleMatrix) = {
- val (sampledRatings, trueRatings) = generateRatings(users, products, features, samplingRate)
- (seqAsJavaList(sampledRatings), trueRatings)
+ samplingRate: Double,
+ implicitPrefs: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = {
+ val (sampledRatings, trueRatings, truePrefs) =
+ generateRatings(users, products, features, samplingRate, implicitPrefs)
+ (seqAsJavaList(sampledRatings), trueRatings, truePrefs)
}
def generateRatings(
users: Int,
products: Int,
features: Int,
- samplingRate: Double): (Seq[Rating], DoubleMatrix) = {
+ samplingRate: Double,
+ implicitPrefs: Boolean = false): (Seq[Rating], DoubleMatrix, DoubleMatrix) = {
val rand = new Random(42)
// Create a random matrix with uniform values from -1 to 1
@@ -52,14 +55,20 @@ object ALSSuite {
val userMatrix = randomMatrix(users, features)
val productMatrix = randomMatrix(features, products)
- val trueRatings = userMatrix.mmul(productMatrix)
+ val (trueRatings, truePrefs) = implicitPrefs match {
+ case true =>
+ val raw = new DoubleMatrix(users, products, Array.fill(users * products)(rand.nextInt(10).toDouble): _*)
+ val prefs = new DoubleMatrix(users, products, raw.data.map(v => if (v > 0) 1.0 else 0.0): _*)
+ (raw, prefs)
+ case false => (userMatrix.mmul(productMatrix), null)
+ }
val sampledRatings = {
for (u <- 0 until users; p <- 0 until products if rand.nextDouble() < samplingRate)
yield Rating(u, p, trueRatings.get(u, p))
}
- (sampledRatings, trueRatings)
+ (sampledRatings, trueRatings, truePrefs)
}
}
@@ -78,11 +87,19 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
}
test("rank-1 matrices") {
- testALS(10, 20, 1, 15, 0.7, 0.3)
+ testALS(50, 100, 1, 15, 0.7, 0.3)
}
test("rank-2 matrices") {
- testALS(20, 30, 2, 15, 0.7, 0.3)
+ testALS(100, 200, 2, 15, 0.7, 0.3)
+ }
+
+ test("rank-1 matrices implicit") {
+ testALS(80, 160, 1, 15, 0.7, 0.4, true)
+ }
+
+ test("rank-2 matrices implicit") {
+ testALS(100, 200, 2, 15, 0.7, 0.4, true)
}
/**
@@ -96,11 +113,14 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
* @param matchThreshold max difference allowed to consider a predicted rating correct
*/
def testALS(users: Int, products: Int, features: Int, iterations: Int,
- samplingRate: Double, matchThreshold: Double)
+ samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false)
{
- val (sampledRatings, trueRatings) = ALSSuite.generateRatings(users, products,
- features, samplingRate)
- val model = ALS.train(sc.parallelize(sampledRatings), features, iterations)
+ val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products,
+ features, samplingRate, implicitPrefs)
+ val model = implicitPrefs match {
+ case false => ALS.train(sc.parallelize(sampledRatings), features, iterations)
+ case true => ALS.trainImplicit(sc.parallelize(sampledRatings), features, iterations)
+ }
val predictedU = new DoubleMatrix(users, features)
for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) {
@@ -112,12 +132,31 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
}
val predictedRatings = predictedU.mmul(predictedP.transpose)
- for (u <- 0 until users; p <- 0 until products) {
- val prediction = predictedRatings.get(u, p)
- val correct = trueRatings.get(u, p)
- if (math.abs(prediction - correct) > matchThreshold) {
- fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
- u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP))
+ if (!implicitPrefs) {
+ for (u <- 0 until users; p <- 0 until products) {
+ val prediction = predictedRatings.get(u, p)
+ val correct = trueRatings.get(u, p)
+ if (math.abs(prediction - correct) > matchThreshold) {
+ fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
+ u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP))
+ }
+ }
+ } else {
+ // For implicit prefs we use the confidence-weighted RMSE to test (ref Mahout's tests)
+ var sqErr = 0.0
+ var denom = 0.0
+ for (u <- 0 until users; p <- 0 until products) {
+ val prediction = predictedRatings.get(u, p)
+ val truePref = truePrefs.get(u, p)
+ val confidence = 1 + 1.0 * trueRatings.get(u, p)
+ val err = confidence * (truePref - prediction) * (truePref - prediction)
+ sqErr += err
+ denom += 1
+ }
+ val rmse = math.sqrt(sqErr / denom)
+ if (math.abs(rmse) > matchThreshold) {
+ fail("Model failed to predict RMSE: %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
+ rmse, truePrefs, predictedRatings, predictedU, predictedP))
}
}
}
diff --git a/pom.xml b/pom.xml
index d42bf7166f..a665bde894 100644
--- a/pom.xml
+++ b/pom.xml
@@ -25,7 +25,7 @@
</parent>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.8.0-SNAPSHOT</version>
+ <version>0.9.0-incubating-SNAPSHOT</version>
<packaging>pom</packaging>
<name>Spark Project Parent POM</name>
<url>http://spark.incubator.apache.org/</url>
@@ -40,6 +40,7 @@
<connection>scm:git:git@github.com:apache/incubator-spark.git</connection>
<developerConnection>scm:git:https://git-wip-us.apache.org/repos/asf/incubator-spark.git</developerConnection>
<url>scm:git:git@github.com:apache/incubator-spark.git</url>
+ <tag>HEAD</tag>
</scm>
<developers>
<developer>
@@ -323,7 +324,7 @@
<dependency>
<groupId>org.scalatest</groupId>
- <artifactId>scalatest_${scala.version}</artifactId>
+ <artifactId>scalatest_2.9.3</artifactId>
<version>1.9.1</version>
<scope>test</scope>
</dependency>
@@ -335,7 +336,7 @@
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
- <artifactId>scalacheck_${scala.version}</artifactId>
+ <artifactId>scalacheck_2.9.3</artifactId>
<version>1.10.0</version>
<scope>test</scope>
</dependency>
@@ -346,6 +347,17 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.apache.zookeeper</groupId>
+ <artifactId>zookeeper</artifactId>
+ <version>3.4.5</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client</artifactId>
<version>${hadoop.version}</version>
@@ -558,7 +570,6 @@
<useZincServer>true</useZincServer>
<args>
<arg>-unchecked</arg>
- <arg>-optimise</arg>
<arg>-deprecation</arg>
</args>
<jvmArgs>
@@ -605,7 +616,7 @@
<junitxml>.</junitxml>
<filereports>${project.build.directory}/SparkTestSuite.txt</filereports>
<argLine>-Xms64m -Xmx3g</argLine>
- <stderr/>
+ <stderr />
</configuration>
<executions>
<execution>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 7dc6c58401..fee9d1c6b9 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -81,9 +81,9 @@ object SparkBuild extends Build {
def sharedSettings = Defaults.defaultSettings ++ Seq(
organization := "org.apache.spark",
- version := "0.8.0-SNAPSHOT",
+ version := "0.9.0-incubating-SNAPSHOT",
scalaVersion := "2.9.3",
- scalacOptions := Seq("-unchecked", "-optimize", "-deprecation",
+ scalacOptions := Seq("-Xmax-classfile-name", "120", "-unchecked", "-deprecation",
"-target:" + SCALAC_JVM_VERSION),
javacOptions := Seq("-target", JAVAC_JVM_VERSION, "-source", JAVAC_JVM_VERSION),
unmanagedJars in Compile <<= baseDirectory map { base => (base / "lib" ** "*.jar").classpath },
@@ -99,6 +99,9 @@ object SparkBuild extends Build {
// Only allow one test at a time, even across projects, since they run in the same JVM
concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
+ // also check the local Maven repository ~/.m2
+ resolvers ++= Seq(Resolver.file("Local Maven Repo", file(Path.userHome + "/.m2/repository"))),
+
// Shared between both core and streaming.
resolvers ++= Seq("Akka Repository" at "http://repo.akka.io/releases/"),
@@ -155,6 +158,7 @@ object SparkBuild extends Build {
*/
+
libraryDependencies ++= Seq(
"org.eclipse.jetty" % "jetty-server" % "7.6.8.v20121106",
"org.scalatest" %% "scalatest" % "1.9.1" % "test",
@@ -177,6 +181,7 @@ object SparkBuild extends Build {
val slf4jVersion = "1.7.2"
+ val excludeCglib = ExclusionRule(organization = "org.sonatype.sisu.inject")
val excludeJackson = ExclusionRule(organization = "org.codehaus.jackson")
val excludeNetty = ExclusionRule(organization = "org.jboss.netty")
val excludeAsm = ExclusionRule(organization = "asm")
@@ -209,10 +214,11 @@ object SparkBuild extends Build {
"org.apache.mesos" % "mesos" % "0.13.0",
"io.netty" % "netty-all" % "4.0.0.Beta2",
"org.apache.derby" % "derby" % "10.4.2.0" % "test",
- "org.apache.hadoop" % "hadoop-client" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm),
+ "org.apache.hadoop" % "hadoop-client" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm, excludeCglib),
"net.java.dev.jets3t" % "jets3t" % "0.7.1",
"org.apache.avro" % "avro" % "1.7.4",
"org.apache.avro" % "avro-ipc" % "1.7.4" excludeAll(excludeNetty),
+ "org.apache.zookeeper" % "zookeeper" % "3.4.5" excludeAll(excludeNetty),
"com.codahale.metrics" % "metrics-core" % "3.0.0",
"com.codahale.metrics" % "metrics-jvm" % "3.0.0",
"com.codahale.metrics" % "metrics-json" % "3.0.0",
@@ -247,6 +253,7 @@ object SparkBuild extends Build {
exclude("log4j","log4j")
exclude("org.apache.cassandra.deps", "avro")
excludeAll(excludeSnappy)
+ excludeAll(excludeCglib)
)
) ++ assemblySettings ++ extraAssemblySettings
@@ -293,10 +300,10 @@ object SparkBuild extends Build {
def yarnEnabledSettings = Seq(
libraryDependencies ++= Seq(
// Exclude rule required for all ?
- "org.apache.hadoop" % "hadoop-client" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm),
- "org.apache.hadoop" % "hadoop-yarn-api" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm),
- "org.apache.hadoop" % "hadoop-yarn-common" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm),
- "org.apache.hadoop" % "hadoop-yarn-client" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm)
+ "org.apache.hadoop" % "hadoop-client" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm, excludeCglib),
+ "org.apache.hadoop" % "hadoop-yarn-api" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm, excludeCglib),
+ "org.apache.hadoop" % "hadoop-yarn-common" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm, excludeCglib),
+ "org.apache.hadoop" % "hadoop-yarn-client" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm, excludeCglib)
)
)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 58e1849cad..7019fb8bee 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -29,7 +29,7 @@ from threading import Thread
from pyspark import cloudpickle
from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
- read_from_pickle_file
+ read_from_pickle_file, pack_long
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
@@ -117,8 +117,6 @@ class RDD(object):
else:
return None
- # TODO persist(self, storageLevel)
-
def map(self, f, preservesPartitioning=False):
"""
Return a new RDD containing the distinct elements in this RDD.
@@ -227,7 +225,7 @@ class RDD(object):
total = num
samples = self.sample(withReplacement, fraction, seed).collect()
-
+
# If the first sample didn't turn out large enough, keep trying to take samples;
# this shouldn't happen often because we use a big multiplier for their initial size.
# See: scala/spark/RDD.scala
@@ -263,7 +261,55 @@ class RDD(object):
raise TypeError
return self.union(other)
- # TODO: sort
+ def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda x: x):
+ """
+ Sorts this RDD, which is assumed to consist of (key, value) pairs.
+
+ >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
+ >>> sc.parallelize(tmp).sortByKey(True, 2).collect()
+ [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
+ >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)]
+ >>> tmp2.extend([('whose', 6), ('fleece', 7), ('was', 8), ('white', 9)])
+ >>> sc.parallelize(tmp2).sortByKey(True, 3, keyfunc=lambda k: k.lower()).collect()
+ [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5), ('little', 4), ('Mary', 1), ('was', 8), ('white', 9), ('whose', 6)]
+ """
+ if numPartitions is None:
+ numPartitions = self.ctx.defaultParallelism
+
+ bounds = list()
+
+ # first compute the boundary of each part via sampling: we want to partition
+ # the key-space into bins such that the bins have roughly the same
+ # number of (key, value) pairs falling into them
+ if numPartitions > 1:
+ rddSize = self.count()
+ maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
+ fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
+
+ samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()
+ samples = sorted(samples, reverse=(not ascending), key=keyfunc)
+
+ # we have numPartitions many parts but one of the them has
+ # an implicit boundary
+ for i in range(0, numPartitions - 1):
+ index = (len(samples) - 1) * (i + 1) / numPartitions
+ bounds.append(samples[index])
+
+ def rangePartitionFunc(k):
+ p = 0
+ while p < len(bounds) and keyfunc(k) > bounds[p]:
+ p += 1
+ if ascending:
+ return p
+ else:
+ return numPartitions-1-p
+
+ def mapFunc(iterator):
+ yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
+
+ return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc)
+ .mapPartitions(mapFunc,preservesPartitioning=True)
+ .flatMap(lambda x: x, preservesPartitioning=True))
def glom(self):
"""
@@ -425,7 +471,7 @@ class RDD(object):
3
"""
return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
-
+
def stats(self):
"""
Return a L{StatCounter} object that captures the mean, variance
@@ -462,7 +508,7 @@ class RDD(object):
0.816...
"""
return self.stats().stdev()
-
+
def sampleStdev(self):
"""
Compute the sample standard deviation of this RDD's elements (which corrects for bias in
@@ -690,11 +736,13 @@ class RDD(object):
# form the hash buckets in Python, transferring O(numPartitions) objects
# to Java. Each object is a (splitNumber, [objects]) pair.
def add_shuffle_key(split, iterator):
+
buckets = defaultdict(list)
+
for (k, v) in iterator:
buckets[partitionFunc(k) % numPartitions].append((k, v))
for (split, items) in buckets.iteritems():
- yield str(split)
+ yield pack_long(split)
yield dump_pickle(Batch(items))
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
@@ -830,9 +878,9 @@ class RDD(object):
>>> y = sc.parallelize([("a", 3), ("c", None)])
>>> sorted(x.subtractByKey(y).collect())
[('b', 4), ('b', 5)]
- """
- filter_func = lambda tpl: len(tpl[1][0]) > 0 and len(tpl[1][1]) == 0
- map_func = lambda tpl: [(tpl[0], val) for val in tpl[1][0]]
+ """
+ filter_func = lambda (key, vals): len(vals[0]) > 0 and len(vals[1]) == 0
+ map_func = lambda (key, vals): [(key, val) for val in vals[0]]
return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func)
def subtract(self, other, numPartitions=None):
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index fecacd1241..54fed1c9c7 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -67,6 +67,10 @@ def write_long(value, stream):
stream.write(struct.pack("!q", value))
+def pack_long(value):
+ return struct.pack("!q", value)
+
+
def read_int(stream):
length = stream.read(4)
if length == "":
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index dc205b306f..a475959090 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -35,7 +35,7 @@ print """Welcome to
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
- /__ / .__/\_,_/_/ /_/\_\ version 0.8.0
+ /__ / .__/\_,_/_/ /_/\_\ version 0.9.0-SNAPSHOT
/_/
"""
print "Using Python version %s (%s, %s)" % (
diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml
index 3685561501..f6bf94be6b 100644
--- a/repl-bin/pom.xml
+++ b/repl-bin/pom.xml
@@ -21,12 +21,12 @@
<parent>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.8.0-SNAPSHOT</version>
+ <version>0.9.0-incubating-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-repl-bin</artifactId>
+ <artifactId>spark-repl-bin_2.9.3</artifactId>
<packaging>pom</packaging>
<name>Spark Project REPL binary packaging</name>
<url>http://spark.incubator.apache.org/</url>
@@ -40,18 +40,18 @@
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core</artifactId>
+ <artifactId>spark-core_2.9.3</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-bagel</artifactId>
+ <artifactId>spark-bagel_2.9.3</artifactId>
<version>${project.version}</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-repl</artifactId>
+ <artifactId>spark-repl_2.9.3</artifactId>
<version>${project.version}</version>
<scope>runtime</scope>
</dependency>
@@ -89,7 +89,7 @@
</goals>
<configuration>
<transformers>
- <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
+ <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer" />
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
<resource>reference.conf</resource>
</transformer>
diff --git a/repl/pom.xml b/repl/pom.xml
index 3123b37780..49d86621dd 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -21,12 +21,12 @@
<parent>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.8.0-SNAPSHOT</version>
+ <version>0.9.0-incubating-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-repl</artifactId>
+ <artifactId>spark-repl_2.9.3</artifactId>
<packaging>jar</packaging>
<name>Spark Project REPL</name>
<url>http://spark.incubator.apache.org/</url>
@@ -39,18 +39,18 @@
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core</artifactId>
+ <artifactId>spark-core_2.9.3</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-bagel</artifactId>
+ <artifactId>spark-bagel_2.9.3</artifactId>
<version>${project.version}</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-mllib</artifactId>
+ <artifactId>spark-mllib_2.9.3</artifactId>
<version>${project.version}</version>
<scope>runtime</scope>
</dependency>
@@ -76,12 +76,12 @@
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
- <artifactId>scalatest_${scala.version}</artifactId>
+ <artifactId>scalatest_2.9.3</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
- <artifactId>scalacheck_${scala.version}</artifactId>
+ <artifactId>scalacheck_2.9.3</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
@@ -101,14 +101,14 @@
<configuration>
<exportAntProperties>true</exportAntProperties>
<tasks>
- <property name="spark.classpath" refid="maven.test.classpath"/>
- <property environment="env"/>
+ <property name="spark.classpath" refid="maven.test.classpath" />
+ <property environment="env" />
<fail message="Please set the SCALA_HOME (or SCALA_LIBRARY_PATH if scala is on the path) environment variables and retry.">
<condition>
<not>
<or>
- <isset property="env.SCALA_HOME"/>
- <isset property="env.SCALA_LIBRARY_PATH"/>
+ <isset property="env.SCALA_HOME" />
+ <isset property="env.SCALA_LIBRARY_PATH" />
</or>
</not>
</condition>
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
index 193ccb48ee..36f54a22cf 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -200,7 +200,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
- /___/ .__/\_,_/_/ /_/\_\ version 0.8.0
+ /___/ .__/\_,_/_/ /_/\_\ version 0.9.0-SNAPSHOT
/_/
""")
import Properties._
diff --git a/spark-class b/spark-class
index 037abda3b7..e111ef6da7 100755
--- a/spark-class
+++ b/spark-class
@@ -37,7 +37,7 @@ fi
# If this is a standalone cluster daemon, reset SPARK_JAVA_OPTS and SPARK_MEM to reasonable
# values for that; it doesn't need a lot
-if [ "$1" = "spark.deploy.master.Master" -o "$1" = "spark.deploy.worker.Worker" ]; then
+if [ "$1" = "org.apache.spark.deploy.master.Master" -o "$1" = "org.apache.spark.deploy.worker.Worker" ]; then
SPARK_MEM=${SPARK_DAEMON_MEMORY:-512m}
SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true"
# Do not overwrite SPARK_JAVA_OPTS environment variable in this script
@@ -49,19 +49,19 @@ fi
# Add java opts for master, worker, executor. The opts maybe null
case "$1" in
- 'spark.deploy.master.Master')
+ 'org.apache.spark.deploy.master.Master')
OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_MASTER_OPTS"
;;
- 'spark.deploy.worker.Worker')
+ 'org.apache.spark.deploy.worker.Worker')
OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_WORKER_OPTS"
;;
- 'spark.executor.StandaloneExecutorBackend')
+ 'org.apache.spark.executor.StandaloneExecutorBackend')
OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
;;
- 'spark.executor.MesosExecutorBackend')
+ 'org.apache.spark.executor.MesosExecutorBackend')
OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
;;
- 'spark.repl.Main')
+ 'org.apache.spark.repl.Main')
OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_REPL_OPTS"
;;
esac
diff --git a/streaming/pom.xml b/streaming/pom.xml
index 7bea069b61..3b25fb49fb 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -21,12 +21,12 @@
<parent>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.8.0-SNAPSHOT</version>
+ <version>0.9.0-incubating-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-streaming</artifactId>
+ <artifactId>spark-streaming_2.9.3</artifactId>
<packaging>jar</packaging>
<name>Spark Project Streaming</name>
<url>http://spark.incubator.apache.org/</url>
@@ -42,7 +42,7 @@
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core</artifactId>
+ <artifactId>spark-core_2.9.3</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
@@ -58,6 +58,7 @@
<groupId>org.apache.kafka</groupId>
<artifactId>kafka</artifactId>
<version>0.7.2-spark</version> <!-- Comes from our in-project repository -->
+ <scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.flume</groupId>
@@ -91,12 +92,12 @@
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
- <artifactId>scalatest_${scala.version}</artifactId>
+ <artifactId>scalatest_2.9.3</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
- <artifactId>scalacheck_${scala.version}</artifactId>
+ <artifactId>scalacheck_2.9.3</artifactId>
<scope>test</scope>
</dependency>
<dependency>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala
index aae79a4e6f..b97fb7e6e3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala
@@ -30,10 +30,11 @@ import akka.actor._
import akka.pattern.ask
import akka.util.duration._
import akka.dispatch._
+import org.apache.spark.storage.BlockId
private[streaming] sealed trait NetworkInputTrackerMessage
private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage
-private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage
+private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[BlockId], metadata: Any) extends NetworkInputTrackerMessage
private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage
/**
@@ -48,7 +49,7 @@ class NetworkInputTracker(
val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*)
val receiverExecutor = new ReceiverExecutor()
val receiverInfo = new HashMap[Int, ActorRef]
- val receivedBlockIds = new HashMap[Int, Queue[String]]
+ val receivedBlockIds = new HashMap[Int, Queue[BlockId]]
val timeout = 5000.milliseconds
var currentTime: Time = null
@@ -67,9 +68,9 @@ class NetworkInputTracker(
}
/** Return all the blocks received from a receiver. */
- def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized {
+ def getBlockIds(receiverId: Int, time: Time): Array[BlockId] = synchronized {
val queue = receivedBlockIds.synchronized {
- receivedBlockIds.getOrElse(receiverId, new Queue[String]())
+ receivedBlockIds.getOrElse(receiverId, new Queue[BlockId]())
}
val result = queue.synchronized {
queue.dequeueAll(x => true)
@@ -92,7 +93,7 @@ class NetworkInputTracker(
case AddBlocks(streamId, blockIds, metadata) => {
val tmp = receivedBlockIds.synchronized {
if (!receivedBlockIds.contains(streamId)) {
- receivedBlockIds += ((streamId, new Queue[String]))
+ receivedBlockIds += ((streamId, new Queue[BlockId]))
}
receivedBlockIds(streamId)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
index 31f9891560..8d3ac0fc65 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
@@ -31,7 +31,7 @@ import org.apache.spark.streaming.util.{RecurringTimer, SystemClock}
import org.apache.spark.streaming._
import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.rdd.{RDD, BlockRDD}
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId}
/**
* Abstract class for defining any InputDStream that has to start a receiver on worker
@@ -69,7 +69,7 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming
val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime)
Some(new BlockRDD[T](ssc.sc, blockIds))
} else {
- Some(new BlockRDD[T](ssc.sc, Array[String]()))
+ Some(new BlockRDD[T](ssc.sc, Array[BlockId]()))
}
}
}
@@ -77,7 +77,7 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming
private[streaming] sealed trait NetworkReceiverMessage
private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage
-private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage
+private[streaming] case class ReportBlock(blockId: BlockId, metadata: Any) extends NetworkReceiverMessage
private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage
/**
@@ -158,7 +158,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
/**
* Pushes a block (as an ArrayBuffer filled with data) into the block manager.
*/
- def pushBlock(blockId: String, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) {
+ def pushBlock(blockId: BlockId, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) {
env.blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], level)
actor ! ReportBlock(blockId, metadata)
}
@@ -166,7 +166,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
/**
* Pushes a block (as bytes) into the block manager.
*/
- def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) {
+ def pushBlock(blockId: BlockId, bytes: ByteBuffer, metadata: Any, level: StorageLevel) {
env.blockManager.putBytes(blockId, bytes, level)
actor ! ReportBlock(blockId, metadata)
}
@@ -209,7 +209,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
class BlockGenerator(storageLevel: StorageLevel)
extends Serializable with Logging {
- case class Block(id: String, buffer: ArrayBuffer[T], metadata: Any = null)
+ case class Block(id: BlockId, buffer: ArrayBuffer[T], metadata: Any = null)
val clock = new SystemClock()
val blockInterval = System.getProperty("spark.streaming.blockInterval", "200").toLong
@@ -241,7 +241,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
val newBlockBuffer = currentBuffer
currentBuffer = new ArrayBuffer[T]
if (newBlockBuffer.size > 0) {
- val blockId = "input-" + NetworkReceiver.this.streamId + "-" + (time - blockInterval)
+ val blockId = StreamBlockId(NetworkReceiver.this.streamId, time - blockInterval)
val newBlock = new Block(blockId, newBlockBuffer)
blocksForPushing.add(newBlock)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala
index c91f12ecd7..10ed4ef78d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala
@@ -18,7 +18,7 @@
package org.apache.spark.streaming.dstream
import org.apache.spark.Logging
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming.StreamingContext
import java.net.InetSocketAddress
@@ -71,7 +71,7 @@ class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel)
var nextBlockNumber = 0
while (true) {
val buffer = queue.take()
- val blockId = "input-" + streamId + "-" + nextBlockNumber
+ val blockId = StreamBlockId(streamId, nextBlockNumber)
nextBlockNumber += 1
pushBlock(blockId, buffer, null, storageLevel)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
index 4b5d8c467e..ef0f85a717 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
@@ -21,7 +21,7 @@ import akka.actor.{ Actor, PoisonPill, Props, SupervisorStrategy }
import akka.actor.{ actorRef2Scala, ActorRef }
import akka.actor.{ PossiblyHarmful, OneForOneStrategy }
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming.dstream.NetworkReceiver
import java.util.concurrent.atomic.AtomicInteger
@@ -159,7 +159,7 @@ private[streaming] class ActorReceiver[T: ClassManifest](
protected def pushBlock(iter: Iterator[T]) {
val buffer = new ArrayBuffer[T]
buffer ++= iter
- pushBlock("block-" + streamId + "-" + System.nanoTime(), buffer, null, storageLevel)
+ pushBlock(StreamBlockId(streamId, System.nanoTime()), buffer, null, storageLevel)
}
protected def onStart() = {
diff --git a/tools/pom.xml b/tools/pom.xml
index 77646a6816..f1c489beea 100644
--- a/tools/pom.xml
+++ b/tools/pom.xml
@@ -20,12 +20,12 @@
<parent>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.8.0-SNAPSHOT</version>
+ <version>0.9.0-incubating-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-tools</artifactId>
+ <artifactId>spark-tools_2.9.3</artifactId>
<packaging>jar</packaging>
<name>Spark Project Tools</name>
<url>http://spark.incubator.apache.org/</url>
@@ -33,17 +33,17 @@
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core</artifactId>
+ <artifactId>spark-core_2.9.3</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-streaming</artifactId>
+ <artifactId>spark-streaming_2.9.3</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
- <artifactId>scalatest_${scala.version}</artifactId>
+ <artifactId>scalatest_2.9.3</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
diff --git a/yarn/pom.xml b/yarn/pom.xml
index 21b650d1ea..3bc619df07 100644
--- a/yarn/pom.xml
+++ b/yarn/pom.xml
@@ -20,12 +20,12 @@
<parent>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.8.0-SNAPSHOT</version>
+ <version>0.9.0-incubating-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-yarn</artifactId>
+ <artifactId>spark-yarn_2.9.3</artifactId>
<packaging>jar</packaging>
<name>Spark Project YARN Support</name>
<url>http://spark.incubator.apache.org/</url>
@@ -33,7 +33,7 @@
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core</artifactId>
+ <artifactId>spark-core_2.9.3</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
@@ -97,7 +97,7 @@
</goals>
<configuration>
<transformers>
- <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
+ <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer" />
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
<resource>reference.conf</resource>
</transformer>
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 858b58d338..c1a87d3373 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -17,22 +17,25 @@
package org.apache.spark.deploy.yarn
+import java.io.IOException;
import java.net.Socket
+import java.security.PrivilegedExceptionAction
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.util.ShutdownHookManager
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
-import scala.collection.JavaConversions._
import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.util.Utils
import org.apache.hadoop.security.UserGroupInformation
-import java.security.PrivilegedExceptionAction
+import scala.collection.JavaConversions._
class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
@@ -43,18 +46,26 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
private var appAttemptId: ApplicationAttemptId = null
private var userThread: Thread = null
private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+ private val fs = FileSystem.get(yarnConf)
private var yarnAllocator: YarnAllocationHandler = null
private var isFinished:Boolean = false
private var uiAddress: String = ""
+ private val maxAppAttempts: Int = conf.getInt(YarnConfiguration.RM_AM_MAX_RETRIES,
+ YarnConfiguration.DEFAULT_RM_AM_MAX_RETRIES)
+ private var isLastAMRetry: Boolean = true
def run() {
// setup the directories so things go to yarn approved directories rather
// then user specified and /tmp
System.setProperty("spark.local.dir", getLocalDirs())
+
+ // use priority 30 as its higher then HDFS. Its same priority as MapReduce is using
+ ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30)
appAttemptId = getApplicationAttemptId()
+ isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts;
resourceManager = registerWithResourceManager()
// Workaround until hadoop moves to something which has
@@ -183,6 +194,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
// It need shutdown hook to set SUCCEEDED
successed = true
} finally {
+ logDebug("finishing main")
+ isLastAMRetry = true;
if (successed) {
ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
} else {
@@ -229,8 +242,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
}
}
-
-
private def allocateWorkers() {
try {
logInfo("Allocating " + args.numWorkers + " workers.")
@@ -329,6 +340,40 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
resourceManager.finishApplicationMaster(finishReq)
}
+
+ /**
+ * clean up the staging directory.
+ */
+ private def cleanupStagingDir() {
+ var stagingDirPath: Path = null
+ try {
+ val preserveFiles = System.getProperty("spark.yarn.preserve.staging.files", "false").toBoolean
+ if (!preserveFiles) {
+ stagingDirPath = new Path(System.getenv("SPARK_YARN_JAR_PATH")).getParent()
+ if (stagingDirPath == null) {
+ logError("Staging directory is null")
+ return
+ }
+ logInfo("Deleting staging directory " + stagingDirPath)
+ fs.delete(stagingDirPath, true)
+ }
+ } catch {
+ case e: IOException =>
+ logError("Failed to cleanup staging dir " + stagingDirPath, e)
+ }
+ }
+
+ // The shutdown hook that runs when a signal is received AND during normal
+ // close of the JVM.
+ class AppMasterShutdownHook(appMaster: ApplicationMaster) extends Runnable {
+
+ def run() {
+ logInfo("AppMaster received a signal.")
+ // we need to clean up staging dir before HDFS is shut down
+ // make sure we don't delete it until this is the last AM
+ if (appMaster.isLastAMRetry) appMaster.cleanupStagingDir()
+ }
+ }
}
@@ -368,6 +413,8 @@ object ApplicationMaster {
// Add a shutdown hook - as a best case effort in case users do not call sc.stop or do System.exit
// Should not really have to do this, but it helps yarn to evict resources earlier.
// not to mention, prevent Client declaring failure even though we exit'ed properly.
+ // Note that this will unfortunately not properly clean up the staging files because it gets called to
+ // late and the filesystem is already shutdown.
if (modified) {
Runtime.getRuntime().addShutdownHook(new Thread with Logging {
// This is not just to log, but also to ensure that log system is initialized for this instance when we actually are 'run'
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 844c707834..8afb3e39cb 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -45,7 +45,13 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
var rpc: YarnRPC = YarnRPC.create(conf)
val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
- val credentials = UserGroupInformation.getCurrentUser().getCredentials();
+ val credentials = UserGroupInformation.getCurrentUser().getCredentials()
+ private var distFiles = None: Option[String]
+ private var distFilesTimeStamps = None: Option[String]
+ private var distFilesFileSizes = None: Option[String]
+ private var distArchives = None: Option[String]
+ private var distArchivesTimeStamps = None: Option[String]
+ private var distArchivesFileSizes = None: Option[String]
def run() {
init(yarnConf)
@@ -57,7 +63,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
verifyClusterResources(newApp)
val appContext = createApplicationSubmissionContext(appId)
- val localResources = prepareLocalResources(appId, "spark")
+ val localResources = prepareLocalResources(appId, ".sparkStaging")
val env = setupLaunchEnv(localResources)
val amContainer = createContainerLaunchContext(newApp, localResources, env)
@@ -106,13 +112,76 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
logInfo("Setting up application submission context for ASM")
val appContext = Records.newRecord(classOf[ApplicationSubmissionContext])
appContext.setApplicationId(appId)
- appContext.setApplicationName("Spark")
+ appContext.setApplicationName(args.appName)
return appContext
}
-
- def prepareLocalResources(appId: ApplicationId, appName: String): HashMap[String, LocalResource] = {
+
+ /**
+ * Copy the local file into HDFS and configure to be distributed with the
+ * job via the distributed cache.
+ * If a fragment is specified the file will be referenced as that fragment.
+ */
+ private def copyLocalFile(
+ dstDir: Path,
+ resourceType: LocalResourceType,
+ originalPath: Path,
+ replication: Short,
+ localResources: HashMap[String,LocalResource],
+ fragment: String,
+ appMasterOnly: Boolean = false): Unit = {
+ val fs = FileSystem.get(conf)
+ val newPath = new Path(dstDir, originalPath.getName())
+ logInfo("Uploading " + originalPath + " to " + newPath)
+ fs.copyFromLocalFile(false, true, originalPath, newPath)
+ fs.setReplication(newPath, replication);
+ val destStatus = fs.getFileStatus(newPath)
+
+ val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ amJarRsrc.setType(resourceType)
+ amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(newPath))
+ amJarRsrc.setTimestamp(destStatus.getModificationTime())
+ amJarRsrc.setSize(destStatus.getLen())
+ var pathURI: URI = new URI(newPath.toString() + "#" + originalPath.getName());
+ if ((fragment == null) || (fragment.isEmpty())){
+ localResources(originalPath.getName()) = amJarRsrc
+ } else {
+ localResources(fragment) = amJarRsrc
+ pathURI = new URI(newPath.toString() + "#" + fragment);
+ }
+ val distPath = pathURI.toString()
+ if (appMasterOnly == true) return
+ if (resourceType == LocalResourceType.FILE) {
+ distFiles match {
+ case Some(path) =>
+ distFilesFileSizes = Some(distFilesFileSizes.get + "," +
+ destStatus.getLen().toString())
+ distFilesTimeStamps = Some(distFilesTimeStamps.get + "," +
+ destStatus.getModificationTime().toString())
+ distFiles = Some(path + "," + distPath)
+ case _ =>
+ distFilesFileSizes = Some(destStatus.getLen().toString())
+ distFilesTimeStamps = Some(destStatus.getModificationTime().toString())
+ distFiles = Some(distPath)
+ }
+ } else {
+ distArchives match {
+ case Some(path) =>
+ distArchivesTimeStamps = Some(distArchivesTimeStamps.get + "," +
+ destStatus.getModificationTime().toString())
+ distArchivesFileSizes = Some(distArchivesFileSizes.get + "," +
+ destStatus.getLen().toString())
+ distArchives = Some(path + "," + distPath)
+ case _ =>
+ distArchivesTimeStamps = Some(destStatus.getModificationTime().toString())
+ distArchivesFileSizes = Some(destStatus.getLen().toString())
+ distArchives = Some(distPath)
+ }
+ }
+ }
+
+ def prepareLocalResources(appId: ApplicationId, sparkStagingDir: String): HashMap[String, LocalResource] = {
logInfo("Preparing Local resources")
- val locaResources = HashMap[String, LocalResource]()
// Upload Spark and the application JAR to the remote file system
// Add them as local resources to the AM
val fs = FileSystem.get(conf)
@@ -125,33 +194,69 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
}
}
+ val pathSuffix = sparkStagingDir + "/" + appId.toString() + "/"
+ val dst = new Path(fs.getHomeDirectory(), pathSuffix)
+ val replication = System.getProperty("spark.yarn.submit.file.replication", "3").toShort
+
+ if (UserGroupInformation.isSecurityEnabled()) {
+ val dstFs = dst.getFileSystem(conf)
+ dstFs.addDelegationTokens(delegTokenRenewer, credentials);
+ }
+ val localResources = HashMap[String, LocalResource]()
+
Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF"))
.foreach { case(destName, _localPath) =>
val localPath: String = if (_localPath != null) _localPath.trim() else ""
if (! localPath.isEmpty()) {
val src = new Path(localPath)
- val pathSuffix = appName + "/" + appId.getId() + destName
- val dst = new Path(fs.getHomeDirectory(), pathSuffix)
- logInfo("Uploading " + src + " to " + dst)
- fs.copyFromLocalFile(false, true, src, dst)
- val destStatus = fs.getFileStatus(dst)
-
- // get tokens for anything we upload to hdfs
- if (UserGroupInformation.isSecurityEnabled()) {
- fs.addDelegationTokens(delegTokenRenewer, credentials);
- }
+ val newPath = new Path(dst, destName)
+ logInfo("Uploading " + src + " to " + newPath)
+ fs.copyFromLocalFile(false, true, src, newPath)
+ fs.setReplication(newPath, replication);
+ val destStatus = fs.getFileStatus(newPath)
val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
amJarRsrc.setType(LocalResourceType.FILE)
amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
- amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(dst))
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(newPath))
amJarRsrc.setTimestamp(destStatus.getModificationTime())
amJarRsrc.setSize(destStatus.getLen())
- locaResources(destName) = amJarRsrc
+ localResources(destName) = amJarRsrc
}
}
+
+ // handle any add jars
+ if ((args.addJars != null) && (!args.addJars.isEmpty())){
+ args.addJars.split(',').foreach { case file: String =>
+ val tmpURI = new URI(file)
+ val tmp = new Path(tmpURI)
+ copyLocalFile(dst, LocalResourceType.FILE, tmp, replication, localResources,
+ tmpURI.getFragment(), true)
+ }
+ }
+
+ // handle any distributed cache files
+ if ((args.files != null) && (!args.files.isEmpty())){
+ args.files.split(',').foreach { case file: String =>
+ val tmpURI = new URI(file)
+ val tmp = new Path(tmpURI)
+ copyLocalFile(dst, LocalResourceType.FILE, tmp, replication, localResources,
+ tmpURI.getFragment())
+ }
+ }
+
+ // handle any distributed cache archives
+ if ((args.archives != null) && (!args.archives.isEmpty())) {
+ args.archives.split(',').foreach { case file:String =>
+ val tmpURI = new URI(file)
+ val tmp = new Path(tmpURI)
+ copyLocalFile(dst, LocalResourceType.ARCHIVE, tmp, replication,
+ localResources, tmpURI.getFragment())
+ }
+ }
+
UserGroupInformation.getCurrentUser().addCredentials(credentials);
- return locaResources
+ return localResources
}
def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = {
@@ -160,11 +265,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
val env = new HashMap[String, String]()
- // If log4j present, ensure ours overrides all others
- if (log4jConfLocalRes != null) Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$())
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name,
+ Environment.PWD.$() + Path.SEPARATOR + "*")
- Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
- Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
Client.populateHadoopClasspath(yarnConf, env)
env("SPARK_YARN_MODE") = "true"
env("SPARK_YARN_JAR_PATH") =
@@ -186,6 +290,18 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString()
}
+ // set the environment variables to be passed on to the Workers
+ if (distFiles != None) {
+ env("SPARK_YARN_CACHE_FILES") = distFiles.get
+ env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = distFilesTimeStamps.get
+ env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = distFilesFileSizes.get
+ }
+ if (distArchives != None) {
+ env("SPARK_YARN_CACHE_ARCHIVES") = distArchives.get
+ env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = distArchivesTimeStamps.get
+ env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") = distArchivesFileSizes.get
+ }
+
// allow users to specify some environment variables
Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"))
@@ -224,8 +340,8 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
// Add Xmx for am memory
JAVA_OPTS += "-Xmx" + amMemory + "m "
- JAVA_OPTS += " -Djava.io.tmpdir=" + new Path(Environment.PWD.$(),
- YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR)
+ JAVA_OPTS += " -Djava.io.tmpdir=" +
+ new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + " "
// Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
@@ -241,6 +357,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
}
+
if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index cd651904d2..852dbd7dab 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -24,6 +24,9 @@ import org.apache.spark.scheduler.{InputFormatInfo, SplitInfo}
// TODO: Add code and support for ensuring that yarn resource 'asks' are location aware !
class ClientArguments(val args: Array[String]) {
+ var addJars: String = null
+ var files: String = null
+ var archives: String = null
var userJar: String = null
var userClass: String = null
var userArgs: Seq[String] = Seq[String]()
@@ -32,6 +35,7 @@ class ClientArguments(val args: Array[String]) {
var numWorkers = 2
var amQueue = System.getProperty("QUEUE", "default")
var amMemory: Int = 512
+ var appName: String = "Spark"
// TODO
var inputFormatInfo: List[InputFormatInfo] = null
@@ -78,6 +82,21 @@ class ClientArguments(val args: Array[String]) {
amQueue = value
args = tail
+ case ("--name") :: value :: tail =>
+ appName = value
+
+ case ("--addJars") :: value :: tail =>
+ addJars = value
+ args = tail
+
+ case ("--files") :: value :: tail =>
+ files = value
+ args = tail
+
+ case ("--archives") :: value :: tail =>
+ archives = value
+ args = tail
+
case Nil =>
if (userJar == null || userClass == null) {
printUsageAndExit(1)
@@ -92,7 +111,7 @@ class ClientArguments(val args: Array[String]) {
inputFormatInfo = inputFormatMap.values.toList
}
-
+
def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
if (unknownParam != null) {
System.err.println("Unknown/unsupported param " + unknownParam)
@@ -108,9 +127,13 @@ class ClientArguments(val args: Array[String]) {
" --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
" --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
" --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
- " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')"
+ " --name NAME The name of your application (Default: Spark)\n" +
+ " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" +
+ " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" +
+ " --files files Comma separated list of files to be distributed with the job.\n" +
+ " --archives archives Comma separated list of archives to be distributed with the job."
)
System.exit(exitCode)
}
-
+
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
index 6229167cb4..8dac9e02ac 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
@@ -77,8 +77,9 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
}
- JAVA_OPTS += " -Djava.io.tmpdir=" + new Path(Environment.PWD.$(),
- YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR)
+ JAVA_OPTS += " -Djava.io.tmpdir=" +
+ new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + " "
+
// Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
// The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
@@ -136,11 +137,26 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
startReq.setContainerLaunchContext(ctx)
cm.startContainer(startReq)
}
+
+ private def setupDistributedCache(file: String,
+ rtype: LocalResourceType,
+ localResources: HashMap[String, LocalResource],
+ timestamp: String,
+ size: String) = {
+ val uri = new URI(file)
+ val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ amJarRsrc.setType(rtype)
+ amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromURI(uri))
+ amJarRsrc.setTimestamp(timestamp.toLong)
+ amJarRsrc.setSize(size.toLong)
+ localResources(uri.getFragment()) = amJarRsrc
+ }
def prepareLocalResources: HashMap[String, LocalResource] = {
logInfo("Preparing Local resources")
- val locaResources = HashMap[String, LocalResource]()
+ val localResources = HashMap[String, LocalResource]()
// Spark JAR
val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
@@ -150,7 +166,7 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
new URI(System.getenv("SPARK_YARN_JAR_PATH"))))
sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong)
sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong)
- locaResources("spark.jar") = sparkJarResource
+ localResources("spark.jar") = sparkJarResource
// User JAR
val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
userJarResource.setType(LocalResourceType.FILE)
@@ -159,7 +175,7 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
new URI(System.getenv("SPARK_YARN_USERJAR_PATH"))))
userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong)
userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong)
- locaResources("app.jar") = userJarResource
+ localResources("app.jar") = userJarResource
// Log4j conf - if available
if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
@@ -170,26 +186,39 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
new URI(System.getenv("SPARK_YARN_LOG4J_PATH"))))
log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong)
log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong)
- locaResources("log4j.properties") = log4jConfResource
+ localResources("log4j.properties") = log4jConfResource
+ }
+
+ if (System.getenv("SPARK_YARN_CACHE_FILES") != null) {
+ val timeStamps = System.getenv("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',')
+ val fileSizes = System.getenv("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',')
+ val distFiles = System.getenv("SPARK_YARN_CACHE_FILES").split(',')
+ for( i <- 0 to distFiles.length - 1) {
+ setupDistributedCache(distFiles(i), LocalResourceType.FILE, localResources, timeStamps(i),
+ fileSizes(i))
+ }
}
+ if (System.getenv("SPARK_YARN_CACHE_ARCHIVES") != null) {
+ val timeStamps = System.getenv("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS").split(',')
+ val fileSizes = System.getenv("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES").split(',')
+ val distArchives = System.getenv("SPARK_YARN_CACHE_ARCHIVES").split(',')
+ for( i <- 0 to distArchives.length - 1) {
+ setupDistributedCache(distArchives(i), LocalResourceType.ARCHIVE, localResources,
+ timeStamps(i), fileSizes(i))
+ }
+ }
- logInfo("Prepared Local resources " + locaResources)
- return locaResources
+ logInfo("Prepared Local resources " + localResources)
+ return localResources
}
def prepareEnvironment: HashMap[String, String] = {
val env = new HashMap[String, String]()
- // If log4j present, ensure ours overrides all others
- if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
- // Which is correct ?
- Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./log4j.properties")
- Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
- }
-
- Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
- Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$())
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name,
+ Environment.PWD.$() + Path.SEPARATOR + "*")
Client.populateHadoopClasspath(yarnConf, env)
// allow users to specify some environment variables