aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bagel/pom.xml39
-rw-r--r--bagel/src/main/scala/spark/bagel/Bagel.scala6
-rw-r--r--bagel/src/test/scala/bagel/BagelSuite.scala1
-rwxr-xr-xbin/spark-daemon.sh10
-rwxr-xr-xbin/spark-daemons.sh2
-rwxr-xr-xbin/start-master.sh2
-rwxr-xr-xbin/start-slave.sh2
-rwxr-xr-xbin/start-slaves.sh11
-rwxr-xr-xbin/stop-master.sh2
-rwxr-xr-xbin/stop-slaves.sh12
-rwxr-xr-xconf/spark-env.sh.template1
-rw-r--r--core/pom.xml73
-rw-r--r--core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala3
-rw-r--r--core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala3
-rw-r--r--core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala23
-rw-r--r--core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala13
-rw-r--r--core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala13
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala63
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala342
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala78
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala272
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala106
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala171
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala547
-rw-r--r--core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala42
-rw-r--r--core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala3
-rw-r--r--core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala3
-rw-r--r--core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala23
-rw-r--r--core/src/main/scala/spark/BlockStoreShuffleFetcher.scala32
-rw-r--r--core/src/main/scala/spark/ClosureCleaner.scala12
-rw-r--r--core/src/main/scala/spark/Dependency.scala4
-rw-r--r--core/src/main/scala/spark/FetchFailedException.scala25
-rw-r--r--core/src/main/scala/spark/HadoopWriter.scala12
-rw-r--r--core/src/main/scala/spark/Logging.scala4
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala99
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala18
-rw-r--r--core/src/main/scala/spark/RDD.scala44
-rw-r--r--core/src/main/scala/spark/RDDCheckpointData.scala15
-rw-r--r--core/src/main/scala/spark/ShuffleFetcher.scala7
-rw-r--r--core/src/main/scala/spark/SparkContext.scala101
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala35
-rw-r--r--core/src/main/scala/spark/Utils.scala149
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDD.scala16
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala15
-rw-r--r--core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala11
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala2
-rw-r--r--core/src/main/scala/spark/deploy/ApplicationDescription.scala2
-rw-r--r--core/src/main/scala/spark/deploy/DeployMessage.scala19
-rw-r--r--core/src/main/scala/spark/deploy/JsonProtocol.scala5
-rw-r--r--core/src/main/scala/spark/deploy/LocalSparkCluster.scala8
-rw-r--r--core/src/main/scala/spark/deploy/client/Client.scala9
-rw-r--r--core/src/main/scala/spark/deploy/client/ClientListener.scala2
-rw-r--r--core/src/main/scala/spark/deploy/client/TestClient.scala2
-rw-r--r--core/src/main/scala/spark/deploy/master/ApplicationInfo.scala2
-rw-r--r--core/src/main/scala/spark/deploy/master/Master.scala17
-rw-r--r--core/src/main/scala/spark/deploy/master/MasterArguments.scala17
-rw-r--r--core/src/main/scala/spark/deploy/master/MasterWebUI.scala4
-rw-r--r--core/src/main/scala/spark/deploy/master/WorkerInfo.scala9
-rw-r--r--core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala6
-rw-r--r--core/src/main/scala/spark/deploy/worker/Worker.scala22
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerArguments.scala13
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala4
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala7
-rw-r--r--core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala32
-rw-r--r--core/src/main/scala/spark/executor/TaskMetrics.scala5
-rw-r--r--core/src/main/scala/spark/network/Connection.scala192
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala504
-rw-r--r--core/src/main/scala/spark/network/Message.scala3
-rw-r--r--core/src/main/scala/spark/rdd/CheckpointRDD.scala31
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala16
-rw-r--r--core/src/main/scala/spark/rdd/NewHadoopRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala12
-rw-r--r--core/src/main/scala/spark/rdd/SubtractedRDD.scala20
-rw-r--r--core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala120
-rw-r--r--core/src/main/scala/spark/rdd/ZippedRDD.scala20
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala18
-rw-r--r--core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/InputFormatInfo.scala156
-rw-r--r--core/src/main/scala/spark/scheduler/ResultTask.scala10
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala69
-rw-r--r--core/src/main/scala/spark/scheduler/SplitInfo.scala61
-rw-r--r--core/src/main/scala/spark/scheduler/TaskScheduler.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala3
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala281
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala7
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala34
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala309
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala4
-rw-r--r--core/src/main/scala/spark/serializer/Serializer.scala8
-rw-r--r--core/src/main/scala/spark/serializer/SerializerManager.scala45
-rw-r--r--core/src/main/scala/spark/storage/BlockException.scala5
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala255
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerId.scala40
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMaster.scala18
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMasterActor.scala23
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerUI.scala4
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerWorker.scala18
-rw-r--r--core/src/main/scala/spark/storage/BlockMessageArray.scala1
-rw-r--r--core/src/main/scala/spark/storage/BlockObjectWriter.scala50
-rw-r--r--core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala12
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala127
-rw-r--r--core/src/main/scala/spark/storage/MemoryStore.scala4
-rw-r--r--core/src/main/scala/spark/storage/ShuffleBlockManager.scala50
-rw-r--r--core/src/main/scala/spark/storage/StorageLevel.scala8
-rw-r--r--core/src/main/scala/spark/storage/StorageUtils.scala33
-rw-r--r--core/src/main/scala/spark/util/AkkaUtils.scala13
-rw-r--r--core/src/main/scala/spark/util/TimeStampedHashMap.scala8
-rw-r--r--core/src/main/scala/spark/util/TimedIterator.scala32
-rw-r--r--core/src/main/twirl/spark/deploy/master/app_details.scala.html11
-rw-r--r--core/src/main/twirl/spark/deploy/master/executor_row.scala.html2
-rw-r--r--core/src/main/twirl/spark/deploy/master/index.scala.html2
-rw-r--r--core/src/main/twirl/spark/deploy/master/worker_row.scala.html2
-rw-r--r--core/src/main/twirl/spark/deploy/worker/index.scala.html2
-rw-r--r--core/src/main/twirl/spark/storage/worker_table.scala.html2
-rw-r--r--core/src/test/scala/spark/DistributedSuite.scala35
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java26
-rw-r--r--core/src/test/scala/spark/LocalSparkContext.scala3
-rw-r--r--core/src/test/scala/spark/MapOutputTrackerSuite.scala3
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala24
-rw-r--r--core/src/test/scala/spark/ZippedPartitionsSuite.scala34
-rw-r--r--core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala2
-rw-r--r--core/src/test/scala/spark/scheduler/SparkListenerSuite.scala1
-rw-r--r--core/src/test/scala/spark/storage/BlockManagerSuite.scala36
-rw-r--r--docs/_config.yml6
-rw-r--r--docs/building-with-maven.md4
-rw-r--r--docs/index.md2
-rw-r--r--docs/quick-start.md25
-rw-r--r--docs/running-on-yarn.md31
-rw-r--r--docs/streaming-programming-guide.md4
-rwxr-xr-xec2/spark_ec2.py21
-rw-r--r--examples/pom.xml47
-rw-r--r--examples/src/main/scala/spark/examples/LocalKMeans.scala138
-rw-r--r--examples/src/main/scala/spark/examples/MultiBroadcastTest.scala6
-rw-r--r--examples/src/main/scala/spark/examples/SimpleSkewedGroupByTest.scala4
-rw-r--r--examples/src/main/scala/spark/examples/SkewedGroupByTest.scala4
-rw-r--r--examples/src/main/scala/spark/examples/SparkHdfsLR.scala10
-rw-r--r--pom.xml77
-rw-r--r--project/SparkBuild.scala88
-rw-r--r--project/build.properties2
-rw-r--r--project/plugins.sbt8
-rw-r--r--repl-bin/pom.xml57
-rw-r--r--repl/pom.xml81
-rw-r--r--repl/src/main/scala/spark/repl/SparkILoop.scala2
-rw-r--r--repl/src/test/scala/spark/repl/ReplSuite.scala1
-rwxr-xr-xrun21
-rw-r--r--run2.cmd3
-rwxr-xr-xsbt/sbt2
-rw-r--r--streaming/pom.xml39
-rw-r--r--streaming/src/main/scala/spark/streaming/Checkpoint.scala33
-rw-r--r--streaming/src/main/scala/spark/streaming/DStreamGraph.scala2
-rw-r--r--streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala7
-rw-r--r--streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala1
-rw-r--r--streaming/src/test/scala/spark/streaming/CheckpointSuite.scala4
-rw-r--r--streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala1
-rw-r--r--streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala1
158 files changed, 5244 insertions, 976 deletions
diff --git a/bagel/pom.xml b/bagel/pom.xml
index 510cff4669..b83a0ef6c0 100644
--- a/bagel/pom.xml
+++ b/bagel/pom.xml
@@ -4,7 +4,7 @@
<parent>
<groupId>org.spark-project</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.7.1-SNAPSHOT</version>
+ <version>0.8.0-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
@@ -102,5 +102,42 @@
</plugins>
</build>
</profile>
+ <profile>
+ <id>hadoop2-yarn</id>
+ <dependencies>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-core</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-common</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <configuration>
+ <classifier>hadoop2-yarn</classifier>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
</profiles>
</project>
diff --git a/bagel/src/main/scala/spark/bagel/Bagel.scala b/bagel/src/main/scala/spark/bagel/Bagel.scala
index e10c03f6ba..5ecdd7d004 100644
--- a/bagel/src/main/scala/spark/bagel/Bagel.scala
+++ b/bagel/src/main/scala/spark/bagel/Bagel.scala
@@ -7,8 +7,7 @@ import scala.collection.mutable.ArrayBuffer
import storage.StorageLevel
object Bagel extends Logging {
-
- val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_ONLY
+ val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK
/**
* Runs a Bagel program.
@@ -63,8 +62,9 @@ object Bagel extends Logging {
val combinedMsgs = msgs.combineByKey(
combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner)
val grouped = combinedMsgs.groupWith(verts)
+ val superstep_ = superstep // Create a read-only copy of superstep for capture in closure
val (processed, numMsgs, numActiveVerts) =
- comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep), storageLevel)
+ comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep_), storageLevel)
val timeTaken = System.currentTimeMillis - startTime
logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala
index 25db395c22..a09c978068 100644
--- a/bagel/src/test/scala/bagel/BagelSuite.scala
+++ b/bagel/src/test/scala/bagel/BagelSuite.scala
@@ -23,6 +23,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeo
}
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
}
test("halting by voting") {
diff --git a/bin/spark-daemon.sh b/bin/spark-daemon.sh
index 0c584055c7..8ee3ec481f 100755
--- a/bin/spark-daemon.sh
+++ b/bin/spark-daemon.sh
@@ -30,7 +30,7 @@
# SPARK_NICENESS The scheduling priority for daemons. Defaults to 0.
##
-usage="Usage: spark-daemon.sh [--config <conf-dir>] [--hosts hostlistfile] (start|stop) <spark-command> <args...>"
+usage="Usage: spark-daemon.sh [--config <conf-dir>] [--hosts hostlistfile] (start|stop) <spark-command> <spark-instance-number> <args...>"
# if no args specified, show usage
if [ $# -le 1 ]; then
@@ -48,6 +48,8 @@ startStop=$1
shift
command=$1
shift
+instance=$1
+shift
spark_rotate_log ()
{
@@ -92,10 +94,10 @@ if [ "$SPARK_PID_DIR" = "" ]; then
fi
# some variables
-export SPARK_LOGFILE=spark-$SPARK_IDENT_STRING-$command-$HOSTNAME.log
+export SPARK_LOGFILE=spark-$SPARK_IDENT_STRING-$command-$instance-$HOSTNAME.log
export SPARK_ROOT_LOGGER="INFO,DRFA"
-log=$SPARK_LOG_DIR/spark-$SPARK_IDENT_STRING-$command-$HOSTNAME.out
-pid=$SPARK_PID_DIR/spark-$SPARK_IDENT_STRING-$command.pid
+log=$SPARK_LOG_DIR/spark-$SPARK_IDENT_STRING-$command-$instance-$HOSTNAME.out
+pid=$SPARK_PID_DIR/spark-$SPARK_IDENT_STRING-$command-$instance.pid
# Set default scheduling priority
if [ "$SPARK_NICENESS" = "" ]; then
diff --git a/bin/spark-daemons.sh b/bin/spark-daemons.sh
index 4f9719ee80..0619097e4d 100755
--- a/bin/spark-daemons.sh
+++ b/bin/spark-daemons.sh
@@ -2,7 +2,7 @@
# Run a Spark command on all slave hosts.
-usage="Usage: spark-daemons.sh [--config confdir] [--hosts hostlistfile] [start|stop] command args..."
+usage="Usage: spark-daemons.sh [--config confdir] [--hosts hostlistfile] [start|stop] command instance-number args..."
# if no args specified, show usage
if [ $# -le 1 ]; then
diff --git a/bin/start-master.sh b/bin/start-master.sh
index 87feb261fe..83a3e1f3dc 100755
--- a/bin/start-master.sh
+++ b/bin/start-master.sh
@@ -32,4 +32,4 @@ if [ "$SPARK_PUBLIC_DNS" = "" ]; then
fi
fi
-"$bin"/spark-daemon.sh start spark.deploy.master.Master --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT
+"$bin"/spark-daemon.sh start spark.deploy.master.Master 1 --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT
diff --git a/bin/start-slave.sh b/bin/start-slave.sh
index 45a0cf7a6b..616c76e4ee 100755
--- a/bin/start-slave.sh
+++ b/bin/start-slave.sh
@@ -11,4 +11,4 @@ if [ "$SPARK_PUBLIC_DNS" = "" ]; then
fi
fi
-"$bin"/spark-daemon.sh start spark.deploy.worker.Worker $1
+"$bin"/spark-daemon.sh start spark.deploy.worker.Worker "$@"
diff --git a/bin/start-slaves.sh b/bin/start-slaves.sh
index 390247ca4a..4e05224190 100755
--- a/bin/start-slaves.sh
+++ b/bin/start-slaves.sh
@@ -21,4 +21,13 @@ fi
echo "Master IP: $SPARK_MASTER_IP"
# Launch the slaves
-exec "$bin/slaves.sh" cd "$SPARK_HOME" \; "$bin/start-slave.sh" spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT
+if [ "$SPARK_WORKER_INSTANCES" = "" ]; then
+ exec "$bin/slaves.sh" cd "$SPARK_HOME" \; "$bin/start-slave.sh" 1 spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT
+else
+ if [ "$SPARK_WORKER_WEBUI_PORT" = "" ]; then
+ SPARK_WORKER_WEBUI_PORT=8081
+ fi
+ for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do
+ "$bin/slaves.sh" cd "$SPARK_HOME" \; "$bin/start-slave.sh" $(( $i + 1 )) spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT --webui-port $(( $SPARK_WORKER_WEBUI_PORT + $i ))
+ done
+fi
diff --git a/bin/stop-master.sh b/bin/stop-master.sh
index f75167dd2c..172ee5891d 100755
--- a/bin/stop-master.sh
+++ b/bin/stop-master.sh
@@ -7,4 +7,4 @@ bin=`cd "$bin"; pwd`
. "$bin/spark-config.sh"
-"$bin"/spark-daemon.sh stop spark.deploy.master.Master \ No newline at end of file
+"$bin"/spark-daemon.sh stop spark.deploy.master.Master 1
diff --git a/bin/stop-slaves.sh b/bin/stop-slaves.sh
index 21c9ebf324..fbfc594472 100755
--- a/bin/stop-slaves.sh
+++ b/bin/stop-slaves.sh
@@ -7,4 +7,14 @@ bin=`cd "$bin"; pwd`
. "$bin/spark-config.sh"
-"$bin"/spark-daemons.sh stop spark.deploy.worker.Worker \ No newline at end of file
+if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then
+ . "${SPARK_CONF_DIR}/spark-env.sh"
+fi
+
+if [ "$SPARK_WORKER_INSTANCES" = "" ]; then
+ "$bin"/spark-daemons.sh stop spark.deploy.worker.Worker 1
+else
+ for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do
+ "$bin"/spark-daemons.sh stop spark.deploy.worker.Worker $(( $i + 1 ))
+ done
+fi
diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template
index 6d71ec5691..37565ca827 100755
--- a/conf/spark-env.sh.template
+++ b/conf/spark-env.sh.template
@@ -12,6 +12,7 @@
# - SPARK_WORKER_CORES, to set the number of cores to use on this machine
# - SPARK_WORKER_MEMORY, to set how much memory to use (e.g. 1000m, 2g)
# - SPARK_WORKER_PORT / SPARK_WORKER_WEBUI_PORT
+# - SPARK_WORKER_INSTANCES, to set the number of worker instances/processes to be spawned on every slave machine
#
# Finally, Spark also relies on the following variables, but these can be set
# on just the *master* (i.e. in your driver program), and will automatically
diff --git a/core/pom.xml b/core/pom.xml
index fe9c803728..9a019b5a42 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -4,7 +4,7 @@
<parent>
<groupId>org.spark-project</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.7.1-SNAPSHOT</version>
+ <version>0.8.0-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
@@ -73,7 +73,7 @@
</dependency>
<dependency>
<groupId>cc.spray</groupId>
- <artifactId>spray-json_${scala.version}</artifactId>
+ <artifactId>spray-json_2.9.2</artifactId>
</dependency>
<dependency>
<groupId>org.tomdz.twirl</groupId>
@@ -81,7 +81,7 @@
</dependency>
<dependency>
<groupId>com.github.scala-incubator.io</groupId>
- <artifactId>scala-io-file_${scala.version}</artifactId>
+ <artifactId>scala-io-file_2.9.2</artifactId>
</dependency>
<dependency>
<groupId>org.apache.mesos</groupId>
@@ -279,5 +279,72 @@
</plugins>
</build>
</profile>
+ <profile>
+ <id>hadoop2-yarn</id>
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-common</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-client</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>build-helper-maven-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>add-source</id>
+ <phase>generate-sources</phase>
+ <goals>
+ <goal>add-source</goal>
+ </goals>
+ <configuration>
+ <sources>
+ <source>src/main/scala</source>
+ <source>src/hadoop2-yarn/scala</source>
+ </sources>
+ </configuration>
+ </execution>
+ <execution>
+ <id>add-scala-test-sources</id>
+ <phase>generate-test-sources</phase>
+ <goals>
+ <goal>add-test-source</goal>
+ </goals>
+ <configuration>
+ <sources>
+ <source>src/test/scala</source>
+ </sources>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <configuration>
+ <classifier>hadoop2-yarn</classifier>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
</profiles>
</project>
diff --git a/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
index ca9f7219de..f286f2cf9c 100644
--- a/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
+++ b/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
@@ -4,4 +4,7 @@ trait HadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContext(conf, jobId)
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
+ jobId, isMap, taskId, attemptId)
}
diff --git a/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
index de7b0f81e3..264d421d14 100644
--- a/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
+++ b/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
@@ -6,4 +6,7 @@ trait HadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContext(conf, jobId)
def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
+ jobId, isMap, taskId, attemptId)
}
diff --git a/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala
new file mode 100644
index 0000000000..a0fb4fe25d
--- /dev/null
+++ b/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala
@@ -0,0 +1,23 @@
+package spark.deploy
+import org.apache.hadoop.conf.Configuration
+
+
+/**
+ * Contains util methods to interact with Hadoop from spark.
+ */
+object SparkHadoopUtil {
+
+ def getUserNameFromEnvironment(): String = {
+ // defaulting to -D ...
+ System.getProperty("user.name")
+ }
+
+ def runAsUser(func: (Product) => Unit, args: Product) {
+
+ // Add support, if exists - for now, simply run func !
+ func(args)
+ }
+
+ // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
+ def newConfiguration(): Configuration = new Configuration()
+}
diff --git a/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
new file mode 100644
index 0000000000..875c0a220b
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
@@ -0,0 +1,13 @@
+
+package org.apache.hadoop.mapred
+
+import org.apache.hadoop.mapreduce.TaskType
+
+trait HadoopMapRedUtil {
+ def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
+
+ def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) =
+ new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId)
+}
diff --git a/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
new file mode 100644
index 0000000000..8bc6fb6dea
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
@@ -0,0 +1,13 @@
+package org.apache.hadoop.mapreduce
+
+import org.apache.hadoop.conf.Configuration
+import task.{TaskAttemptContextImpl, JobContextImpl}
+
+trait HadoopMapReduceUtil {
+ def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
+
+ def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) =
+ new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId)
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala
new file mode 100644
index 0000000000..ab1ab9d8a7
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala
@@ -0,0 +1,63 @@
+package spark.deploy
+
+import collection.mutable.HashMap
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import java.security.PrivilegedExceptionAction
+
+/**
+ * Contains util methods to interact with Hadoop from spark.
+ */
+object SparkHadoopUtil {
+
+ val yarnConf = newConfiguration()
+
+ def getUserNameFromEnvironment(): String = {
+ // defaulting to env if -D is not present ...
+ val retval = System.getProperty(Environment.USER.name, System.getenv(Environment.USER.name))
+
+ // If nothing found, default to user we are running as
+ if (retval == null) System.getProperty("user.name") else retval
+ }
+
+ def runAsUser(func: (Product) => Unit, args: Product) {
+ runAsUser(func, args, getUserNameFromEnvironment())
+ }
+
+ def runAsUser(func: (Product) => Unit, args: Product, user: String) {
+
+ // println("running as user " + jobUserName)
+
+ UserGroupInformation.setConfiguration(yarnConf)
+ val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(user)
+ appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] {
+ def run: AnyRef = {
+ func(args)
+ // no return value ...
+ null
+ }
+ })
+ }
+
+ // Note that all params which start with SPARK are propagated all the way through, so if in yarn mode, this MUST be set to true.
+ def isYarnMode(): Boolean = {
+ val yarnMode = System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))
+ java.lang.Boolean.valueOf(yarnMode)
+ }
+
+ // Set an env variable indicating we are running in YARN mode.
+ // Note that anything with SPARK prefix gets propagated to all (remote) processes
+ def setYarnMode() {
+ System.setProperty("SPARK_YARN_MODE", "true")
+ }
+
+ def setYarnMode(env: HashMap[String, String]) {
+ env("SPARK_YARN_MODE") = "true"
+ }
+
+ // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
+ // Always create a new config, dont reuse yarnConf.
+ def newConfiguration(): Configuration = new YarnConfiguration(new Configuration())
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala
new file mode 100644
index 0000000000..ae719267e8
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala
@@ -0,0 +1,342 @@
+package spark.deploy.yarn
+
+import java.net.Socket
+import java.util.concurrent.CopyOnWriteArrayList
+import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+import scala.collection.JavaConversions._
+import spark.{SparkContext, Logging, Utils}
+import org.apache.hadoop.security.UserGroupInformation
+import java.security.PrivilegedExceptionAction
+
+class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
+
+ def this(args: ApplicationMasterArguments) = this(args, new Configuration())
+
+ private var rpc: YarnRPC = YarnRPC.create(conf)
+ private var resourceManager: AMRMProtocol = null
+ private var appAttemptId: ApplicationAttemptId = null
+ private var userThread: Thread = null
+ private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+ private var yarnAllocator: YarnAllocationHandler = null
+
+ def run() {
+
+ // Initialization
+ val jobUserName = Utils.getUserNameFromEnvironment()
+ logInfo("running as user " + jobUserName)
+
+ // run as user ...
+ UserGroupInformation.setConfiguration(yarnConf)
+ val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(jobUserName)
+ appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] {
+ def run: AnyRef = {
+ runImpl()
+ return null
+ }
+ })
+ }
+
+ private def runImpl() {
+
+ appAttemptId = getApplicationAttemptId()
+ resourceManager = registerWithResourceManager()
+ val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
+
+ // Compute number of threads for akka
+ val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()
+
+ if (minimumMemory > 0) {
+ val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD
+ val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0)
+
+ if (numCore > 0) {
+ // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406
+ // TODO: Uncomment when hadoop is on a version which has this fixed.
+ // args.workerCores = numCore
+ }
+ }
+
+ // Workaround until hadoop moves to something which has
+ // https://issues.apache.org/jira/browse/HADOOP-8406
+ // ignore result
+ // This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times
+ // Hence args.workerCores = numCore disabled above. Any better option ?
+ // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf)
+
+ ApplicationMaster.register(this)
+ // Start the user's JAR
+ userThread = startUserClass()
+
+ // This a bit hacky, but we need to wait until the spark.driver.port property has
+ // been set by the Thread executing the user class.
+ waitForSparkMaster()
+
+ // Allocate all containers
+ allocateWorkers()
+
+ // Wait for the user class to Finish
+ userThread.join()
+
+ // Finish the ApplicationMaster
+ finishApplicationMaster()
+ // TODO: Exit based on success/failure
+ System.exit(0)
+ }
+
+ private def getApplicationAttemptId(): ApplicationAttemptId = {
+ val envs = System.getenv()
+ val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV)
+ val containerId = ConverterUtils.toContainerId(containerIdString)
+ val appAttemptId = containerId.getApplicationAttemptId()
+ logInfo("ApplicationAttemptId: " + appAttemptId)
+ return appAttemptId
+ }
+
+ private def registerWithResourceManager(): AMRMProtocol = {
+ val rmAddress = NetUtils.createSocketAddr(yarnConf.get(
+ YarnConfiguration.RM_SCHEDULER_ADDRESS,
+ YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS))
+ logInfo("Connecting to ResourceManager at " + rmAddress)
+ return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
+ }
+
+ private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
+ logInfo("Registering the ApplicationMaster")
+ val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest])
+ .asInstanceOf[RegisterApplicationMasterRequest]
+ appMasterRequest.setApplicationAttemptId(appAttemptId)
+ // Setting this to master host,port - so that the ApplicationReport at client has some sensible info.
+ // Users can then monitor stderr/stdout on that node if required.
+ appMasterRequest.setHost(Utils.localHostName())
+ appMasterRequest.setRpcPort(0)
+ // What do we provide here ? Might make sense to expose something sensible later ?
+ appMasterRequest.setTrackingUrl("")
+ return resourceManager.registerApplicationMaster(appMasterRequest)
+ }
+
+ private def waitForSparkMaster() {
+ logInfo("Waiting for spark driver to be reachable.")
+ var driverUp = false
+ while(!driverUp) {
+ val driverHost = System.getProperty("spark.driver.host")
+ val driverPort = System.getProperty("spark.driver.port")
+ try {
+ val socket = new Socket(driverHost, driverPort.toInt)
+ socket.close()
+ logInfo("Master now available: " + driverHost + ":" + driverPort)
+ driverUp = true
+ } catch {
+ case e: Exception =>
+ logError("Failed to connect to driver at " + driverHost + ":" + driverPort)
+ Thread.sleep(100)
+ }
+ }
+ }
+
+ private def startUserClass(): Thread = {
+ logInfo("Starting the user JAR in a separate Thread")
+ val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader)
+ .getMethod("main", classOf[Array[String]])
+ val t = new Thread {
+ override def run() {
+ var mainArgs: Array[String] = null
+ var startIndex = 0
+
+ // I am sure there is a better 'scala' way to do this .... but I am just trying to get things to work right now !
+ if (args.userArgs.isEmpty || args.userArgs.get(0) != "yarn-standalone") {
+ // ensure that first param is ALWAYS "yarn-standalone"
+ mainArgs = new Array[String](args.userArgs.size() + 1)
+ mainArgs.update(0, "yarn-standalone")
+ startIndex = 1
+ }
+ else {
+ mainArgs = new Array[String](args.userArgs.size())
+ }
+
+ args.userArgs.copyToArray(mainArgs, startIndex, args.userArgs.size())
+
+ mainMethod.invoke(null, mainArgs)
+ }
+ }
+ t.start()
+ return t
+ }
+
+ private def allocateWorkers() {
+ logInfo("Waiting for spark context initialization")
+
+ try {
+ var sparkContext: SparkContext = null
+ ApplicationMaster.sparkContextRef.synchronized {
+ var count = 0
+ while (ApplicationMaster.sparkContextRef.get() == null) {
+ logInfo("Waiting for spark context initialization ... " + count)
+ count = count + 1
+ ApplicationMaster.sparkContextRef.wait(10000L)
+ }
+ sparkContext = ApplicationMaster.sparkContextRef.get()
+ assert(sparkContext != null)
+ this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, sparkContext.preferredNodeLocationData)
+ }
+
+
+ logInfo("Allocating " + args.numWorkers + " workers.")
+ // Wait until all containers have finished
+ // TODO: This is a bit ugly. Can we make it nicer?
+ // TODO: Handle container failure
+ while(yarnAllocator.getNumWorkersRunning < args.numWorkers &&
+ // If user thread exists, then quit !
+ userThread.isAlive) {
+
+ this.yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0))
+ ApplicationMaster.incrementAllocatorLoop(1)
+ Thread.sleep(100)
+ }
+ } finally {
+ // in case of exceptions, etc - ensure that count is atleast ALLOCATOR_LOOP_WAIT_COUNT :
+ // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks
+ ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT)
+ }
+ logInfo("All workers have launched.")
+
+ // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout
+ if (userThread.isAlive){
+ // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
+
+ val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
+ // must be <= timeoutInterval/ 2.
+ // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
+ // so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
+ val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
+ launchReporterThread(interval)
+ }
+ }
+
+ // TODO: We might want to extend this to allocate more containers in case they die !
+ private def launchReporterThread(_sleepTime: Long): Thread = {
+ val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime
+
+ val t = new Thread {
+ override def run() {
+ while (userThread.isAlive){
+ val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning
+ if (missingWorkerCount > 0) {
+ logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers")
+ yarnAllocator.allocateContainers(missingWorkerCount)
+ }
+ else sendProgress()
+ Thread.sleep(sleepTime)
+ }
+ }
+ }
+ // setting to daemon status, though this is usually not a good idea.
+ t.setDaemon(true)
+ t.start()
+ logInfo("Started progress reporter thread - sleep time : " + sleepTime)
+ return t
+ }
+
+ private def sendProgress() {
+ logDebug("Sending progress")
+ // simulated with an allocate request with no nodes requested ...
+ yarnAllocator.allocateContainers(0)
+ }
+
+ /*
+ def printContainers(containers: List[Container]) = {
+ for (container <- containers) {
+ logInfo("Launching shell command on a new container."
+ + ", containerId=" + container.getId()
+ + ", containerNode=" + container.getNodeId().getHost()
+ + ":" + container.getNodeId().getPort()
+ + ", containerNodeURI=" + container.getNodeHttpAddress()
+ + ", containerState" + container.getState()
+ + ", containerResourceMemory"
+ + container.getResource().getMemory())
+ }
+ }
+ */
+
+ def finishApplicationMaster() {
+ val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest])
+ .asInstanceOf[FinishApplicationMasterRequest]
+ finishReq.setAppAttemptId(appAttemptId)
+ // TODO: Check if the application has failed or succeeded
+ finishReq.setFinishApplicationStatus(FinalApplicationStatus.SUCCEEDED)
+ resourceManager.finishApplicationMaster(finishReq)
+ }
+
+}
+
+object ApplicationMaster {
+ // number of times to wait for the allocator loop to complete.
+ // each loop iteration waits for 100ms, so maximum of 3 seconds.
+ // This is to ensure that we have reasonable number of containers before we start
+ // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be optimal as more
+ // containers are available. Might need to handle this better.
+ private val ALLOCATOR_LOOP_WAIT_COUNT = 30
+ def incrementAllocatorLoop(by: Int) {
+ val count = yarnAllocatorLoop.getAndAdd(by)
+ if (count >= ALLOCATOR_LOOP_WAIT_COUNT){
+ yarnAllocatorLoop.synchronized {
+ // to wake threads off wait ...
+ yarnAllocatorLoop.notifyAll()
+ }
+ }
+ }
+
+ private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]()
+
+ def register(master: ApplicationMaster) {
+ applicationMasters.add(master)
+ }
+
+ val sparkContextRef: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null)
+ val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0)
+
+ def sparkContextInitialized(sc: SparkContext): Boolean = {
+ var modified = false
+ sparkContextRef.synchronized {
+ modified = sparkContextRef.compareAndSet(null, sc)
+ sparkContextRef.notifyAll()
+ }
+
+ // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do System.exit
+ // Should not really have to do this, but it helps yarn to evict resources earlier.
+ // not to mention, prevent Client declaring failure even though we exit'ed properly.
+ if (modified) {
+ Runtime.getRuntime().addShutdownHook(new Thread with Logging {
+ // This is not just to log, but also to ensure that log system is initialized for this instance when we actually are 'run'
+ logInfo("Adding shutdown hook for context " + sc)
+ override def run() {
+ logInfo("Invoking sc stop from shutdown hook")
+ sc.stop()
+ // best case ...
+ for (master <- applicationMasters) master.finishApplicationMaster
+ }
+ } )
+ }
+
+ // Wait for initialization to complete and atleast 'some' nodes can get allocated
+ yarnAllocatorLoop.synchronized {
+ while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT){
+ yarnAllocatorLoop.wait(1000L)
+ }
+ }
+ modified
+ }
+
+ def main(argStrings: Array[String]) {
+ val args = new ApplicationMasterArguments(argStrings)
+ new ApplicationMaster(args).run()
+ }
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala
new file mode 100644
index 0000000000..dc89125d81
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala
@@ -0,0 +1,78 @@
+package spark.deploy.yarn
+
+import spark.util.IntParam
+import collection.mutable.ArrayBuffer
+
+class ApplicationMasterArguments(val args: Array[String]) {
+ var userJar: String = null
+ var userClass: String = null
+ var userArgs: Seq[String] = Seq[String]()
+ var workerMemory = 1024
+ var workerCores = 1
+ var numWorkers = 2
+
+ parseArgs(args.toList)
+
+ private def parseArgs(inputArgs: List[String]): Unit = {
+ val userArgsBuffer = new ArrayBuffer[String]()
+
+ var args = inputArgs
+
+ while (! args.isEmpty) {
+
+ args match {
+ case ("--jar") :: value :: tail =>
+ userJar = value
+ args = tail
+
+ case ("--class") :: value :: tail =>
+ userClass = value
+ args = tail
+
+ case ("--args") :: value :: tail =>
+ userArgsBuffer += value
+ args = tail
+
+ case ("--num-workers") :: IntParam(value) :: tail =>
+ numWorkers = value
+ args = tail
+
+ case ("--worker-memory") :: IntParam(value) :: tail =>
+ workerMemory = value
+ args = tail
+
+ case ("--worker-cores") :: IntParam(value) :: tail =>
+ workerCores = value
+ args = tail
+
+ case Nil =>
+ if (userJar == null || userClass == null) {
+ printUsageAndExit(1)
+ }
+
+ case _ =>
+ printUsageAndExit(1, args)
+ }
+ }
+
+ userArgs = userArgsBuffer.readOnly
+ }
+
+ def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ if (unknownParam != null) {
+ System.err.println("Unknown/unsupported param " + unknownParam)
+ }
+ System.err.println(
+ "Usage: spark.deploy.yarn.ApplicationMaster [options] \n" +
+ "Options:\n" +
+ " --jar JAR_PATH Path to your application's JAR file (required)\n" +
+ " --class CLASS_NAME Name of your application's main class (required)\n" +
+ " --args ARGS Arguments to be passed to your application's main class.\n" +
+ " Mutliple invocations are possible, each will be passed in order.\n" +
+ " Note that first argument will ALWAYS be yarn-standalone : will be added if missing.\n" +
+ " --num-workers NUM Number of workers to start (Default: 2)\n" +
+ " --worker-cores NUM Number of cores for the workers (Default: 1)\n" +
+ " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n")
+ System.exit(exitCode)
+ }
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala
new file mode 100644
index 0000000000..7a881e26df
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala
@@ -0,0 +1,272 @@
+package spark.deploy.yarn
+
+import java.net.{InetSocketAddress, URI}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.client.YarnClientImpl
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions._
+import spark.{Logging, Utils}
+import org.apache.hadoop.yarn.util.{Apps, Records, ConverterUtils}
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import spark.deploy.SparkHadoopUtil
+
+class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging {
+
+ def this(args: ClientArguments) = this(new Configuration(), args)
+
+ var rpc: YarnRPC = YarnRPC.create(conf)
+ val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+ def run() {
+ init(yarnConf)
+ start()
+ logClusterResourceDetails()
+
+ val newApp = super.getNewApplication()
+ val appId = newApp.getApplicationId()
+
+ verifyClusterResources(newApp)
+ val appContext = createApplicationSubmissionContext(appId)
+ val localResources = prepareLocalResources(appId, "spark")
+ val env = setupLaunchEnv(localResources)
+ val amContainer = createContainerLaunchContext(newApp, localResources, env)
+
+ appContext.setQueue(args.amQueue)
+ appContext.setAMContainerSpec(amContainer)
+ appContext.setUser(args.amUser)
+
+ submitApp(appContext)
+
+ monitorApplication(appId)
+ System.exit(0)
+ }
+
+
+ def logClusterResourceDetails() {
+ val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics
+ logInfo("Got Cluster metric info from ASM, numNodeManagers=" + clusterMetrics.getNumNodeManagers)
+
+ val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue)
+ logInfo("Queue info .. queueName=" + queueInfo.getQueueName + ", queueCurrentCapacity=" + queueInfo.getCurrentCapacity +
+ ", queueMaxCapacity=" + queueInfo.getMaximumCapacity + ", queueApplicationCount=" + queueInfo.getApplications.size +
+ ", queueChildQueueCount=" + queueInfo.getChildQueues.size)
+ }
+
+
+ def verifyClusterResources(app: GetNewApplicationResponse) = {
+ val maxMem = app.getMaximumResourceCapability().getMemory()
+ logInfo("Max mem capabililty of resources in this cluster " + maxMem)
+
+ // If the cluster does not have enough memory resources, exit.
+ val requestedMem = (args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + args.numWorkers * args.workerMemory
+ if (requestedMem > maxMem) {
+ logError("Cluster cannot satisfy memory resource request of " + requestedMem)
+ System.exit(1)
+ }
+ }
+
+ def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = {
+ logInfo("Setting up application submission context for ASM")
+ val appContext = Records.newRecord(classOf[ApplicationSubmissionContext])
+ appContext.setApplicationId(appId)
+ appContext.setApplicationName("Spark")
+ return appContext
+ }
+
+ def prepareLocalResources(appId: ApplicationId, appName: String): HashMap[String, LocalResource] = {
+ logInfo("Preparing Local resources")
+ val locaResources = HashMap[String, LocalResource]()
+ // Upload Spark and the application JAR to the remote file system
+ // Add them as local resources to the AM
+ val fs = FileSystem.get(conf)
+ Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF"))
+ .foreach { case(destName, _localPath) =>
+ val localPath: String = if (_localPath != null) _localPath.trim() else ""
+ if (! localPath.isEmpty()) {
+ val src = new Path(localPath)
+ val pathSuffix = appName + "/" + appId.getId() + destName
+ val dst = new Path(fs.getHomeDirectory(), pathSuffix)
+ logInfo("Uploading " + src + " to " + dst)
+ fs.copyFromLocalFile(false, true, src, dst)
+ val destStatus = fs.getFileStatus(dst)
+
+ val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ amJarRsrc.setType(LocalResourceType.FILE)
+ amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(dst))
+ amJarRsrc.setTimestamp(destStatus.getModificationTime())
+ amJarRsrc.setSize(destStatus.getLen())
+ locaResources(destName) = amJarRsrc
+ }
+ }
+ return locaResources
+ }
+
+ def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = {
+ logInfo("Setting up the launch environment")
+ val log4jConfLocalRes = localResources.getOrElse("log4j.properties", null)
+
+ val env = new HashMap[String, String]()
+ Apps.addToEnvironment(env, Environment.USER.name, args.amUser)
+
+ // If log4j present, ensure ours overrides all others
+ if (log4jConfLocalRes != null) Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
+
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
+ Client.populateHadoopClasspath(yarnConf, env)
+ SparkHadoopUtil.setYarnMode(env)
+ env("SPARK_YARN_JAR_PATH") =
+ localResources("spark.jar").getResource().getScheme.toString() + "://" +
+ localResources("spark.jar").getResource().getFile().toString()
+ env("SPARK_YARN_JAR_TIMESTAMP") = localResources("spark.jar").getTimestamp().toString()
+ env("SPARK_YARN_JAR_SIZE") = localResources("spark.jar").getSize().toString()
+
+ env("SPARK_YARN_USERJAR_PATH") =
+ localResources("app.jar").getResource().getScheme.toString() + "://" +
+ localResources("app.jar").getResource().getFile().toString()
+ env("SPARK_YARN_USERJAR_TIMESTAMP") = localResources("app.jar").getTimestamp().toString()
+ env("SPARK_YARN_USERJAR_SIZE") = localResources("app.jar").getSize().toString()
+
+ if (log4jConfLocalRes != null) {
+ env("SPARK_YARN_LOG4J_PATH") =
+ log4jConfLocalRes.getResource().getScheme.toString() + "://" + log4jConfLocalRes.getResource().getFile().toString()
+ env("SPARK_YARN_LOG4J_TIMESTAMP") = log4jConfLocalRes.getTimestamp().toString()
+ env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString()
+ }
+
+ // Add each SPARK-* key to the environment
+ System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
+ return env
+ }
+
+ def userArgsToString(clientArgs: ClientArguments): String = {
+ val prefix = " --args "
+ val args = clientArgs.userArgs
+ val retval = new StringBuilder()
+ for (arg <- args){
+ retval.append(prefix).append(" '").append(arg).append("' ")
+ }
+
+ retval.toString
+ }
+
+ def createContainerLaunchContext(newApp: GetNewApplicationResponse,
+ localResources: HashMap[String, LocalResource],
+ env: HashMap[String, String]): ContainerLaunchContext = {
+ logInfo("Setting up container launch context")
+ val amContainer = Records.newRecord(classOf[ContainerLaunchContext])
+ amContainer.setLocalResources(localResources)
+ amContainer.setEnvironment(env)
+
+ val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory()
+
+ var amMemory = ((args.amMemory / minResMemory) * minResMemory) +
+ (if (0 != (args.amMemory % minResMemory)) minResMemory else 0) - YarnAllocationHandler.MEMORY_OVERHEAD
+
+ // Extra options for the JVM
+ var JAVA_OPTS = ""
+
+ // Add Xmx for am memory
+ JAVA_OPTS += "-Xmx" + amMemory + "m "
+
+ // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
+ // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
+ // node, spark gc effects all other containers performance (which can also be other spark containers)
+ // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
+ // limited to subset of cores on a node.
+ if (env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))) {
+ // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tenant machines
+ JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
+ JAVA_OPTS += " -XX:+CMSIncrementalMode "
+ JAVA_OPTS += " -XX:+CMSIncrementalPacing "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
+ }
+ if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
+ JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
+ }
+
+ // Command for the ApplicationMaster
+ val commands = List[String]("java " +
+ " -server " +
+ JAVA_OPTS +
+ " spark.deploy.yarn.ApplicationMaster" +
+ " --class " + args.userClass +
+ " --jar " + args.userJar +
+ userArgsToString(args) +
+ " --worker-memory " + args.workerMemory +
+ " --worker-cores " + args.workerCores +
+ " --num-workers " + args.numWorkers +
+ " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
+ " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
+ logInfo("Command for the ApplicationMaster: " + commands(0))
+ amContainer.setCommands(commands)
+
+ val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource]
+ // Memory for the ApplicationMaster
+ capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+ amContainer.setResource(capability)
+
+ return amContainer
+ }
+
+ def submitApp(appContext: ApplicationSubmissionContext) = {
+ // Submit the application to the applications manager
+ logInfo("Submitting application to ASM")
+ super.submitApplication(appContext)
+ }
+
+ def monitorApplication(appId: ApplicationId): Boolean = {
+ while(true) {
+ Thread.sleep(1000)
+ val report = super.getApplicationReport(appId)
+
+ logInfo("Application report from ASM: \n" +
+ "\t application identifier: " + appId.toString() + "\n" +
+ "\t appId: " + appId.getId() + "\n" +
+ "\t clientToken: " + report.getClientToken() + "\n" +
+ "\t appDiagnostics: " + report.getDiagnostics() + "\n" +
+ "\t appMasterHost: " + report.getHost() + "\n" +
+ "\t appQueue: " + report.getQueue() + "\n" +
+ "\t appMasterRpcPort: " + report.getRpcPort() + "\n" +
+ "\t appStartTime: " + report.getStartTime() + "\n" +
+ "\t yarnAppState: " + report.getYarnApplicationState() + "\n" +
+ "\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" +
+ "\t appTrackingUrl: " + report.getTrackingUrl() + "\n" +
+ "\t appUser: " + report.getUser()
+ )
+
+ val state = report.getYarnApplicationState()
+ val dsStatus = report.getFinalApplicationStatus()
+ if (state == YarnApplicationState.FINISHED ||
+ state == YarnApplicationState.FAILED ||
+ state == YarnApplicationState.KILLED) {
+ return true
+ }
+ }
+ return true
+ }
+}
+
+object Client {
+ def main(argStrings: Array[String]) {
+ val args = new ClientArguments(argStrings)
+ SparkHadoopUtil.setYarnMode()
+ new Client(args).run
+ }
+
+ // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps
+ def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) {
+ for (c <- conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim)
+ }
+ }
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala
new file mode 100644
index 0000000000..2e69fe3fb0
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala
@@ -0,0 +1,106 @@
+package spark.deploy.yarn
+
+import spark.util.MemoryParam
+import spark.util.IntParam
+import collection.mutable.{ArrayBuffer, HashMap}
+import spark.scheduler.{InputFormatInfo, SplitInfo}
+
+// TODO: Add code and support for ensuring that yarn resource 'asks' are location aware !
+class ClientArguments(val args: Array[String]) {
+ var userJar: String = null
+ var userClass: String = null
+ var userArgs: Seq[String] = Seq[String]()
+ var workerMemory = 1024
+ var workerCores = 1
+ var numWorkers = 2
+ var amUser = System.getProperty("user.name")
+ var amQueue = System.getProperty("QUEUE", "default")
+ var amMemory: Int = 512
+ // TODO
+ var inputFormatInfo: List[InputFormatInfo] = null
+
+ parseArgs(args.toList)
+
+ private def parseArgs(inputArgs: List[String]): Unit = {
+ val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]()
+ val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]()
+
+ var args = inputArgs
+
+ while (! args.isEmpty) {
+
+ args match {
+ case ("--jar") :: value :: tail =>
+ userJar = value
+ args = tail
+
+ case ("--class") :: value :: tail =>
+ userClass = value
+ args = tail
+
+ case ("--args") :: value :: tail =>
+ userArgsBuffer += value
+ args = tail
+
+ case ("--master-memory") :: MemoryParam(value) :: tail =>
+ amMemory = value
+ args = tail
+
+ case ("--num-workers") :: IntParam(value) :: tail =>
+ numWorkers = value
+ args = tail
+
+ case ("--worker-memory") :: MemoryParam(value) :: tail =>
+ workerMemory = value
+ args = tail
+
+ case ("--worker-cores") :: IntParam(value) :: tail =>
+ workerCores = value
+ args = tail
+
+ case ("--user") :: value :: tail =>
+ amUser = value
+ args = tail
+
+ case ("--queue") :: value :: tail =>
+ amQueue = value
+ args = tail
+
+ case Nil =>
+ if (userJar == null || userClass == null) {
+ printUsageAndExit(1)
+ }
+
+ case _ =>
+ printUsageAndExit(1, args)
+ }
+ }
+
+ userArgs = userArgsBuffer.readOnly
+ inputFormatInfo = inputFormatMap.values.toList
+ }
+
+
+ def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ if (unknownParam != null) {
+ System.err.println("Unknown/unsupported param " + unknownParam)
+ }
+ System.err.println(
+ "Usage: spark.deploy.yarn.Client [options] \n" +
+ "Options:\n" +
+ " --jar JAR_PATH Path to your application's JAR file (required)\n" +
+ " --class CLASS_NAME Name of your application's main class (required)\n" +
+ " --args ARGS Arguments to be passed to your application's main class.\n" +
+ " Mutliple invocations are possible, each will be passed in order.\n" +
+ " Note that first argument will ALWAYS be yarn-standalone : will be added if missing.\n" +
+ " --num-workers NUM Number of workers to start (Default: 2)\n" +
+ " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
+ " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
+ " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
+ " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" +
+ " --user USERNAME Run the ApplicationMaster (and slaves) as a different user\n"
+ )
+ System.exit(exitCode)
+ }
+
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala
new file mode 100644
index 0000000000..a2bf0af762
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala
@@ -0,0 +1,171 @@
+package spark.deploy.yarn
+
+import java.net.URI
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records}
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.HashMap
+
+import spark.{Logging, Utils}
+
+class WorkerRunnable(container: Container, conf: Configuration, masterAddress: String,
+ slaveId: String, hostname: String, workerMemory: Int, workerCores: Int)
+ extends Runnable with Logging {
+
+ var rpc: YarnRPC = YarnRPC.create(conf)
+ var cm: ContainerManager = null
+ val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+ def run = {
+ logInfo("Starting Worker Container")
+ cm = connectToCM
+ startContainer
+ }
+
+ def startContainer = {
+ logInfo("Setting up ContainerLaunchContext")
+
+ val ctx = Records.newRecord(classOf[ContainerLaunchContext])
+ .asInstanceOf[ContainerLaunchContext]
+
+ ctx.setContainerId(container.getId())
+ ctx.setResource(container.getResource())
+ val localResources = prepareLocalResources
+ ctx.setLocalResources(localResources)
+
+ val env = prepareEnvironment
+ ctx.setEnvironment(env)
+
+ // Extra options for the JVM
+ var JAVA_OPTS = ""
+ // Set the JVM memory
+ val workerMemoryString = workerMemory + "m"
+ JAVA_OPTS += "-Xms" + workerMemoryString + " -Xmx" + workerMemoryString + " "
+ if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
+ JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
+ }
+ // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
+ // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
+ // node, spark gc effects all other containers performance (which can also be other spark containers)
+ // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
+ // limited to subset of cores on a node.
+/*
+ else {
+ // If no java_opts specified, default to using -XX:+CMSIncrementalMode
+ // It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont want to mess with it.
+ // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tennent machines
+ // The options are based on
+ // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline
+ JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
+ JAVA_OPTS += " -XX:+CMSIncrementalMode "
+ JAVA_OPTS += " -XX:+CMSIncrementalPacing "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
+ }
+*/
+
+ ctx.setUser(UserGroupInformation.getCurrentUser().getShortUserName())
+ val commands = List[String]("java " +
+ " -server " +
+ // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling.
+ // Not killing the task leaves various aspects of the worker and (to some extent) the jvm in an inconsistent state.
+ // TODO: If the OOM is not recoverable by rescheduling it on different node, then do 'something' to fail job ... akin to blacklisting trackers in mapred ?
+ " -XX:OnOutOfMemoryError='kill %p' " +
+ JAVA_OPTS +
+ " spark.executor.StandaloneExecutorBackend " +
+ masterAddress + " " +
+ slaveId + " " +
+ hostname + " " +
+ workerCores +
+ " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
+ " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
+ logInfo("Setting up worker with commands: " + commands)
+ ctx.setCommands(commands)
+
+ // Send the start request to the ContainerManager
+ val startReq = Records.newRecord(classOf[StartContainerRequest])
+ .asInstanceOf[StartContainerRequest]
+ startReq.setContainerLaunchContext(ctx)
+ cm.startContainer(startReq)
+ }
+
+
+ def prepareLocalResources: HashMap[String, LocalResource] = {
+ logInfo("Preparing Local resources")
+ val locaResources = HashMap[String, LocalResource]()
+
+ // Spark JAR
+ val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ sparkJarResource.setType(LocalResourceType.FILE)
+ sparkJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
+ sparkJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
+ new URI(System.getenv("SPARK_YARN_JAR_PATH"))))
+ sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong)
+ sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong)
+ locaResources("spark.jar") = sparkJarResource
+ // User JAR
+ val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ userJarResource.setType(LocalResourceType.FILE)
+ userJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
+ userJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
+ new URI(System.getenv("SPARK_YARN_USERJAR_PATH"))))
+ userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong)
+ userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong)
+ locaResources("app.jar") = userJarResource
+
+ // Log4j conf - if available
+ if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
+ val log4jConfResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ log4jConfResource.setType(LocalResourceType.FILE)
+ log4jConfResource.setVisibility(LocalResourceVisibility.APPLICATION)
+ log4jConfResource.setResource(ConverterUtils.getYarnUrlFromURI(
+ new URI(System.getenv("SPARK_YARN_LOG4J_PATH"))))
+ log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong)
+ log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong)
+ locaResources("log4j.properties") = log4jConfResource
+ }
+
+
+ logInfo("Prepared Local resources " + locaResources)
+ return locaResources
+ }
+
+ def prepareEnvironment: HashMap[String, String] = {
+ val env = new HashMap[String, String]()
+ // should we add this ?
+ Apps.addToEnvironment(env, Environment.USER.name, Utils.getUserNameFromEnvironment())
+
+ // If log4j present, ensure ours overrides all others
+ if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
+ // Which is correct ?
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./log4j.properties")
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
+ }
+
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
+ Client.populateHadoopClasspath(yarnConf, env)
+
+ System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
+ return env
+ }
+
+ def connectToCM: ContainerManager = {
+ val cmHostPortStr = container.getNodeId().getHost() + ":" + container.getNodeId().getPort()
+ val cmAddress = NetUtils.createSocketAddr(cmHostPortStr)
+ logInfo("Connecting to ContainerManager at " + cmHostPortStr)
+ return rpc.getProxy(classOf[ContainerManager], cmAddress, conf).asInstanceOf[ContainerManager]
+ }
+
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala
new file mode 100644
index 0000000000..61dd72a651
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala
@@ -0,0 +1,547 @@
+package spark.deploy.yarn
+
+import spark.{Logging, Utils}
+import spark.scheduler.SplitInfo
+import scala.collection
+import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId, ContainerId, Priority, Resource, ResourceRequest, ContainerStatus, Container}
+import spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend}
+import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse}
+import org.apache.hadoop.yarn.util.{RackResolver, Records}
+import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap}
+import java.util.concurrent.atomic.AtomicInteger
+import org.apache.hadoop.yarn.api.AMRMProtocol
+import collection.JavaConversions._
+import collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import org.apache.hadoop.conf.Configuration
+import java.util.{Collections, Set => JSet}
+import java.lang.{Boolean => JBoolean}
+
+object AllocationType extends Enumeration ("HOST", "RACK", "ANY") {
+ type AllocationType = Value
+ val HOST, RACK, ANY = Value
+}
+
+// too many params ? refactor it 'somehow' ?
+// needs to be mt-safe
+// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive : should make it
+// more proactive and decoupled.
+// Note that right now, we assume all node asks as uniform in terms of capabilities and priority
+// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for more info
+// on how we are requesting for containers.
+private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceManager: AMRMProtocol,
+ val appAttemptId: ApplicationAttemptId,
+ val maxWorkers: Int, val workerMemory: Int, val workerCores: Int,
+ val preferredHostToCount: Map[String, Int],
+ val preferredRackToCount: Map[String, Int])
+ extends Logging {
+
+
+ // These three are locked on allocatedHostToContainersMap. Complementary data structures
+ // allocatedHostToContainersMap : containers which are running : host, Set<containerid>
+ // allocatedContainerToHostMap: container to host mapping
+ private val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]]()
+ private val allocatedContainerToHostMap = new HashMap[ContainerId, String]()
+ // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an allocated node)
+ // As with the two data structures above, tightly coupled with them, and to be locked on allocatedHostToContainersMap
+ private val allocatedRackCount = new HashMap[String, Int]()
+
+ // containers which have been released.
+ private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]()
+ // containers to be released in next request to RM
+ private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean]
+
+ private val numWorkersRunning = new AtomicInteger()
+ // Used to generate a unique id per worker
+ private val workerIdCounter = new AtomicInteger()
+ private val lastResponseId = new AtomicInteger()
+
+ def getNumWorkersRunning: Int = numWorkersRunning.intValue
+
+
+ def isResourceConstraintSatisfied(container: Container): Boolean = {
+ container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+ }
+
+ def allocateContainers(workersToRequest: Int) {
+ // We need to send the request only once from what I understand ... but for now, not modifying this much.
+
+ // Keep polling the Resource Manager for containers
+ val amResp = allocateWorkerResources(workersToRequest).getAMResponse
+
+ val _allocatedContainers = amResp.getAllocatedContainers()
+ if (_allocatedContainers.size > 0) {
+
+
+ logDebug("Allocated " + _allocatedContainers.size + " containers, current count " +
+ numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
+ ", pendingReleaseContainers : " + pendingReleaseContainers)
+ logDebug("Cluster Resources: " + amResp.getAvailableResources)
+
+ val hostToContainers = new HashMap[String, ArrayBuffer[Container]]()
+
+ // ignore if not satisfying constraints {
+ for (container <- _allocatedContainers) {
+ if (isResourceConstraintSatisfied(container)) {
+ // allocatedContainers += container
+
+ val host = container.getNodeId.getHost
+ val containers = hostToContainers.getOrElseUpdate(host, new ArrayBuffer[Container]())
+
+ containers += container
+ }
+ // Add all ignored containers to released list
+ else releasedContainerList.add(container.getId())
+ }
+
+ // Find the appropriate containers to use
+ // Slightly non trivial groupBy I guess ...
+ val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
+ val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
+ val offRackContainers = new HashMap[String, ArrayBuffer[Container]]()
+
+ for (candidateHost <- hostToContainers.keySet)
+ {
+ val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0)
+ val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost)
+
+ var remainingContainers = hostToContainers.get(candidateHost).getOrElse(null)
+ assert(remainingContainers != null)
+
+ if (requiredHostCount >= remainingContainers.size){
+ // Since we got <= required containers, add all to dataLocalContainers
+ dataLocalContainers.put(candidateHost, remainingContainers)
+ // all consumed
+ remainingContainers = null
+ }
+ else if (requiredHostCount > 0) {
+ // container list has more containers than we need for data locality.
+ // Split into two : data local container count of (remainingContainers.size - requiredHostCount)
+ // and rest as remainingContainer
+ val (dataLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredHostCount)
+ dataLocalContainers.put(candidateHost, dataLocal)
+ // remainingContainers = remaining
+
+ // yarn has nasty habit of allocating a tonne of containers on a host - discourage this :
+ // add remaining to release list. If we have insufficient containers, next allocation cycle
+ // will reallocate (but wont treat it as data local)
+ for (container <- remaining) releasedContainerList.add(container.getId())
+ remainingContainers = null
+ }
+
+ // now rack local
+ if (remainingContainers != null){
+ val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
+
+ if (rack != null){
+ val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0)
+ val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) -
+ rackLocalContainers.get(rack).getOrElse(List()).size
+
+
+ if (requiredRackCount >= remainingContainers.size){
+ // Add all to dataLocalContainers
+ dataLocalContainers.put(rack, remainingContainers)
+ // all consumed
+ remainingContainers = null
+ }
+ else if (requiredRackCount > 0) {
+ // container list has more containers than we need for data locality.
+ // Split into two : data local container count of (remainingContainers.size - requiredRackCount)
+ // and rest as remainingContainer
+ val (rackLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredRackCount)
+ val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, new ArrayBuffer[Container]())
+
+ existingRackLocal ++= rackLocal
+ remainingContainers = remaining
+ }
+ }
+ }
+
+ // If still not consumed, then it is off rack host - add to that list.
+ if (remainingContainers != null){
+ offRackContainers.put(candidateHost, remainingContainers)
+ }
+ }
+
+ // Now that we have split the containers into various groups, go through them in order :
+ // first host local, then rack local and then off rack (everything else).
+ // Note that the list we create below tries to ensure that not all containers end up within a host
+ // if there are sufficiently large number of hosts/containers.
+
+ val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size)
+ allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers)
+ allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers)
+ allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers)
+
+ // Run each of the allocated containers
+ for (container <- allocatedContainers) {
+ val numWorkersRunningNow = numWorkersRunning.incrementAndGet()
+ val workerHostname = container.getNodeId.getHost
+ val containerId = container.getId
+
+ assert (container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD))
+
+ if (numWorkersRunningNow > maxWorkers) {
+ logInfo("Ignoring container " + containerId + " at host " + workerHostname +
+ " .. we already have required number of containers")
+ releasedContainerList.add(containerId)
+ // reset counter back to old value.
+ numWorkersRunning.decrementAndGet()
+ }
+ else {
+ // deallocate + allocate can result in reusing id's wrongly - so use a different counter (workerIdCounter)
+ val workerId = workerIdCounter.incrementAndGet().toString
+ val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
+ StandaloneSchedulerBackend.ACTOR_NAME)
+
+ logInfo("launching container on " + containerId + " host " + workerHostname)
+ // just to be safe, simply remove it from pendingReleaseContainers. Should not be there, but ..
+ pendingReleaseContainers.remove(containerId)
+
+ val rack = YarnAllocationHandler.lookupRack(conf, workerHostname)
+ allocatedHostToContainersMap.synchronized {
+ val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname, new HashSet[ContainerId]())
+
+ containerSet += containerId
+ allocatedContainerToHostMap.put(containerId, workerHostname)
+ if (rack != null) allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1)
+ }
+
+ new Thread(
+ new WorkerRunnable(container, conf, driverUrl, workerId,
+ workerHostname, workerMemory, workerCores)
+ ).start()
+ }
+ }
+ logDebug("After allocated " + allocatedContainers.size + " containers (orig : " +
+ _allocatedContainers.size + "), current count " + numWorkersRunning.get() +
+ ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
+ }
+
+
+ val completedContainers = amResp.getCompletedContainersStatuses()
+ if (completedContainers.size > 0){
+ logDebug("Completed " + completedContainers.size + " containers, current count " + numWorkersRunning.get() +
+ ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
+
+ for (completedContainer <- completedContainers){
+ val containerId = completedContainer.getContainerId
+
+ // Was this released by us ? If yes, then simply remove from containerSet and move on.
+ if (pendingReleaseContainers.containsKey(containerId)) {
+ pendingReleaseContainers.remove(containerId)
+ }
+ else {
+ // simply decrement count - next iteration of ReporterThread will take care of allocating !
+ numWorkersRunning.decrementAndGet()
+ logInfo("Container completed ? nodeId: " + containerId + ", state " + completedContainer.getState +
+ " httpaddress: " + completedContainer.getDiagnostics)
+ }
+
+ allocatedHostToContainersMap.synchronized {
+ if (allocatedContainerToHostMap.containsKey(containerId)) {
+ val host = allocatedContainerToHostMap.get(containerId).getOrElse(null)
+ assert (host != null)
+
+ val containerSet = allocatedHostToContainersMap.get(host).getOrElse(null)
+ assert (containerSet != null)
+
+ containerSet -= containerId
+ if (containerSet.isEmpty) allocatedHostToContainersMap.remove(host)
+ else allocatedHostToContainersMap.update(host, containerSet)
+
+ allocatedContainerToHostMap -= containerId
+
+ // doing this within locked context, sigh ... move to outside ?
+ val rack = YarnAllocationHandler.lookupRack(conf, host)
+ if (rack != null) {
+ val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1
+ if (rackCount > 0) allocatedRackCount.put(rack, rackCount)
+ else allocatedRackCount.remove(rack)
+ }
+ }
+ }
+ }
+ logDebug("After completed " + completedContainers.size + " containers, current count " +
+ numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
+ ", pendingReleaseContainers : " + pendingReleaseContainers)
+ }
+ }
+
+ def createRackResourceRequests(hostContainers: List[ResourceRequest]): List[ResourceRequest] = {
+ // First generate modified racks and new set of hosts under it : then issue requests
+ val rackToCounts = new HashMap[String, Int]()
+
+ // Within this lock - used to read/write to the rack related maps too.
+ for (container <- hostContainers) {
+ val candidateHost = container.getHostName
+ val candidateNumContainers = container.getNumContainers
+ assert(YarnAllocationHandler.ANY_HOST != candidateHost)
+
+ val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
+ if (rack != null) {
+ var count = rackToCounts.getOrElse(rack, 0)
+ count += candidateNumContainers
+ rackToCounts.put(rack, count)
+ }
+ }
+
+ val requestedContainers: ArrayBuffer[ResourceRequest] =
+ new ArrayBuffer[ResourceRequest](rackToCounts.size)
+ for ((rack, count) <- rackToCounts){
+ requestedContainers +=
+ createResourceRequest(AllocationType.RACK, rack, count, YarnAllocationHandler.PRIORITY)
+ }
+
+ requestedContainers.toList
+ }
+
+ def allocatedContainersOnHost(host: String): Int = {
+ var retval = 0
+ allocatedHostToContainersMap.synchronized {
+ retval = allocatedHostToContainersMap.getOrElse(host, Set()).size
+ }
+ retval
+ }
+
+ def allocatedContainersOnRack(rack: String): Int = {
+ var retval = 0
+ allocatedHostToContainersMap.synchronized {
+ retval = allocatedRackCount.getOrElse(rack, 0)
+ }
+ retval
+ }
+
+ private def allocateWorkerResources(numWorkers: Int): AllocateResponse = {
+
+ var resourceRequests: List[ResourceRequest] = null
+
+ // default.
+ if (numWorkers <= 0 || preferredHostToCount.isEmpty) {
+ logDebug("numWorkers: " + numWorkers + ", host preferences ? " + preferredHostToCount.isEmpty)
+ resourceRequests = List(
+ createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY))
+ }
+ else {
+ // request for all hosts in preferred nodes and for numWorkers -
+ // candidates.size, request by default allocation policy.
+ val hostContainerRequests: ArrayBuffer[ResourceRequest] =
+ new ArrayBuffer[ResourceRequest](preferredHostToCount.size)
+ for ((candidateHost, candidateCount) <- preferredHostToCount) {
+ val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost)
+
+ if (requiredCount > 0) {
+ hostContainerRequests +=
+ createResourceRequest(AllocationType.HOST, candidateHost, requiredCount, YarnAllocationHandler.PRIORITY)
+ }
+ }
+ val rackContainerRequests: List[ResourceRequest] = createRackResourceRequests(hostContainerRequests.toList)
+
+ val anyContainerRequests: ResourceRequest =
+ createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY)
+
+ val containerRequests: ArrayBuffer[ResourceRequest] =
+ new ArrayBuffer[ResourceRequest](hostContainerRequests.size() + rackContainerRequests.size() + 1)
+
+ containerRequests ++= hostContainerRequests
+ containerRequests ++= rackContainerRequests
+ containerRequests += anyContainerRequests
+
+ resourceRequests = containerRequests.toList
+ }
+
+ val req = Records.newRecord(classOf[AllocateRequest])
+ req.setResponseId(lastResponseId.incrementAndGet)
+ req.setApplicationAttemptId(appAttemptId)
+
+ req.addAllAsks(resourceRequests)
+
+ val releasedContainerList = createReleasedContainerList()
+ req.addAllReleases(releasedContainerList)
+
+
+
+ if (numWorkers > 0) {
+ logInfo("Allocating " + numWorkers + " worker containers with " + (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + " of memory each.")
+ }
+ else {
+ logDebug("Empty allocation req .. release : " + releasedContainerList)
+ }
+
+ for (req <- resourceRequests) {
+ logInfo("rsrcRequest ... host : " + req.getHostName + ", numContainers : " + req.getNumContainers +
+ ", p = " + req.getPriority().getPriority + ", capability: " + req.getCapability)
+ }
+ resourceManager.allocate(req)
+ }
+
+
+ private def createResourceRequest(requestType: AllocationType.AllocationType,
+ resource:String, numWorkers: Int, priority: Int): ResourceRequest = {
+
+ // If hostname specified, we need atleast two requests - node local and rack local.
+ // There must be a third request - which is ANY : that will be specially handled.
+ requestType match {
+ case AllocationType.HOST => {
+ assert (YarnAllocationHandler.ANY_HOST != resource)
+
+ val hostname = resource
+ val nodeLocal = createResourceRequestImpl(hostname, numWorkers, priority)
+
+ // add to host->rack mapping
+ YarnAllocationHandler.populateRackInfo(conf, hostname)
+
+ nodeLocal
+ }
+
+ case AllocationType.RACK => {
+ val rack = resource
+ createResourceRequestImpl(rack, numWorkers, priority)
+ }
+
+ case AllocationType.ANY => {
+ createResourceRequestImpl(YarnAllocationHandler.ANY_HOST, numWorkers, priority)
+ }
+
+ case _ => throw new IllegalArgumentException("Unexpected/unsupported request type .. " + requestType)
+ }
+ }
+
+ private def createResourceRequestImpl(hostname:String, numWorkers: Int, priority: Int): ResourceRequest = {
+
+ val rsrcRequest = Records.newRecord(classOf[ResourceRequest])
+ val memCapability = Records.newRecord(classOf[Resource])
+ // There probably is some overhead here, let's reserve a bit more memory.
+ memCapability.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+ rsrcRequest.setCapability(memCapability)
+
+ val pri = Records.newRecord(classOf[Priority])
+ pri.setPriority(priority)
+ rsrcRequest.setPriority(pri)
+
+ rsrcRequest.setHostName(hostname)
+
+ rsrcRequest.setNumContainers(java.lang.Math.max(numWorkers, 0))
+ rsrcRequest
+ }
+
+ def createReleasedContainerList(): ArrayBuffer[ContainerId] = {
+
+ val retval = new ArrayBuffer[ContainerId](1)
+ // iterator on COW list ...
+ for (container <- releasedContainerList.iterator()){
+ retval += container
+ }
+ // remove from the original list.
+ if (! retval.isEmpty) {
+ releasedContainerList.removeAll(retval)
+ for (v <- retval) pendingReleaseContainers.put(v, true)
+ logInfo("Releasing " + retval.size + " containers. pendingReleaseContainers : " +
+ pendingReleaseContainers)
+ }
+
+ retval
+ }
+}
+
+object YarnAllocationHandler {
+
+ val ANY_HOST = "*"
+ // all requests are issued with same priority : we do not (yet) have any distinction between
+ // request types (like map/reduce in hadoop for example)
+ val PRIORITY = 1
+
+ // Additional memory overhead - in mb
+ val MEMORY_OVERHEAD = 384
+
+ // host to rack map - saved from allocation requests
+ // We are expecting this not to change.
+ // Note that it is possible for this to change : and RM will indicate that to us via update
+ // response to allocate. But we are punting on handling that for now.
+ private val hostToRack = new ConcurrentHashMap[String, String]()
+ private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]()
+
+ def newAllocator(conf: Configuration,
+ resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
+ args: ApplicationMasterArguments,
+ map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
+
+ val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
+
+
+ new YarnAllocationHandler(conf, resourceManager, appAttemptId, args.numWorkers,
+ args.workerMemory, args.workerCores, hostToCount, rackToCount)
+ }
+
+ def newAllocator(conf: Configuration,
+ resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
+ maxWorkers: Int, workerMemory: Int, workerCores: Int,
+ map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
+
+ val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
+
+ new YarnAllocationHandler(conf, resourceManager, appAttemptId, maxWorkers,
+ workerMemory, workerCores, hostToCount, rackToCount)
+ }
+
+ // A simple method to copy the split info map.
+ private def generateNodeToWeight(conf: Configuration, input: collection.Map[String, collection.Set[SplitInfo]]) :
+ // host to count, rack to count
+ (Map[String, Int], Map[String, Int]) = {
+
+ if (input == null) return (Map[String, Int](), Map[String, Int]())
+
+ val hostToCount = new HashMap[String, Int]
+ val rackToCount = new HashMap[String, Int]
+
+ for ((host, splits) <- input) {
+ val hostCount = hostToCount.getOrElse(host, 0)
+ hostToCount.put(host, hostCount + splits.size)
+
+ val rack = lookupRack(conf, host)
+ if (rack != null){
+ val rackCount = rackToCount.getOrElse(host, 0)
+ rackToCount.put(host, rackCount + splits.size)
+ }
+ }
+
+ (hostToCount.toMap, rackToCount.toMap)
+ }
+
+ def lookupRack(conf: Configuration, host: String): String = {
+ if (! hostToRack.contains(host)) populateRackInfo(conf, host)
+ hostToRack.get(host)
+ }
+
+ def fetchCachedHostsForRack(rack: String): Option[Set[String]] = {
+ val set = rackToHostSet.get(rack)
+ if (set == null) return None
+
+ // No better way to get a Set[String] from JSet ?
+ val convertedSet: collection.mutable.Set[String] = set
+ Some(convertedSet.toSet)
+ }
+
+ def populateRackInfo(conf: Configuration, hostname: String) {
+ Utils.checkHost(hostname)
+
+ if (!hostToRack.containsKey(hostname)) {
+ // If there are repeated failures to resolve, all to an ignore list ?
+ val rackInfo = RackResolver.resolve(conf, hostname)
+ if (rackInfo != null && rackInfo.getNetworkLocation != null) {
+ val rack = rackInfo.getNetworkLocation
+ hostToRack.put(hostname, rack)
+ if (! rackToHostSet.containsKey(rack)) {
+ rackToHostSet.putIfAbsent(rack, Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]()))
+ }
+ rackToHostSet.get(rack).add(hostname)
+
+ // Since RackResolver caches, we are disabling this for now ...
+ } /* else {
+ // right ? Else we will keep calling rack resolver in case we cant resolve rack info ...
+ hostToRack.put(hostname, null)
+ } */
+ }
+ }
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala b/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala
new file mode 100644
index 0000000000..ed732d36bf
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -0,0 +1,42 @@
+package spark.scheduler.cluster
+
+import spark._
+import spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler}
+import org.apache.hadoop.conf.Configuration
+
+/**
+ *
+ * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done
+ */
+private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+
+ def this(sc: SparkContext) = this(sc, new Configuration())
+
+ // Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate
+ // Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?)
+ // Subsequent creations are ignored - since nodes are already allocated by then.
+
+
+ // By default, rack is unknown
+ override def getRackForHost(hostPort: String): Option[String] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ val retval = YarnAllocationHandler.lookupRack(conf, host)
+ if (retval != null) Some(retval) else None
+ }
+
+ // By default, if rack is unknown, return nothing
+ override def getCachedHostsForRack(rack: String): Option[Set[String]] = {
+ if (rack == None || rack == null) return None
+
+ YarnAllocationHandler.fetchCachedHostsForRack(rack)
+ }
+
+ override def postStartHook() {
+ val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc)
+ if (sparkContextInitialized){
+ // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
+ Thread.sleep(3000L)
+ }
+ logInfo("YarnClusterScheduler.postStartHook done")
+ }
+}
diff --git a/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
index 35300cea58..a0652d7fc7 100644
--- a/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
+++ b/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
@@ -4,4 +4,7 @@ trait HadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
+ jobId, isMap, taskId, attemptId)
}
diff --git a/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
index 7afdbff320..7fdbe322fd 100644
--- a/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
+++ b/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
@@ -7,4 +7,7 @@ trait HadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
+ jobId, isMap, taskId, attemptId)
}
diff --git a/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala
new file mode 100644
index 0000000000..a0fb4fe25d
--- /dev/null
+++ b/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala
@@ -0,0 +1,23 @@
+package spark.deploy
+import org.apache.hadoop.conf.Configuration
+
+
+/**
+ * Contains util methods to interact with Hadoop from spark.
+ */
+object SparkHadoopUtil {
+
+ def getUserNameFromEnvironment(): String = {
+ // defaulting to -D ...
+ System.getProperty("user.name")
+ }
+
+ def runAsUser(func: (Product) => Unit, args: Product) {
+
+ // Add support, if exists - for now, simply run func !
+ func(args)
+ }
+
+ // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
+ def newConfiguration(): Configuration = new Configuration()
+}
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
index c27ed36406..e1fb02157a 100644
--- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -1,14 +1,19 @@
package spark
-import executor.{ShuffleReadMetrics, TaskMetrics}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
-import spark.storage.{DelegateBlockFetchTracker, BlockManagerId}
-import util.{CompletionIterator, TimedIterator}
+import spark.executor.{ShuffleReadMetrics, TaskMetrics}
+import spark.serializer.Serializer
+import spark.storage.BlockManagerId
+import spark.util.CompletionIterator
+
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
- override def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) = {
+
+ override def fetch[K, V](
+ shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) = {
+
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
@@ -48,18 +53,17 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
}
}
- val blockFetcherItr = blockManager.getMultiple(blocksByAddress)
- val itr = new TimedIterator(blockFetcherItr.flatMap(unpackBlock)) with DelegateBlockFetchTracker
- itr.setDelegate(blockFetcherItr)
+ val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
+ val itr = blockFetcherItr.flatMap(unpackBlock)
+
CompletionIterator[(K,V), Iterator[(K,V)]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
- shuffleMetrics.shuffleReadMillis = itr.getNetMillis
- shuffleMetrics.remoteFetchTime = itr.remoteFetchTime
- shuffleMetrics.fetchWaitTime = itr.fetchWaitTime
- shuffleMetrics.remoteBytesRead = itr.remoteBytesRead
- shuffleMetrics.totalBlocksFetched = itr.totalBlocks
- shuffleMetrics.localBlocksFetched = itr.numLocalBlocks
- shuffleMetrics.remoteBlocksFetched = itr.numRemoteBlocks
+ shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
+ shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
+ shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
+ shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
+ shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
+ shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
metrics.shuffleReadMetrics = Some(shuffleMetrics)
})
}
diff --git a/core/src/main/scala/spark/ClosureCleaner.scala b/core/src/main/scala/spark/ClosureCleaner.scala
index 98525b99c8..50d6a1c5c9 100644
--- a/core/src/main/scala/spark/ClosureCleaner.scala
+++ b/core/src/main/scala/spark/ClosureCleaner.scala
@@ -8,12 +8,20 @@ import scala.collection.mutable.Set
import org.objectweb.asm.{ClassReader, MethodVisitor, Type}
import org.objectweb.asm.commons.EmptyVisitor
import org.objectweb.asm.Opcodes._
+import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream}
private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
private def getClassReader(cls: Class[_]): ClassReader = {
- new ClassReader(cls.getResourceAsStream(
- cls.getName.replaceFirst("^.*\\.", "") + ".class"))
+ // Copy data over, before delegating to ClassReader - else we can run out of open file handles.
+ val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
+ val resourceStream = cls.getResourceAsStream(className)
+ // todo: Fixme - continuing with earlier behavior ...
+ if (resourceStream == null) return new ClassReader(resourceStream)
+
+ val baos = new ByteArrayOutputStream(128)
+ Utils.copyStream(resourceStream, baos, true)
+ new ClassReader(new ByteArrayInputStream(baos.toByteArray))
}
// Check whether a class represents a Scala closure
diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala
index 5eea907322..2af44aa383 100644
--- a/core/src/main/scala/spark/Dependency.scala
+++ b/core/src/main/scala/spark/Dependency.scala
@@ -25,10 +25,12 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
* @param shuffleId the shuffle id
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
+ * @param serializerClass class name of the serializer to use
*/
class ShuffleDependency[K, V](
@transient rdd: RDD[(K, V)],
- val partitioner: Partitioner)
+ val partitioner: Partitioner,
+ val serializerClass: String = null)
extends Dependency(rdd) {
val shuffleId: Int = rdd.context.newShuffleId()
diff --git a/core/src/main/scala/spark/FetchFailedException.scala b/core/src/main/scala/spark/FetchFailedException.scala
index a953081d24..40b0193f19 100644
--- a/core/src/main/scala/spark/FetchFailedException.scala
+++ b/core/src/main/scala/spark/FetchFailedException.scala
@@ -3,18 +3,25 @@ package spark
import spark.storage.BlockManagerId
private[spark] class FetchFailedException(
- val bmAddress: BlockManagerId,
- val shuffleId: Int,
- val mapId: Int,
- val reduceId: Int,
+ taskEndReason: TaskEndReason,
+ message: String,
cause: Throwable)
extends Exception {
-
- override def getMessage(): String =
- "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
+
+ def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) =
+ this(FetchFailed(bmAddress, shuffleId, mapId, reduceId),
+ "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId),
+ cause)
+
+ def this (shuffleId: Int, reduceId: Int, cause: Throwable) =
+ this(FetchFailed(null, shuffleId, -1, reduceId),
+ "Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause)
+
+ override def getMessage(): String = message
+
override def getCause(): Throwable = cause
- def toTaskEndReason: TaskEndReason =
- FetchFailed(bmAddress, shuffleId, mapId, reduceId)
+ def toTaskEndReason: TaskEndReason = taskEndReason
+
}
diff --git a/core/src/main/scala/spark/HadoopWriter.scala b/core/src/main/scala/spark/HadoopWriter.scala
index afcf9f6db4..5e8396edb9 100644
--- a/core/src/main/scala/spark/HadoopWriter.scala
+++ b/core/src/main/scala/spark/HadoopWriter.scala
@@ -2,14 +2,10 @@ package org.apache.hadoop.mapred
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
-import org.apache.hadoop.util.ReflectionUtils
-import org.apache.hadoop.io.NullWritable
-import org.apache.hadoop.io.Text
import java.text.SimpleDateFormat
import java.text.NumberFormat
import java.io.IOException
-import java.net.URI
import java.util.Date
import spark.Logging
@@ -24,7 +20,7 @@ import spark.SerializableWritable
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRedUtil with Serializable {
-
+
private val now = new Date()
private val conf = new SerializableWritable(jobConf)
@@ -106,6 +102,12 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRe
}
}
+ def commitJob() {
+ // always ? Or if cmtr.needsTaskCommit ?
+ val cmtr = getOutputCommitter()
+ cmtr.commitJob(getJobContext())
+ }
+
def cleanup() {
getOutputCommitter().cleanupJob(getJobContext())
}
diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala
index 7c1c1bb144..0fc8c31463 100644
--- a/core/src/main/scala/spark/Logging.scala
+++ b/core/src/main/scala/spark/Logging.scala
@@ -68,6 +68,10 @@ trait Logging {
if (log.isErrorEnabled) log.error(msg, throwable)
}
+ protected def isTraceEnabled(): Boolean = {
+ log.isTraceEnabled
+ }
+
// Method for ensuring that logging is initialized, to avoid having multiple
// threads do it concurrently (as SLF4J initialization is not thread safe).
protected def initLogging() { log }
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 866d630a6d..fde597ffd1 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -1,7 +1,6 @@
package spark
import java.io._
-import java.util.concurrent.ConcurrentHashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.HashMap
@@ -12,8 +11,7 @@ import akka.dispatch._
import akka.pattern.ask
import akka.remote._
import akka.util.Duration
-import akka.util.Timeout
-import akka.util.duration._
+
import spark.scheduler.MapStatus
import spark.storage.BlockManagerId
@@ -40,10 +38,12 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac
private[spark] class MapOutputTracker extends Logging {
+ private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _
- var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
@@ -52,7 +52,7 @@ private[spark] class MapOutputTracker extends Logging {
// Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheGeneration = generation
- val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
+ private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
@@ -60,7 +60,6 @@ private[spark] class MapOutputTracker extends Logging {
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
try {
- val timeout = 10.seconds
val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout)
} catch {
@@ -77,10 +76,9 @@ private[spark] class MapOutputTracker extends Logging {
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (mapStatuses.get(shuffleId) != None) {
+ if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
- mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
@@ -101,8 +99,9 @@ private[spark] class MapOutputTracker extends Logging {
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var array = mapStatuses(shuffleId)
- if (array != null) {
+ var arrayOpt = mapStatuses.get(shuffleId)
+ if (arrayOpt.isDefined && arrayOpt.get != null) {
+ var array = arrayOpt.get
array.synchronized {
if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null
@@ -115,13 +114,14 @@ private[spark] class MapOutputTracker extends Logging {
}
// Remembers which map output locations are currently being fetched on a worker
- val fetching = new HashSet[Int]
+ private val fetching = new HashSet[Int]
// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
+ var fetchedStatuses: Array[MapStatus] = null
fetching.synchronized {
if (fetching.contains(shuffleId)) {
// Someone else is fetching it; wait for them to be done
@@ -132,31 +132,48 @@ private[spark] class MapOutputTracker extends Logging {
case e: InterruptedException =>
}
}
- return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId))
- } else {
+ }
+
+ // Either while we waited the fetch happened successfully, or
+ // someone fetched it in between the get and the fetching.synchronized.
+ fetchedStatuses = mapStatuses.get(shuffleId).orNull
+ if (fetchedStatuses == null) {
+ // We have to do the fetch, get others to wait for us.
fetching += shuffleId
}
}
- // We won the race to fetch the output locs; do so
- logInfo("Doing the fetch; tracker actor = " + trackerActor)
- val host = System.getProperty("spark.hostname", Utils.localHostName)
- // This try-finally prevents hangs due to timeouts:
- var fetchedStatuses: Array[MapStatus] = null
- try {
- val fetchedBytes =
- askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]]
- fetchedStatuses = deserializeStatuses(fetchedBytes)
- logInfo("Got the output locations")
- mapStatuses.put(shuffleId, fetchedStatuses)
- } finally {
- fetching.synchronized {
- fetching -= shuffleId
- fetching.notifyAll()
+
+ if (fetchedStatuses == null) {
+ // We won the race to fetch the output locs; do so
+ logInfo("Doing the fetch; tracker actor = " + trackerActor)
+ val hostPort = Utils.localHostPort()
+ // This try-finally prevents hangs due to timeouts:
+ try {
+ val fetchedBytes =
+ askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
+ fetchedStatuses = deserializeStatuses(fetchedBytes)
+ logInfo("Got the output locations")
+ mapStatuses.put(shuffleId, fetchedStatuses)
+ } finally {
+ fetching.synchronized {
+ fetching -= shuffleId
+ fetching.notifyAll()
+ }
+ }
+ }
+ if (fetchedStatuses != null) {
+ fetchedStatuses.synchronized {
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
}
- return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
+ else{
+ throw new FetchFailedException(null, shuffleId, -1, reduceId,
+ new Exception("Missing all output locations for shuffle " + shuffleId))
+ }
} else {
- return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
+ statuses.synchronized {
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
+ }
}
}
@@ -194,7 +211,8 @@ private[spark] class MapOutputTracker extends Logging {
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
- mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ mapStatuses.clear()
generation = newGen
}
}
@@ -232,10 +250,13 @@ private[spark] class MapOutputTracker extends Logging {
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
- def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
+ private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
- objOut.writeObject(statuses)
+ // Since statuses can be modified in parallel, sync on it
+ statuses.synchronized {
+ objOut.writeObject(statuses)
+ }
objOut.close()
out.toByteArray
}
@@ -243,7 +264,10 @@ private[spark] class MapOutputTracker extends Logging {
// Opposite of serializeStatuses.
def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
- objIn.readObject().asInstanceOf[Array[MapStatus]]
+ objIn.readObject().
+ // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
+ // comment this out - nulls could be due to missing location ?
+ asInstanceOf[Array[MapStatus]] // .filter( _ != null )
}
}
@@ -253,14 +277,11 @@ private[spark] object MapOutputTracker {
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),
// throw a FetchFailedException.
- def convertMapStatuses(
+ private def convertMapStatuses(
shuffleId: Int,
reduceId: Int,
statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
- if (statuses == null) {
- throw new FetchFailedException(null, shuffleId, -1, reduceId,
- new Exception("Missing all output locations for shuffle " + shuffleId))
- }
+ assert (statuses != null)
statuses.map {
status =>
if (status == null) {
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 07efba9e8d..2b0e697337 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -52,7 +52,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
partitioner: Partitioner,
- mapSideCombine: Boolean = true): RDD[(K, C)] = {
+ mapSideCombine: Boolean = true,
+ serializerClass: String = null): RDD[(K, C)] = {
if (getKeyClass().isArray) {
if (mapSideCombine) {
throw new SparkException("Cannot use map-side combining with array keys.")
@@ -67,13 +68,13 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
self.mapPartitions(aggregator.combineValuesByKey(_), true)
} else if (mapSideCombine) {
val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
- val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
+ val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner, serializerClass)
partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
- val values = new ShuffledRDD[K, V](self, partitioner)
+ val values = new ShuffledRDD[K, V](self, partitioner, serializerClass)
values.mapPartitions(aggregator.combineValuesByKey(_), true)
}
}
@@ -469,7 +470,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/**
* Return an RDD with the pairs from `this` whose keys are not in `other`.
- *
+ *
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
@@ -545,8 +546,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
- val attemptId = new TaskAttemptID(jobtrackerID,
- stageId, false, context.splitId, attemptNumber)
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = outputFormatClass.newInstance
val committer = format.getOutputCommitter(hadoopContext)
@@ -565,11 +565,12 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* however we're only going to use this local OutputCommitter for
* setupJob/commitJob, so we just use a dummy "map" task.
*/
- val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, true, 0, 0)
+ val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, true, 0, 0)
val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
jobCommitter.setupJob(jobTaskContext)
val count = self.context.runJob(self, writeShard _).sum
+ jobCommitter.commitJob(jobTaskContext)
jobCommitter.cleanupJob(jobTaskContext)
}
@@ -637,6 +638,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
self.context.runJob(self, writeToFile _)
+ writer.commitJob()
writer.cleanup()
}
@@ -644,7 +646,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* Return an RDD with the keys of each tuple.
*/
def keys: RDD[K] = self.map(_._1)
-
+
/**
* Return an RDD with the values of each tuple.
*/
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 33dc7627a3..fd14ef17f1 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -35,6 +35,9 @@ import spark.rdd.ShuffledRDD
import spark.rdd.SubtractedRDD
import spark.rdd.UnionRDD
import spark.rdd.ZippedRDD
+import spark.rdd.ZippedPartitionsRDD2
+import spark.rdd.ZippedPartitionsRDD3
+import spark.rdd.ZippedPartitionsRDD4
import spark.storage.StorageLevel
import SparkContext._
@@ -104,7 +107,7 @@ abstract class RDD[T: ClassManifest](
// =======================================================================
/** A unique ID for this RDD (within its SparkContext). */
- val id = sc.newRddId()
+ val id: Int = sc.newRddId()
/** A friendly name for this RDD */
var name: String = null
@@ -117,7 +120,8 @@ abstract class RDD[T: ClassManifest](
/**
* Set this RDD's storage level to persist its values across operations after the first time
- * it is computed. Can only be called once on each RDD.
+ * it is computed. This can only be used to assign a new storage level if the RDD does not
+ * have a storage level set yet..
*/
def persist(newLevel: StorageLevel): RDD[T] = {
// TODO: Handle changes of StorageLevel
@@ -137,6 +141,15 @@ abstract class RDD[T: ClassManifest](
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): RDD[T] = persist()
+ /** Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. */
+ def unpersist(): RDD[T] = {
+ logInfo("Removing RDD " + id + " from persistence list")
+ sc.env.blockManager.master.removeRdd(id)
+ sc.persistentRdds.remove(id)
+ storageLevel = StorageLevel.NONE
+ this
+ }
+
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
@@ -366,7 +379,7 @@ abstract class RDD[T: ClassManifest](
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
- @deprecated("use mapPartitionsWithIndex")
+ @deprecated("use mapPartitionsWithIndex", "0.7.0")
def mapPartitionsWithSplit[U: ClassManifest](
f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] =
@@ -436,6 +449,31 @@ abstract class RDD[T: ClassManifest](
*/
def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
+ /**
+ * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
+ * applying a function to the zipped partitions. Assumes that all the RDDs have the
+ * *same number of partitions*, but does *not* require them to have the same number
+ * of elements in each partition.
+ */
+ def zipPartitions[B: ClassManifest, V: ClassManifest](
+ f: (Iterator[T], Iterator[B]) => Iterator[V],
+ rdd2: RDD[B]): RDD[V] =
+ new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2)
+
+ def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest](
+ f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V],
+ rdd2: RDD[B],
+ rdd3: RDD[C]): RDD[V] =
+ new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3)
+
+ def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest](
+ f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
+ rdd2: RDD[B],
+ rdd3: RDD[C],
+ rdd4: RDD[D]): RDD[V] =
+ new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4)
+
+
// Actions (launch a job to return a value to the user program)
/**
diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala
index d00092e984..57e0405fb4 100644
--- a/core/src/main/scala/spark/RDDCheckpointData.scala
+++ b/core/src/main/scala/spark/RDDCheckpointData.scala
@@ -1,6 +1,7 @@
package spark
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.conf.Configuration
import rdd.{CheckpointRDD, CoalescedRDD}
import scheduler.{ResultTask, ShuffleMapTask}
@@ -62,14 +63,20 @@ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
}
}
+ // Create the output path for the checkpoint
+ val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id)
+ val fs = path.getFileSystem(new Configuration())
+ if (!fs.mkdirs(path)) {
+ throw new SparkException("Failed to create checkpoint path " + path)
+ }
+
// Save to file, and reload it as an RDD
- val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id).toString
- rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _)
- val newRDD = new CheckpointRDD[T](rdd.context, path)
+ rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _)
+ val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
// Change the dependencies and partitions of the RDD
RDDCheckpointData.synchronized {
- cpFile = Some(path)
+ cpFile = Some(path.toString)
cpRDD = Some(newRDD)
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala
index 442e9f0269..9513a00126 100644
--- a/core/src/main/scala/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/spark/ShuffleFetcher.scala
@@ -1,13 +1,16 @@
package spark
-import executor.TaskMetrics
+import spark.executor.TaskMetrics
+import spark.serializer.Serializer
+
private[spark] abstract class ShuffleFetcher {
/**
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
- def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) : Iterator[(K,V)]
+ def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
+ serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[(K,V)]
/** Stop the fetcher */
def stop() {}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 4957a54c1b..2ae4ad8659 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -1,47 +1,50 @@
package spark
import java.io._
-import java.util.concurrent.atomic.AtomicInteger
import java.net.URI
+import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.JavaConversions._
import scala.collection.Map
import scala.collection.generic.Growable
-import scala.collection.mutable.HashMap
-import scala.collection.JavaConversions._
+import scala.collection.mutable.{ConcurrentMap, HashMap}
+
+import akka.actor.Actor._
-import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.mapred.InputFormat
-import org.apache.hadoop.mapred.SequenceFileInputFormat
-import org.apache.hadoop.io.Writable
-import org.apache.hadoop.io.IntWritable
-import org.apache.hadoop.io.LongWritable
-import org.apache.hadoop.io.FloatWritable
-import org.apache.hadoop.io.DoubleWritable
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.ArrayWritable
import org.apache.hadoop.io.BooleanWritable
import org.apache.hadoop.io.BytesWritable
-import org.apache.hadoop.io.ArrayWritable
+import org.apache.hadoop.io.DoubleWritable
+import org.apache.hadoop.io.FloatWritable
+import org.apache.hadoop.io.IntWritable
+import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
+import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.FileInputFormat
+import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapred.SequenceFileInputFormat
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
-import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
+import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
+
import org.apache.mesos.MesosNativeLibrary
-import spark.deploy.LocalSparkCluster
-import spark.partial.ApproximateEvaluator
-import spark.partial.PartialResult
+import spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
+import spark.partial.{ApproximateEvaluator, PartialResult}
import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
-import spark.scheduler._
+import spark.scheduler.{DAGScheduler, ResultTask, ShuffleMapTask, SparkListener, SplitInfo, Stage, StageInfo, TaskScheduler}
+import spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, ClusterScheduler}
import spark.scheduler.local.LocalScheduler
-import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import spark.storage.BlockManagerUI
+import spark.storage.{BlockManagerUI, StorageStatus, StorageUtils, RDDInfo}
import spark.util.{MetadataCleaner, TimeStampedHashMap}
-import spark.storage.{StorageStatus, StorageUtils, RDDInfo}
+
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -59,7 +62,10 @@ class SparkContext(
val appName: String,
val sparkHome: String = null,
val jars: Seq[String] = Nil,
- val environment: Map[String, String] = Map())
+ val environment: Map[String, String] = Map(),
+ // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) too.
+ // This is typically generated from InputFormatInfo.computePreferredLocations .. host, set of data-local splits on host
+ val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map())
extends Logging {
// Ensure logging is initialized before we spawn any threads
@@ -67,7 +73,7 @@ class SparkContext(
// Set Spark driver host and port system properties
if (System.getProperty("spark.driver.host") == null) {
- System.setProperty("spark.driver.host", Utils.localIpAddress)
+ System.setProperty("spark.driver.host", Utils.localHostName())
}
if (System.getProperty("spark.driver.port") == null) {
System.setProperty("spark.driver.port", "0")
@@ -94,12 +100,14 @@ class SparkContext(
private[spark] val addedJars = HashMap[String, Long]()
// Keeps track of all persisted RDDs
- private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]()
+ private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup)
// Add each JAR given through the constructor
- jars.foreach { addJar(_) }
+ if (jars != null) {
+ jars.foreach { addJar(_) }
+ }
// Environment variables to pass to our executors
private[spark] val executorEnvs = HashMap[String, String]()
@@ -111,7 +119,9 @@ class SparkContext(
executorEnvs(key) = value
}
}
- executorEnvs ++= environment
+ if (environment != null) {
+ executorEnvs ++= environment
+ }
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
@@ -164,6 +174,22 @@ class SparkContext(
}
scheduler
+ case "yarn-standalone" =>
+ val scheduler = try {
+ val clazz = Class.forName("spark.scheduler.cluster.YarnClusterScheduler")
+ val cons = clazz.getConstructor(classOf[SparkContext])
+ cons.newInstance(this).asInstanceOf[ClusterScheduler]
+ } catch {
+ // TODO: Enumerate the exact reasons why it can fail
+ // But irrespective of it, it means we cannot proceed !
+ case th: Throwable => {
+ throw new SparkException("YARN mode not available ?", th)
+ }
+ }
+ val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem)
+ scheduler.initialize(backend)
+ scheduler
+
case _ =>
if (MESOS_REGEX.findFirstIn(master).isEmpty) {
logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
@@ -183,12 +209,12 @@ class SparkContext(
}
taskScheduler.start()
- private var dagScheduler = new DAGScheduler(taskScheduler)
+ @volatile private var dagScheduler = new DAGScheduler(taskScheduler)
dagScheduler.start()
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
- val conf = new Configuration()
+ val conf = SparkHadoopUtil.newConfiguration()
// Explicitly check for S3 environment variables
if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
@@ -207,6 +233,9 @@ class SparkContext(
private[spark] var checkpointDir: Option[String] = None
+ // Post init
+ taskScheduler.postStartHook()
+
// Methods for creating RDDs
/** Distribute a local Scala collection to form an RDD. */
@@ -471,7 +500,7 @@ class SparkContext(
*/
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
- (blockManagerId.ip + ":" + blockManagerId.port, mem)
+ (blockManagerId.host + ":" + blockManagerId.port, mem)
}
}
@@ -479,7 +508,7 @@ class SparkContext(
* Return information about what RDDs are cached, if they are in mem or on disk, how much space
* they take, etc.
*/
- def getRDDStorageInfo : Array[RDDInfo] = {
+ def getRDDStorageInfo: Array[RDDInfo] = {
StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
}
@@ -490,7 +519,7 @@ class SparkContext(
/**
* Return information about blocks stored in all of the slaves
*/
- def getExecutorStorageStatus : Array[StorageStatus] = {
+ def getExecutorStorageStatus: Array[StorageStatus] = {
env.blockManager.master.getStorageStatus
}
@@ -527,10 +556,13 @@ class SparkContext(
/** Shut down the SparkContext. */
def stop() {
- if (dagScheduler != null) {
+ // Do this only if not stopped already - best case effort.
+ // prevent NPE if stopped more than once.
+ val dagSchedulerCopy = dagScheduler
+ dagScheduler = null
+ if (dagSchedulerCopy != null) {
metadataCleaner.cancel()
- dagScheduler.stop()
- dagScheduler = null
+ dagSchedulerCopy.stop()
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
@@ -546,6 +578,7 @@ class SparkContext(
}
}
+
/**
* Get Spark's home location from either a value set through the constructor,
* or the spark.home Java property, or the SPARK_HOME environment variable
@@ -685,7 +718,7 @@ class SparkContext(
*/
def setCheckpointDir(dir: String, useExisting: Boolean = false) {
val path = new Path(dir)
- val fs = path.getFileSystem(new Configuration())
+ val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration())
if (!useExisting) {
if (fs.exists(path)) {
throw new Exception("Checkpoint directory '" + path + "' already exists.")
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 7157fd2688..2fa97cd829 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -3,13 +3,14 @@ package spark
import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
import akka.remote.RemoteActorRefProvider
-import serializer.Serializer
import spark.broadcast.BroadcastManager
import spark.storage.BlockManager
import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager
+import spark.serializer.{Serializer, SerializerManager}
import spark.util.AkkaUtils
+
/**
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
* including the serializer, Akka actor system, block manager, map output tracker, etc. Currently
@@ -20,6 +21,7 @@ import spark.util.AkkaUtils
class SparkEnv (
val executorId: String,
val actorSystem: ActorSystem,
+ val serializerManager: SerializerManager,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
@@ -72,6 +74,16 @@ object SparkEnv extends Logging {
System.setProperty("spark.driver.port", boundPort.toString)
}
+ // set only if unset until now.
+ if (System.getProperty("spark.hostPort", null) == null) {
+ if (!isDriver){
+ // unexpected
+ Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set")
+ }
+ Utils.checkHost(hostname)
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+ }
+
val classLoader = Thread.currentThread.getContextClassLoader
// Create an instance of the class named by the given Java system property, or by
@@ -81,16 +93,23 @@ object SparkEnv extends Logging {
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
}
- val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
-
+ val serializerManager = new SerializerManager
+
+ val serializer = serializerManager.setDefault(
+ System.getProperty("spark.serializer", "spark.JavaSerializer"))
+
+ val closureSerializer = serializerManager.get(
+ System.getProperty("spark.closure.serializer", "spark.JavaSerializer"))
+
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
if (isDriver) {
logInfo("Registering " + name)
actorSystem.actorOf(Props(newActor), name = name)
} else {
- val driverIp: String = System.getProperty("spark.driver.host", "localhost")
+ val driverHost: String = System.getProperty("spark.driver.host", "localhost")
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
- val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, name)
+ Utils.checkHost(driverHost, "Expected hostname")
+ val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name)
logInfo("Connecting to " + name + ": " + url)
actorSystem.actorFor(url)
}
@@ -105,9 +124,6 @@ object SparkEnv extends Logging {
val broadcastManager = new BroadcastManager(isDriver)
- val closureSerializer = instantiateClass[Serializer](
- "spark.closure.serializer", "spark.JavaSerializer")
-
val cacheManager = new CacheManager(blockManager)
// Have to assign trackerActor after initialization as MapOutputTrackerActor
@@ -142,6 +158,7 @@ object SparkEnv extends Logging {
new SparkEnv(
executorId,
actorSystem,
+ serializerManager,
serializer,
closureSerializer,
cacheManager,
@@ -153,5 +170,5 @@ object SparkEnv extends Logging {
httpFileServer,
sparkFilesDir)
}
-
+
}
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 81daacf958..9f48cbe490 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -1,18 +1,18 @@
package spark
import java.io._
-import java.net._
+import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket}
import java.util.{Locale, Random, UUID}
-import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
-import org.apache.hadoop.conf.Configuration
+import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor}
import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.JavaConversions._
import scala.io.Source
import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
-import scala.Some
import spark.serializer.SerializerInstance
+import spark.deploy.SparkHadoopUtil
+import java.util.regex.Pattern
/**
* Various utility methods used by Spark.
@@ -68,6 +68,41 @@ private object Utils extends Logging {
return buf
}
+
+ private val shutdownDeletePaths = new collection.mutable.HashSet[String]()
+
+ // Register the path to be deleted via shutdown hook
+ def registerShutdownDeleteDir(file: File) {
+ val absolutePath = file.getAbsolutePath()
+ shutdownDeletePaths.synchronized {
+ shutdownDeletePaths += absolutePath
+ }
+ }
+
+ // Is the path already registered to be deleted via a shutdown hook ?
+ def hasShutdownDeleteDir(file: File): Boolean = {
+ val absolutePath = file.getAbsolutePath()
+ shutdownDeletePaths.synchronized {
+ shutdownDeletePaths.contains(absolutePath)
+ }
+ }
+
+ // Note: if file is child of some registered path, while not equal to it, then return true; else false
+ // This is to ensure that two shutdown hooks do not try to delete each others paths - resulting in IOException
+ // and incomplete cleanup
+ def hasRootAsShutdownDeleteDir(file: File): Boolean = {
+
+ val absolutePath = file.getAbsolutePath()
+
+ val retval = shutdownDeletePaths.synchronized {
+ shutdownDeletePaths.find(path => ! absolutePath.equals(path) && absolutePath.startsWith(path) ).isDefined
+ }
+
+ if (retval) logInfo("path = " + file + ", already present as root for deletion.")
+
+ retval
+ }
+
/** Create a temporary directory inside the given parent directory */
def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = {
var attempts = 0
@@ -86,10 +121,14 @@ private object Utils extends Logging {
}
} catch { case e: IOException => ; }
}
+
+ registerShutdownDeleteDir(dir)
+
// Add a shutdown hook to delete the temp dir when the JVM exits
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) {
override def run() {
- Utils.deleteRecursively(dir)
+ // Attempt to delete if some patch which is parent of this is not already registered.
+ if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir)
}
})
return dir
@@ -168,7 +207,7 @@ private object Utils extends Logging {
case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
val uri = new URI(url)
- val conf = new Configuration()
+ val conf = SparkHadoopUtil.newConfiguration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(tempFile)
@@ -227,8 +266,10 @@ private object Utils extends Logging {
/**
* Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4).
+ * Note, this is typically not used from within core spark.
*/
lazy val localIpAddress: String = findLocalIpAddress()
+ lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress)
private def findLocalIpAddress(): String = {
val defaultIpOverride = System.getenv("SPARK_LOCAL_IP")
@@ -266,6 +307,8 @@ private object Utils extends Logging {
* hostname it reports to the master.
*/
def setCustomHostname(hostname: String) {
+ // DEBUG code
+ Utils.checkHost(hostname)
customHostname = Some(hostname)
}
@@ -273,7 +316,97 @@ private object Utils extends Logging {
* Get the local machine's hostname.
*/
def localHostName(): String = {
- customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
+ customHostname.getOrElse(localIpAddressHostname)
+ }
+
+ def getAddressHostName(address: String): String = {
+ InetAddress.getByName(address).getHostName
+ }
+
+
+
+ def localHostPort(): String = {
+ val retval = System.getProperty("spark.hostPort", null)
+ if (retval == null) {
+ logErrorWithStack("spark.hostPort not set but invoking localHostPort")
+ return localHostName()
+ }
+
+ retval
+ }
+
+ /*
+ // Used by DEBUG code : remove when all testing done
+ private val ipPattern = Pattern.compile("^[0-9]+(\\.[0-9]+)*$")
+ def checkHost(host: String, message: String = "") {
+ // Currently catches only ipv4 pattern, this is just a debugging tool - not rigourous !
+ // if (host.matches("^[0-9]+(\\.[0-9]+)*$")) {
+ if (ipPattern.matcher(host).matches()) {
+ Utils.logErrorWithStack("Unexpected to have host " + host + " which matches IP pattern. Message " + message)
+ }
+ if (Utils.parseHostPort(host)._2 != 0){
+ Utils.logErrorWithStack("Unexpected to have host " + host + " which has port in it. Message " + message)
+ }
+ }
+
+ // Used by DEBUG code : remove when all testing done
+ def checkHostPort(hostPort: String, message: String = "") {
+ val (host, port) = Utils.parseHostPort(hostPort)
+ checkHost(host)
+ if (port <= 0){
+ Utils.logErrorWithStack("Unexpected to have port " + port + " which is not valid in " + hostPort + ". Message " + message)
+ }
+ }
+ */
+
+ // Once testing is complete in various modes, replace with this ?
+ def checkHost(host: String, message: String = "") {}
+ def checkHostPort(hostPort: String, message: String = "") {}
+
+ def getUserNameFromEnvironment(): String = {
+ SparkHadoopUtil.getUserNameFromEnvironment
+ }
+
+ // Used by DEBUG code : remove when all testing done
+ def logErrorWithStack(msg: String) {
+ try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } }
+ // temp code for debug
+ System.exit(-1)
+ }
+
+ // Typically, this will be of order of number of nodes in cluster
+ // If not, we should change it to LRUCache or something.
+ private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]()
+ def parseHostPort(hostPort: String): (String, Int) = {
+ {
+ // Check cache first.
+ var cached = hostPortParseResults.get(hostPort)
+ if (cached != null) return cached
+ }
+
+ val indx: Int = hostPort.lastIndexOf(':')
+ // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... but then hadoop does not support ipv6 right now.
+ // For now, we assume that if port exists, then it is valid - not check if it is an int > 0
+ if (-1 == indx) {
+ val retval = (hostPort, 0)
+ hostPortParseResults.put(hostPort, retval)
+ return retval
+ }
+
+ val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt)
+ hostPortParseResults.putIfAbsent(hostPort, retval)
+ hostPortParseResults.get(hostPort)
+ }
+
+ def addIfNoPort(hostPort: String, port: Int): String = {
+ if (port <= 0) throw new IllegalArgumentException("Invalid port specified " + port)
+
+ // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... but then hadoop does not support ipv6 right now.
+ // For now, we assume that if port exists, then it is valid - not check if it is an int > 0
+ val indx: Int = hostPort.lastIndexOf(':')
+ if (-1 != indx) return hostPort
+
+ hostPort + ":" + port
}
private[spark] val daemonThreadFactory: ThreadFactory =
diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala
index e29f1e5899..eb81ed64cd 100644
--- a/core/src/main/scala/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDD.scala
@@ -14,12 +14,18 @@ JavaRDDLike[T, JavaRDD[T]] {
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaRDD[T] = wrapRDD(rdd.cache())
- /**
+ /**
* Set this RDD's storage level to persist its values across operations after the first time
- * it is computed. Can only be called once on each RDD.
+ * it is computed. This can only be used to assign a new storage level if the RDD does not
+ * have a storage level set yet..
*/
def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel))
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ */
+ def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist())
+
// Transformations (return a new RDD)
/**
@@ -31,7 +37,7 @@ JavaRDDLike[T, JavaRDD[T]] {
* Return a new RDD containing the distinct elements in this RDD.
*/
def distinct(numPartitions: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numPartitions))
-
+
/**
* Return a new RDD containing only the elements that satisfy a predicate.
*/
@@ -54,7 +60,7 @@ JavaRDDLike[T, JavaRDD[T]] {
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
wrapRDD(rdd.sample(withReplacement, fraction, seed))
-
+
/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
@@ -63,7 +69,7 @@ JavaRDDLike[T, JavaRDD[T]] {
/**
* Return an RDD with the elements from `this` that are not in `other`.
- *
+ *
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index d884529d7a..9b74d1226f 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -182,6 +182,21 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classManifest))(classManifest, other.classManifest)
}
+ /**
+ * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
+ * applying a function to the zipped partitions. Assumes that all the RDDs have the
+ * *same number of partitions*, but does *not* require them to have the same number
+ * of elements in each partition.
+ */
+ def zipPartitions[U, V](
+ f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V],
+ other: JavaRDDLike[U, _]): JavaRDD[V] = {
+ def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator(
+ f.apply(asJavaIterator(x), asJavaIterator(y)).iterator())
+ JavaRDD.fromRDD(
+ rdd.zipPartitions(fn, other.rdd)(other.classManifest, f.elementType()))(f.elementType())
+ }
+
// Actions (launch a job to return a value to the user program)
/**
diff --git a/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala
new file mode 100644
index 0000000000..6044043add
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala
@@ -0,0 +1,11 @@
+package spark.api.java.function
+
+/**
+ * A function that takes two inputs and returns zero or more output records.
+ */
+abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
+ @throws(classOf[Exception])
+ def call(a: A, b:B) : java.lang.Iterable[C]
+
+ def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]]
+}
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 9b4d54ab4e..807119ca8c 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -277,6 +277,8 @@ private class BytesToString extends spark.api.java.function.Function[Array[Byte]
*/
class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
+
+ Utils.checkHost(serverHost, "Expected hostname")
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
diff --git a/core/src/main/scala/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/spark/deploy/ApplicationDescription.scala
index 6659e53b25..b6b9f9bf9d 100644
--- a/core/src/main/scala/spark/deploy/ApplicationDescription.scala
+++ b/core/src/main/scala/spark/deploy/ApplicationDescription.scala
@@ -2,7 +2,7 @@ package spark.deploy
private[spark] class ApplicationDescription(
val name: String,
- val cores: Int,
+ val maxCores: Int, /* Integer.MAX_VALUE denotes an unlimited number of cores */
val memoryPerSlave: Int,
val command: Command,
val sparkHome: String)
diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala
index 8a3e64e4c2..51274acb1e 100644
--- a/core/src/main/scala/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/spark/deploy/DeployMessage.scala
@@ -4,6 +4,7 @@ import spark.deploy.ExecutorState.ExecutorState
import spark.deploy.master.{WorkerInfo, ApplicationInfo}
import spark.deploy.worker.ExecutorRunner
import scala.collection.immutable.List
+import spark.Utils
private[spark] sealed trait DeployMessage extends Serializable
@@ -19,7 +20,10 @@ case class RegisterWorker(
memory: Int,
webUiPort: Int,
publicAddress: String)
- extends DeployMessage
+ extends DeployMessage {
+ Utils.checkHost(host, "Required hostname")
+ assert (port > 0)
+}
private[spark]
case class ExecutorStateChanged(
@@ -58,7 +62,9 @@ private[spark]
case class RegisteredApplication(appId: String) extends DeployMessage
private[spark]
-case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int)
+case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) {
+ Utils.checkHostPort(hostPort, "Required hostport")
+}
private[spark]
case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String],
@@ -81,6 +87,9 @@ private[spark]
case class MasterState(host: String, port: Int, workers: Array[WorkerInfo],
activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) {
+ Utils.checkHost(host, "Required hostname")
+ assert (port > 0)
+
def uri = "spark://" + host + ":" + port
}
@@ -92,4 +101,8 @@ private[spark] case object RequestWorkerState
private[spark]
case class WorkerState(host: String, port: Int, workerId: String, executors: List[ExecutorRunner],
finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int,
- coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String)
+ coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) {
+
+ Utils.checkHost(host, "Required hostname")
+ assert (port > 0)
+}
diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala
index 38a6ebfc24..ea832101d2 100644
--- a/core/src/main/scala/spark/deploy/JsonProtocol.scala
+++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala
@@ -12,6 +12,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
def write(obj: WorkerInfo) = JsObject(
"id" -> JsString(obj.id),
"host" -> JsString(obj.host),
+ "port" -> JsNumber(obj.port),
"webuiaddress" -> JsString(obj.webUiAddress),
"cores" -> JsNumber(obj.cores),
"coresused" -> JsNumber(obj.coresUsed),
@@ -25,7 +26,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
"starttime" -> JsNumber(obj.startTime),
"id" -> JsString(obj.id),
"name" -> JsString(obj.desc.name),
- "cores" -> JsNumber(obj.desc.cores),
+ "cores" -> JsNumber(obj.desc.maxCores),
"user" -> JsString(obj.desc.user),
"memoryperslave" -> JsNumber(obj.desc.memoryPerSlave),
"submitdate" -> JsString(obj.submitDate.toString))
@@ -34,7 +35,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
implicit object AppDescriptionJsonFormat extends RootJsonWriter[ApplicationDescription] {
def write(obj: ApplicationDescription) = JsObject(
"name" -> JsString(obj.name),
- "cores" -> JsNumber(obj.cores),
+ "cores" -> JsNumber(obj.maxCores),
"memoryperslave" -> JsNumber(obj.memoryPerSlave),
"user" -> JsString(obj.user)
)
diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
index 22319a96ca..55bb61b0cc 100644
--- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
@@ -18,7 +18,7 @@ import scala.collection.mutable.ArrayBuffer
private[spark]
class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
- private val localIpAddress = Utils.localIpAddress
+ private val localHostname = Utils.localHostName()
private val masterActorSystems = ArrayBuffer[ActorSystem]()
private val workerActorSystems = ArrayBuffer[ActorSystem]()
@@ -26,13 +26,13 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
- val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0)
+ val (masterSystem, masterPort) = Master.startSystemAndActor(localHostname, 0, 0)
masterActorSystems += masterSystem
- val masterUrl = "spark://" + localIpAddress + ":" + masterPort
+ val masterUrl = "spark://" + localHostname + ":" + masterPort
/* Start the Workers */
for (workerNum <- 1 to numWorkers) {
- val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker,
+ val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker,
memoryPerWorker, masterUrl, null, Some(workerNum))
workerActorSystems += workerSystem
}
diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala
index 2fc5e657f9..4af44f9c16 100644
--- a/core/src/main/scala/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/spark/deploy/client/Client.scala
@@ -3,6 +3,7 @@ package spark.deploy.client
import spark.deploy._
import akka.actor._
import akka.pattern.ask
+import akka.util.Duration
import akka.util.duration._
import akka.pattern.AskTimeoutException
import spark.{SparkException, Logging}
@@ -59,10 +60,10 @@ private[spark] class Client(
markDisconnected()
context.stop(self)
- case ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) =>
+ case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) =>
val fullId = appId + "/" + id
- logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, host, cores))
- listener.executorAdded(fullId, workerId, host, cores, memory)
+ logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores))
+ listener.executorAdded(fullId, workerId, hostPort, cores, memory)
case ExecutorUpdated(id, state, message, exitStatus) =>
val fullId = appId + "/" + id
@@ -112,7 +113,7 @@ private[spark] class Client(
def stop() {
if (actor != null) {
try {
- val timeout = 5.seconds
+ val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
val future = actor.ask(StopClient)(timeout)
Await.result(future, timeout)
} catch {
diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala
index b7008321df..e8c4083f9d 100644
--- a/core/src/main/scala/spark/deploy/client/ClientListener.scala
+++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala
@@ -12,7 +12,7 @@ private[spark] trait ClientListener {
def disconnected(): Unit
- def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit
+ def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit
def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit
}
diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala
index dc004b59ca..ad92532b58 100644
--- a/core/src/main/scala/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/spark/deploy/client/TestClient.scala
@@ -16,7 +16,7 @@ private[spark] object TestClient {
System.exit(0)
}
- def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) {}
+ def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {}
def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {}
}
diff --git a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala
index 3591a94072..70e5caab66 100644
--- a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala
@@ -37,7 +37,7 @@ private[spark] class ApplicationInfo(
coresGranted -= exec.cores
}
- def coresLeft: Int = desc.cores - coresGranted
+ def coresLeft: Int = desc.maxCores - coresGranted
private var _retryCount = 0
diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala
index 71b9d0801d..707fe57983 100644
--- a/core/src/main/scala/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/spark/deploy/master/Master.scala
@@ -15,7 +15,7 @@ import spark.{Logging, SparkException, Utils}
import spark.util.AkkaUtils
-private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
+private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
@@ -35,9 +35,11 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
var firstApp: Option[ApplicationInfo] = None
+ Utils.checkHost(host, "Expected hostname")
+
val masterPublicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
- if (envVar != null) envVar else ip
+ if (envVar != null) envVar else host
}
// As a temporary workaround before better ways of configuring memory, we allow users to set
@@ -46,7 +48,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean
override def preStart() {
- logInfo("Starting Spark master at spark://" + ip + ":" + port)
+ logInfo("Starting Spark master at spark://" + host + ":" + port)
// Listen for remote client disconnection events, since they don't go through Akka's watch()
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
startWebUi()
@@ -145,7 +147,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
case RequestMasterState => {
- sender ! MasterState(ip, port, workers.toArray, apps.toArray, completedApps.toArray)
+ sender ! MasterState(host, port, workers.toArray, apps.toArray, completedApps.toArray)
}
}
@@ -211,13 +213,13 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
worker.actor ! LaunchExecutor(exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome)
- exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
+ exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)
}
def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
publicAddress: String): WorkerInfo = {
// There may be one or more refs to dead workers on this same node (w/ different ID's), remove them.
- workers.filter(w => (w.host == host) && (w.state == WorkerState.DEAD)).foreach(workers -= _)
+ workers.filter(w => (w.host == host && w.port == port) && (w.state == WorkerState.DEAD)).foreach(workers -= _)
val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
workers += worker
idToWorker(worker.id) = worker
@@ -273,6 +275,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
for (exec <- app.executors.values) {
exec.worker.removeExecutor(exec)
exec.worker.actor ! KillExecutor(exec.application.id, exec.id)
+ exec.state = ExecutorState.KILLED
}
app.markFinished(state)
app.driver ! ApplicationRemoved(state.toString)
@@ -307,7 +310,7 @@ private[spark] object Master {
def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings)
- val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort)
+ val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort)
actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/spark/deploy/master/MasterArguments.scala
index 4ceab3fc03..3d28ecabb4 100644
--- a/core/src/main/scala/spark/deploy/master/MasterArguments.scala
+++ b/core/src/main/scala/spark/deploy/master/MasterArguments.scala
@@ -7,13 +7,13 @@ import spark.Utils
* Command-line parser for the master.
*/
private[spark] class MasterArguments(args: Array[String]) {
- var ip = Utils.localHostName()
+ var host = Utils.localHostName()
var port = 7077
var webUiPort = 8080
// Check for settings in environment variables
- if (System.getenv("SPARK_MASTER_IP") != null) {
- ip = System.getenv("SPARK_MASTER_IP")
+ if (System.getenv("SPARK_MASTER_HOST") != null) {
+ host = System.getenv("SPARK_MASTER_HOST")
}
if (System.getenv("SPARK_MASTER_PORT") != null) {
port = System.getenv("SPARK_MASTER_PORT").toInt
@@ -26,7 +26,13 @@ private[spark] class MasterArguments(args: Array[String]) {
def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
- ip = value
+ Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
+ host = value
+ parse(tail)
+
+ case ("--host" | "-h") :: value :: tail =>
+ Utils.checkHost(value, "Please use hostname " + value)
+ host = value
parse(tail)
case ("--port" | "-p") :: IntParam(value) :: tail =>
@@ -54,7 +60,8 @@ private[spark] class MasterArguments(args: Array[String]) {
"Usage: Master [options]\n" +
"\n" +
"Options:\n" +
- " -i IP, --ip IP IP address or DNS name to listen on\n" +
+ " -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" +
+ " -h HOST, --host HOST Hostname to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: 7077)\n" +
" --webui-port PORT Port for web UI (default: 8080)")
System.exit(exitCode)
diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
index 54faa375fb..a4e21c8130 100644
--- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
+++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
@@ -3,7 +3,7 @@ package spark.deploy.master
import akka.actor.{ActorRef, ActorSystem}
import akka.dispatch.Await
import akka.pattern.ask
-import akka.util.Timeout
+import akka.util.{Duration, Timeout}
import akka.util.duration._
import cc.spray.Directives
import cc.spray.directives._
@@ -22,7 +22,7 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
val RESOURCE_DIR = "spark/deploy/master/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
- implicit val timeout = Timeout(10 seconds)
+ implicit val timeout = Timeout(Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds"))
val handler = {
get {
diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
index 23df1bb463..0c08c5f417 100644
--- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
@@ -2,6 +2,7 @@ package spark.deploy.master
import akka.actor.ActorRef
import scala.collection.mutable
+import spark.Utils
private[spark] class WorkerInfo(
val id: String,
@@ -13,6 +14,9 @@ private[spark] class WorkerInfo(
val webUiPort: Int,
val publicAddress: String) {
+ Utils.checkHost(host, "Expected hostname")
+ assert (port > 0)
+
var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info
var state: WorkerState.Value = WorkerState.ALIVE
var coresUsed = 0
@@ -23,6 +27,11 @@ private[spark] class WorkerInfo(
def coresFree: Int = cores - coresUsed
def memoryFree: Int = memory - memoryUsed
+ def hostPort: String = {
+ assert (port > 0)
+ host + ":" + port
+ }
+
def addExecutor(exec: ExecutorInfo) {
executors(exec.fullId) = exec
coresUsed += exec.cores
diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
index de11771c8e..04a774658e 100644
--- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
@@ -21,11 +21,13 @@ private[spark] class ExecutorRunner(
val memory: Int,
val worker: ActorRef,
val workerId: String,
- val hostname: String,
+ val hostPort: String,
val sparkHome: File,
val workDir: File)
extends Logging {
+ Utils.checkHostPort(hostPort, "Expected hostport")
+
val fullId = appId + "/" + execId
var workerThread: Thread = null
var process: Process = null
@@ -68,7 +70,7 @@ private[spark] class ExecutorRunner(
/** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */
def substituteVariables(argument: String): String = argument match {
case "{{EXECUTOR_ID}}" => execId.toString
- case "{{HOSTNAME}}" => hostname
+ case "{{HOSTNAME}}" => Utils.parseHostPort(hostPort)._1
case "{{CORES}}" => cores.toString
case other => other
}
diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala
index 8919d1261c..1a7da0f7bf 100644
--- a/core/src/main/scala/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/spark/deploy/worker/Worker.scala
@@ -16,7 +16,7 @@ import spark.deploy.master.Master
import java.io.File
private[spark] class Worker(
- ip: String,
+ host: String,
port: Int,
webUiPort: Int,
cores: Int,
@@ -25,6 +25,9 @@ private[spark] class Worker(
workDirPath: String = null)
extends Actor with Logging {
+ Utils.checkHost(host, "Expected hostname")
+ assert (port > 0)
+
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
@@ -39,7 +42,7 @@ private[spark] class Worker(
val finishedExecutors = new HashMap[String, ExecutorRunner]
val publicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
- if (envVar != null) envVar else ip
+ if (envVar != null) envVar else host
}
var coresUsed = 0
@@ -51,10 +54,11 @@ private[spark] class Worker(
def createWorkDir() {
workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work"))
try {
- if (!workDir.exists() && !workDir.mkdirs()) {
+ if ( (workDir.exists() && !workDir.isDirectory) || (!workDir.exists() && !workDir.mkdirs()) ) {
logError("Failed to create work directory " + workDir)
System.exit(1)
}
+ assert (workDir.isDirectory)
} catch {
case e: Exception =>
logError("Failed to create work directory " + workDir, e)
@@ -64,7 +68,7 @@ private[spark] class Worker(
override def preStart() {
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
- ip, port, cores, Utils.memoryMegabytesToString(memory)))
+ host, port, cores, Utils.memoryMegabytesToString(memory)))
sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
logInfo("Spark home: " + sparkHome)
createWorkDir()
@@ -75,7 +79,7 @@ private[spark] class Worker(
def connectToMaster() {
logInfo("Connecting to master " + masterUrl)
master = context.actorFor(Master.toAkkaUrl(masterUrl))
- master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
+ master ! RegisterWorker(workerId, host, port, cores, memory, webUiPort, publicAddress)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing
}
@@ -106,7 +110,7 @@ private[spark] class Worker(
case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
val manager = new ExecutorRunner(
- appId, execId, appDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir)
+ appId, execId, appDesc, cores_, memory_, self, workerId, host + ":" + port, new File(execSparkHome_), workDir)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
@@ -141,7 +145,7 @@ private[spark] class Worker(
masterDisconnected()
case RequestWorkerState => {
- sender ! WorkerState(ip, port, workerId, executors.values.toList,
+ sender ! WorkerState(host, port, workerId, executors.values.toList,
finishedExecutors.values.toList, masterUrl, cores, memory,
coresUsed, memoryUsed, masterWebUiUrl)
}
@@ -156,7 +160,7 @@ private[spark] class Worker(
}
def generateWorkerId(): String = {
- "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port)
+ "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), host, port)
}
override def postStop() {
@@ -167,7 +171,7 @@ private[spark] class Worker(
private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
- val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores,
+ val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
args.memory, args.master, args.workDir)
actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
index 08f02bad80..2b96611ee3 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
@@ -9,7 +9,7 @@ import java.lang.management.ManagementFactory
* Command-line parser for the master.
*/
private[spark] class WorkerArguments(args: Array[String]) {
- var ip = Utils.localHostName()
+ var host = Utils.localHostName()
var port = 0
var webUiPort = 8081
var cores = inferDefaultCores()
@@ -38,7 +38,13 @@ private[spark] class WorkerArguments(args: Array[String]) {
def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
- ip = value
+ Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
+ host = value
+ parse(tail)
+
+ case ("--host" | "-h") :: value :: tail =>
+ Utils.checkHost(value, "Please use hostname " + value)
+ host = value
parse(tail)
case ("--port" | "-p") :: IntParam(value) :: tail =>
@@ -93,7 +99,8 @@ private[spark] class WorkerArguments(args: Array[String]) {
" -c CORES, --cores CORES Number of cores to use\n" +
" -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" +
" -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" +
- " -i IP, --ip IP IP address or DNS name to listen on\n" +
+ " -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" +
+ " -h HOST, --host HOST Hostname to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: random)\n" +
" --webui-port PORT Port for web UI (default: 8081)")
System.exit(exitCode)
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
index c834f87d50..3235c50d1b 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
@@ -3,7 +3,7 @@ package spark.deploy.worker
import akka.actor.{ActorRef, ActorSystem}
import akka.dispatch.Await
import akka.pattern.ask
-import akka.util.Timeout
+import akka.util.{Duration, Timeout}
import akka.util.duration._
import cc.spray.Directives
import cc.spray.typeconversion.TwirlSupport._
@@ -22,7 +22,7 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef, workDir: File)
val RESOURCE_DIR = "spark/deploy/worker/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
- implicit val timeout = Timeout(10 seconds)
+ implicit val timeout = Timeout(Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds"))
val handler = {
get {
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index 3e7407b58d..344face5e6 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -17,7 +17,7 @@ import java.nio.ByteBuffer
* The Mesos executor for Spark.
*/
private[spark] class Executor(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) extends Logging {
-
+
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
@@ -27,6 +27,11 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
initLogging()
+ // No ip or host:port - just hostname
+ Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname")
+ // must not have port specified.
+ assert (0 == Utils.parseHostPort(slaveHostname)._2)
+
// Make sure the local hostname we report matches the cluster scheduler's name for this host
Utils.setCustomHostname(slaveHostname)
diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
index 1047f71c6a..ebe2ac68d8 100644
--- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
+++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
@@ -12,23 +12,27 @@ import spark.scheduler.cluster.RegisteredExecutor
import spark.scheduler.cluster.LaunchTask
import spark.scheduler.cluster.RegisterExecutorFailed
import spark.scheduler.cluster.RegisterExecutor
+import spark.Utils
+import spark.deploy.SparkHadoopUtil
private[spark] class StandaloneExecutorBackend(
driverUrl: String,
executorId: String,
- hostname: String,
+ hostPort: String,
cores: Int)
extends Actor
with ExecutorBackend
with Logging {
+ Utils.checkHostPort(hostPort, "Expected hostport")
+
var executor: Executor = null
var driver: ActorRef = null
override def preStart() {
logInfo("Connecting to driver: " + driverUrl)
driver = context.actorFor(driverUrl)
- driver ! RegisterExecutor(executorId, hostname, cores)
+ driver ! RegisterExecutor(executorId, hostPort, cores)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(driver) // Doesn't work with remote actors, but useful for testing
}
@@ -36,7 +40,8 @@ private[spark] class StandaloneExecutorBackend(
override def receive = {
case RegisteredExecutor(sparkProperties) =>
logInfo("Successfully registered with driver")
- executor = new Executor(executorId, hostname, sparkProperties)
+ // Make this host instead of hostPort ?
+ executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties)
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
@@ -63,11 +68,30 @@ private[spark] class StandaloneExecutorBackend(
private[spark] object StandaloneExecutorBackend {
def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
+ SparkHadoopUtil.runAsUser(run0, Tuple4[Any, Any, Any, Any] (driverUrl, executorId, hostname, cores))
+ }
+
+ // This will be run 'as' the user
+ def run0(args: Product) {
+ assert(4 == args.productArity)
+ runImpl(args.productElement(0).asInstanceOf[String],
+ args.productElement(1).asInstanceOf[String],
+ args.productElement(2).asInstanceOf[String],
+ args.productElement(3).asInstanceOf[Int])
+ }
+
+ private def runImpl(driverUrl: String, executorId: String, hostname: String, cores: Int) {
+ // Debug code
+ Utils.checkHost(hostname)
+
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0)
+ // set it
+ val sparkHostPort = hostname + ":" + boundPort
+ System.setProperty("spark.hostPort", sparkHostPort)
val actor = actorSystem.actorOf(
- Props(new StandaloneExecutorBackend(driverUrl, executorId, hostname, cores)),
+ Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)),
name = "Executor")
actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala
index 93bbb6b458..a7c56c2371 100644
--- a/core/src/main/scala/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/spark/executor/TaskMetrics.scala
@@ -49,11 +49,6 @@ class ShuffleReadMetrics extends Serializable {
var localBlocksFetched: Int = _
/**
- * Total time to read shuffle data
- */
- var shuffleReadMillis: Long = _
-
- /**
* Total time that is spent blocked waiting for shuffle to fetch data
*/
var fetchWaitTime: Long = _
diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
index d1451bc212..00a0433a44 100644
--- a/core/src/main/scala/spark/network/Connection.scala
+++ b/core/src/main/scala/spark/network/Connection.scala
@@ -13,7 +13,7 @@ import java.net._
private[spark]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
- val remoteConnectionManagerId: ConnectionManagerId) extends Logging {
+ val socketRemoteConnectionManagerId: ConnectionManagerId) extends Logging {
def this(channel_ : SocketChannel, selector_ : Selector) = {
this(channel_, selector_,
ConnectionManagerId.fromSocketAddress(
@@ -32,16 +32,43 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
val remoteAddress = getRemoteAddress()
+
+ // Read channels typically do not register for write and write does not for read
+ // Now, we do have write registering for read too (temporarily), but this is to detect
+ // channel close NOT to actually read/consume data on it !
+ // How does this work if/when we move to SSL ?
+
+ // What is the interest to register with selector for when we want this connection to be selected
+ def registerInterest()
+ // What is the interest to register with selector for when we want this connection to be de-selected
+ // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, it will be
+ // SelectionKey.OP_READ (until we fix it properly)
+ def unregisterInterest()
+
+ // On receiving a read event, should we change the interest for this channel or not ?
+ // Will be true for ReceivingConnection, false for SendingConnection.
+ def changeInterestForRead(): Boolean
+
+ // On receiving a write event, should we change the interest for this channel or not ?
+ // Will be false for ReceivingConnection, true for SendingConnection.
+ // Actually, for now, should not get triggered for ReceivingConnection
+ def changeInterestForWrite(): Boolean
+
+ def getRemoteConnectionManagerId(): ConnectionManagerId = {
+ socketRemoteConnectionManagerId
+ }
def key() = channel.keyFor(selector)
def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
- def read() {
+ // Returns whether we have to register for further reads or not.
+ def read(): Boolean = {
throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString)
}
-
- def write() {
+
+ // Returns whether we have to register for further writes or not.
+ def write(): Boolean = {
throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString)
}
@@ -64,7 +91,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
if (onExceptionCallback != null) {
onExceptionCallback(this, e)
} else {
- logError("Error in connection to " + remoteConnectionManagerId +
+ logError("Error in connection to " + getRemoteConnectionManagerId() +
" and OnExceptionCallback not registered", e)
}
}
@@ -73,7 +100,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
if (onCloseCallback != null) {
onCloseCallback(this)
} else {
- logWarning("Connection to " + remoteConnectionManagerId +
+ logWarning("Connection to " + getRemoteConnectionManagerId() +
" closed and OnExceptionCallback not registered")
}
@@ -81,7 +108,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
def changeConnectionKeyInterest(ops: Int) {
if (onKeyInterestChangeCallback != null) {
- onKeyInterestChangeCallback(this, ops)
+ onKeyInterestChangeCallback(this, ops)
} else {
throw new Exception("OnKeyInterestChangeCallback not registered")
}
@@ -122,7 +149,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
messages.synchronized{
/*messages += message*/
messages.enqueue(message)
- logDebug("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]")
+ logDebug("Added [" + message + "] to outbox for sending to [" + getRemoteConnectionManagerId() + "]")
}
}
@@ -149,9 +176,9 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
return chunk
} else {
- /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
+ /*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/
message.finishTime = System.currentTimeMillis
- logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
+ logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
"] in " + message.timeTaken )
}
}
@@ -170,15 +197,15 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
messages.enqueue(message)
nextMessageToBeUsed = nextMessageToBeUsed + 1
if (!message.started) {
- logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]")
+ logDebug("Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
message.started = true
message.startTime = System.currentTimeMillis
}
- logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]")
+ logTrace("Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]")
return chunk
} else {
message.finishTime = System.currentTimeMillis
- logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
+ logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
"] in " + message.timeTaken )
}
}
@@ -187,26 +214,39 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
}
- val outbox = new Outbox(1)
+ private val outbox = new Outbox(1)
val currentBuffers = new ArrayBuffer[ByteBuffer]()
/*channel.socket.setSendBufferSize(256 * 1024)*/
- override def getRemoteAddress() = address
+ override def getRemoteAddress() = address
+ val DEFAULT_INTEREST = SelectionKey.OP_READ
+
+ override def registerInterest() {
+ // Registering read too - does not really help in most cases, but for some
+ // it does - so let us keep it for now.
+ changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST)
+ }
+
+ override def unregisterInterest() {
+ changeConnectionKeyInterest(DEFAULT_INTEREST)
+ }
+
def send(message: Message) {
outbox.synchronized {
outbox.addMessage(message)
if (channel.isConnected) {
- changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ)
+ registerInterest()
}
}
}
+ // MUST be called within the selector loop
def connect() {
try{
- channel.connect(address)
channel.register(selector, SelectionKey.OP_CONNECT)
+ channel.connect(address)
logInfo("Initiating connection to [" + address + "]")
} catch {
case e: Exception => {
@@ -216,20 +256,33 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
}
- def finishConnect() {
+ def finishConnect(force: Boolean): Boolean = {
try {
- channel.finishConnect
- changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ)
+ // Typically, this should finish immediately since it was triggered by a connect
+ // selection - though need not necessarily always complete successfully.
+ val connected = channel.finishConnect
+ if (!force && !connected) {
+ logInfo("finish connect failed [" + address + "], " + outbox.messages.size + " messages pending")
+ return false
+ }
+
+ // Fallback to previous behavior - assume finishConnect completed
+ // This will happen only when finishConnect failed for some repeated number of times (10 or so)
+ // Is highly unlikely unless there was an unclean close of socket, etc
+ registerInterest()
logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
+ return true
} catch {
case e: Exception => {
logWarning("Error finishing connection to " + address, e)
callOnExceptionCallback(e)
+ // ignore
+ return true
}
}
}
- override def write() {
+ override def write(): Boolean = {
try{
while(true) {
if (currentBuffers.size == 0) {
@@ -239,8 +292,9 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
currentBuffers ++= chunk.buffers
}
case None => {
- changeConnectionKeyInterest(SelectionKey.OP_READ)
- return
+ // changeConnectionKeyInterest(0)
+ /*key.interestOps(0)*/
+ return false
}
}
}
@@ -254,38 +308,53 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
currentBuffers -= buffer
}
if (writtenBytes < remainingBytes) {
- return
+ // re-register for write.
+ return true
}
}
}
} catch {
case e: Exception => {
- logWarning("Error writing in connection to " + remoteConnectionManagerId, e)
+ logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
+ return false
}
}
+ // should not happen - to keep scala compiler happy
+ return true
}
- override def read() {
+ // This is a hack to determine if remote socket was closed or not.
+ // SendingConnection DOES NOT expect to receive any data - if it does, it is an error
+ // For a bunch of cases, read will return -1 in case remote socket is closed : hence we
+ // register for reads to determine that.
+ override def read(): Boolean = {
// We don't expect the other side to send anything; so, we just read to detect an error or EOF.
try {
val length = channel.read(ByteBuffer.allocate(1))
if (length == -1) { // EOF
close()
} else if (length > 0) {
- logWarning("Unexpected data read from SendingConnection to " + remoteConnectionManagerId)
+ logWarning("Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId())
}
} catch {
case e: Exception =>
- logError("Exception while reading SendingConnection to " + remoteConnectionManagerId, e)
+ logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
}
+
+ false
}
+
+ override def changeInterestForRead(): Boolean = false
+
+ override def changeInterestForWrite(): Boolean = true
}
+// Must be created within selector loop - else deadlock
private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
extends Connection(channel_, selector_) {
@@ -298,13 +367,13 @@ extends Connection(channel_, selector_) {
val newMessage = Message.create(header).asInstanceOf[BufferMessage]
newMessage.started = true
newMessage.startTime = System.currentTimeMillis
- logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]")
+ logDebug("Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
messages += ((newMessage.id, newMessage))
newMessage
}
val message = messages.getOrElseUpdate(header.id, createNewMessage)
- logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]")
+ logTrace("Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]")
message.getChunkForReceiving(header.chunkSize)
}
@@ -316,7 +385,27 @@ extends Connection(channel_, selector_) {
messages -= message.id
}
}
-
+
+ @volatile private var inferredRemoteManagerId: ConnectionManagerId = null
+ override def getRemoteConnectionManagerId(): ConnectionManagerId = {
+ val currId = inferredRemoteManagerId
+ if (currId != null) currId else super.getRemoteConnectionManagerId()
+ }
+
+ // The reciever's remote address is the local socket on remote side : which is NOT the connection manager id of the receiver.
+ // We infer that from the messages we receive on the receiver socket.
+ private def processConnectionManagerId(header: MessageChunkHeader) {
+ val currId = inferredRemoteManagerId
+ if (header.address == null || currId != null) return
+
+ val managerId = ConnectionManagerId.fromSocketAddress(header.address)
+
+ if (managerId != null) {
+ inferredRemoteManagerId = managerId
+ }
+ }
+
+
val inbox = new Inbox()
val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
var onReceiveCallback: (Connection , Message) => Unit = null
@@ -324,17 +413,18 @@ extends Connection(channel_, selector_) {
channel.register(selector, SelectionKey.OP_READ)
- override def read() {
+ override def read(): Boolean = {
try {
while (true) {
if (currentChunk == null) {
val headerBytesRead = channel.read(headerBuffer)
if (headerBytesRead == -1) {
close()
- return
+ return false
}
if (headerBuffer.remaining > 0) {
- return
+ // re-register for read event ...
+ return true
}
headerBuffer.flip
if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
@@ -342,6 +432,9 @@ extends Connection(channel_, selector_) {
}
val header = MessageChunkHeader.create(headerBuffer)
headerBuffer.clear()
+
+ processConnectionManagerId(header)
+
header.typ match {
case Message.BUFFER_MESSAGE => {
if (header.totalSize == 0) {
@@ -349,7 +442,8 @@ extends Connection(channel_, selector_) {
onReceiveCallback(this, Message.create(header))
}
currentChunk = null
- return
+ // re-register for read event ...
+ return true
} else {
currentChunk = inbox.getChunk(header).orNull
}
@@ -362,10 +456,11 @@ extends Connection(channel_, selector_) {
val bytesRead = channel.read(currentChunk.buffer)
if (bytesRead == 0) {
- return
+ // re-register for read event ...
+ return true
} else if (bytesRead == -1) {
close()
- return
+ return false
}
/*logDebug("Read " + bytesRead + " bytes for the buffer")*/
@@ -376,7 +471,7 @@ extends Connection(channel_, selector_) {
if (bufferMessage.isCompletelyReceived) {
bufferMessage.flip
bufferMessage.finishTime = System.currentTimeMillis
- logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken)
+ logDebug("Finished receiving [" + bufferMessage + "] from [" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken)
if (onReceiveCallback != null) {
onReceiveCallback(this, bufferMessage)
}
@@ -387,12 +482,31 @@ extends Connection(channel_, selector_) {
}
} catch {
case e: Exception => {
- logWarning("Error reading from connection to " + remoteConnectionManagerId, e)
+ logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
+ return false
}
}
+ // should not happen - to keep scala compiler happy
+ return true
}
def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
+
+ override def changeInterestForRead(): Boolean = true
+
+ override def changeInterestForWrite(): Boolean = {
+ throw new IllegalStateException("Unexpected invocation right now")
+ }
+
+ override def registerInterest() {
+ // Registering read too - does not really help in most cases, but for some
+ // it does - so let us keep it for now.
+ changeConnectionKeyInterest(SelectionKey.OP_READ)
+ }
+
+ override def unregisterInterest() {
+ changeConnectionKeyInterest(0)
+ }
}
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
index b6ec664d7e..0eb03630d0 100644
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -6,12 +6,12 @@ import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
import java.net._
-import java.util.concurrent.Executors
+import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor}
+import scala.collection.mutable.HashSet
import scala.collection.mutable.HashMap
import scala.collection.mutable.SynchronizedMap
import scala.collection.mutable.SynchronizedQueue
-import scala.collection.mutable.Queue
import scala.collection.mutable.ArrayBuffer
import akka.dispatch.{Await, Promise, ExecutionContext, Future}
@@ -19,6 +19,10 @@ import akka.util.Duration
import akka.util.duration._
private[spark] case class ConnectionManagerId(host: String, port: Int) {
+ // DEBUG code
+ Utils.checkHost(host)
+ assert (port > 0)
+
def toSocketAddress() = new InetSocketAddress(host, port)
}
@@ -42,19 +46,37 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
def markDone() { completionHandler(this) }
}
- val selector = SelectorProvider.provider.openSelector()
- val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt)
- val serverChannel = ServerSocketChannel.open()
- val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
- val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
- val messageStatuses = new HashMap[Int, MessageStatus]
- val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
- val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
- val sendMessageRequests = new Queue[(Message, SendingConnection)]
+ private val selector = SelectorProvider.provider.openSelector()
+
+ private val handleMessageExecutor = new ThreadPoolExecutor(
+ System.getProperty("spark.core.connection.handler.threads.min","20").toInt,
+ System.getProperty("spark.core.connection.handler.threads.max","60").toInt,
+ System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable]())
+
+ private val handleReadWriteExecutor = new ThreadPoolExecutor(
+ System.getProperty("spark.core.connection.io.threads.min","4").toInt,
+ System.getProperty("spark.core.connection.io.threads.max","32").toInt,
+ System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable]())
+
+ // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap
+ private val handleConnectExecutor = new ThreadPoolExecutor(
+ System.getProperty("spark.core.connection.connect.threads.min","1").toInt,
+ System.getProperty("spark.core.connection.connect.threads.max","8").toInt,
+ System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable]())
+
+ private val serverChannel = ServerSocketChannel.open()
+ private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
+ private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
+ private val messageStatuses = new HashMap[Int, MessageStatus]
+ private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
+ private val registerRequests = new SynchronizedQueue[SendingConnection]
implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
- var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
+ private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
serverChannel.configureBlocking(false)
serverChannel.socket.setReuseAddress(true)
@@ -65,49 +87,221 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
-
- val selectorThread = new Thread("connection-manager-thread") {
+
+ private val selectorThread = new Thread("connection-manager-thread") {
override def run() = ConnectionManager.this.run()
}
selectorThread.setDaemon(true)
selectorThread.start()
- private def run() {
- try {
- while(!selectorThread.isInterrupted) {
- for ((connectionManagerId, sendingConnection) <- connectionRequests) {
- sendingConnection.connect()
- addConnection(sendingConnection)
- connectionRequests -= connectionManagerId
+ private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
+
+ private def triggerWrite(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ writeRunnableStarted.synchronized {
+ // So that we do not trigger more write events while processing this one.
+ // The write method will re-register when done.
+ if (conn.changeInterestForWrite()) conn.unregisterInterest()
+ if (writeRunnableStarted.contains(key)) {
+ // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE)
+ return
+ }
+
+ writeRunnableStarted += key
+ }
+ handleReadWriteExecutor.execute(new Runnable {
+ override def run() {
+ var register: Boolean = false
+ try {
+ register = conn.write()
+ } finally {
+ writeRunnableStarted.synchronized {
+ writeRunnableStarted -= key
+ if (register && conn.changeInterestForWrite()) {
+ conn.registerInterest()
+ }
+ }
}
- sendMessageRequests.synchronized {
- while (!sendMessageRequests.isEmpty) {
- val (message, connection) = sendMessageRequests.dequeue
- connection.send(message)
+ }
+ } )
+ }
+
+ private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
+
+ private def triggerRead(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ readRunnableStarted.synchronized {
+ // So that we do not trigger more read events while processing this one.
+ // The read method will re-register when done.
+ if (conn.changeInterestForRead())conn.unregisterInterest()
+ if (readRunnableStarted.contains(key)) {
+ return
+ }
+
+ readRunnableStarted += key
+ }
+ handleReadWriteExecutor.execute(new Runnable {
+ override def run() {
+ var register: Boolean = false
+ try {
+ register = conn.read()
+ } finally {
+ readRunnableStarted.synchronized {
+ readRunnableStarted -= key
+ if (register && conn.changeInterestForRead()) {
+ conn.registerInterest()
+ }
}
}
+ }
+ } )
+ }
- while (!keyInterestChangeRequests.isEmpty) {
+ private def triggerConnect(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection]
+ if (conn == null) return
+
+ // prevent other events from being triggered
+ // Since we are still trying to connect, we do not need to do the additional steps in triggerWrite
+ conn.changeConnectionKeyInterest(0)
+
+ handleConnectExecutor.execute(new Runnable {
+ override def run() {
+
+ var tries: Int = 10
+ while (tries >= 0) {
+ if (conn.finishConnect(false)) return
+ // Sleep ?
+ Thread.sleep(1)
+ tries -= 1
+ }
+
+ // fallback to previous behavior : we should not really come here since this method was
+ // triggered since channel became connectable : but at times, the first finishConnect need not
+ // succeed : hence the loop to retry a few 'times'.
+ conn.finishConnect(true)
+ }
+ } )
+ }
+
+ // MUST be called within selector loop - else deadlock.
+ private def triggerForceCloseByException(key: SelectionKey, e: Exception) {
+ try {
+ key.interestOps(0)
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ // Pushing to connect threadpool
+ handleConnectExecutor.execute(new Runnable {
+ override def run() {
+ try {
+ conn.callOnExceptionCallback(e)
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+ try {
+ conn.close()
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+ }
+ })
+ }
+
+
+ def run() {
+ try {
+ while(!selectorThread.isInterrupted) {
+ while (! registerRequests.isEmpty) {
+ val conn: SendingConnection = registerRequests.dequeue
+ addListeners(conn)
+ conn.connect()
+ addConnection(conn)
+ }
+
+ while(!keyInterestChangeRequests.isEmpty) {
val (key, ops) = keyInterestChangeRequests.dequeue
- val connection = connectionsByKey(key)
- val lastOps = key.interestOps()
- key.interestOps(ops)
-
- def intToOpStr(op: Int): String = {
- val opStrs = ArrayBuffer[String]()
- if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
- if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
- if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
- if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
- if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
+
+ try {
+ if (key.isValid) {
+ val connection = connectionsByKey.getOrElse(key, null)
+ if (connection != null) {
+ val lastOps = key.interestOps()
+ key.interestOps(ops)
+
+ // hot loop - prevent materialization of string if trace not enabled.
+ if (isTraceEnabled()) {
+ def intToOpStr(op: Int): String = {
+ val opStrs = ArrayBuffer[String]()
+ if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
+ if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
+ if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
+ if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
+ if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
+ }
+
+ logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() +
+ "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
+ }
+ }
+ } else {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
}
-
- logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId +
- "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
-
}
- val selectedKeysCount = selector.select()
+ val selectedKeysCount =
+ try {
+ selector.select()
+ } catch {
+ // Explicitly only dealing with CancelledKeyException here since other exceptions should be dealt with differently.
+ case e: CancelledKeyException => {
+ // Some keys within the selectors list are invalid/closed. clear them.
+ val allKeys = selector.keys().iterator()
+
+ while (allKeys.hasNext()) {
+ val key = allKeys.next()
+ try {
+ if (! key.isValid) {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ }
+ }
+ }
+ 0
+ }
+
if (selectedKeysCount == 0) {
logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
}
@@ -115,20 +309,40 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logInfo("Selector thread was interrupted!")
return
}
-
- val selectedKeys = selector.selectedKeys().iterator()
- while (selectedKeys.hasNext()) {
- val key = selectedKeys.next
- selectedKeys.remove()
- if (key.isValid) {
- if (key.isAcceptable) {
- acceptConnection(key)
- } else if (key.isConnectable) {
- connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect()
- } else if (key.isReadable) {
- connectionsByKey(key).read()
- } else if (key.isWritable) {
- connectionsByKey(key).write()
+
+ if (0 != selectedKeysCount) {
+ val selectedKeys = selector.selectedKeys().iterator()
+ while (selectedKeys.hasNext()) {
+ val key = selectedKeys.next
+ selectedKeys.remove()
+ try {
+ if (key.isValid) {
+ if (key.isAcceptable) {
+ acceptConnection(key)
+ } else
+ if (key.isConnectable) {
+ triggerConnect(key)
+ } else
+ if (key.isReadable) {
+ triggerRead(key)
+ } else
+ if (key.isWritable) {
+ triggerWrite(key)
+ }
+ } else {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ // weird, but we saw this happening - even though key.isValid was true, key.isAcceptable would throw CancelledKeyException.
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
}
}
}
@@ -138,94 +352,116 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
}
}
- private def acceptConnection(key: SelectionKey) {
+ def acceptConnection(key: SelectionKey) {
val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
- val newChannel = serverChannel.accept()
- val newConnection = new ReceivingConnection(newChannel, selector)
- newConnection.onReceive(receiveMessage)
- newConnection.onClose(removeConnection)
- addConnection(newConnection)
- logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
- }
- private def addConnection(connection: Connection) {
- connectionsByKey += ((connection.key, connection))
- if (connection.isInstanceOf[SendingConnection]) {
- val sendingConnection = connection.asInstanceOf[SendingConnection]
- connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection))
+ var newChannel = serverChannel.accept()
+
+ // accept them all in a tight loop. non blocking accept with no processing, should be fine
+ while (newChannel != null) {
+ try {
+ val newConnection = new ReceivingConnection(newChannel, selector)
+ newConnection.onReceive(receiveMessage)
+ addListeners(newConnection)
+ addConnection(newConnection)
+ logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
+ } catch {
+ // might happen in case of issues with registering with selector
+ case e: Exception => logError("Error in accept loop", e)
+ }
+
+ newChannel = serverChannel.accept()
}
+ }
+
+ private def addListeners(connection: Connection) {
connection.onKeyInterestChange(changeConnectionKeyInterest)
connection.onException(handleConnectionError)
connection.onClose(removeConnection)
}
- private def removeConnection(connection: Connection) {
+ def addConnection(connection: Connection) {
+ connectionsByKey += ((connection.key, connection))
+ }
+
+ def removeConnection(connection: Connection) {
connectionsByKey -= connection.key
- if (connection.isInstanceOf[SendingConnection]) {
- val sendingConnection = connection.asInstanceOf[SendingConnection]
- val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId
- logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
-
- connectionsById -= sendingConnectionManagerId
-
- messageStatuses.synchronized {
- messageStatuses
- .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
- logInfo("Notifying " + status)
- status.synchronized {
- status.attempted = true
- status.acked = false
- status.markDone()
- }
+
+ try {
+ if (connection.isInstanceOf[SendingConnection]) {
+ val sendingConnection = connection.asInstanceOf[SendingConnection]
+ val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
+ logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
+
+ connectionsById -= sendingConnectionManagerId
+
+ messageStatuses.synchronized {
+ messageStatuses
+ .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
+ logInfo("Notifying " + status)
+ status.synchronized {
+ status.attempted = true
+ status.acked = false
+ status.markDone()
+ }
+ })
+
+ messageStatuses.retain((i, status) => {
+ status.connectionManagerId != sendingConnectionManagerId
})
+ }
+ } else if (connection.isInstanceOf[ReceivingConnection]) {
+ val receivingConnection = connection.asInstanceOf[ReceivingConnection]
+ val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId()
+ logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
+
+ val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId)
+ if (! sendingConnectionOpt.isDefined) {
+ logError("Corresponding SendingConnectionManagerId not found")
+ return
+ }
- messageStatuses.retain((i, status) => {
- status.connectionManagerId != sendingConnectionManagerId
- })
- }
- } else if (connection.isInstanceOf[ReceivingConnection]) {
- val receivingConnection = connection.asInstanceOf[ReceivingConnection]
- val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId
- logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
-
- val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull
- if (sendingConnectionManagerId == null) {
- logError("Corresponding SendingConnectionManagerId not found")
- return
- }
- logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId)
-
- val sendingConnection = connectionsById(sendingConnectionManagerId)
- sendingConnection.close()
- connectionsById -= sendingConnectionManagerId
-
- messageStatuses.synchronized {
- for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
- logInfo("Notifying " + s)
- s.synchronized {
- s.attempted = true
- s.acked = false
- s.markDone()
+ val sendingConnection = sendingConnectionOpt.get
+ connectionsById -= remoteConnectionManagerId
+ sendingConnection.close()
+
+ val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
+
+ assert (sendingConnectionManagerId == remoteConnectionManagerId)
+
+ messageStatuses.synchronized {
+ for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
+ logInfo("Notifying " + s)
+ s.synchronized {
+ s.attempted = true
+ s.acked = false
+ s.markDone()
+ }
}
- }
- messageStatuses.retain((i, status) => {
- status.connectionManagerId != sendingConnectionManagerId
- })
+ messageStatuses.retain((i, status) => {
+ status.connectionManagerId != sendingConnectionManagerId
+ })
+ }
}
+ } finally {
+ // So that the selection keys can be removed.
+ wakeupSelector()
}
}
- private def handleConnectionError(connection: Connection, e: Exception) {
- logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId)
+ def handleConnectionError(connection: Connection, e: Exception) {
+ logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId())
removeConnection(connection)
}
- private def changeConnectionKeyInterest(connection: Connection, ops: Int) {
- keyInterestChangeRequests += ((connection.key, ops))
+ def changeConnectionKeyInterest(connection: Connection, ops: Int) {
+ keyInterestChangeRequests += ((connection.key, ops))
+ // so that registerations happen !
+ wakeupSelector()
}
- private def receiveMessage(connection: Connection, message: Message) {
+ def receiveMessage(connection: Connection, message: Message) {
val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
val runnable = new Runnable() {
@@ -293,18 +529,22 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
- val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId,
- new SendingConnection(inetSocketAddress, selector, connectionManagerId))
- newConnection
+ val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId)
+ registerRequests.enqueue(newConnection)
+
+ newConnection
}
- val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)
- val connection = connectionsById.getOrElse(lookupKey, startNewConnection())
+ // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ...
+ // If we do re-add it, we should consistently use it everywhere I guess ?
+ val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
message.senderAddress = id.toSocketAddress()
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
- /*connection.send(message)*/
- sendMessageRequests.synchronized {
- sendMessageRequests += ((message, connection))
- }
+ connection.send(message)
+
+ wakeupSelector()
+ }
+
+ private def wakeupSelector() {
selector.wakeup()
}
@@ -337,6 +577,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logWarning("All connections not cleaned up")
}
handleMessageExecutor.shutdown()
+ handleReadWriteExecutor.shutdown()
+ handleConnectExecutor.shutdown()
logInfo("ConnectionManager stopped")
}
}
diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala
index 525751b5bf..34fac9e776 100644
--- a/core/src/main/scala/spark/network/Message.scala
+++ b/core/src/main/scala/spark/network/Message.scala
@@ -17,7 +17,8 @@ private[spark] class MessageChunkHeader(
val other: Int,
val address: InetSocketAddress) {
lazy val buffer = {
- val ip = address.getAddress.getAddress()
+ // No need to change this, at 'use' time, we do a reverse lookup of the hostname. Refer to network.Connection
+ val ip = address.getAddress.getAddress()
val port = address.getPort()
ByteBuffer.
allocate(MessageChunkHeader.HEADER_SIZE).
diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
index 9e37bdf659..43ee39c993 100644
--- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
@@ -8,6 +8,7 @@ import org.apache.hadoop.util.ReflectionUtils
import org.apache.hadoop.fs.Path
import java.io.{File, IOException, EOFException}
import java.text.NumberFormat
+import spark.deploy.SparkHadoopUtil
private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
@@ -21,13 +22,20 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: Stri
@transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
override def getPartitions: Array[Partition] = {
- val dirContents = fs.listStatus(new Path(checkpointPath))
- val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
- val numPartitions = partitionFiles.size
- if (numPartitions > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
- ! partitionFiles(numPartitions-1).endsWith(CheckpointRDD.splitIdToFile(numPartitions-1)))) {
- throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
- }
+ val cpath = new Path(checkpointPath)
+ val numPartitions =
+ // listStatus can throw exception if path does not exist.
+ if (fs.exists(cpath)) {
+ val dirContents = fs.listStatus(cpath)
+ val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
+ val numPart = partitionFiles.size
+ if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
+ ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) {
+ throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
+ }
+ numPart
+ } else 0
+
Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
}
@@ -58,7 +66,7 @@ private[spark] object CheckpointRDD extends Logging {
def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
val outputDir = new Path(path)
- val fs = outputDir.getFileSystem(new Configuration())
+ val fs = outputDir.getFileSystem(SparkHadoopUtil.newConfiguration())
val finalOutputName = splitIdToFile(ctx.splitId)
val finalOutputPath = new Path(outputDir, finalOutputName)
@@ -83,6 +91,7 @@ private[spark] object CheckpointRDD extends Logging {
if (!fs.rename(tempOutputPath, finalOutputPath)) {
if (!fs.exists(finalOutputPath)) {
+ logInfo("Deleting tempOutputPath " + tempOutputPath)
fs.delete(tempOutputPath, false)
throw new IOException("Checkpoint failed: failed to save output of task: "
+ ctx.attemptId + " and final output path does not exist")
@@ -95,7 +104,7 @@ private[spark] object CheckpointRDD extends Logging {
}
def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
- val fs = path.getFileSystem(new Configuration())
+ val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration())
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val fileInputStream = fs.open(path, bufferSize)
val serializer = SparkEnv.get.serializer.newInstance()
@@ -117,11 +126,11 @@ private[spark] object CheckpointRDD extends Logging {
val sc = new SparkContext(cluster, "CheckpointRDD Test")
val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
val path = new Path(hdfsPath, "temp")
- val fs = path.getFileSystem(new Configuration())
+ val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration())
sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same")
- fs.delete(path)
+ fs.delete(path, true)
}
}
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 9213513e80..7599ba1a02 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -29,7 +29,7 @@ private[spark] case class NarrowCoGroupSplitDep(
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
private[spark]
-class CoGroupPartition(idx: Int, val deps: Seq[CoGroupSplitDep])
+class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
extends Partition with Serializable {
override val index: Int = idx
override def hashCode(): Int = idx
@@ -54,7 +54,8 @@ private[spark] class CoGroupAggregator
class CoGroupedRDD[K](
@transient var rdds: Seq[RDD[(K, _)]],
part: Partitioner,
- val mapSideCombine: Boolean = true)
+ val mapSideCombine: Boolean = true,
+ val serializerClass: String = null)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
private val aggr = new CoGroupAggregator
@@ -68,9 +69,9 @@ class CoGroupedRDD[K](
logInfo("Adding shuffle dependency with " + rdd)
if (mapSideCombine) {
val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
- new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
+ new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part, serializerClass)
} else {
- new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part)
+ new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part, serializerClass)
}
}
}
@@ -88,7 +89,7 @@ class CoGroupedRDD[K](
case _ =>
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
}
- }.toList)
+ }.toArray)
}
array
}
@@ -112,6 +113,7 @@ class CoGroupedRDD[K](
}
}
+ val ser = SparkEnv.get.serializerManager.get(serializerClass)
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
// Read them from the parent
@@ -124,12 +126,12 @@ class CoGroupedRDD[K](
val fetcher = SparkEnv.get.shuffleFetcher
if (mapSideCombine) {
// With map side combine on, for each key, the shuffle fetcher returns a list of values.
- fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics).foreach {
+ fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
case (key, values) => getSeq(key)(depNum) ++= values
}
} else {
// With map side combine off, for each key the shuffle fetcher returns a single value.
- fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics).foreach {
+ fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics, ser).foreach {
case (key, value) => getSeq(key)(depNum) += value
}
}
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index bdd974590a..901d01ef30 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -57,7 +57,7 @@ class NewHadoopRDD[K, V](
override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopPartition]
val conf = confBroadcast.value.value
- val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
+ val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
if (format.isInstanceOf[Configurable]) {
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 51f02409b6..c7d1926b83 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -3,6 +3,7 @@ package spark.rdd
import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext}
import spark.SparkContext._
+
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
override val index = idx
override def hashCode(): Int = idx
@@ -12,13 +13,15 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
* The resulting RDD from a shuffle (e.g. repartitioning of data).
* @param prev the parent RDD.
* @param part the partitioner used to partition the RDD
+ * @param serializerClass class name of the serializer to use.
* @tparam K the key class.
* @tparam V the value class.
*/
class ShuffledRDD[K, V](
- prev: RDD[(K, V)],
- part: Partitioner)
- extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part))) {
+ @transient prev: RDD[(K, V)],
+ part: Partitioner,
+ serializerClass: String = null)
+ extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part, serializerClass))) {
override val partitioner = Some(part)
@@ -28,6 +31,7 @@ class ShuffledRDD[K, V](
override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics)
+ SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics,
+ SparkEnv.get.serializerManager.get(serializerClass))
}
}
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index 0a02561062..8a9efc5da2 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -12,6 +12,7 @@ import spark.SparkEnv
import spark.ShuffleDependency
import spark.OneToOneDependency
+
/**
* An optimized version of cogroup for set difference/subtraction.
*
@@ -31,7 +32,9 @@ import spark.OneToOneDependency
private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest](
@transient var rdd1: RDD[(K, V)],
@transient var rdd2: RDD[(K, W)],
- part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) {
+ part: Partitioner,
+ val serializerClass: String = null)
+ extends RDD[(K, V)](rdd1.context, Nil) {
override def getDependencies: Seq[Dependency[_]] = {
Seq(rdd1, rdd2).map { rdd =>
@@ -40,7 +43,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
- new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part)
+ new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part, serializerClass)
}
}
}
@@ -56,7 +59,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
case _ =>
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
}
- }.toList)
+ }.toArray)
}
array
}
@@ -65,6 +68,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
+ val serializer = SparkEnv.get.serializerManager.get(serializerClass)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
@@ -77,12 +81,16 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
}
}
def integrate(dep: CoGroupSplitDep, op: ((K, V)) => Unit) = dep match {
- case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
for (t <- rdd.iterator(itsSplit, context))
op(t.asInstanceOf[(K, V)])
- case ShuffleCoGroupSplitDep(shuffleId) =>
- for (t <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
+ }
+ case ShuffleCoGroupSplitDep(shuffleId) => {
+ val iter = SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index,
+ context.taskMetrics, serializer)
+ for (t <- iter)
op(t.asInstanceOf[(K, V)])
+ }
}
// the first dep is rdd1; add all values to the map
integrate(partition.deps(0), t => getSeq(t._1) += t._2)
diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala
new file mode 100644
index 0000000000..fc3f29ffcd
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala
@@ -0,0 +1,120 @@
+package spark.rdd
+
+import spark.{OneToOneDependency, RDD, SparkContext, Partition, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
+
+private[spark] class ZippedPartitionsPartition(
+ idx: Int,
+ @transient rdds: Seq[RDD[_]])
+ extends Partition {
+
+ override val index: Int = idx
+ var partitionValues = rdds.map(rdd => rdd.partitions(idx))
+ def partitions = partitionValues
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ partitionValues = rdds.map(rdd => rdd.partitions(idx))
+ oos.defaultWriteObject()
+ }
+}
+
+abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
+ sc: SparkContext,
+ var rdds: Seq[RDD[_]])
+ extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) {
+
+ override def getPartitions: Array[Partition] = {
+ val sizes = rdds.map(x => x.partitions.size)
+ if (!sizes.forall(x => x == sizes(0))) {
+ throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
+ }
+ val array = new Array[Partition](sizes(0))
+ for (i <- 0 until sizes(0)) {
+ array(i) = new ZippedPartitionsPartition(i, rdds)
+ }
+ array
+ }
+
+ override def getPreferredLocations(s: Partition): Seq[String] = {
+ val splits = s.asInstanceOf[ZippedPartitionsPartition].partitions
+ val preferredLocations = rdds.zip(splits).map(x => x._1.preferredLocations(x._2))
+ preferredLocations.reduce((x, y) => x.intersect(y))
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdds = null
+ }
+}
+
+class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest](
+ sc: SparkContext,
+ f: (Iterator[A], Iterator[B]) => Iterator[V],
+ var rdd1: RDD[A],
+ var rdd2: RDD[B])
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) {
+
+ override def compute(s: Partition, context: TaskContext): Iterator[V] = {
+ val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
+ f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context))
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdd1 = null
+ rdd2 = null
+ }
+}
+
+class ZippedPartitionsRDD3
+ [A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest](
+ sc: SparkContext,
+ f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
+ var rdd1: RDD[A],
+ var rdd2: RDD[B],
+ var rdd3: RDD[C])
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) {
+
+ override def compute(s: Partition, context: TaskContext): Iterator[V] = {
+ val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
+ f(rdd1.iterator(partitions(0), context),
+ rdd2.iterator(partitions(1), context),
+ rdd3.iterator(partitions(2), context))
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdd1 = null
+ rdd2 = null
+ rdd3 = null
+ }
+}
+
+class ZippedPartitionsRDD4
+ [A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest](
+ sc: SparkContext,
+ f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
+ var rdd1: RDD[A],
+ var rdd2: RDD[B],
+ var rdd3: RDD[C],
+ var rdd4: RDD[D])
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) {
+
+ override def compute(s: Partition, context: TaskContext): Iterator[V] = {
+ val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
+ f(rdd1.iterator(partitions(0), context),
+ rdd2.iterator(partitions(1), context),
+ rdd3.iterator(partitions(2), context),
+ rdd4.iterator(partitions(3), context))
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdd1 = null
+ rdd2 = null
+ rdd3 = null
+ rdd4 = null
+ }
+}
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
index e80ec17aa5..35b0e06785 100644
--- a/core/src/main/scala/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -10,17 +10,17 @@ private[spark] class ZippedPartition[T: ClassManifest, U: ClassManifest](
@transient rdd2: RDD[U]
) extends Partition {
- var split1 = rdd1.partitions(idx)
- var split2 = rdd1.partitions(idx)
+ var partition1 = rdd1.partitions(idx)
+ var partition2 = rdd2.partitions(idx)
override val index: Int = idx
- def splits = (split1, split2)
+ def partitions = (partition1, partition2)
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
- // Update the reference to parent split at the time of task serialization
- split1 = rdd1.partitions(idx)
- split2 = rdd2.partitions(idx)
+ // Update the reference to parent partition at the time of task serialization
+ partition1 = rdd1.partitions(idx)
+ partition2 = rdd2.partitions(idx)
oos.defaultWriteObject()
}
}
@@ -43,13 +43,13 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
}
override def compute(s: Partition, context: TaskContext): Iterator[(T, U)] = {
- val (split1, split2) = s.asInstanceOf[ZippedPartition[T, U]].splits
- rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
+ val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions
+ rdd1.iterator(partition1, context).zip(rdd2.iterator(partition2, context))
}
override def getPreferredLocations(s: Partition): Seq[String] = {
- val (split1, split2) = s.asInstanceOf[ZippedPartition[T, U]].splits
- rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
+ val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions
+ rdd1.preferredLocations(partition1).intersect(rdd2.preferredLocations(partition2))
}
override def clearDependencies() {
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index c54dce51d7..1440b93e65 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -50,6 +50,11 @@ class DAGScheduler(
eventQueue.put(ExecutorLost(execId))
}
+ // Called by TaskScheduler when a host is added
+ override def executorGained(execId: String, hostPort: String) {
+ eventQueue.put(ExecutorGained(execId, hostPort))
+ }
+
// Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
override def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
@@ -113,7 +118,7 @@ class DAGScheduler(
if (!cacheLocs.contains(rdd.id)) {
val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map {
- locations => locations.map(_.ip).toList
+ locations => locations.map(_.hostPort).toList
}.toArray
}
cacheLocs(rdd.id)
@@ -293,6 +298,9 @@ class DAGScheduler(
submitStage(finalStage)
}
+ case ExecutorGained(execId, hostPort) =>
+ handleExecutorGained(execId, hostPort)
+
case ExecutorLost(execId) =>
handleExecutorLost(execId)
@@ -630,6 +638,14 @@ class DAGScheduler(
"(generation " + currentGeneration + ")")
}
}
+
+ private def handleExecutorGained(execId: String, hostPort: String) {
+ // remove from failedGeneration(execId) ?
+ if (failedGeneration.contains(execId)) {
+ logInfo("Host gained which was in lost list earlier: " + hostPort)
+ failedGeneration -= execId
+ }
+ }
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
index ed0b9bf178..b46bb863f0 100644
--- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
@@ -32,6 +32,10 @@ private[spark] case class CompletionEvent(
taskMetrics: TaskMetrics)
extends DAGSchedulerEvent
+private[spark] case class ExecutorGained(execId: String, hostPort: String) extends DAGSchedulerEvent {
+ Utils.checkHostPort(hostPort, "Required hostport")
+}
+
private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
diff --git a/core/src/main/scala/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/spark/scheduler/InputFormatInfo.scala
new file mode 100644
index 0000000000..287f731787
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/InputFormatInfo.scala
@@ -0,0 +1,156 @@
+package spark.scheduler
+
+import spark.Logging
+import scala.collection.immutable.Set
+import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
+import org.apache.hadoop.util.ReflectionUtils
+import org.apache.hadoop.mapreduce.Job
+import org.apache.hadoop.conf.Configuration
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.JavaConversions._
+
+
+/**
+ * Parses and holds information about inputFormat (and files) specified as a parameter.
+ */
+class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_],
+ val path: String) extends Logging {
+
+ var mapreduceInputFormat: Boolean = false
+ var mapredInputFormat: Boolean = false
+
+ validate()
+
+ override def toString(): String = {
+ "InputFormatInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + ", path : " + path
+ }
+
+ override def hashCode(): Int = {
+ var hashCode = inputFormatClazz.hashCode
+ hashCode = hashCode * 31 + path.hashCode
+ hashCode
+ }
+
+ // Since we are not doing canonicalization of path, this can be wrong : like relative vs absolute path
+ // .. which is fine, this is best case effort to remove duplicates - right ?
+ override def equals(other: Any): Boolean = other match {
+ case that: InputFormatInfo => {
+ // not checking config - that should be fine, right ?
+ this.inputFormatClazz == that.inputFormatClazz &&
+ this.path == that.path
+ }
+ case _ => false
+ }
+
+ private def validate() {
+ logDebug("validate InputFormatInfo : " + inputFormatClazz + ", path " + path)
+
+ try {
+ if (classOf[org.apache.hadoop.mapreduce.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) {
+ logDebug("inputformat is from mapreduce package")
+ mapreduceInputFormat = true
+ }
+ else if (classOf[org.apache.hadoop.mapred.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) {
+ logDebug("inputformat is from mapred package")
+ mapredInputFormat = true
+ }
+ else {
+ throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz +
+ " is NOT a supported input format ? does not implement either of the supported hadoop api's")
+ }
+ }
+ catch {
+ case e: ClassNotFoundException => {
+ throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + " cannot be found ?", e)
+ }
+ }
+ }
+
+
+ // This method does not expect failures, since validate has already passed ...
+ private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = {
+ val conf = new JobConf(configuration)
+ FileInputFormat.setInputPaths(conf, path)
+
+ val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] =
+ ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[
+ org.apache.hadoop.mapreduce.InputFormat[_, _]]
+ val job = new Job(conf)
+
+ val retval = new ArrayBuffer[SplitInfo]()
+ val list = instance.getSplits(job)
+ for (split <- list) {
+ retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split)
+ }
+
+ return retval.toSet
+ }
+
+ // This method does not expect failures, since validate has already passed ...
+ private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = {
+ val jobConf = new JobConf(configuration)
+ FileInputFormat.setInputPaths(jobConf, path)
+
+ val instance: org.apache.hadoop.mapred.InputFormat[_, _] =
+ ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], jobConf).asInstanceOf[
+ org.apache.hadoop.mapred.InputFormat[_, _]]
+
+ val retval = new ArrayBuffer[SplitInfo]()
+ instance.getSplits(jobConf, jobConf.getNumMapTasks()).foreach(
+ elem => retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, elem)
+ )
+
+ return retval.toSet
+ }
+
+ private def findPreferredLocations(): Set[SplitInfo] = {
+ logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat +
+ ", inputFormatClazz : " + inputFormatClazz)
+ if (mapreduceInputFormat) {
+ return prefLocsFromMapreduceInputFormat()
+ }
+ else {
+ assert(mapredInputFormat)
+ return prefLocsFromMapredInputFormat()
+ }
+ }
+}
+
+
+
+
+object InputFormatInfo {
+ /**
+ Computes the preferred locations based on input(s) and returned a location to block map.
+ Typical use of this method for allocation would follow some algo like this
+ (which is what we currently do in YARN branch) :
+ a) For each host, count number of splits hosted on that host.
+ b) Decrement the currently allocated containers on that host.
+ c) Compute rack info for each host and update rack -> count map based on (b).
+ d) Allocate nodes based on (c)
+ e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node
+ (even if data locality on that is very high) : this is to prevent fragility of job if a single
+ (or small set of) hosts go down.
+
+ go to (a) until required nodes are allocated.
+
+ If a node 'dies', follow same procedure.
+
+ PS: I know the wording here is weird, hopefully it makes some sense !
+ */
+ def computePreferredLocations(formats: Seq[InputFormatInfo]): HashMap[String, HashSet[SplitInfo]] = {
+
+ val nodeToSplit = new HashMap[String, HashSet[SplitInfo]]
+ for (inputSplit <- formats) {
+ val splits = inputSplit.findPreferredLocations()
+
+ for (split <- splits){
+ val location = split.hostLocation
+ val set = nodeToSplit.getOrElseUpdate(location, new HashSet[SplitInfo])
+ set += split
+ }
+ }
+
+ nodeToSplit
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index beb21a76fe..89dc6640b2 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -70,6 +70,14 @@ private[spark] class ResultTask[T, U](
rdd.partitions(partition)
}
+ // data locality is on a per host basis, not hyper specific to container (host:port). Unique on set of hosts.
+ val preferredLocs: Seq[String] = if (locs == null) Nil else locs.map(loc => Utils.parseHostPort(loc)._1).toSet.toSeq
+
+ {
+ // DEBUG code
+ preferredLocs.foreach (host => Utils.checkHost(host, "preferredLocs : " + preferredLocs))
+ }
+
override def run(attemptId: Long): U = {
val context = new TaskContext(stageId, partition, attemptId)
metrics = Some(context.taskMetrics)
@@ -80,7 +88,7 @@ private[spark] class ResultTask[T, U](
}
}
- override def preferredLocations: Seq[String] = locs
+ override def preferredLocations: Seq[String] = preferredLocs
override def toString = "ResultTask(" + stageId + ", " + partition + ")"
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 36d087a4d0..f097213ab5 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -13,9 +13,10 @@ import com.ning.compress.lzf.LZFInputStream
import com.ning.compress.lzf.LZFOutputStream
import spark._
-import executor.ShuffleWriteMetrics
+import spark.executor.ShuffleWriteMetrics
import spark.storage._
-import util.{TimeStampedHashMap, MetadataCleaner}
+import spark.util.{TimeStampedHashMap, MetadataCleaner}
+
private[spark] object ShuffleMapTask {
@@ -77,13 +78,27 @@ private[spark] class ShuffleMapTask(
var rdd: RDD[_],
var dep: ShuffleDependency[_,_],
var partition: Int,
- @transient var locs: Seq[String])
+ @transient private var locs: Seq[String])
extends Task[MapStatus](stageId)
with Externalizable
with Logging {
protected def this() = this(0, null, null, 0, null)
+ // Data locality is on a per host basis, not hyper specific to container (host:port).
+ // Unique on set of hosts.
+ // TODO(rxin): The above statement seems problematic. Even if partitions are on the same host,
+ // the worker would still need to serialize / deserialize those data when they are in
+ // different jvm processes. Often that is very costly ...
+ @transient
+ private val preferredLocs: Seq[String] =
+ if (locs == null) Nil else locs.map(loc => Utils.parseHostPort(loc)._1).toSet.toSeq
+
+ {
+ // DEBUG code
+ preferredLocs.foreach (host => Utils.checkHost(host, "preferredLocs : " + preferredLocs))
+ }
+
var split = if (rdd == null) {
null
} else {
@@ -121,40 +136,58 @@ private[spark] class ShuffleMapTask(
val taskContext = new TaskContext(stageId, partition, attemptId)
metrics = Some(taskContext.taskMetrics)
+
+ val blockManager = SparkEnv.get.blockManager
+ var shuffle: ShuffleBlocks = null
+ var buckets: ShuffleWriterGroup = null
+
try {
- // Partition the map output.
- val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
+ // Obtain all the block writers for shuffle blocks.
+ val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
+ shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
+ buckets = shuffle.acquireWriters(partition)
+
+ // Write the map output to its associated buckets.
for (elem <- rdd.iterator(split, taskContext)) {
val pair = elem.asInstanceOf[(Any, Any)]
val bucketId = dep.partitioner.getPartition(pair._1)
- buckets(bucketId) += pair
+ buckets.writers(bucketId).write(pair)
}
- val compressedSizes = new Array[Byte](numOutputSplits)
-
- var totalBytes = 0l
-
- val blockManager = SparkEnv.get.blockManager
- for (i <- 0 until numOutputSplits) {
- val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
- // Get a Scala iterator from Java map
- val iter: Iterator[(Any, Any)] = buckets(i).iterator
- val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
+ // Commit the writes. Get the size of each bucket block (total block size).
+ var totalBytes = 0L
+ val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
+ writer.commit()
+ writer.close()
+ val size = writer.size()
totalBytes += size
- compressedSizes(i) = MapOutputTracker.compressSize(size)
+ MapOutputTracker.compressSize(size)
}
+
+ // Update shuffle metrics.
val shuffleMetrics = new ShuffleWriteMetrics
shuffleMetrics.shuffleBytesWritten = totalBytes
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
return new MapStatus(blockManager.blockManagerId, compressedSizes)
+ } catch { case e: Exception =>
+ // If there is an exception from running the task, revert the partial writes
+ // and throw the exception upstream to Spark.
+ if (buckets != null) {
+ buckets.writers.foreach(_.revertPartialWrites())
+ }
+ throw e
} finally {
+ // Release the writers back to the shuffle block manager.
+ if (shuffle != null && buckets != null) {
+ shuffle.releaseWriters(buckets)
+ }
// Execute the callbacks on task completion.
taskContext.executeOnCompleteCallbacks()
}
}
- override def preferredLocations: Seq[String] = locs
+ override def preferredLocations: Seq[String] = preferredLocs
override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
}
diff --git a/core/src/main/scala/spark/scheduler/SplitInfo.scala b/core/src/main/scala/spark/scheduler/SplitInfo.scala
new file mode 100644
index 0000000000..6abfb7a1f7
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/SplitInfo.scala
@@ -0,0 +1,61 @@
+package spark.scheduler
+
+import collection.mutable.ArrayBuffer
+
+// information about a specific split instance : handles both split instances.
+// So that we do not need to worry about the differences.
+class SplitInfo(val inputFormatClazz: Class[_], val hostLocation: String, val path: String,
+ val length: Long, val underlyingSplit: Any) {
+ override def toString(): String = {
+ "SplitInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz +
+ ", hostLocation : " + hostLocation + ", path : " + path +
+ ", length : " + length + ", underlyingSplit " + underlyingSplit
+ }
+
+ override def hashCode(): Int = {
+ var hashCode = inputFormatClazz.hashCode
+ hashCode = hashCode * 31 + hostLocation.hashCode
+ hashCode = hashCode * 31 + path.hashCode
+ // ignore overflow ? It is hashcode anyway !
+ hashCode = hashCode * 31 + (length & 0x7fffffff).toInt
+ hashCode
+ }
+
+ // This is practically useless since most of the Split impl's dont seem to implement equals :-(
+ // So unless there is identity equality between underlyingSplits, it will always fail even if it
+ // is pointing to same block.
+ override def equals(other: Any): Boolean = other match {
+ case that: SplitInfo => {
+ this.hostLocation == that.hostLocation &&
+ this.inputFormatClazz == that.inputFormatClazz &&
+ this.path == that.path &&
+ this.length == that.length &&
+ // other split specific checks (like start for FileSplit)
+ this.underlyingSplit == that.underlyingSplit
+ }
+ case _ => false
+ }
+}
+
+object SplitInfo {
+
+ def toSplitInfo(inputFormatClazz: Class[_], path: String,
+ mapredSplit: org.apache.hadoop.mapred.InputSplit): Seq[SplitInfo] = {
+ val retval = new ArrayBuffer[SplitInfo]()
+ val length = mapredSplit.getLength
+ for (host <- mapredSplit.getLocations) {
+ retval += new SplitInfo(inputFormatClazz, host, path, length, mapredSplit)
+ }
+ retval
+ }
+
+ def toSplitInfo(inputFormatClazz: Class[_], path: String,
+ mapreduceSplit: org.apache.hadoop.mapreduce.InputSplit): Seq[SplitInfo] = {
+ val retval = new ArrayBuffer[SplitInfo]()
+ val length = mapreduceSplit.getLength
+ for (host <- mapreduceSplit.getLocations) {
+ retval += new SplitInfo(inputFormatClazz, host, path, length, mapreduceSplit)
+ }
+ retval
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala
index d549b184b0..7787b54762 100644
--- a/core/src/main/scala/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/TaskScheduler.scala
@@ -10,6 +10,10 @@ package spark.scheduler
private[spark] trait TaskScheduler {
def start(): Unit
+ // Invoked after system has successfully initialized (typically in spark context).
+ // Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc.
+ def postStartHook() { }
+
// Disconnect from the cluster.
def stop(): Unit
diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
index 771518dddf..b75d3736cf 100644
--- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
+++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
@@ -14,6 +14,9 @@ private[spark] trait TaskSchedulerListener {
def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any],
taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit
+ // A node was added to the cluster.
+ def executorGained(execId: String, hostPort: String): Unit
+
// A node was lost from the cluster.
def executorLost(execId: String): Unit
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index 26fdef101b..a9d9c5e44c 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -1,6 +1,6 @@
package spark.scheduler.cluster
-import java.io.{File, FileInputStream, FileOutputStream}
+import java.lang.{Boolean => JBoolean}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@@ -25,6 +25,35 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
// Threshold above which we warn user initial TaskSet may be starved
val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong
+ // How often to revive offers in case there are pending tasks - that is how often to try to get
+ // tasks scheduled in case there are nodes available : default 0 is to disable it - to preserve existing behavior
+ // Note that this is required due to delayed scheduling due to data locality waits, etc.
+ // TODO: rename property ?
+ val TASK_REVIVAL_INTERVAL = System.getProperty("spark.tasks.revive.interval", "0").toLong
+
+ /*
+ This property controls how aggressive we should be to modulate waiting for host local task scheduling.
+ To elaborate, currently there is a time limit (3 sec def) to ensure that spark attempts to wait for host locality of tasks before
+ scheduling on other nodes. We have modified this in yarn branch such that offers to task set happen in prioritized order :
+ host-local, rack-local and then others
+ But once all available host local (and no pref) tasks are scheduled, instead of waiting for 3 sec before
+ scheduling to other nodes (which degrades performance for time sensitive tasks and on larger clusters), we can
+ modulate that : to also allow rack local nodes or any node. The default is still set to HOST - so that previous behavior is
+ maintained. This is to allow tuning the tension between pulling rdd data off node and scheduling computation asap.
+
+ TODO: rename property ? The value is one of
+ - HOST_LOCAL (default, no change w.r.t current behavior),
+ - RACK_LOCAL and
+ - ANY
+
+ Note that this property makes more sense when used in conjugation with spark.tasks.revive.interval > 0 : else it is not very effective.
+
+ Additional Note: For non trivial clusters, there is a 4x - 5x reduction in running time (in some of our experiments) based on whether
+ it is left at default HOST_LOCAL, RACK_LOCAL (if cluster is configured to be rack aware) or ANY.
+ If cluster is rack aware, then setting it to RACK_LOCAL gives best tradeoff and a 3x - 4x performance improvement while minimizing IO impact.
+ Also, it brings down the variance in running time drastically.
+ */
+ val TASK_SCHEDULING_AGGRESSION = TaskLocality.parse(System.getProperty("spark.tasks.schedule.aggression", "HOST_LOCAL"))
val activeTaskSets = new HashMap[String, TaskSetManager]
var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
@@ -33,9 +62,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
val taskIdToExecutorId = new HashMap[Long, String]
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
- var hasReceivedTask = false
- var hasLaunchedTask = false
- val starvationTimer = new Timer(true)
+ @volatile private var hasReceivedTask = false
+ @volatile private var hasLaunchedTask = false
+ private val starvationTimer = new Timer(true)
// Incrementing Mesos task IDs
val nextTaskId = new AtomicLong(0)
@@ -43,11 +72,16 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// Which executor IDs we have executors on
val activeExecutorIds = new HashSet[String]
+ // TODO: We might want to remove this and merge it with execId datastructures - but later.
+ // Which hosts in the cluster are alive (contains hostPort's) - used for hyper local and local task locality.
+ private val hostPortsAlive = new HashSet[String]
+ private val hostToAliveHostPorts = new HashMap[String, HashSet[String]]
+
// The set of executors we have on each host; this is used to compute hostsAlive, which
// in turn is used to decide when we can attain data locality on a given host
- val executorsByHost = new HashMap[String, HashSet[String]]
+ val executorsByHostPort = new HashMap[String, HashSet[String]]
- val executorIdToHost = new HashMap[String, String]
+ val executorIdToHostPort = new HashMap[String, String]
// JAR server, if any JARs were added by the user to the SparkContext
var jarServer: HttpServer = null
@@ -75,11 +109,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
override def start() {
backend.start()
- if (System.getProperty("spark.speculation", "false") == "true") {
+ if (JBoolean.getBoolean("spark.speculation")) {
new Thread("ClusterScheduler speculation check") {
setDaemon(true)
override def run() {
+ logInfo("Starting speculative execution thread")
while (true) {
try {
Thread.sleep(SPECULATION_INTERVAL)
@@ -91,6 +126,27 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}.start()
}
+
+
+ // Change to always run with some default if TASK_REVIVAL_INTERVAL <= 0 ?
+ if (TASK_REVIVAL_INTERVAL > 0) {
+ new Thread("ClusterScheduler task offer revival check") {
+ setDaemon(true)
+
+ override def run() {
+ logInfo("Starting speculative task offer revival thread")
+ while (true) {
+ try {
+ Thread.sleep(TASK_REVIVAL_INTERVAL)
+ } catch {
+ case e: InterruptedException => {}
+ }
+
+ if (hasPendingTasks()) backend.reviveOffers()
+ }
+ }
+ }.start()
+ }
}
override def submitTasks(taskSet: TaskSet) {
@@ -139,22 +195,92 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
SparkEnv.set(sc.env)
// Mark each slave as alive and remember its hostname
for (o <- offers) {
- executorIdToHost(o.executorId) = o.hostname
- if (!executorsByHost.contains(o.hostname)) {
- executorsByHost(o.hostname) = new HashSet()
+ // DEBUG Code
+ Utils.checkHostPort(o.hostPort)
+
+ executorIdToHostPort(o.executorId) = o.hostPort
+ if (! executorsByHostPort.contains(o.hostPort)) {
+ executorsByHostPort(o.hostPort) = new HashSet[String]()
}
+
+ hostPortsAlive += o.hostPort
+ hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(o.hostPort)._1, new HashSet[String]).add(o.hostPort)
+ executorGained(o.executorId, o.hostPort)
}
// Build a list of tasks to assign to each slave
val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
val availableCpus = offers.map(o => o.cores).toArray
var launchedTask = false
+
+
for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) {
+
+ // Split offers based on host local, rack local and off-rack tasks.
+ val hostLocalOffers = new HashMap[String, ArrayBuffer[Int]]()
+ val rackLocalOffers = new HashMap[String, ArrayBuffer[Int]]()
+ val otherOffers = new HashMap[String, ArrayBuffer[Int]]()
+
+ for (i <- 0 until offers.size) {
+ val hostPort = offers(i).hostPort
+ // DEBUG code
+ Utils.checkHostPort(hostPort)
+ val host = Utils.parseHostPort(hostPort)._1
+ val numHostLocalTasks = math.max(0, math.min(manager.numPendingTasksForHost(hostPort), availableCpus(i)))
+ if (numHostLocalTasks > 0){
+ val list = hostLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
+ for (j <- 0 until numHostLocalTasks) list += i
+ }
+
+ val numRackLocalTasks = math.max(0,
+ // Remove host local tasks (which are also rack local btw !) from this
+ math.min(manager.numRackLocalPendingTasksForHost(hostPort) - numHostLocalTasks, availableCpus(i)))
+ if (numRackLocalTasks > 0){
+ val list = rackLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
+ for (j <- 0 until numRackLocalTasks) list += i
+ }
+ if (numHostLocalTasks <= 0 && numRackLocalTasks <= 0){
+ // add to others list - spread even this across cluster.
+ val list = otherOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
+ list += i
+ }
+ }
+
+ val offersPriorityList = new ArrayBuffer[Int](
+ hostLocalOffers.size + rackLocalOffers.size + otherOffers.size)
+ // First host local, then rack, then others
+ val numHostLocalOffers = {
+ val hostLocalPriorityList = ClusterScheduler.prioritizeContainers(hostLocalOffers)
+ offersPriorityList ++= hostLocalPriorityList
+ hostLocalPriorityList.size
+ }
+ val numRackLocalOffers = {
+ val rackLocalPriorityList = ClusterScheduler.prioritizeContainers(rackLocalOffers)
+ offersPriorityList ++= rackLocalPriorityList
+ rackLocalPriorityList.size
+ }
+ offersPriorityList ++= ClusterScheduler.prioritizeContainers(otherOffers)
+
+ var lastLoop = false
+ val lastLoopIndex = TASK_SCHEDULING_AGGRESSION match {
+ case TaskLocality.HOST_LOCAL => numHostLocalOffers
+ case TaskLocality.RACK_LOCAL => numRackLocalOffers + numHostLocalOffers
+ case TaskLocality.ANY => offersPriorityList.size
+ }
+
do {
launchedTask = false
- for (i <- 0 until offers.size) {
+ var loopCount = 0
+ for (i <- offersPriorityList) {
val execId = offers(i).executorId
- val host = offers(i).hostname
- manager.slaveOffer(execId, host, availableCpus(i)) match {
+ val hostPort = offers(i).hostPort
+
+ // If last loop and within the lastLoopIndex, expand scope - else use null (which will use default/existing)
+ val overrideLocality = if (lastLoop && loopCount < lastLoopIndex) TASK_SCHEDULING_AGGRESSION else null
+
+ // If last loop, override waiting for host locality - we scheduled all local tasks already and there might be more available ...
+ loopCount += 1
+
+ manager.slaveOffer(execId, hostPort, availableCpus(i), overrideLocality) match {
case Some(task) =>
tasks(i) += task
val tid = task.taskId
@@ -162,15 +288,31 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
taskSetTaskIds(manager.taskSet.id) += tid
taskIdToExecutorId(tid) = execId
activeExecutorIds += execId
- executorsByHost(host) += execId
+ executorsByHostPort(hostPort) += execId
availableCpus(i) -= 1
launchedTask = true
-
+
case None => {}
}
}
+ // Loop once more - when lastLoop = true, then we try to schedule task on all nodes irrespective of
+ // data locality (we still go in order of priority : but that would not change anything since
+ // if data local tasks had been available, we would have scheduled them already)
+ if (lastLoop) {
+ // prevent more looping
+ launchedTask = false
+ } else if (!lastLoop && !launchedTask) {
+ // Do this only if TASK_SCHEDULING_AGGRESSION != HOST_LOCAL
+ if (TASK_SCHEDULING_AGGRESSION != TaskLocality.HOST_LOCAL) {
+ // fudge launchedTask to ensure we loop once more
+ launchedTask = true
+ // dont loop anymore
+ lastLoop = true
+ }
+ }
} while (launchedTask)
}
+
if (tasks.size > 0) {
hasLaunchedTask = true
}
@@ -256,10 +398,15 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (jarServer != null) {
jarServer.stop()
}
+
+ // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out.
+ // TODO: Do something better !
+ Thread.sleep(5000L)
}
override def defaultParallelism() = backend.defaultParallelism()
+
// Check for speculatable tasks in all our active jobs.
def checkSpeculatableTasks() {
var shouldRevive = false
@@ -273,12 +420,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
+ // Check for pending tasks in all our active jobs.
+ def hasPendingTasks(): Boolean = {
+ synchronized {
+ activeTaskSetsQueue.exists( _.hasPendingTasks() )
+ }
+ }
+
def executorLost(executorId: String, reason: ExecutorLossReason) {
var failedExecutor: Option[String] = None
+
synchronized {
if (activeExecutorIds.contains(executorId)) {
- val host = executorIdToHost(executorId)
- logError("Lost executor %s on %s: %s".format(executorId, host, reason))
+ val hostPort = executorIdToHostPort(executorId)
+ logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason))
removeExecutor(executorId)
failedExecutor = Some(executorId)
} else {
@@ -296,19 +451,95 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
- /** Get a list of hosts that currently have executors */
- def hostsAlive: scala.collection.Set[String] = executorsByHost.keySet
-
/** Remove an executor from all our data structures and mark it as lost */
private def removeExecutor(executorId: String) {
activeExecutorIds -= executorId
- val host = executorIdToHost(executorId)
- val execs = executorsByHost.getOrElse(host, new HashSet)
+ val hostPort = executorIdToHostPort(executorId)
+ if (hostPortsAlive.contains(hostPort)) {
+ // DEBUG Code
+ Utils.checkHostPort(hostPort)
+
+ hostPortsAlive -= hostPort
+ hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(hostPort)._1, new HashSet[String]).remove(hostPort)
+ }
+
+ val execs = executorsByHostPort.getOrElse(hostPort, new HashSet)
execs -= executorId
if (execs.isEmpty) {
- executorsByHost -= host
+ executorsByHostPort -= hostPort
}
- executorIdToHost -= executorId
- activeTaskSetsQueue.foreach(_.executorLost(executorId, host))
+ executorIdToHostPort -= executorId
+ activeTaskSetsQueue.foreach(_.executorLost(executorId, hostPort))
+ }
+
+ def executorGained(execId: String, hostPort: String) {
+ listener.executorGained(execId, hostPort)
+ }
+
+ def getExecutorsAliveOnHost(host: String): Option[Set[String]] = {
+ val retval = hostToAliveHostPorts.get(host)
+ if (retval.isDefined) {
+ return Some(retval.get.toSet)
+ }
+
+ None
+ }
+
+ // By default, rack is unknown
+ def getRackForHost(value: String): Option[String] = None
+
+ // By default, (cached) hosts for rack is unknown
+ def getCachedHostsForRack(rack: String): Option[Set[String]] = None
+}
+
+
+object ClusterScheduler {
+
+ // Used to 'spray' available containers across the available set to ensure too many containers on same host
+ // are not used up. Used in yarn mode and in task scheduling (when there are multiple containers available
+ // to execute a task)
+ // For example: yarn can returns more containers than we would have requested under ANY, this method
+ // prioritizes how to use the allocated containers.
+ // flatten the map such that the array buffer entries are spread out across the returned value.
+ // given <host, list[container]> == <h1, [c1 .. c5]>, <h2, [c1 .. c3]>, <h3, [c1, c2]>, <h4, c1>, <h5, c1>, i
+ // the return value would be something like : h1c1, h2c1, h3c1, h4c1, h5c1, h1c2, h2c2, h3c2, h1c3, h2c3, h1c4, h1c5
+ // We then 'use' the containers in this order (consuming only the top K from this list where
+ // K = number to be user). This is to ensure that if we have multiple eligible allocations,
+ // they dont end up allocating all containers on a small number of hosts - increasing probability of
+ // multiple container failure when a host goes down.
+ // Note, there is bias for keys with higher number of entries in value to be picked first (by design)
+ // Also note that invocation of this method is expected to have containers of same 'type'
+ // (host-local, rack-local, off-rack) and not across types : so that reordering is simply better from
+ // the available list - everything else being same.
+ // That is, we we first consume data local, then rack local and finally off rack nodes. So the
+ // prioritization from this method applies to within each category
+ def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = {
+ val _keyList = new ArrayBuffer[K](map.size)
+ _keyList ++= map.keys
+
+ // order keyList based on population of value in map
+ val keyList = _keyList.sortWith(
+ (left, right) => map.get(left).getOrElse(Set()).size > map.get(right).getOrElse(Set()).size
+ )
+
+ val retval = new ArrayBuffer[T](keyList.size * 2)
+ var index = 0
+ var found = true
+
+ while (found){
+ found = false
+ for (key <- keyList) {
+ val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null)
+ assert(containerList != null)
+ // Get the index'th entry for this host - if present
+ if (index < containerList.size){
+ retval += containerList.apply(index)
+ found = true
+ }
+ }
+ index += 1
+ }
+
+ retval.toList
}
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index bb289c9cf3..0b8922d139 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -57,9 +57,9 @@ private[spark] class SparkDeploySchedulerBackend(
}
}
- override def executorAdded(executorId: String, workerId: String, host: String, cores: Int, memory: Int) {
- logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format(
- executorId, host, cores, Utils.memoryMegabytesToString(memory)))
+ override def executorAdded(executorId: String, workerId: String, hostPort: String, cores: Int, memory: Int) {
+ logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format(
+ executorId, hostPort, cores, Utils.memoryMegabytesToString(memory)))
}
override def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) {
diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
index d766067824..3335294844 100644
--- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
@@ -3,6 +3,7 @@ package spark.scheduler.cluster
import spark.TaskState.TaskState
import java.nio.ByteBuffer
import spark.util.SerializableBuffer
+import spark.Utils
private[spark] sealed trait StandaloneClusterMessage extends Serializable
@@ -19,8 +20,10 @@ case class RegisterExecutorFailed(message: String) extends StandaloneClusterMess
// Executors to driver
private[spark]
-case class RegisterExecutor(executorId: String, host: String, cores: Int)
- extends StandaloneClusterMessage
+case class RegisterExecutor(executorId: String, hostPort: String, cores: Int)
+ extends StandaloneClusterMessage {
+ Utils.checkHostPort(hostPort, "Expected host port")
+}
private[spark]
case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer)
diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
index 7a428e3361..004592a540 100644
--- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
@@ -5,8 +5,9 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import akka.actor._
import akka.util.duration._
import akka.pattern.ask
+import akka.util.Duration
-import spark.{SparkException, Logging, TaskState}
+import spark.{Utils, SparkException, Logging, TaskState}
import akka.dispatch.Await
import java.util.concurrent.atomic.AtomicInteger
import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent}
@@ -24,12 +25,12 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
var totalCoreCount = new AtomicInteger(0)
class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor {
- val executorActor = new HashMap[String, ActorRef]
- val executorAddress = new HashMap[String, Address]
- val executorHost = new HashMap[String, String]
- val freeCores = new HashMap[String, Int]
- val actorToExecutorId = new HashMap[ActorRef, String]
- val addressToExecutorId = new HashMap[Address, String]
+ private val executorActor = new HashMap[String, ActorRef]
+ private val executorAddress = new HashMap[String, Address]
+ private val executorHostPort = new HashMap[String, String]
+ private val freeCores = new HashMap[String, Int]
+ private val actorToExecutorId = new HashMap[ActorRef, String]
+ private val addressToExecutorId = new HashMap[Address, String]
override def preStart() {
// Listen for remote client disconnection events, since they don't go through Akka's watch()
@@ -37,7 +38,8 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
def receive = {
- case RegisterExecutor(executorId, host, cores) =>
+ case RegisterExecutor(executorId, hostPort, cores) =>
+ Utils.checkHostPort(hostPort, "Host port expected " + hostPort)
if (executorActor.contains(executorId)) {
sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId)
} else {
@@ -45,7 +47,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
sender ! RegisteredExecutor(sparkProperties)
context.watch(sender)
executorActor(executorId) = sender
- executorHost(executorId) = host
+ executorHostPort(executorId) = hostPort
freeCores(executorId) = cores
executorAddress(executorId) = sender.path.address
actorToExecutorId(sender) = executorId
@@ -85,13 +87,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
// Make fake resource offers on all executors
def makeOffers() {
launchTasks(scheduler.resourceOffers(
- executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))}))
+ executorHostPort.toArray.map {case (id, hostPort) => new WorkerOffer(id, hostPort, freeCores(id))}))
}
// Make fake resource offers on just one executor
def makeOffers(executorId: String) {
launchTasks(scheduler.resourceOffers(
- Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId)))))
+ Seq(new WorkerOffer(executorId, executorHostPort(executorId), freeCores(executorId)))))
}
// Launch tasks returned by a set of resource offers
@@ -110,9 +112,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
actorToExecutorId -= executorActor(executorId)
addressToExecutorId -= executorAddress(executorId)
executorActor -= executorId
- executorHost -= executorId
+ executorHostPort -= executorId
freeCores -= executorId
- executorHost -= executorId
+ executorHostPort -= executorId
totalCoreCount.addAndGet(-numCores)
scheduler.executorLost(executorId, SlaveLost(reason))
}
@@ -128,7 +130,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
while (iterator.hasNext) {
val entry = iterator.next
val (key, value) = (entry.getKey.toString, entry.getValue.toString)
- if (key.startsWith("spark.")) {
+ if (key.startsWith("spark.") && !key.equals("spark.hostPort")) {
properties += ((key, value))
}
}
@@ -136,10 +138,11 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
}
+ private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+
override def stop() {
try {
if (driverActor != null) {
- val timeout = 5.seconds
val future = driverActor.ask(StopDriver)(timeout)
Await.result(future, timeout)
}
@@ -159,7 +162,6 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
// Called by subclasses when notified of a lost worker
def removeExecutor(executorId: String, reason: String) {
try {
- val timeout = 5.seconds
val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout)
Await.result(future, timeout)
} catch {
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
index dfe3c5a85b..718f26bfbd 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
@@ -1,5 +1,7 @@
package spark.scheduler.cluster
+import spark.Utils
+
/**
* Information about a running task attempt inside a TaskSet.
*/
@@ -9,8 +11,11 @@ class TaskInfo(
val index: Int,
val launchTime: Long,
val executorId: String,
- val host: String,
- val preferred: Boolean) {
+ val hostPort: String,
+ val taskLocality: TaskLocality.TaskLocality) {
+
+ Utils.checkHostPort(hostPort, "Expected hostport")
+
var finishTime: Long = 0
var failed = false
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index c9f2c48804..27e713e2c4 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -1,7 +1,6 @@
package spark.scheduler.cluster
-import java.util.Arrays
-import java.util.{HashMap => JHashMap}
+import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@@ -14,6 +13,36 @@ import spark.scheduler._
import spark.TaskState.TaskState
import java.nio.ByteBuffer
+private[spark] object TaskLocality extends Enumeration("HOST_LOCAL", "RACK_LOCAL", "ANY") with Logging {
+
+ val HOST_LOCAL, RACK_LOCAL, ANY = Value
+
+ type TaskLocality = Value
+
+ def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = {
+
+ constraint match {
+ case TaskLocality.HOST_LOCAL => condition == TaskLocality.HOST_LOCAL
+ case TaskLocality.RACK_LOCAL => condition == TaskLocality.HOST_LOCAL || condition == TaskLocality.RACK_LOCAL
+ // For anything else, allow
+ case _ => true
+ }
+ }
+
+ def parse(str: String): TaskLocality = {
+ // better way to do this ?
+ try {
+ TaskLocality.withName(str)
+ } catch {
+ case nEx: NoSuchElementException => {
+ logWarning("Invalid task locality specified '" + str + "', defaulting to HOST_LOCAL");
+ // default to preserve earlier behavior
+ HOST_LOCAL
+ }
+ }
+ }
+}
+
/**
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
*/
@@ -47,14 +76,22 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
// Last time when we launched a preferred task (for delay scheduling)
var lastPreferredLaunchTime = System.currentTimeMillis
- // List of pending tasks for each node. These collections are actually
+ // List of pending tasks for each node (hyper local to container). These collections are actually
// treated as stacks, in which new tasks are added to the end of the
// ArrayBuffer and removed from the end. This makes it faster to detect
// tasks that repeatedly fail because whenever a task failed, it is put
// back at the head of the stack. They are also only cleaned up lazily;
// when a task is launched, it remains in all the pending lists except
// the one that it was launched from, but gets removed from them later.
- val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
+ private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]]
+
+ // List of pending tasks for each node.
+ // Essentially, similar to pendingTasksForHostPort, except at host level
+ private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
+
+ // List of pending tasks for each node based on rack locality.
+ // Essentially, similar to pendingTasksForHost, except at rack level
+ private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]]
// List containing pending tasks with no locality preferences
val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
@@ -96,26 +133,117 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
addPendingTask(i)
}
+ private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, rackLocal: Boolean = false): ArrayBuffer[String] = {
+ // DEBUG code
+ _taskPreferredLocations.foreach(h => Utils.checkHost(h, "taskPreferredLocation " + _taskPreferredLocations))
+
+ val taskPreferredLocations = if (! rackLocal) _taskPreferredLocations else {
+ // Expand set to include all 'seen' rack local hosts.
+ // This works since container allocation/management happens within master - so any rack locality information is updated in msater.
+ // Best case effort, and maybe sort of kludge for now ... rework it later ?
+ val hosts = new HashSet[String]
+ _taskPreferredLocations.foreach(h => {
+ val rackOpt = scheduler.getRackForHost(h)
+ if (rackOpt.isDefined) {
+ val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get)
+ if (hostsOpt.isDefined) {
+ hosts ++= hostsOpt.get
+ }
+ }
+
+ // Ensure that irrespective of what scheduler says, host is always added !
+ hosts += h
+ })
+
+ hosts
+ }
+
+ val retval = new ArrayBuffer[String]
+ scheduler.synchronized {
+ for (prefLocation <- taskPreferredLocations) {
+ val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(prefLocation)
+ if (aliveLocationsOpt.isDefined) {
+ retval ++= aliveLocationsOpt.get
+ }
+ }
+ }
+
+ retval
+ }
+
// Add a task to all the pending-task lists that it should be on.
private def addPendingTask(index: Int) {
- val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
- if (locations.size == 0) {
+ // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate
+ // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it.
+ val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched)
+ val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, true)
+
+ if (rackLocalLocations.size == 0) {
+ // Current impl ensures this.
+ assert (hostLocalLocations.size == 0)
pendingTasksWithNoPrefs += index
} else {
- for (host <- locations) {
- val list = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
+
+ // host locality
+ for (hostPort <- hostLocalLocations) {
+ // DEBUG Code
+ Utils.checkHostPort(hostPort)
+
+ val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer())
+ hostPortList += index
+
+ val host = Utils.parseHostPort(hostPort)._1
+ val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
+ hostList += index
+ }
+
+ // rack locality
+ for (rackLocalHostPort <- rackLocalLocations) {
+ // DEBUG Code
+ Utils.checkHostPort(rackLocalHostPort)
+
+ val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1
+ val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer())
list += index
}
}
+
allPendingTasks += index
}
+ // Return the pending tasks list for a given host port (hyper local), or an empty list if
+ // there is no map entry for that host
+ private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = {
+ // DEBUG Code
+ Utils.checkHostPort(hostPort)
+ pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer())
+ }
+
// Return the pending tasks list for a given host, or an empty list if
// there is no map entry for that host
- private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
+ private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
+ val host = Utils.parseHostPort(hostPort)._1
pendingTasksForHost.getOrElse(host, ArrayBuffer())
}
+ // Return the pending tasks (rack level) list for a given host, or an empty list if
+ // there is no map entry for that host
+ private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer())
+ }
+
+ // Number of pending tasks for a given host (which would be data local)
+ def numPendingTasksForHost(hostPort: String): Int = {
+ getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
+ }
+
+ // Number of pending rack local tasks for a given host
+ def numRackLocalPendingTasksForHost(hostPort: String): Int = {
+ getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
+ }
+
+
// Dequeue a pending task from the given list and return its index.
// Return None if the list is empty.
// This method also cleans up any tasks in the list that have already
@@ -132,26 +260,49 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
}
// Return a speculative task for a given host if any are available. The task should not have an
- // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
- // task must have a preference for this host (or no preferred locations at all).
- private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
- val hostsAlive = sched.hostsAlive
+ // attempt running on this host, in case the host is slow. In addition, if locality is set, the
+ // task must have a preference for this host/rack/no preferred locations at all.
+ private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
+
+ assert (TaskLocality.isAllowed(locality, TaskLocality.HOST_LOCAL))
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
- val localTask = speculatableTasks.find {
- index =>
- val locations = tasks(index).preferredLocations.toSet & hostsAlive
- val attemptLocs = taskAttempts(index).map(_.host)
- (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host)
+
+ if (speculatableTasks.size > 0) {
+ val localTask = speculatableTasks.find {
+ index =>
+ val locations = findPreferredLocations(tasks(index).preferredLocations, sched)
+ val attemptLocs = taskAttempts(index).map(_.hostPort)
+ (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort)
+ }
+
+ if (localTask != None) {
+ speculatableTasks -= localTask.get
+ return localTask
}
- if (localTask != None) {
- speculatableTasks -= localTask.get
- return localTask
- }
- if (!localOnly && speculatableTasks.size > 0) {
- val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.host).contains(host))
- if (nonLocalTask != None) {
- speculatableTasks -= nonLocalTask.get
- return nonLocalTask
+
+ // check for rack locality
+ if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+ val rackTask = speculatableTasks.find {
+ index =>
+ val locations = findPreferredLocations(tasks(index).preferredLocations, sched, true)
+ val attemptLocs = taskAttempts(index).map(_.hostPort)
+ locations.contains(hostPort) && !attemptLocs.contains(hostPort)
+ }
+
+ if (rackTask != None) {
+ speculatableTasks -= rackTask.get
+ return rackTask
+ }
+ }
+
+ // Any task ...
+ if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+ // Check for attemptLocs also ?
+ val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort))
+ if (nonLocalTask != None) {
+ speculatableTasks -= nonLocalTask.get
+ return nonLocalTask
+ }
}
}
return None
@@ -159,59 +310,103 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
// Dequeue a pending task for a given node and return its index.
// If localOnly is set to false, allow non-local tasks as well.
- private def findTask(host: String, localOnly: Boolean): Option[Int] = {
- val localTask = findTaskFromList(getPendingTasksForHost(host))
+ private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
+ val localTask = findTaskFromList(getPendingTasksForHost(hostPort))
if (localTask != None) {
return localTask
}
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+ val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort))
+ if (rackLocalTask != None) {
+ return rackLocalTask
+ }
+ }
+
+ // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner.
+ // TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down).
val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
if (noPrefTask != None) {
return noPrefTask
}
- if (!localOnly) {
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
val nonLocalTask = findTaskFromList(allPendingTasks)
if (nonLocalTask != None) {
return nonLocalTask
}
}
+
// Finally, if all else has failed, find a speculative task
- return findSpeculativeTask(host, localOnly)
+ return findSpeculativeTask(hostPort, locality)
}
// Does a host count as a preferred location for a task? This is true if
// either the task has preferred locations and this host is one, or it has
// no preferred locations (in which we still count the launch as preferred).
- private def isPreferredLocation(task: Task[_], host: String): Boolean = {
+ private def isPreferredLocation(task: Task[_], hostPort: String): Boolean = {
val locs = task.preferredLocations
- return (locs.contains(host) || locs.isEmpty)
+ // DEBUG code
+ locs.foreach(h => Utils.checkHost(h, "preferredLocation " + locs))
+
+ if (locs.contains(hostPort) || locs.isEmpty) return true
+
+ val host = Utils.parseHostPort(hostPort)._1
+ locs.contains(host)
+ }
+
+ // Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location).
+ // This is true if either the task has preferred locations and this host is one, or it has
+ // no preferred locations (in which we still count the launch as preferred).
+ def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = {
+
+ val locs = task.preferredLocations
+
+ // DEBUG code
+ locs.foreach(h => Utils.checkHost(h, "preferredLocation " + locs))
+
+ val preferredRacks = new HashSet[String]()
+ for (preferredHost <- locs) {
+ val rack = sched.getRackForHost(preferredHost)
+ if (None != rack) preferredRacks += rack.get
+ }
+
+ if (preferredRacks.isEmpty) return false
+
+ val hostRack = sched.getRackForHost(hostPort)
+
+ return None != hostRack && preferredRacks.contains(hostRack.get)
}
// Respond to an offer of a single slave from the scheduler by finding a task
- def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = {
+ def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
+
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
- val time = System.currentTimeMillis
- val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
+ // If explicitly specified, use that
+ val locality = if (overrideLocality != null) overrideLocality else {
+ // expand only if we have waited for more than LOCALITY_WAIT for a host local task ...
+ val time = System.currentTimeMillis
+ if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.HOST_LOCAL else TaskLocality.ANY
+ }
- findTask(host, localOnly) match {
+ findTask(hostPort, locality) match {
case Some(index) => {
// Found a task; do some bookkeeping and return a Mesos task for it
val task = tasks(index)
val taskId = sched.newTaskId()
// Figure out whether this should count as a preferred launch
- val preferred = isPreferredLocation(task, host)
- val prefStr = if (preferred) {
- "preferred"
- } else {
- "non-preferred, not one of " + task.preferredLocations.mkString(", ")
- }
- logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
- taskSet.id, index, taskId, execId, host, prefStr))
+ val taskLocality = if (isPreferredLocation(task, hostPort)) TaskLocality.HOST_LOCAL else
+ if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else TaskLocality.ANY
+ val prefStr = taskLocality.toString
+ logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
+ taskSet.id, index, taskId, execId, hostPort, prefStr))
// Do various bookkeeping
copiesRunning(index) += 1
- val info = new TaskInfo(taskId, index, time, execId, host, preferred)
+ val time = System.currentTimeMillis
+ val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality)
taskInfos(taskId) = info
taskAttempts(index) = info :: taskAttempts(index)
- if (preferred) {
+ if (TaskLocality.HOST_LOCAL == taskLocality) {
lastPreferredLaunchTime = time
}
// Serialize and return the task
@@ -355,17 +550,15 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
sched.taskSetFinished(this)
}
- def executorLost(execId: String, hostname: String) {
+ def executorLost(execId: String, hostPort: String) {
logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
- val newHostsAlive = sched.hostsAlive
// If some task has preferred locations only on hostname, and there are no more executors there,
// put it in the no-prefs list to avoid the wait from delay scheduling
- if (!newHostsAlive.contains(hostname)) {
- for (index <- getPendingTasksForHost(hostname)) {
- val newLocs = tasks(index).preferredLocations.toSet & newHostsAlive
- if (newLocs.isEmpty) {
- pendingTasksWithNoPrefs += index
- }
+ for (index <- getPendingTasksForHostPort(hostPort)) {
+ val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, true)
+ if (newLocs.isEmpty) {
+ assert (findPreferredLocations(tasks(index).preferredLocations, sched).isEmpty)
+ pendingTasksWithNoPrefs += index
}
}
// Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
@@ -419,7 +612,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
!speculatableTasks.contains(index)) {
logInfo(
"Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
- taskSet.id, index, info.host, threshold))
+ taskSet.id, index, info.hostPort, threshold))
speculatableTasks += index
foundTasks = true
}
@@ -427,4 +620,8 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
}
return foundTasks
}
+
+ def hasPendingTasks(): Boolean = {
+ numTasks > 0 && tasksFinished < numTasks
+ }
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
index 3c3afcbb14..c47824315c 100644
--- a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
@@ -4,5 +4,5 @@ package spark.scheduler.cluster
* Represents free resources available on an executor.
*/
private[spark]
-class WorkerOffer(val executorId: String, val hostname: String, val cores: Int) {
+class WorkerOffer(val executorId: String, val hostPort: String, val cores: Int) {
}
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 9e1bde3fbe..f060a940a9 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -7,7 +7,7 @@ import scala.collection.mutable.HashMap
import spark._
import spark.executor.ExecutorURLClassLoader
import spark.scheduler._
-import spark.scheduler.cluster.TaskInfo
+import spark.scheduler.cluster.{TaskLocality, TaskInfo}
/**
* A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
@@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
logInfo("Running " + task)
- val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local", true)
+ val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local:1", TaskLocality.HOST_LOCAL)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
try {
diff --git a/core/src/main/scala/spark/serializer/Serializer.scala b/core/src/main/scala/spark/serializer/Serializer.scala
index aca86ab6f0..2ad73b711d 100644
--- a/core/src/main/scala/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/spark/serializer/Serializer.scala
@@ -1,10 +1,13 @@
package spark.serializer
-import java.nio.ByteBuffer
import java.io.{EOFException, InputStream, OutputStream}
+import java.nio.ByteBuffer
+
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+
import spark.util.ByteBufferInputStream
+
/**
* A serializer. Because some serialization libraries are not thread safe, this class is used to
* create [[spark.serializer.SerializerInstance]] objects that do the actual serialization and are
@@ -14,6 +17,7 @@ trait Serializer {
def newInstance(): SerializerInstance
}
+
/**
* An instance of a serializer, for use by one thread at a time.
*/
@@ -45,6 +49,7 @@ trait SerializerInstance {
}
}
+
/**
* A stream for writing serialized objects.
*/
@@ -61,6 +66,7 @@ trait SerializationStream {
}
}
+
/**
* A stream for reading serialized objects.
*/
diff --git a/core/src/main/scala/spark/serializer/SerializerManager.scala b/core/src/main/scala/spark/serializer/SerializerManager.scala
new file mode 100644
index 0000000000..60b2aac797
--- /dev/null
+++ b/core/src/main/scala/spark/serializer/SerializerManager.scala
@@ -0,0 +1,45 @@
+package spark.serializer
+
+import java.util.concurrent.ConcurrentHashMap
+
+
+/**
+ * A service that returns a serializer object given the serializer's class name. If a previous
+ * instance of the serializer object has been created, the get method returns that instead of
+ * creating a new one.
+ */
+private[spark] class SerializerManager {
+
+ private val serializers = new ConcurrentHashMap[String, Serializer]
+ private var _default: Serializer = _
+
+ def default = _default
+
+ def setDefault(clsName: String): Serializer = {
+ _default = get(clsName)
+ _default
+ }
+
+ def get(clsName: String): Serializer = {
+ if (clsName == null) {
+ default
+ } else {
+ var serializer = serializers.get(clsName)
+ if (serializer != null) {
+ // If the serializer has been created previously, reuse that.
+ serializer
+ } else this.synchronized {
+ // Otherwise, create a new one. But make sure no other thread has attempted
+ // to create another new one at the same time.
+ serializer = serializers.get(clsName)
+ if (serializer == null) {
+ val clsLoader = Thread.currentThread.getContextClassLoader
+ serializer =
+ Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer]
+ serializers.put(clsName, serializer)
+ }
+ serializer
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockException.scala b/core/src/main/scala/spark/storage/BlockException.scala
new file mode 100644
index 0000000000..f275d476df
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockException.scala
@@ -0,0 +1,5 @@
+package spark.storage
+
+private[spark]
+case class BlockException(blockId: String, message: String) extends Exception(message)
+
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 5a00180922..433e939656 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -27,29 +27,35 @@ import spark.network.netty.ShuffleCopier
import io.netty.buffer.ByteBuf
private[spark]
-case class BlockException(blockId: String, message: String, ex: Exception = null)
-extends Exception(message)
-
-private[spark]
class BlockManager(
executorId: String,
actorSystem: ActorSystem,
val master: BlockManagerMaster,
- val serializer: Serializer,
+ val defaultSerializer: Serializer,
maxMemory: Long)
extends Logging {
- class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
- var pending: Boolean = true
- var size: Long = -1L
- var failed: Boolean = false
+ private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
+ @volatile var pending: Boolean = true
+ @volatile var size: Long = -1L
+ @volatile var initThread: Thread = null
+ @volatile var failed = false
+
+ setInitThread()
+
+ private def setInitThread() {
+ // Set current thread as init thread - waitForReady will not block this thread
+ // (in case there is non trivial initialization which ends up calling waitForReady as part of
+ // initialization itself)
+ this.initThread = Thread.currentThread()
+ }
/**
* Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
* Return true if the block is available, false otherwise.
*/
def waitForReady(): Boolean = {
- if (pending) {
+ if (initThread != Thread.currentThread() && pending) {
synchronized {
while (pending) this.wait()
}
@@ -59,28 +65,37 @@ class BlockManager(
/** Mark this BlockInfo as ready (i.e. block is finished writing) */
def markReady(sizeInBytes: Long) {
+ assert (pending)
+ size = sizeInBytes
+ initThread = null
+ failed = false
+ initThread = null
+ pending = false
synchronized {
- pending = false
- failed = false
- size = sizeInBytes
this.notifyAll()
}
}
/** Mark this BlockInfo as ready but failed */
def markFailure() {
+ assert (pending)
+ size = 0
+ initThread = null
+ failed = true
+ initThread = null
+ pending = false
synchronized {
- failed = true
- pending = false
this.notifyAll()
}
}
}
+ val shuffleBlockManager = new ShuffleBlockManager(this)
+
private val blockInfo = new TimeStampedHashMap[String, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
- private[storage] val diskStore: BlockStore =
+ private[storage] val diskStore: DiskStore =
new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
val connectionManager = new ConnectionManager(0)
@@ -103,7 +118,7 @@ class BlockManager(
val heartBeatFrequency = BlockManager.getHeartBeatFrequencyFromSystemProperties
- val host = System.getProperty("spark.hostname", Utils.localHostName())
+ val hostPort = Utils.localHostPort()
val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
@@ -214,9 +229,12 @@ class BlockManager(
* Tell the master about the current storage status of a block. This will send a block update
* message reflecting the current status, *not* the desired storage level in its block info.
* For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk.
+ *
+ * droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid).
+ * This ensures that update in master will compensate for the increase in memory on slave.
*/
- def reportBlockStatus(blockId: String, info: BlockInfo) {
- val needReregister = !tryToReportBlockStatus(blockId, info)
+ def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) {
+ val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize)
if (needReregister) {
logInfo("Got told to reregister updating block " + blockId)
// Reregistering will report our new block for free.
@@ -230,7 +248,7 @@ class BlockManager(
* which will be true if the block was successfully recorded and false if
* the slave needs to re-register.
*/
- private def tryToReportBlockStatus(blockId: String, info: BlockInfo): Boolean = {
+ private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = {
val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized {
info.level match {
case null =>
@@ -239,7 +257,7 @@ class BlockManager(
val inMem = level.useMemory && memoryStore.contains(blockId)
val onDisk = level.useDisk && diskStore.contains(blockId)
val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication)
- val memSize = if (inMem) memoryStore.getSize(blockId) else 0L
+ val memSize = if (inMem) memoryStore.getSize(blockId) else droppedMemorySize
val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L
(storageLevel, memSize, diskSize, info.tellMaster)
}
@@ -259,7 +277,7 @@ class BlockManager(
def getLocations(blockId: String): Seq[String] = {
val startTimeMs = System.currentTimeMillis
var managers = master.getLocations(blockId)
- val locations = managers.map(_.ip)
+ val locations = managers.map(_.hostPort)
logDebug("Got block locations in " + Utils.getUsedTimeMs(startTimeMs))
return locations
}
@@ -269,28 +287,26 @@ class BlockManager(
*/
def getLocations(blockIds: Array[String]): Array[Seq[String]] = {
val startTimeMs = System.currentTimeMillis
- val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray
+ val locations = master.getLocations(blockIds).map(_.map(_.hostPort).toSeq).toArray
logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
return locations
}
/**
+ * A short-circuited method to get blocks directly from disk. This is used for getting
+ * shuffle blocks. It is safe to do so without a lock on block info since disk store
+ * never deletes (recent) items.
+ */
+ def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ diskStore.getValues(blockId, serializer).orElse(
+ sys.error("Block " + blockId + " not found on disk, though it should be"))
+ }
+
+ /**
* Get block from local block manager.
*/
def getLocal(blockId: String): Option[Iterator[Any]] = {
logDebug("Getting local block " + blockId)
-
- // As an optimization for map output fetches, if the block is for a shuffle, return it
- // without acquiring a lock; the disk store never deletes (recent) items so this should work
- if (blockId.startsWith("shuffle_")) {
- return diskStore.getValues(blockId) match {
- case Some(iterator) =>
- Some(iterator)
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- }
-
val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
@@ -341,6 +357,8 @@ class BlockManager(
case Some(bytes) =>
// Put a copy of the block back in memory before returning it. Note that we can't
// put the ByteBuffer returned by the disk store as that's a memory-mapped file.
+ // The use of rewind assumes this.
+ assert (0 == bytes.position())
val copyForMemory = ByteBuffer.allocate(bytes.limit)
copyForMemory.put(bytes)
memoryStore.putBytes(blockId, copyForMemory, level)
@@ -374,7 +392,7 @@ class BlockManager(
// As an optimization for map output fetches, if the block is for a shuffle, return it
// without acquiring a lock; the disk store never deletes (recent) items so this should work
- if (blockId.startsWith("shuffle_")) {
+ if (ShuffleBlockManager.isShuffle(blockId)) {
return diskStore.getBytes(blockId) match {
case Some(bytes) =>
Some(bytes)
@@ -413,6 +431,7 @@ class BlockManager(
// Read it as a byte buffer into memory first, then return it
diskStore.getBytes(blockId) match {
case Some(bytes) =>
+ assert (0 == bytes.position())
if (level.useMemory) {
if (level.deserialized) {
memoryStore.putBytes(blockId, bytes, level)
@@ -452,7 +471,7 @@ class BlockManager(
for (loc <- locations) {
logDebug("Getting remote block " + blockId + " from " + loc)
val data = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port))
+ GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
if (data != null) {
return Some(dataDeserialize(blockId, data))
}
@@ -475,13 +494,13 @@ class BlockManager(
* fashion as they're received. Expects a size in bytes to be provided for each block fetched,
* so that we can control the maxMegabytesInFlight for the fetch.
*/
- def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])])
+ def getMultiple(
+ blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer)
: BlockFetcherIterator = {
-
if(System.getProperty("spark.shuffle.use.netty", "false").toBoolean){
- return BlockFetcherIterator("netty",this, blocksByAddress)
+ return BlockFetcherIterator("netty",this, blocksByAddress, serializer)
} else {
- return BlockFetcherIterator("", this, blocksByAddress)
+ return BlockFetcherIterator("", this, blocksByAddress, serializer)
}
}
@@ -493,6 +512,22 @@ class BlockManager(
}
/**
+ * A short circuited method to get a block writer that can write data directly to disk.
+ * This is currently used for writing shuffle files out. Callers should handle error
+ * cases.
+ */
+ def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ : BlockObjectWriter = {
+ val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
+ writer.registerCloseEventHandler(() => {
+ val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
+ blockInfo.put(blockId, myInfo)
+ myInfo.markReady(writer.size())
+ })
+ writer
+ }
+
+ /**
* Put a new block of values to the block manager. Returns its (estimated) size in bytes.
*/
def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel,
@@ -508,17 +543,26 @@ class BlockManager(
throw new IllegalArgumentException("Storage level is null or invalid")
}
- val oldBlock = blockInfo.get(blockId).orNull
- if (oldBlock != null && oldBlock.waitForReady()) {
- logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
- return oldBlock.size
- }
-
// Remember the block's storage level so that we can correctly drop it to disk if it needs
// to be dropped right after it got put into memory. Note, however, that other threads will
// not be able to get() this block until we call markReady on its BlockInfo.
- val myInfo = new BlockInfo(level, tellMaster)
- blockInfo.put(blockId, myInfo)
+ val myInfo = {
+ val tinfo = new BlockInfo(level, tellMaster)
+ // Do atomically !
+ val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
+
+ if (oldBlockOpt.isDefined) {
+ if (oldBlockOpt.get.waitForReady()) {
+ logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
+ return oldBlockOpt.get.size
+ }
+
+ // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
+ oldBlockOpt.get
+ } else {
+ tinfo
+ }
+ }
val startTimeMs = System.currentTimeMillis
@@ -538,6 +582,7 @@ class BlockManager(
logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
+ var marked = false
try {
if (level.useMemory) {
// Save it just to memory first, even if it also has useDisk set to true; we will later
@@ -562,26 +607,25 @@ class BlockManager(
// Now that the block is in either the memory or disk store, let other threads read it,
// and tell the master about it.
+ marked = true
myInfo.markReady(size)
if (tellMaster) {
reportBlockStatus(blockId, myInfo)
}
- } catch {
+ } finally {
// If we failed at putting the block to memory/disk, notify other possible readers
// that it has failed, and then remove it from the block info map.
- case e: Exception => {
+ if (! marked) {
// Note that the remove must happen before markFailure otherwise another thread
// could've inserted a new BlockInfo before we remove it.
blockInfo.remove(blockId)
myInfo.markFailure()
- logWarning("Putting block " + blockId + " failed", e)
- throw e
+ logWarning("Putting block " + blockId + " failed")
}
}
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
-
// Replicate block if required
if (level.replication > 1) {
val remoteStartTime = System.currentTimeMillis
@@ -618,16 +662,26 @@ class BlockManager(
throw new IllegalArgumentException("Storage level is null or invalid")
}
- if (blockInfo.contains(blockId)) {
- logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
- return
- }
-
// Remember the block's storage level so that we can correctly drop it to disk if it needs
// to be dropped right after it got put into memory. Note, however, that other threads will
// not be able to get() this block until we call markReady on its BlockInfo.
- val myInfo = new BlockInfo(level, tellMaster)
- blockInfo.put(blockId, myInfo)
+ val myInfo = {
+ val tinfo = new BlockInfo(level, tellMaster)
+ // Do atomically !
+ val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
+
+ if (oldBlockOpt.isDefined) {
+ if (oldBlockOpt.get.waitForReady()) {
+ logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
+ return
+ }
+
+ // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
+ oldBlockOpt.get
+ } else {
+ tinfo
+ }
+ }
val startTimeMs = System.currentTimeMillis
@@ -646,6 +700,7 @@ class BlockManager(
logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
+ var marked = false
try {
if (level.useMemory) {
// Store it only in memory at first, even if useDisk is also set to true
@@ -656,22 +711,24 @@ class BlockManager(
diskStore.putBytes(blockId, bytes, level)
}
+ // assert (0 == bytes.position(), "" + bytes)
+
// Now that the block is in either the memory or disk store, let other threads read it,
// and tell the master about it.
+ marked = true
myInfo.markReady(bytes.limit)
if (tellMaster) {
reportBlockStatus(blockId, myInfo)
}
- } catch {
+ } finally {
// If we failed at putting the block to memory/disk, notify other possible readers
// that it has failed, and then remove it from the block info map.
- case e: Exception => {
+ if (! marked) {
// Note that the remove must happen before markFailure otherwise another thread
// could've inserted a new BlockInfo before we remove it.
blockInfo.remove(blockId)
myInfo.markFailure()
- logWarning("Putting block " + blockId + " failed", e)
- throw e
+ logWarning("Putting block " + blockId + " failed")
}
}
}
@@ -705,7 +762,7 @@ class BlockManager(
logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is "
+ data.limit() + " Bytes. To node: " + peer)
if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel),
- new ConnectionManagerId(peer.ip, peer.port))) {
+ new ConnectionManagerId(peer.host, peer.port))) {
logError("Failed to call syncPutBlock to " + peer)
}
logDebug("Replicated BlockId " + blockId + " once used " +
@@ -737,6 +794,14 @@ class BlockManager(
val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
+ // required ? As of now, this will be invoked only for blocks which are ready
+ // But in case this changes in future, adding for consistency sake.
+ if (! info.waitForReady() ) {
+ // If we get here, the block write failed.
+ logWarning("Block " + blockId + " was marked as failure. Nothing to drop")
+ return
+ }
+
val level = info.level
if (level.useDisk && !diskStore.contains(blockId)) {
logInfo("Writing block " + blockId + " to disk")
@@ -747,12 +812,13 @@ class BlockManager(
diskStore.putBytes(blockId, bytes, level)
}
}
+ val droppedMemorySize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L
val blockWasRemoved = memoryStore.remove(blockId)
if (!blockWasRemoved) {
logWarning("Block " + blockId + " could not be dropped from memory as it does not exist")
}
if (info.tellMaster) {
- reportBlockStatus(blockId, info)
+ reportBlockStatus(blockId, info, droppedMemorySize)
}
if (!level.useDisk) {
// The block is completely gone from this node; forget it so we can put() it again later.
@@ -812,7 +878,7 @@ class BlockManager(
}
def shouldCompress(blockId: String): Boolean = {
- if (blockId.startsWith("shuffle_")) {
+ if (ShuffleBlockManager.isShuffle(blockId)) {
compressShuffle
} else if (blockId.startsWith("broadcast_")) {
compressBroadcast
@@ -827,7 +893,11 @@ class BlockManager(
* Wrap an output stream for compression if block compression is enabled for its block type
*/
def wrapForCompression(blockId: String, s: OutputStream): OutputStream = {
- if (shouldCompress(blockId)) new LZFOutputStream(s) else s
+ if (shouldCompress(blockId)) {
+ (new LZFOutputStream(s)).setFinishBlockOnFlush(true)
+ } else {
+ s
+ }
}
/**
@@ -837,7 +907,10 @@ class BlockManager(
if (shouldCompress(blockId)) new LZFInputStream(s) else s
}
- def dataSerialize(blockId: String, values: Iterator[Any]): ByteBuffer = {
+ def dataSerialize(
+ blockId: String,
+ values: Iterator[Any],
+ serializer: Serializer = defaultSerializer): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
val ser = serializer.newInstance()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
@@ -849,7 +922,10 @@ class BlockManager(
* Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of
* the iterator is reached.
*/
- def dataDeserialize(blockId: String, bytes: ByteBuffer): Iterator[Any] = {
+ def dataDeserialize(
+ blockId: String,
+ bytes: ByteBuffer,
+ serializer: Serializer = defaultSerializer): Iterator[Any] = {
bytes.rewind()
val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true))
serializer.newInstance().deserializeStream(stream).asIterator
@@ -922,7 +998,8 @@ object BlockFetcherIterator {
class BasicBlockFetcherIterator(
private val blockManager: BlockManager,
- val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ serializer: Serializer
) extends BlockFetcherIterator {
import blockManager._
@@ -952,8 +1029,8 @@ class BasicBlockFetcherIterator(
def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
- req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
- val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
+ req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort))
+ val cmId = new ConnectionManagerId(req.address.host, req.address.port)
val blockMessageArray = new BlockMessageArray(req.blocks.map {
case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
})
@@ -973,8 +1050,8 @@ class BasicBlockFetcherIterator(
"Unexpected message " + blockMessage.getType + " received from " + cmId)
}
val blockId = blockMessage.getId
- results.put(new FetchResult(
- blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData)))
+ results.put(new FetchResult(blockId, sizeMap(blockId),
+ () => dataDeserialize(blockId, blockMessage.getData, serializer)))
_remoteBytesRead += req.size
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
@@ -1043,7 +1120,7 @@ class BasicBlockFetcherIterator(
}
def initialize(){
- // Split local and remote blocks.
+ // Split local and remote blocks.
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
fetchRequests ++= Utils.randomize(remoteRequests)
@@ -1065,7 +1142,7 @@ class BasicBlockFetcherIterator(
}
//an iterator that will read fetched blocks off the queue as they arrive.
- var resultsGotten = 0
+ @volatile private var resultsGotten = 0
def hasNext: Boolean = resultsGotten < totalBlocks
@@ -1075,7 +1152,7 @@ class BasicBlockFetcherIterator(
val result = results.take()
val stopFetchWait = System.currentTimeMillis()
_fetchWaitTime += (stopFetchWait - startFetchWait)
- bytesInFlight -= result.size
+ if (! result.failed) bytesInFlight -= result.size
while (!fetchRequests.isEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
sendRequest(fetchRequests.dequeue())
@@ -1097,8 +1174,9 @@ class BasicBlockFetcherIterator(
class NettyBlockFetcherIterator(
blockManager: BlockManager,
- blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]
-) extends BasicBlockFetcherIterator(blockManager,blocksByAddress) {
+ blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ serializer: Serializer
+) extends BasicBlockFetcherIterator(blockManager,blocksByAddress,serializer) {
import blockManager._
@@ -1107,7 +1185,7 @@ class NettyBlockFetcherIterator(
def putResult(blockId:String, blockSize:Long, blockData:ByteBuffer,
results : LinkedBlockingQueue[FetchResult]){
results.put(new FetchResult(
- blockId, blockSize, () => dataDeserialize(blockId, blockData) ))
+ blockId, blockSize, () => dataDeserialize(blockId, blockData, serializer) ))
}
def startCopiers (numCopiers: Int): List [ _ <: Thread]= {
@@ -1194,7 +1272,7 @@ class NettyBlockFetcherIterator(
var copiers : List[_ <: Thread] = null
override def initialize(){
- // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks
+ // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
for (request <- Utils.randomize(remoteRequests)) {
@@ -1223,13 +1301,14 @@ class NettyBlockFetcherIterator(
}
}
- def apply(t: String,
+ def apply(t: String,
blockManager: BlockManager,
- blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]): BlockFetcherIterator = {
- val iter = if (t == "netty") { new NettyBlockFetcherIterator(blockManager,blocksByAddress) }
- else { new BasicBlockFetcherIterator(blockManager,blocksByAddress) }
+ blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ serializer: Serializer): BlockFetcherIterator = {
+ val iter = if (t == "netty") { new NettyBlockFetcherIterator(blockManager,blocksByAddress, serializer) }
+ else { new BasicBlockFetcherIterator(blockManager,blocksByAddress, serializer) }
iter.initialize
- iter
+ iter
}
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala
index f2f1e77d41..f4a2181490 100644
--- a/core/src/main/scala/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerId.scala
@@ -2,6 +2,7 @@ package spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.ConcurrentHashMap
+import spark.Utils
/**
* This class represent an unique identifier for a BlockManager.
@@ -13,7 +14,7 @@ import java.util.concurrent.ConcurrentHashMap
*/
private[spark] class BlockManagerId private (
private var executorId_ : String,
- private var ip_ : String,
+ private var host_ : String,
private var port_ : Int
) extends Externalizable {
@@ -21,32 +22,45 @@ private[spark] class BlockManagerId private (
def executorId: String = executorId_
- def ip: String = ip_
+ if (null != host_){
+ Utils.checkHost(host_, "Expected hostname")
+ assert (port_ > 0)
+ }
+
+ def hostPort: String = {
+ // DEBUG code
+ Utils.checkHost(host)
+ assert (port > 0)
+
+ host + ":" + port
+ }
+
+ def host: String = host_
def port: Int = port_
override def writeExternal(out: ObjectOutput) {
out.writeUTF(executorId_)
- out.writeUTF(ip_)
+ out.writeUTF(host_)
out.writeInt(port_)
}
override def readExternal(in: ObjectInput) {
executorId_ = in.readUTF()
- ip_ = in.readUTF()
+ host_ = in.readUTF()
port_ = in.readInt()
}
@throws(classOf[IOException])
private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this)
- override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, ip, port)
+ override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, host, port)
- override def hashCode: Int = (executorId.hashCode * 41 + ip.hashCode) * 41 + port
+ override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port
override def equals(that: Any) = that match {
case id: BlockManagerId =>
- executorId == id.executorId && port == id.port && ip == id.ip
+ executorId == id.executorId && port == id.port && host == id.host
case _ =>
false
}
@@ -55,8 +69,8 @@ private[spark] class BlockManagerId private (
private[spark] object BlockManagerId {
- def apply(execId: String, ip: String, port: Int) =
- getCachedBlockManagerId(new BlockManagerId(execId, ip, port))
+ def apply(execId: String, host: String, port: Int) =
+ getCachedBlockManagerId(new BlockManagerId(execId, host, port))
def apply(in: ObjectInput) = {
val obj = new BlockManagerId()
@@ -67,11 +81,7 @@ private[spark] object BlockManagerId {
val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()
def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
- if (blockManagerIdCache.containsKey(id)) {
- blockManagerIdCache.get(id)
- } else {
- blockManagerIdCache.put(id, id)
- id
- }
+ blockManagerIdCache.putIfAbsent(id, id)
+ blockManagerIdCache.get(id)
}
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
index 036fdc3480..ac26c16867 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -15,6 +15,7 @@ import akka.util.duration._
import spark.{Logging, SparkException, Utils}
+
private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging {
val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
@@ -22,7 +23,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster"
- val timeout = 10.seconds
+ val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
/** Remove a dead executor from the driver actor. This is only called on the driver side. */
def removeExecutor(execId: String) {
@@ -88,6 +89,21 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
}
/**
+ * Remove all blocks belonging to the given RDD.
+ */
+ def removeRdd(rddId: Int) {
+ val rddBlockPrefix = "rdd_" + rddId + "_"
+ // Get the list of blocks in block manager, and remove ones that are part of this RDD.
+ // The runtime complexity is linear to the number of blocks persisted in the cluster.
+ // It could be expensive if the cluster is large and has a lot of blocks persisted.
+ getStorageStatus.flatMap(_.blocks).foreach { case(blockId, status) =>
+ if (blockId.startsWith(rddBlockPrefix)) {
+ removeBlock(blockId)
+ }
+ }
+ }
+
+ /**
* Return the memory status for each block manager, in the form of a map from
* the block manager's id to two long values. The first value is the maximum
* amount of memory allocated for the block manager, while the second is the
diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
index 2830bc6297..9b64f95df8 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
@@ -121,7 +121,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
val toRemove = new HashSet[BlockManagerId]
for (info <- blockManagerInfo.values) {
if (info.lastSeenMs < minSeenTime) {
- logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats")
+ logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: " +
+ (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms")
toRemove += info.blockManagerId
}
}
@@ -332,8 +333,8 @@ object BlockManagerMasterActor {
// Mapping from block id to its status.
private val _blocks = new JHashMap[String, BlockStatus]
- logInfo("Registering block manager %s:%d with %s RAM".format(
- blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem)))
+ logInfo("Registering block manager %s with %s RAM".format(
+ blockManagerId.hostPort, Utils.memoryBytesToString(maxMem)))
def updateLastSeenMs() {
_lastSeenMs = System.currentTimeMillis()
@@ -358,13 +359,13 @@ object BlockManagerMasterActor {
_blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize))
if (storageLevel.useMemory) {
_remainingMem -= memSize
- logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
+ logInfo("Added %s in memory on %s (size: %s, free: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize),
Utils.memoryBytesToString(_remainingMem)))
}
if (storageLevel.useDisk) {
- logInfo("Added %s on disk on %s:%d (size: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
+ logInfo("Added %s on disk on %s (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize)))
}
} else if (_blocks.containsKey(blockId)) {
// If isValid is not true, drop the block.
@@ -372,13 +373,13 @@ object BlockManagerMasterActor {
_blocks.remove(blockId)
if (blockStatus.storageLevel.useMemory) {
_remainingMem += blockStatus.memSize
- logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
+ logInfo("Removed %s on %s in memory (size: %s, free: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize),
Utils.memoryBytesToString(_remainingMem)))
}
if (blockStatus.storageLevel.useDisk) {
- logInfo("Removed %s on %s:%d on disk (size: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
+ logInfo("Removed %s on %s on disk (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize)))
}
}
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala
index 9e6721ec17..07da572044 100644
--- a/core/src/main/scala/spark/storage/BlockManagerUI.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala
@@ -1,7 +1,7 @@
package spark.storage
import akka.actor.{ActorRef, ActorSystem}
-import akka.util.Timeout
+import akka.util.Duration
import akka.util.duration._
import cc.spray.typeconversion.TwirlSupport._
import cc.spray.Directives
@@ -19,7 +19,7 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef,
val STATIC_RESOURCE_DIR = "spark/deploy/static"
- implicit val timeout = Timeout(10 seconds)
+ implicit val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
/** Start a HTTP server to run the Web interface */
def start() {
diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
index d2985559c1..15225f93a6 100644
--- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
@@ -19,7 +19,7 @@ import spark.network._
*/
private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
initLogging()
-
+
blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive)
def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = {
@@ -51,7 +51,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
logDebug("Received [" + pB + "]")
putBlock(pB.id, pB.data, pB.level)
return None
- }
+ }
case BlockMessage.TYPE_GET_BLOCK => {
val gB = new GetBlock(blockMessage.getId)
logDebug("Received [" + gB + "]")
@@ -90,28 +90,26 @@ private[spark] object BlockManagerWorker extends Logging {
private var blockManagerWorker: BlockManagerWorker = null
private val DATA_TRANSFER_TIME_OUT_MS: Long = 500
private val REQUEST_RETRY_INTERVAL_MS: Long = 1000
-
+
initLogging()
-
+
def startBlockManagerWorker(manager: BlockManager) {
blockManagerWorker = new BlockManagerWorker(manager)
}
-
+
def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = {
val blockManager = blockManagerWorker.blockManager
- val connectionManager = blockManager.connectionManager
- val serializer = blockManager.serializer
+ val connectionManager = blockManager.connectionManager
val blockMessage = BlockMessage.fromPutBlock(msg)
val blockMessageArray = new BlockMessageArray(blockMessage)
val resultMessage = connectionManager.sendMessageReliablySync(
toConnManagerId, blockMessageArray.toBufferMessage)
return (resultMessage != None)
}
-
+
def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
val blockManager = blockManagerWorker.blockManager
- val connectionManager = blockManager.connectionManager
- val serializer = blockManager.serializer
+ val connectionManager = blockManager.connectionManager
val blockMessage = BlockMessage.fromGetBlock(msg)
val blockMessageArray = new BlockMessageArray(blockMessage)
val responseMessage = connectionManager.sendMessageReliablySync(
diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala
index a25decb123..ee0c5ff9a2 100644
--- a/core/src/main/scala/spark/storage/BlockMessageArray.scala
+++ b/core/src/main/scala/spark/storage/BlockMessageArray.scala
@@ -115,6 +115,7 @@ private[spark] object BlockMessageArray {
val newBuffer = ByteBuffer.allocate(totalSize)
newBuffer.clear()
bufferMessage.buffers.foreach(buffer => {
+ assert (0 == buffer.position())
newBuffer.put(buffer)
buffer.rewind()
})
diff --git a/core/src/main/scala/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/spark/storage/BlockObjectWriter.scala
new file mode 100644
index 0000000000..42e2b07d5c
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockObjectWriter.scala
@@ -0,0 +1,50 @@
+package spark.storage
+
+import java.nio.ByteBuffer
+
+
+/**
+ * An interface for writing JVM objects to some underlying storage. This interface allows
+ * appending data to an existing block, and can guarantee atomicity in the case of faults
+ * as it allows the caller to revert partial writes.
+ *
+ * This interface does not support concurrent writes.
+ */
+abstract class BlockObjectWriter(val blockId: String) {
+
+ var closeEventHandler: () => Unit = _
+
+ def open(): BlockObjectWriter
+
+ def close() {
+ closeEventHandler()
+ }
+
+ def isOpen: Boolean
+
+ def registerCloseEventHandler(handler: () => Unit) {
+ closeEventHandler = handler
+ }
+
+ /**
+ * Flush the partial writes and commit them as a single atomic block. Return the
+ * number of bytes written for this commit.
+ */
+ def commit(): Long
+
+ /**
+ * Reverts writes that haven't been flushed yet. Callers should invoke this function
+ * when there are runtime exceptions.
+ */
+ def revertPartialWrites()
+
+ /**
+ * Writes an object.
+ */
+ def write(value: Any)
+
+ /**
+ * Size of the valid writes, in bytes.
+ */
+ def size(): Long
+}
diff --git a/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala b/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala
deleted file mode 100644
index f6c28dce52..0000000000
--- a/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala
+++ /dev/null
@@ -1,12 +0,0 @@
-package spark.storage
-
-private[spark] trait DelegateBlockFetchTracker extends BlockFetchTracker {
- var delegate : BlockFetchTracker = _
- def setDelegate(d: BlockFetchTracker) {delegate = d}
- def totalBlocks = delegate.totalBlocks
- def numLocalBlocks = delegate.numLocalBlocks
- def numRemoteBlocks = delegate.numRemoteBlocks
- def remoteFetchTime = delegate.remoteFetchTime
- def fetchWaitTime = delegate.fetchWaitTime
- def remoteBytesRead = delegate.remoteBytesRead
-}
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index cc5bf29a32..82bcbd5bc2 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -1,28 +1,84 @@
package spark.storage
+import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
import java.nio.ByteBuffer
-import java.io.{File, FileOutputStream, RandomAccessFile}
+import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
import java.util.{Random, Date}
import java.text.SimpleDateFormat
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
import scala.collection.mutable.ArrayBuffer
-import spark.executor.ExecutorExitCode
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import spark.Utils
+import spark.executor.ExecutorExitCode
+import spark.serializer.{Serializer, SerializationStream}
import spark.Logging
import spark.network.netty.ShuffleSender
import spark.network.netty.PathResolver
+
/**
* Stores BlockManager blocks on disk.
*/
private class DiskStore(blockManager: BlockManager, rootDirs: String)
extends BlockStore(blockManager) with Logging {
+ class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ extends BlockObjectWriter(blockId) {
+
+ private val f: File = createFile(blockId /*, allowAppendExisting */)
+
+ // The file channel, used for repositioning / truncating the file.
+ private var channel: FileChannel = null
+ private var bs: OutputStream = null
+ private var objOut: SerializationStream = null
+ private var lastValidPosition = 0L
+
+ override def open(): DiskBlockObjectWriter = {
+ val fos = new FileOutputStream(f, true)
+ channel = fos.getChannel()
+ bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos))
+ objOut = serializer.newInstance().serializeStream(bs)
+ this
+ }
+
+ override def close() {
+ objOut.close()
+ bs.close()
+ channel = null
+ bs = null
+ objOut = null
+ // Invoke the close callback handler.
+ super.close()
+ }
+
+ override def isOpen: Boolean = objOut != null
+
+ // Flush the partial writes, and set valid length to be the length of the entire file.
+ // Return the number of bytes written for this commit.
+ override def commit(): Long = {
+ bs.flush()
+ val prevPos = lastValidPosition
+ lastValidPosition = channel.position()
+ lastValidPosition - prevPos
+ }
+
+ override def revertPartialWrites() {
+ // Discard current writes. We do this by flushing the outstanding writes and
+ // truncate the file to the last valid position.
+ bs.flush()
+ channel.truncate(lastValidPosition)
+ }
+
+ override def write(value: Any) {
+ objOut.writeObject(value)
+ }
+
+ override def size(): Long = lastValidPosition
+ }
+
val MAX_DIR_CREATION_ATTEMPTS: Int = 10
val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
@@ -42,11 +98,20 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
startShuffleBlockSender()
}
+ def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ : BlockObjectWriter = {
+ new DiskBlockObjectWriter(blockId, serializer, bufferSize)
+ }
+
+
override def getSize(blockId: String): Long = {
getFile(blockId).length()
}
- override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
+ // So that we do not modify the input offsets !
+ // duplicate does not copy buffer, so inexpensive
+ val bytes = _bytes.duplicate()
logDebug("Attempting to put block " + blockId)
val startTime = System.currentTimeMillis
val file = createFile(blockId)
@@ -60,6 +125,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
blockId, Utils.memoryBytesToString(bytes.limit), (finishTime - startTime)))
}
+ private def getFileBytes(file: File): ByteBuffer = {
+ val length = file.length()
+ val channel = new RandomAccessFile(file, "r").getChannel()
+ val buffer = try {
+ channel.map(MapMode.READ_ONLY, 0, length)
+ } finally {
+ channel.close()
+ }
+
+ buffer
+ }
+
override def putValues(
blockId: String,
values: ArrayBuffer[Any],
@@ -72,18 +149,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
val file = createFile(blockId)
val fileOut = blockManager.wrapForCompression(blockId,
new FastBufferedOutputStream(new FileOutputStream(file)))
- val objOut = blockManager.serializer.newInstance().serializeStream(fileOut)
+ val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut)
objOut.writeAll(values.iterator)
objOut.close()
val length = file.length()
+
+ val timeTaken = System.currentTimeMillis - startTime
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.memoryBytesToString(length), (System.currentTimeMillis - startTime)))
+ blockId, Utils.memoryBytesToString(length), timeTaken))
if (returnValues) {
// Return a byte buffer for the contents of the file
- val channel = new RandomAccessFile(file, "r").getChannel()
- val buffer = channel.map(MapMode.READ_ONLY, 0, length)
- channel.close()
+ val buffer = getFileBytes(file)
PutResult(length, Right(buffer))
} else {
PutResult(length, null)
@@ -92,10 +169,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
override def getBytes(blockId: String): Option[ByteBuffer] = {
val file = getFile(blockId)
- val length = file.length().toInt
- val channel = new RandomAccessFile(file, "r").getChannel()
- val bytes = channel.map(MapMode.READ_ONLY, 0, length)
- channel.close()
+ val bytes = getFileBytes(file)
Some(bytes)
}
@@ -103,11 +177,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
}
+ /**
+ * A version of getValues that allows a custom serializer. This is used as part of the
+ * shuffle short-circuit code.
+ */
+ def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer))
+ }
+
override def remove(blockId: String): Boolean = {
val file = getFile(blockId)
if (file.exists()) {
file.delete()
- true
} else {
false
}
@@ -117,9 +198,9 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
getFile(blockId).exists()
}
- private def createFile(blockId: String): File = {
+ private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = {
val file = getFile(blockId)
- if (file.exists()) {
+ if (!allowAppendExisting && file.exists()) {
throw new Exception("File for block " + blockId + " already exists on disk: " + file)
}
file
@@ -167,8 +248,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
localDir = new File(rootDir, "spark-local-" + localDirId)
if (!localDir.exists) {
- localDir.mkdirs()
- foundLocalDir = true
+ foundLocalDir = localDir.mkdirs()
}
} catch {
case e: Exception =>
@@ -186,11 +266,14 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
private def addShutdownHook() {
+ localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
override def run() {
logDebug("Shutdown hook called")
try {
- localDirs.foreach(localDir => Utils.deleteRecursively(localDir))
+ localDirs.foreach { localDir =>
+ if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ }
if (useNetty && shuffleSender != null)
shuffleSender.stop
} catch {
@@ -211,7 +294,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
thisInstance.getFile(blockId).getAbsolutePath()
}
- }
+ }
shuffleSender = new Thread {
override def run() = {
val sender = new ShuffleSender(port,pResolver)
@@ -221,7 +304,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
shuffleSender.setDaemon(true)
shuffleSender.start
-
+
} catch {
case interrupted: InterruptedException =>
logInfo("Runner thread for ShuffleBlockSender interrupted")
diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala
index 949588476c..eba5ee507f 100644
--- a/core/src/main/scala/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/spark/storage/MemoryStore.scala
@@ -31,7 +31,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
+ // Work on a duplicate - since the original input might be used elsewhere.
+ val bytes = _bytes.duplicate()
bytes.rewind()
if (level.deserialized) {
val values = blockManager.dataDeserialize(blockId, bytes)
diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
new file mode 100644
index 0000000000..49eabfb0d2
--- /dev/null
+++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
@@ -0,0 +1,50 @@
+package spark.storage
+
+import spark.serializer.Serializer
+
+
+private[spark]
+class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter])
+
+
+private[spark]
+trait ShuffleBlocks {
+ def acquireWriters(mapId: Int): ShuffleWriterGroup
+ def releaseWriters(group: ShuffleWriterGroup)
+}
+
+
+private[spark]
+class ShuffleBlockManager(blockManager: BlockManager) {
+
+ def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = {
+ new ShuffleBlocks {
+ // Get a group of writers for a map task.
+ override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
+ val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+ val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
+ blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open()
+ }
+ new ShuffleWriterGroup(mapId, writers)
+ }
+
+ override def releaseWriters(group: ShuffleWriterGroup) = {
+ // Nothing really to release here.
+ }
+ }
+ }
+}
+
+
+private[spark]
+object ShuffleBlockManager {
+
+ // Returns the block id for a given shuffle block.
+ def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = {
+ "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId
+ }
+
+ // Returns true if the block is a shuffle block.
+ def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_")
+}
diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala
index 3b5a77ab22..cc0c354e7e 100644
--- a/core/src/main/scala/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/spark/storage/StorageLevel.scala
@@ -123,11 +123,7 @@ object StorageLevel {
val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]()
private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = {
- if (storageLevelCache.containsKey(level)) {
- storageLevelCache.get(level)
- } else {
- storageLevelCache.put(level, level)
- level
- }
+ storageLevelCache.putIfAbsent(level, level)
+ storageLevelCache.get(level)
}
}
diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala
index dec47a9d41..8f52168c24 100644
--- a/core/src/main/scala/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/spark/storage/StorageUtils.scala
@@ -4,9 +4,9 @@ import spark.{Utils, SparkContext}
import BlockManagerMasterActor.BlockStatus
private[spark]
-case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
+case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
blocks: Map[String, BlockStatus]) {
-
+
def memUsed(blockPrefix: String = "") = {
blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize).
reduceOption(_+_).getOrElse(0l)
@@ -22,35 +22,40 @@ case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
}
case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
- numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) {
+ numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long)
+ extends Ordered[RDDInfo] {
override def toString = {
import Utils.memoryBytesToString
"RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id,
storageLevel.toString, numCachedPartitions, numPartitions, memoryBytesToString(memSize), memoryBytesToString(diskSize))
}
+
+ override def compare(that: RDDInfo) = {
+ this.id - that.id
+ }
}
/* Helper methods for storage-related objects */
private[spark]
object StorageUtils {
- /* Given the current storage status of the BlockManager, returns information for each RDD */
- def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus],
+ /* Given the current storage status of the BlockManager, returns information for each RDD */
+ def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus],
sc: SparkContext) : Array[RDDInfo] = {
- rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc)
+ rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc)
}
- /* Given a list of BlockStatus objets, returns information for each RDD */
- def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
+ /* Given a list of BlockStatus objets, returns information for each RDD */
+ def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
sc: SparkContext) : Array[RDDInfo] = {
// Group by rddId, ignore the partition name
- val groupedRddBlocks = infos.groupBy { case(k, v) =>
+ val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) =>
k.substring(0,k.lastIndexOf('_'))
}.mapValues(_.values.toArray)
// For each RDD, generate an RDDInfo object
- groupedRddBlocks.map { case(rddKey, rddBlocks) =>
+ val rddInfos = groupedRddBlocks.map { case(rddKey, rddBlocks) =>
// Add up memory and disk sizes
val memSize = rddBlocks.map(_.memSize).reduce(_ + _)
@@ -65,10 +70,14 @@ object StorageUtils {
RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.partitions.size, memSize, diskSize)
}.toArray
+
+ scala.util.Sorting.quickSort(rddInfos)
+
+ rddInfos
}
- /* Removes all BlockStatus object that are not part of a block prefix */
- def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus],
+ /* Removes all BlockStatus object that are not part of a block prefix */
+ def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus],
prefix: String) : Array[StorageStatus] = {
storageStatusList.map { status =>
diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala
index 3e805b7831..9fb7e001ba 100644
--- a/core/src/main/scala/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/spark/util/AkkaUtils.scala
@@ -11,7 +11,7 @@ import cc.spray.{SprayCanRootService, HttpService}
import cc.spray.can.server.HttpServer
import cc.spray.io.pipelines.MessageHandlerDispatch.SingletonHandler
import akka.dispatch.Await
-import spark.SparkException
+import spark.{Utils, SparkException}
import java.util.concurrent.TimeoutException
/**
@@ -31,7 +31,10 @@ private[spark] object AkkaUtils {
val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt
val akkaTimeout = System.getProperty("spark.akka.timeout", "20").toInt
val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt
- val lifecycleEvents = System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean
+ val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off"
+ // 10 seconds is the default akka timeout, but in a cluster, we need higher by default.
+ val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt
+
val akkaConf = ConfigFactory.parseString("""
akka.daemonic = on
akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
@@ -45,8 +48,9 @@ private[spark] object AkkaUtils {
akka.remote.netty.execution-pool-size = %d
akka.actor.default-dispatcher.throughput = %d
akka.remote.log-remote-lifecycle-events = %s
+ akka.remote.netty.write-timeout = %ds
""".format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize,
- if (lifecycleEvents) "on" else "off"))
+ lifecycleEvents, akkaWriteTimeout))
val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader)
@@ -60,8 +64,9 @@ private[spark] object AkkaUtils {
/**
* Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to
* handle requests. Returns the bound port or throws a SparkException on failure.
+ * TODO: Not changing ip to host here - is it required ?
*/
- def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route,
+ def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route,
name: String = "HttpServer"): ActorRef = {
val ioWorker = new IoWorker(actorSystem).start()
val httpService = actorSystem.actorOf(Props(new HttpService(route)))
diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala
index 188f8910da..92dfaa6e6f 100644
--- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala
@@ -3,6 +3,7 @@ package spark.util
import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConversions
import scala.collection.mutable.Map
+import spark.scheduler.MapStatus
/**
* This is a custom implementation of scala.collection.mutable.Map which stores the insertion
@@ -42,6 +43,13 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging {
this
}
+ // Should we return previous value directly or as Option ?
+ def putIfAbsent(key: A, value: B): Option[B] = {
+ val prev = internalMap.putIfAbsent(key, (value, currentTime))
+ if (prev != null) Some(prev._1) else None
+ }
+
+
override def -= (key: A): this.type = {
internalMap.remove(key)
this
diff --git a/core/src/main/scala/spark/util/TimedIterator.scala b/core/src/main/scala/spark/util/TimedIterator.scala
deleted file mode 100644
index 539b01f4ce..0000000000
--- a/core/src/main/scala/spark/util/TimedIterator.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-package spark.util
-
-/**
- * A utility for tracking the total time an iterator takes to iterate through its elements.
- *
- * In general, this should only be used if you expect it to take a considerable amount of time
- * (eg. milliseconds) to get each element -- otherwise, the timing won't be very accurate,
- * and you are probably just adding more overhead
- */
-class TimedIterator[+A](val sub: Iterator[A]) extends Iterator[A] {
- private var netMillis = 0l
- private var nElems = 0
- def hasNext = {
- val start = System.currentTimeMillis()
- val r = sub.hasNext
- val end = System.currentTimeMillis()
- netMillis += (end - start)
- r
- }
- def next = {
- val start = System.currentTimeMillis()
- val r = sub.next
- val end = System.currentTimeMillis()
- netMillis += (end - start)
- nElems += 1
- r
- }
-
- def getNetMillis = netMillis
- def getAverageTimePerItem = netMillis / nElems.toDouble
-
-}
diff --git a/core/src/main/twirl/spark/deploy/master/app_details.scala.html b/core/src/main/twirl/spark/deploy/master/app_details.scala.html
index 301a7e2124..66147e213f 100644
--- a/core/src/main/twirl/spark/deploy/master/app_details.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/app_details.scala.html
@@ -9,15 +9,12 @@
<li><strong>ID:</strong> @app.id</li>
<li><strong>Description:</strong> @app.desc.name</li>
<li><strong>User:</strong> @app.desc.user</li>
- <li><strong>Cores:</strong>
- @app.desc.cores
- (@app.coresGranted Granted
- @if(app.desc.cores == Integer.MAX_VALUE) {
-
+ <li><strong>Cores:</strong>
+ @if(app.desc.maxCores == Integer.MAX_VALUE) {
+ Unlimited (@app.coresGranted granted)
} else {
- , @app.coresLeft
+ @app.desc.maxCores (@app.coresGranted granted, @app.coresLeft left)
}
- )
</li>
<li><strong>Memory per Slave:</strong> @app.desc.memoryPerSlave</li>
<li><strong>Submit Date:</strong> @app.submitDate</li>
diff --git a/core/src/main/twirl/spark/deploy/master/executor_row.scala.html b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html
index d2d80fad48..21e72c7aab 100644
--- a/core/src/main/twirl/spark/deploy/master/executor_row.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html
@@ -3,7 +3,7 @@
<tr>
<td>@executor.id</td>
<td>
- <a href="@executor.worker.webUiAddress">@executor.worker.id</href>
+ <a href="@executor.worker.webUiAddress">@executor.worker.id</a>
</td>
<td>@executor.cores</td>
<td>@executor.memory</td>
diff --git a/core/src/main/twirl/spark/deploy/master/index.scala.html b/core/src/main/twirl/spark/deploy/master/index.scala.html
index ac51a39a51..b9b9f08810 100644
--- a/core/src/main/twirl/spark/deploy/master/index.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/index.scala.html
@@ -2,7 +2,7 @@
@import spark.deploy.master._
@import spark.Utils
-@spark.common.html.layout(title = "Spark Master on " + state.host) {
+@spark.common.html.layout(title = "Spark Master on " + state.host + ":" + state.port) {
<!-- Cluster Details -->
<div class="row">
diff --git a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html
index be69e9bf02..46277ca421 100644
--- a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html
@@ -4,7 +4,7 @@
<tr>
<td>
- <a href="@worker.webUiAddress">@worker.id</href>
+ <a href="@worker.webUiAddress">@worker.id</a>
</td>
<td>@{worker.host}:@{worker.port}</td>
<td>@worker.state</td>
diff --git a/core/src/main/twirl/spark/deploy/worker/index.scala.html b/core/src/main/twirl/spark/deploy/worker/index.scala.html
index c39f769a73..0e66af9284 100644
--- a/core/src/main/twirl/spark/deploy/worker/index.scala.html
+++ b/core/src/main/twirl/spark/deploy/worker/index.scala.html
@@ -1,7 +1,7 @@
@(worker: spark.deploy.WorkerState)
@import spark.Utils
-@spark.common.html.layout(title = "Spark Worker on " + worker.host) {
+@spark.common.html.layout(title = "Spark Worker on " + worker.host + ":" + worker.port) {
<!-- Worker Details -->
<div class="row">
diff --git a/core/src/main/twirl/spark/storage/worker_table.scala.html b/core/src/main/twirl/spark/storage/worker_table.scala.html
index d54b8de4cc..cd72a688c1 100644
--- a/core/src/main/twirl/spark/storage/worker_table.scala.html
+++ b/core/src/main/twirl/spark/storage/worker_table.scala.html
@@ -12,7 +12,7 @@
<tbody>
@for(status <- workersStatusList) {
<tr>
- <td>@(status.blockManagerId.ip + ":" + status.blockManagerId.port)</td>
+ <td>@(status.blockManagerId.host + ":" + status.blockManagerId.port)</td>
<td>
@(Utils.memoryBytesToString(status.memUsed(prefix)))
(@(Utils.memoryBytesToString(status.memRemaining)) Total Available)
diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala
index 4104b33c8b..4df3bb5b67 100644
--- a/core/src/test/scala/spark/DistributedSuite.scala
+++ b/core/src/test/scala/spark/DistributedSuite.scala
@@ -3,8 +3,10 @@ package spark
import network.ConnectionManagerId
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
+import org.scalatest.concurrent.Timeouts._
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.prop.Checkers
+import org.scalatest.time.{Span, Millis}
import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
import org.scalacheck.Prop._
@@ -153,7 +155,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
val blockManager = SparkEnv.get.blockManager
blockManager.master.getLocations(blockId).foreach(id => {
val bytes = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(id.ip, id.port))
+ GetBlock(blockId), ConnectionManagerId(id.host, id.port))
val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList
assert(deserialized === (1 to 100).toList)
})
@@ -252,12 +254,35 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
assert(data2.count === 2)
}
}
+
+ test("unpersist RDDs") {
+ DistributedSuite.amMaster = true
+ sc = new SparkContext("local-cluster[3,1,512]", "test")
+ val data = sc.parallelize(Seq(true, false, false, false), 4)
+ data.persist(StorageLevel.MEMORY_ONLY_2)
+ data.count
+ assert(sc.persistentRdds.isEmpty === false)
+ data.unpersist()
+ assert(sc.persistentRdds.isEmpty === true)
+
+ failAfter(Span(3000, Millis)) {
+ try {
+ while (! sc.getRDDStorageInfo.isEmpty) {
+ Thread.sleep(200)
+ }
+ } catch {
+ case _ => { Thread.sleep(10) }
+ // Do nothing. We might see exceptions because block manager
+ // is racing this thread to remove entries from the driver.
+ }
+ }
+ }
}
object DistributedSuite {
// Indicates whether this JVM is marked for failure.
var mark = false
-
+
// Set by test to remember if we are in the driver program so we can assert
// that we are not.
var amMaster = false
@@ -274,9 +299,9 @@ object DistributedSuite {
// Act like an identity function, but if mark was set to true previously, fail,
// crashing the entire JVM.
def failOnMarkedIdentity(item: Boolean): Boolean = {
- if (mark) {
+ if (mark) {
System.exit(42)
- }
+ }
item
- }
+ }
}
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index d3dcd3bbeb..93bb69b41c 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -633,6 +633,32 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void zipPartitions() {
+ JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2);
+ JavaRDD<String> rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2);
+ FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer> sizesFn =
+ new FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer>() {
+ @Override
+ public Iterable<Integer> call(Iterator<Integer> i, Iterator<String> s) {
+ int sizeI = 0;
+ int sizeS = 0;
+ while (i.hasNext()) {
+ sizeI += 1;
+ i.next();
+ }
+ while (s.hasNext()) {
+ sizeS += 1;
+ s.next();
+ }
+ return Arrays.asList(sizeI, sizeS);
+ }
+ };
+
+ JavaRDD<Integer> sizes = rdd1.zipPartitions(sizesFn, rdd2);
+ Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString());
+ }
+
+ @Test
public void accumulators() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala
index ff00dd05dd..76d5258b02 100644
--- a/core/src/test/scala/spark/LocalSparkContext.scala
+++ b/core/src/test/scala/spark/LocalSparkContext.scala
@@ -27,6 +27,7 @@ object LocalSparkContext {
sc.stop()
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
}
/** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */
@@ -38,4 +39,4 @@ object LocalSparkContext {
}
}
-} \ No newline at end of file
+}
diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
index 3abc584b6a..e95818db61 100644
--- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
@@ -81,6 +81,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("remote fetch") {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0)
+ System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ System.setProperty("spark.hostPort", "localhost:" + boundPort)
+
val masterTracker = new MapOutputTracker()
masterTracker.trackerActor = actorSystem.actorOf(
Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker")
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index 7fbdd44340..cee6312572 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -2,6 +2,8 @@ package spark
import scala.collection.mutable.HashMap
import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts._
+import org.scalatest.time.{Span, Millis}
import spark.SparkContext._
import spark.rdd.{CoalescedRDD, CoGroupedRDD, PartitionPruningRDD, ShuffledRDD}
@@ -100,6 +102,28 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(rdd.collect().toList === List(1, 2, 3, 4))
}
+ test("unpersist RDD") {
+ sc = new SparkContext("local", "test")
+ val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
+ rdd.count
+ assert(sc.persistentRdds.isEmpty == false)
+ rdd.unpersist()
+ assert(sc.persistentRdds.isEmpty == true)
+
+ failAfter(Span(3000, Millis)) {
+ try {
+ while (! sc.getRDDStorageInfo.isEmpty) {
+ Thread.sleep(200)
+ }
+ } catch {
+ case e: Exception =>
+ // Do nothing. We might see exceptions because block manager
+ // is racing this thread to remove entries from the driver.
+ }
+ }
+ assert(sc.getRDDStorageInfo.isEmpty == true)
+ }
+
test("caching with failures") {
sc = new SparkContext("local", "test")
val onlySplit = new Partition { override def index: Int = 0 }
diff --git a/core/src/test/scala/spark/ZippedPartitionsSuite.scala b/core/src/test/scala/spark/ZippedPartitionsSuite.scala
new file mode 100644
index 0000000000..5f60aa75d7
--- /dev/null
+++ b/core/src/test/scala/spark/ZippedPartitionsSuite.scala
@@ -0,0 +1,34 @@
+package spark
+
+import scala.collection.immutable.NumericRange
+
+import org.scalatest.FunSuite
+import org.scalatest.prop.Checkers
+import org.scalacheck.Arbitrary._
+import org.scalacheck.Gen
+import org.scalacheck.Prop._
+
+import SparkContext._
+
+
+object ZippedPartitionsSuite {
+ def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = {
+ Iterator(i.toArray.size, s.toArray.size, d.toArray.size)
+ }
+}
+
+class ZippedPartitionsSuite extends FunSuite with LocalSparkContext {
+ test("print sizes") {
+ sc = new SparkContext("local", "test")
+ val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2)
+ val data3 = sc.makeRDD(Array(1.0, 2.0), 2)
+
+ val zippedRDD = data1.zipPartitions(ZippedPartitionsSuite.procZippedData, data2, data3)
+
+ val obtainedSizes = zippedRDD.collect()
+ val expectedSizes = Array(2, 3, 1, 2, 3, 1)
+ assert(obtainedSizes.size == 6)
+ assert(obtainedSizes.zip(expectedSizes).forall(x => x._1 == x._2))
+ }
+}
diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
index 6da58a0f6e..c0f8986de8 100644
--- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -271,7 +271,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
// have the 2nd attempt pass
complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
// we can see both result blocks now
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.ip) === Array("hostA", "hostB"))
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB"))
complete(taskSets(3), Seq((Success, 43)))
assert(results === Map(0 -> 42, 1 -> 43))
}
diff --git a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
index 2f5af10e69..42a87d8b90 100644
--- a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
@@ -57,7 +57,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
taskMetrics.shuffleReadMetrics should be ('defined)
val sm = taskMetrics.shuffleReadMetrics.get
sm.totalBlocksFetched should be > (0)
- sm.shuffleReadMillis should be > (0l)
sm.localBlocksFetched should be > (0)
sm.remoteBlocksFetched should be (0)
sm.remoteBytesRead should be (0l)
diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
index b8c0f6fb76..9fe0de665c 100644
--- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
@@ -15,6 +15,8 @@ import org.scalatest.time.SpanSugar._
import spark.JavaSerializer
import spark.KryoSerializer
import spark.SizeEstimator
+import spark.Utils
+import spark.util.AkkaUtils
import spark.util.ByteBufferInputStream
class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester {
@@ -31,7 +33,11 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
val serializer = new KryoSerializer
before {
- actorSystem = ActorSystem("test")
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0)
+ this.actorSystem = actorSystem
+ System.setProperty("spark.driver.port", boundPort.toString)
+ System.setProperty("spark.hostPort", "localhost:" + boundPort)
+
master = new BlockManagerMaster(
actorSystem.actorOf(Props(new spark.storage.BlockManagerMasterActor(true))))
@@ -44,6 +50,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
after {
+ System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
+
if (store != null) {
store.stop()
store = null
@@ -198,6 +207,31 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
}
+ test("removing rdd") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+ // Putting a1, a2 and a3 in memory.
+ store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY)
+ master.removeRdd(0)
+
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("rdd_0_0") should be (None)
+ master.getLocations("rdd_0_0") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("rdd_0_1") should be (None)
+ master.getLocations("rdd_0_1") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("nonrddblock") should not be (None)
+ master.getLocations("nonrddblock") should have size (1)
+ }
+ }
+
test("reregistration on heart beat") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
diff --git a/docs/_config.yml b/docs/_config.yml
index f99d5bb376..5c135a0242 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -3,8 +3,8 @@ markdown: kramdown
# These allow the documentation to be updated with nerw releases
# of Spark, Scala, and Mesos.
-SPARK_VERSION: 0.7.1-SNAPSHOT
-SPARK_VERSION_SHORT: 0.7.1
-SCALA_VERSION: 2.9.2
+SPARK_VERSION: 0.8.0-SNAPSHOT
+SPARK_VERSION_SHORT: 0.8.0
+SCALA_VERSION: 2.9.3
MESOS_VERSION: 0.9.0-incubating
SPARK_ISSUE_TRACKER_URL: https://spark-project.atlassian.net
diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md
index c2eeafd07a..04cd79d039 100644
--- a/docs/building-with-maven.md
+++ b/docs/building-with-maven.md
@@ -42,10 +42,10 @@ To run a specific test suite:
You might run into the following errors if you're using a vanilla installation of Maven:
- [INFO] Compiling 203 Scala sources and 9 Java sources to /Users/andyk/Development/spark/core/target/scala-2.9.2/classes...
+ [INFO] Compiling 203 Scala sources and 9 Java sources to /Users/me/Development/spark/core/target/scala-{{site.SCALA_VERSION}}/classes...
[ERROR] PermGen space -> [Help 1]
- [INFO] Compiling 203 Scala sources and 9 Java sources to /Users/andyk/Development/spark/core/target/scala-2.9.2/classes...
+ [INFO] Compiling 203 Scala sources and 9 Java sources to /Users/me/Development/spark/core/target/scala-{{site.SCALA_VERSION}}/classes...
[ERROR] Java heap space -> [Help 1]
To fix these, you can do the following:
diff --git a/docs/index.md b/docs/index.md
index 51d505e1fa..0c4add45dc 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -18,7 +18,7 @@ or you will need to set the `SCALA_HOME` environment variable to point
to where you've installed Scala. Scala must also be accessible through one
of these methods on slave nodes on your cluster.
-Spark uses [Simple Build Tool](https://github.com/harrah/xsbt/wiki), which is bundled with it. To compile the code, go into the top-level Spark directory and run
+Spark uses [Simple Build Tool](http://www.scala-sbt.org), which is bundled with it. To compile the code, go into the top-level Spark directory and run
sbt/sbt package
diff --git a/docs/quick-start.md b/docs/quick-start.md
index 5c80d2ed3a..335643536a 100644
--- a/docs/quick-start.md
+++ b/docs/quick-start.md
@@ -53,8 +53,8 @@ scala> textFile.filter(line => line.contains("Spark")).count() // How many lines
res3: Long = 15
{% endhighlight %}
-## Transformations
-RDD transformations can be used for more complex computations. Let's say we want to find the line with the most words:
+## More On RDD Operations
+RDD actions and transformations can be used for more complex computations. Let's say we want to find the line with the most words:
{% highlight scala %}
scala> textFile.map(line => line.split(" ").size).reduce((a, b) => if (a > b) a else b)
@@ -113,8 +113,8 @@ import SparkContext._
object SimpleJob {
def main(args: Array[String]) {
- val logFile = "/var/log/syslog" // Should be some file on your system
- val sc = new SparkContext("local", "Simple Job", "$YOUR_SPARK_HOME",
+ val logFile = "$YOUR_SPARK_HOME/README.md" // Should be some file on your system
+ val sc = new SparkContext("local", "Simple Job", "YOUR_SPARK_HOME",
List("target/scala-{{site.SCALA_VERSION}}/simple-project_{{site.SCALA_VERSION}}-1.0.jar"))
val logData = sc.textFile(logFile, 2).cache()
val numAs = logData.filter(line => line.contains("a")).count()
@@ -124,7 +124,7 @@ object SimpleJob {
}
{% endhighlight %}
-This job simply counts the number of lines containing 'a' and the number containing 'b' in a system log file. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext, we initialize a SparkContext as part of the job. We pass the SparkContext constructor four arguments, the type of scheduler we want to use (in this case, a local scheduler), a name for the job, the directory where Spark is installed, and a name for the jar file containing the job's sources. The final two arguments are needed in a distributed setting, where Spark is running across several nodes, so we include them for completeness. Spark will automatically ship the jar files you list to slave nodes.
+This job simply counts the number of lines containing 'a' and the number containing 'b' in the Spark README. Note that you'll need to replace $YOUR_SPARK_HOME with the location where Spark is installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext, we initialize a SparkContext as part of the job. We pass the SparkContext constructor four arguments, the type of scheduler we want to use (in this case, a local scheduler), a name for the job, the directory where Spark is installed, and a name for the jar file containing the job's sources. The final two arguments are needed in a distributed setting, where Spark is running across several nodes, so we include them for completeness. Spark will automatically ship the jar files you list to slave nodes.
This file depends on the Spark API, so we'll also include an sbt configuration file, `simple.sbt` which explains that Spark is a dependency. This file also adds two repositories which host Spark dependencies:
@@ -156,7 +156,7 @@ $ find .
$ sbt package
$ sbt run
...
-Lines with a: 8422, Lines with b: 1836
+Lines with a: 46, Lines with b: 23
{% endhighlight %}
This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS.
@@ -173,7 +173,7 @@ import spark.api.java.function.Function;
public class SimpleJob {
public static void main(String[] args) {
- String logFile = "/var/log/syslog"; // Should be some file on your system
+ String logFile = "$YOUR_SPARK_HOME/README.md"; // Should be some file on your system
JavaSparkContext sc = new JavaSparkContext("local", "Simple Job",
"$YOUR_SPARK_HOME", new String[]{"target/simple-project-1.0.jar"});
JavaRDD<String> logData = sc.textFile(logFile).cache();
@@ -191,7 +191,7 @@ public class SimpleJob {
}
{% endhighlight %}
-This job simply counts the number of lines containing 'a' and the number containing 'b' in a system log file. Note that like in the Scala example, we initialize a SparkContext, though we use the special `JavaSparkContext` class to get a Java-friendly one. We also create RDDs (represented by `JavaRDD`) and run transformations on them. Finally, we pass functions to Spark by creating classes that extend `spark.api.java.function.Function`. The [Java programming guide](java-programming-guide.html) describes these differences in more detail.
+This job simply counts the number of lines containing 'a' and the number containing 'b' in a system log file. Note that you'll need to replace $YOUR_SPARK_HOME with the location where Spark is installed. As with the Scala example, we initialize a SparkContext, though we use the special `JavaSparkContext` class to get a Java-friendly one. We also create RDDs (represented by `JavaRDD`) and run transformations on them. Finally, we pass functions to Spark by creating classes that extend `spark.api.java.function.Function`. The [Java programming guide](java-programming-guide.html) describes these differences in more detail.
To build the job, we also write a Maven `pom.xml` file that lists Spark as a dependency. Note that Spark artifacts are tagged with a Scala version.
@@ -239,7 +239,7 @@ Now, we can execute the job using Maven:
$ mvn package
$ mvn exec:java -Dexec.mainClass="SimpleJob"
...
-Lines with a: 8422, Lines with b: 1836
+Lines with a: 46, Lines with b: 23
{% endhighlight %}
This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS.
@@ -253,7 +253,7 @@ As an example, we'll create a simple Spark job, `SimpleJob.py`:
"""SimpleJob.py"""
from pyspark import SparkContext
-logFile = "/var/log/syslog" # Should be some file on your system
+logFile = "$YOUR_SPARK_HOME/README.md" # Should be some file on your system
sc = SparkContext("local", "Simple job")
logData = sc.textFile(logFile).cache()
@@ -265,7 +265,8 @@ print "Lines with a: %i, lines with b: %i" % (numAs, numBs)
This job simply counts the number of lines containing 'a' and the number containing 'b' in a system log file.
-Like in the Scala and Java examples, we use a SparkContext to create RDDs.
+Note that you'll need to replace $YOUR_SPARK_HOME with the location where Spark is installed.
+As with the Scala and Java examples, we use a SparkContext to create RDDs.
We can pass Python functions to Spark, which are automatically serialized along with any variables that they reference.
For jobs that use custom classes or third-party libraries, we can add those code dependencies to SparkContext to ensure that they will be available on remote machines; this is described in more detail in the [Python programming guide](python-programming-guide.html).
`SimpleJob` is simple enough that we do not need to specify any code dependencies.
@@ -276,7 +277,7 @@ We can run this job using the `pyspark` script:
$ cd $SPARK_HOME
$ ./pyspark SimpleJob.py
...
-Lines with a: 8422, Lines with b: 1836
+Lines with a: 46, Lines with b: 23
{% endhighlight python %}
This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS.
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index c2957e6cb4..26424bbe52 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -5,18 +5,25 @@ title: Launching Spark on YARN
Experimental support for running over a [YARN (Hadoop
NextGen)](http://hadoop.apache.org/docs/r2.0.2-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html)
-cluster was added to Spark in version 0.6.0. Because YARN depends on version
-2.0 of the Hadoop libraries, this currently requires checking out a separate
-branch of Spark, called `yarn`, which you can do as follows:
+cluster was added to Spark in version 0.6.0. This was merged into master as part of 0.7 effort.
+To build spark core with YARN support, please use the hadoop2-yarn profile.
+Ex: mvn -Phadoop2-yarn clean install
- git clone git://github.com/mesos/spark
- cd spark
- git checkout -b yarn --track origin/yarn
+# Building spark core consolidated jar.
+
+Currently, only sbt can buid a consolidated jar which contains the entire spark code - which is required for launching jars on yarn.
+To do this via sbt - though (right now) is a manual process of enabling it in project/SparkBuild.scala.
+Please comment out the
+ HADOOP_VERSION, HADOOP_MAJOR_VERSION and HADOOP_YARN
+variables before the line 'For Hadoop 2 YARN support'
+Next, uncomment the subsequent 3 variable declaration lines (for these three variables) which enable hadoop yarn support.
+
+Currnetly, it is a TODO to add support for maven assembly.
# Preparations
-- In order to distribute Spark within the cluster, it must be packaged into a single JAR file. This can be done by running `sbt/sbt assembly`
+- Building spark core assembled jar (see above).
- Your application code must be packaged into a separate JAR file.
If you want to test out the YARN deployment mode, you can use the current Spark examples. A `spark-examples_{{site.SCALA_VERSION}}-{{site.SPARK_VERSION}}` file can be generated by running `sbt/sbt package`. NOTE: since the documentation you're reading is for Spark version {{site.SPARK_VERSION}}, we are assuming here that you have downloaded Spark {{site.SPARK_VERSION}} or checked it out of source control. If you are using a different version of Spark, the version numbers in the jar generated by the sbt package command will obviously be different.
@@ -30,8 +37,11 @@ The command to launch the YARN Client is as follows:
--class <APP_MAIN_CLASS> \
--args <APP_MAIN_ARGUMENTS> \
--num-workers <NUMBER_OF_WORKER_MACHINES> \
+ --master-memory <MEMORY_FOR_MASTER> \
--worker-memory <MEMORY_PER_WORKER> \
- --worker-cores <CORES_PER_WORKER>
+ --worker-cores <CORES_PER_WORKER> \
+ --user <hadoop_user> \
+ --queue <queue_name>
For example:
@@ -40,8 +50,9 @@ For example:
--class spark.examples.SparkPi \
--args standalone \
--num-workers 3 \
+ --master-memory 4g \
--worker-memory 2g \
- --worker-cores 2
+ --worker-cores 1
The above starts a YARN Client programs which periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running.
@@ -49,3 +60,5 @@ The above starts a YARN Client programs which periodically polls the Application
- When your application instantiates a Spark context it must use a special "standalone" master url. This starts the scheduler without forcing it to connect to a cluster. A good way to handle this is to pass "standalone" as an argument to your program, as shown in the example above.
- YARN does not support requesting container resources based on the number of cores. Thus the numbers of cores given via command line arguments cannot be guaranteed.
+- Currently, we have not yet integrated with hadoop security. If --user is present, the hadoop_user specified will be used to run the tasks on the cluster. If unspecified, current user will be used (which should be valid in cluster).
+ Once hadoop security support is added, and if hadoop cluster is enabled with security, additional restrictions would apply via delegation tokens passed.
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index b30699cf3d..f5788dc467 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -83,7 +83,7 @@ DStreams support many of the transformations available on normal Spark RDD's:
<tr>
<td> <b>groupByKey</b>([<i>numTasks</i>]) </td>
<td> When called on a DStream of (K, V) pairs, returns a new DStream of (K, Seq[V]) pairs by grouping together all the values of each key in the RDDs of the source DStream. <br />
- <b>Note:</b> By default, this uses Spark's default number of parallel tasks (2 for local machine, 8 for a cluser) to do the grouping. You can pass an optional <code>numTasks</code> argument to set a different number of tasks.
+ <b>Note:</b> By default, this uses Spark's default number of parallel tasks (2 for local machine, 8 for a cluster) to do the grouping. You can pass an optional <code>numTasks</code> argument to set a different number of tasks.
</td>
</tr>
<tr>
@@ -132,7 +132,7 @@ Spark Streaming features windowed computations, which allow you to apply transfo
<td> <b>groupByKeyAndWindow</b>(<i>windowDuration</i>, <i>slideDuration</i>, [<i>numTasks</i>])
</td>
<td> When called on a DStream of (K, V) pairs, returns a new DStream of (K, Seq[V]) pairs by grouping together values of each key over batches in a sliding window. <br />
-<b>Note:</b> By default, this uses Spark's default number of parallel tasks (2 for local machine, 8 for a cluser) to do the grouping. You can pass an optional <code>numTasks</code> argument to set a different number of tasks.</td>
+<b>Note:</b> By default, this uses Spark's default number of parallel tasks (2 for local machine, 8 for a cluster) to do the grouping. You can pass an optional <code>numTasks</code> argument to set a different number of tasks.</td>
</tr>
<tr>
<td> <b>reduceByKeyAndWindow</b>(<i>func</i>, <i>windowDuration</i>, <i>slideDuration</i>, [<i>numTasks</i>]) </td>
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 571d27fde6..9f2daad2b6 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -540,11 +540,24 @@ def scp(host, opts, local_file, dest_file):
(opts.identity_file, local_file, opts.user, host, dest_file), shell=True)
-# Run a command on a host through ssh, throwing an exception if ssh fails
+# Run a command on a host through ssh, retrying up to two times
+# and then throwing an exception if ssh continues to fail.
def ssh(host, opts, command):
- subprocess.check_call(
- "ssh -t -o StrictHostKeyChecking=no -i %s %s@%s '%s'" %
- (opts.identity_file, opts.user, host, command), shell=True)
+ tries = 0
+ while True:
+ try:
+ return subprocess.check_call(
+ "ssh -t -o StrictHostKeyChecking=no -i %s %s@%s '%s'" %
+ (opts.identity_file, opts.user, host, command), shell=True)
+ except subprocess.CalledProcessError as e:
+ if (tries > 2):
+ raise e
+ print "Error connecting to host {0}, sleeping 30".format(e)
+ time.sleep(30)
+ tries = tries + 1
+
+
+
# Gets a list of zones to launch instances in
diff --git a/examples/pom.xml b/examples/pom.xml
index 39cc47c709..c42d2bcdb9 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -4,7 +4,7 @@
<parent>
<groupId>org.spark-project</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.7.1-SNAPSHOT</version>
+ <version>0.8.0-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
@@ -22,7 +22,7 @@
<dependency>
<groupId>com.twitter</groupId>
<artifactId>algebird-core_2.9.2</artifactId>
- <version>0.1.8</version>
+ <version>0.1.11</version>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
@@ -118,5 +118,48 @@
</plugins>
</build>
</profile>
+ <profile>
+ <id>hadoop2-yarn</id>
+ <dependencies>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-core</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-streaming</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-common</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <configuration>
+ <classifier>hadoop2-yarn</classifier>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
</profiles>
</project>
diff --git a/examples/src/main/scala/spark/examples/LocalKMeans.scala b/examples/src/main/scala/spark/examples/LocalKMeans.scala
index b07e799cef..4849f216fb 100644
--- a/examples/src/main/scala/spark/examples/LocalKMeans.scala
+++ b/examples/src/main/scala/spark/examples/LocalKMeans.scala
@@ -10,73 +10,73 @@ import scala.collection.mutable.HashSet
* K-means clustering.
*/
object LocalKMeans {
- val N = 1000
- val R = 1000 // Scaling factor
- val D = 10
- val K = 10
- val convergeDist = 0.001
- val rand = new Random(42)
-
- def generateData = {
- def generatePoint(i: Int) = {
- Vector(D, _ => rand.nextDouble * R)
- }
- Array.tabulate(N)(generatePoint)
- }
-
- def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = {
- var index = 0
- var bestIndex = 0
- var closest = Double.PositiveInfinity
-
- for (i <- 1 to centers.size) {
- val vCurr = centers.get(i).get
- val tempDist = p.squaredDist(vCurr)
- if (tempDist < closest) {
- closest = tempDist
- bestIndex = i
- }
- }
-
- return bestIndex
- }
-
- def main(args: Array[String]) {
- val data = generateData
- var points = new HashSet[Vector]
- var kPoints = new HashMap[Int, Vector]
- var tempDist = 1.0
-
- while (points.size < K) {
- points.add(data(rand.nextInt(N)))
- }
-
- val iter = points.iterator
- for (i <- 1 to points.size) {
- kPoints.put(i, iter.next())
- }
-
- println("Initial centers: " + kPoints)
-
- while(tempDist > convergeDist) {
- var closest = data.map (p => (closestPoint(p, kPoints), (p, 1)))
-
- var mappings = closest.groupBy[Int] (x => x._1)
-
- var pointStats = mappings.map(pair => pair._2.reduceLeft [(Int, (Vector, Int))] {case ((id1, (x1, y1)), (id2, (x2, y2))) => (id1, (x1 + x2, y1+y2))})
-
- var newPoints = pointStats.map {mapping => (mapping._1, mapping._2._1/mapping._2._2)}
-
- tempDist = 0.0
- for (mapping <- newPoints) {
- tempDist += kPoints.get(mapping._1).get.squaredDist(mapping._2)
- }
-
- for (newP <- newPoints) {
- kPoints.put(newP._1, newP._2)
- }
- }
-
- println("Final centers: " + kPoints)
- }
+ val N = 1000
+ val R = 1000 // Scaling factor
+ val D = 10
+ val K = 10
+ val convergeDist = 0.001
+ val rand = new Random(42)
+
+ def generateData = {
+ def generatePoint(i: Int) = {
+ Vector(D, _ => rand.nextDouble * R)
+ }
+ Array.tabulate(N)(generatePoint)
+ }
+
+ def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = {
+ var index = 0
+ var bestIndex = 0
+ var closest = Double.PositiveInfinity
+
+ for (i <- 1 to centers.size) {
+ val vCurr = centers.get(i).get
+ val tempDist = p.squaredDist(vCurr)
+ if (tempDist < closest) {
+ closest = tempDist
+ bestIndex = i
+ }
+ }
+
+ return bestIndex
+ }
+
+ def main(args: Array[String]) {
+ val data = generateData
+ var points = new HashSet[Vector]
+ var kPoints = new HashMap[Int, Vector]
+ var tempDist = 1.0
+
+ while (points.size < K) {
+ points.add(data(rand.nextInt(N)))
+ }
+
+ val iter = points.iterator
+ for (i <- 1 to points.size) {
+ kPoints.put(i, iter.next())
+ }
+
+ println("Initial centers: " + kPoints)
+
+ while(tempDist > convergeDist) {
+ var closest = data.map (p => (closestPoint(p, kPoints), (p, 1)))
+
+ var mappings = closest.groupBy[Int] (x => x._1)
+
+ var pointStats = mappings.map(pair => pair._2.reduceLeft [(Int, (Vector, Int))] {case ((id1, (x1, y1)), (id2, (x2, y2))) => (id1, (x1 + x2, y1+y2))})
+
+ var newPoints = pointStats.map {mapping => (mapping._1, mapping._2._1/mapping._2._2)}
+
+ tempDist = 0.0
+ for (mapping <- newPoints) {
+ tempDist += kPoints.get(mapping._1).get.squaredDist(mapping._2)
+ }
+
+ for (newP <- newPoints) {
+ kPoints.put(newP._1, newP._2)
+ }
+ }
+
+ println("Final centers: " + kPoints)
+ }
}
diff --git a/examples/src/main/scala/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/spark/examples/MultiBroadcastTest.scala
index 92cd81c487..a0aaf60918 100644
--- a/examples/src/main/scala/spark/examples/MultiBroadcastTest.scala
+++ b/examples/src/main/scala/spark/examples/MultiBroadcastTest.scala
@@ -8,7 +8,7 @@ object MultiBroadcastTest {
System.err.println("Usage: BroadcastTest <master> [<slices>] [numElem]")
System.exit(1)
}
-
+
val sc = new SparkContext(args(0), "Broadcast Test",
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
@@ -19,7 +19,7 @@ object MultiBroadcastTest {
for (i <- 0 until arr1.length) {
arr1(i) = i
}
-
+
var arr2 = new Array[Int](num)
for (i <- 0 until arr2.length) {
arr2(i) = i
@@ -30,7 +30,7 @@ object MultiBroadcastTest {
sc.parallelize(1 to 10, slices).foreach {
i => println(barr1.value.size + barr2.value.size)
}
-
+
System.exit(0)
}
}
diff --git a/examples/src/main/scala/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/spark/examples/SimpleSkewedGroupByTest.scala
index 0d17bda004..461b84a2c6 100644
--- a/examples/src/main/scala/spark/examples/SimpleSkewedGroupByTest.scala
+++ b/examples/src/main/scala/spark/examples/SimpleSkewedGroupByTest.scala
@@ -11,7 +11,7 @@ object SimpleSkewedGroupByTest {
"[numMappers] [numKVPairs] [valSize] [numReducers] [ratio]")
System.exit(1)
}
-
+
var numMappers = if (args.length > 1) args(1).toInt else 2
var numKVPairs = if (args.length > 2) args(2).toInt else 1000
var valSize = if (args.length > 3) args(3).toInt else 1000
@@ -20,7 +20,7 @@ object SimpleSkewedGroupByTest {
val sc = new SparkContext(args(0), "GroupBy Test",
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
-
+
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
var result = new Array[(Int, Array[Byte])](numKVPairs)
diff --git a/examples/src/main/scala/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/spark/examples/SkewedGroupByTest.scala
index 83be3fc27b..435675f9de 100644
--- a/examples/src/main/scala/spark/examples/SkewedGroupByTest.scala
+++ b/examples/src/main/scala/spark/examples/SkewedGroupByTest.scala
@@ -10,7 +10,7 @@ object SkewedGroupByTest {
System.err.println("Usage: GroupByTest <master> [numMappers] [numKVPairs] [KeySize] [numReducers]")
System.exit(1)
}
-
+
var numMappers = if (args.length > 1) args(1).toInt else 2
var numKVPairs = if (args.length > 2) args(2).toInt else 1000
var valSize = if (args.length > 3) args(3).toInt else 1000
@@ -18,7 +18,7 @@ object SkewedGroupByTest {
val sc = new SparkContext(args(0), "GroupBy Test",
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
-
+
val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p =>
val ranGen = new Random
diff --git a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/spark/examples/SparkHdfsLR.scala
index 0f42f405a0..3d080a0257 100644
--- a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala
+++ b/examples/src/main/scala/spark/examples/SparkHdfsLR.scala
@@ -4,6 +4,8 @@ import java.util.Random
import scala.math.exp
import spark.util.Vector
import spark._
+import spark.deploy.SparkHadoopUtil
+import spark.scheduler.InputFormatInfo
/**
* Logistic regression based classification.
@@ -32,9 +34,13 @@ object SparkHdfsLR {
System.err.println("Usage: SparkHdfsLR <master> <file> <iters>")
System.exit(1)
}
+ val inputPath = args(1)
+ val conf = SparkHadoopUtil.newConfiguration()
val sc = new SparkContext(args(0), "SparkHdfsLR",
- System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
- val lines = sc.textFile(args(1))
+ System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")), Map(),
+ InputFormatInfo.computePreferredLocations(
+ Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath))))
+ val lines = sc.textFile(inputPath)
val points = lines.map(parsePoint _).cache()
val ITERATIONS = args(2).toInt
diff --git a/pom.xml b/pom.xml
index 08d1fc12e0..3936165d78 100644
--- a/pom.xml
+++ b/pom.xml
@@ -3,7 +3,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.spark-project</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.7.1-SNAPSHOT</version>
+ <version>0.8.0-SNAPSHOT</version>
<packaging>pom</packaging>
<name>Spark Project Parent POM</name>
<url>http://spark-project.org/</url>
@@ -51,7 +51,7 @@
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<java.version>1.5</java.version>
- <scala.version>2.9.2</scala.version>
+ <scala.version>2.9.3</scala.version>
<mesos.version>0.9.0-incubating</mesos.version>
<akka.version>2.0.3</akka.version>
<spray.version>1.0-M2.1</spray.version>
@@ -238,7 +238,7 @@
</dependency>
<dependency>
<groupId>cc.spray</groupId>
- <artifactId>spray-json_${scala.version}</artifactId>
+ <artifactId>spray-json_2.9.2</artifactId>
<version>${spray.json.version}</version>
</dependency>
<dependency>
@@ -248,7 +248,7 @@
</dependency>
<dependency>
<groupId>com.github.scala-incubator.io</groupId>
- <artifactId>scala-io-file_${scala.version}</artifactId>
+ <artifactId>scala-io-file_2.9.2</artifactId>
<version>0.4.1</version>
</dependency>
<dependency>
@@ -277,7 +277,7 @@
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.version}</artifactId>
- <version>1.8</version>
+ <version>1.9.1</version>
<scope>test</scope>
</dependency>
<dependency>
@@ -289,7 +289,7 @@
<dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.version}</artifactId>
- <version>1.9</version>
+ <version>1.10.0</version>
<scope>test</scope>
</dependency>
<dependency>
@@ -417,8 +417,9 @@
<configuration>
<reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory>
<junitxml>.</junitxml>
- <filereports>WDF TestSuite.txt</filereports>
+ <filereports>${project.build.directory}/SparkTestSuite.txt</filereports>
<argLine>-Xms64m -Xmx1024m</argLine>
+ <stderr/>
</configuration>
<executions>
<execution>
@@ -512,7 +513,6 @@
<profiles>
<profile>
<id>hadoop1</id>
-
<properties>
<hadoop.major.version>1</hadoop.major.version>
</properties>
@@ -558,5 +558,66 @@
</dependencies>
</dependencyManagement>
</profile>
+
+ <profile>
+ <id>hadoop2-yarn</id>
+ <properties>
+ <hadoop.major.version>2</hadoop.major.version>
+ <!-- 0.23.* is same as 2.0.* - except hardened to run production jobs -->
+ <!-- <yarn.version>0.23.7</yarn.version> -->
+ <yarn.version>2.0.2-alpha</yarn.version>
+ </properties>
+
+ <repositories>
+ <repository>
+ <id>maven-root</id>
+ <name>Maven root repository</name>
+ <url>http://repo1.maven.org/maven2/</url>
+ <releases>
+ <enabled>true</enabled>
+ </releases>
+ <snapshots>
+ <enabled>false</enabled>
+ </snapshots>
+ </repository>
+ </repositories>
+
+ <dependencyManagement>
+ <dependencies>
+ <!-- TODO: check versions, bringover from yarn branch ! -->
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <version>${yarn.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ <version>${yarn.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-common</artifactId>
+ <version>${yarn.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-client</artifactId>
+ <version>${yarn.version}</version>
+ </dependency>
+ <!-- Specify Avro version because Kafka also has it as a dependency -->
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro</artifactId>
+ <version>1.7.1.cloudera.2</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro-ipc</artifactId>
+ <version>1.7.1.cloudera.2</version>
+ </dependency>
+ </dependencies>
+ </dependencyManagement>
+ </profile>
</profiles>
</project>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index e3645653ee..dbfe5b0aa6 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -1,3 +1,4 @@
+
import sbt._
import sbt.Classpaths.publishTask
import Keys._
@@ -12,16 +13,23 @@ object SparkBuild extends Build {
// "1.0.4" for Apache releases, or "0.20.2-cdh3u5" for Cloudera Hadoop.
val HADOOP_VERSION = "1.0.4"
val HADOOP_MAJOR_VERSION = "1"
+ val HADOOP_YARN = false
// For Hadoop 2 versions such as "2.0.0-mr1-cdh4.1.1", set the HADOOP_MAJOR_VERSION to "2"
//val HADOOP_VERSION = "2.0.0-mr1-cdh4.1.1"
//val HADOOP_MAJOR_VERSION = "2"
+ //val HADOOP_YARN = false
+
+ // For Hadoop 2 YARN support
+ //val HADOOP_VERSION = "2.0.2-alpha"
+ //val HADOOP_MAJOR_VERSION = "2"
+ //val HADOOP_YARN = true
lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel, streaming)
lazy val core = Project("core", file("core"), settings = coreSettings)
- lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core) dependsOn (streaming)
+ lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core)
lazy val examples = Project("examples", file("examples"), settings = examplesSettings) dependsOn (core) dependsOn (streaming)
@@ -35,8 +43,8 @@ object SparkBuild extends Build {
def sharedSettings = Defaults.defaultSettings ++ Seq(
organization := "org.spark-project",
- version := "0.7.1-SNAPSHOT",
- scalaVersion := "2.9.2",
+ version := "0.8.0-SNAPSHOT",
+ scalaVersion := "2.9.3",
scalacOptions := Seq("-unchecked", "-optimize", "-deprecation"),
unmanagedJars in Compile <<= baseDirectory map { base => (base / "lib" ** "*.jar").classpath },
retrieveManaged := true,
@@ -44,7 +52,14 @@ object SparkBuild extends Build {
transitiveClassifiers in Scope.GlobalScope := Seq("sources"),
testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))),
- // shared between both core and streaming.
+ // Fork new JVMs for tests and set Java options for those
+ fork := true,
+ javaOptions += "-Xmx2g",
+
+ // Only allow one test at a time, even across projects, since they run in the same JVM
+ concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
+
+ // Shared between both core and streaming.
resolvers ++= Seq("Akka Repository" at "http://repo.akka.io/releases/"),
// For Sonatype publishing
@@ -92,13 +107,13 @@ object SparkBuild extends Build {
*/
libraryDependencies ++= Seq(
- "org.eclipse.jetty" % "jetty-server" % "7.5.3.v20111011",
- "org.scalatest" %% "scalatest" % "1.8" % "test",
- "org.scalacheck" %% "scalacheck" % "1.9" % "test",
- "com.novocode" % "junit-interface" % "0.8" % "test",
+ "io.netty" % "netty" % "3.5.3.Final",
+ "org.eclipse.jetty" % "jetty-server" % "7.6.8.v20121106",
+ "org.scalatest" %% "scalatest" % "1.9.1" % "test",
+ "org.scalacheck" %% "scalacheck" % "1.10.0" % "test",
+ "com.novocode" % "junit-interface" % "0.9" % "test",
"org.easymock" % "easymock" % "3.1" % "test"
),
- parallelExecution := false,
/* Workaround for issue #206 (fixed after SBT 0.11.0) */
watchTransitiveSources <<= Defaults.inDependencies[Task[Seq[File]]](watchSources.task,
const(std.TaskExtra.constant(Nil)), aggregate = true, includeRoot = true) apply { _.join.map(_.flatten) },
@@ -114,6 +129,9 @@ object SparkBuild extends Build {
val slf4jVersion = "1.6.1"
+ val excludeJackson = ExclusionRule(organization = "org.codehaus.jackson")
+ val excludeNetty = ExclusionRule(organization = "org.jboss.netty")
+
def coreSettings = sharedSettings ++ Seq(
name := "spark-core",
resolvers ++= Seq(
@@ -124,27 +142,52 @@ object SparkBuild extends Build {
),
libraryDependencies ++= Seq(
+ "io.netty" % "netty" % "3.5.3.Final",
"com.google.guava" % "guava" % "11.0.1",
"log4j" % "log4j" % "1.2.16",
"org.slf4j" % "slf4j-api" % slf4jVersion,
"org.slf4j" % "slf4j-log4j12" % slf4jVersion,
+ "commons-daemon" % "commons-daemon" % "1.0.10",
"com.ning" % "compress-lzf" % "0.8.4",
- "org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION,
"asm" % "asm-all" % "3.3.1",
"com.google.protobuf" % "protobuf-java" % "2.4.1",
"de.javakaffee" % "kryo-serializers" % "0.22",
- "com.typesafe.akka" % "akka-actor" % "2.0.3",
- "com.typesafe.akka" % "akka-remote" % "2.0.3",
- "com.typesafe.akka" % "akka-slf4j" % "2.0.3",
+ "com.typesafe.akka" % "akka-actor" % "2.0.3" excludeAll(excludeNetty),
+ "com.typesafe.akka" % "akka-remote" % "2.0.3" excludeAll(excludeNetty),
+ "com.typesafe.akka" % "akka-slf4j" % "2.0.3" excludeAll(excludeNetty),
"it.unimi.dsi" % "fastutil" % "6.4.4",
"colt" % "colt" % "1.2.0",
- "cc.spray" % "spray-can" % "1.0-M2.1",
- "cc.spray" % "spray-server" % "1.0-M2.1",
- "cc.spray" %% "spray-json" % "1.1.1",
+ "cc.spray" % "spray-can" % "1.0-M2.1" excludeAll(excludeNetty),
+ "cc.spray" % "spray-server" % "1.0-M2.1" excludeAll(excludeNetty),
+ "cc.spray" % "spray-json_2.9.2" % "1.1.1" excludeAll(excludeNetty),
"org.apache.mesos" % "mesos" % "0.9.0-incubating",
"io.netty" % "netty-all" % "4.0.0.Beta2"
- ) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq,
- unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") }
+ ) ++ (
+ if (HADOOP_MAJOR_VERSION == "2") {
+ if (HADOOP_YARN) {
+ Seq(
+ // Exclude rule required for all ?
+ "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty),
+ "org.apache.hadoop" % "hadoop-yarn-api" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty),
+ "org.apache.hadoop" % "hadoop-yarn-common" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty),
+ "org.apache.hadoop" % "hadoop-yarn-client" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty)
+ )
+ } else {
+ Seq(
+ "org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty),
+ "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty)
+ )
+ }
+ } else {
+ Seq("org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION excludeAll(excludeJackson, excludeNetty) )
+ }),
+ unmanagedSourceDirectories in Compile <+= baseDirectory{ _ /
+ ( if (HADOOP_YARN && HADOOP_MAJOR_VERSION == "2") {
+ "src/hadoop2-yarn/scala"
+ } else {
+ "src/hadoop" + HADOOP_MAJOR_VERSION + "/scala"
+ } )
+ }
) ++ assemblySettings ++ extraAssemblySettings ++ Twirl.settings
def rootSettings = sharedSettings ++ Seq(
@@ -158,7 +201,7 @@ object SparkBuild extends Build {
def examplesSettings = sharedSettings ++ Seq(
name := "spark-examples",
- libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.8")
+ libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.11")
)
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")
@@ -166,16 +209,17 @@ object SparkBuild extends Build {
def streamingSettings = sharedSettings ++ Seq(
name := "spark-streaming",
libraryDependencies ++= Seq(
- "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile",
+ "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty),
"com.github.sgroschupf" % "zkclient" % "0.1",
- "org.twitter4j" % "twitter4j-stream" % "3.0.3",
- "com.typesafe.akka" % "akka-zeromq" % "2.0.3"
+ "org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty),
+ "com.typesafe.akka" % "akka-zeromq" % "2.0.3" excludeAll(excludeNetty)
)
) ++ assemblySettings ++ extraAssemblySettings
def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq(
mergeStrategy in assembly := {
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
+ case m if m.toLowerCase.matches("meta-inf/.*\\.sf$") => MergeStrategy.discard
case "reference.conf" => MergeStrategy.concat
case _ => MergeStrategy.first
}
diff --git a/project/build.properties b/project/build.properties
index d4287112c6..9b860e23c5 100644
--- a/project/build.properties
+++ b/project/build.properties
@@ -1 +1 @@
-sbt.version=0.11.3
+sbt.version=0.12.3
diff --git a/project/plugins.sbt b/project/plugins.sbt
index 4d0e696a11..d4f2442872 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -4,13 +4,13 @@ resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/release
resolvers += "Spray Repository" at "http://repo.spray.cc/"
-addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.8.3")
+addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.8.5")
-addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.1.0-RC1")
+addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.1.1")
-addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.0.0")
+addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.2.0")
-addSbtPlugin("cc.spray" %% "sbt-twirl" % "0.5.2")
+addSbtPlugin("io.spray" %% "sbt-twirl" % "0.6.1")
// For Sonatype publishing
//resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns)
diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml
index dd720e2291..7a7280313e 100644
--- a/repl-bin/pom.xml
+++ b/repl-bin/pom.xml
@@ -4,7 +4,7 @@
<parent>
<groupId>org.spark-project</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.7.1-SNAPSHOT</version>
+ <version>0.8.0-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
@@ -154,6 +154,61 @@
</dependencies>
</profile>
<profile>
+ <id>hadoop2-yarn</id>
+ <properties>
+ <classifier>hadoop2-yarn</classifier>
+ </properties>
+ <dependencies>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-core</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-bagel</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-examples</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-repl</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-common</artifactId>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-client</artifactId>
+ <scope>runtime</scope>
+ </dependency>
+ </dependencies>
+ </profile>
+ <profile>
<id>deb</id>
<build>
<plugins>
diff --git a/repl/pom.xml b/repl/pom.xml
index a3e4606edc..92a2020b48 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -4,7 +4,7 @@
<parent>
<groupId>org.spark-project</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.7.1-SNAPSHOT</version>
+ <version>0.8.0-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
@@ -97,13 +97,6 @@
<scope>runtime</scope>
</dependency>
<dependency>
- <groupId>org.spark-project</groupId>
- <artifactId>spark-streaming</artifactId>
- <version>${project.version}</version>
- <classifier>hadoop1</classifier>
- <scope>runtime</scope>
- </dependency>
- <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
<scope>provided</scope>
@@ -148,20 +141,84 @@
<scope>runtime</scope>
</dependency>
<dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-core</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro-ipc</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <configuration>
+ <classifier>hadoop2</classifier>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ <profile>
+ <id>hadoop2-yarn</id>
+ <properties>
+ <classifier>hadoop2-yarn</classifier>
+ </properties>
+ <dependencies>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-core</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-bagel</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-examples</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
<groupId>org.spark-project</groupId>
<artifactId>spark-streaming</artifactId>
<version>${project.version}</version>
- <classifier>hadoop2</classifier>
+ <classifier>hadoop2-yarn</classifier>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
- <artifactId>hadoop-core</artifactId>
+ <artifactId>hadoop-client</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
- <artifactId>hadoop-client</artifactId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-common</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
@@ -181,7 +238,7 @@
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
- <classifier>hadoop2</classifier>
+ <classifier>hadoop2-yarn</classifier>
</configuration>
</plugin>
</plugins>
diff --git a/repl/src/main/scala/spark/repl/SparkILoop.scala b/repl/src/main/scala/spark/repl/SparkILoop.scala
index cd7b5128b2..23556dbc8f 100644
--- a/repl/src/main/scala/spark/repl/SparkILoop.scala
+++ b/repl/src/main/scala/spark/repl/SparkILoop.scala
@@ -200,7 +200,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
- /___/ .__/\_,_/_/ /_/\_\ version 0.7.1
+ /___/ .__/\_,_/_/ /_/\_\ version 0.8.0
/_/
""")
import Properties._
diff --git a/repl/src/test/scala/spark/repl/ReplSuite.scala b/repl/src/test/scala/spark/repl/ReplSuite.scala
index 43559b96d3..1c64f9b98d 100644
--- a/repl/src/test/scala/spark/repl/ReplSuite.scala
+++ b/repl/src/test/scala/spark/repl/ReplSuite.scala
@@ -32,6 +32,7 @@ class ReplSuite extends FunSuite {
interp.sparkContext.stop()
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
return out.toString
}
diff --git a/run b/run
index 2c29cc4a66..0a58ac4a36 100755
--- a/run
+++ b/run
@@ -1,6 +1,6 @@
#!/bin/bash
-SCALA_VERSION=2.9.2
+SCALA_VERSION=2.9.3
# Figure out where the Scala framework is installed
FWDIR="$(cd `dirname $0`; pwd)"
@@ -22,6 +22,7 @@ fi
# values for that; it doesn't need a lot
if [ "$1" = "spark.deploy.master.Master" -o "$1" = "spark.deploy.worker.Worker" ]; then
SPARK_MEM=${SPARK_DAEMON_MEMORY:-512m}
+ SPARK_DAEMON_JAVA_OPTS+=" -Dspark.akka.logLifecycleEvents=true"
SPARK_JAVA_OPTS=$SPARK_DAEMON_JAVA_OPTS # Empty by default
fi
@@ -46,14 +47,15 @@ case "$1" in
esac
if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then
- if [ `command -v scala` ]; then
- RUNNER="scala"
+ if [ "$SCALA_HOME" ]; then
+ RUNNER="${SCALA_HOME}/bin/scala"
else
- if [ -z "$SCALA_HOME" ]; then
- echo "SCALA_HOME is not set" >&2
+ if [ `command -v scala` ]; then
+ RUNNER="scala"
+ else
+ echo "SCALA_HOME is not set and scala is not in PATH" >&2
exit 1
fi
- RUNNER="${SCALA_HOME}/bin/scala"
fi
else
if [ `command -v java` ]; then
@@ -93,6 +95,7 @@ export JAVA_OPTS
CORE_DIR="$FWDIR/core"
REPL_DIR="$FWDIR/repl"
+REPL_BIN_DIR="$FWDIR/repl-bin"
EXAMPLES_DIR="$FWDIR/examples"
BAGEL_DIR="$FWDIR/bagel"
STREAMING_DIR="$FWDIR/streaming"
@@ -123,8 +126,8 @@ if [ -e "$FWDIR/lib_managed" ]; then
CLASSPATH+=":$FWDIR/lib_managed/bundles/*"
fi
CLASSPATH+=":$REPL_DIR/lib/*"
-if [ -e repl-bin/target ]; then
- for jar in `find "repl-bin/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do
+if [ -e $REPL_BIN_DIR/target ]; then
+ for jar in `find "$REPL_BIN_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do
CLASSPATH+=":$jar"
done
fi
@@ -132,7 +135,6 @@ CLASSPATH+=":$BAGEL_DIR/target/scala-$SCALA_VERSION/classes"
for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do
CLASSPATH+=":$jar"
done
-export CLASSPATH # Needed for spark-shell
# Figure out the JAR file that our examples were packaged into. This includes a bit of a hack
# to avoid the -sources and -doc packages that are built by publish-local.
@@ -161,4 +163,5 @@ else
EXTRA_ARGS="$JAVA_OPTS"
fi
+export CLASSPATH # Needed for spark-shell
exec "$RUNNER" -cp "$CLASSPATH" $EXTRA_ARGS "$@"
diff --git a/run2.cmd b/run2.cmd
index cb20a4b7a2..d2d4807971 100644
--- a/run2.cmd
+++ b/run2.cmd
@@ -1,6 +1,6 @@
@echo off
-set SCALA_VERSION=2.9.2
+set SCALA_VERSION=2.9.3
rem Figure out where the Spark framework is installed
set FWDIR=%~dp0
@@ -21,6 +21,7 @@ set RUNNING_DAEMON=0
if "%1"=="spark.deploy.master.Master" set RUNNING_DAEMON=1
if "%1"=="spark.deploy.worker.Worker" set RUNNING_DAEMON=1
if "x%SPARK_DAEMON_MEMORY%" == "x" set SPARK_DAEMON_MEMORY=512m
+set SPARK_DAEMON_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% -Dspark.akka.logLifecycleEvents=true
if "%RUNNING_DAEMON%"=="1" set SPARK_MEM=%SPARK_DAEMON_MEMORY%
if "%RUNNING_DAEMON%"=="1" set SPARK_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS%
diff --git a/sbt/sbt b/sbt/sbt
index 8f426d18e8..850c58e1e9 100755
--- a/sbt/sbt
+++ b/sbt/sbt
@@ -5,4 +5,4 @@ if [ "$MESOS_HOME" != "" ]; then
fi
export SPARK_HOME=$(cd "$(dirname $0)/.."; pwd)
export SPARK_TESTING=1 # To put test classes on classpath
-java -Xmx1200M -XX:MaxPermSize=250m $EXTRA_ARGS -jar $SPARK_HOME/sbt/sbt-launch-*.jar "$@"
+java -Xmx1200m -XX:MaxPermSize=250m -XX:ReservedCodeCacheSize=128m $EXTRA_ARGS -jar $SPARK_HOME/sbt/sbt-launch-*.jar "$@"
diff --git a/streaming/pom.xml b/streaming/pom.xml
index ec077e8089..08ff3e2ae1 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -4,7 +4,7 @@
<parent>
<groupId>org.spark-project</groupId>
<artifactId>spark-parent</artifactId>
- <version>0.7.1-SNAPSHOT</version>
+ <version>0.8.0-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
@@ -149,5 +149,42 @@
</plugins>
</build>
</profile>
+ <profile>
+ <id>hadoop2-yarn</id>
+ <dependencies>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-core</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2-yarn</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-common</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <configuration>
+ <classifier>hadoop2-yarn</classifier>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
</profiles>
</project>
diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
index e303e33e5e..66e67cbfa1 100644
--- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
@@ -38,11 +38,20 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
private[streaming]
class CheckpointWriter(checkpointDir: String) extends Logging {
val file = new Path(checkpointDir, "graph")
+ // The file to which we actually write - and then "move" to file.
+ private val writeFile = new Path(file.getParent, file.getName + ".next")
+ private val bakFile = new Path(file.getParent, file.getName + ".bk")
+
+ private var stopped = false
+
val conf = new Configuration()
var fs = file.getFileSystem(conf)
val maxAttempts = 3
val executor = Executors.newFixedThreadPool(1)
+ // Removed code which validates whether there is only one CheckpointWriter per path 'file' since
+ // I did not notice any errors - reintroduce it ?
+
class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable {
def run() {
var attempts = 0
@@ -51,15 +60,17 @@ class CheckpointWriter(checkpointDir: String) extends Logging {
attempts += 1
try {
logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'")
- if (fs.exists(file)) {
- val bkFile = new Path(file.getParent, file.getName + ".bk")
- FileUtil.copy(fs, file, fs, bkFile, true, true, conf)
- logDebug("Moved existing checkpoint file to " + bkFile)
- }
- val fos = fs.create(file)
+ // This is inherently thread unsafe .. so alleviating it by writing to '.new' and then doing moves : which should be pretty fast.
+ val fos = fs.create(writeFile)
fos.write(bytes)
fos.close()
- fos.close()
+ if (fs.exists(file) && fs.rename(file, bakFile)) {
+ logDebug("Moved existing checkpoint file to " + bakFile)
+ }
+ // paranoia
+ fs.delete(file, false)
+ fs.rename(writeFile, file)
+
val finishTime = System.currentTimeMillis();
logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file +
"', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds")
@@ -84,7 +95,15 @@ class CheckpointWriter(checkpointDir: String) extends Logging {
}
def stop() {
+ synchronized {
+ if (stopped) return ;
+ stopped = true
+ }
executor.shutdown()
+ val startTime = System.currentTimeMillis()
+ val terminated = executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS)
+ val endTime = System.currentTimeMillis()
+ logInfo("CheckpointWriter executor terminated ? " + terminated + ", waited for " + (endTime - startTime) + " ms.")
}
}
diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
index adb7f3a24d..3b331956f5 100644
--- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
@@ -54,8 +54,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
throw new Exception("Batch duration already set as " + batchDuration +
". cannot set it again.")
}
+ batchDuration = duration
}
- batchDuration = duration
}
def remember(duration: Duration) {
diff --git a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala
index f673e5be15..426a9b6f71 100644
--- a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala
+++ b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala
@@ -159,6 +159,7 @@ object MasterFailureTest extends Logging {
// Setup the streaming computation with the given operation
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
var ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map())
ssc.checkpoint(checkpointDir.toString)
val inputStream = ssc.textFileStream(testDir.toString)
@@ -205,6 +206,7 @@ object MasterFailureTest extends Logging {
// (iii) Its not timed out yet
System.clearProperty("spark.streaming.clock")
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
ssc.start()
val startTime = System.currentTimeMillis()
while (!killed && !isLastOutputGenerated && !isTimedOut) {
@@ -357,13 +359,16 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long)
// Write the data to a local file and then move it to the target test directory
val localFile = new File(localTestDir, (i+1).toString)
val hadoopFile = new Path(testDir, (i+1).toString)
+ val tempHadoopFile = new Path(testDir, ".tmp_" + (i+1).toString)
FileUtils.writeStringToFile(localFile, input(i).toString + "\n")
var tries = 0
var done = false
while (!done && tries < maxTries) {
tries += 1
try {
- fs.copyFromLocalFile(new Path(localFile.toString), hadoopFile)
+ // fs.copyFromLocalFile(new Path(localFile.toString), hadoopFile)
+ fs.copyFromLocalFile(new Path(localFile.toString), tempHadoopFile)
+ fs.rename(tempHadoopFile, hadoopFile)
done = true
} catch {
case ioe: IOException => {
diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
index cf2ed8b1d4..e7352deb81 100644
--- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
@@ -15,6 +15,7 @@ class BasicOperationsSuite extends TestSuiteBase {
after {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
}
test("map") {
diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
index cac86deeaf..607dea77ec 100644
--- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
@@ -31,6 +31,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
}
var ssc: StreamingContext = null
@@ -325,6 +326,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
)
ssc = new StreamingContext(checkpointDir)
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
ssc.start()
val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches)
// the first element will be re-processed data of the last batch before restart
@@ -350,4 +352,4 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]]
outputStream.output
}
-} \ No newline at end of file
+}
diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
index 67dca2ac31..0acb6db6f2 100644
--- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
@@ -41,6 +41,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
after {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
}
diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
index 1b66f3bda2..80d827706f 100644
--- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
@@ -16,6 +16,7 @@ class WindowOperationsSuite extends TestSuiteBase {
after {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
}
val largerSlideInput = Seq(