aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorPrashant Sharma <prashant.s@imaginea.com>2013-07-03 11:43:26 +0530
committerPrashant Sharma <prashant.s@imaginea.com>2013-07-03 11:43:26 +0530
commita5f1f6a907b116325c56d38157ec2df76150951e (patch)
tree27de949c24a61b2301c7690db9e28992f49ea39c /core/src
parentb7794813b181f13801596e8d8c3b4471c0c84f20 (diff)
parent6d60fe571a405eb9306a2be1817901316a46f892 (diff)
downloadspark-a5f1f6a907b116325c56d38157ec2df76150951e.tar.gz
spark-a5f1f6a907b116325c56d38157ec2df76150951e.tar.bz2
spark-a5f1f6a907b116325c56d38157ec2df76150951e.zip
Merge branch 'master' into master-merge
Conflicts: core/pom.xml core/src/main/scala/spark/MapOutputTracker.scala core/src/main/scala/spark/RDD.scala core/src/main/scala/spark/RDDCheckpointData.scala core/src/main/scala/spark/SparkContext.scala core/src/main/scala/spark/Utils.scala core/src/main/scala/spark/api/python/PythonRDD.scala core/src/main/scala/spark/deploy/client/Client.scala core/src/main/scala/spark/deploy/master/MasterWebUI.scala core/src/main/scala/spark/deploy/worker/Worker.scala core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala core/src/main/scala/spark/rdd/BlockRDD.scala core/src/main/scala/spark/rdd/ZippedRDD.scala core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala core/src/main/scala/spark/storage/BlockManager.scala core/src/main/scala/spark/storage/BlockManagerMaster.scala core/src/main/scala/spark/storage/BlockManagerMasterActor.scala core/src/main/scala/spark/storage/BlockManagerUI.scala core/src/main/scala/spark/util/AkkaUtils.scala core/src/test/scala/spark/SizeEstimatorSuite.scala pom.xml project/SparkBuild.scala repl/src/main/scala/spark/repl/SparkILoop.scala repl/src/test/scala/spark/repl/ReplSuite.scala streaming/src/main/scala/spark/streaming/StreamingContext.scala streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala
Diffstat (limited to 'core/src')
-rw-r--r--core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala3
-rw-r--r--core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala3
-rw-r--r--core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala23
-rw-r--r--core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala13
-rw-r--r--core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala13
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala63
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala329
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala77
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala272
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala105
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala171
-rw-r--r--core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala547
-rw-r--r--core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala42
-rw-r--r--core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala3
-rw-r--r--core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala3
-rw-r--r--core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala23
-rw-r--r--core/src/main/java/spark/network/netty/FileClient.java72
-rw-r--r--core/src/main/java/spark/network/netty/FileClientChannelInitializer.java24
-rw-r--r--core/src/main/java/spark/network/netty/FileClientHandler.java43
-rw-r--r--core/src/main/java/spark/network/netty/FileServer.java86
-rw-r--r--core/src/main/java/spark/network/netty/FileServerChannelInitializer.java25
-rw-r--r--core/src/main/java/spark/network/netty/FileServerHandler.java65
-rwxr-xr-xcore/src/main/java/spark/network/netty/PathResolver.java12
-rw-r--r--core/src/main/scala/spark/BlockStoreShuffleFetcher.scala33
-rw-r--r--core/src/main/scala/spark/ClosureCleaner.scala23
-rw-r--r--core/src/main/scala/spark/Dependency.scala4
-rw-r--r--core/src/main/scala/spark/FetchFailedException.scala25
-rw-r--r--core/src/main/scala/spark/HadoopWriter.scala12
-rw-r--r--core/src/main/scala/spark/Logging.scala4
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala99
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala73
-rw-r--r--core/src/main/scala/spark/RDD.scala129
-rw-r--r--core/src/main/scala/spark/RDDCheckpointData.scala15
-rw-r--r--core/src/main/scala/spark/SequenceFileRDDFunctions.scala15
-rw-r--r--core/src/main/scala/spark/ShuffleFetcher.scala7
-rw-r--r--core/src/main/scala/spark/SizeEstimator.scala2
-rw-r--r--core/src/main/scala/spark/SparkContext.scala167
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala68
-rw-r--r--core/src/main/scala/spark/TaskEndReason.scala16
-rw-r--r--core/src/main/scala/spark/Utils.scala296
-rw-r--r--core/src/main/scala/spark/api/java/JavaPairRDD.scala11
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDD.scala9
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala50
-rw-r--r--core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala11
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala109
-rw-r--r--core/src/main/scala/spark/api/python/PythonWorkerFactory.scala113
-rw-r--r--core/src/main/scala/spark/deploy/ApplicationDescription.scala5
-rw-r--r--core/src/main/scala/spark/deploy/DeployMessage.scala19
-rw-r--r--core/src/main/scala/spark/deploy/JsonProtocol.scala5
-rw-r--r--core/src/main/scala/spark/deploy/LocalSparkCluster.scala8
-rw-r--r--core/src/main/scala/spark/deploy/client/Client.scala9
-rw-r--r--core/src/main/scala/spark/deploy/client/ClientListener.scala2
-rw-r--r--core/src/main/scala/spark/deploy/client/TestClient.scala4
-rw-r--r--core/src/main/scala/spark/deploy/master/ApplicationInfo.scala6
-rw-r--r--core/src/main/scala/spark/deploy/master/Master.scala19
-rw-r--r--core/src/main/scala/spark/deploy/master/MasterArguments.scala17
-rw-r--r--core/src/main/scala/spark/deploy/master/MasterWebUI.scala4
-rw-r--r--core/src/main/scala/spark/deploy/master/WorkerInfo.scala9
-rw-r--r--core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala36
-rw-r--r--core/src/main/scala/spark/deploy/worker/Worker.scala25
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerArguments.scala13
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala3
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala53
-rw-r--r--core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala32
-rw-r--r--core/src/main/scala/spark/executor/TaskMetrics.scala17
-rw-r--r--core/src/main/scala/spark/network/BufferMessage.scala94
-rw-r--r--core/src/main/scala/spark/network/Connection.scala291
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala549
-rw-r--r--core/src/main/scala/spark/network/ConnectionManagerId.scala21
-rw-r--r--core/src/main/scala/spark/network/Message.scala178
-rw-r--r--core/src/main/scala/spark/network/MessageChunk.scala25
-rw-r--r--core/src/main/scala/spark/network/MessageChunkHeader.scala58
-rw-r--r--core/src/main/scala/spark/network/netty/FileHeader.scala57
-rw-r--r--core/src/main/scala/spark/network/netty/ShuffleCopier.scala101
-rw-r--r--core/src/main/scala/spark/network/netty/ShuffleSender.scala53
-rw-r--r--core/src/main/scala/spark/rdd/BlockRDD.scala10
-rw-r--r--core/src/main/scala/spark/rdd/CheckpointRDD.scala31
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala20
-rw-r--r--core/src/main/scala/spark/rdd/EmptyRDD.scala16
-rw-r--r--core/src/main/scala/spark/rdd/JdbcRDD.scala103
-rw-r--r--core/src/main/scala/spark/rdd/NewHadoopRDD.scala2
-rw-r--r--core/src/main/scala/spark/rdd/PipedRDD.scala27
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala10
-rw-r--r--core/src/main/scala/spark/rdd/SubtractedRDD.scala18
-rw-r--r--core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala138
-rw-r--r--core/src/main/scala/spark/rdd/ZippedRDD.scala23
-rw-r--r--core/src/main/scala/spark/scheduler/ActiveJob.scala5
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala59
-rw-r--r--core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/InputFormatInfo.scala156
-rw-r--r--core/src/main/scala/spark/scheduler/JobLogger.scala306
-rw-r--r--core/src/main/scala/spark/scheduler/ResultTask.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala62
-rw-r--r--core/src/main/scala/spark/scheduler/SparkListener.scala50
-rw-r--r--core/src/main/scala/spark/scheduler/SplitInfo.scala61
-rw-r--r--core/src/main/scala/spark/scheduler/Stage.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/TaskScheduler.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala3
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSet.scala11
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala364
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala747
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/Pool.scala104
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/Schedulable.scala27
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala115
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala11
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala64
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala7
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala7
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala34
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala431
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala225
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala172
-rw-r--r--core/src/main/scala/spark/serializer/Serializer.scala8
-rw-r--r--core/src/main/scala/spark/serializer/SerializerManager.scala45
-rw-r--r--core/src/main/scala/spark/storage/BlockException.scala5
-rw-r--r--core/src/main/scala/spark/storage/BlockFetchTracker.scala12
-rw-r--r--core/src/main/scala/spark/storage/BlockFetcherIterator.scala330
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala465
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerId.scala62
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMaster.scala24
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMasterActor.scala211
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMessages.scala3
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala8
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerUI.scala25
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerWorker.scala28
-rw-r--r--core/src/main/scala/spark/storage/BlockMessageArray.scala1
-rw-r--r--core/src/main/scala/spark/storage/BlockObjectWriter.scala50
-rw-r--r--core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala12
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala190
-rw-r--r--core/src/main/scala/spark/storage/MemoryStore.scala4
-rw-r--r--core/src/main/scala/spark/storage/ShuffleBlockManager.scala50
-rw-r--r--core/src/main/scala/spark/storage/StorageLevel.scala8
-rw-r--r--core/src/main/scala/spark/storage/StorageUtils.scala47
-rw-r--r--core/src/main/scala/spark/util/AkkaUtils.scala23
-rw-r--r--core/src/main/scala/spark/util/BoundedPriorityQueue.scala45
-rw-r--r--core/src/main/scala/spark/util/StatCounter.scala26
-rw-r--r--core/src/main/scala/spark/util/TimeStampedHashMap.scala8
-rw-r--r--core/src/main/scala/spark/util/TimedIterator.scala32
-rw-r--r--core/src/main/twirl/spark/deploy/master/app_details.scala.html12
-rw-r--r--core/src/main/twirl/spark/deploy/master/executor_row.scala.html2
-rw-r--r--core/src/main/twirl/spark/deploy/master/index.scala.html2
-rw-r--r--core/src/main/twirl/spark/deploy/master/worker_row.scala.html2
-rw-r--r--core/src/main/twirl/spark/deploy/worker/index.scala.html2
-rw-r--r--core/src/main/twirl/spark/storage/worker_table.scala.html2
-rw-r--r--core/src/test/resources/fairscheduler.xml14
-rw-r--r--core/src/test/scala/spark/CheckpointSuite.scala10
-rw-r--r--core/src/test/scala/spark/DistributedSuite.scala67
-rw-r--r--core/src/test/scala/spark/FileSuite.scala46
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java71
-rw-r--r--core/src/test/scala/spark/LocalSparkContext.scala3
-rw-r--r--core/src/test/scala/spark/MapOutputTrackerSuite.scala36
-rw-r--r--core/src/test/scala/spark/PairRDDFunctionsSuite.scala287
-rw-r--r--core/src/test/scala/spark/PartitioningSuite.scala30
-rw-r--r--core/src/test/scala/spark/PipedRDDSuite.scala45
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala98
-rw-r--r--core/src/test/scala/spark/SharedSparkContext.scala25
-rw-r--r--core/src/test/scala/spark/ShuffleNettySuite.scala17
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala322
-rw-r--r--core/src/test/scala/spark/SizeEstimatorSuite.scala2
-rw-r--r--core/src/test/scala/spark/SortingSuite.scala23
-rw-r--r--core/src/test/scala/spark/UnpersistSuite.scala30
-rw-r--r--core/src/test/scala/spark/UtilsSuite.scala53
-rw-r--r--core/src/test/scala/spark/ZippedPartitionsSuite.scala33
-rw-r--r--core/src/test/scala/spark/rdd/JdbcRDDSuite.scala56
-rw-r--r--core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala250
-rw-r--r--core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala20
-rw-r--r--core/src/test/scala/spark/scheduler/JobLoggerSuite.scala104
-rw-r--r--core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala206
-rw-r--r--core/src/test/scala/spark/scheduler/SparkListenerSuite.scala3
-rw-r--r--core/src/test/scala/spark/storage/BlockManagerSuite.scala76
173 files changed, 9792 insertions, 2365 deletions
diff --git a/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
index ca9f7219de..f286f2cf9c 100644
--- a/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
+++ b/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
@@ -4,4 +4,7 @@ trait HadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContext(conf, jobId)
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
+ jobId, isMap, taskId, attemptId)
}
diff --git a/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
index de7b0f81e3..264d421d14 100644
--- a/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
+++ b/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
@@ -6,4 +6,7 @@ trait HadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContext(conf, jobId)
def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
+ jobId, isMap, taskId, attemptId)
}
diff --git a/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala
new file mode 100644
index 0000000000..a0fb4fe25d
--- /dev/null
+++ b/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala
@@ -0,0 +1,23 @@
+package spark.deploy
+import org.apache.hadoop.conf.Configuration
+
+
+/**
+ * Contains util methods to interact with Hadoop from spark.
+ */
+object SparkHadoopUtil {
+
+ def getUserNameFromEnvironment(): String = {
+ // defaulting to -D ...
+ System.getProperty("user.name")
+ }
+
+ def runAsUser(func: (Product) => Unit, args: Product) {
+
+ // Add support, if exists - for now, simply run func !
+ func(args)
+ }
+
+ // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
+ def newConfiguration(): Configuration = new Configuration()
+}
diff --git a/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
new file mode 100644
index 0000000000..875c0a220b
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
@@ -0,0 +1,13 @@
+
+package org.apache.hadoop.mapred
+
+import org.apache.hadoop.mapreduce.TaskType
+
+trait HadoopMapRedUtil {
+ def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
+
+ def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) =
+ new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId)
+}
diff --git a/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
new file mode 100644
index 0000000000..8bc6fb6dea
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
@@ -0,0 +1,13 @@
+package org.apache.hadoop.mapreduce
+
+import org.apache.hadoop.conf.Configuration
+import task.{TaskAttemptContextImpl, JobContextImpl}
+
+trait HadoopMapReduceUtil {
+ def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
+
+ def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) =
+ new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId)
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala
new file mode 100644
index 0000000000..ab1ab9d8a7
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala
@@ -0,0 +1,63 @@
+package spark.deploy
+
+import collection.mutable.HashMap
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import java.security.PrivilegedExceptionAction
+
+/**
+ * Contains util methods to interact with Hadoop from spark.
+ */
+object SparkHadoopUtil {
+
+ val yarnConf = newConfiguration()
+
+ def getUserNameFromEnvironment(): String = {
+ // defaulting to env if -D is not present ...
+ val retval = System.getProperty(Environment.USER.name, System.getenv(Environment.USER.name))
+
+ // If nothing found, default to user we are running as
+ if (retval == null) System.getProperty("user.name") else retval
+ }
+
+ def runAsUser(func: (Product) => Unit, args: Product) {
+ runAsUser(func, args, getUserNameFromEnvironment())
+ }
+
+ def runAsUser(func: (Product) => Unit, args: Product, user: String) {
+
+ // println("running as user " + jobUserName)
+
+ UserGroupInformation.setConfiguration(yarnConf)
+ val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(user)
+ appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] {
+ def run: AnyRef = {
+ func(args)
+ // no return value ...
+ null
+ }
+ })
+ }
+
+ // Note that all params which start with SPARK are propagated all the way through, so if in yarn mode, this MUST be set to true.
+ def isYarnMode(): Boolean = {
+ val yarnMode = System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))
+ java.lang.Boolean.valueOf(yarnMode)
+ }
+
+ // Set an env variable indicating we are running in YARN mode.
+ // Note that anything with SPARK prefix gets propagated to all (remote) processes
+ def setYarnMode() {
+ System.setProperty("SPARK_YARN_MODE", "true")
+ }
+
+ def setYarnMode(env: HashMap[String, String]) {
+ env("SPARK_YARN_MODE") = "true"
+ }
+
+ // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
+ // Always create a new config, dont reuse yarnConf.
+ def newConfiguration(): Configuration = new YarnConfiguration(new Configuration())
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala
new file mode 100644
index 0000000000..aa72c1e5fe
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala
@@ -0,0 +1,329 @@
+package spark.deploy.yarn
+
+import java.net.Socket
+import java.util.concurrent.CopyOnWriteArrayList
+import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+import scala.collection.JavaConversions._
+import spark.{SparkContext, Logging, Utils}
+import org.apache.hadoop.security.UserGroupInformation
+import java.security.PrivilegedExceptionAction
+
+class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
+
+ def this(args: ApplicationMasterArguments) = this(args, new Configuration())
+
+ private var rpc: YarnRPC = YarnRPC.create(conf)
+ private var resourceManager: AMRMProtocol = null
+ private var appAttemptId: ApplicationAttemptId = null
+ private var userThread: Thread = null
+ private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+ private var yarnAllocator: YarnAllocationHandler = null
+
+ def run() {
+
+ // Initialization
+ val jobUserName = Utils.getUserNameFromEnvironment()
+ logInfo("running as user " + jobUserName)
+
+ // run as user ...
+ UserGroupInformation.setConfiguration(yarnConf)
+ val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(jobUserName)
+ appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] {
+ def run: AnyRef = {
+ runImpl()
+ return null
+ }
+ })
+ }
+
+ private def runImpl() {
+
+ appAttemptId = getApplicationAttemptId()
+ resourceManager = registerWithResourceManager()
+ val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
+
+ // Compute number of threads for akka
+ val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()
+
+ if (minimumMemory > 0) {
+ val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD
+ val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0)
+
+ if (numCore > 0) {
+ // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406
+ // TODO: Uncomment when hadoop is on a version which has this fixed.
+ // args.workerCores = numCore
+ }
+ }
+
+ // Workaround until hadoop moves to something which has
+ // https://issues.apache.org/jira/browse/HADOOP-8406
+ // ignore result
+ // This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times
+ // Hence args.workerCores = numCore disabled above. Any better option ?
+ // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf)
+
+ ApplicationMaster.register(this)
+ // Start the user's JAR
+ userThread = startUserClass()
+
+ // This a bit hacky, but we need to wait until the spark.driver.port property has
+ // been set by the Thread executing the user class.
+ waitForSparkMaster()
+
+ // Allocate all containers
+ allocateWorkers()
+
+ // Wait for the user class to Finish
+ userThread.join()
+
+ // Finish the ApplicationMaster
+ finishApplicationMaster()
+ // TODO: Exit based on success/failure
+ System.exit(0)
+ }
+
+ private def getApplicationAttemptId(): ApplicationAttemptId = {
+ val envs = System.getenv()
+ val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV)
+ val containerId = ConverterUtils.toContainerId(containerIdString)
+ val appAttemptId = containerId.getApplicationAttemptId()
+ logInfo("ApplicationAttemptId: " + appAttemptId)
+ return appAttemptId
+ }
+
+ private def registerWithResourceManager(): AMRMProtocol = {
+ val rmAddress = NetUtils.createSocketAddr(yarnConf.get(
+ YarnConfiguration.RM_SCHEDULER_ADDRESS,
+ YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS))
+ logInfo("Connecting to ResourceManager at " + rmAddress)
+ return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
+ }
+
+ private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
+ logInfo("Registering the ApplicationMaster")
+ val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest])
+ .asInstanceOf[RegisterApplicationMasterRequest]
+ appMasterRequest.setApplicationAttemptId(appAttemptId)
+ // Setting this to master host,port - so that the ApplicationReport at client has some sensible info.
+ // Users can then monitor stderr/stdout on that node if required.
+ appMasterRequest.setHost(Utils.localHostName())
+ appMasterRequest.setRpcPort(0)
+ // What do we provide here ? Might make sense to expose something sensible later ?
+ appMasterRequest.setTrackingUrl("")
+ return resourceManager.registerApplicationMaster(appMasterRequest)
+ }
+
+ private def waitForSparkMaster() {
+ logInfo("Waiting for spark driver to be reachable.")
+ var driverUp = false
+ while(!driverUp) {
+ val driverHost = System.getProperty("spark.driver.host")
+ val driverPort = System.getProperty("spark.driver.port")
+ try {
+ val socket = new Socket(driverHost, driverPort.toInt)
+ socket.close()
+ logInfo("Master now available: " + driverHost + ":" + driverPort)
+ driverUp = true
+ } catch {
+ case e: Exception =>
+ logError("Failed to connect to driver at " + driverHost + ":" + driverPort)
+ Thread.sleep(100)
+ }
+ }
+ }
+
+ private def startUserClass(): Thread = {
+ logInfo("Starting the user JAR in a separate Thread")
+ val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader)
+ .getMethod("main", classOf[Array[String]])
+ val t = new Thread {
+ override def run() {
+ // Copy
+ var mainArgs: Array[String] = new Array[String](args.userArgs.size())
+ args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size())
+ mainMethod.invoke(null, mainArgs)
+ }
+ }
+ t.start()
+ return t
+ }
+
+ private def allocateWorkers() {
+ logInfo("Waiting for spark context initialization")
+
+ try {
+ var sparkContext: SparkContext = null
+ ApplicationMaster.sparkContextRef.synchronized {
+ var count = 0
+ while (ApplicationMaster.sparkContextRef.get() == null) {
+ logInfo("Waiting for spark context initialization ... " + count)
+ count = count + 1
+ ApplicationMaster.sparkContextRef.wait(10000L)
+ }
+ sparkContext = ApplicationMaster.sparkContextRef.get()
+ assert(sparkContext != null)
+ this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, sparkContext.preferredNodeLocationData)
+ }
+
+
+ logInfo("Allocating " + args.numWorkers + " workers.")
+ // Wait until all containers have finished
+ // TODO: This is a bit ugly. Can we make it nicer?
+ // TODO: Handle container failure
+ while(yarnAllocator.getNumWorkersRunning < args.numWorkers &&
+ // If user thread exists, then quit !
+ userThread.isAlive) {
+
+ this.yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0))
+ ApplicationMaster.incrementAllocatorLoop(1)
+ Thread.sleep(100)
+ }
+ } finally {
+ // in case of exceptions, etc - ensure that count is atleast ALLOCATOR_LOOP_WAIT_COUNT :
+ // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks
+ ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT)
+ }
+ logInfo("All workers have launched.")
+
+ // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout
+ if (userThread.isAlive){
+ // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
+
+ val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
+ // must be <= timeoutInterval/ 2.
+ // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
+ // so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
+ val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
+ launchReporterThread(interval)
+ }
+ }
+
+ // TODO: We might want to extend this to allocate more containers in case they die !
+ private def launchReporterThread(_sleepTime: Long): Thread = {
+ val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime
+
+ val t = new Thread {
+ override def run() {
+ while (userThread.isAlive){
+ val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning
+ if (missingWorkerCount > 0) {
+ logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers")
+ yarnAllocator.allocateContainers(missingWorkerCount)
+ }
+ else sendProgress()
+ Thread.sleep(sleepTime)
+ }
+ }
+ }
+ // setting to daemon status, though this is usually not a good idea.
+ t.setDaemon(true)
+ t.start()
+ logInfo("Started progress reporter thread - sleep time : " + sleepTime)
+ return t
+ }
+
+ private def sendProgress() {
+ logDebug("Sending progress")
+ // simulated with an allocate request with no nodes requested ...
+ yarnAllocator.allocateContainers(0)
+ }
+
+ /*
+ def printContainers(containers: List[Container]) = {
+ for (container <- containers) {
+ logInfo("Launching shell command on a new container."
+ + ", containerId=" + container.getId()
+ + ", containerNode=" + container.getNodeId().getHost()
+ + ":" + container.getNodeId().getPort()
+ + ", containerNodeURI=" + container.getNodeHttpAddress()
+ + ", containerState" + container.getState()
+ + ", containerResourceMemory"
+ + container.getResource().getMemory())
+ }
+ }
+ */
+
+ def finishApplicationMaster() {
+ val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest])
+ .asInstanceOf[FinishApplicationMasterRequest]
+ finishReq.setAppAttemptId(appAttemptId)
+ // TODO: Check if the application has failed or succeeded
+ finishReq.setFinishApplicationStatus(FinalApplicationStatus.SUCCEEDED)
+ resourceManager.finishApplicationMaster(finishReq)
+ }
+
+}
+
+object ApplicationMaster {
+ // number of times to wait for the allocator loop to complete.
+ // each loop iteration waits for 100ms, so maximum of 3 seconds.
+ // This is to ensure that we have reasonable number of containers before we start
+ // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be optimal as more
+ // containers are available. Might need to handle this better.
+ private val ALLOCATOR_LOOP_WAIT_COUNT = 30
+ def incrementAllocatorLoop(by: Int) {
+ val count = yarnAllocatorLoop.getAndAdd(by)
+ if (count >= ALLOCATOR_LOOP_WAIT_COUNT){
+ yarnAllocatorLoop.synchronized {
+ // to wake threads off wait ...
+ yarnAllocatorLoop.notifyAll()
+ }
+ }
+ }
+
+ private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]()
+
+ def register(master: ApplicationMaster) {
+ applicationMasters.add(master)
+ }
+
+ val sparkContextRef: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null)
+ val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0)
+
+ def sparkContextInitialized(sc: SparkContext): Boolean = {
+ var modified = false
+ sparkContextRef.synchronized {
+ modified = sparkContextRef.compareAndSet(null, sc)
+ sparkContextRef.notifyAll()
+ }
+
+ // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do System.exit
+ // Should not really have to do this, but it helps yarn to evict resources earlier.
+ // not to mention, prevent Client declaring failure even though we exit'ed properly.
+ if (modified) {
+ Runtime.getRuntime().addShutdownHook(new Thread with Logging {
+ // This is not just to log, but also to ensure that log system is initialized for this instance when we actually are 'run'
+ logInfo("Adding shutdown hook for context " + sc)
+ override def run() {
+ logInfo("Invoking sc stop from shutdown hook")
+ sc.stop()
+ // best case ...
+ for (master <- applicationMasters) master.finishApplicationMaster
+ }
+ } )
+ }
+
+ // Wait for initialization to complete and atleast 'some' nodes can get allocated
+ yarnAllocatorLoop.synchronized {
+ while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT){
+ yarnAllocatorLoop.wait(1000L)
+ }
+ }
+ modified
+ }
+
+ def main(argStrings: Array[String]) {
+ val args = new ApplicationMasterArguments(argStrings)
+ new ApplicationMaster(args).run()
+ }
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala
new file mode 100644
index 0000000000..1b00208511
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala
@@ -0,0 +1,77 @@
+package spark.deploy.yarn
+
+import spark.util.IntParam
+import collection.mutable.ArrayBuffer
+
+class ApplicationMasterArguments(val args: Array[String]) {
+ var userJar: String = null
+ var userClass: String = null
+ var userArgs: Seq[String] = Seq[String]()
+ var workerMemory = 1024
+ var workerCores = 1
+ var numWorkers = 2
+
+ parseArgs(args.toList)
+
+ private def parseArgs(inputArgs: List[String]): Unit = {
+ val userArgsBuffer = new ArrayBuffer[String]()
+
+ var args = inputArgs
+
+ while (! args.isEmpty) {
+
+ args match {
+ case ("--jar") :: value :: tail =>
+ userJar = value
+ args = tail
+
+ case ("--class") :: value :: tail =>
+ userClass = value
+ args = tail
+
+ case ("--args") :: value :: tail =>
+ userArgsBuffer += value
+ args = tail
+
+ case ("--num-workers") :: IntParam(value) :: tail =>
+ numWorkers = value
+ args = tail
+
+ case ("--worker-memory") :: IntParam(value) :: tail =>
+ workerMemory = value
+ args = tail
+
+ case ("--worker-cores") :: IntParam(value) :: tail =>
+ workerCores = value
+ args = tail
+
+ case Nil =>
+ if (userJar == null || userClass == null) {
+ printUsageAndExit(1)
+ }
+
+ case _ =>
+ printUsageAndExit(1, args)
+ }
+ }
+
+ userArgs = userArgsBuffer.readOnly
+ }
+
+ def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ if (unknownParam != null) {
+ System.err.println("Unknown/unsupported param " + unknownParam)
+ }
+ System.err.println(
+ "Usage: spark.deploy.yarn.ApplicationMaster [options] \n" +
+ "Options:\n" +
+ " --jar JAR_PATH Path to your application's JAR file (required)\n" +
+ " --class CLASS_NAME Name of your application's main class (required)\n" +
+ " --args ARGS Arguments to be passed to your application's main class.\n" +
+ " Mutliple invocations are possible, each will be passed in order.\n" +
+ " --num-workers NUM Number of workers to start (Default: 2)\n" +
+ " --worker-cores NUM Number of cores for the workers (Default: 1)\n" +
+ " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n")
+ System.exit(exitCode)
+ }
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala
new file mode 100644
index 0000000000..7a881e26df
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala
@@ -0,0 +1,272 @@
+package spark.deploy.yarn
+
+import java.net.{InetSocketAddress, URI}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.client.YarnClientImpl
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions._
+import spark.{Logging, Utils}
+import org.apache.hadoop.yarn.util.{Apps, Records, ConverterUtils}
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import spark.deploy.SparkHadoopUtil
+
+class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging {
+
+ def this(args: ClientArguments) = this(new Configuration(), args)
+
+ var rpc: YarnRPC = YarnRPC.create(conf)
+ val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+ def run() {
+ init(yarnConf)
+ start()
+ logClusterResourceDetails()
+
+ val newApp = super.getNewApplication()
+ val appId = newApp.getApplicationId()
+
+ verifyClusterResources(newApp)
+ val appContext = createApplicationSubmissionContext(appId)
+ val localResources = prepareLocalResources(appId, "spark")
+ val env = setupLaunchEnv(localResources)
+ val amContainer = createContainerLaunchContext(newApp, localResources, env)
+
+ appContext.setQueue(args.amQueue)
+ appContext.setAMContainerSpec(amContainer)
+ appContext.setUser(args.amUser)
+
+ submitApp(appContext)
+
+ monitorApplication(appId)
+ System.exit(0)
+ }
+
+
+ def logClusterResourceDetails() {
+ val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics
+ logInfo("Got Cluster metric info from ASM, numNodeManagers=" + clusterMetrics.getNumNodeManagers)
+
+ val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue)
+ logInfo("Queue info .. queueName=" + queueInfo.getQueueName + ", queueCurrentCapacity=" + queueInfo.getCurrentCapacity +
+ ", queueMaxCapacity=" + queueInfo.getMaximumCapacity + ", queueApplicationCount=" + queueInfo.getApplications.size +
+ ", queueChildQueueCount=" + queueInfo.getChildQueues.size)
+ }
+
+
+ def verifyClusterResources(app: GetNewApplicationResponse) = {
+ val maxMem = app.getMaximumResourceCapability().getMemory()
+ logInfo("Max mem capabililty of resources in this cluster " + maxMem)
+
+ // If the cluster does not have enough memory resources, exit.
+ val requestedMem = (args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + args.numWorkers * args.workerMemory
+ if (requestedMem > maxMem) {
+ logError("Cluster cannot satisfy memory resource request of " + requestedMem)
+ System.exit(1)
+ }
+ }
+
+ def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = {
+ logInfo("Setting up application submission context for ASM")
+ val appContext = Records.newRecord(classOf[ApplicationSubmissionContext])
+ appContext.setApplicationId(appId)
+ appContext.setApplicationName("Spark")
+ return appContext
+ }
+
+ def prepareLocalResources(appId: ApplicationId, appName: String): HashMap[String, LocalResource] = {
+ logInfo("Preparing Local resources")
+ val locaResources = HashMap[String, LocalResource]()
+ // Upload Spark and the application JAR to the remote file system
+ // Add them as local resources to the AM
+ val fs = FileSystem.get(conf)
+ Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF"))
+ .foreach { case(destName, _localPath) =>
+ val localPath: String = if (_localPath != null) _localPath.trim() else ""
+ if (! localPath.isEmpty()) {
+ val src = new Path(localPath)
+ val pathSuffix = appName + "/" + appId.getId() + destName
+ val dst = new Path(fs.getHomeDirectory(), pathSuffix)
+ logInfo("Uploading " + src + " to " + dst)
+ fs.copyFromLocalFile(false, true, src, dst)
+ val destStatus = fs.getFileStatus(dst)
+
+ val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ amJarRsrc.setType(LocalResourceType.FILE)
+ amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(dst))
+ amJarRsrc.setTimestamp(destStatus.getModificationTime())
+ amJarRsrc.setSize(destStatus.getLen())
+ locaResources(destName) = amJarRsrc
+ }
+ }
+ return locaResources
+ }
+
+ def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = {
+ logInfo("Setting up the launch environment")
+ val log4jConfLocalRes = localResources.getOrElse("log4j.properties", null)
+
+ val env = new HashMap[String, String]()
+ Apps.addToEnvironment(env, Environment.USER.name, args.amUser)
+
+ // If log4j present, ensure ours overrides all others
+ if (log4jConfLocalRes != null) Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
+
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
+ Client.populateHadoopClasspath(yarnConf, env)
+ SparkHadoopUtil.setYarnMode(env)
+ env("SPARK_YARN_JAR_PATH") =
+ localResources("spark.jar").getResource().getScheme.toString() + "://" +
+ localResources("spark.jar").getResource().getFile().toString()
+ env("SPARK_YARN_JAR_TIMESTAMP") = localResources("spark.jar").getTimestamp().toString()
+ env("SPARK_YARN_JAR_SIZE") = localResources("spark.jar").getSize().toString()
+
+ env("SPARK_YARN_USERJAR_PATH") =
+ localResources("app.jar").getResource().getScheme.toString() + "://" +
+ localResources("app.jar").getResource().getFile().toString()
+ env("SPARK_YARN_USERJAR_TIMESTAMP") = localResources("app.jar").getTimestamp().toString()
+ env("SPARK_YARN_USERJAR_SIZE") = localResources("app.jar").getSize().toString()
+
+ if (log4jConfLocalRes != null) {
+ env("SPARK_YARN_LOG4J_PATH") =
+ log4jConfLocalRes.getResource().getScheme.toString() + "://" + log4jConfLocalRes.getResource().getFile().toString()
+ env("SPARK_YARN_LOG4J_TIMESTAMP") = log4jConfLocalRes.getTimestamp().toString()
+ env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString()
+ }
+
+ // Add each SPARK-* key to the environment
+ System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
+ return env
+ }
+
+ def userArgsToString(clientArgs: ClientArguments): String = {
+ val prefix = " --args "
+ val args = clientArgs.userArgs
+ val retval = new StringBuilder()
+ for (arg <- args){
+ retval.append(prefix).append(" '").append(arg).append("' ")
+ }
+
+ retval.toString
+ }
+
+ def createContainerLaunchContext(newApp: GetNewApplicationResponse,
+ localResources: HashMap[String, LocalResource],
+ env: HashMap[String, String]): ContainerLaunchContext = {
+ logInfo("Setting up container launch context")
+ val amContainer = Records.newRecord(classOf[ContainerLaunchContext])
+ amContainer.setLocalResources(localResources)
+ amContainer.setEnvironment(env)
+
+ val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory()
+
+ var amMemory = ((args.amMemory / minResMemory) * minResMemory) +
+ (if (0 != (args.amMemory % minResMemory)) minResMemory else 0) - YarnAllocationHandler.MEMORY_OVERHEAD
+
+ // Extra options for the JVM
+ var JAVA_OPTS = ""
+
+ // Add Xmx for am memory
+ JAVA_OPTS += "-Xmx" + amMemory + "m "
+
+ // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
+ // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
+ // node, spark gc effects all other containers performance (which can also be other spark containers)
+ // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
+ // limited to subset of cores on a node.
+ if (env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))) {
+ // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tenant machines
+ JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
+ JAVA_OPTS += " -XX:+CMSIncrementalMode "
+ JAVA_OPTS += " -XX:+CMSIncrementalPacing "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
+ }
+ if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
+ JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
+ }
+
+ // Command for the ApplicationMaster
+ val commands = List[String]("java " +
+ " -server " +
+ JAVA_OPTS +
+ " spark.deploy.yarn.ApplicationMaster" +
+ " --class " + args.userClass +
+ " --jar " + args.userJar +
+ userArgsToString(args) +
+ " --worker-memory " + args.workerMemory +
+ " --worker-cores " + args.workerCores +
+ " --num-workers " + args.numWorkers +
+ " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
+ " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
+ logInfo("Command for the ApplicationMaster: " + commands(0))
+ amContainer.setCommands(commands)
+
+ val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource]
+ // Memory for the ApplicationMaster
+ capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+ amContainer.setResource(capability)
+
+ return amContainer
+ }
+
+ def submitApp(appContext: ApplicationSubmissionContext) = {
+ // Submit the application to the applications manager
+ logInfo("Submitting application to ASM")
+ super.submitApplication(appContext)
+ }
+
+ def monitorApplication(appId: ApplicationId): Boolean = {
+ while(true) {
+ Thread.sleep(1000)
+ val report = super.getApplicationReport(appId)
+
+ logInfo("Application report from ASM: \n" +
+ "\t application identifier: " + appId.toString() + "\n" +
+ "\t appId: " + appId.getId() + "\n" +
+ "\t clientToken: " + report.getClientToken() + "\n" +
+ "\t appDiagnostics: " + report.getDiagnostics() + "\n" +
+ "\t appMasterHost: " + report.getHost() + "\n" +
+ "\t appQueue: " + report.getQueue() + "\n" +
+ "\t appMasterRpcPort: " + report.getRpcPort() + "\n" +
+ "\t appStartTime: " + report.getStartTime() + "\n" +
+ "\t yarnAppState: " + report.getYarnApplicationState() + "\n" +
+ "\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" +
+ "\t appTrackingUrl: " + report.getTrackingUrl() + "\n" +
+ "\t appUser: " + report.getUser()
+ )
+
+ val state = report.getYarnApplicationState()
+ val dsStatus = report.getFinalApplicationStatus()
+ if (state == YarnApplicationState.FINISHED ||
+ state == YarnApplicationState.FAILED ||
+ state == YarnApplicationState.KILLED) {
+ return true
+ }
+ }
+ return true
+ }
+}
+
+object Client {
+ def main(argStrings: Array[String]) {
+ val args = new ClientArguments(argStrings)
+ SparkHadoopUtil.setYarnMode()
+ new Client(args).run
+ }
+
+ // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps
+ def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) {
+ for (c <- conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim)
+ }
+ }
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala
new file mode 100644
index 0000000000..24110558e7
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala
@@ -0,0 +1,105 @@
+package spark.deploy.yarn
+
+import spark.util.MemoryParam
+import spark.util.IntParam
+import collection.mutable.{ArrayBuffer, HashMap}
+import spark.scheduler.{InputFormatInfo, SplitInfo}
+
+// TODO: Add code and support for ensuring that yarn resource 'asks' are location aware !
+class ClientArguments(val args: Array[String]) {
+ var userJar: String = null
+ var userClass: String = null
+ var userArgs: Seq[String] = Seq[String]()
+ var workerMemory = 1024
+ var workerCores = 1
+ var numWorkers = 2
+ var amUser = System.getProperty("user.name")
+ var amQueue = System.getProperty("QUEUE", "default")
+ var amMemory: Int = 512
+ // TODO
+ var inputFormatInfo: List[InputFormatInfo] = null
+
+ parseArgs(args.toList)
+
+ private def parseArgs(inputArgs: List[String]): Unit = {
+ val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]()
+ val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]()
+
+ var args = inputArgs
+
+ while (! args.isEmpty) {
+
+ args match {
+ case ("--jar") :: value :: tail =>
+ userJar = value
+ args = tail
+
+ case ("--class") :: value :: tail =>
+ userClass = value
+ args = tail
+
+ case ("--args") :: value :: tail =>
+ userArgsBuffer += value
+ args = tail
+
+ case ("--master-memory") :: MemoryParam(value) :: tail =>
+ amMemory = value
+ args = tail
+
+ case ("--num-workers") :: IntParam(value) :: tail =>
+ numWorkers = value
+ args = tail
+
+ case ("--worker-memory") :: MemoryParam(value) :: tail =>
+ workerMemory = value
+ args = tail
+
+ case ("--worker-cores") :: IntParam(value) :: tail =>
+ workerCores = value
+ args = tail
+
+ case ("--user") :: value :: tail =>
+ amUser = value
+ args = tail
+
+ case ("--queue") :: value :: tail =>
+ amQueue = value
+ args = tail
+
+ case Nil =>
+ if (userJar == null || userClass == null) {
+ printUsageAndExit(1)
+ }
+
+ case _ =>
+ printUsageAndExit(1, args)
+ }
+ }
+
+ userArgs = userArgsBuffer.readOnly
+ inputFormatInfo = inputFormatMap.values.toList
+ }
+
+
+ def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ if (unknownParam != null) {
+ System.err.println("Unknown/unsupported param " + unknownParam)
+ }
+ System.err.println(
+ "Usage: spark.deploy.yarn.Client [options] \n" +
+ "Options:\n" +
+ " --jar JAR_PATH Path to your application's JAR file (required)\n" +
+ " --class CLASS_NAME Name of your application's main class (required)\n" +
+ " --args ARGS Arguments to be passed to your application's main class.\n" +
+ " Mutliple invocations are possible, each will be passed in order.\n" +
+ " --num-workers NUM Number of workers to start (Default: 2)\n" +
+ " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
+ " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
+ " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
+ " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" +
+ " --user USERNAME Run the ApplicationMaster (and slaves) as a different user\n"
+ )
+ System.exit(exitCode)
+ }
+
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala
new file mode 100644
index 0000000000..a2bf0af762
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala
@@ -0,0 +1,171 @@
+package spark.deploy.yarn
+
+import java.net.URI
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records}
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.HashMap
+
+import spark.{Logging, Utils}
+
+class WorkerRunnable(container: Container, conf: Configuration, masterAddress: String,
+ slaveId: String, hostname: String, workerMemory: Int, workerCores: Int)
+ extends Runnable with Logging {
+
+ var rpc: YarnRPC = YarnRPC.create(conf)
+ var cm: ContainerManager = null
+ val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+ def run = {
+ logInfo("Starting Worker Container")
+ cm = connectToCM
+ startContainer
+ }
+
+ def startContainer = {
+ logInfo("Setting up ContainerLaunchContext")
+
+ val ctx = Records.newRecord(classOf[ContainerLaunchContext])
+ .asInstanceOf[ContainerLaunchContext]
+
+ ctx.setContainerId(container.getId())
+ ctx.setResource(container.getResource())
+ val localResources = prepareLocalResources
+ ctx.setLocalResources(localResources)
+
+ val env = prepareEnvironment
+ ctx.setEnvironment(env)
+
+ // Extra options for the JVM
+ var JAVA_OPTS = ""
+ // Set the JVM memory
+ val workerMemoryString = workerMemory + "m"
+ JAVA_OPTS += "-Xms" + workerMemoryString + " -Xmx" + workerMemoryString + " "
+ if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
+ JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
+ }
+ // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
+ // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
+ // node, spark gc effects all other containers performance (which can also be other spark containers)
+ // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
+ // limited to subset of cores on a node.
+/*
+ else {
+ // If no java_opts specified, default to using -XX:+CMSIncrementalMode
+ // It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont want to mess with it.
+ // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tennent machines
+ // The options are based on
+ // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline
+ JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
+ JAVA_OPTS += " -XX:+CMSIncrementalMode "
+ JAVA_OPTS += " -XX:+CMSIncrementalPacing "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
+ }
+*/
+
+ ctx.setUser(UserGroupInformation.getCurrentUser().getShortUserName())
+ val commands = List[String]("java " +
+ " -server " +
+ // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling.
+ // Not killing the task leaves various aspects of the worker and (to some extent) the jvm in an inconsistent state.
+ // TODO: If the OOM is not recoverable by rescheduling it on different node, then do 'something' to fail job ... akin to blacklisting trackers in mapred ?
+ " -XX:OnOutOfMemoryError='kill %p' " +
+ JAVA_OPTS +
+ " spark.executor.StandaloneExecutorBackend " +
+ masterAddress + " " +
+ slaveId + " " +
+ hostname + " " +
+ workerCores +
+ " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
+ " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
+ logInfo("Setting up worker with commands: " + commands)
+ ctx.setCommands(commands)
+
+ // Send the start request to the ContainerManager
+ val startReq = Records.newRecord(classOf[StartContainerRequest])
+ .asInstanceOf[StartContainerRequest]
+ startReq.setContainerLaunchContext(ctx)
+ cm.startContainer(startReq)
+ }
+
+
+ def prepareLocalResources: HashMap[String, LocalResource] = {
+ logInfo("Preparing Local resources")
+ val locaResources = HashMap[String, LocalResource]()
+
+ // Spark JAR
+ val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ sparkJarResource.setType(LocalResourceType.FILE)
+ sparkJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
+ sparkJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
+ new URI(System.getenv("SPARK_YARN_JAR_PATH"))))
+ sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong)
+ sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong)
+ locaResources("spark.jar") = sparkJarResource
+ // User JAR
+ val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ userJarResource.setType(LocalResourceType.FILE)
+ userJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
+ userJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
+ new URI(System.getenv("SPARK_YARN_USERJAR_PATH"))))
+ userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong)
+ userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong)
+ locaResources("app.jar") = userJarResource
+
+ // Log4j conf - if available
+ if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
+ val log4jConfResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ log4jConfResource.setType(LocalResourceType.FILE)
+ log4jConfResource.setVisibility(LocalResourceVisibility.APPLICATION)
+ log4jConfResource.setResource(ConverterUtils.getYarnUrlFromURI(
+ new URI(System.getenv("SPARK_YARN_LOG4J_PATH"))))
+ log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong)
+ log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong)
+ locaResources("log4j.properties") = log4jConfResource
+ }
+
+
+ logInfo("Prepared Local resources " + locaResources)
+ return locaResources
+ }
+
+ def prepareEnvironment: HashMap[String, String] = {
+ val env = new HashMap[String, String]()
+ // should we add this ?
+ Apps.addToEnvironment(env, Environment.USER.name, Utils.getUserNameFromEnvironment())
+
+ // If log4j present, ensure ours overrides all others
+ if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
+ // Which is correct ?
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./log4j.properties")
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
+ }
+
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
+ Client.populateHadoopClasspath(yarnConf, env)
+
+ System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
+ return env
+ }
+
+ def connectToCM: ContainerManager = {
+ val cmHostPortStr = container.getNodeId().getHost() + ":" + container.getNodeId().getPort()
+ val cmAddress = NetUtils.createSocketAddr(cmHostPortStr)
+ logInfo("Connecting to ContainerManager at " + cmHostPortStr)
+ return rpc.getProxy(classOf[ContainerManager], cmAddress, conf).asInstanceOf[ContainerManager]
+ }
+
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala
new file mode 100644
index 0000000000..61dd72a651
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala
@@ -0,0 +1,547 @@
+package spark.deploy.yarn
+
+import spark.{Logging, Utils}
+import spark.scheduler.SplitInfo
+import scala.collection
+import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId, ContainerId, Priority, Resource, ResourceRequest, ContainerStatus, Container}
+import spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend}
+import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse}
+import org.apache.hadoop.yarn.util.{RackResolver, Records}
+import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap}
+import java.util.concurrent.atomic.AtomicInteger
+import org.apache.hadoop.yarn.api.AMRMProtocol
+import collection.JavaConversions._
+import collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import org.apache.hadoop.conf.Configuration
+import java.util.{Collections, Set => JSet}
+import java.lang.{Boolean => JBoolean}
+
+object AllocationType extends Enumeration ("HOST", "RACK", "ANY") {
+ type AllocationType = Value
+ val HOST, RACK, ANY = Value
+}
+
+// too many params ? refactor it 'somehow' ?
+// needs to be mt-safe
+// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive : should make it
+// more proactive and decoupled.
+// Note that right now, we assume all node asks as uniform in terms of capabilities and priority
+// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for more info
+// on how we are requesting for containers.
+private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceManager: AMRMProtocol,
+ val appAttemptId: ApplicationAttemptId,
+ val maxWorkers: Int, val workerMemory: Int, val workerCores: Int,
+ val preferredHostToCount: Map[String, Int],
+ val preferredRackToCount: Map[String, Int])
+ extends Logging {
+
+
+ // These three are locked on allocatedHostToContainersMap. Complementary data structures
+ // allocatedHostToContainersMap : containers which are running : host, Set<containerid>
+ // allocatedContainerToHostMap: container to host mapping
+ private val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]]()
+ private val allocatedContainerToHostMap = new HashMap[ContainerId, String]()
+ // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an allocated node)
+ // As with the two data structures above, tightly coupled with them, and to be locked on allocatedHostToContainersMap
+ private val allocatedRackCount = new HashMap[String, Int]()
+
+ // containers which have been released.
+ private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]()
+ // containers to be released in next request to RM
+ private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean]
+
+ private val numWorkersRunning = new AtomicInteger()
+ // Used to generate a unique id per worker
+ private val workerIdCounter = new AtomicInteger()
+ private val lastResponseId = new AtomicInteger()
+
+ def getNumWorkersRunning: Int = numWorkersRunning.intValue
+
+
+ def isResourceConstraintSatisfied(container: Container): Boolean = {
+ container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+ }
+
+ def allocateContainers(workersToRequest: Int) {
+ // We need to send the request only once from what I understand ... but for now, not modifying this much.
+
+ // Keep polling the Resource Manager for containers
+ val amResp = allocateWorkerResources(workersToRequest).getAMResponse
+
+ val _allocatedContainers = amResp.getAllocatedContainers()
+ if (_allocatedContainers.size > 0) {
+
+
+ logDebug("Allocated " + _allocatedContainers.size + " containers, current count " +
+ numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
+ ", pendingReleaseContainers : " + pendingReleaseContainers)
+ logDebug("Cluster Resources: " + amResp.getAvailableResources)
+
+ val hostToContainers = new HashMap[String, ArrayBuffer[Container]]()
+
+ // ignore if not satisfying constraints {
+ for (container <- _allocatedContainers) {
+ if (isResourceConstraintSatisfied(container)) {
+ // allocatedContainers += container
+
+ val host = container.getNodeId.getHost
+ val containers = hostToContainers.getOrElseUpdate(host, new ArrayBuffer[Container]())
+
+ containers += container
+ }
+ // Add all ignored containers to released list
+ else releasedContainerList.add(container.getId())
+ }
+
+ // Find the appropriate containers to use
+ // Slightly non trivial groupBy I guess ...
+ val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
+ val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
+ val offRackContainers = new HashMap[String, ArrayBuffer[Container]]()
+
+ for (candidateHost <- hostToContainers.keySet)
+ {
+ val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0)
+ val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost)
+
+ var remainingContainers = hostToContainers.get(candidateHost).getOrElse(null)
+ assert(remainingContainers != null)
+
+ if (requiredHostCount >= remainingContainers.size){
+ // Since we got <= required containers, add all to dataLocalContainers
+ dataLocalContainers.put(candidateHost, remainingContainers)
+ // all consumed
+ remainingContainers = null
+ }
+ else if (requiredHostCount > 0) {
+ // container list has more containers than we need for data locality.
+ // Split into two : data local container count of (remainingContainers.size - requiredHostCount)
+ // and rest as remainingContainer
+ val (dataLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredHostCount)
+ dataLocalContainers.put(candidateHost, dataLocal)
+ // remainingContainers = remaining
+
+ // yarn has nasty habit of allocating a tonne of containers on a host - discourage this :
+ // add remaining to release list. If we have insufficient containers, next allocation cycle
+ // will reallocate (but wont treat it as data local)
+ for (container <- remaining) releasedContainerList.add(container.getId())
+ remainingContainers = null
+ }
+
+ // now rack local
+ if (remainingContainers != null){
+ val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
+
+ if (rack != null){
+ val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0)
+ val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) -
+ rackLocalContainers.get(rack).getOrElse(List()).size
+
+
+ if (requiredRackCount >= remainingContainers.size){
+ // Add all to dataLocalContainers
+ dataLocalContainers.put(rack, remainingContainers)
+ // all consumed
+ remainingContainers = null
+ }
+ else if (requiredRackCount > 0) {
+ // container list has more containers than we need for data locality.
+ // Split into two : data local container count of (remainingContainers.size - requiredRackCount)
+ // and rest as remainingContainer
+ val (rackLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredRackCount)
+ val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, new ArrayBuffer[Container]())
+
+ existingRackLocal ++= rackLocal
+ remainingContainers = remaining
+ }
+ }
+ }
+
+ // If still not consumed, then it is off rack host - add to that list.
+ if (remainingContainers != null){
+ offRackContainers.put(candidateHost, remainingContainers)
+ }
+ }
+
+ // Now that we have split the containers into various groups, go through them in order :
+ // first host local, then rack local and then off rack (everything else).
+ // Note that the list we create below tries to ensure that not all containers end up within a host
+ // if there are sufficiently large number of hosts/containers.
+
+ val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size)
+ allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers)
+ allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers)
+ allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers)
+
+ // Run each of the allocated containers
+ for (container <- allocatedContainers) {
+ val numWorkersRunningNow = numWorkersRunning.incrementAndGet()
+ val workerHostname = container.getNodeId.getHost
+ val containerId = container.getId
+
+ assert (container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD))
+
+ if (numWorkersRunningNow > maxWorkers) {
+ logInfo("Ignoring container " + containerId + " at host " + workerHostname +
+ " .. we already have required number of containers")
+ releasedContainerList.add(containerId)
+ // reset counter back to old value.
+ numWorkersRunning.decrementAndGet()
+ }
+ else {
+ // deallocate + allocate can result in reusing id's wrongly - so use a different counter (workerIdCounter)
+ val workerId = workerIdCounter.incrementAndGet().toString
+ val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
+ StandaloneSchedulerBackend.ACTOR_NAME)
+
+ logInfo("launching container on " + containerId + " host " + workerHostname)
+ // just to be safe, simply remove it from pendingReleaseContainers. Should not be there, but ..
+ pendingReleaseContainers.remove(containerId)
+
+ val rack = YarnAllocationHandler.lookupRack(conf, workerHostname)
+ allocatedHostToContainersMap.synchronized {
+ val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname, new HashSet[ContainerId]())
+
+ containerSet += containerId
+ allocatedContainerToHostMap.put(containerId, workerHostname)
+ if (rack != null) allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1)
+ }
+
+ new Thread(
+ new WorkerRunnable(container, conf, driverUrl, workerId,
+ workerHostname, workerMemory, workerCores)
+ ).start()
+ }
+ }
+ logDebug("After allocated " + allocatedContainers.size + " containers (orig : " +
+ _allocatedContainers.size + "), current count " + numWorkersRunning.get() +
+ ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
+ }
+
+
+ val completedContainers = amResp.getCompletedContainersStatuses()
+ if (completedContainers.size > 0){
+ logDebug("Completed " + completedContainers.size + " containers, current count " + numWorkersRunning.get() +
+ ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
+
+ for (completedContainer <- completedContainers){
+ val containerId = completedContainer.getContainerId
+
+ // Was this released by us ? If yes, then simply remove from containerSet and move on.
+ if (pendingReleaseContainers.containsKey(containerId)) {
+ pendingReleaseContainers.remove(containerId)
+ }
+ else {
+ // simply decrement count - next iteration of ReporterThread will take care of allocating !
+ numWorkersRunning.decrementAndGet()
+ logInfo("Container completed ? nodeId: " + containerId + ", state " + completedContainer.getState +
+ " httpaddress: " + completedContainer.getDiagnostics)
+ }
+
+ allocatedHostToContainersMap.synchronized {
+ if (allocatedContainerToHostMap.containsKey(containerId)) {
+ val host = allocatedContainerToHostMap.get(containerId).getOrElse(null)
+ assert (host != null)
+
+ val containerSet = allocatedHostToContainersMap.get(host).getOrElse(null)
+ assert (containerSet != null)
+
+ containerSet -= containerId
+ if (containerSet.isEmpty) allocatedHostToContainersMap.remove(host)
+ else allocatedHostToContainersMap.update(host, containerSet)
+
+ allocatedContainerToHostMap -= containerId
+
+ // doing this within locked context, sigh ... move to outside ?
+ val rack = YarnAllocationHandler.lookupRack(conf, host)
+ if (rack != null) {
+ val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1
+ if (rackCount > 0) allocatedRackCount.put(rack, rackCount)
+ else allocatedRackCount.remove(rack)
+ }
+ }
+ }
+ }
+ logDebug("After completed " + completedContainers.size + " containers, current count " +
+ numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
+ ", pendingReleaseContainers : " + pendingReleaseContainers)
+ }
+ }
+
+ def createRackResourceRequests(hostContainers: List[ResourceRequest]): List[ResourceRequest] = {
+ // First generate modified racks and new set of hosts under it : then issue requests
+ val rackToCounts = new HashMap[String, Int]()
+
+ // Within this lock - used to read/write to the rack related maps too.
+ for (container <- hostContainers) {
+ val candidateHost = container.getHostName
+ val candidateNumContainers = container.getNumContainers
+ assert(YarnAllocationHandler.ANY_HOST != candidateHost)
+
+ val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
+ if (rack != null) {
+ var count = rackToCounts.getOrElse(rack, 0)
+ count += candidateNumContainers
+ rackToCounts.put(rack, count)
+ }
+ }
+
+ val requestedContainers: ArrayBuffer[ResourceRequest] =
+ new ArrayBuffer[ResourceRequest](rackToCounts.size)
+ for ((rack, count) <- rackToCounts){
+ requestedContainers +=
+ createResourceRequest(AllocationType.RACK, rack, count, YarnAllocationHandler.PRIORITY)
+ }
+
+ requestedContainers.toList
+ }
+
+ def allocatedContainersOnHost(host: String): Int = {
+ var retval = 0
+ allocatedHostToContainersMap.synchronized {
+ retval = allocatedHostToContainersMap.getOrElse(host, Set()).size
+ }
+ retval
+ }
+
+ def allocatedContainersOnRack(rack: String): Int = {
+ var retval = 0
+ allocatedHostToContainersMap.synchronized {
+ retval = allocatedRackCount.getOrElse(rack, 0)
+ }
+ retval
+ }
+
+ private def allocateWorkerResources(numWorkers: Int): AllocateResponse = {
+
+ var resourceRequests: List[ResourceRequest] = null
+
+ // default.
+ if (numWorkers <= 0 || preferredHostToCount.isEmpty) {
+ logDebug("numWorkers: " + numWorkers + ", host preferences ? " + preferredHostToCount.isEmpty)
+ resourceRequests = List(
+ createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY))
+ }
+ else {
+ // request for all hosts in preferred nodes and for numWorkers -
+ // candidates.size, request by default allocation policy.
+ val hostContainerRequests: ArrayBuffer[ResourceRequest] =
+ new ArrayBuffer[ResourceRequest](preferredHostToCount.size)
+ for ((candidateHost, candidateCount) <- preferredHostToCount) {
+ val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost)
+
+ if (requiredCount > 0) {
+ hostContainerRequests +=
+ createResourceRequest(AllocationType.HOST, candidateHost, requiredCount, YarnAllocationHandler.PRIORITY)
+ }
+ }
+ val rackContainerRequests: List[ResourceRequest] = createRackResourceRequests(hostContainerRequests.toList)
+
+ val anyContainerRequests: ResourceRequest =
+ createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY)
+
+ val containerRequests: ArrayBuffer[ResourceRequest] =
+ new ArrayBuffer[ResourceRequest](hostContainerRequests.size() + rackContainerRequests.size() + 1)
+
+ containerRequests ++= hostContainerRequests
+ containerRequests ++= rackContainerRequests
+ containerRequests += anyContainerRequests
+
+ resourceRequests = containerRequests.toList
+ }
+
+ val req = Records.newRecord(classOf[AllocateRequest])
+ req.setResponseId(lastResponseId.incrementAndGet)
+ req.setApplicationAttemptId(appAttemptId)
+
+ req.addAllAsks(resourceRequests)
+
+ val releasedContainerList = createReleasedContainerList()
+ req.addAllReleases(releasedContainerList)
+
+
+
+ if (numWorkers > 0) {
+ logInfo("Allocating " + numWorkers + " worker containers with " + (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + " of memory each.")
+ }
+ else {
+ logDebug("Empty allocation req .. release : " + releasedContainerList)
+ }
+
+ for (req <- resourceRequests) {
+ logInfo("rsrcRequest ... host : " + req.getHostName + ", numContainers : " + req.getNumContainers +
+ ", p = " + req.getPriority().getPriority + ", capability: " + req.getCapability)
+ }
+ resourceManager.allocate(req)
+ }
+
+
+ private def createResourceRequest(requestType: AllocationType.AllocationType,
+ resource:String, numWorkers: Int, priority: Int): ResourceRequest = {
+
+ // If hostname specified, we need atleast two requests - node local and rack local.
+ // There must be a third request - which is ANY : that will be specially handled.
+ requestType match {
+ case AllocationType.HOST => {
+ assert (YarnAllocationHandler.ANY_HOST != resource)
+
+ val hostname = resource
+ val nodeLocal = createResourceRequestImpl(hostname, numWorkers, priority)
+
+ // add to host->rack mapping
+ YarnAllocationHandler.populateRackInfo(conf, hostname)
+
+ nodeLocal
+ }
+
+ case AllocationType.RACK => {
+ val rack = resource
+ createResourceRequestImpl(rack, numWorkers, priority)
+ }
+
+ case AllocationType.ANY => {
+ createResourceRequestImpl(YarnAllocationHandler.ANY_HOST, numWorkers, priority)
+ }
+
+ case _ => throw new IllegalArgumentException("Unexpected/unsupported request type .. " + requestType)
+ }
+ }
+
+ private def createResourceRequestImpl(hostname:String, numWorkers: Int, priority: Int): ResourceRequest = {
+
+ val rsrcRequest = Records.newRecord(classOf[ResourceRequest])
+ val memCapability = Records.newRecord(classOf[Resource])
+ // There probably is some overhead here, let's reserve a bit more memory.
+ memCapability.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+ rsrcRequest.setCapability(memCapability)
+
+ val pri = Records.newRecord(classOf[Priority])
+ pri.setPriority(priority)
+ rsrcRequest.setPriority(pri)
+
+ rsrcRequest.setHostName(hostname)
+
+ rsrcRequest.setNumContainers(java.lang.Math.max(numWorkers, 0))
+ rsrcRequest
+ }
+
+ def createReleasedContainerList(): ArrayBuffer[ContainerId] = {
+
+ val retval = new ArrayBuffer[ContainerId](1)
+ // iterator on COW list ...
+ for (container <- releasedContainerList.iterator()){
+ retval += container
+ }
+ // remove from the original list.
+ if (! retval.isEmpty) {
+ releasedContainerList.removeAll(retval)
+ for (v <- retval) pendingReleaseContainers.put(v, true)
+ logInfo("Releasing " + retval.size + " containers. pendingReleaseContainers : " +
+ pendingReleaseContainers)
+ }
+
+ retval
+ }
+}
+
+object YarnAllocationHandler {
+
+ val ANY_HOST = "*"
+ // all requests are issued with same priority : we do not (yet) have any distinction between
+ // request types (like map/reduce in hadoop for example)
+ val PRIORITY = 1
+
+ // Additional memory overhead - in mb
+ val MEMORY_OVERHEAD = 384
+
+ // host to rack map - saved from allocation requests
+ // We are expecting this not to change.
+ // Note that it is possible for this to change : and RM will indicate that to us via update
+ // response to allocate. But we are punting on handling that for now.
+ private val hostToRack = new ConcurrentHashMap[String, String]()
+ private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]()
+
+ def newAllocator(conf: Configuration,
+ resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
+ args: ApplicationMasterArguments,
+ map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
+
+ val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
+
+
+ new YarnAllocationHandler(conf, resourceManager, appAttemptId, args.numWorkers,
+ args.workerMemory, args.workerCores, hostToCount, rackToCount)
+ }
+
+ def newAllocator(conf: Configuration,
+ resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
+ maxWorkers: Int, workerMemory: Int, workerCores: Int,
+ map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
+
+ val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
+
+ new YarnAllocationHandler(conf, resourceManager, appAttemptId, maxWorkers,
+ workerMemory, workerCores, hostToCount, rackToCount)
+ }
+
+ // A simple method to copy the split info map.
+ private def generateNodeToWeight(conf: Configuration, input: collection.Map[String, collection.Set[SplitInfo]]) :
+ // host to count, rack to count
+ (Map[String, Int], Map[String, Int]) = {
+
+ if (input == null) return (Map[String, Int](), Map[String, Int]())
+
+ val hostToCount = new HashMap[String, Int]
+ val rackToCount = new HashMap[String, Int]
+
+ for ((host, splits) <- input) {
+ val hostCount = hostToCount.getOrElse(host, 0)
+ hostToCount.put(host, hostCount + splits.size)
+
+ val rack = lookupRack(conf, host)
+ if (rack != null){
+ val rackCount = rackToCount.getOrElse(host, 0)
+ rackToCount.put(host, rackCount + splits.size)
+ }
+ }
+
+ (hostToCount.toMap, rackToCount.toMap)
+ }
+
+ def lookupRack(conf: Configuration, host: String): String = {
+ if (! hostToRack.contains(host)) populateRackInfo(conf, host)
+ hostToRack.get(host)
+ }
+
+ def fetchCachedHostsForRack(rack: String): Option[Set[String]] = {
+ val set = rackToHostSet.get(rack)
+ if (set == null) return None
+
+ // No better way to get a Set[String] from JSet ?
+ val convertedSet: collection.mutable.Set[String] = set
+ Some(convertedSet.toSet)
+ }
+
+ def populateRackInfo(conf: Configuration, hostname: String) {
+ Utils.checkHost(hostname)
+
+ if (!hostToRack.containsKey(hostname)) {
+ // If there are repeated failures to resolve, all to an ignore list ?
+ val rackInfo = RackResolver.resolve(conf, hostname)
+ if (rackInfo != null && rackInfo.getNetworkLocation != null) {
+ val rack = rackInfo.getNetworkLocation
+ hostToRack.put(hostname, rack)
+ if (! rackToHostSet.containsKey(rack)) {
+ rackToHostSet.putIfAbsent(rack, Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]()))
+ }
+ rackToHostSet.get(rack).add(hostname)
+
+ // Since RackResolver caches, we are disabling this for now ...
+ } /* else {
+ // right ? Else we will keep calling rack resolver in case we cant resolve rack info ...
+ hostToRack.put(hostname, null)
+ } */
+ }
+ }
+}
diff --git a/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala b/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala
new file mode 100644
index 0000000000..ed732d36bf
--- /dev/null
+++ b/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -0,0 +1,42 @@
+package spark.scheduler.cluster
+
+import spark._
+import spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler}
+import org.apache.hadoop.conf.Configuration
+
+/**
+ *
+ * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done
+ */
+private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+
+ def this(sc: SparkContext) = this(sc, new Configuration())
+
+ // Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate
+ // Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?)
+ // Subsequent creations are ignored - since nodes are already allocated by then.
+
+
+ // By default, rack is unknown
+ override def getRackForHost(hostPort: String): Option[String] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ val retval = YarnAllocationHandler.lookupRack(conf, host)
+ if (retval != null) Some(retval) else None
+ }
+
+ // By default, if rack is unknown, return nothing
+ override def getCachedHostsForRack(rack: String): Option[Set[String]] = {
+ if (rack == None || rack == null) return None
+
+ YarnAllocationHandler.fetchCachedHostsForRack(rack)
+ }
+
+ override def postStartHook() {
+ val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc)
+ if (sparkContextInitialized){
+ // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
+ Thread.sleep(3000L)
+ }
+ logInfo("YarnClusterScheduler.postStartHook done")
+ }
+}
diff --git a/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
index 35300cea58..a0652d7fc7 100644
--- a/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
+++ b/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala
@@ -4,4 +4,7 @@ trait HadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
+ jobId, isMap, taskId, attemptId)
}
diff --git a/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
index 7afdbff320..7fdbe322fd 100644
--- a/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
+++ b/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala
@@ -7,4 +7,7 @@ trait HadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+
+ def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
+ jobId, isMap, taskId, attemptId)
}
diff --git a/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala
new file mode 100644
index 0000000000..a0fb4fe25d
--- /dev/null
+++ b/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala
@@ -0,0 +1,23 @@
+package spark.deploy
+import org.apache.hadoop.conf.Configuration
+
+
+/**
+ * Contains util methods to interact with Hadoop from spark.
+ */
+object SparkHadoopUtil {
+
+ def getUserNameFromEnvironment(): String = {
+ // defaulting to -D ...
+ System.getProperty("user.name")
+ }
+
+ def runAsUser(func: (Product) => Unit, args: Product) {
+
+ // Add support, if exists - for now, simply run func !
+ func(args)
+ }
+
+ // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
+ def newConfiguration(): Configuration = new Configuration()
+}
diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java
new file mode 100644
index 0000000000..a4bb4bc701
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileClient.java
@@ -0,0 +1,72 @@
+package spark.network.netty;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.oio.OioEventLoopGroup;
+import io.netty.channel.socket.oio.OioSocketChannel;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+class FileClient {
+
+ private Logger LOG = LoggerFactory.getLogger(this.getClass().getName());
+ private FileClientHandler handler = null;
+ private Channel channel = null;
+ private Bootstrap bootstrap = null;
+ private int connectTimeout = 60*1000; // 1 min
+
+ public FileClient(FileClientHandler handler, int connectTimeout) {
+ this.handler = handler;
+ this.connectTimeout = connectTimeout;
+ }
+
+ public void init() {
+ bootstrap = new Bootstrap();
+ bootstrap.group(new OioEventLoopGroup())
+ .channel(OioSocketChannel.class)
+ .option(ChannelOption.SO_KEEPALIVE, true)
+ .option(ChannelOption.TCP_NODELAY, true)
+ .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout)
+ .handler(new FileClientChannelInitializer(handler));
+ }
+
+ public void connect(String host, int port) {
+ try {
+ // Start the connection attempt.
+ channel = bootstrap.connect(host, port).sync().channel();
+ // ChannelFuture cf = channel.closeFuture();
+ //cf.addListener(new ChannelCloseListener(this));
+ } catch (InterruptedException e) {
+ close();
+ }
+ }
+
+ public void waitForClose() {
+ try {
+ channel.closeFuture().sync();
+ } catch (InterruptedException e) {
+ LOG.warn("FileClient interrupted", e);
+ }
+ }
+
+ public void sendRequest(String file) {
+ //assert(file == null);
+ //assert(channel == null);
+ channel.write(file + "\r\n");
+ }
+
+ public void close() {
+ if(channel != null) {
+ channel.close();
+ channel = null;
+ }
+ if ( bootstrap!=null) {
+ bootstrap.shutdown();
+ bootstrap = null;
+ }
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java
new file mode 100644
index 0000000000..af25baf641
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java
@@ -0,0 +1,24 @@
+package spark.network.netty;
+
+import io.netty.buffer.BufType;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.codec.string.StringEncoder;
+
+
+class FileClientChannelInitializer extends ChannelInitializer<SocketChannel> {
+
+ private FileClientHandler fhandler;
+
+ public FileClientChannelInitializer(FileClientHandler handler) {
+ fhandler = handler;
+ }
+
+ @Override
+ public void initChannel(SocketChannel channel) {
+ // file no more than 2G
+ channel.pipeline()
+ .addLast("encoder", new StringEncoder(BufType.BYTE))
+ .addLast("handler", fhandler);
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java
new file mode 100644
index 0000000000..9fc9449827
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileClientHandler.java
@@ -0,0 +1,43 @@
+package spark.network.netty;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundByteHandlerAdapter;
+
+
+abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
+
+ private FileHeader currentHeader = null;
+
+ private volatile boolean handlerCalled = false;
+
+ public boolean isComplete() {
+ return handlerCalled;
+ }
+
+ public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
+ public abstract void handleError(String blockId);
+
+ @Override
+ public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {
+ // Use direct buffer if possible.
+ return ctx.alloc().ioBuffer();
+ }
+
+ @Override
+ public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) {
+ // get header
+ if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) {
+ currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE()));
+ }
+ // get file
+ if(in.readableBytes() >= currentHeader.fileLen()) {
+ handle(ctx, in, currentHeader);
+ handlerCalled = true;
+ currentHeader = null;
+ ctx.close();
+ }
+ }
+
+}
+
diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java
new file mode 100644
index 0000000000..dd3a557ae5
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileServer.java
@@ -0,0 +1,86 @@
+package spark.network.netty;
+
+import java.net.InetSocketAddress;
+
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.oio.OioEventLoopGroup;
+import io.netty.channel.socket.oio.OioServerSocketChannel;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Server that accept the path of a file an echo back its content.
+ */
+class FileServer {
+
+ private Logger LOG = LoggerFactory.getLogger(this.getClass().getName());
+
+ private ServerBootstrap bootstrap = null;
+ private ChannelFuture channelFuture = null;
+ private int port = 0;
+ private Thread blockingThread = null;
+
+ public FileServer(PathResolver pResolver, int port) {
+ InetSocketAddress addr = new InetSocketAddress(port);
+
+ // Configure the server.
+ bootstrap = new ServerBootstrap();
+ bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup())
+ .channel(OioServerSocketChannel.class)
+ .option(ChannelOption.SO_BACKLOG, 100)
+ .option(ChannelOption.SO_RCVBUF, 1500)
+ .childHandler(new FileServerChannelInitializer(pResolver));
+ // Start the server.
+ channelFuture = bootstrap.bind(addr);
+ try {
+ // Get the address we bound to.
+ InetSocketAddress boundAddress =
+ ((InetSocketAddress) channelFuture.sync().channel().localAddress());
+ this.port = boundAddress.getPort();
+ } catch (InterruptedException ie) {
+ this.port = 0;
+ }
+ }
+
+ /**
+ * Start the file server asynchronously in a new thread.
+ */
+ public void start() {
+ blockingThread = new Thread() {
+ public void run() {
+ try {
+ channelFuture.channel().closeFuture().sync();
+ LOG.info("FileServer exiting");
+ } catch (InterruptedException e) {
+ LOG.error("File server start got interrupted", e);
+ }
+ // NOTE: bootstrap is shutdown in stop()
+ }
+ };
+ blockingThread.setDaemon(true);
+ blockingThread.start();
+ }
+
+ public int getPort() {
+ return port;
+ }
+
+ public void stop() {
+ // Close the bound channel.
+ if (channelFuture != null) {
+ channelFuture.channel().close();
+ channelFuture = null;
+ }
+ // Shutdown bootstrap.
+ if (bootstrap != null) {
+ bootstrap.shutdown();
+ bootstrap = null;
+ }
+ // TODO: Shutdown all accepted channels as well ?
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java
new file mode 100644
index 0000000000..8f1f5c65cd
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java
@@ -0,0 +1,25 @@
+package spark.network.netty;
+
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.codec.DelimiterBasedFrameDecoder;
+import io.netty.handler.codec.Delimiters;
+import io.netty.handler.codec.string.StringDecoder;
+
+
+class FileServerChannelInitializer extends ChannelInitializer<SocketChannel> {
+
+ PathResolver pResolver;
+
+ public FileServerChannelInitializer(PathResolver pResolver) {
+ this.pResolver = pResolver;
+ }
+
+ @Override
+ public void initChannel(SocketChannel channel) {
+ channel.pipeline()
+ .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter()))
+ .addLast("strDecoder", new StringDecoder())
+ .addLast("handler", new FileServerHandler(pResolver));
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/FileServerHandler.java b/core/src/main/java/spark/network/netty/FileServerHandler.java
new file mode 100644
index 0000000000..a78eddb1b5
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileServerHandler.java
@@ -0,0 +1,65 @@
+package spark.network.netty;
+
+import java.io.File;
+import java.io.FileInputStream;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundMessageHandlerAdapter;
+import io.netty.channel.DefaultFileRegion;
+
+
+class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
+
+ PathResolver pResolver;
+
+ public FileServerHandler(PathResolver pResolver){
+ this.pResolver = pResolver;
+ }
+
+ @Override
+ public void messageReceived(ChannelHandlerContext ctx, String blockId) {
+ String path = pResolver.getAbsolutePath(blockId);
+ // if getFilePath returns null, close the channel
+ if (path == null) {
+ //ctx.close();
+ return;
+ }
+ File file = new File(path);
+ if (file.exists()) {
+ if (!file.isFile()) {
+ //logger.info("Not a file : " + file.getAbsolutePath());
+ ctx.write(new FileHeader(0, blockId).buffer());
+ ctx.flush();
+ return;
+ }
+ long length = file.length();
+ if (length > Integer.MAX_VALUE || length <= 0) {
+ //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
+ ctx.write(new FileHeader(0, blockId).buffer());
+ ctx.flush();
+ return;
+ }
+ int len = new Long(length).intValue();
+ //logger.info("Sending block "+blockId+" filelen = "+len);
+ //logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
+ ctx.write((new FileHeader(len, blockId)).buffer());
+ try {
+ ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
+ .getChannel(), 0, file.length()));
+ } catch (Exception e) {
+ //logger.warning("Exception when sending file : " + file.getAbsolutePath());
+ e.printStackTrace();
+ }
+ } else {
+ //logger.warning("File not found: " + file.getAbsolutePath());
+ ctx.write(new FileHeader(0, blockId).buffer());
+ }
+ ctx.flush();
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+ cause.printStackTrace();
+ ctx.close();
+ }
+}
diff --git a/core/src/main/java/spark/network/netty/PathResolver.java b/core/src/main/java/spark/network/netty/PathResolver.java
new file mode 100755
index 0000000000..302411672c
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/PathResolver.java
@@ -0,0 +1,12 @@
+package spark.network.netty;
+
+
+public interface PathResolver {
+ /**
+ * Get the absolute path of the file
+ *
+ * @param fileId
+ * @return the absolute path of file
+ */
+ public String getAbsolutePath(String fileId);
+}
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
index c27ed36406..3239f4c385 100644
--- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -1,14 +1,19 @@
package spark
-import executor.{ShuffleReadMetrics, TaskMetrics}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
-import spark.storage.{DelegateBlockFetchTracker, BlockManagerId}
-import util.{CompletionIterator, TimedIterator}
+import spark.executor.{ShuffleReadMetrics, TaskMetrics}
+import spark.serializer.Serializer
+import spark.storage.BlockManagerId
+import spark.util.CompletionIterator
+
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
- override def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) = {
+
+ override def fetch[K, V](
+ shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) = {
+
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
@@ -48,18 +53,18 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
}
}
- val blockFetcherItr = blockManager.getMultiple(blocksByAddress)
- val itr = new TimedIterator(blockFetcherItr.flatMap(unpackBlock)) with DelegateBlockFetchTracker
- itr.setDelegate(blockFetcherItr)
+ val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
+ val itr = blockFetcherItr.flatMap(unpackBlock)
+
CompletionIterator[(K,V), Iterator[(K,V)]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
- shuffleMetrics.shuffleReadMillis = itr.getNetMillis
- shuffleMetrics.remoteFetchTime = itr.remoteFetchTime
- shuffleMetrics.fetchWaitTime = itr.fetchWaitTime
- shuffleMetrics.remoteBytesRead = itr.remoteBytesRead
- shuffleMetrics.totalBlocksFetched = itr.totalBlocks
- shuffleMetrics.localBlocksFetched = itr.numLocalBlocks
- shuffleMetrics.remoteBlocksFetched = itr.numRemoteBlocks
+ shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
+ shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
+ shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
+ shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
+ shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
+ shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
+ shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
metrics.shuffleReadMetrics = Some(shuffleMetrics)
})
}
diff --git a/core/src/main/scala/spark/ClosureCleaner.scala b/core/src/main/scala/spark/ClosureCleaner.scala
index 98525b99c8..d5e7132ff9 100644
--- a/core/src/main/scala/spark/ClosureCleaner.scala
+++ b/core/src/main/scala/spark/ClosureCleaner.scala
@@ -5,15 +5,22 @@ import java.lang.reflect.Field
import scala.collection.mutable.Map
import scala.collection.mutable.Set
-import org.objectweb.asm.{ClassReader, MethodVisitor, Type}
-import org.objectweb.asm.commons.EmptyVisitor
+import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import org.objectweb.asm.Opcodes._
+import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream}
private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
private def getClassReader(cls: Class[_]): ClassReader = {
- new ClassReader(cls.getResourceAsStream(
- cls.getName.replaceFirst("^.*\\.", "") + ".class"))
+ // Copy data over, before delegating to ClassReader - else we can run out of open file handles.
+ val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
+ val resourceStream = cls.getResourceAsStream(className)
+ // todo: Fixme - continuing with earlier behavior ...
+ if (resourceStream == null) return new ClassReader(resourceStream)
+
+ val baos = new ByteArrayOutputStream(128)
+ Utils.copyStream(resourceStream, baos, true)
+ new ClassReader(new ByteArrayInputStream(baos.toByteArray))
}
// Check whether a class represents a Scala closure
@@ -154,10 +161,10 @@ private[spark] object ClosureCleaner extends Logging {
}
}
-private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor {
+private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
- return new EmptyVisitor {
+ return new MethodVisitor(ASM4) {
override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) {
if (op == GETFIELD) {
for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
@@ -180,7 +187,7 @@ private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) exten
}
}
-private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor {
+private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) {
var myName: String = null
override def visit(version: Int, access: Int, name: String, sig: String,
@@ -190,7 +197,7 @@ private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisi
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
- return new EmptyVisitor {
+ return new MethodVisitor(ASM4) {
override def visitMethodInsn(op: Int, owner: String, name: String,
desc: String) {
val argTypes = Type.getArgumentTypes(desc)
diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala
index 5eea907322..2af44aa383 100644
--- a/core/src/main/scala/spark/Dependency.scala
+++ b/core/src/main/scala/spark/Dependency.scala
@@ -25,10 +25,12 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
* @param shuffleId the shuffle id
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
+ * @param serializerClass class name of the serializer to use
*/
class ShuffleDependency[K, V](
@transient rdd: RDD[(K, V)],
- val partitioner: Partitioner)
+ val partitioner: Partitioner,
+ val serializerClass: String = null)
extends Dependency(rdd) {
val shuffleId: Int = rdd.context.newShuffleId()
diff --git a/core/src/main/scala/spark/FetchFailedException.scala b/core/src/main/scala/spark/FetchFailedException.scala
index a953081d24..40b0193f19 100644
--- a/core/src/main/scala/spark/FetchFailedException.scala
+++ b/core/src/main/scala/spark/FetchFailedException.scala
@@ -3,18 +3,25 @@ package spark
import spark.storage.BlockManagerId
private[spark] class FetchFailedException(
- val bmAddress: BlockManagerId,
- val shuffleId: Int,
- val mapId: Int,
- val reduceId: Int,
+ taskEndReason: TaskEndReason,
+ message: String,
cause: Throwable)
extends Exception {
-
- override def getMessage(): String =
- "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
+
+ def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) =
+ this(FetchFailed(bmAddress, shuffleId, mapId, reduceId),
+ "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId),
+ cause)
+
+ def this (shuffleId: Int, reduceId: Int, cause: Throwable) =
+ this(FetchFailed(null, shuffleId, -1, reduceId),
+ "Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause)
+
+ override def getMessage(): String = message
+
override def getCause(): Throwable = cause
- def toTaskEndReason: TaskEndReason =
- FetchFailed(bmAddress, shuffleId, mapId, reduceId)
+ def toTaskEndReason: TaskEndReason = taskEndReason
+
}
diff --git a/core/src/main/scala/spark/HadoopWriter.scala b/core/src/main/scala/spark/HadoopWriter.scala
index afcf9f6db4..5e8396edb9 100644
--- a/core/src/main/scala/spark/HadoopWriter.scala
+++ b/core/src/main/scala/spark/HadoopWriter.scala
@@ -2,14 +2,10 @@ package org.apache.hadoop.mapred
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
-import org.apache.hadoop.util.ReflectionUtils
-import org.apache.hadoop.io.NullWritable
-import org.apache.hadoop.io.Text
import java.text.SimpleDateFormat
import java.text.NumberFormat
import java.io.IOException
-import java.net.URI
import java.util.Date
import spark.Logging
@@ -24,7 +20,7 @@ import spark.SerializableWritable
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRedUtil with Serializable {
-
+
private val now = new Date()
private val conf = new SerializableWritable(jobConf)
@@ -106,6 +102,12 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRe
}
}
+ def commitJob() {
+ // always ? Or if cmtr.needsTaskCommit ?
+ val cmtr = getOutputCommitter()
+ cmtr.commitJob(getJobContext())
+ }
+
def cleanup() {
getOutputCommitter().cleanupJob(getJobContext())
}
diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala
index 7c1c1bb144..0fc8c31463 100644
--- a/core/src/main/scala/spark/Logging.scala
+++ b/core/src/main/scala/spark/Logging.scala
@@ -68,6 +68,10 @@ trait Logging {
if (log.isErrorEnabled) log.error(msg, throwable)
}
+ protected def isTraceEnabled(): Boolean = {
+ log.isTraceEnabled
+ }
+
// Method for ensuring that logging is initialized, to avoid having multiple
// threads do it concurrently (as SLF4J initialization is not thread safe).
protected def initLogging() { log }
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 50708d9cb1..0fc6427307 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -1,7 +1,6 @@
package spark
import java.io._
-import java.util.concurrent.ConcurrentHashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.HashMap
@@ -11,6 +10,7 @@ import akka.actor._
import scala.concurrent.Await
import akka.pattern.ask
import akka.remote._
+
import scala.concurrent.duration.Duration
import akka.util.Timeout
import scala.concurrent.duration._
@@ -40,10 +40,12 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac
private[spark] class MapOutputTracker extends Logging {
+ private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _
- var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
@@ -52,7 +54,7 @@ private[spark] class MapOutputTracker extends Logging {
// Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheGeneration = generation
- val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
+ private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
@@ -60,7 +62,6 @@ private[spark] class MapOutputTracker extends Logging {
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
try {
- val timeout = 10.seconds
val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout)
} catch {
@@ -77,10 +78,9 @@ private[spark] class MapOutputTracker extends Logging {
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (mapStatuses.get(shuffleId) != None) {
+ if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
- mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
@@ -101,8 +101,9 @@ private[spark] class MapOutputTracker extends Logging {
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var array = mapStatuses(shuffleId)
- if (array != null) {
+ var arrayOpt = mapStatuses.get(shuffleId)
+ if (arrayOpt.isDefined && arrayOpt.get != null) {
+ var array = arrayOpt.get
array.synchronized {
if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null
@@ -115,13 +116,14 @@ private[spark] class MapOutputTracker extends Logging {
}
// Remembers which map output locations are currently being fetched on a worker
- val fetching = new HashSet[Int]
+ private val fetching = new HashSet[Int]
// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
+ var fetchedStatuses: Array[MapStatus] = null
fetching.synchronized {
if (fetching.contains(shuffleId)) {
// Someone else is fetching it; wait for them to be done
@@ -132,31 +134,48 @@ private[spark] class MapOutputTracker extends Logging {
case e: InterruptedException =>
}
}
- return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId))
- } else {
+ }
+
+ // Either while we waited the fetch happened successfully, or
+ // someone fetched it in between the get and the fetching.synchronized.
+ fetchedStatuses = mapStatuses.get(shuffleId).orNull
+ if (fetchedStatuses == null) {
+ // We have to do the fetch, get others to wait for us.
fetching += shuffleId
}
}
- // We won the race to fetch the output locs; do so
- logInfo("Doing the fetch; tracker actor = " + trackerActor)
- val host = System.getProperty("spark.hostname", Utils.localHostName)
- // This try-finally prevents hangs due to timeouts:
- var fetchedStatuses: Array[MapStatus] = null
- try {
- val fetchedBytes =
- askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]]
- fetchedStatuses = deserializeStatuses(fetchedBytes)
- logInfo("Got the output locations")
- mapStatuses.put(shuffleId, fetchedStatuses)
- } finally {
- fetching.synchronized {
- fetching -= shuffleId
- fetching.notifyAll()
+
+ if (fetchedStatuses == null) {
+ // We won the race to fetch the output locs; do so
+ logInfo("Doing the fetch; tracker actor = " + trackerActor)
+ val hostPort = Utils.localHostPort()
+ // This try-finally prevents hangs due to timeouts:
+ try {
+ val fetchedBytes =
+ askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
+ fetchedStatuses = deserializeStatuses(fetchedBytes)
+ logInfo("Got the output locations")
+ mapStatuses.put(shuffleId, fetchedStatuses)
+ } finally {
+ fetching.synchronized {
+ fetching -= shuffleId
+ fetching.notifyAll()
+ }
+ }
+ }
+ if (fetchedStatuses != null) {
+ fetchedStatuses.synchronized {
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
}
- return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
+ else{
+ throw new FetchFailedException(null, shuffleId, -1, reduceId,
+ new Exception("Missing all output locations for shuffle " + shuffleId))
+ }
} else {
- return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
+ statuses.synchronized {
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
+ }
}
}
@@ -194,7 +213,8 @@ private[spark] class MapOutputTracker extends Logging {
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
- mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ mapStatuses.clear()
generation = newGen
}
}
@@ -232,10 +252,13 @@ private[spark] class MapOutputTracker extends Logging {
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
- def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
+ private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
- objOut.writeObject(statuses)
+ // Since statuses can be modified in parallel, sync on it
+ statuses.synchronized {
+ objOut.writeObject(statuses)
+ }
objOut.close()
out.toByteArray
}
@@ -243,7 +266,10 @@ private[spark] class MapOutputTracker extends Logging {
// Opposite of serializeStatuses.
def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
- objIn.readObject().asInstanceOf[Array[MapStatus]]
+ objIn.readObject().
+ // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
+ // comment this out - nulls could be due to missing location ?
+ asInstanceOf[Array[MapStatus]] // .filter( _ != null )
}
}
@@ -253,16 +279,13 @@ private[spark] object MapOutputTracker {
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),
// throw a FetchFailedException.
- def convertMapStatuses(
+ private def convertMapStatuses(
shuffleId: Int,
reduceId: Int,
statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
- if (statuses == null) {
- throw new FetchFailedException(null, shuffleId, -1, reduceId,
- new Exception("Missing all output locations for shuffle " + shuffleId))
- }
+ assert (statuses != null)
statuses.map {
- status =>
+ status =>
if (status == null) {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing an output location for shuffle " + shuffleId))
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 2052d05788..fe812fe530 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -1,5 +1,6 @@
package spark
+import java.nio.ByteBuffer
import java.util.{Date, HashMap => JHashMap}
import java.text.SimpleDateFormat
@@ -11,6 +12,8 @@ import scala.reflect.{ ClassTag, classTag}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.compress.CompressionCodec
+import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.mapred.FileOutputCommitter
import org.apache.hadoop.mapred.FileOutputFormat
import org.apache.hadoop.mapred.HadoopWriter
@@ -18,7 +21,7 @@ import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
-import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil, TaskAttemptID, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil}
import spark.partial.BoundedDouble
import spark.partial.PartialResult
@@ -53,7 +56,8 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
partitioner: Partitioner,
- mapSideCombine: Boolean = true): RDD[(K, C)] = {
+ mapSideCombine: Boolean = true,
+ serializerClass: String = null): RDD[(K, C)] = {
if (getKeyClass().isArray) {
if (mapSideCombine) {
throw new SparkException("Cannot use map-side combining with array keys.")
@@ -62,19 +66,18 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
throw new SparkException("Default partitioner cannot partition array keys.")
}
}
- val aggregator =
- new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
+ val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) {
self.mapPartitions(aggregator.combineValuesByKey(_), true)
} else if (mapSideCombine) {
val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
- val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
+ val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner, serializerClass)
partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
- val values = new ShuffledRDD[K, V](self, partitioner)
+ val values = new ShuffledRDD[K, V](self, partitioner, serializerClass)
values.mapPartitions(aggregator.combineValuesByKey(_), true)
}
}
@@ -95,7 +98,16 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
* list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = {
- combineByKey[V]({v: V => func(zeroValue, v)}, func, func, partitioner)
+ // Serialize the zero value to a byte array so that we can get a new clone of it on each key
+ val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
+ val zeroArray = new Array[Byte](zeroBuffer.limit)
+ zeroBuffer.get(zeroArray)
+
+ // When deserializing, use a lazy val to create just one instance of the serializer per task
+ lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
+ def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray))
+
+ combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner)
}
/**
@@ -185,11 +197,13 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
* partitioning of the resulting key-value pair RDD by passing a Partitioner.
*/
def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = {
+ // groupByKey shouldn't use map side combine because map side combine does not
+ // reduce the amount of data shuffled and requires all map side data be inserted
+ // into a hash table, leading to more objects in the old gen.
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
- def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
val bufs = combineByKey[ArrayBuffer[V]](
- createCombiner _, mergeValue _, mergeCombiners _, partitioner)
+ createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false)
bufs.asInstanceOf[RDD[(K, Seq[V])]]
}
@@ -516,6 +530,16 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
}
/**
+ * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
+ * supporting the key and value types K and V in this RDD. Compress the result with the
+ * supplied codec.
+ */
+ def saveAsHadoopFile[F <: OutputFormat[K, V]](
+ path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) {
+ saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec)
+ }
+
+ /**
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
*/
@@ -546,8 +570,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
- val attemptId = new TaskAttemptID(jobtrackerID,
- stageId, false, context.splitId, attemptNumber)
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = outputFormatClass.newInstance
val committer = format.getOutputCommitter(hadoopContext)
@@ -566,16 +589,31 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
* however we're only going to use this local OutputCommitter for
* setupJob/commitJob, so we just use a dummy "map" task.
*/
- val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, true, 0, 0)
+ val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, true, 0, 0)
val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
jobCommitter.setupJob(jobTaskContext)
val count = self.context.runJob(self, writeShard _).sum
+ jobCommitter.commitJob(jobTaskContext)
jobCommitter.cleanupJob(jobTaskContext)
}
/**
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
+ * supporting the key and value types K and V in this RDD. Compress with the supplied codec.
+ */
+ def saveAsHadoopFile(
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[_ <: OutputFormat[_, _]],
+ codec: Class[_ <: CompressionCodec]) {
+ saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass,
+ new JobConf(self.context.hadoopConfiguration), Some(codec))
+ }
+
+ /**
+ * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
* supporting the key and value types K and V in this RDD.
*/
def saveAsHadoopFile(
@@ -583,11 +621,19 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[_ <: OutputFormat[_, _]],
- conf: JobConf = new JobConf(self.context.hadoopConfiguration)) {
+ conf: JobConf = new JobConf(self.context.hadoopConfiguration),
+ codec: Option[Class[_ <: CompressionCodec]] = None) {
conf.setOutputKeyClass(keyClass)
conf.setOutputValueClass(valueClass)
// conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug
conf.set("mapred.output.format.class", outputFormatClass.getName)
+ for (c <- codec) {
+ conf.setCompressMapOutput(true)
+ conf.set("mapred.output.compress", "true")
+ conf.setMapOutputCompressorClass(c)
+ conf.set("mapred.output.compression.codec", c.getCanonicalName)
+ conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString)
+ }
conf.setOutputCommitter(classOf[FileOutputCommitter])
FileOutputFormat.setOutputPath(conf, HadoopWriter.createPathFromString(path, conf))
saveAsHadoopDataset(conf)
@@ -638,6 +684,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](
}
self.context.runJob(self, writeToFile _)
+ writer.commitJob()
writer.cleanup()
}
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 6ee075315a..e88290fdb2 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -1,22 +1,23 @@
package spark
-import java.net.URL
-import java.util.{Date, Random}
-import java.util.{HashMap => JHashMap}
+import java.util.Random
import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
+
import scala.collection.mutable.HashMap
import scala.reflect.{classTag, ClassTag}
import org.apache.hadoop.io.BytesWritable
+import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.mapred.TextOutputFormat
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+import spark.broadcast.Broadcast
import spark.Partitioner._
import spark.partial.BoundedDouble
import spark.partial.CountEvaluator
@@ -33,10 +34,13 @@ import spark.rdd.MapPartitionsWithIndexRDD
import spark.rdd.PipedRDD
import spark.rdd.SampledRDD
import spark.rdd.ShuffledRDD
-import spark.rdd.SubtractedRDD
import spark.rdd.UnionRDD
import spark.rdd.ZippedRDD
+import spark.rdd.ZippedPartitionsRDD2
+import spark.rdd.ZippedPartitionsRDD3
+import spark.rdd.ZippedPartitionsRDD4
import spark.storage.StorageLevel
+import spark.util.BoundedPriorityQueue
import SparkContext._
@@ -105,7 +109,7 @@ abstract class RDD[T: ClassTag](
// =======================================================================
/** A unique ID for this RDD (within its SparkContext). */
- val id = sc.newRddId()
+ val id: Int = sc.newRddId()
/** A friendly name for this RDD */
var name: String = null
@@ -116,9 +120,18 @@ abstract class RDD[T: ClassTag](
this
}
+ /** User-defined generator of this RDD*/
+ var generator = Utils.getCallSiteInfo.firstUserClass
+
+ /** Reset generator*/
+ def setGenerator(_generator: String) = {
+ generator = _generator
+ }
+
/**
* Set this RDD's storage level to persist its values across operations after the first time
- * it is computed. Can only be called once on each RDD.
+ * it is computed. This can only be used to assign a new storage level if the RDD does not
+ * have a storage level set yet..
*/
def persist(newLevel: StorageLevel): RDD[T] = {
// TODO: Handle changes of StorageLevel
@@ -138,6 +151,20 @@ abstract class RDD[T: ClassTag](
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): RDD[T] = persist()
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ *
+ * @param blocking Whether to block until all blocks are deleted.
+ * @return This RDD.
+ */
+ def unpersist(blocking: Boolean = true): RDD[T] = {
+ logInfo("Removing RDD " + id + " from persistence list")
+ sc.env.blockManager.master.removeRdd(id, blocking)
+ sc.persistentRdds.remove(id)
+ storageLevel = StorageLevel.NONE
+ this
+ }
+
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
@@ -257,8 +284,8 @@ abstract class RDD[T: ClassTag](
def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
var fraction = 0.0
var total = 0
- var multiplier = 3.0
- var initialCount = count()
+ val multiplier = 3.0
+ val initialCount = count()
var maxSelected = 0
if (initialCount > Integer.MAX_VALUE - 1) {
@@ -339,13 +366,36 @@ abstract class RDD[T: ClassTag](
/**
* Return an RDD created by piping elements to a forked external process.
*/
- def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command)
+ def pipe(command: String, env: Map[String, String]): RDD[String] =
+ new PipedRDD(this, command, env)
+
/**
* Return an RDD created by piping elements to a forked external process.
- */
- def pipe(command: Seq[String], env: Map[String, String]): RDD[String] =
- new PipedRDD(this, command, env)
+ * The print behavior can be customized by providing two functions.
+ *
+ * @param command command to run in forked process.
+ * @param env environment variables to set.
+ * @param printPipeContext Before piping elements, this function is called as an oppotunity
+ * to pipe context data. Print line function (like out.println) will be
+ * passed as printPipeContext's parameter.
+ * @param printRDDElement Use this function to customize how to pipe elements. This function
+ * will be called with each RDD element as the 1st parameter, and the
+ * print line function (like out.println()) as the 2nd parameter.
+ * An example of pipe the RDD data of groupBy() in a streaming way,
+ * instead of constructing a huge String to concat all the elements:
+ * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
+ * for (e <- record._2){f(e)}
+ * @return the result RDD
+ */
+ def pipe(
+ command: Seq[String],
+ env: Map[String, String] = Map(),
+ printPipeContext: (String => Unit) => Unit = null,
+ printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
+ new PipedRDD(this, command, env,
+ if (printPipeContext ne null) sc.clean(printPipeContext) else null,
+ if (printRDDElement ne null) sc.clean(printRDDElement) else null)
/**
* Return a new RDD by applying a function to each partition of this RDD.
@@ -437,6 +487,31 @@ abstract class RDD[T: ClassTag](
*/
def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
+ /**
+ * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
+ * applying a function to the zipped partitions. Assumes that all the RDDs have the
+ * *same number of partitions*, but does *not* require them to have the same number
+ * of elements in each partition.
+ */
+ def zipPartitions[B: ClassManifest, V: ClassManifest](
+ f: (Iterator[T], Iterator[B]) => Iterator[V],
+ rdd2: RDD[B]): RDD[V] =
+ new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2)
+
+ def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest](
+ f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V],
+ rdd2: RDD[B],
+ rdd3: RDD[C]): RDD[V] =
+ new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3)
+
+ def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest](
+ f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
+ rdd2: RDD[B],
+ rdd3: RDD[C],
+ rdd4: RDD[D]): RDD[V] =
+ new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4)
+
+
// Actions (launch a job to return a value to the user program)
/**
@@ -452,7 +527,7 @@ abstract class RDD[T: ClassTag](
*/
def foreachPartition(f: Iterator[T] => Unit) {
val cleanF = sc.clean(f)
- sc.runJob(this, (iter: Iterator[T]) => f(iter))
+ sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
}
/**
@@ -685,6 +760,24 @@ abstract class RDD[T: ClassTag](
}
/**
+ * Returns the top K elements from this RDD as defined by
+ * the specified implicit Ordering[T].
+ * @param num the number of top elements to return
+ * @param ord the implicit ordering for T
+ * @return an array of top elements
+ */
+ def top(num: Int)(implicit ord: Ordering[T]): Array[T] = {
+ mapPartitions { items =>
+ val queue = new BoundedPriorityQueue[T](num)
+ queue ++= items
+ Iterator.single(queue)
+ }.reduce { (queue1, queue2) =>
+ queue1 ++= queue2
+ queue1
+ }.toArray
+ }
+
+ /**
* Save this RDD as a text file, using string representations of elements.
*/
def saveAsTextFile(path: String) {
@@ -693,6 +786,14 @@ abstract class RDD[T: ClassTag](
}
/**
+ * Save this RDD as a compressed text file, using string representations of elements.
+ */
+ def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) {
+ this.map(x => (NullWritable.get(), new Text(x.toString)))
+ .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec)
+ }
+
+ /**
* Save this RDD as a SequenceFile of serialized objects.
*/
def saveAsObjectFile(path: String) {
@@ -750,7 +851,7 @@ abstract class RDD[T: ClassTag](
private var storageLevel: StorageLevel = StorageLevel.NONE
/** Record user function generating this RDD. */
- private[spark] val origin = Utils.getSparkCallSite
+ private[spark] val origin = Utils.formatSparkCallSite
private[spark] def elementClassTag: ClassTag[T] = classTag[T]
diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala
index 083ba9b8fa..5e7bb029eb 100644
--- a/core/src/main/scala/spark/RDDCheckpointData.scala
+++ b/core/src/main/scala/spark/RDDCheckpointData.scala
@@ -3,6 +3,7 @@ package spark
import scala.reflect.ClassTag
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.conf.Configuration
import rdd.{CheckpointRDD, CoalescedRDD}
@@ -66,14 +67,20 @@ private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T])
}
}
+ // Create the output path for the checkpoint
+ val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id)
+ val fs = path.getFileSystem(new Configuration())
+ if (!fs.mkdirs(path)) {
+ throw new SparkException("Failed to create checkpoint path " + path)
+ }
+
// Save to file, and reload it as an RDD
- val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id).toString
- rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _)
- val newRDD = new CheckpointRDD[T](rdd.context, path)
+ rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _)
+ val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
// Change the dependencies and partitions of the RDD
RDDCheckpointData.synchronized {
- cpFile = Some(path)
+ cpFile = Some(path.toString)
cpRDD = Some(newRDD)
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
index 883a0152bb..edfde37da3 100644
--- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
@@ -19,6 +19,7 @@ import org.apache.hadoop.mapred.TextOutputFormat
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.mapred.OutputCommitter
import org.apache.hadoop.mapred.FileOutputCommitter
+import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.Writable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.BytesWritable
@@ -63,7 +64,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag
* byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported
* file system.
*/
- def saveAsSequenceFile(path: String) {
+ def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) {
def anyToWritable[U <% Writable](u: U): Writable = u
val keyClass = getWritableClass[K]
@@ -73,14 +74,18 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag
logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" )
val format = classOf[SequenceFileOutputFormat[Writable, Writable]]
+ val jobConf = new JobConf(self.context.hadoopConfiguration)
if (!convertKey && !convertValue) {
- self.saveAsHadoopFile(path, keyClass, valueClass, format)
+ self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec)
} else if (!convertKey && convertValue) {
- self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format)
+ self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(
+ path, keyClass, valueClass, format, jobConf, codec)
} else if (convertKey && !convertValue) {
- self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(path, keyClass, valueClass, format)
+ self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(
+ path, keyClass, valueClass, format, jobConf, codec)
} else if (convertKey && convertValue) {
- self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format)
+ self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(
+ path, keyClass, valueClass, format, jobConf, codec)
}
}
}
diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala
index 442e9f0269..9513a00126 100644
--- a/core/src/main/scala/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/spark/ShuffleFetcher.scala
@@ -1,13 +1,16 @@
package spark
-import executor.TaskMetrics
+import spark.executor.TaskMetrics
+import spark.serializer.Serializer
+
private[spark] abstract class ShuffleFetcher {
/**
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
- def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) : Iterator[(K,V)]
+ def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
+ serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[(K,V)]
/** Stop the fetcher */
def stop() {}
diff --git a/core/src/main/scala/spark/SizeEstimator.scala b/core/src/main/scala/spark/SizeEstimator.scala
index d4e1157250..f8a4c4e489 100644
--- a/core/src/main/scala/spark/SizeEstimator.scala
+++ b/core/src/main/scala/spark/SizeEstimator.scala
@@ -198,7 +198,7 @@ private[spark] object SizeEstimator extends Logging {
val elem = JArray.get(array, index)
size += SizeEstimator.estimate(elem, state.visited)
}
- state.size += ((length / 100.0) * size).toLong
+ state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong
}
}
}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 7272a592a5..ef6de87193 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -1,48 +1,56 @@
package spark
import java.io._
-import java.util.concurrent.atomic.AtomicInteger
import java.net.URI
+import java.util.Properties
+import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.JavaConversions._
import scala.collection.Map
import scala.collection.generic.Growable
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
+
import scala.reflect.{ ClassTag, classTag}
-import org.apache.hadoop.fs.Path
+import scala.util.DynamicVariable
+import scala.collection.mutable.{ConcurrentMap, HashMap}
+
+import akka.actor.Actor._
+
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.mapred.InputFormat
-import org.apache.hadoop.mapred.SequenceFileInputFormat
-import org.apache.hadoop.io.Writable
-import org.apache.hadoop.io.IntWritable
-import org.apache.hadoop.io.LongWritable
-import org.apache.hadoop.io.FloatWritable
-import org.apache.hadoop.io.DoubleWritable
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.ArrayWritable
import org.apache.hadoop.io.BooleanWritable
import org.apache.hadoop.io.BytesWritable
-import org.apache.hadoop.io.ArrayWritable
+import org.apache.hadoop.io.DoubleWritable
+import org.apache.hadoop.io.FloatWritable
+import org.apache.hadoop.io.IntWritable
+import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
+import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.FileInputFormat
+import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapred.SequenceFileInputFormat
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
-import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
+import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
+
import org.apache.mesos.MesosNativeLibrary
-import spark.deploy.LocalSparkCluster
-import spark.partial.ApproximateEvaluator
-import spark.partial.PartialResult
+import spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
+import spark.partial.{ApproximateEvaluator, PartialResult}
import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
-import spark.scheduler._
+import spark.scheduler.{DAGScheduler, ResultTask, ShuffleMapTask, SparkListener, SplitInfo, Stage, StageInfo, TaskScheduler}
+import spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, ClusterScheduler}
import spark.scheduler.local.LocalScheduler
-import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import spark.storage.BlockManagerUI
+import spark.storage.{BlockManagerUI, StorageStatus, StorageUtils, RDDInfo}
import spark.util.{MetadataCleaner, TimeStampedHashMap}
-import spark.storage.{StorageStatus, StorageUtils, RDDInfo}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -60,7 +68,10 @@ class SparkContext(
val appName: String,
val sparkHome: String = null,
val jars: Seq[String] = Nil,
- val environment: Map[String, String] = Map())
+ val environment: Map[String, String] = Map(),
+ // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) too.
+ // This is typically generated from InputFormatInfo.computePreferredLocations .. host, set of data-local splits on host
+ val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map())
extends Logging {
// Ensure logging is initialized before we spawn any threads
@@ -68,7 +79,7 @@ class SparkContext(
// Set Spark driver host and port system properties
if (System.getProperty("spark.driver.host") == null) {
- System.setProperty("spark.driver.host", Utils.localIpAddress)
+ System.setProperty("spark.driver.host", Utils.localHostName())
}
if (System.getProperty("spark.driver.port") == null) {
System.setProperty("spark.driver.port", "0")
@@ -95,24 +106,29 @@ class SparkContext(
private[spark] val addedJars = HashMap[String, Long]()
// Keeps track of all persisted RDDs
- private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]()
+ private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup)
// Add each JAR given through the constructor
- jars.foreach { addJar(_) }
+ if (jars != null) {
+ jars.foreach { addJar(_) }
+ }
// Environment variables to pass to our executors
private[spark] val executorEnvs = HashMap[String, String]()
// Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner
- for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS",
- "SPARK_TESTING")) {
+ for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", "SPARK_TESTING")) {
val value = System.getenv(key)
if (value != null) {
executorEnvs(key) = value
}
}
- executorEnvs ++= environment
+ // Since memory can be set with a system property too, use that
+ executorEnvs("SPARK_MEM") = SparkContext.executorMemoryRequested + "m"
+ if (environment != null) {
+ executorEnvs ++= environment
+ }
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
@@ -144,14 +160,12 @@ class SparkContext(
scheduler
case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
- // Check to make sure SPARK_MEM <= memoryPerSlave. Otherwise Spark will just hang.
+ // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
val memoryPerSlaveInt = memoryPerSlave.toInt
- val sparkMemEnv = System.getenv("SPARK_MEM")
- val sparkMemEnvInt = if (sparkMemEnv != null) Utils.memoryStringToMb(sparkMemEnv) else 512
- if (sparkMemEnvInt > memoryPerSlaveInt) {
+ if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) {
throw new SparkException(
- "Slave memory (%d MB) cannot be smaller than SPARK_MEM (%d MB)".format(
- memoryPerSlaveInt, sparkMemEnvInt))
+ "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format(
+ memoryPerSlaveInt, SparkContext.executorMemoryRequested))
}
val scheduler = new ClusterScheduler(this)
@@ -165,6 +179,22 @@ class SparkContext(
}
scheduler
+ case "yarn-standalone" =>
+ val scheduler = try {
+ val clazz = Class.forName("spark.scheduler.cluster.YarnClusterScheduler")
+ val cons = clazz.getConstructor(classOf[SparkContext])
+ cons.newInstance(this).asInstanceOf[ClusterScheduler]
+ } catch {
+ // TODO: Enumerate the exact reasons why it can fail
+ // But irrespective of it, it means we cannot proceed !
+ case th: Throwable => {
+ throw new SparkException("YARN mode not available ?", th)
+ }
+ }
+ val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem)
+ scheduler.initialize(backend)
+ scheduler
+
case _ =>
if (MESOS_REGEX.findFirstIn(master).isEmpty) {
logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
@@ -184,12 +214,12 @@ class SparkContext(
}
taskScheduler.start()
- private var dagScheduler = new DAGScheduler(taskScheduler)
+ @volatile private var dagScheduler = new DAGScheduler(taskScheduler)
dagScheduler.start()
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
- val conf = new Configuration()
+ val conf = SparkHadoopUtil.newConfiguration()
// Explicitly check for S3 environment variables
if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
@@ -208,6 +238,22 @@ class SparkContext(
private[spark] var checkpointDir: Option[String] = None
+ // Thread Local variable that can be used by users to pass information down the stack
+ private val localProperties = new DynamicVariable[Properties](null)
+
+ def initLocalProperties() {
+ localProperties.value = new Properties()
+ }
+
+ def addLocalProperties(key: String, value: String) {
+ if(localProperties.value == null) {
+ localProperties.value = new Properties()
+ }
+ localProperties.value.setProperty(key,value)
+ }
+ // Post init
+ taskScheduler.postStartHook()
+
// Methods for creating RDDs
/** Distribute a local Scala collection to form an RDD. */
@@ -472,7 +518,7 @@ class SparkContext(
*/
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
- (blockManagerId.ip + ":" + blockManagerId.port, mem)
+ (blockManagerId.host + ":" + blockManagerId.port, mem)
}
}
@@ -480,7 +526,7 @@ class SparkContext(
* Return information about what RDDs are cached, if they are in mem or on disk, how much space
* they take, etc.
*/
- def getRDDStorageInfo : Array[RDDInfo] = {
+ def getRDDStorageInfo: Array[RDDInfo] = {
StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
}
@@ -491,7 +537,7 @@ class SparkContext(
/**
* Return information about blocks stored in all of the slaves
*/
- def getExecutorStorageStatus : Array[StorageStatus] = {
+ def getExecutorStorageStatus: Array[StorageStatus] = {
env.blockManager.master.getStorageStatus
}
@@ -509,13 +555,18 @@ class SparkContext(
* filesystems), or an HTTP, HTTPS or FTP URI.
*/
def addJar(path: String) {
- val uri = new URI(path)
- val key = uri.getScheme match {
- case null | "file" => env.httpFileServer.addJar(new File(uri.getPath))
- case _ => path
+ if (null == path) {
+ logWarning("null specified as parameter to addJar",
+ new SparkException("null specified as parameter to addJar"))
+ } else {
+ val uri = new URI(path)
+ val key = uri.getScheme match {
+ case null | "file" => env.httpFileServer.addJar(new File(uri.getPath))
+ case _ => path
+ }
+ addedJars(key) = System.currentTimeMillis
+ logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
}
- addedJars(key) = System.currentTimeMillis
- logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
}
/**
@@ -528,10 +579,13 @@ class SparkContext(
/** Shut down the SparkContext. */
def stop() {
- if (dagScheduler != null) {
+ // Do this only if not stopped already - best case effort.
+ // prevent NPE if stopped more than once.
+ val dagSchedulerCopy = dagScheduler
+ dagScheduler = null
+ if (dagSchedulerCopy != null) {
metadataCleaner.cancel()
- dagScheduler.stop()
- dagScheduler = null
+ dagSchedulerCopy.stop()
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
@@ -547,6 +601,7 @@ class SparkContext(
}
}
+
/**
* Get Spark's home location from either a value set through the constructor,
* or the spark.home Java property, or the SPARK_HOME environment variable
@@ -576,10 +631,10 @@ class SparkContext(
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
- val callSite = Utils.getSparkCallSite
+ val callSite = Utils.formatSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler)
+ val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
@@ -658,12 +713,11 @@ class SparkContext(
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
- timeout: Long
- ): PartialResult[R] = {
- val callSite = Utils.getSparkCallSite
+ timeout: Long): PartialResult[R] = {
+ val callSite = Utils.formatSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout)
+ val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
result
}
@@ -686,7 +740,7 @@ class SparkContext(
*/
def setCheckpointDir(dir: String, useExisting: Boolean = false) {
val path = new Path(dir)
- val fs = path.getFileSystem(new Configuration())
+ val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration())
if (!useExisting) {
if (fs.exists(path)) {
throw new Exception("Checkpoint directory '" + path + "' already exists.")
@@ -829,6 +883,15 @@ object SparkContext {
/** Find the JAR that contains the class of a particular object */
def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass)
+
+ /** Get the amount of memory per executor requested through system properties or SPARK_MEM */
+ private[spark] val executorMemoryRequested = {
+ // TODO: Might need to add some extra memory for the non-heap parts of the JVM
+ Option(System.getProperty("spark.executor.memory"))
+ .orElse(Option(System.getenv("SPARK_MEM")))
+ .map(Utils.memoryStringToMb)
+ .getOrElse(512)
+ }
}
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 144ddea35f..89d52064e1 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -1,14 +1,19 @@
package spark
+import collection.mutable
+import serializer.Serializer
+
import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
import akka.remote.RemoteActorRefProvider
-import serializer.Serializer
import spark.broadcast.BroadcastManager
import spark.storage.BlockManager
import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager
+import spark.serializer.{Serializer, SerializerManager}
import spark.util.AkkaUtils
+import spark.api.python.PythonWorkerFactory
+
/**
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
@@ -20,6 +25,7 @@ import spark.util.AkkaUtils
class SparkEnv (
val executorId: String,
val actorSystem: ActorSystem,
+ val serializerManager: SerializerManager,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
@@ -29,10 +35,16 @@ class SparkEnv (
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
val httpFileServer: HttpFileServer,
- val sparkFilesDir: String
- ) {
+ val sparkFilesDir: String,
+ // To be set only as part of initialization of SparkContext.
+ // (executorId, defaultHostPort) => executorHostPort
+ // If executorId is NOT found, return defaultHostPort
+ var executorIdToHostPort: Option[(String, String) => String]) {
+
+ private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
def stop() {
+ pythonWorkers.foreach { case(key, worker) => worker.stop() }
httpFileServer.stop()
mapOutputTracker.stop()
shuffleFetcher.stop()
@@ -45,6 +57,23 @@ class SparkEnv (
// UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it.
//actorSystem.awaitTermination()
}
+
+ def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
+ synchronized {
+ val key = (pythonExec, envVars)
+ pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
+ }
+ }
+
+ def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = {
+ val env = SparkEnv.get
+ if (env.executorIdToHostPort.isEmpty) {
+ // default to using host, not host port. Relevant to non cluster modes.
+ return defaultHostPort
+ }
+
+ env.executorIdToHostPort.get(executorId, defaultHostPort)
+ }
}
object SparkEnv extends Logging {
@@ -73,6 +102,16 @@ object SparkEnv extends Logging {
System.setProperty("spark.driver.port", boundPort.toString)
}
+ // set only if unset until now.
+ if (System.getProperty("spark.hostPort", null) == null) {
+ if (!isDriver){
+ // unexpected
+ Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set")
+ }
+ Utils.checkHost(hostname)
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+ }
+
val classLoader = Thread.currentThread.getContextClassLoader
// Create an instance of the class named by the given Java system property, or by
@@ -82,16 +121,23 @@ object SparkEnv extends Logging {
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
}
- val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
-
+ val serializerManager = new SerializerManager
+
+ val serializer = serializerManager.setDefault(
+ System.getProperty("spark.serializer", "spark.JavaSerializer"))
+
+ val closureSerializer = serializerManager.get(
+ System.getProperty("spark.closure.serializer", "spark.JavaSerializer"))
+
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
if (isDriver) {
logInfo("Registering " + name)
actorSystem.actorOf(Props(newActor), name = name)
} else {
- val driverIp: String = System.getProperty("spark.driver.host", "localhost")
+ val driverHost: String = System.getProperty("spark.driver.host", "localhost")
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
- val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, name)
+ Utils.checkHost(driverHost, "Expected hostname")
+ val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name)
logInfo("Connecting to " + name + ": " + url)
actorSystem.actorFor(url)
}
@@ -106,9 +152,6 @@ object SparkEnv extends Logging {
val broadcastManager = new BroadcastManager(isDriver)
- val closureSerializer = instantiateClass[Serializer](
- "spark.closure.serializer", "spark.JavaSerializer")
-
val cacheManager = new CacheManager(blockManager)
// Have to assign trackerActor after initialization as MapOutputTrackerActor
@@ -143,6 +186,7 @@ object SparkEnv extends Logging {
new SparkEnv(
executorId,
actorSystem,
+ serializerManager,
serializer,
closureSerializer,
cacheManager,
@@ -152,7 +196,7 @@ object SparkEnv extends Logging {
blockManager,
connectionManager,
httpFileServer,
- sparkFilesDir)
+ sparkFilesDir,
+ None)
}
-
}
diff --git a/core/src/main/scala/spark/TaskEndReason.scala b/core/src/main/scala/spark/TaskEndReason.scala
index 420c54bc9a..8140cba084 100644
--- a/core/src/main/scala/spark/TaskEndReason.scala
+++ b/core/src/main/scala/spark/TaskEndReason.scala
@@ -14,9 +14,19 @@ private[spark] case object Success extends TaskEndReason
private[spark]
case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it
-private[spark]
-case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
+private[spark] case class FetchFailed(
+ bmAddress: BlockManagerId,
+ shuffleId: Int,
+ mapId: Int,
+ reduceId: Int)
+ extends TaskEndReason
-private[spark] case class ExceptionFailure(exception: Throwable) extends TaskEndReason
+private[spark] case class ExceptionFailure(
+ className: String,
+ description: String,
+ stackTrace: Array[StackTraceElement])
+ extends TaskEndReason
private[spark] case class OtherFailure(message: String) extends TaskEndReason
+
+private[spark] case class TaskResultTooBigFailure() extends TaskEndReason
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index cdccb8b336..e02507f83e 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -1,14 +1,16 @@
package spark
import java.io._
-import java.net._
+import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket}
import java.util.{Locale, Random, UUID}
-import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
+
+import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor}
+import java.util.regex.Pattern
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.JavaConversions._
import scala.io.Source
import scala.reflect.ClassTag
@@ -18,11 +20,14 @@ import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
import spark.serializer.SerializerInstance
+import spark.deploy.SparkHadoopUtil
+
/**
* Various utility methods used by Spark.
*/
private object Utils extends Logging {
+
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
@@ -73,6 +78,40 @@ private object Utils extends Logging {
return buf
}
+ private val shutdownDeletePaths = new collection.mutable.HashSet[String]()
+
+ // Register the path to be deleted via shutdown hook
+ def registerShutdownDeleteDir(file: File) {
+ val absolutePath = file.getAbsolutePath()
+ shutdownDeletePaths.synchronized {
+ shutdownDeletePaths += absolutePath
+ }
+ }
+
+ // Is the path already registered to be deleted via a shutdown hook ?
+ def hasShutdownDeleteDir(file: File): Boolean = {
+ val absolutePath = file.getAbsolutePath()
+ shutdownDeletePaths.synchronized {
+ shutdownDeletePaths.contains(absolutePath)
+ }
+ }
+
+ // Note: if file is child of some registered path, while not equal to it, then return true;
+ // else false. This is to ensure that two shutdown hooks do not try to delete each others
+ // paths - resulting in IOException and incomplete cleanup.
+ def hasRootAsShutdownDeleteDir(file: File): Boolean = {
+ val absolutePath = file.getAbsolutePath()
+ val retval = shutdownDeletePaths.synchronized {
+ shutdownDeletePaths.find { path =>
+ !absolutePath.equals(path) && absolutePath.startsWith(path)
+ }.isDefined
+ }
+ if (retval) {
+ logInfo("path = " + file + ", already present as root for deletion.")
+ }
+ retval
+ }
+
/** Create a temporary directory inside the given parent directory */
def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = {
var attempts = 0
@@ -81,8 +120,8 @@ private object Utils extends Logging {
while (dir == null) {
attempts += 1
if (attempts > maxAttempts) {
- throw new IOException("Failed to create a temp directory after " + maxAttempts +
- " attempts!")
+ throw new IOException("Failed to create a temp directory (under " + root + ") after " +
+ maxAttempts + " attempts!")
}
try {
dir = new File(root, "spark-" + UUID.randomUUID.toString)
@@ -91,13 +130,17 @@ private object Utils extends Logging {
}
} catch { case e: IOException => ; }
}
+
+ registerShutdownDeleteDir(dir)
+
// Add a shutdown hook to delete the temp dir when the JVM exits
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) {
override def run() {
- Utils.deleteRecursively(dir)
+ // Attempt to delete if some patch which is parent of this is not already registered.
+ if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir)
}
})
- return dir
+ dir
}
/** Copy all data from an InputStream to an OutputStream */
@@ -140,40 +183,35 @@ private object Utils extends Logging {
Utils.copyStream(in, out, true)
if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
tempFile.delete()
- throw new SparkException("File " + targetFile + " exists and does not match contents of" +
- " " + url)
+ throw new SparkException(
+ "File " + targetFile + " exists and does not match contents of" + " " + url)
} else {
Files.move(tempFile, targetFile)
}
case "file" | null =>
- val sourceFile = if (uri.isAbsolute) {
- new File(uri)
- } else {
- new File(url)
- }
- if (targetFile.exists && !Files.equal(sourceFile, targetFile)) {
- throw new SparkException("File " + targetFile + " exists and does not match contents of" +
- " " + url)
- } else {
- // Remove the file if it already exists
- targetFile.delete()
- // Symlink the file locally.
- if (uri.isAbsolute) {
- // url is absolute, i.e. it starts with "file:///". Extract the source
- // file's absolute path from the url.
- val sourceFile = new File(uri)
- logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
- FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath)
+ // In the case of a local file, copy the local file to the target directory.
+ // Note the difference between uri vs url.
+ val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url)
+ if (targetFile.exists) {
+ // If the target file already exists, warn the user if
+ if (!Files.equal(sourceFile, targetFile)) {
+ throw new SparkException(
+ "File " + targetFile + " exists and does not match contents of" + " " + url)
} else {
- // url is not absolute, i.e. itself is the path to the source file.
- logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath)
- FileUtil.symLink(url, targetFile.getAbsolutePath)
+ // Do nothing if the file contents are the same, i.e. this file has been copied
+ // previously.
+ logInfo(sourceFile.getAbsolutePath + " has been previously copied to "
+ + targetFile.getAbsolutePath)
}
+ } else {
+ // The file does not exist in the target directory. Copy it there.
+ logInfo("Copying " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
+ Files.copy(sourceFile, targetFile)
}
case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
val uri = new URI(url)
- val conf = new Configuration()
+ val conf = SparkHadoopUtil.newConfiguration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(tempFile)
@@ -232,8 +270,10 @@ private object Utils extends Logging {
/**
* Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4).
+ * Note, this is typically not used from within core spark.
*/
lazy val localIpAddress: String = findLocalIpAddress()
+ lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress)
private def findLocalIpAddress(): String = {
val defaultIpOverride = System.getenv("SPARK_LOCAL_IP")
@@ -271,6 +311,8 @@ private object Utils extends Logging {
* hostname it reports to the master.
*/
def setCustomHostname(hostname: String) {
+ // DEBUG code
+ Utils.checkHost(hostname)
customHostname = Some(hostname)
}
@@ -278,7 +320,91 @@ private object Utils extends Logging {
* Get the local machine's hostname.
*/
def localHostName(): String = {
- customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
+ customHostname.getOrElse(localIpAddressHostname)
+ }
+
+ def getAddressHostName(address: String): String = {
+ InetAddress.getByName(address).getHostName
+ }
+
+ def localHostPort(): String = {
+ val retval = System.getProperty("spark.hostPort", null)
+ if (retval == null) {
+ logErrorWithStack("spark.hostPort not set but invoking localHostPort")
+ return localHostName()
+ }
+
+ retval
+ }
+
+/*
+ // Used by DEBUG code : remove when all testing done
+ private val ipPattern = Pattern.compile("^[0-9]+(\\.[0-9]+)*$")
+ def checkHost(host: String, message: String = "") {
+ // Currently catches only ipv4 pattern, this is just a debugging tool - not rigourous !
+ // if (host.matches("^[0-9]+(\\.[0-9]+)*$")) {
+ if (ipPattern.matcher(host).matches()) {
+ Utils.logErrorWithStack("Unexpected to have host " + host + " which matches IP pattern. Message " + message)
+ }
+ if (Utils.parseHostPort(host)._2 != 0){
+ Utils.logErrorWithStack("Unexpected to have host " + host + " which has port in it. Message " + message)
+ }
+ }
+
+ // Used by DEBUG code : remove when all testing done
+ def checkHostPort(hostPort: String, message: String = "") {
+ val (host, port) = Utils.parseHostPort(hostPort)
+ checkHost(host)
+ if (port <= 0){
+ Utils.logErrorWithStack("Unexpected to have port " + port + " which is not valid in " + hostPort + ". Message " + message)
+ }
+ }
+
+ // Used by DEBUG code : remove when all testing done
+ def logErrorWithStack(msg: String) {
+ try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } }
+ // temp code for debug
+ System.exit(-1)
+ }
+*/
+
+ // Once testing is complete in various modes, replace with this ?
+ def checkHost(host: String, message: String = "") {}
+ def checkHostPort(hostPort: String, message: String = "") {}
+
+ // Used by DEBUG code : remove when all testing done
+ def logErrorWithStack(msg: String) {
+ try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } }
+ }
+
+ def getUserNameFromEnvironment(): String = {
+ SparkHadoopUtil.getUserNameFromEnvironment
+ }
+
+ // Typically, this will be of order of number of nodes in cluster
+ // If not, we should change it to LRUCache or something.
+ private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]()
+
+ def parseHostPort(hostPort: String): (String, Int) = {
+ {
+ // Check cache first.
+ var cached = hostPortParseResults.get(hostPort)
+ if (cached != null) return cached
+ }
+
+ val indx: Int = hostPort.lastIndexOf(':')
+ // This is potentially broken - when dealing with ipv6 addresses for example, sigh ...
+ // but then hadoop does not support ipv6 right now.
+ // For now, we assume that if port exists, then it is valid - not check if it is an int > 0
+ if (-1 == indx) {
+ val retval = (hostPort, 0)
+ hostPortParseResults.put(hostPort, retval)
+ return retval
+ }
+
+ val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt)
+ hostPortParseResults.putIfAbsent(hostPort, retval)
+ hostPortParseResults.get(hostPort)
}
private[spark] val daemonThreadFactory: ThreadFactory =
@@ -400,13 +526,45 @@ private object Utils extends Logging {
execute(command, new File("."))
}
+ /**
+ * Execute a command and get its output, throwing an exception if it yields a code other than 0.
+ */
+ def executeAndGetOutput(command: Seq[String], workingDir: File = new File(".")): String = {
+ val process = new ProcessBuilder(command: _*)
+ .directory(workingDir)
+ .start()
+ new Thread("read stderr for " + command(0)) {
+ override def run() {
+ for (line <- Source.fromInputStream(process.getErrorStream).getLines) {
+ System.err.println(line)
+ }
+ }
+ }.start()
+ val output = new StringBuffer
+ val stdoutThread = new Thread("read stdout for " + command(0)) {
+ override def run() {
+ for (line <- Source.fromInputStream(process.getInputStream).getLines) {
+ output.append(line)
+ }
+ }
+ }
+ stdoutThread.start()
+ val exitCode = process.waitFor()
+ stdoutThread.join() // Wait for it to finish reading output
+ if (exitCode != 0) {
+ throw new SparkException("Process " + command + " exited with code " + exitCode)
+ }
+ output.toString
+ }
+ private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String,
+ val firstUserLine: Int, val firstUserClass: String)
/**
* When called inside a class in the spark package, returns the name of the user code class
* (outside the spark package) that called into Spark, as well as which Spark method they called.
* This is used, for example, to tell users where in their code each RDD got created.
*/
- def getSparkCallSite: String = {
+ def getCallSiteInfo: CallSiteInfo = {
val trace = Thread.currentThread.getStackTrace().filter( el =>
(!el.getMethodName.contains("getStackTrace")))
@@ -418,6 +576,7 @@ private object Utils extends Logging {
var firstUserFile = "<unknown>"
var firstUserLine = 0
var finished = false
+ var firstUserClass = "<unknown>"
for (el <- trace) {
if (!finished) {
@@ -432,13 +591,19 @@ private object Utils extends Logging {
else {
firstUserLine = el.getLineNumber
firstUserFile = el.getFileName
+ firstUserClass = el.getClassName
finished = true
}
}
}
- "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
+ new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass)
}
+ def formatSparkCallSite = {
+ val callSiteInfo = getCallSiteInfo
+ "%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile,
+ callSiteInfo.firstUserLine)
+ }
/**
* Try to find a free port to bind to on the local host. This should ideally never be needed,
* except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray)
@@ -480,4 +645,67 @@ private object Utils extends Logging {
}
return false
}
+
+ def isSpace(c: Char): Boolean = {
+ " \t\r\n".indexOf(c) != -1
+ }
+
+ /**
+ * Split a string of potentially quoted arguments from the command line the way that a shell
+ * would do it to determine arguments to a command. For example, if the string is 'a "b c" d',
+ * then it would be parsed as three arguments: 'a', 'b c' and 'd'.
+ */
+ def splitCommandString(s: String): Seq[String] = {
+ val buf = new ArrayBuffer[String]
+ var inWord = false
+ var inSingleQuote = false
+ var inDoubleQuote = false
+ var curWord = new StringBuilder
+ def endWord() {
+ buf += curWord.toString
+ curWord.clear()
+ }
+ var i = 0
+ while (i < s.length) {
+ var nextChar = s.charAt(i)
+ if (inDoubleQuote) {
+ if (nextChar == '"') {
+ inDoubleQuote = false
+ } else if (nextChar == '\\') {
+ if (i < s.length - 1) {
+ // Append the next character directly, because only " and \ may be escaped in
+ // double quotes after the shell's own expansion
+ curWord.append(s.charAt(i + 1))
+ i += 1
+ }
+ } else {
+ curWord.append(nextChar)
+ }
+ } else if (inSingleQuote) {
+ if (nextChar == '\'') {
+ inSingleQuote = false
+ } else {
+ curWord.append(nextChar)
+ }
+ // Backslashes are not treated specially in single quotes
+ } else if (nextChar == '"') {
+ inWord = true
+ inDoubleQuote = true
+ } else if (nextChar == '\'') {
+ inWord = true
+ inSingleQuote = true
+ } else if (!isSpace(nextChar)) {
+ curWord.append(nextChar)
+ inWord = true
+ } else if (inWord && isSpace(nextChar)) {
+ endWord()
+ inWord = false
+ }
+ i += 1
+ }
+ if (inWord || inDoubleQuote || inSingleQuote) {
+ endWord()
+ }
+ return buf
+ }
}
diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
index 89c6d05383..0fa8162f3c 100644
--- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
@@ -7,6 +7,7 @@ import scala.Tuple2
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
+import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
@@ -460,6 +461,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass)
}
+ /** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */
+ def saveAsHadoopFile[F <: OutputFormat[_, _]](
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F],
+ codec: Class[_ <: CompressionCodec]) {
+ rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec)
+ }
+
/** Output the RDD to any Hadoop-supported file system. */
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
path: String,
diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala
index 032506383c..6f44e018e9 100644
--- a/core/src/main/scala/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDD.scala
@@ -17,10 +17,16 @@ JavaRDDLike[T, JavaRDD[T]] {
/**
* Set this RDD's storage level to persist its values across operations after the first time
- * it is computed. Can only be called once on each RDD.
+ * it is computed. This can only be used to assign a new storage level if the RDD does not
+ * have a storage level set yet..
*/
def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel))
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ */
+ def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist())
+
// Transformations (return a new RDD)
/**
@@ -81,7 +87,6 @@ JavaRDDLike[T, JavaRDD[T]] {
*/
def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] =
wrapRDD(rdd.subtract(other, p))
-
}
object JavaRDD {
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index a6555081b3..3fe2011f4c 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -1,10 +1,11 @@
package spark.api.java
-import java.util.{List => JList}
+import java.util.{List => JList, Comparator}
import scala.Tuple2
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
+import org.apache.hadoop.io.compress.CompressionCodec
import spark.{SparkContext, Partition, RDD, TaskContext}
import spark.api.java.JavaPairRDD._
import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
@@ -183,6 +184,21 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classTag))(classTag, other.classTag)
}
+ /**
+ * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
+ * applying a function to the zipped partitions. Assumes that all the RDDs have the
+ * *same number of partitions*, but does *not* require them to have the same number
+ * of elements in each partition.
+ */
+ def zipPartitions[U, V](
+ f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V],
+ other: JavaRDDLike[U, _]): JavaRDD[V] = {
+ def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator(
+ f.apply(asJavaIterator(x), asJavaIterator(y)).iterator())
+ JavaRDD.fromRDD(
+ rdd.zipPartitions(fn, other.rdd)(other.classTag, f.elementType()))(f.elementType())
+ }
+
// Actions (launch a job to return a value to the user program)
/**
@@ -296,6 +312,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
*/
def saveAsTextFile(path: String) = rdd.saveAsTextFile(path)
+
+ /**
+ * Save this RDD as a compressed text file, using string representations of elements.
+ */
+ def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) =
+ rdd.saveAsTextFile(path, codec)
+
/**
* Save this RDD as a SequenceFile of serialized objects.
*/
@@ -337,4 +360,29 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def toDebugString(): String = {
rdd.toDebugString
}
+
+ /**
+ * Returns the top K elements from this RDD as defined by
+ * the specified Comparator[T].
+ * @param num the number of top elements to return
+ * @param comp the comparator that defines the order
+ * @return an array of top elements
+ */
+ def top(num: Int, comp: Comparator[T]): JList[T] = {
+ import scala.collection.JavaConversions._
+ val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp))
+ val arr: java.util.Collection[T] = topElems.toSeq
+ new java.util.ArrayList(arr)
+ }
+
+ /**
+ * Returns the top K elements from this RDD using the
+ * natural ordering for T.
+ * @param num the number of top elements to return
+ * @return an array of top elements
+ */
+ def top(num: Int): JList[T] = {
+ val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]]
+ top(num, comp)
+ }
}
diff --git a/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala
new file mode 100644
index 0000000000..6044043add
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/function/FlatMapFunction2.scala
@@ -0,0 +1,11 @@
+package spark.api.java.function
+
+/**
+ * A function that takes two inputs and returns zero or more output records.
+ */
+abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
+ @throws(classOf[Exception])
+ def call(a: A, b:B) : java.lang.Iterable[C]
+
+ def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]]
+}
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 220047c360..3d1e45cb2c 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -2,9 +2,10 @@ package spark.api.python
import java.io._
import java.net._
-import java.util.{List => JList, ArrayList => JArrayList, Collections}
+import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._
+
import scala.io.Source
import scala.reflect.ClassTag
@@ -17,16 +18,18 @@ import spark.rdd.PipedRDD
private[spark] class PythonRDD[T: ClassTag](
parent: RDD[T],
command: Seq[String],
- envVars: java.util.Map[String, String],
+ envVars: JMap[String, String],
preservePartitoning: Boolean,
pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
- def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
+ def this(parent: RDD[T], command: String, envVars: JMap[String, String],
preservePartitoning: Boolean, pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]]) =
@@ -37,68 +40,57 @@ private[spark] class PythonRDD[T: ClassTag](
override val partitioner = if (preservePartitoning) parent.partitioner else None
- override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
- val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
-
- val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py"))
- // Add the environmental variables to the process.
- val currentEnvVars = pb.environment()
- for ((variable, value) <- envVars) {
- currentEnvVars.put(variable, value)
- }
-
- val proc = pb.start()
+ override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
+ val startTime = System.currentTimeMillis
val env = SparkEnv.get
-
- // Start a thread to print the process's stderr to ours
- new Thread("stderr reader for " + pythonExec) {
- override def run() {
- for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
- System.err.println(line)
- }
- }
- }.start()
+ val worker = env.createPythonWorker(pythonExec, envVars.toMap)
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + pythonExec) {
override def run() {
SparkEnv.set(env)
- val out = new PrintWriter(proc.getOutputStream)
- val dOut = new DataOutputStream(proc.getOutputStream)
+ val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
+ val dataOut = new DataOutputStream(stream)
+ val printOut = new PrintWriter(stream)
// Partition index
- dOut.writeInt(split.index)
+ dataOut.writeInt(split.index)
// sparkFilesDir
- PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut)
+ PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
// Broadcast variables
- dOut.writeInt(broadcastVars.length)
+ dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
- dOut.writeLong(broadcast.id)
- dOut.writeInt(broadcast.value.length)
- dOut.write(broadcast.value)
- dOut.flush()
+ dataOut.writeLong(broadcast.id)
+ dataOut.writeInt(broadcast.value.length)
+ dataOut.write(broadcast.value)
}
+ dataOut.flush()
// Serialized user code
for (elem <- command) {
- out.println(elem)
+ printOut.println(elem)
}
- out.flush()
+ printOut.flush()
// Data values
for (elem <- parent.iterator(split, context)) {
- PythonRDD.writeAsPickle(elem, dOut)
+ PythonRDD.writeAsPickle(elem, dataOut)
}
- dOut.flush()
- out.flush()
- proc.getOutputStream.close()
+ dataOut.flush()
+ printOut.flush()
+ worker.shutdownOutput()
}
}.start()
// Return an iterator that read lines from the process's stdout
- val stream = new DataInputStream(proc.getInputStream)
+ val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
return new Iterator[Array[Byte]] {
def next(): Array[Byte] = {
val obj = _nextObj
- _nextObj = read()
+ if (hasNext) {
+ // FIXME: can deadlock if worker is waiting for us to
+ // respond to current message (currently irrelevant because
+ // output is shutdown before we read any input)
+ _nextObj = read()
+ }
obj
}
@@ -109,6 +101,17 @@ private[spark] class PythonRDD[T: ClassTag](
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
+ case -3 =>
+ // Timing data from worker
+ val bootTime = stream.readLong()
+ val initTime = stream.readLong()
+ val finishTime = stream.readLong()
+ val boot = bootTime - startTime
+ val init = initTime - bootTime
+ val finish = finishTime - initTime
+ val total = finishTime - startTime
+ logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
+ read
case -2 =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
@@ -116,23 +119,21 @@ private[spark] class PythonRDD[T: ClassTag](
stream.readFully(obj)
throw new PythonException(new String(obj))
case -1 =>
- // We've finished the data section of the output, but we can still read some
- // accumulator updates; let's do that, breaking when we get EOFException
- while (true) {
- val len2 = stream.readInt()
+ // We've finished the data section of the output, but we can still
+ // read some accumulator updates; let's do that, breaking when we
+ // get a negative length record.
+ var len2 = stream.readInt()
+ while (len2 >= 0) {
val update = new Array[Byte](len2)
stream.readFully(update)
accumulator += Collections.singletonList(update)
+ len2 = stream.readInt()
}
new Array[Byte](0)
}
} catch {
case eof: EOFException => {
- val exitStatus = proc.waitFor()
- if (exitStatus != 0) {
- throw new Exception("Subprocess exited with status " + exitStatus)
- }
- new Array[Byte](0)
+ throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
}
case e : Throwable => throw e
}
@@ -160,7 +161,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
override def compute(split: Partition, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (a, b)
- case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
+ case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
}
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
}
@@ -216,7 +217,7 @@ private[spark] object PythonRDD {
dOut.write(s)
dOut.writeByte(Pickle.STOP)
} else {
- throw new Exception("Unexpected RDD type")
+ throw new SparkException("Unexpected RDD type")
}
}
@@ -279,6 +280,10 @@ private class BytesToString extends spark.api.java.function.Function[Array[Byte]
class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
+ Utils.checkHost(serverHost, "Expected hostname")
+
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
@@ -291,7 +296,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
// This happens on the master, where we pass the updates to Python through a socket
val socket = new Socket(serverHost, serverPort)
val in = socket.getInputStream
- val out = new DataOutputStream(socket.getOutputStream)
+ val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
out.writeInt(val2.size)
for (array <- val2) {
out.writeInt(array.length)
diff --git a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala
new file mode 100644
index 0000000000..85d1dfeac8
--- /dev/null
+++ b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala
@@ -0,0 +1,113 @@
+package spark.api.python
+
+import java.io.{DataInputStream, IOException}
+import java.net.{Socket, SocketException, InetAddress}
+
+import scala.collection.JavaConversions._
+
+import spark._
+
+private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
+ extends Logging {
+ var daemon: Process = null
+ val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
+ var daemonPort: Int = 0
+
+ def create(): Socket = {
+ synchronized {
+ // Start the daemon if it hasn't been started
+ startDaemon()
+
+ // Attempt to connect, restart and retry once if it fails
+ try {
+ new Socket(daemonHost, daemonPort)
+ } catch {
+ case exc: SocketException => {
+ logWarning("Python daemon unexpectedly quit, attempting to restart")
+ stopDaemon()
+ startDaemon()
+ new Socket(daemonHost, daemonPort)
+ }
+ case e => throw e
+ }
+ }
+ }
+
+ def stop() {
+ stopDaemon()
+ }
+
+ private def startDaemon() {
+ synchronized {
+ // Is it already running?
+ if (daemon != null) {
+ return
+ }
+
+ try {
+ // Create and start the daemon
+ val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
+ val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py"))
+ val workerEnv = pb.environment()
+ workerEnv.putAll(envVars)
+ daemon = pb.start()
+
+ // Redirect the stderr to ours
+ new Thread("stderr reader for " + pythonExec) {
+ override def run() {
+ scala.util.control.Exception.ignoring(classOf[IOException]) {
+ // FIXME HACK: We copy the stream on the level of bytes to
+ // attempt to dodge encoding problems.
+ val in = daemon.getErrorStream
+ var buf = new Array[Byte](1024)
+ var len = in.read(buf)
+ while (len != -1) {
+ System.err.write(buf, 0, len)
+ len = in.read(buf)
+ }
+ }
+ }
+ }.start()
+
+ val in = new DataInputStream(daemon.getInputStream)
+ daemonPort = in.readInt()
+
+ // Redirect further stdout output to our stderr
+ new Thread("stdout reader for " + pythonExec) {
+ override def run() {
+ scala.util.control.Exception.ignoring(classOf[IOException]) {
+ // FIXME HACK: We copy the stream on the level of bytes to
+ // attempt to dodge encoding problems.
+ var buf = new Array[Byte](1024)
+ var len = in.read(buf)
+ while (len != -1) {
+ System.err.write(buf, 0, len)
+ len = in.read(buf)
+ }
+ }
+ }
+ }.start()
+ } catch {
+ case e => {
+ stopDaemon()
+ throw e
+ }
+ }
+
+ // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly
+ // detect our disappearance.
+ }
+ }
+
+ private def stopDaemon() {
+ synchronized {
+ // Request shutdown of existing daemon by sending SIGTERM
+ if (daemon != null) {
+ daemon.destroy()
+ }
+
+ daemon = null
+ daemonPort = 0
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/spark/deploy/ApplicationDescription.scala
index 6659e53b25..02193c7008 100644
--- a/core/src/main/scala/spark/deploy/ApplicationDescription.scala
+++ b/core/src/main/scala/spark/deploy/ApplicationDescription.scala
@@ -2,10 +2,11 @@ package spark.deploy
private[spark] class ApplicationDescription(
val name: String,
- val cores: Int,
+ val maxCores: Int, /* Integer.MAX_VALUE denotes an unlimited number of cores */
val memoryPerSlave: Int,
val command: Command,
- val sparkHome: String)
+ val sparkHome: String,
+ val appUiUrl: String)
extends Serializable {
val user = System.getProperty("user.name", "<unknown>")
diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala
index 8a3e64e4c2..51274acb1e 100644
--- a/core/src/main/scala/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/spark/deploy/DeployMessage.scala
@@ -4,6 +4,7 @@ import spark.deploy.ExecutorState.ExecutorState
import spark.deploy.master.{WorkerInfo, ApplicationInfo}
import spark.deploy.worker.ExecutorRunner
import scala.collection.immutable.List
+import spark.Utils
private[spark] sealed trait DeployMessage extends Serializable
@@ -19,7 +20,10 @@ case class RegisterWorker(
memory: Int,
webUiPort: Int,
publicAddress: String)
- extends DeployMessage
+ extends DeployMessage {
+ Utils.checkHost(host, "Required hostname")
+ assert (port > 0)
+}
private[spark]
case class ExecutorStateChanged(
@@ -58,7 +62,9 @@ private[spark]
case class RegisteredApplication(appId: String) extends DeployMessage
private[spark]
-case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int)
+case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) {
+ Utils.checkHostPort(hostPort, "Required hostport")
+}
private[spark]
case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String],
@@ -81,6 +87,9 @@ private[spark]
case class MasterState(host: String, port: Int, workers: Array[WorkerInfo],
activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) {
+ Utils.checkHost(host, "Required hostname")
+ assert (port > 0)
+
def uri = "spark://" + host + ":" + port
}
@@ -92,4 +101,8 @@ private[spark] case object RequestWorkerState
private[spark]
case class WorkerState(host: String, port: Int, workerId: String, executors: List[ExecutorRunner],
finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int,
- coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String)
+ coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) {
+
+ Utils.checkHost(host, "Required hostname")
+ assert (port > 0)
+}
diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala
index 702defb628..88b03a007c 100644
--- a/core/src/main/scala/spark/deploy/JsonProtocol.scala
+++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala
@@ -12,6 +12,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
def write(obj: WorkerInfo) = JsObject(
"id" -> JsString(obj.id),
"host" -> JsString(obj.host),
+ "port" -> JsNumber(obj.port),
"webuiaddress" -> JsString(obj.webUiAddress),
"cores" -> JsNumber(obj.cores),
"coresused" -> JsNumber(obj.coresUsed),
@@ -25,7 +26,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
"starttime" -> JsNumber(obj.startTime),
"id" -> JsString(obj.id),
"name" -> JsString(obj.desc.name),
- "cores" -> JsNumber(obj.desc.cores),
+ "cores" -> JsNumber(obj.desc.maxCores),
"user" -> JsString(obj.desc.user),
"memoryperslave" -> JsNumber(obj.desc.memoryPerSlave),
"submitdate" -> JsString(obj.submitDate.toString))
@@ -34,7 +35,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
implicit object AppDescriptionJsonFormat extends RootJsonWriter[ApplicationDescription] {
def write(obj: ApplicationDescription) = JsObject(
"name" -> JsString(obj.name),
- "cores" -> JsNumber(obj.cores),
+ "cores" -> JsNumber(obj.maxCores),
"memoryperslave" -> JsNumber(obj.memoryPerSlave),
"user" -> JsString(obj.user)
)
diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
index 6abaaeaa3f..2b0b3b10e7 100644
--- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala
@@ -18,7 +18,7 @@ import scala.collection.mutable.ArrayBuffer
private[spark]
class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
- private val localIpAddress = Utils.localIpAddress
+ private val localHostname = Utils.localHostName()
private val masterActorSystems = ArrayBuffer[ActorSystem]()
private val workerActorSystems = ArrayBuffer[ActorSystem]()
@@ -26,13 +26,13 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
- val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0)
+ val (masterSystem, masterPort) = Master.startSystemAndActor(localHostname, 0, 0)
masterActorSystems += masterSystem
- val masterUrl = "spark://" + localIpAddress + ":" + masterPort
+ val masterUrl = "spark://" + localHostname + ":" + masterPort
/* Start the Workers */
for (workerNum <- 1 to numWorkers) {
- val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker,
+ val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker,
memoryPerWorker, masterUrl, null, Some(workerNum))
workerActorSystems += workerSystem
}
diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala
index a38218a391..690bb20e50 100644
--- a/core/src/main/scala/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/spark/deploy/client/Client.scala
@@ -4,6 +4,7 @@ import spark.deploy._
import akka.actor._
import akka.pattern.ask
import scala.concurrent.duration._
+
import akka.pattern.AskTimeoutException
import spark.{SparkException, Logging}
import akka.remote.RemoteClientLifeCycleEvent
@@ -59,10 +60,10 @@ private[spark] class Client(
markDisconnected()
context.stop(self)
- case ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) =>
+ case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) =>
val fullId = appId + "/" + id
- logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, host, cores))
- listener.executorAdded(fullId, workerId, host, cores, memory)
+ logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores))
+ listener.executorAdded(fullId, workerId, hostPort, cores, memory)
case ExecutorUpdated(id, state, message, exitStatus) =>
val fullId = appId + "/" + id
@@ -112,7 +113,7 @@ private[spark] class Client(
def stop() {
if (actor != null) {
try {
- val timeout = 5.seconds
+ val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
val future = actor.ask(StopClient)(timeout)
Await.result(future, timeout)
} catch {
diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala
index b7008321df..e8c4083f9d 100644
--- a/core/src/main/scala/spark/deploy/client/ClientListener.scala
+++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala
@@ -12,7 +12,7 @@ private[spark] trait ClientListener {
def disconnected(): Unit
- def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit
+ def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit
def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit
}
diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala
index dc004b59ca..f195082808 100644
--- a/core/src/main/scala/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/spark/deploy/client/TestClient.scala
@@ -16,7 +16,7 @@ private[spark] object TestClient {
System.exit(0)
}
- def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) {}
+ def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {}
def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {}
}
@@ -25,7 +25,7 @@ private[spark] object TestClient {
val url = args(0)
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0)
val desc = new ApplicationDescription(
- "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home")
+ "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored")
val listener = new TestListener
val client = new Client(actorSystem, url, desc, listener)
client.start()
diff --git a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala
index 3591a94072..785c16e2be 100644
--- a/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala
@@ -10,7 +10,8 @@ private[spark] class ApplicationInfo(
val id: String,
val desc: ApplicationDescription,
val submitDate: Date,
- val driver: ActorRef)
+ val driver: ActorRef,
+ val appUiUrl: String)
{
var state = ApplicationState.WAITING
var executors = new mutable.HashMap[Int, ExecutorInfo]
@@ -37,7 +38,7 @@ private[spark] class ApplicationInfo(
coresGranted -= exec.cores
}
- def coresLeft: Int = desc.cores - coresGranted
+ def coresLeft: Int = desc.maxCores - coresGranted
private var _retryCount = 0
@@ -60,4 +61,5 @@ private[spark] class ApplicationInfo(
System.currentTimeMillis() - startTime
}
}
+
}
diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala
index d1428bcfc6..770cfe9d05 100644
--- a/core/src/main/scala/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/spark/deploy/master/Master.scala
@@ -15,7 +15,7 @@ import spark.{Logging, SparkException, Utils}
import spark.util.AkkaUtils
-private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
+private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
@@ -35,9 +35,11 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
var firstApp: Option[ApplicationInfo] = None
+ Utils.checkHost(host, "Expected hostname")
+
val masterPublicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
- if (envVar != null) envVar else ip
+ if (envVar != null) envVar else host
}
// As a temporary workaround before better ways of configuring memory, we allow users to set
@@ -46,7 +48,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean
override def preStart() {
- logInfo("Starting Spark master at spark://" + ip + ":" + port)
+ logInfo("Starting Spark master at spark://" + host + ":" + port)
// Listen for remote client disconnection events, since they don't go through Akka's watch()
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
startWebUi()
@@ -146,7 +148,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
case RequestMasterState => {
- sender ! MasterState(ip, port, workers.toArray, apps.toArray, completedApps.toArray)
+ sender ! MasterState(host, port, workers.toArray, apps.toArray, completedApps.toArray)
}
}
@@ -212,13 +214,13 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
worker.actor ! LaunchExecutor(exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome)
- exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
+ exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)
}
def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
publicAddress: String): WorkerInfo = {
// There may be one or more refs to dead workers on this same node (w/ different ID's), remove them.
- workers.filter(w => (w.host == host) && (w.state == WorkerState.DEAD)).foreach(workers -= _)
+ workers.filter(w => (w.host == host && w.port == port) && (w.state == WorkerState.DEAD)).foreach(workers -= _)
val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
workers += worker
idToWorker(worker.id) = worker
@@ -243,7 +245,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
val now = System.currentTimeMillis()
val date = new Date(now)
- val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver)
+ val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl)
apps += app
idToApp(app.id) = app
actorToApp(driver) = app
@@ -274,6 +276,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
for (exec <- app.executors.values) {
exec.worker.removeExecutor(exec)
exec.worker.actor ! KillExecutor(exec.application.id, exec.id)
+ exec.state = ExecutorState.KILLED
}
app.markFinished(state)
app.driver ! ApplicationRemoved(state.toString)
@@ -308,7 +311,7 @@ private[spark] object Master {
def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings)
- val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort)
+ val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort)
actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/spark/deploy/master/MasterArguments.scala
index 4ceab3fc03..3d28ecabb4 100644
--- a/core/src/main/scala/spark/deploy/master/MasterArguments.scala
+++ b/core/src/main/scala/spark/deploy/master/MasterArguments.scala
@@ -7,13 +7,13 @@ import spark.Utils
* Command-line parser for the master.
*/
private[spark] class MasterArguments(args: Array[String]) {
- var ip = Utils.localHostName()
+ var host = Utils.localHostName()
var port = 7077
var webUiPort = 8080
// Check for settings in environment variables
- if (System.getenv("SPARK_MASTER_IP") != null) {
- ip = System.getenv("SPARK_MASTER_IP")
+ if (System.getenv("SPARK_MASTER_HOST") != null) {
+ host = System.getenv("SPARK_MASTER_HOST")
}
if (System.getenv("SPARK_MASTER_PORT") != null) {
port = System.getenv("SPARK_MASTER_PORT").toInt
@@ -26,7 +26,13 @@ private[spark] class MasterArguments(args: Array[String]) {
def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
- ip = value
+ Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
+ host = value
+ parse(tail)
+
+ case ("--host" | "-h") :: value :: tail =>
+ Utils.checkHost(value, "Please use hostname " + value)
+ host = value
parse(tail)
case ("--port" | "-p") :: IntParam(value) :: tail =>
@@ -54,7 +60,8 @@ private[spark] class MasterArguments(args: Array[String]) {
"Usage: Master [options]\n" +
"\n" +
"Options:\n" +
- " -i IP, --ip IP IP address or DNS name to listen on\n" +
+ " -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" +
+ " -h HOST, --host HOST Hostname to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: 7077)\n" +
" --webui-port PORT Port for web UI (default: 8080)")
System.exit(exitCode)
diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
index fe859d48c3..34cee87853 100644
--- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
+++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
@@ -3,6 +3,7 @@ package spark.deploy.master
import akka.actor.{ActorRef, ActorContext, ActorRefFactory}
import scala.concurrent.Await
import akka.pattern.ask
+
import akka.util.Timeout
import scala.concurrent.duration._
import spray.routing.Directives
@@ -25,8 +26,7 @@ class MasterWebUI(master: ActorRef)(implicit val context: ActorContext) extends
val RESOURCE_DIR = "spark/deploy/master/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
- implicit val timeout = Timeout(10 seconds)
-
+ implicit val timeout = Timeout(Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds"))
val handler = {
get {
diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
index 23df1bb463..0c08c5f417 100644
--- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
@@ -2,6 +2,7 @@ package spark.deploy.master
import akka.actor.ActorRef
import scala.collection.mutable
+import spark.Utils
private[spark] class WorkerInfo(
val id: String,
@@ -13,6 +14,9 @@ private[spark] class WorkerInfo(
val webUiPort: Int,
val publicAddress: String) {
+ Utils.checkHost(host, "Expected hostname")
+ assert (port > 0)
+
var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info
var state: WorkerState.Value = WorkerState.ALIVE
var coresUsed = 0
@@ -23,6 +27,11 @@ private[spark] class WorkerInfo(
def coresFree: Int = cores - coresUsed
def memoryFree: Int = memory - memoryUsed
+ def hostPort: String = {
+ assert (port > 0)
+ host + ":" + port
+ }
+
def addExecutor(exec: ExecutorInfo) {
executors(exec.fullId) = exec
coresUsed += exec.cores
diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
index de11771c8e..d7f58b2cb1 100644
--- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
@@ -1,6 +1,7 @@
package spark.deploy.worker
import java.io._
+import java.lang.System.getenv
import spark.deploy.{ExecutorState, ExecutorStateChanged, ApplicationDescription}
import akka.actor.ActorRef
import spark.{Utils, Logging}
@@ -21,11 +22,13 @@ private[spark] class ExecutorRunner(
val memory: Int,
val worker: ActorRef,
val workerId: String,
- val hostname: String,
+ val hostPort: String,
val sparkHome: File,
val workDir: File)
extends Logging {
+ Utils.checkHostPort(hostPort, "Expected hostport")
+
val fullId = appId + "/" + execId
var workerThread: Thread = null
var process: Process = null
@@ -38,7 +41,7 @@ private[spark] class ExecutorRunner(
workerThread.start()
// Shutdown hook that kills actors on shutdown.
- shutdownHook = new Thread() {
+ shutdownHook = new Thread() {
override def run() {
if (process != null) {
logInfo("Shutdown hook killing child process.")
@@ -68,16 +71,36 @@ private[spark] class ExecutorRunner(
/** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */
def substituteVariables(argument: String): String = argument match {
case "{{EXECUTOR_ID}}" => execId.toString
- case "{{HOSTNAME}}" => hostname
+ case "{{HOSTNAME}}" => Utils.parseHostPort(hostPort)._1
case "{{CORES}}" => cores.toString
case other => other
}
def buildCommandSeq(): Seq[String] = {
val command = appDesc.command
- val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run"
- val runScript = new File(sparkHome, script).getCanonicalPath
- Seq(runScript, command.mainClass) ++ (command.arguments ++ Seq(appId)).map(substituteVariables)
+ val runner = Option(getenv("JAVA_HOME")).map(_ + "/bin/java").getOrElse("java")
+ // SPARK-698: do not call the run.cmd script, as process.destroy()
+ // fails to kill a process tree on Windows
+ Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++
+ command.arguments.map(substituteVariables)
+ }
+
+ /**
+ * Attention: this must always be aligned with the environment variables in the run scripts and
+ * the way the JAVA_OPTS are assembled there.
+ */
+ def buildJavaOpts(): Seq[String] = {
+ val libraryOpts = Option(getenv("SPARK_LIBRARY_PATH"))
+ .map(p => List("-Djava.library.path=" + p))
+ .getOrElse(Nil)
+ val userOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil)
+ val memoryOpts = Seq("-Xms" + memory + "M", "-Xmx" + memory + "M")
+
+ // Figure out our classpath with the external compute-classpath script
+ val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh"
+ val classPath = Utils.executeAndGetOutput(Seq(sparkHome + "/bin/compute-classpath" + ext))
+
+ Seq("-cp", classPath) ++ libraryOpts ++ userOpts ++ memoryOpts
}
/** Spawn a thread that will redirect a given stream to a file */
@@ -113,7 +136,6 @@ private[spark] class ExecutorRunner(
for ((key, value) <- appDesc.command.environment) {
env.put(key, value)
}
- env.put("SPARK_MEM", memory.toString + "m")
// In case we are running this from within the Spark Shell, avoid creating a "scala"
// parent process for the executor command
env.put("SPARK_LAUNCH_WITH_SCALA", "0")
diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala
index 5bcf00443c..b5dfd16e67 100644
--- a/core/src/main/scala/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/spark/deploy/worker/Worker.scala
@@ -17,7 +17,7 @@ import java.io.File
private[spark] class Worker(
- ip: String,
+ host: String,
port: Int,
webUiPort: Int,
cores: Int,
@@ -26,6 +26,9 @@ private[spark] class Worker(
workDirPath: String = null)
extends Actor with Logging {
+ Utils.checkHost(host, "Expected hostname")
+ assert (port > 0)
+
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
@@ -40,7 +43,7 @@ private[spark] class Worker(
val finishedExecutors = new HashMap[String, ExecutorRunner]
val publicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
- if (envVar != null) envVar else ip
+ if (envVar != null) envVar else host
}
var coresUsed = 0
@@ -52,10 +55,14 @@ private[spark] class Worker(
def createWorkDir() {
workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work"))
try {
- if (!workDir.exists() && !workDir.mkdirs()) {
+ // This sporadically fails - not sure why ... !workDir.exists() && !workDir.mkdirs()
+ // So attempting to create and then check if directory was created or not.
+ workDir.mkdirs()
+ if ( !workDir.exists() || !workDir.isDirectory) {
logError("Failed to create work directory " + workDir)
System.exit(1)
}
+ assert (workDir.isDirectory)
} catch {
case e: Exception =>
logError("Failed to create work directory " + workDir, e)
@@ -65,7 +72,7 @@ private[spark] class Worker(
override def preStart() {
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
- ip, port, cores, Utils.memoryMegabytesToString(memory)))
+ host, port, cores, Utils.memoryMegabytesToString(memory)))
sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
logInfo("Spark home: " + sparkHome)
createWorkDir()
@@ -76,7 +83,7 @@ private[spark] class Worker(
def connectToMaster() {
logInfo("Connecting to master " + masterUrl)
master = context.actorFor(Master.toAkkaUrl(masterUrl))
- master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
+ master ! RegisterWorker(workerId, host, port, cores, memory, webUiPort, publicAddress)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing
}
@@ -108,7 +115,7 @@ private[spark] class Worker(
case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
val manager = new ExecutorRunner(
- appId, execId, appDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir)
+ appId, execId, appDesc, cores_, memory_, self, workerId, host + ":" + port, new File(execSparkHome_), workDir)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
@@ -143,7 +150,7 @@ private[spark] class Worker(
masterDisconnected()
case RequestWorkerState => {
- sender ! WorkerState(ip, port, workerId, executors.values.toList,
+ sender ! WorkerState(host, port, workerId, executors.values.toList,
finishedExecutors.values.toList, masterUrl, cores, memory,
coresUsed, memoryUsed, masterWebUiUrl)
}
@@ -158,7 +165,7 @@ private[spark] class Worker(
}
def generateWorkerId(): String = {
- "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port)
+ "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), host, port)
}
override def postStop() {
@@ -169,7 +176,7 @@ private[spark] class Worker(
private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
- val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores,
+ val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
args.memory, args.master, args.workDir)
actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
index 08f02bad80..2b96611ee3 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
@@ -9,7 +9,7 @@ import java.lang.management.ManagementFactory
* Command-line parser for the master.
*/
private[spark] class WorkerArguments(args: Array[String]) {
- var ip = Utils.localHostName()
+ var host = Utils.localHostName()
var port = 0
var webUiPort = 8081
var cores = inferDefaultCores()
@@ -38,7 +38,13 @@ private[spark] class WorkerArguments(args: Array[String]) {
def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
- ip = value
+ Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
+ host = value
+ parse(tail)
+
+ case ("--host" | "-h") :: value :: tail =>
+ Utils.checkHost(value, "Please use hostname " + value)
+ host = value
parse(tail)
case ("--port" | "-p") :: IntParam(value) :: tail =>
@@ -93,7 +99,8 @@ private[spark] class WorkerArguments(args: Array[String]) {
" -c CORES, --cores CORES Number of cores to use\n" +
" -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" +
" -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" +
- " -i IP, --ip IP IP address or DNS name to listen on\n" +
+ " -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" +
+ " -h HOST, --host HOST Hostname to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: random)\n" +
" --webui-port PORT Port for web UI (default: 8081)")
System.exit(exitCode)
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
index 33a2a9516e..cc2ab6187a 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
@@ -3,6 +3,7 @@ package spark.deploy.worker
import akka.actor.{ActorRef, ActorContext}
import scala.concurrent.Await
import akka.pattern.ask
+
import akka.util.Timeout
import scala.concurrent.duration._
import spray.routing.Directives
@@ -25,7 +26,7 @@ class WorkerWebUI(worker: ActorRef, workDir: File)(implicit val context: ActorCo
val RESOURCE_DIR = "spark/deploy/worker/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
- implicit val timeout = Timeout(10 seconds)
+ implicit val timeout = Timeout(Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds"))
val handler = {
get {
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index 3e7407b58d..2bf55ea9a9 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -17,7 +17,7 @@ import java.nio.ByteBuffer
* The Mesos executor for Spark.
*/
private[spark] class Executor(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) extends Logging {
-
+
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
@@ -27,6 +27,11 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
initLogging()
+ // No ip or host:port - just hostname
+ Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname")
+ // must not have port specified.
+ assert (0 == Utils.parseHostPort(slaveHostname)._2)
+
// Make sure the local hostname we report matches the cluster scheduler's name for this host
Utils.setCustomHostname(slaveHostname)
@@ -37,7 +42,8 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
// Create our ClassLoader and set it on this thread
private val urlClassLoader = createClassLoader()
- Thread.currentThread.setContextClassLoader(urlClassLoader)
+ private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
+ Thread.currentThread.setContextClassLoader(replClassLoader)
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
@@ -67,6 +73,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
// Initialize Spark environment (using system properties read above)
val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
SparkEnv.set(env)
+ private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size")
// Start worker thread pool
val threadPool = new ThreadPoolExecutor(
@@ -82,7 +89,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
override def run() {
val startTime = System.currentTimeMillis()
SparkEnv.set(env)
- Thread.currentThread.setContextClassLoader(urlClassLoader)
+ Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + taskId)
context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
@@ -98,6 +105,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
val value = task.run(taskId.toInt)
val taskFinish = System.currentTimeMillis()
task.metrics.foreach{ m =>
+ m.hostname = Utils.localHostName
m.executorDeserializeTime = (taskStart - startTime).toInt
m.executorRunTime = (taskFinish - taskStart).toInt
}
@@ -108,6 +116,10 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null))
val serializedResult = ser.serialize(result)
logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit)
+ if (serializedResult.limit >= (akkaFrameSize - 1024)) {
+ context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(TaskResultTooBigFailure()))
+ return
+ }
context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
logInfo("Finished task ID " + taskId)
} catch {
@@ -117,7 +129,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
}
case t: Throwable => {
- val reason = ExceptionFailure(t)
+ val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace)
context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
// TODO: Should we exit the whole executor here? On the one hand, the failed task may
@@ -142,26 +154,31 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
val urls = currentJars.keySet.map { uri =>
new File(uri.split("/").last).toURI.toURL
}.toArray
- loader = new URLClassLoader(urls, loader)
+ new ExecutorURLClassLoader(urls, loader)
+ }
- // If the REPL is in use, add another ClassLoader that will read
- // new classes defined by the REPL as the user types code
+ /**
+ * If the REPL is in use, add another ClassLoader that will read
+ * new classes defined by the REPL as the user types code
+ */
+ private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = {
val classUri = System.getProperty("spark.repl.class.uri")
if (classUri != null) {
logInfo("Using REPL class URI: " + classUri)
- loader = {
- try {
- val klass = Class.forName("spark.repl.ExecutorClassLoader")
- .asInstanceOf[Class[_ <: ClassLoader]]
- val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
- constructor.newInstance(classUri, loader)
- } catch {
- case _: ClassNotFoundException => loader
- }
+ try {
+ val klass = Class.forName("spark.repl.ExecutorClassLoader")
+ .asInstanceOf[Class[_ <: ClassLoader]]
+ val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
+ return constructor.newInstance(classUri, parent)
+ } catch {
+ case _: ClassNotFoundException =>
+ logError("Could not find spark.repl.ExecutorClassLoader on classpath!")
+ System.exit(1)
+ null
}
+ } else {
+ return parent
}
-
- return new ExecutorURLClassLoader(Array(), loader)
}
/**
diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
index 1047f71c6a..ebe2ac68d8 100644
--- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
+++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
@@ -12,23 +12,27 @@ import spark.scheduler.cluster.RegisteredExecutor
import spark.scheduler.cluster.LaunchTask
import spark.scheduler.cluster.RegisterExecutorFailed
import spark.scheduler.cluster.RegisterExecutor
+import spark.Utils
+import spark.deploy.SparkHadoopUtil
private[spark] class StandaloneExecutorBackend(
driverUrl: String,
executorId: String,
- hostname: String,
+ hostPort: String,
cores: Int)
extends Actor
with ExecutorBackend
with Logging {
+ Utils.checkHostPort(hostPort, "Expected hostport")
+
var executor: Executor = null
var driver: ActorRef = null
override def preStart() {
logInfo("Connecting to driver: " + driverUrl)
driver = context.actorFor(driverUrl)
- driver ! RegisterExecutor(executorId, hostname, cores)
+ driver ! RegisterExecutor(executorId, hostPort, cores)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(driver) // Doesn't work with remote actors, but useful for testing
}
@@ -36,7 +40,8 @@ private[spark] class StandaloneExecutorBackend(
override def receive = {
case RegisteredExecutor(sparkProperties) =>
logInfo("Successfully registered with driver")
- executor = new Executor(executorId, hostname, sparkProperties)
+ // Make this host instead of hostPort ?
+ executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties)
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
@@ -63,11 +68,30 @@ private[spark] class StandaloneExecutorBackend(
private[spark] object StandaloneExecutorBackend {
def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
+ SparkHadoopUtil.runAsUser(run0, Tuple4[Any, Any, Any, Any] (driverUrl, executorId, hostname, cores))
+ }
+
+ // This will be run 'as' the user
+ def run0(args: Product) {
+ assert(4 == args.productArity)
+ runImpl(args.productElement(0).asInstanceOf[String],
+ args.productElement(1).asInstanceOf[String],
+ args.productElement(2).asInstanceOf[String],
+ args.productElement(3).asInstanceOf[Int])
+ }
+
+ private def runImpl(driverUrl: String, executorId: String, hostname: String, cores: Int) {
+ // Debug code
+ Utils.checkHost(hostname)
+
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0)
+ // set it
+ val sparkHostPort = hostname + ":" + boundPort
+ System.setProperty("spark.hostPort", sparkHostPort)
val actor = actorSystem.actorOf(
- Props(new StandaloneExecutorBackend(driverUrl, executorId, hostname, cores)),
+ Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)),
name = "Executor")
actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala
index 93bbb6b458..1dc13754f9 100644
--- a/core/src/main/scala/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/spark/executor/TaskMetrics.scala
@@ -2,6 +2,11 @@ package spark.executor
class TaskMetrics extends Serializable {
/**
+ * Host's name the task runs on
+ */
+ var hostname: String = _
+
+ /**
* Time taken on the executor to deserialize this task
*/
var executorDeserializeTime: Int = _
@@ -34,9 +39,14 @@ object TaskMetrics {
class ShuffleReadMetrics extends Serializable {
/**
+ * Time when shuffle finishs
+ */
+ var shuffleFinishTime: Long = _
+
+ /**
* Total number of blocks fetched in a shuffle (remote or local)
*/
- var totalBlocksFetched : Int = _
+ var totalBlocksFetched: Int = _
/**
* Number of remote blocks fetched in a shuffle
@@ -49,11 +59,6 @@ class ShuffleReadMetrics extends Serializable {
var localBlocksFetched: Int = _
/**
- * Total time to read shuffle data
- */
- var shuffleReadMillis: Long = _
-
- /**
* Total time that is spent blocked waiting for shuffle to fetch data
*/
var fetchWaitTime: Long = _
diff --git a/core/src/main/scala/spark/network/BufferMessage.scala b/core/src/main/scala/spark/network/BufferMessage.scala
new file mode 100644
index 0000000000..7b0e489a6c
--- /dev/null
+++ b/core/src/main/scala/spark/network/BufferMessage.scala
@@ -0,0 +1,94 @@
+package spark.network
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+import spark.storage.BlockManager
+
+
+private[spark]
+class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
+ extends Message(Message.BUFFER_MESSAGE, id_) {
+
+ val initialSize = currentSize()
+ var gotChunkForSendingOnce = false
+
+ def size = initialSize
+
+ def currentSize() = {
+ if (buffers == null || buffers.isEmpty) {
+ 0
+ } else {
+ buffers.map(_.remaining).reduceLeft(_ + _)
+ }
+ }
+
+ def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
+ if (maxChunkSize <= 0) {
+ throw new Exception("Max chunk size is " + maxChunkSize)
+ }
+
+ if (size == 0 && gotChunkForSendingOnce == false) {
+ val newChunk = new MessageChunk(
+ new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+
+ while(!buffers.isEmpty) {
+ val buffer = buffers(0)
+ if (buffer.remaining == 0) {
+ BlockManager.dispose(buffer)
+ buffers -= buffer
+ } else {
+ val newBuffer = if (buffer.remaining <= maxChunkSize) {
+ buffer.duplicate()
+ } else {
+ buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
+ }
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+ }
+ None
+ }
+
+ def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
+ // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
+ if (buffers.size > 1) {
+ throw new Exception("Attempting to get chunk from message with multiple data buffers")
+ }
+ val buffer = buffers(0)
+ if (buffer.remaining > 0) {
+ if (buffer.remaining < chunkSize) {
+ throw new Exception("Not enough space in data buffer for receiving chunk")
+ }
+ val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ return Some(newChunk)
+ }
+ None
+ }
+
+ def flip() {
+ buffers.foreach(_.flip)
+ }
+
+ def hasAckId() = (ackId != 0)
+
+ def isCompletelyReceived() = !buffers(0).hasRemaining
+
+ override def toString = {
+ if (hasAckId) {
+ "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
+ } else {
+ "BufferMessage(id = " + id + ", size = " + size + ")"
+ }
+ }
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
index d1451bc212..6e28f677a3 100644
--- a/core/src/main/scala/spark/network/Connection.scala
+++ b/core/src/main/scala/spark/network/Connection.scala
@@ -13,12 +13,13 @@ import java.net._
private[spark]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
- val remoteConnectionManagerId: ConnectionManagerId) extends Logging {
+ val socketRemoteConnectionManagerId: ConnectionManagerId)
+ extends Logging {
+
def this(channel_ : SocketChannel, selector_ : Selector) = {
this(channel_, selector_,
- ConnectionManagerId.fromSocketAddress(
- channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
- ))
+ ConnectionManagerId.fromSocketAddress(
+ channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]))
}
channel.configureBlocking(false)
@@ -33,16 +34,47 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
val remoteAddress = getRemoteAddress()
+ // Read channels typically do not register for write and write does not for read
+ // Now, we do have write registering for read too (temporarily), but this is to detect
+ // channel close NOT to actually read/consume data on it !
+ // How does this work if/when we move to SSL ?
+
+ // What is the interest to register with selector for when we want this connection to be selected
+ def registerInterest()
+
+ // What is the interest to register with selector for when we want this connection to
+ // be de-selected
+ // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack,
+ // it will be SelectionKey.OP_READ (until we fix it properly)
+ def unregisterInterest()
+
+ // On receiving a read event, should we change the interest for this channel or not ?
+ // Will be true for ReceivingConnection, false for SendingConnection.
+ def changeInterestForRead(): Boolean
+
+ // On receiving a write event, should we change the interest for this channel or not ?
+ // Will be false for ReceivingConnection, true for SendingConnection.
+ // Actually, for now, should not get triggered for ReceivingConnection
+ def changeInterestForWrite(): Boolean
+
+ def getRemoteConnectionManagerId(): ConnectionManagerId = {
+ socketRemoteConnectionManagerId
+ }
+
def key() = channel.keyFor(selector)
def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
- def read() {
- throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString)
+ // Returns whether we have to register for further reads or not.
+ def read(): Boolean = {
+ throw new UnsupportedOperationException(
+ "Cannot read on connection of type " + this.getClass.toString)
}
-
- def write() {
- throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString)
+
+ // Returns whether we have to register for further writes or not.
+ def write(): Boolean = {
+ throw new UnsupportedOperationException(
+ "Cannot write on connection of type " + this.getClass.toString)
}
def close() {
@@ -54,26 +86,32 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
callOnCloseCallback()
}
- def onClose(callback: Connection => Unit) {onCloseCallback = callback}
+ def onClose(callback: Connection => Unit) {
+ onCloseCallback = callback
+ }
- def onException(callback: (Connection, Exception) => Unit) {onExceptionCallback = callback}
+ def onException(callback: (Connection, Exception) => Unit) {
+ onExceptionCallback = callback
+ }
- def onKeyInterestChange(callback: (Connection, Int) => Unit) {onKeyInterestChangeCallback = callback}
+ def onKeyInterestChange(callback: (Connection, Int) => Unit) {
+ onKeyInterestChangeCallback = callback
+ }
def callOnExceptionCallback(e: Exception) {
if (onExceptionCallback != null) {
onExceptionCallback(this, e)
} else {
- logError("Error in connection to " + remoteConnectionManagerId +
+ logError("Error in connection to " + getRemoteConnectionManagerId() +
" and OnExceptionCallback not registered", e)
}
}
-
+
def callOnCloseCallback() {
if (onCloseCallback != null) {
onCloseCallback(this)
} else {
- logWarning("Connection to " + remoteConnectionManagerId +
+ logWarning("Connection to " + getRemoteConnectionManagerId() +
" closed and OnExceptionCallback not registered")
}
@@ -81,7 +119,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
def changeConnectionKeyInterest(ops: Int) {
if (onKeyInterestChangeCallback != null) {
- onKeyInterestChangeCallback(this, ops)
+ onKeyInterestChangeCallback(this, ops)
} else {
throw new Exception("OnKeyInterestChangeCallback not registered")
}
@@ -105,24 +143,25 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
print(" (" + position + ", " + length + ")")
buffer.position(curPosition)
}
-
}
-private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
- remoteId_ : ConnectionManagerId)
-extends Connection(SocketChannel.open, selector_, remoteId_) {
+private[spark]
+class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
+ remoteId_ : ConnectionManagerId)
+ extends Connection(SocketChannel.open, selector_, remoteId_) {
class Outbox(fair: Int = 0) {
val messages = new Queue[Message]()
- val defaultChunkSize = 65536 //32768 //16384
+ val defaultChunkSize = 65536 //32768 //16384
var nextMessageToBeUsed = 0
def addMessage(message: Message) {
- messages.synchronized{
+ messages.synchronized{
/*messages += message*/
messages.enqueue(message)
- logDebug("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]")
+ logDebug("Added [" + message + "] to outbox for sending to " +
+ "[" + getRemoteConnectionManagerId() + "]")
}
}
@@ -147,18 +186,18 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
message.started = true
message.startTime = System.currentTimeMillis
}
- return chunk
+ return chunk
} else {
- /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
+ /*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/
message.finishTime = System.currentTimeMillis
- logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
+ logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
"] in " + message.timeTaken )
}
}
}
None
}
-
+
private def getChunkRR(): Option[MessageChunk] = {
messages.synchronized {
while (!messages.isEmpty) {
@@ -170,15 +209,17 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
messages.enqueue(message)
nextMessageToBeUsed = nextMessageToBeUsed + 1
if (!message.started) {
- logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]")
+ logDebug(
+ "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
message.started = true
message.startTime = System.currentTimeMillis
}
- logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]")
- return chunk
+ logTrace(
+ "Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]")
+ return chunk
} else {
message.finishTime = System.currentTimeMillis
- logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
+ logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
"] in " + message.timeTaken )
}
}
@@ -186,27 +227,40 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
None
}
}
-
- val outbox = new Outbox(1)
+
+ private val outbox = new Outbox(1)
val currentBuffers = new ArrayBuffer[ByteBuffer]()
/*channel.socket.setSendBufferSize(256 * 1024)*/
- override def getRemoteAddress() = address
+ override def getRemoteAddress() = address
+
+ val DEFAULT_INTEREST = SelectionKey.OP_READ
+
+ override def registerInterest() {
+ // Registering read too - does not really help in most cases, but for some
+ // it does - so let us keep it for now.
+ changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST)
+ }
+
+ override def unregisterInterest() {
+ changeConnectionKeyInterest(DEFAULT_INTEREST)
+ }
def send(message: Message) {
outbox.synchronized {
outbox.addMessage(message)
if (channel.isConnected) {
- changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ)
+ registerInterest()
}
}
}
+ // MUST be called within the selector loop
def connect() {
try{
- channel.connect(address)
channel.register(selector, SelectionKey.OP_CONNECT)
+ channel.connect(address)
logInfo("Initiating connection to [" + address + "]")
} catch {
case e: Exception => {
@@ -216,36 +270,52 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
}
- def finishConnect() {
+ def finishConnect(force: Boolean): Boolean = {
try {
- channel.finishConnect
- changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ)
+ // Typically, this should finish immediately since it was triggered by a connect
+ // selection - though need not necessarily always complete successfully.
+ val connected = channel.finishConnect
+ if (!force && !connected) {
+ logInfo(
+ "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending")
+ return false
+ }
+
+ // Fallback to previous behavior - assume finishConnect completed
+ // This will happen only when finishConnect failed for some repeated number of times
+ // (10 or so)
+ // Is highly unlikely unless there was an unclean close of socket, etc
+ registerInterest()
logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
+ return true
} catch {
case e: Exception => {
logWarning("Error finishing connection to " + address, e)
callOnExceptionCallback(e)
+ // ignore
+ return true
}
}
}
- override def write() {
- try{
- while(true) {
+ override def write(): Boolean = {
+ try {
+ while (true) {
if (currentBuffers.size == 0) {
outbox.synchronized {
outbox.getChunk() match {
case Some(chunk) => {
- currentBuffers ++= chunk.buffers
+ currentBuffers ++= chunk.buffers
}
case None => {
- changeConnectionKeyInterest(SelectionKey.OP_READ)
- return
+ // changeConnectionKeyInterest(0)
+ /*key.interestOps(0)*/
+ return false
}
}
}
}
-
+
if (currentBuffers.size > 0) {
val buffer = currentBuffers(0)
val remainingBytes = buffer.remaining
@@ -254,69 +324,109 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
currentBuffers -= buffer
}
if (writtenBytes < remainingBytes) {
- return
+ // re-register for write.
+ return true
}
}
}
} catch {
- case e: Exception => {
- logWarning("Error writing in connection to " + remoteConnectionManagerId, e)
+ case e: Exception => {
+ logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
+ return false
}
}
+ // should not happen - to keep scala compiler happy
+ return true
}
- override def read() {
+ // This is a hack to determine if remote socket was closed or not.
+ // SendingConnection DOES NOT expect to receive any data - if it does, it is an error
+ // For a bunch of cases, read will return -1 in case remote socket is closed : hence we
+ // register for reads to determine that.
+ override def read(): Boolean = {
// We don't expect the other side to send anything; so, we just read to detect an error or EOF.
try {
val length = channel.read(ByteBuffer.allocate(1))
if (length == -1) { // EOF
close()
} else if (length > 0) {
- logWarning("Unexpected data read from SendingConnection to " + remoteConnectionManagerId)
+ logWarning(
+ "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId())
}
} catch {
case e: Exception =>
- logError("Exception while reading SendingConnection to " + remoteConnectionManagerId, e)
+ logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
}
+
+ false
}
+
+ override def changeInterestForRead(): Boolean = false
+
+ override def changeInterestForWrite(): Boolean = true
}
-private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
-extends Connection(channel_, selector_) {
-
+// Must be created within selector loop - else deadlock
+private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
+ extends Connection(channel_, selector_) {
+
class Inbox() {
val messages = new HashMap[Int, BufferMessage]()
-
+
def getChunk(header: MessageChunkHeader): Option[MessageChunk] = {
-
+
def createNewMessage: BufferMessage = {
val newMessage = Message.create(header).asInstanceOf[BufferMessage]
newMessage.started = true
newMessage.startTime = System.currentTimeMillis
- logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]")
+ logDebug(
+ "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
messages += ((newMessage.id, newMessage))
newMessage
}
-
+
val message = messages.getOrElseUpdate(header.id, createNewMessage)
- logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]")
+ logTrace(
+ "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]")
message.getChunkForReceiving(header.chunkSize)
}
-
+
def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = {
- messages.get(chunk.header.id)
+ messages.get(chunk.header.id)
}
def removeMessage(message: Message) {
messages -= message.id
}
}
-
+
+ @volatile private var inferredRemoteManagerId: ConnectionManagerId = null
+
+ override def getRemoteConnectionManagerId(): ConnectionManagerId = {
+ val currId = inferredRemoteManagerId
+ if (currId != null) currId else super.getRemoteConnectionManagerId()
+ }
+
+ // The reciever's remote address is the local socket on remote side : which is NOT
+ // the connection manager id of the receiver.
+ // We infer that from the messages we receive on the receiver socket.
+ private def processConnectionManagerId(header: MessageChunkHeader) {
+ val currId = inferredRemoteManagerId
+ if (header.address == null || currId != null) return
+
+ val managerId = ConnectionManagerId.fromSocketAddress(header.address)
+
+ if (managerId != null) {
+ inferredRemoteManagerId = managerId
+ }
+ }
+
+
val inbox = new Inbox()
val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
var onReceiveCallback: (Connection , Message) => Unit = null
@@ -324,24 +434,29 @@ extends Connection(channel_, selector_) {
channel.register(selector, SelectionKey.OP_READ)
- override def read() {
+ override def read(): Boolean = {
try {
while (true) {
if (currentChunk == null) {
val headerBytesRead = channel.read(headerBuffer)
if (headerBytesRead == -1) {
close()
- return
+ return false
}
if (headerBuffer.remaining > 0) {
- return
+ // re-register for read event ...
+ return true
}
headerBuffer.flip
if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
- throw new Exception("Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
+ throw new Exception(
+ "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
}
val header = MessageChunkHeader.create(headerBuffer)
headerBuffer.clear()
+
+ processConnectionManagerId(header)
+
header.typ match {
case Message.BUFFER_MESSAGE => {
if (header.totalSize == 0) {
@@ -349,7 +464,8 @@ extends Connection(channel_, selector_) {
onReceiveCallback(this, Message.create(header))
}
currentChunk = null
- return
+ // re-register for read event ...
+ return true
} else {
currentChunk = inbox.getChunk(header).orNull
}
@@ -357,26 +473,28 @@ extends Connection(channel_, selector_) {
case _ => throw new Exception("Message of unknown type received")
}
}
-
+
if (currentChunk == null) throw new Exception("No message chunk to receive data")
-
+
val bytesRead = channel.read(currentChunk.buffer)
if (bytesRead == 0) {
- return
+ // re-register for read event ...
+ return true
} else if (bytesRead == -1) {
close()
- return
+ return false
}
/*logDebug("Read " + bytesRead + " bytes for the buffer")*/
-
+
if (currentChunk.buffer.remaining == 0) {
/*println("Filled buffer at " + System.currentTimeMillis)*/
val bufferMessage = inbox.getMessageForChunk(currentChunk).get
if (bufferMessage.isCompletelyReceived) {
bufferMessage.flip
bufferMessage.finishTime = System.currentTimeMillis
- logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken)
+ logDebug("Finished receiving [" + bufferMessage + "] from " +
+ "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken)
if (onReceiveCallback != null) {
onReceiveCallback(this, bufferMessage)
}
@@ -386,13 +504,32 @@ extends Connection(channel_, selector_) {
}
}
} catch {
- case e: Exception => {
- logWarning("Error reading from connection to " + remoteConnectionManagerId, e)
+ case e: Exception => {
+ logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
+ return false
}
}
+ // should not happen - to keep scala compiler happy
+ return true
}
-
+
def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
+
+ override def changeInterestForRead(): Boolean = true
+
+ override def changeInterestForWrite(): Boolean = {
+ throw new IllegalStateException("Unexpected invocation right now")
+ }
+
+ override def registerInterest() {
+ // Registering read too - does not really help in most cases, but for some
+ // it does - so let us keep it for now.
+ changeConnectionKeyInterest(SelectionKey.OP_READ)
+ }
+
+ override def unregisterInterest() {
+ changeConnectionKeyInterest(0)
+ }
}
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
index 8f8892b8c7..cc5c62a542 100644
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -6,28 +6,19 @@ import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
import java.net._
-import java.util.concurrent.Executors
+import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor}
+import scala.collection.mutable.HashSet
import scala.collection.mutable.HashMap
import scala.collection.mutable.SynchronizedMap
import scala.collection.mutable.SynchronizedQueue
-import scala.collection.mutable.Queue
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{Await, Promise, ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.concurrent.duration._
-private[spark] case class ConnectionManagerId(host: String, port: Int) {
- def toSocketAddress() = new InetSocketAddress(host, port)
-}
-private[spark] object ConnectionManagerId {
- def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
- new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
- }
-}
-
private[spark] class ConnectionManager(port: Int) extends Logging {
class MessageStatus(
@@ -41,73 +32,263 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
def markDone() { completionHandler(this) }
}
-
- val selector = SelectorProvider.provider.openSelector()
- val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt)
- val serverChannel = ServerSocketChannel.open()
- val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
- val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
- val messageStatuses = new HashMap[Int, MessageStatus]
- val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
- val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
- val sendMessageRequests = new Queue[(Message, SendingConnection)]
+
+ private val selector = SelectorProvider.provider.openSelector()
+
+ private val handleMessageExecutor = new ThreadPoolExecutor(
+ System.getProperty("spark.core.connection.handler.threads.min","20").toInt,
+ System.getProperty("spark.core.connection.handler.threads.max","60").toInt,
+ System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable]())
+
+ private val handleReadWriteExecutor = new ThreadPoolExecutor(
+ System.getProperty("spark.core.connection.io.threads.min","4").toInt,
+ System.getProperty("spark.core.connection.io.threads.max","32").toInt,
+ System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable]())
+
+ // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap
+ private val handleConnectExecutor = new ThreadPoolExecutor(
+ System.getProperty("spark.core.connection.connect.threads.min","1").toInt,
+ System.getProperty("spark.core.connection.connect.threads.max","8").toInt,
+ System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable]())
+
+ private val serverChannel = ServerSocketChannel.open()
+ private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
+ private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
+ private val messageStatuses = new HashMap[Int, MessageStatus]
+ private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
+ private val registerRequests = new SynchronizedQueue[SendingConnection]
implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
- var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
+ private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
serverChannel.configureBlocking(false)
serverChannel.socket.setReuseAddress(true)
- serverChannel.socket.setReceiveBufferSize(256 * 1024)
+ serverChannel.socket.setReceiveBufferSize(256 * 1024)
serverChannel.socket.bind(new InetSocketAddress(port))
serverChannel.register(selector, SelectionKey.OP_ACCEPT)
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
-
- val selectorThread = new Thread("connection-manager-thread") {
+
+ private val selectorThread = new Thread("connection-manager-thread") {
override def run() = ConnectionManager.this.run()
}
selectorThread.setDaemon(true)
selectorThread.start()
- private def run() {
- try {
- while(!selectorThread.isInterrupted) {
- for ((connectionManagerId, sendingConnection) <- connectionRequests) {
- sendingConnection.connect()
- addConnection(sendingConnection)
- connectionRequests -= connectionManagerId
+ private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
+
+ private def triggerWrite(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ writeRunnableStarted.synchronized {
+ // So that we do not trigger more write events while processing this one.
+ // The write method will re-register when done.
+ if (conn.changeInterestForWrite()) conn.unregisterInterest()
+ if (writeRunnableStarted.contains(key)) {
+ // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE)
+ return
+ }
+
+ writeRunnableStarted += key
+ }
+ handleReadWriteExecutor.execute(new Runnable {
+ override def run() {
+ var register: Boolean = false
+ try {
+ register = conn.write()
+ } finally {
+ writeRunnableStarted.synchronized {
+ writeRunnableStarted -= key
+ if (register && conn.changeInterestForWrite()) {
+ conn.registerInterest()
+ }
+ }
}
- sendMessageRequests.synchronized {
- while (!sendMessageRequests.isEmpty) {
- val (message, connection) = sendMessageRequests.dequeue
- connection.send(message)
+ }
+ } )
+ }
+
+ private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
+
+ private def triggerRead(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ readRunnableStarted.synchronized {
+ // So that we do not trigger more read events while processing this one.
+ // The read method will re-register when done.
+ if (conn.changeInterestForRead())conn.unregisterInterest()
+ if (readRunnableStarted.contains(key)) {
+ return
+ }
+
+ readRunnableStarted += key
+ }
+ handleReadWriteExecutor.execute(new Runnable {
+ override def run() {
+ var register: Boolean = false
+ try {
+ register = conn.read()
+ } finally {
+ readRunnableStarted.synchronized {
+ readRunnableStarted -= key
+ if (register && conn.changeInterestForRead()) {
+ conn.registerInterest()
+ }
}
}
+ }
+ } )
+ }
+
+ private def triggerConnect(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection]
+ if (conn == null) return
+
+ // prevent other events from being triggered
+ // Since we are still trying to connect, we do not need to do the additional steps in triggerWrite
+ conn.changeConnectionKeyInterest(0)
+
+ handleConnectExecutor.execute(new Runnable {
+ override def run() {
+
+ var tries: Int = 10
+ while (tries >= 0) {
+ if (conn.finishConnect(false)) return
+ // Sleep ?
+ Thread.sleep(1)
+ tries -= 1
+ }
+
+ // fallback to previous behavior : we should not really come here since this method was
+ // triggered since channel became connectable : but at times, the first finishConnect need not
+ // succeed : hence the loop to retry a few 'times'.
+ conn.finishConnect(true)
+ }
+ } )
+ }
+
+ // MUST be called within selector loop - else deadlock.
+ private def triggerForceCloseByException(key: SelectionKey, e: Exception) {
+ try {
+ key.interestOps(0)
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ // Pushing to connect threadpool
+ handleConnectExecutor.execute(new Runnable {
+ override def run() {
+ try {
+ conn.callOnExceptionCallback(e)
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+ try {
+ conn.close()
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+ }
+ })
+ }
+
- while (!keyInterestChangeRequests.isEmpty) {
+ def run() {
+ try {
+ while(!selectorThread.isInterrupted) {
+ while (! registerRequests.isEmpty) {
+ val conn: SendingConnection = registerRequests.dequeue
+ addListeners(conn)
+ conn.connect()
+ addConnection(conn)
+ }
+
+ while(!keyInterestChangeRequests.isEmpty) {
val (key, ops) = keyInterestChangeRequests.dequeue
- val connection = connectionsByKey(key)
- val lastOps = key.interestOps()
- key.interestOps(ops)
-
- def intToOpStr(op: Int): String = {
- val opStrs = ArrayBuffer[String]()
- if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
- if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
- if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
- if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
- if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
+
+ try {
+ if (key.isValid) {
+ val connection = connectionsByKey.getOrElse(key, null)
+ if (connection != null) {
+ val lastOps = key.interestOps()
+ key.interestOps(ops)
+
+ // hot loop - prevent materialization of string if trace not enabled.
+ if (isTraceEnabled()) {
+ def intToOpStr(op: Int): String = {
+ val opStrs = ArrayBuffer[String]()
+ if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
+ if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
+ if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
+ if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
+ if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
+ }
+
+ logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() +
+ "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
+ }
+ }
+ } else {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
}
-
- logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId +
- "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
-
}
- val selectedKeysCount = selector.select()
+ val selectedKeysCount =
+ try {
+ selector.select()
+ } catch {
+ // Explicitly only dealing with CancelledKeyException here since other exceptions should be dealt with differently.
+ case e: CancelledKeyException => {
+ // Some keys within the selectors list are invalid/closed. clear them.
+ val allKeys = selector.keys().iterator()
+
+ while (allKeys.hasNext()) {
+ val key = allKeys.next()
+ try {
+ if (! key.isValid) {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ }
+ }
+ }
+ 0
+ }
+
if (selectedKeysCount == 0) {
logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
}
@@ -115,20 +296,40 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logInfo("Selector thread was interrupted!")
return
}
-
- val selectedKeys = selector.selectedKeys().iterator()
- while (selectedKeys.hasNext()) {
- val key = selectedKeys.next
- selectedKeys.remove()
- if (key.isValid) {
- if (key.isAcceptable) {
- acceptConnection(key)
- } else if (key.isConnectable) {
- connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect()
- } else if (key.isReadable) {
- connectionsByKey(key).read()
- } else if (key.isWritable) {
- connectionsByKey(key).write()
+
+ if (0 != selectedKeysCount) {
+ val selectedKeys = selector.selectedKeys().iterator()
+ while (selectedKeys.hasNext()) {
+ val key = selectedKeys.next
+ selectedKeys.remove()
+ try {
+ if (key.isValid) {
+ if (key.isAcceptable) {
+ acceptConnection(key)
+ } else
+ if (key.isConnectable) {
+ triggerConnect(key)
+ } else
+ if (key.isReadable) {
+ triggerRead(key)
+ } else
+ if (key.isWritable) {
+ triggerWrite(key)
+ }
+ } else {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ // weird, but we saw this happening - even though key.isValid was true, key.isAcceptable would throw CancelledKeyException.
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
}
}
}
@@ -137,97 +338,119 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
case e: Exception => logError("Error in select loop", e)
}
}
-
- private def acceptConnection(key: SelectionKey) {
+
+ def acceptConnection(key: SelectionKey) {
val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
- val newChannel = serverChannel.accept()
- val newConnection = new ReceivingConnection(newChannel, selector)
- newConnection.onReceive(receiveMessage)
- newConnection.onClose(removeConnection)
- addConnection(newConnection)
- logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
- }
- private def addConnection(connection: Connection) {
- connectionsByKey += ((connection.key, connection))
- if (connection.isInstanceOf[SendingConnection]) {
- val sendingConnection = connection.asInstanceOf[SendingConnection]
- connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection))
+ var newChannel = serverChannel.accept()
+
+ // accept them all in a tight loop. non blocking accept with no processing, should be fine
+ while (newChannel != null) {
+ try {
+ val newConnection = new ReceivingConnection(newChannel, selector)
+ newConnection.onReceive(receiveMessage)
+ addListeners(newConnection)
+ addConnection(newConnection)
+ logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
+ } catch {
+ // might happen in case of issues with registering with selector
+ case e: Exception => logError("Error in accept loop", e)
+ }
+
+ newChannel = serverChannel.accept()
}
+ }
+
+ private def addListeners(connection: Connection) {
connection.onKeyInterestChange(changeConnectionKeyInterest)
connection.onException(handleConnectionError)
connection.onClose(removeConnection)
}
- private def removeConnection(connection: Connection) {
+ def addConnection(connection: Connection) {
+ connectionsByKey += ((connection.key, connection))
+ }
+
+ def removeConnection(connection: Connection) {
connectionsByKey -= connection.key
- if (connection.isInstanceOf[SendingConnection]) {
- val sendingConnection = connection.asInstanceOf[SendingConnection]
- val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId
- logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
-
- connectionsById -= sendingConnectionManagerId
-
- messageStatuses.synchronized {
- messageStatuses
- .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
- logInfo("Notifying " + status)
- status.synchronized {
- status.attempted = true
- status.acked = false
- status.markDone()
- }
+
+ try {
+ if (connection.isInstanceOf[SendingConnection]) {
+ val sendingConnection = connection.asInstanceOf[SendingConnection]
+ val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
+ logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
+
+ connectionsById -= sendingConnectionManagerId
+
+ messageStatuses.synchronized {
+ messageStatuses
+ .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
+ logInfo("Notifying " + status)
+ status.synchronized {
+ status.attempted = true
+ status.acked = false
+ status.markDone()
+ }
+ })
+
+ messageStatuses.retain((i, status) => {
+ status.connectionManagerId != sendingConnectionManagerId
})
+ }
+ } else if (connection.isInstanceOf[ReceivingConnection]) {
+ val receivingConnection = connection.asInstanceOf[ReceivingConnection]
+ val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId()
+ logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
+
+ val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId)
+ if (! sendingConnectionOpt.isDefined) {
+ logError("Corresponding SendingConnectionManagerId not found")
+ return
+ }
- messageStatuses.retain((i, status) => {
- status.connectionManagerId != sendingConnectionManagerId
- })
- }
- } else if (connection.isInstanceOf[ReceivingConnection]) {
- val receivingConnection = connection.asInstanceOf[ReceivingConnection]
- val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId
- logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
-
- val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull
- if (sendingConnectionManagerId == null) {
- logError("Corresponding SendingConnectionManagerId not found")
- return
- }
- logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId)
-
- val sendingConnection = connectionsById(sendingConnectionManagerId)
- sendingConnection.close()
- connectionsById -= sendingConnectionManagerId
-
- messageStatuses.synchronized {
- for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
- logInfo("Notifying " + s)
- s.synchronized {
- s.attempted = true
- s.acked = false
- s.markDone()
+ val sendingConnection = sendingConnectionOpt.get
+ connectionsById -= remoteConnectionManagerId
+ sendingConnection.close()
+
+ val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
+
+ assert (sendingConnectionManagerId == remoteConnectionManagerId)
+
+ messageStatuses.synchronized {
+ for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
+ logInfo("Notifying " + s)
+ s.synchronized {
+ s.attempted = true
+ s.acked = false
+ s.markDone()
+ }
}
- }
- messageStatuses.retain((i, status) => {
- status.connectionManagerId != sendingConnectionManagerId
- })
+ messageStatuses.retain((i, status) => {
+ status.connectionManagerId != sendingConnectionManagerId
+ })
+ }
}
+ } finally {
+ // So that the selection keys can be removed.
+ wakeupSelector()
}
}
- private def handleConnectionError(connection: Connection, e: Exception) {
- logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId)
+ def handleConnectionError(connection: Connection, e: Exception) {
+ logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId())
removeConnection(connection)
}
- private def changeConnectionKeyInterest(connection: Connection, ops: Int) {
- keyInterestChangeRequests += ((connection.key, ops))
+ def changeConnectionKeyInterest(connection: Connection, ops: Int) {
+ keyInterestChangeRequests += ((connection.key, ops))
+ // so that registerations happen !
+ wakeupSelector()
}
- private def receiveMessage(connection: Connection, message: Message) {
+ def receiveMessage(connection: Connection, message: Message) {
val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
- logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
+ logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
val runnable = new Runnable() {
val creationTime = System.currentTimeMillis
def run() {
@@ -247,11 +470,11 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
if (bufferMessage.hasAckId) {
val sentMessageStatus = messageStatuses.synchronized {
messageStatuses.get(bufferMessage.ackId) match {
- case Some(status) => {
- messageStatuses -= bufferMessage.ackId
+ case Some(status) => {
+ messageStatuses -= bufferMessage.ackId
status
}
- case None => {
+ case None => {
throw new Exception("Could not find reference for received ack message " + message.id)
null
}
@@ -271,7 +494,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logDebug("Not calling back as callback is null")
None
}
-
+
if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
@@ -281,7 +504,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
}
}
- sendMessage(connectionManagerId, ackMessage.getOrElse {
+ sendMessage(connectionManagerId, ackMessage.getOrElse {
Message.createBufferMessage(bufferMessage.id)
})
}
@@ -293,18 +516,22 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
- val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId,
- new SendingConnection(inetSocketAddress, selector, connectionManagerId))
- newConnection
+ val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId)
+ registerRequests.enqueue(newConnection)
+
+ newConnection
}
- val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)
- val connection = connectionsById.getOrElse(lookupKey, startNewConnection())
+ // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ...
+ // If we do re-add it, we should consistently use it everywhere I guess ?
+ val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
message.senderAddress = id.toSocketAddress()
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
- /*connection.send(message)*/
- sendMessageRequests.synchronized {
- sendMessageRequests += ((message, connection))
- }
+ connection.send(message)
+
+ wakeupSelector()
+ }
+
+ private def wakeupSelector() {
selector.wakeup()
}
@@ -337,6 +564,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logWarning("All connections not cleaned up")
}
handleMessageExecutor.shutdown()
+ handleReadWriteExecutor.shutdown()
+ handleConnectExecutor.shutdown()
logInfo("ConnectionManager stopped")
}
}
@@ -346,17 +575,17 @@ private[spark] object ConnectionManager {
def main(args: Array[String]) {
val manager = new ConnectionManager(9999)
- manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
println("Received [" + msg + "] from [" + id + "]")
None
})
-
+
/*testSequentialSending(manager)*/
/*System.gc()*/
/*testParallelSending(manager)*/
/*System.gc()*/
-
+
/*testParallelDecreasingSending(manager)*/
/*System.gc()*/
@@ -368,9 +597,9 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Sequential Sending")
println("--------------------------")
- val size = 10 * 1024 * 1024
+ val size = 10 * 1024 * 1024
val count = 10
-
+
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
@@ -386,7 +615,7 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Parallel Sending")
println("--------------------------")
- val size = 10 * 1024 * 1024
+ val size = 10 * 1024 * 1024
val count = 10
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
@@ -401,12 +630,12 @@ private[spark] object ConnectionManager {
if (!g.isDefined) println("Failed")
})
val finishTime = System.currentTimeMillis
-
+
val mb = size * count / 1024.0 / 1024.0
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
println("--------------------------")
- println("Started at " + startTime + ", finished at " + finishTime)
+ println("Started at " + startTime + ", finished at " + finishTime)
println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)")
println("--------------------------")
println()
@@ -416,7 +645,7 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Parallel Decreasing Sending")
println("--------------------------")
- val size = 10 * 1024 * 1024
+ val size = 10 * 1024 * 1024
val count = 10
val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte)))
buffers.foreach(_.flip)
@@ -431,7 +660,7 @@ private[spark] object ConnectionManager {
if (!g.isDefined) println("Failed")
})
val finishTime = System.currentTimeMillis
-
+
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
println("--------------------------")
@@ -445,7 +674,7 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Continuous Sending")
println("--------------------------")
- val size = 10 * 1024 * 1024
+ val size = 10 * 1024 * 1024
val count = 10
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
diff --git a/core/src/main/scala/spark/network/ConnectionManagerId.scala b/core/src/main/scala/spark/network/ConnectionManagerId.scala
new file mode 100644
index 0000000000..b554e84251
--- /dev/null
+++ b/core/src/main/scala/spark/network/ConnectionManagerId.scala
@@ -0,0 +1,21 @@
+package spark.network
+
+import java.net.InetSocketAddress
+
+import spark.Utils
+
+
+private[spark] case class ConnectionManagerId(host: String, port: Int) {
+ // DEBUG code
+ Utils.checkHost(host)
+ assert (port > 0)
+
+ def toSocketAddress() = new InetSocketAddress(host, port)
+}
+
+
+private[spark] object ConnectionManagerId {
+ def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
+ new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
+ }
+}
diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala
index 525751b5bf..d4f03610eb 100644
--- a/core/src/main/scala/spark/network/Message.scala
+++ b/core/src/main/scala/spark/network/Message.scala
@@ -1,55 +1,10 @@
package spark.network
-import spark._
-
-import scala.collection.mutable.ArrayBuffer
-
import java.nio.ByteBuffer
-import java.net.InetAddress
import java.net.InetSocketAddress
-import storage.BlockManager
-
-private[spark] class MessageChunkHeader(
- val typ: Long,
- val id: Int,
- val totalSize: Int,
- val chunkSize: Int,
- val other: Int,
- val address: InetSocketAddress) {
- lazy val buffer = {
- val ip = address.getAddress.getAddress()
- val port = address.getPort()
- ByteBuffer.
- allocate(MessageChunkHeader.HEADER_SIZE).
- putLong(typ).
- putInt(id).
- putInt(totalSize).
- putInt(chunkSize).
- putInt(other).
- putInt(ip.size).
- put(ip).
- putInt(port).
- position(MessageChunkHeader.HEADER_SIZE).
- flip.asInstanceOf[ByteBuffer]
- }
-
- override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
- " and sizes " + totalSize + " / " + chunkSize + " bytes"
-}
-private[spark] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
- val size = if (buffer == null) 0 else buffer.remaining
- lazy val buffers = {
- val ab = new ArrayBuffer[ByteBuffer]()
- ab += header.buffer
- if (buffer != null) {
- ab += buffer
- }
- ab
- }
+import scala.collection.mutable.ArrayBuffer
- override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
-}
private[spark] abstract class Message(val typ: Long, val id: Int) {
var senderAddress: InetSocketAddress = null
@@ -58,120 +13,16 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
var finishTime = -1L
def size: Int
-
+
def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]
-
+
def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]
-
+
def timeTaken(): String = (finishTime - startTime).toString + " ms"
override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
}
-private[spark] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
-extends Message(Message.BUFFER_MESSAGE, id_) {
-
- val initialSize = currentSize()
- var gotChunkForSendingOnce = false
-
- def size = initialSize
-
- def currentSize() = {
- if (buffers == null || buffers.isEmpty) {
- 0
- } else {
- buffers.map(_.remaining).reduceLeft(_ + _)
- }
- }
-
- def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
- if (maxChunkSize <= 0) {
- throw new Exception("Max chunk size is " + maxChunkSize)
- }
-
- if (size == 0 && gotChunkForSendingOnce == false) {
- val newChunk = new MessageChunk(new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
- gotChunkForSendingOnce = true
- return Some(newChunk)
- }
-
- while(!buffers.isEmpty) {
- val buffer = buffers(0)
- if (buffer.remaining == 0) {
- BlockManager.dispose(buffer)
- buffers -= buffer
- } else {
- val newBuffer = if (buffer.remaining <= maxChunkSize) {
- buffer.duplicate()
- } else {
- buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
- }
- buffer.position(buffer.position + newBuffer.remaining)
- val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
- gotChunkForSendingOnce = true
- return Some(newChunk)
- }
- }
- None
- }
-
- def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
- // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
- if (buffers.size > 1) {
- throw new Exception("Attempting to get chunk from message with multiple data buffers")
- }
- val buffer = buffers(0)
- if (buffer.remaining > 0) {
- if (buffer.remaining < chunkSize) {
- throw new Exception("Not enough space in data buffer for receiving chunk")
- }
- val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
- buffer.position(buffer.position + newBuffer.remaining)
- val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
- return Some(newChunk)
- }
- None
- }
-
- def flip() {
- buffers.foreach(_.flip)
- }
-
- def hasAckId() = (ackId != 0)
-
- def isCompletelyReceived() = !buffers(0).hasRemaining
-
- override def toString = {
- if (hasAckId) {
- "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
- } else {
- "BufferMessage(id = " + id + ", size = " + size + ")"
- }
- }
-}
-
-private[spark] object MessageChunkHeader {
- val HEADER_SIZE = 40
-
- def create(buffer: ByteBuffer): MessageChunkHeader = {
- if (buffer.remaining != HEADER_SIZE) {
- throw new IllegalArgumentException("Cannot convert buffer data to Message")
- }
- val typ = buffer.getLong()
- val id = buffer.getInt()
- val totalSize = buffer.getInt()
- val chunkSize = buffer.getInt()
- val other = buffer.getInt()
- val ipSize = buffer.getInt()
- val ipBytes = new Array[Byte](ipSize)
- buffer.get(ipBytes)
- val ip = InetAddress.getByAddress(ipBytes)
- val port = buffer.getInt()
- new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
- }
-}
private[spark] object Message {
val BUFFER_MESSAGE = 1111111111L
@@ -180,14 +31,16 @@ private[spark] object Message {
def getNewId() = synchronized {
lastId += 1
- if (lastId == 0) lastId += 1
+ if (lastId == 0) {
+ lastId += 1
+ }
lastId
}
def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
if (dataBuffers == null) {
return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
- }
+ }
if (dataBuffers.exists(_ == null)) {
throw new Exception("Attempting to create buffer message with null buffer")
}
@@ -196,7 +49,7 @@ private[spark] object Message {
def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
createBufferMessage(dataBuffers, 0)
-
+
def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
if (dataBuffer == null) {
return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
@@ -204,15 +57,18 @@ private[spark] object Message {
return createBufferMessage(Array(dataBuffer), ackId)
}
}
-
- def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
+
+ def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
createBufferMessage(dataBuffer, 0)
-
- def createBufferMessage(ackId: Int): BufferMessage = createBufferMessage(new Array[ByteBuffer](0), ackId)
+
+ def createBufferMessage(ackId: Int): BufferMessage = {
+ createBufferMessage(new Array[ByteBuffer](0), ackId)
+ }
def create(header: MessageChunkHeader): Message = {
val newMessage: Message = header.typ match {
- case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
+ case BUFFER_MESSAGE => new BufferMessage(header.id,
+ ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
}
newMessage.senderAddress = header.address
newMessage
diff --git a/core/src/main/scala/spark/network/MessageChunk.scala b/core/src/main/scala/spark/network/MessageChunk.scala
new file mode 100644
index 0000000000..aaf9204d0e
--- /dev/null
+++ b/core/src/main/scala/spark/network/MessageChunk.scala
@@ -0,0 +1,25 @@
+package spark.network
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+
+private[network]
+class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
+
+ val size = if (buffer == null) 0 else buffer.remaining
+
+ lazy val buffers = {
+ val ab = new ArrayBuffer[ByteBuffer]()
+ ab += header.buffer
+ if (buffer != null) {
+ ab += buffer
+ }
+ ab
+ }
+
+ override def toString = {
+ "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
+ }
+}
diff --git a/core/src/main/scala/spark/network/MessageChunkHeader.scala b/core/src/main/scala/spark/network/MessageChunkHeader.scala
new file mode 100644
index 0000000000..3693d509d6
--- /dev/null
+++ b/core/src/main/scala/spark/network/MessageChunkHeader.scala
@@ -0,0 +1,58 @@
+package spark.network
+
+import java.net.InetAddress
+import java.net.InetSocketAddress
+import java.nio.ByteBuffer
+
+
+private[spark] class MessageChunkHeader(
+ val typ: Long,
+ val id: Int,
+ val totalSize: Int,
+ val chunkSize: Int,
+ val other: Int,
+ val address: InetSocketAddress) {
+ lazy val buffer = {
+ // No need to change this, at 'use' time, we do a reverse lookup of the hostname.
+ // Refer to network.Connection
+ val ip = address.getAddress.getAddress()
+ val port = address.getPort()
+ ByteBuffer.
+ allocate(MessageChunkHeader.HEADER_SIZE).
+ putLong(typ).
+ putInt(id).
+ putInt(totalSize).
+ putInt(chunkSize).
+ putInt(other).
+ putInt(ip.size).
+ put(ip).
+ putInt(port).
+ position(MessageChunkHeader.HEADER_SIZE).
+ flip.asInstanceOf[ByteBuffer]
+ }
+
+ override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
+ " and sizes " + totalSize + " / " + chunkSize + " bytes"
+}
+
+
+private[spark] object MessageChunkHeader {
+ val HEADER_SIZE = 40
+
+ def create(buffer: ByteBuffer): MessageChunkHeader = {
+ if (buffer.remaining != HEADER_SIZE) {
+ throw new IllegalArgumentException("Cannot convert buffer data to Message")
+ }
+ val typ = buffer.getLong()
+ val id = buffer.getInt()
+ val totalSize = buffer.getInt()
+ val chunkSize = buffer.getInt()
+ val other = buffer.getInt()
+ val ipSize = buffer.getInt()
+ val ipBytes = new Array[Byte](ipSize)
+ buffer.get(ipBytes)
+ val ip = InetAddress.getByAddress(ipBytes)
+ val port = buffer.getInt()
+ new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
+ }
+}
diff --git a/core/src/main/scala/spark/network/netty/FileHeader.scala b/core/src/main/scala/spark/network/netty/FileHeader.scala
new file mode 100644
index 0000000000..aed4254234
--- /dev/null
+++ b/core/src/main/scala/spark/network/netty/FileHeader.scala
@@ -0,0 +1,57 @@
+package spark.network.netty
+
+import io.netty.buffer._
+
+import spark.Logging
+
+private[spark] class FileHeader (
+ val fileLen: Int,
+ val blockId: String) extends Logging {
+
+ lazy val buffer = {
+ val buf = Unpooled.buffer()
+ buf.capacity(FileHeader.HEADER_SIZE)
+ buf.writeInt(fileLen)
+ buf.writeInt(blockId.length)
+ blockId.foreach((x: Char) => buf.writeByte(x))
+ //padding the rest of header
+ if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
+ buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
+ } else {
+ throw new Exception("too long header " + buf.readableBytes)
+ logInfo("too long header")
+ }
+ buf
+ }
+
+}
+
+private[spark] object FileHeader {
+
+ val HEADER_SIZE = 40
+
+ def getFileLenOffset = 0
+ def getFileLenSize = Integer.SIZE/8
+
+ def create(buf: ByteBuf): FileHeader = {
+ val length = buf.readInt
+ val idLength = buf.readInt
+ val idBuilder = new StringBuilder(idLength)
+ for (i <- 1 to idLength) {
+ idBuilder += buf.readByte().asInstanceOf[Char]
+ }
+ val blockId = idBuilder.toString()
+ new FileHeader(length, blockId)
+ }
+
+
+ def main (args:Array[String]){
+
+ val header = new FileHeader(25,"block_0");
+ val buf = header.buffer;
+ val newheader = FileHeader.create(buf);
+ System.out.println("id="+newheader.blockId+",size="+newheader.fileLen)
+
+ }
+}
+
diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
new file mode 100644
index 0000000000..8d5194a737
--- /dev/null
+++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
@@ -0,0 +1,101 @@
+package spark.network.netty
+
+import java.util.concurrent.Executors
+
+import io.netty.buffer.ByteBuf
+import io.netty.channel.ChannelHandlerContext
+import io.netty.util.CharsetUtil
+
+import spark.Logging
+import spark.network.ConnectionManagerId
+
+import scala.collection.JavaConverters._
+
+
+private[spark] class ShuffleCopier extends Logging {
+
+ def getBlock(host: String, port: Int, blockId: String,
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+
+ val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
+ val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
+ val fc = new FileClient(handler, connectTimeout)
+
+ try {
+ fc.init()
+ fc.connect(host, port)
+ fc.sendRequest(blockId)
+ fc.waitForClose()
+ fc.close()
+ } catch {
+ // Handle any socket-related exceptions in FileClient
+ case e: Exception => {
+ logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e)
+ handler.handleError(blockId)
+ }
+ }
+ }
+
+ def getBlock(cmId: ConnectionManagerId, blockId: String,
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
+ }
+
+ def getBlocks(cmId: ConnectionManagerId,
+ blocks: Seq[(String, Long)],
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+
+ for ((blockId, size) <- blocks) {
+ getBlock(cmId, blockId, resultCollectCallback)
+ }
+ }
+}
+
+
+private[spark] object ShuffleCopier extends Logging {
+
+ private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit)
+ extends FileClientHandler with Logging {
+
+ override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
+ logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
+ resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
+ }
+
+ override def handleError(blockId: String) {
+ if (!isComplete) {
+ resultCollectCallBack(blockId, -1, null)
+ }
+ }
+ }
+
+ def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
+ if (size != -1) {
+ logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
+ }
+ }
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: ShuffleCopier <host> <port> <shuffle_block_id> <threads>")
+ System.exit(1)
+ }
+ val host = args(0)
+ val port = args(1).toInt
+ val file = args(2)
+ val threads = if (args.length > 3) args(3).toInt else 10
+
+ val copiers = Executors.newFixedThreadPool(80)
+ val tasks = (for (i <- Range(0, threads)) yield {
+ Executors.callable(new Runnable() {
+ def run() {
+ val copier = new ShuffleCopier()
+ copier.getBlock(host, port, file, echoResultCollectCallBack)
+ }
+ })
+ }).asJava
+ copiers.invokeAll(tasks)
+ copiers.shutdown
+ System.exit(0)
+ }
+}
diff --git a/core/src/main/scala/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/spark/network/netty/ShuffleSender.scala
new file mode 100644
index 0000000000..d6fa4b1e80
--- /dev/null
+++ b/core/src/main/scala/spark/network/netty/ShuffleSender.scala
@@ -0,0 +1,53 @@
+package spark.network.netty
+
+import java.io.File
+
+import spark.Logging
+
+
+private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
+
+ val server = new FileServer(pResolver, portIn)
+ server.start()
+
+ def stop() {
+ server.stop()
+ }
+
+ def port: Int = server.getPort()
+}
+
+
+/**
+ * An application for testing the shuffle sender as a standalone program.
+ */
+private[spark] object ShuffleSender {
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println(
+ "Usage: ShuffleSender <port> <subDirsPerLocalDir> <list of shuffle_block_directories>")
+ System.exit(1)
+ }
+
+ val port = args(0).toInt
+ val subDirsPerLocalDir = args(1).toInt
+ val localDirs = args.drop(2).map(new File(_))
+
+ val pResovler = new PathResolver {
+ override def getAbsolutePath(blockId: String): String = {
+ if (!blockId.startsWith("shuffle_")) {
+ throw new Exception("Block " + blockId + " is not a shuffle block")
+ }
+ // Figure out which local directory it hashes to, and which subdirectory in that
+ val hash = math.abs(blockId.hashCode)
+ val dirId = hash % localDirs.length
+ val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
+ val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
+ val file = new File(subDir, blockId)
+ return file.getAbsolutePath
+ }
+ }
+ val sender = new ShuffleSender(port, pResovler)
+ }
+}
diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index f44d37a91f..3e60860b3e 100644
--- a/core/src/main/scala/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -1,8 +1,9 @@
package spark.rdd
-import scala.collection.mutable.HashMap
import scala.reflect.ClassTag
+
import spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext}
+import spark.storage.BlockManager
private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition {
val index = idx
@@ -12,12 +13,7 @@ private[spark]
class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[String])
extends RDD[T](sc, Nil) {
- @transient lazy val locations_ = {
- val blockManager = SparkEnv.get.blockManager
- /*val locations = blockIds.map(id => blockManager.getLocations(id))*/
- val locations = blockManager.getLocations(blockIds)
- HashMap(blockIds.zip(locations):_*)
- }
+ @transient lazy val locations_ = BlockManager.blockIdsToExecutorLocations(blockIds, SparkEnv.get)
override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => {
new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition]
diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
index 700a4160c8..efd29fa561 100644
--- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
@@ -9,6 +9,7 @@ import org.apache.hadoop.util.ReflectionUtils
import org.apache.hadoop.fs.Path
import java.io.{File, IOException, EOFException}
import java.text.NumberFormat
+import spark.deploy.SparkHadoopUtil
private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
@@ -22,13 +23,20 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
@transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
override def getPartitions: Array[Partition] = {
- val dirContents = fs.listStatus(new Path(checkpointPath))
- val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
- val numPartitions = partitionFiles.size
- if (numPartitions > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
- ! partitionFiles(numPartitions-1).endsWith(CheckpointRDD.splitIdToFile(numPartitions-1)))) {
- throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
- }
+ val cpath = new Path(checkpointPath)
+ val numPartitions =
+ // listStatus can throw exception if path does not exist.
+ if (fs.exists(cpath)) {
+ val dirContents = fs.listStatus(cpath)
+ val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
+ val numPart = partitionFiles.size
+ if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
+ ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) {
+ throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
+ }
+ numPart
+ } else 0
+
Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
}
@@ -36,7 +44,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
checkpointData.get.cpFile = Some(checkpointPath)
override def getPreferredLocations(split: Partition): Seq[String] = {
- val status = fs.getFileStatus(new Path(checkpointPath))
+ val status = fs.getFileStatus(new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)))
val locations = fs.getFileBlockLocations(status, 0, status.getLen)
locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
}
@@ -59,7 +67,7 @@ private[spark] object CheckpointRDD extends Logging {
def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
val outputDir = new Path(path)
- val fs = outputDir.getFileSystem(new Configuration())
+ val fs = outputDir.getFileSystem(SparkHadoopUtil.newConfiguration())
val finalOutputName = splitIdToFile(ctx.splitId)
val finalOutputPath = new Path(outputDir, finalOutputName)
@@ -84,6 +92,7 @@ private[spark] object CheckpointRDD extends Logging {
if (!fs.rename(tempOutputPath, finalOutputPath)) {
if (!fs.exists(finalOutputPath)) {
+ logInfo("Deleting tempOutputPath " + tempOutputPath)
fs.delete(tempOutputPath, false)
throw new IOException("Checkpoint failed: failed to save output of task: "
+ ctx.attemptId + " and final output path does not exist")
@@ -96,7 +105,7 @@ private[spark] object CheckpointRDD extends Logging {
}
def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
- val fs = path.getFileSystem(new Configuration())
+ val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration())
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val fileInputStream = fs.open(path, bufferSize)
val serializer = SparkEnv.get.serializer.newInstance()
@@ -118,7 +127,7 @@ private[spark] object CheckpointRDD extends Logging {
val sc = new SparkContext(cluster, "CheckpointRDD Test")
val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
val path = new Path(hdfsPath, "temp")
- val fs = path.getFileSystem(new Configuration())
+ val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration())
sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index a6235491ca..8966f9f86e 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -6,7 +6,7 @@ import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer
-import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext}
+import spark.{Aggregator, Partition, Partitioner, RDD, SparkEnv, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
@@ -49,12 +49,17 @@ private[spark] class CoGroupAggregator
*
* @param rdds parent RDDs.
* @param part partitioner used to partition the shuffle output.
- * @param mapSideCombine flag indicating whether to merge values before shuffle step.
+ * @param mapSideCombine flag indicating whether to merge values before shuffle step. If the flag
+ * is on, Spark does an extra pass over the data on the map side to merge
+ * all values belonging to the same key together. This can reduce the amount
+ * of data shuffled if and only if the number of distinct keys is very small,
+ * and the ratio of key size to value size is also very small.
*/
class CoGroupedRDD[K](
@transient var rdds: Seq[RDD[(K, _)]],
part: Partitioner,
- val mapSideCombine: Boolean = true)
+ val mapSideCombine: Boolean = false,
+ val serializerClass: String = null)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
private val aggr = new CoGroupAggregator
@@ -68,9 +73,9 @@ class CoGroupedRDD[K](
logInfo("Adding shuffle dependency with " + rdd)
if (mapSideCombine) {
val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
- new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
+ new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part, serializerClass)
} else {
- new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part)
+ new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part, serializerClass)
}
}
}
@@ -112,6 +117,7 @@ class CoGroupedRDD[K](
}
}
+ val ser = SparkEnv.get.serializerManager.get(serializerClass)
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
// Read them from the parent
@@ -124,12 +130,12 @@ class CoGroupedRDD[K](
val fetcher = SparkEnv.get.shuffleFetcher
if (mapSideCombine) {
// With map side combine on, for each key, the shuffle fetcher returns a list of values.
- fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics).foreach {
+ fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
case (key, values) => getSeq(key)(depNum) ++= values
}
} else {
// With map side combine off, for each key the shuffle fetcher returns a single value.
- fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics).foreach {
+ fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics, ser).foreach {
case (key, value) => getSeq(key)(depNum) += value
}
}
diff --git a/core/src/main/scala/spark/rdd/EmptyRDD.scala b/core/src/main/scala/spark/rdd/EmptyRDD.scala
new file mode 100644
index 0000000000..e4dd3a7fa7
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/EmptyRDD.scala
@@ -0,0 +1,16 @@
+package spark.rdd
+
+import spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext}
+
+
+/**
+ * An RDD that is empty, i.e. has no element in it.
+ */
+class EmptyRDD[T: ClassManifest](sc: SparkContext) extends RDD[T](sc, Nil) {
+
+ override def getPartitions: Array[Partition] = Array.empty
+
+ override def compute(split: Partition, context: TaskContext): Iterator[T] = {
+ throw new UnsupportedOperationException("empty RDD")
+ }
+}
diff --git a/core/src/main/scala/spark/rdd/JdbcRDD.scala b/core/src/main/scala/spark/rdd/JdbcRDD.scala
new file mode 100644
index 0000000000..a50f407737
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/JdbcRDD.scala
@@ -0,0 +1,103 @@
+package spark.rdd
+
+import java.sql.{Connection, ResultSet}
+
+import spark.{Logging, Partition, RDD, SparkContext, TaskContext}
+import spark.util.NextIterator
+
+private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
+ override def index = idx
+}
+
+/**
+ * An RDD that executes an SQL query on a JDBC connection and reads results.
+ * For usage example, see test case JdbcRDDSuite.
+ *
+ * @param getConnection a function that returns an open Connection.
+ * The RDD takes care of closing the connection.
+ * @param sql the text of the query.
+ * The query must contain two ? placeholders for parameters used to partition the results.
+ * E.g. "select title, author from books where ? <= id and id <= ?"
+ * @param lowerBound the minimum value of the first placeholder
+ * @param upperBound the maximum value of the second placeholder
+ * The lower and upper bounds are inclusive.
+ * @param numPartitions the number of partitions.
+ * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
+ * the query would be executed twice, once with (1, 10) and once with (11, 20)
+ * @param mapRow a function from a ResultSet to a single row of the desired result type(s).
+ * This should only call getInt, getString, etc; the RDD takes care of calling next.
+ * The default maps a ResultSet to an array of Object.
+ */
+class JdbcRDD[T: ClassManifest](
+ sc: SparkContext,
+ getConnection: () => Connection,
+ sql: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int,
+ mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _)
+ extends RDD[T](sc, Nil) with Logging {
+
+ override def getPartitions: Array[Partition] = {
+ // bounds are inclusive, hence the + 1 here and - 1 on end
+ val length = 1 + upperBound - lowerBound
+ (0 until numPartitions).map(i => {
+ val start = lowerBound + ((i * length) / numPartitions).toLong
+ val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1
+ new JdbcPartition(i, start, end)
+ }).toArray
+ }
+
+ override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] {
+ context.addOnCompleteCallback{ () => closeIfNeeded() }
+ val part = thePart.asInstanceOf[JdbcPartition]
+ val conn = getConnection()
+ val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
+
+ // setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results,
+ // rather than pulling entire resultset into memory.
+ // see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html
+ if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) {
+ stmt.setFetchSize(Integer.MIN_VALUE)
+ logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ")
+ }
+
+ stmt.setLong(1, part.lower)
+ stmt.setLong(2, part.upper)
+ val rs = stmt.executeQuery()
+
+ override def getNext: T = {
+ if (rs.next()) {
+ mapRow(rs)
+ } else {
+ finished = true
+ null.asInstanceOf[T]
+ }
+ }
+
+ override def close() {
+ try {
+ if (null != rs && ! rs.isClosed()) rs.close()
+ } catch {
+ case e: Exception => logWarning("Exception closing resultset", e)
+ }
+ try {
+ if (null != stmt && ! stmt.isClosed()) stmt.close()
+ } catch {
+ case e: Exception => logWarning("Exception closing statement", e)
+ }
+ try {
+ if (null != conn && ! stmt.isClosed()) conn.close()
+ logInfo("closed connection")
+ } catch {
+ case e: Exception => logWarning("Exception closing connection", e)
+ }
+ }
+ }
+}
+
+object JdbcRDD {
+ def resultSetToObjectArray(rs: ResultSet) = {
+ Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
+ }
+}
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index bdd974590a..901d01ef30 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -57,7 +57,7 @@ class NewHadoopRDD[K, V](
override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopPartition]
val conf = confBroadcast.value.value
- val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
+ val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
if (format.isInstanceOf[Configurable]) {
diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index 34d32eb85a..349e6162c4 100644
--- a/core/src/main/scala/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -10,6 +10,7 @@ import scala.io.Source
import scala.reflect.ClassTag
import spark.{RDD, SparkEnv, Partition, TaskContext}
+import spark.broadcast.Broadcast
/**
@@ -19,14 +20,21 @@ import spark.{RDD, SparkEnv, Partition, TaskContext}
class PipedRDD[T: ClassTag](
prev: RDD[T],
command: Seq[String],
- envVars: Map[String, String])
+ envVars: Map[String, String],
+ printPipeContext: (String => Unit) => Unit,
+ printRDDElement: (T, String => Unit) => Unit)
extends RDD[String](prev) {
- def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map())
-
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
- def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command))
+ def this(
+ prev: RDD[T],
+ command: String,
+ envVars: Map[String, String] = Map(),
+ printPipeContext: (String => Unit) => Unit = null,
+ printRDDElement: (T, String => Unit) => Unit = null) =
+ this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement)
+
override def getPartitions: Array[Partition] = firstParent[T].partitions
@@ -53,8 +61,17 @@ class PipedRDD[T: ClassTag](
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
+
+ // input the pipe context firstly
+ if (printPipeContext != null) {
+ printPipeContext(out.println(_))
+ }
for (elem <- firstParent[T].iterator(split, context)) {
- out.println(elem)
+ if (printRDDElement != null) {
+ printRDDElement(elem, out.println(_))
+ } else {
+ out.println(elem)
+ }
}
out.close()
}
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 4e33b7dd5c..c7d1926b83 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -3,6 +3,7 @@ package spark.rdd
import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext}
import spark.SparkContext._
+
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
override val index = idx
override def hashCode(): Int = idx
@@ -12,13 +13,15 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
* The resulting RDD from a shuffle (e.g. repartitioning of data).
* @param prev the parent RDD.
* @param part the partitioner used to partition the RDD
+ * @param serializerClass class name of the serializer to use.
* @tparam K the key class.
* @tparam V the value class.
*/
class ShuffledRDD[K, V](
@transient prev: RDD[(K, V)],
- part: Partitioner)
- extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part))) {
+ part: Partitioner,
+ serializerClass: String = null)
+ extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part, serializerClass))) {
override val partitioner = Some(part)
@@ -28,6 +31,7 @@ class ShuffledRDD[K, V](
override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics)
+ SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics,
+ SparkEnv.get.serializerManager.get(serializerClass))
}
}
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index 5e56900b18..9274839bca 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -15,6 +15,7 @@ import spark.SparkEnv
import spark.ShuffleDependency
import spark.OneToOneDependency
+
/**
* An optimized version of cogroup for set difference/subtraction.
*
@@ -34,7 +35,9 @@ import spark.OneToOneDependency
private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
@transient var rdd1: RDD[(K, V)],
@transient var rdd2: RDD[(K, W)],
- part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) {
+ part: Partitioner,
+ val serializerClass: String = null)
+ extends RDD[(K, V)](rdd1.context, Nil) {
override def getDependencies: Seq[Dependency[_]] = {
Seq(rdd1, rdd2).map { rdd =>
@@ -43,7 +46,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
- new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part)
+ new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part, serializerClass)
}
}
}
@@ -68,6 +71,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
+ val serializer = SparkEnv.get.serializerManager.get(serializerClass)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
@@ -80,12 +84,16 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
}
}
def integrate(dep: CoGroupSplitDep, op: ((K, V)) => Unit) = dep match {
- case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
for (t <- rdd.iterator(itsSplit, context))
op(t.asInstanceOf[(K, V)])
- case ShuffleCoGroupSplitDep(shuffleId) =>
- for (t <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
+ }
+ case ShuffleCoGroupSplitDep(shuffleId) => {
+ val iter = SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index,
+ context.taskMetrics, serializer)
+ for (t <- iter)
op(t.asInstanceOf[(K, V)])
+ }
}
// the first dep is rdd1; add all values to the map
integrate(partition.deps(0), t => getSeq(t._1) += t._2)
diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala
new file mode 100644
index 0000000000..b234428ab2
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala
@@ -0,0 +1,138 @@
+package spark.rdd
+
+import spark.{Utils, OneToOneDependency, RDD, SparkContext, Partition, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
+
+private[spark] class ZippedPartitionsPartition(
+ idx: Int,
+ @transient rdds: Seq[RDD[_]])
+ extends Partition {
+
+ override val index: Int = idx
+ var partitionValues = rdds.map(rdd => rdd.partitions(idx))
+ def partitions = partitionValues
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ partitionValues = rdds.map(rdd => rdd.partitions(idx))
+ oos.defaultWriteObject()
+ }
+}
+
+abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
+ sc: SparkContext,
+ var rdds: Seq[RDD[_]])
+ extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) {
+
+ override def getPartitions: Array[Partition] = {
+ val sizes = rdds.map(x => x.partitions.size)
+ if (!sizes.forall(x => x == sizes(0))) {
+ throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
+ }
+ val array = new Array[Partition](sizes(0))
+ for (i <- 0 until sizes(0)) {
+ array(i) = new ZippedPartitionsPartition(i, rdds)
+ }
+ array
+ }
+
+ override def getPreferredLocations(s: Partition): Seq[String] = {
+ // Note that as number of rdd's increase and/or number of slaves in cluster increase, the computed preferredLocations below
+ // become diminishingly small : so we might need to look at alternate strategies to alleviate this.
+ // If there are no (or very small number of preferred locations), we will end up transferred the blocks to 'any' node in the
+ // cluster - paying with n/w and cache cost.
+ // Maybe pick a node which figures max amount of time ?
+ // Choose node which is hosting 'larger' of some subset of blocks ?
+ // Look at rack locality to ensure chosen host is atleast rack local to both hosting node ?, etc (would be good to defer this if possible)
+ val splits = s.asInstanceOf[ZippedPartitionsPartition].partitions
+ val rddSplitZip = rdds.zip(splits)
+
+ // exact match.
+ val exactMatchPreferredLocations = rddSplitZip.map(x => x._1.preferredLocations(x._2))
+ val exactMatchLocations = exactMatchPreferredLocations.reduce((x, y) => x.intersect(y))
+
+ // Remove exact match and then do host local match.
+ val exactMatchHosts = exactMatchLocations.map(Utils.parseHostPort(_)._1)
+ val matchPreferredHosts = exactMatchPreferredLocations.map(locs => locs.map(Utils.parseHostPort(_)._1))
+ .reduce((x, y) => x.intersect(y))
+ val otherNodeLocalLocations = matchPreferredHosts.filter { s => !exactMatchHosts.contains(s) }
+
+ otherNodeLocalLocations ++ exactMatchLocations
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdds = null
+ }
+}
+
+class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest](
+ sc: SparkContext,
+ f: (Iterator[A], Iterator[B]) => Iterator[V],
+ var rdd1: RDD[A],
+ var rdd2: RDD[B])
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) {
+
+ override def compute(s: Partition, context: TaskContext): Iterator[V] = {
+ val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
+ f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context))
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdd1 = null
+ rdd2 = null
+ }
+}
+
+class ZippedPartitionsRDD3
+ [A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest](
+ sc: SparkContext,
+ f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
+ var rdd1: RDD[A],
+ var rdd2: RDD[B],
+ var rdd3: RDD[C])
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) {
+
+ override def compute(s: Partition, context: TaskContext): Iterator[V] = {
+ val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
+ f(rdd1.iterator(partitions(0), context),
+ rdd2.iterator(partitions(1), context),
+ rdd3.iterator(partitions(2), context))
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdd1 = null
+ rdd2 = null
+ rdd3 = null
+ }
+}
+
+class ZippedPartitionsRDD4
+ [A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest](
+ sc: SparkContext,
+ f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
+ var rdd1: RDD[A],
+ var rdd2: RDD[B],
+ var rdd3: RDD[C],
+ var rdd4: RDD[D])
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) {
+
+ override def compute(s: Partition, context: TaskContext): Iterator[V] = {
+ val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
+ f(rdd1.iterator(partitions(0), context),
+ rdd2.iterator(partitions(1), context),
+ rdd3.iterator(partitions(2), context),
+ rdd4.iterator(partitions(3), context))
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdd1 = null
+ rdd2 = null
+ rdd3 = null
+ rdd4 = null
+ }
+}
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
index 1b438cd505..be05fb71f9 100644
--- a/core/src/main/scala/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -1,5 +1,7 @@
package spark.rdd
+import spark.{Utils, OneToOneDependency, RDD, SparkContext, Partition, TaskContext}
+
import java.io.{ObjectOutputStream, IOException}
import scala.reflect.ClassTag
@@ -50,8 +52,27 @@ class ZippedRDD[T: ClassTag, U: ClassTag](
}
override def getPreferredLocations(s: Partition): Seq[String] = {
+ // Note that as number of slaves in cluster increase, the computed preferredLocations can become small : so we might need
+ // to look at alternate strategies to alleviate this. (If there are no (or very small number of preferred locations), we
+ // will end up transferred the blocks to 'any' node in the cluster - paying with n/w and cache cost.
+ // Maybe pick one or the other ? (so that atleast one block is local ?).
+ // Choose node which is hosting 'larger' of the blocks ?
+ // Look at rack locality to ensure chosen host is atleast rack local to both hosting node ?, etc (would be good to defer this if possible)
val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions
- rdd1.preferredLocations(partition1).intersect(rdd2.preferredLocations(partition2))
+ val pref1 = rdd1.preferredLocations(partition1)
+ val pref2 = rdd2.preferredLocations(partition2)
+
+ // exact match - instance local and host local.
+ val exactMatchLocations = pref1.intersect(pref2)
+
+ // remove locations which are already handled via exactMatchLocations, and intersect where both partitions are node local.
+ val otherNodeLocalPref1 = pref1.filter(loc => ! exactMatchLocations.contains(loc)).map(loc => Utils.parseHostPort(loc)._1)
+ val otherNodeLocalPref2 = pref2.filter(loc => ! exactMatchLocations.contains(loc)).map(loc => Utils.parseHostPort(loc)._1)
+ val otherNodeLocalLocations = otherNodeLocalPref1.intersect(otherNodeLocalPref2)
+
+
+ // Can have mix of instance local (hostPort) and node local (host) locations as preference !
+ exactMatchLocations ++ otherNodeLocalLocations
}
override def clearDependencies() {
diff --git a/core/src/main/scala/spark/scheduler/ActiveJob.scala b/core/src/main/scala/spark/scheduler/ActiveJob.scala
index 5a4e9a582d..105eaecb22 100644
--- a/core/src/main/scala/spark/scheduler/ActiveJob.scala
+++ b/core/src/main/scala/spark/scheduler/ActiveJob.scala
@@ -2,6 +2,8 @@ package spark.scheduler
import spark.TaskContext
+import java.util.Properties
+
/**
* Tracks information about an active job in the DAGScheduler.
*/
@@ -11,7 +13,8 @@ private[spark] class ActiveJob(
val func: (TaskContext, Iterator[_]) => _,
val partitions: Array[Int],
val callSite: String,
- val listener: JobListener) {
+ val listener: JobListener,
+ val properties: Properties) {
val numPartitions = partitions.length
val finished = Array.fill[Boolean](numPartitions)(false)
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index b838cf84a8..1164c40c43 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -4,6 +4,7 @@ import cluster.TaskInfo
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
+import java.util.Properties
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.reflect.ClassTag
@@ -13,7 +14,7 @@ import spark.executor.TaskMetrics
import spark.partial.ApproximateActionListener
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
-import spark.storage.BlockManagerMaster
+import spark.storage.{BlockManager, BlockManagerMaster}
import spark.util.{MetadataCleaner, TimeStampedHashMap}
/**
@@ -51,6 +52,11 @@ class DAGScheduler(
eventQueue.put(ExecutorLost(execId))
}
+ // Called by TaskScheduler when a host is added
+ override def executorGained(execId: String, hostPort: String) {
+ eventQueue.put(ExecutorGained(execId, hostPort))
+ }
+
// Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
override def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
@@ -89,6 +95,8 @@ class DAGScheduler(
// stray messages to detect.
val failedGeneration = new HashMap[String, Long]
+ val idToActiveJob = new HashMap[Int, ActiveJob]
+
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
val running = new HashSet[Stage] // Stages we are running right now
val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures
@@ -113,9 +121,8 @@ class DAGScheduler(
private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
if (!cacheLocs.contains(rdd.id)) {
val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
- cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map {
- locations => locations.map(_.ip).toList
- }.toArray
+ val locs = BlockManager.blockIdsToExecutorLocations(blockIds, env, blockManagerMaster)
+ cacheLocs(rdd.id) = blockIds.map(locs.getOrElse(_, Nil))
}
cacheLocs(rdd.id)
}
@@ -222,13 +229,14 @@ class DAGScheduler(
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
- resultHandler: (Int, U) => Unit)
+ resultHandler: (Int, U) => Unit,
+ properties: Properties = null)
: (JobSubmitted, JobWaiter[U]) =
{
assert(partitions.size > 0)
val waiter = new JobWaiter(partitions.size, resultHandler)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)
+ val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
return (toSubmit, waiter)
}
@@ -238,13 +246,14 @@ class DAGScheduler(
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
- resultHandler: (Int, U) => Unit)
+ resultHandler: (Int, U) => Unit,
+ properties: Properties = null)
{
if (partitions.size == 0) {
return
}
val (toSubmit, waiter) = prepareJob(
- finalRdd, func, partitions, callSite, allowLocal, resultHandler)
+ finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)
eventQueue.put(toSubmit)
waiter.awaitResult() match {
case JobSucceeded => {}
@@ -259,13 +268,14 @@ class DAGScheduler(
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
callSite: String,
- timeout: Long)
+ timeout: Long,
+ properties: Properties = null)
: PartialResult[R] =
{
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
- eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener))
+ eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener, properties))
return listener.awaitResult() // Will throw an exception if the job fails
}
@@ -275,10 +285,10 @@ class DAGScheduler(
*/
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
- case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
+ case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) =>
val runId = nextRunId.getAndIncrement()
val finalStage = newStage(finalRDD, None, runId)
- val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
+ val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
" output partitions (allowLocal=" + allowLocal + ")")
@@ -289,15 +299,22 @@ class DAGScheduler(
// Compute very short actions like first() or take() with no parent stages locally.
runLocally(job)
} else {
+ sparkListeners.foreach(_.onJobStart(SparkListenerJobStart(job, properties)))
+ idToActiveJob(runId) = job
activeJobs += job
resultStageToJob(finalStage) = job
submitStage(finalStage)
}
+ case ExecutorGained(execId, hostPort) =>
+ handleExecutorGained(execId, hostPort)
+
case ExecutorLost(execId) =>
handleExecutorLost(execId)
case completion: CompletionEvent =>
+ sparkListeners.foreach(_.onTaskEnd(SparkListenerTaskEnd(completion.task,
+ completion.reason, completion.taskInfo, completion.taskMetrics)))
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
@@ -308,6 +325,7 @@ class DAGScheduler(
for (job <- activeJobs) {
val error = new SparkException("Job cancelled because SparkContext was shut down")
job.listener.jobFailed(error)
+ sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error))))
}
return true
}
@@ -455,11 +473,13 @@ class DAGScheduler(
}
}
if (tasks.size > 0) {
+ sparkListeners.foreach(_.onStageSubmitted(SparkListenerStageSubmitted(stage, tasks.size)))
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
myPending ++= tasks
logDebug("New pending tasks: " + myPending)
+ val properties = idToActiveJob(stage.priority).properties
taskSched.submitTasks(
- new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority))
+ new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority, properties))
if (!stage.submissionTime.isDefined) {
stage.submissionTime = Some(System.currentTimeMillis())
}
@@ -508,6 +528,7 @@ class DAGScheduler(
activeJobs -= job
resultStageToJob -= stage
markStageAsFinished(stage)
+ sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobSucceeded)))
}
job.listener.taskSucceeded(rt.outputId, event.result)
}
@@ -631,6 +652,14 @@ class DAGScheduler(
"(generation " + currentGeneration + ")")
}
}
+
+ private def handleExecutorGained(execId: String, hostPort: String) {
+ // remove from failedGeneration(execId) ?
+ if (failedGeneration.contains(execId)) {
+ logInfo("Host gained which was in lost list earlier: " + hostPort)
+ failedGeneration -= execId
+ }
+ }
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
@@ -640,7 +669,9 @@ class DAGScheduler(
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
- job.listener.jobFailed(new SparkException("Job failed: " + reason))
+ val error = new SparkException("Job failed: " + reason)
+ job.listener.jobFailed(error)
+ sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error))))
activeJobs -= job
resultStageToJob -= resultStage
}
diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
index ed0b9bf178..acad915f13 100644
--- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
@@ -1,5 +1,7 @@
package spark.scheduler
+import java.util.Properties
+
import spark.scheduler.cluster.TaskInfo
import scala.collection.mutable.Map
@@ -20,7 +22,8 @@ private[spark] case class JobSubmitted(
partitions: Array[Int],
allowLocal: Boolean,
callSite: String,
- listener: JobListener)
+ listener: JobListener,
+ properties: Properties = null)
extends DAGSchedulerEvent
private[spark] case class CompletionEvent(
@@ -32,6 +35,10 @@ private[spark] case class CompletionEvent(
taskMetrics: TaskMetrics)
extends DAGSchedulerEvent
+private[spark] case class ExecutorGained(execId: String, hostPort: String) extends DAGSchedulerEvent {
+ Utils.checkHostPort(hostPort, "Required hostport")
+}
+
private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
diff --git a/core/src/main/scala/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/spark/scheduler/InputFormatInfo.scala
new file mode 100644
index 0000000000..287f731787
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/InputFormatInfo.scala
@@ -0,0 +1,156 @@
+package spark.scheduler
+
+import spark.Logging
+import scala.collection.immutable.Set
+import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
+import org.apache.hadoop.util.ReflectionUtils
+import org.apache.hadoop.mapreduce.Job
+import org.apache.hadoop.conf.Configuration
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.JavaConversions._
+
+
+/**
+ * Parses and holds information about inputFormat (and files) specified as a parameter.
+ */
+class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_],
+ val path: String) extends Logging {
+
+ var mapreduceInputFormat: Boolean = false
+ var mapredInputFormat: Boolean = false
+
+ validate()
+
+ override def toString(): String = {
+ "InputFormatInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + ", path : " + path
+ }
+
+ override def hashCode(): Int = {
+ var hashCode = inputFormatClazz.hashCode
+ hashCode = hashCode * 31 + path.hashCode
+ hashCode
+ }
+
+ // Since we are not doing canonicalization of path, this can be wrong : like relative vs absolute path
+ // .. which is fine, this is best case effort to remove duplicates - right ?
+ override def equals(other: Any): Boolean = other match {
+ case that: InputFormatInfo => {
+ // not checking config - that should be fine, right ?
+ this.inputFormatClazz == that.inputFormatClazz &&
+ this.path == that.path
+ }
+ case _ => false
+ }
+
+ private def validate() {
+ logDebug("validate InputFormatInfo : " + inputFormatClazz + ", path " + path)
+
+ try {
+ if (classOf[org.apache.hadoop.mapreduce.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) {
+ logDebug("inputformat is from mapreduce package")
+ mapreduceInputFormat = true
+ }
+ else if (classOf[org.apache.hadoop.mapred.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) {
+ logDebug("inputformat is from mapred package")
+ mapredInputFormat = true
+ }
+ else {
+ throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz +
+ " is NOT a supported input format ? does not implement either of the supported hadoop api's")
+ }
+ }
+ catch {
+ case e: ClassNotFoundException => {
+ throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + " cannot be found ?", e)
+ }
+ }
+ }
+
+
+ // This method does not expect failures, since validate has already passed ...
+ private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = {
+ val conf = new JobConf(configuration)
+ FileInputFormat.setInputPaths(conf, path)
+
+ val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] =
+ ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[
+ org.apache.hadoop.mapreduce.InputFormat[_, _]]
+ val job = new Job(conf)
+
+ val retval = new ArrayBuffer[SplitInfo]()
+ val list = instance.getSplits(job)
+ for (split <- list) {
+ retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split)
+ }
+
+ return retval.toSet
+ }
+
+ // This method does not expect failures, since validate has already passed ...
+ private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = {
+ val jobConf = new JobConf(configuration)
+ FileInputFormat.setInputPaths(jobConf, path)
+
+ val instance: org.apache.hadoop.mapred.InputFormat[_, _] =
+ ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], jobConf).asInstanceOf[
+ org.apache.hadoop.mapred.InputFormat[_, _]]
+
+ val retval = new ArrayBuffer[SplitInfo]()
+ instance.getSplits(jobConf, jobConf.getNumMapTasks()).foreach(
+ elem => retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, elem)
+ )
+
+ return retval.toSet
+ }
+
+ private def findPreferredLocations(): Set[SplitInfo] = {
+ logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat +
+ ", inputFormatClazz : " + inputFormatClazz)
+ if (mapreduceInputFormat) {
+ return prefLocsFromMapreduceInputFormat()
+ }
+ else {
+ assert(mapredInputFormat)
+ return prefLocsFromMapredInputFormat()
+ }
+ }
+}
+
+
+
+
+object InputFormatInfo {
+ /**
+ Computes the preferred locations based on input(s) and returned a location to block map.
+ Typical use of this method for allocation would follow some algo like this
+ (which is what we currently do in YARN branch) :
+ a) For each host, count number of splits hosted on that host.
+ b) Decrement the currently allocated containers on that host.
+ c) Compute rack info for each host and update rack -> count map based on (b).
+ d) Allocate nodes based on (c)
+ e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node
+ (even if data locality on that is very high) : this is to prevent fragility of job if a single
+ (or small set of) hosts go down.
+
+ go to (a) until required nodes are allocated.
+
+ If a node 'dies', follow same procedure.
+
+ PS: I know the wording here is weird, hopefully it makes some sense !
+ */
+ def computePreferredLocations(formats: Seq[InputFormatInfo]): HashMap[String, HashSet[SplitInfo]] = {
+
+ val nodeToSplit = new HashMap[String, HashSet[SplitInfo]]
+ for (inputSplit <- formats) {
+ val splits = inputSplit.findPreferredLocations()
+
+ for (split <- splits){
+ val location = split.hostLocation
+ val set = nodeToSplit.getOrElseUpdate(location, new HashSet[SplitInfo])
+ set += split
+ }
+ }
+
+ nodeToSplit
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/JobLogger.scala b/core/src/main/scala/spark/scheduler/JobLogger.scala
new file mode 100644
index 0000000000..178bfaba3d
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/JobLogger.scala
@@ -0,0 +1,306 @@
+package spark.scheduler
+
+import java.io.PrintWriter
+import java.io.File
+import java.io.FileNotFoundException
+import java.text.SimpleDateFormat
+import java.util.{Date, Properties}
+import java.util.concurrent.LinkedBlockingQueue
+import scala.collection.mutable.{Map, HashMap, ListBuffer}
+import scala.io.Source
+import spark._
+import spark.executor.TaskMetrics
+import spark.scheduler.cluster.TaskInfo
+
+// Used to record runtime information for each job, including RDD graph
+// tasks' start/stop shuffle information and information from outside
+
+class JobLogger(val logDirName: String) extends SparkListener with Logging {
+ private val logDir =
+ if (System.getenv("SPARK_LOG_DIR") != null)
+ System.getenv("SPARK_LOG_DIR")
+ else
+ "/tmp/spark"
+ private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
+ private val stageIDToJobID = new HashMap[Int, Int]
+ private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
+ private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
+
+ createLogDir()
+ def this() = this(String.valueOf(System.currentTimeMillis()))
+
+ def getLogDir = logDir
+ def getJobIDtoPrintWriter = jobIDToPrintWriter
+ def getStageIDToJobID = stageIDToJobID
+ def getJobIDToStages = jobIDToStages
+ def getEventQueue = eventQueue
+
+ new Thread("JobLogger") {
+ setDaemon(true)
+ override def run() {
+ while (true) {
+ val event = eventQueue.take
+ logDebug("Got event of type " + event.getClass.getName)
+ event match {
+ case SparkListenerJobStart(job, properties) =>
+ processJobStartEvent(job, properties)
+ case SparkListenerStageSubmitted(stage, taskSize) =>
+ processStageSubmittedEvent(stage, taskSize)
+ case StageCompleted(stageInfo) =>
+ processStageCompletedEvent(stageInfo)
+ case SparkListenerJobEnd(job, result) =>
+ processJobEndEvent(job, result)
+ case SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics) =>
+ processTaskEndEvent(task, reason, taskInfo, taskMetrics)
+ case _ =>
+ }
+ }
+ }
+ }.start()
+
+ // Create a folder for log files, the folder's name is the creation time of the jobLogger
+ protected def createLogDir() {
+ val dir = new File(logDir + "/" + logDirName + "/")
+ if (dir.exists()) {
+ return
+ }
+ if (dir.mkdirs() == false) {
+ logError("create log directory error:" + logDir + "/" + logDirName + "/")
+ }
+ }
+
+ // Create a log file for one job, the file name is the jobID
+ protected def createLogWriter(jobID: Int) {
+ try{
+ val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
+ jobIDToPrintWriter += (jobID -> fileWriter)
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
+ }
+
+ // Close log file, and clean the stage relationship in stageIDToJobID
+ protected def closeLogWriter(jobID: Int) =
+ jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
+ stageIDToJobID -= stage.id
+ })
+ jobIDToPrintWriter -= jobID
+ jobIDToStages -= jobID
+ }
+
+ // Write log information to log file, withTime parameter controls whether to recored
+ // time stamp for the information
+ protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
+ var writeInfo = info
+ if (withTime) {
+ val date = new Date(System.currentTimeMillis())
+ writeInfo = DATE_FORMAT.format(date) + ": " +info
+ }
+ jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
+ }
+
+ protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
+ stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
+
+ protected def buildJobDep(jobID: Int, stage: Stage) {
+ if (stage.priority == jobID) {
+ jobIDToStages.get(jobID) match {
+ case Some(stageList) => stageList += stage
+ case None => val stageList = new ListBuffer[Stage]
+ stageList += stage
+ jobIDToStages += (jobID -> stageList)
+ }
+ stageIDToJobID += (stage.id -> jobID)
+ stage.parents.foreach(buildJobDep(jobID, _))
+ }
+ }
+
+ protected def recordStageDep(jobID: Int) {
+ def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
+ var rddList = new ListBuffer[RDD[_]]
+ rddList += rdd
+ rdd.dependencies.foreach{ dep => dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ case _ => rddList ++= getRddsInStage(dep.rdd)
+ }
+ }
+ rddList
+ }
+ jobIDToStages.get(jobID).foreach {_.foreach { stage =>
+ var depRddDesc: String = ""
+ getRddsInStage(stage.rdd).foreach { rdd =>
+ depRddDesc += rdd.id + ","
+ }
+ var depStageDesc: String = ""
+ stage.parents.foreach { stage =>
+ depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
+ }
+ jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
+ depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
+ " STAGE_DEP=" + depStageDesc, false)
+ }
+ }
+ }
+
+ // Generate indents and convert to String
+ protected def indentString(indent: Int) = {
+ val sb = new StringBuilder()
+ for (i <- 1 to indent) {
+ sb.append(" ")
+ }
+ sb.toString()
+ }
+
+ protected def getRddName(rdd: RDD[_]) = {
+ var rddName = rdd.getClass.getName
+ if (rdd.name != null) {
+ rddName = rdd.name
+ }
+ rddName
+ }
+
+ protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
+ val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
+ jobLogInfo(jobID, indentString(indent) + rddInfo, false)
+ rdd.dependencies.foreach{ dep => dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
+ }
+ }
+ }
+
+ protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
+ var stageInfo: String = ""
+ if (stage.isShuffleMap) {
+ stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
+ stage.shuffleDep.get.shuffleId
+ }else{
+ stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ }
+ if (stage.priority == jobID) {
+ jobLogInfo(jobID, indentString(indent) + stageInfo, false)
+ recordRddInStageGraph(jobID, stage.rdd, indent)
+ stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
+ } else
+ jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.priority, false)
+ }
+
+ // Record task metrics into job log files
+ protected def recordTaskMetrics(stageID: Int, status: String,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
+ " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
+ " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
+ val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
+ val readMetrics =
+ taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics =
+ taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
+ stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
+ }
+
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
+ eventQueue.put(stageSubmitted)
+ }
+
+ protected def processStageSubmittedEvent(stage: Stage, taskSize: Int) {
+ stageLogInfo(stage.id, "STAGE_ID=" + stage.id + " STATUS=SUBMITTED" + " TASK_SIZE=" + taskSize)
+ }
+
+ override def onStageCompleted(stageCompleted: StageCompleted) {
+ eventQueue.put(stageCompleted)
+ }
+
+ protected def processStageCompletedEvent(stageInfo: StageInfo) {
+ stageLogInfo(stageInfo.stage.id, "STAGE_ID=" +
+ stageInfo.stage.id + " STATUS=COMPLETED")
+
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ eventQueue.put(taskEnd)
+ }
+
+ protected def processTaskEndEvent(task: Task[_], reason: TaskEndReason,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ var taskStatus = ""
+ task match {
+ case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
+ case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
+ }
+ reason match {
+ case Success => taskStatus += " STATUS=SUCCESS"
+ recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskMetrics)
+ case Resubmitted =>
+ taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId
+ stageLogInfo(task.stageId, taskStatus)
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
+ task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
+ mapId + " REDUCE_ID=" + reduceId
+ stageLogInfo(task.stageId, taskStatus)
+ case OtherFailure(message) =>
+ taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId + " INFO=" + message
+ stageLogInfo(task.stageId, taskStatus)
+ case _ =>
+ }
+ }
+
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) {
+ eventQueue.put(jobEnd)
+ }
+
+ protected def processJobEndEvent(job: ActiveJob, reason: JobResult) {
+ var info = "JOB_ID=" + job.runId
+ reason match {
+ case JobSucceeded => info += " STATUS=SUCCESS"
+ case JobFailed(exception) =>
+ info += " STATUS=FAILED REASON="
+ exception.getMessage.split("\\s+").foreach(info += _ + "_")
+ case _ =>
+ }
+ jobLogInfo(job.runId, info.substring(0, info.length - 1).toUpperCase)
+ closeLogWriter(job.runId)
+ }
+
+ protected def recordJobProperties(jobID: Int, properties: Properties) {
+ if(properties != null) {
+ val annotation = properties.getProperty("spark.job.annotation", "")
+ jobLogInfo(jobID, annotation, false)
+ }
+ }
+
+ override def onJobStart(jobStart: SparkListenerJobStart) {
+ eventQueue.put(jobStart)
+ }
+
+ protected def processJobStartEvent(job: ActiveJob, properties: Properties) {
+ createLogWriter(job.runId)
+ recordJobProperties(job.runId, properties)
+ buildJobDep(job.runId, job.finalStage)
+ recordStageDep(job.runId)
+ recordStageDepGraph(job.runId, job.finalStage)
+ jobLogInfo(job.runId, "JOB_ID=" + job.runId + " STATUS=STARTED")
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index beb21a76fe..83166bce22 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -70,6 +70,13 @@ private[spark] class ResultTask[T, U](
rdd.partitions(partition)
}
+ private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.toSet.toSeq
+
+ {
+ // DEBUG code
+ preferredLocs.foreach (hostPort => Utils.checkHost(Utils.parseHostPort(hostPort)._1, "preferredLocs : " + preferredLocs))
+ }
+
override def run(attemptId: Long): U = {
val context = new TaskContext(stageId, partition, attemptId)
metrics = Some(context.taskMetrics)
@@ -80,7 +87,7 @@ private[spark] class ResultTask[T, U](
}
}
- override def preferredLocations: Seq[String] = locs
+ override def preferredLocations: Seq[String] = preferredLocs
override def toString = "ResultTask(" + stageId + ", " + partition + ")"
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 36d087a4d0..95647389c3 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -13,9 +13,10 @@ import com.ning.compress.lzf.LZFInputStream
import com.ning.compress.lzf.LZFOutputStream
import spark._
-import executor.ShuffleWriteMetrics
+import spark.executor.ShuffleWriteMetrics
import spark.storage._
-import util.{TimeStampedHashMap, MetadataCleaner}
+import spark.util.{TimeStampedHashMap, MetadataCleaner}
+
private[spark] object ShuffleMapTask {
@@ -77,13 +78,20 @@ private[spark] class ShuffleMapTask(
var rdd: RDD[_],
var dep: ShuffleDependency[_,_],
var partition: Int,
- @transient var locs: Seq[String])
+ @transient private var locs: Seq[String])
extends Task[MapStatus](stageId)
with Externalizable
with Logging {
protected def this() = this(0, null, null, 0, null)
+ @transient private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.toSet.toSeq
+
+ {
+ // DEBUG code
+ preferredLocs.foreach (hostPort => Utils.checkHost(Utils.parseHostPort(hostPort)._1, "preferredLocs : " + preferredLocs))
+ }
+
var split = if (rdd == null) {
null
} else {
@@ -121,40 +129,58 @@ private[spark] class ShuffleMapTask(
val taskContext = new TaskContext(stageId, partition, attemptId)
metrics = Some(taskContext.taskMetrics)
+
+ val blockManager = SparkEnv.get.blockManager
+ var shuffle: ShuffleBlocks = null
+ var buckets: ShuffleWriterGroup = null
+
try {
- // Partition the map output.
- val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
+ // Obtain all the block writers for shuffle blocks.
+ val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
+ shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
+ buckets = shuffle.acquireWriters(partition)
+
+ // Write the map output to its associated buckets.
for (elem <- rdd.iterator(split, taskContext)) {
val pair = elem.asInstanceOf[(Any, Any)]
val bucketId = dep.partitioner.getPartition(pair._1)
- buckets(bucketId) += pair
+ buckets.writers(bucketId).write(pair)
}
- val compressedSizes = new Array[Byte](numOutputSplits)
-
- var totalBytes = 0l
-
- val blockManager = SparkEnv.get.blockManager
- for (i <- 0 until numOutputSplits) {
- val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
- // Get a Scala iterator from Java map
- val iter: Iterator[(Any, Any)] = buckets(i).iterator
- val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
+ // Commit the writes. Get the size of each bucket block (total block size).
+ var totalBytes = 0L
+ val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
+ writer.commit()
+ writer.close()
+ val size = writer.size()
totalBytes += size
- compressedSizes(i) = MapOutputTracker.compressSize(size)
+ MapOutputTracker.compressSize(size)
}
+
+ // Update shuffle metrics.
val shuffleMetrics = new ShuffleWriteMetrics
shuffleMetrics.shuffleBytesWritten = totalBytes
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
return new MapStatus(blockManager.blockManagerId, compressedSizes)
+ } catch { case e: Exception =>
+ // If there is an exception from running the task, revert the partial writes
+ // and throw the exception upstream to Spark.
+ if (buckets != null) {
+ buckets.writers.foreach(_.revertPartialWrites())
+ }
+ throw e
} finally {
+ // Release the writers back to the shuffle block manager.
+ if (shuffle != null && buckets != null) {
+ shuffle.releaseWriters(buckets)
+ }
// Execute the callbacks on task completion.
taskContext.executeOnCompleteCallbacks()
}
}
- override def preferredLocations: Seq[String] = locs
+ override def preferredLocations: Seq[String] = preferredLocs
override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
}
diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala
index a65140b145..bac984b5c9 100644
--- a/core/src/main/scala/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/spark/scheduler/SparkListener.scala
@@ -1,27 +1,59 @@
package spark.scheduler
+import java.util.Properties
import spark.scheduler.cluster.TaskInfo
import spark.util.Distribution
-import spark.{Utils, Logging}
+import spark.{Logging, SparkContext, TaskEndReason, Utils}
import spark.executor.TaskMetrics
-trait SparkListener {
- /**
- * called when a stage is completed, with information on the completed stage
- */
- def onStageCompleted(stageCompleted: StageCompleted)
-}
-
sealed trait SparkListenerEvents
+case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int) extends SparkListenerEvents
+
case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
+case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics) extends SparkListenerEvents
+
+case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null)
+ extends SparkListenerEvents
+
+case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult)
+ extends SparkListenerEvents
+
+trait SparkListener {
+ /**
+ * Called when a stage is completed, with information on the completed stage
+ */
+ def onStageCompleted(stageCompleted: StageCompleted) { }
+
+ /**
+ * Called when a stage is submitted
+ */
+ def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { }
+
+ /**
+ * Called when a task ends
+ */
+ def onTaskEnd(taskEnd: SparkListenerTaskEnd) { }
+
+ /**
+ * Called when a job starts
+ */
+ def onJobStart(jobStart: SparkListenerJobStart) { }
+
+ /**
+ * Called when a job ends
+ */
+ def onJobEnd(jobEnd: SparkListenerJobEnd) { }
+
+}
/**
* Simple SparkListener that logs a few summary statistics when each stage completes
*/
class StatsReportListener extends SparkListener with Logging {
- def onStageCompleted(stageCompleted: StageCompleted) {
+ override def onStageCompleted(stageCompleted: StageCompleted) {
import spark.scheduler.StatsReportListener._
implicit val sc = stageCompleted
this.logInfo("Finished stage: " + stageCompleted.stageInfo)
diff --git a/core/src/main/scala/spark/scheduler/SplitInfo.scala b/core/src/main/scala/spark/scheduler/SplitInfo.scala
new file mode 100644
index 0000000000..6abfb7a1f7
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/SplitInfo.scala
@@ -0,0 +1,61 @@
+package spark.scheduler
+
+import collection.mutable.ArrayBuffer
+
+// information about a specific split instance : handles both split instances.
+// So that we do not need to worry about the differences.
+class SplitInfo(val inputFormatClazz: Class[_], val hostLocation: String, val path: String,
+ val length: Long, val underlyingSplit: Any) {
+ override def toString(): String = {
+ "SplitInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz +
+ ", hostLocation : " + hostLocation + ", path : " + path +
+ ", length : " + length + ", underlyingSplit " + underlyingSplit
+ }
+
+ override def hashCode(): Int = {
+ var hashCode = inputFormatClazz.hashCode
+ hashCode = hashCode * 31 + hostLocation.hashCode
+ hashCode = hashCode * 31 + path.hashCode
+ // ignore overflow ? It is hashcode anyway !
+ hashCode = hashCode * 31 + (length & 0x7fffffff).toInt
+ hashCode
+ }
+
+ // This is practically useless since most of the Split impl's dont seem to implement equals :-(
+ // So unless there is identity equality between underlyingSplits, it will always fail even if it
+ // is pointing to same block.
+ override def equals(other: Any): Boolean = other match {
+ case that: SplitInfo => {
+ this.hostLocation == that.hostLocation &&
+ this.inputFormatClazz == that.inputFormatClazz &&
+ this.path == that.path &&
+ this.length == that.length &&
+ // other split specific checks (like start for FileSplit)
+ this.underlyingSplit == that.underlyingSplit
+ }
+ case _ => false
+ }
+}
+
+object SplitInfo {
+
+ def toSplitInfo(inputFormatClazz: Class[_], path: String,
+ mapredSplit: org.apache.hadoop.mapred.InputSplit): Seq[SplitInfo] = {
+ val retval = new ArrayBuffer[SplitInfo]()
+ val length = mapredSplit.getLength
+ for (host <- mapredSplit.getLocations) {
+ retval += new SplitInfo(inputFormatClazz, host, path, length, mapredSplit)
+ }
+ retval
+ }
+
+ def toSplitInfo(inputFormatClazz: Class[_], path: String,
+ mapreduceSplit: org.apache.hadoop.mapreduce.InputSplit): Seq[SplitInfo] = {
+ val retval = new ArrayBuffer[SplitInfo]()
+ val length = mapreduceSplit.getLength
+ for (host <- mapreduceSplit.getLocations) {
+ retval += new SplitInfo(inputFormatClazz, host, path, length, mapreduceSplit)
+ }
+ retval
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala
index 552061e46b..7fc9e13fd9 100644
--- a/core/src/main/scala/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/spark/scheduler/Stage.scala
@@ -26,7 +26,7 @@ private[spark] class Stage(
val parents: List[Stage],
val priority: Int)
extends Logging {
-
+
val isShuffleMap = shuffleDep != None
val numPartitions = rdd.partitions.size
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
@@ -60,7 +60,7 @@ private[spark] class Stage(
numAvailableOutputs -= 1
}
}
-
+
def removeOutputsOnExecutor(execId: String) {
var becameUnavailable = false
for (partition <- 0 until numPartitions) {
diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala
index d549b184b0..7787b54762 100644
--- a/core/src/main/scala/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/TaskScheduler.scala
@@ -10,6 +10,10 @@ package spark.scheduler
private[spark] trait TaskScheduler {
def start(): Unit
+ // Invoked after system has successfully initialized (typically in spark context).
+ // Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc.
+ def postStartHook() { }
+
// Disconnect from the cluster.
def stop(): Unit
diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
index 771518dddf..b75d3736cf 100644
--- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
+++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
@@ -14,6 +14,9 @@ private[spark] trait TaskSchedulerListener {
def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any],
taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit
+ // A node was added to the cluster.
+ def executorGained(execId: String, hostPort: String): Unit
+
// A node was lost from the cluster.
def executorLost(execId: String): Unit
diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala
index a3002ca477..e4b5fcaedb 100644
--- a/core/src/main/scala/spark/scheduler/TaskSet.scala
+++ b/core/src/main/scala/spark/scheduler/TaskSet.scala
@@ -1,11 +1,18 @@
package spark.scheduler
+import java.util.Properties
+
/**
* A set of tasks submitted together to the low-level TaskScheduler, usually representing
* missing partitions of a particular stage.
*/
-private[spark] class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) {
- val id: String = stageId + "." + attempt
+private[spark] class TaskSet(
+ val tasks: Array[Task[_]],
+ val stageId: Int,
+ val attempt: Int,
+ val priority: Int,
+ val properties: Properties) {
+ val id: String = stageId + "." + attempt
override def toString: String = "TaskSet " + id
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index 26fdef101b..3a0c29b27f 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -1,6 +1,6 @@
package spark.scheduler.cluster
-import java.io.{File, FileInputStream, FileOutputStream}
+import java.lang.{Boolean => JBoolean}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@@ -25,17 +25,45 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
// Threshold above which we warn user initial TaskSet may be starved
val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong
+ // How often to revive offers in case there are pending tasks - that is how often to try to get
+ // tasks scheduled in case there are nodes available : default 0 is to disable it - to preserve existing behavior
+ // Note that this is required due to delayed scheduling due to data locality waits, etc.
+ // TODO: rename property ?
+ val TASK_REVIVAL_INTERVAL = System.getProperty("spark.tasks.revive.interval", "0").toLong
+
+ /*
+ This property controls how aggressive we should be to modulate waiting for node local task scheduling.
+ To elaborate, currently there is a time limit (3 sec def) to ensure that spark attempts to wait for node locality of tasks before
+ scheduling on other nodes. We have modified this in yarn branch such that offers to task set happen in prioritized order :
+ node-local, rack-local and then others
+ But once all available node local (and no pref) tasks are scheduled, instead of waiting for 3 sec before
+ scheduling to other nodes (which degrades performance for time sensitive tasks and on larger clusters), we can
+ modulate that : to also allow rack local nodes or any node. The default is still set to HOST - so that previous behavior is
+ maintained. This is to allow tuning the tension between pulling rdd data off node and scheduling computation asap.
+
+ TODO: rename property ? The value is one of
+ - NODE_LOCAL (default, no change w.r.t current behavior),
+ - RACK_LOCAL and
+ - ANY
+
+ Note that this property makes more sense when used in conjugation with spark.tasks.revive.interval > 0 : else it is not very effective.
+
+ Additional Note: For non trivial clusters, there is a 4x - 5x reduction in running time (in some of our experiments) based on whether
+ it is left at default NODE_LOCAL, RACK_LOCAL (if cluster is configured to be rack aware) or ANY.
+ If cluster is rack aware, then setting it to RACK_LOCAL gives best tradeoff and a 3x - 4x performance improvement while minimizing IO impact.
+ Also, it brings down the variance in running time drastically.
+ */
+ val TASK_SCHEDULING_AGGRESSION = TaskLocality.parse(System.getProperty("spark.tasks.schedule.aggression", "NODE_LOCAL"))
val activeTaskSets = new HashMap[String, TaskSetManager]
- var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
val taskIdToTaskSetId = new HashMap[Long, String]
val taskIdToExecutorId = new HashMap[Long, String]
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
- var hasReceivedTask = false
- var hasLaunchedTask = false
- val starvationTimer = new Timer(true)
+ @volatile private var hasReceivedTask = false
+ @volatile private var hasLaunchedTask = false
+ private val starvationTimer = new Timer(true)
// Incrementing Mesos task IDs
val nextTaskId = new AtomicLong(0)
@@ -43,11 +71,16 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// Which executor IDs we have executors on
val activeExecutorIds = new HashSet[String]
+ // TODO: We might want to remove this and merge it with execId datastructures - but later.
+ // Which hosts in the cluster are alive (contains hostPort's) - used for process local and node local task locality.
+ private val hostPortsAlive = new HashSet[String]
+ private val hostToAliveHostPorts = new HashMap[String, HashSet[String]]
+
// The set of executors we have on each host; this is used to compute hostsAlive, which
// in turn is used to decide when we can attain data locality on a given host
- val executorsByHost = new HashMap[String, HashSet[String]]
+ private val executorsByHostPort = new HashMap[String, HashSet[String]]
- val executorIdToHost = new HashMap[String, String]
+ private val executorIdToHostPort = new HashMap[String, String]
// JAR server, if any JARs were added by the user to the SparkContext
var jarServer: HttpServer = null
@@ -62,24 +95,50 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
val mapOutputTracker = SparkEnv.get.mapOutputTracker
+ var schedulableBuilder: SchedulableBuilder = null
+ var rootPool: Pool = null
+
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
}
def initialize(context: SchedulerBackend) {
backend = context
+ //default scheduler is FIFO
+ val schedulingMode = System.getProperty("spark.cluster.schedulingmode", "FIFO")
+ //temporarily set rootPool name to empty
+ rootPool = new Pool("", SchedulingMode.withName(schedulingMode), 0, 0)
+ schedulableBuilder = {
+ schedulingMode match {
+ case "FIFO" =>
+ new FIFOSchedulableBuilder(rootPool)
+ case "FAIR" =>
+ new FairSchedulableBuilder(rootPool)
+ }
+ }
+ schedulableBuilder.buildPools()
+ // resolve executorId to hostPort mapping.
+ def executorToHostPort(executorId: String, defaultHostPort: String): String = {
+ executorIdToHostPort.getOrElse(executorId, defaultHostPort)
+ }
+
+ // Unfortunately, this means that SparkEnv is indirectly referencing ClusterScheduler
+ // Will that be a design violation ?
+ SparkEnv.get.executorIdToHostPort = Some(executorToHostPort)
}
+
def newTaskId(): Long = nextTaskId.getAndIncrement()
override def start() {
backend.start()
- if (System.getProperty("spark.speculation", "false") == "true") {
+ if (JBoolean.getBoolean("spark.speculation")) {
new Thread("ClusterScheduler speculation check") {
setDaemon(true)
override def run() {
+ logInfo("Starting speculative execution thread")
while (true) {
try {
Thread.sleep(SPECULATION_INTERVAL)
@@ -91,15 +150,36 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}.start()
}
+
+
+ // Change to always run with some default if TASK_REVIVAL_INTERVAL <= 0 ?
+ if (TASK_REVIVAL_INTERVAL > 0) {
+ new Thread("ClusterScheduler task offer revival check") {
+ setDaemon(true)
+
+ override def run() {
+ logInfo("Starting speculative task offer revival thread")
+ while (true) {
+ try {
+ Thread.sleep(TASK_REVIVAL_INTERVAL)
+ } catch {
+ case e: InterruptedException => {}
+ }
+
+ if (hasPendingTasks()) backend.reviveOffers()
+ }
+ }
+ }.start()
+ }
}
override def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
- val manager = new TaskSetManager(this, taskSet)
+ val manager = new ClusterTaskSetManager(this, taskSet)
activeTaskSets(taskSet.id) = manager
- activeTaskSetsQueue += manager
+ schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
if (hasReceivedTask == false) {
@@ -122,7 +202,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
def taskSetFinished(manager: TaskSetManager) {
this.synchronized {
activeTaskSets -= manager.taskSet.id
- activeTaskSetsQueue -= manager
+ manager.parent.removeSchedulable(manager)
+ logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
taskSetTaskIds.remove(manager.taskSet.id)
@@ -139,22 +220,128 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
SparkEnv.set(sc.env)
// Mark each slave as alive and remember its hostname
for (o <- offers) {
- executorIdToHost(o.executorId) = o.hostname
- if (!executorsByHost.contains(o.hostname)) {
- executorsByHost(o.hostname) = new HashSet()
+ // DEBUG Code
+ Utils.checkHostPort(o.hostPort)
+
+ executorIdToHostPort(o.executorId) = o.hostPort
+ if (! executorsByHostPort.contains(o.hostPort)) {
+ executorsByHostPort(o.hostPort) = new HashSet[String]()
}
+
+ hostPortsAlive += o.hostPort
+ hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(o.hostPort)._1, new HashSet[String]).add(o.hostPort)
+ executorGained(o.executorId, o.hostPort)
}
// Build a list of tasks to assign to each slave
val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
+ // merge availableCpus into nodeToAvailableCpus block ?
val availableCpus = offers.map(o => o.cores).toArray
+ val nodeToAvailableCpus = {
+ val map = new HashMap[String, Int]()
+ for (offer <- offers) {
+ val hostPort = offer.hostPort
+ val cores = offer.cores
+ // DEBUG code
+ Utils.checkHostPort(hostPort)
+
+ val host = Utils.parseHostPort(hostPort)._1
+
+ map.put(host, map.getOrElse(host, 0) + cores)
+ }
+
+ map
+ }
var launchedTask = false
- for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) {
+ val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
+ for (manager <- sortedTaskSetQueue)
+ {
+ logInfo("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks))
+ }
+ for (manager <- sortedTaskSetQueue) {
+
+ // Split offers based on node local, rack local and off-rack tasks.
+ val processLocalOffers = new HashMap[String, ArrayBuffer[Int]]()
+ val nodeLocalOffers = new HashMap[String, ArrayBuffer[Int]]()
+ val rackLocalOffers = new HashMap[String, ArrayBuffer[Int]]()
+ val otherOffers = new HashMap[String, ArrayBuffer[Int]]()
+
+ for (i <- 0 until offers.size) {
+ val hostPort = offers(i).hostPort
+ // DEBUG code
+ Utils.checkHostPort(hostPort)
+
+ val numProcessLocalTasks = math.max(0, math.min(manager.numPendingTasksForHostPort(hostPort), availableCpus(i)))
+ if (numProcessLocalTasks > 0){
+ val list = processLocalOffers.getOrElseUpdate(hostPort, new ArrayBuffer[Int])
+ for (j <- 0 until numProcessLocalTasks) list += i
+ }
+
+ val host = Utils.parseHostPort(hostPort)._1
+ val numNodeLocalTasks = math.max(0,
+ // Remove process local tasks (which are also host local btw !) from this
+ math.min(manager.numPendingTasksForHost(hostPort) - numProcessLocalTasks, nodeToAvailableCpus(host)))
+ if (numNodeLocalTasks > 0){
+ val list = nodeLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
+ for (j <- 0 until numNodeLocalTasks) list += i
+ }
+
+ val numRackLocalTasks = math.max(0,
+ // Remove node local tasks (which are also rack local btw !) from this
+ math.min(manager.numRackLocalPendingTasksForHost(hostPort) - numProcessLocalTasks - numNodeLocalTasks, nodeToAvailableCpus(host)))
+ if (numRackLocalTasks > 0){
+ val list = rackLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
+ for (j <- 0 until numRackLocalTasks) list += i
+ }
+ if (numNodeLocalTasks <= 0 && numRackLocalTasks <= 0){
+ // add to others list - spread even this across cluster.
+ val list = otherOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
+ list += i
+ }
+ }
+
+ val offersPriorityList = new ArrayBuffer[Int](
+ processLocalOffers.size + nodeLocalOffers.size + rackLocalOffers.size + otherOffers.size)
+
+ // First process local, then host local, then rack, then others
+
+ // numNodeLocalOffers contains count of both process local and host offers.
+ val numNodeLocalOffers = {
+ val processLocalPriorityList = ClusterScheduler.prioritizeContainers(processLocalOffers)
+ offersPriorityList ++= processLocalPriorityList
+
+ val nodeLocalPriorityList = ClusterScheduler.prioritizeContainers(nodeLocalOffers)
+ offersPriorityList ++= nodeLocalPriorityList
+
+ processLocalPriorityList.size + nodeLocalPriorityList.size
+ }
+ val numRackLocalOffers = {
+ val rackLocalPriorityList = ClusterScheduler.prioritizeContainers(rackLocalOffers)
+ offersPriorityList ++= rackLocalPriorityList
+ rackLocalPriorityList.size
+ }
+ offersPriorityList ++= ClusterScheduler.prioritizeContainers(otherOffers)
+
+ var lastLoop = false
+ val lastLoopIndex = TASK_SCHEDULING_AGGRESSION match {
+ case TaskLocality.NODE_LOCAL => numNodeLocalOffers
+ case TaskLocality.RACK_LOCAL => numRackLocalOffers + numNodeLocalOffers
+ case TaskLocality.ANY => offersPriorityList.size
+ }
+
do {
launchedTask = false
- for (i <- 0 until offers.size) {
+ var loopCount = 0
+ for (i <- offersPriorityList) {
val execId = offers(i).executorId
- val host = offers(i).hostname
- manager.slaveOffer(execId, host, availableCpus(i)) match {
+ val hostPort = offers(i).hostPort
+
+ // If last loop and within the lastLoopIndex, expand scope - else use null (which will use default/existing)
+ val overrideLocality = if (lastLoop && loopCount < lastLoopIndex) TASK_SCHEDULING_AGGRESSION else null
+
+ // If last loop, override waiting for host locality - we scheduled all local tasks already and there might be more available ...
+ loopCount += 1
+
+ manager.slaveOffer(execId, hostPort, availableCpus(i), overrideLocality) match {
case Some(task) =>
tasks(i) += task
val tid = task.taskId
@@ -162,15 +349,31 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
taskSetTaskIds(manager.taskSet.id) += tid
taskIdToExecutorId(tid) = execId
activeExecutorIds += execId
- executorsByHost(host) += execId
+ executorsByHostPort(hostPort) += execId
availableCpus(i) -= 1
launchedTask = true
case None => {}
+ }
+ }
+ // Loop once more - when lastLoop = true, then we try to schedule task on all nodes irrespective of
+ // data locality (we still go in order of priority : but that would not change anything since
+ // if data local tasks had been available, we would have scheduled them already)
+ if (lastLoop) {
+ // prevent more looping
+ launchedTask = false
+ } else if (!lastLoop && !launchedTask) {
+ // Do this only if TASK_SCHEDULING_AGGRESSION != NODE_LOCAL
+ if (TASK_SCHEDULING_AGGRESSION != TaskLocality.NODE_LOCAL) {
+ // fudge launchedTask to ensure we loop once more
+ launchedTask = true
+ // dont loop anymore
+ lastLoop = true
}
}
} while (launchedTask)
}
+
if (tasks.size > 0) {
hasLaunchedTask = true
}
@@ -223,6 +426,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
backend.reviveOffers()
}
if (taskFailed) {
+
// Also revive offers if a task had failed for some reason other than host lost
backend.reviveOffers()
}
@@ -256,29 +460,40 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (jarServer != null) {
jarServer.stop()
}
+
+ // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out.
+ // TODO: Do something better !
+ Thread.sleep(5000L)
}
override def defaultParallelism() = backend.defaultParallelism()
+
// Check for speculatable tasks in all our active jobs.
def checkSpeculatableTasks() {
var shouldRevive = false
synchronized {
- for (ts <- activeTaskSetsQueue) {
- shouldRevive |= ts.checkSpeculatableTasks()
- }
+ shouldRevive = rootPool.checkSpeculatableTasks()
}
if (shouldRevive) {
backend.reviveOffers()
}
}
+ // Check for pending tasks in all our active jobs.
+ def hasPendingTasks(): Boolean = {
+ synchronized {
+ rootPool.hasPendingTasks()
+ }
+ }
+
def executorLost(executorId: String, reason: ExecutorLossReason) {
var failedExecutor: Option[String] = None
+
synchronized {
if (activeExecutorIds.contains(executorId)) {
- val host = executorIdToHost(executorId)
- logError("Lost executor %s on %s: %s".format(executorId, host, reason))
+ val hostPort = executorIdToHostPort(executorId)
+ logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason))
removeExecutor(executorId)
failedExecutor = Some(executorId)
} else {
@@ -296,19 +511,104 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
- /** Get a list of hosts that currently have executors */
- def hostsAlive: scala.collection.Set[String] = executorsByHost.keySet
-
/** Remove an executor from all our data structures and mark it as lost */
private def removeExecutor(executorId: String) {
activeExecutorIds -= executorId
- val host = executorIdToHost(executorId)
- val execs = executorsByHost.getOrElse(host, new HashSet)
+ val hostPort = executorIdToHostPort(executorId)
+ if (hostPortsAlive.contains(hostPort)) {
+ // DEBUG Code
+ Utils.checkHostPort(hostPort)
+
+ hostPortsAlive -= hostPort
+ hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(hostPort)._1, new HashSet[String]).remove(hostPort)
+ }
+
+ val execs = executorsByHostPort.getOrElse(hostPort, new HashSet)
execs -= executorId
if (execs.isEmpty) {
- executorsByHost -= host
+ executorsByHostPort -= hostPort
}
- executorIdToHost -= executorId
- activeTaskSetsQueue.foreach(_.executorLost(executorId, host))
+ executorIdToHostPort -= executorId
+ rootPool.executorLost(executorId, hostPort)
+ }
+
+ def executorGained(execId: String, hostPort: String) {
+ listener.executorGained(execId, hostPort)
+ }
+
+ def getExecutorsAliveOnHost(host: String): Option[Set[String]] = {
+ Utils.checkHost(host)
+
+ val retval = hostToAliveHostPorts.get(host)
+ if (retval.isDefined) {
+ return Some(retval.get.toSet)
+ }
+
+ None
+ }
+
+ def isExecutorAliveOnHostPort(hostPort: String): Boolean = {
+ // Even if hostPort is a host, it does not matter - it is just a specific check.
+ // But we do have to ensure that only hostPort get into hostPortsAlive !
+ // So no check against Utils.checkHostPort
+ hostPortsAlive.contains(hostPort)
+ }
+
+ // By default, rack is unknown
+ def getRackForHost(value: String): Option[String] = None
+
+ // By default, (cached) hosts for rack is unknown
+ def getCachedHostsForRack(rack: String): Option[Set[String]] = None
+}
+
+
+object ClusterScheduler {
+
+ // Used to 'spray' available containers across the available set to ensure too many containers on same host
+ // are not used up. Used in yarn mode and in task scheduling (when there are multiple containers available
+ // to execute a task)
+ // For example: yarn can returns more containers than we would have requested under ANY, this method
+ // prioritizes how to use the allocated containers.
+ // flatten the map such that the array buffer entries are spread out across the returned value.
+ // given <host, list[container]> == <h1, [c1 .. c5]>, <h2, [c1 .. c3]>, <h3, [c1, c2]>, <h4, c1>, <h5, c1>, i
+ // the return value would be something like : h1c1, h2c1, h3c1, h4c1, h5c1, h1c2, h2c2, h3c2, h1c3, h2c3, h1c4, h1c5
+ // We then 'use' the containers in this order (consuming only the top K from this list where
+ // K = number to be user). This is to ensure that if we have multiple eligible allocations,
+ // they dont end up allocating all containers on a small number of hosts - increasing probability of
+ // multiple container failure when a host goes down.
+ // Note, there is bias for keys with higher number of entries in value to be picked first (by design)
+ // Also note that invocation of this method is expected to have containers of same 'type'
+ // (host-local, rack-local, off-rack) and not across types : so that reordering is simply better from
+ // the available list - everything else being same.
+ // That is, we we first consume data local, then rack local and finally off rack nodes. So the
+ // prioritization from this method applies to within each category
+ def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = {
+ val _keyList = new ArrayBuffer[K](map.size)
+ _keyList ++= map.keys
+
+ // order keyList based on population of value in map
+ val keyList = _keyList.sortWith(
+ (left, right) => map.get(left).getOrElse(Set()).size > map.get(right).getOrElse(Set()).size
+ )
+
+ val retval = new ArrayBuffer[T](keyList.size * 2)
+ var index = 0
+ var found = true
+
+ while (found){
+ found = false
+ for (key <- keyList) {
+ val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null)
+ assert(containerList != null)
+ // Get the index'th entry for this host - if present
+ if (index < containerList.size){
+ retval += containerList.apply(index)
+ found = true
+ }
+ }
+ index += 1
+ }
+
+ retval.toList
}
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala
new file mode 100644
index 0000000000..d72b0bfc9f
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -0,0 +1,747 @@
+package spark.scheduler.cluster
+
+import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.math.max
+import scala.math.min
+
+import spark._
+import spark.scheduler._
+import spark.TaskState.TaskState
+import java.nio.ByteBuffer
+
+private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging {
+
+ // process local is expected to be used ONLY within tasksetmanager for now.
+ val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value
+
+ type TaskLocality = Value
+
+ def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = {
+
+ // Must not be the constraint.
+ assert (constraint != TaskLocality.PROCESS_LOCAL)
+
+ constraint match {
+ case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL
+ case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL
+ // For anything else, allow
+ case _ => true
+ }
+ }
+
+ def parse(str: String): TaskLocality = {
+ // better way to do this ?
+ try {
+ val retval = TaskLocality.withName(str)
+ // Must not specify PROCESS_LOCAL !
+ assert (retval != TaskLocality.PROCESS_LOCAL)
+
+ retval
+ } catch {
+ case nEx: NoSuchElementException => {
+ logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL");
+ // default to preserve earlier behavior
+ NODE_LOCAL
+ }
+ }
+ }
+}
+
+/**
+ * Schedules the tasks within a single TaskSet in the ClusterScheduler.
+ */
+private[spark] class ClusterTaskSetManager(
+ sched: ClusterScheduler,
+ val taskSet: TaskSet)
+ extends TaskSetManager
+ with Logging {
+
+ // Maximum time to wait to run a task in a preferred location (in ms)
+ val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
+
+ // CPUs to request per task
+ val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble
+
+ // Maximum times a task is allowed to fail before failing the job
+ val MAX_TASK_FAILURES = 4
+
+ // Quantile of tasks at which to start speculation
+ val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
+ val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
+
+ // Serializer for closures and tasks.
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+
+ val tasks = taskSet.tasks
+ val numTasks = tasks.length
+ val copiesRunning = new Array[Int](numTasks)
+ val finished = new Array[Boolean](numTasks)
+ val numFailures = new Array[Int](numTasks)
+ val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
+ var tasksFinished = 0
+
+ var weight = 1
+ var minShare = 0
+ var runningTasks = 0
+ var priority = taskSet.priority
+ var stageId = taskSet.stageId
+ var name = "TaskSet_"+taskSet.stageId.toString
+ var parent:Schedulable = null
+
+ // Last time when we launched a preferred task (for delay scheduling)
+ var lastPreferredLaunchTime = System.currentTimeMillis
+
+ // List of pending tasks for each node (process local to container). These collections are actually
+ // treated as stacks, in which new tasks are added to the end of the
+ // ArrayBuffer and removed from the end. This makes it faster to detect
+ // tasks that repeatedly fail because whenever a task failed, it is put
+ // back at the head of the stack. They are also only cleaned up lazily;
+ // when a task is launched, it remains in all the pending lists except
+ // the one that it was launched from, but gets removed from them later.
+ private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]]
+
+ // List of pending tasks for each node.
+ // Essentially, similar to pendingTasksForHostPort, except at host level
+ private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
+
+ // List of pending tasks for each node based on rack locality.
+ // Essentially, similar to pendingTasksForHost, except at rack level
+ private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]]
+
+ // List containing pending tasks with no locality preferences
+ val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
+
+ // List containing all pending tasks (also used as a stack, as above)
+ val allPendingTasks = new ArrayBuffer[Int]
+
+ // Tasks that can be speculated. Since these will be a small fraction of total
+ // tasks, we'll just hold them in a HashSet.
+ val speculatableTasks = new HashSet[Int]
+
+ // Task index, start and finish time for each task attempt (indexed by task ID)
+ val taskInfos = new HashMap[Long, TaskInfo]
+
+ // Did the job fail?
+ var failed = false
+ var causeOfFailure = ""
+
+ // How frequently to reprint duplicate exceptions in full, in milliseconds
+ val EXCEPTION_PRINT_INTERVAL =
+ System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
+ // Map of recent exceptions (identified by string representation and
+ // top stack frame) to duplicate count (how many times the same
+ // exception has appeared) and time the full exception was
+ // printed. This should ideally be an LRU map that can drop old
+ // exceptions automatically.
+ val recentExceptions = HashMap[String, (Int, Long)]()
+
+ // Figure out the current map output tracker generation and set it on all tasks
+ val generation = sched.mapOutputTracker.getGeneration
+ logDebug("Generation for " + taskSet.id + ": " + generation)
+ for (t <- tasks) {
+ t.generation = generation
+ }
+
+ // Add all our tasks to the pending lists. We do this in reverse order
+ // of task index so that tasks with low indices get launched first.
+ for (i <- (0 until numTasks).reverse) {
+ addPendingTask(i)
+ }
+
+ // Note that it follows the hierarchy.
+ // if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and
+ // if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL
+ private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler,
+ taskLocality: TaskLocality.TaskLocality): HashSet[String] = {
+
+ if (TaskLocality.PROCESS_LOCAL == taskLocality) {
+ // straight forward comparison ! Special case it.
+ val retval = new HashSet[String]()
+ scheduler.synchronized {
+ for (location <- _taskPreferredLocations) {
+ if (scheduler.isExecutorAliveOnHostPort(location)) {
+ retval += location
+ }
+ }
+ }
+
+ return retval
+ }
+
+ val taskPreferredLocations =
+ if (TaskLocality.NODE_LOCAL == taskLocality) {
+ _taskPreferredLocations
+ } else {
+ assert (TaskLocality.RACK_LOCAL == taskLocality)
+ // Expand set to include all 'seen' rack local hosts.
+ // This works since container allocation/management happens within master - so any rack locality information is updated in msater.
+ // Best case effort, and maybe sort of kludge for now ... rework it later ?
+ val hosts = new HashSet[String]
+ _taskPreferredLocations.foreach(h => {
+ val rackOpt = scheduler.getRackForHost(h)
+ if (rackOpt.isDefined) {
+ val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get)
+ if (hostsOpt.isDefined) {
+ hosts ++= hostsOpt.get
+ }
+ }
+
+ // Ensure that irrespective of what scheduler says, host is always added !
+ hosts += h
+ })
+
+ hosts
+ }
+
+ val retval = new HashSet[String]
+ scheduler.synchronized {
+ for (prefLocation <- taskPreferredLocations) {
+ val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1)
+ if (aliveLocationsOpt.isDefined) {
+ retval ++= aliveLocationsOpt.get
+ }
+ }
+ }
+
+ retval
+ }
+
+ // Add a task to all the pending-task lists that it should be on.
+ private def addPendingTask(index: Int) {
+ // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate
+ // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it.
+ val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL)
+ val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
+ val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
+
+ if (rackLocalLocations.size == 0) {
+ // Current impl ensures this.
+ assert (processLocalLocations.size == 0)
+ assert (hostLocalLocations.size == 0)
+ pendingTasksWithNoPrefs += index
+ } else {
+
+ // process local locality
+ for (hostPort <- processLocalLocations) {
+ // DEBUG Code
+ Utils.checkHostPort(hostPort)
+
+ val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer())
+ hostPortList += index
+ }
+
+ // host locality (includes process local)
+ for (hostPort <- hostLocalLocations) {
+ // DEBUG Code
+ Utils.checkHostPort(hostPort)
+
+ val host = Utils.parseHostPort(hostPort)._1
+ val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
+ hostList += index
+ }
+
+ // rack locality (includes process local and host local)
+ for (rackLocalHostPort <- rackLocalLocations) {
+ // DEBUG Code
+ Utils.checkHostPort(rackLocalHostPort)
+
+ val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1
+ val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer())
+ list += index
+ }
+ }
+
+ allPendingTasks += index
+ }
+
+ // Return the pending tasks list for a given host port (process local), or an empty list if
+ // there is no map entry for that host
+ private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = {
+ // DEBUG Code
+ Utils.checkHostPort(hostPort)
+ pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer())
+ }
+
+ // Return the pending tasks list for a given host, or an empty list if
+ // there is no map entry for that host
+ private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ pendingTasksForHost.getOrElse(host, ArrayBuffer())
+ }
+
+ // Return the pending tasks (rack level) list for a given host, or an empty list if
+ // there is no map entry for that host
+ private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer())
+ }
+
+ // Number of pending tasks for a given host Port (which would be process local)
+ def numPendingTasksForHostPort(hostPort: String): Int = {
+ getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
+ }
+
+ // Number of pending tasks for a given host (which would be data local)
+ def numPendingTasksForHost(hostPort: String): Int = {
+ getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
+ }
+
+ // Number of pending rack local tasks for a given host
+ def numRackLocalPendingTasksForHost(hostPort: String): Int = {
+ getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
+ }
+
+
+ // Dequeue a pending task from the given list and return its index.
+ // Return None if the list is empty.
+ // This method also cleans up any tasks in the list that have already
+ // been launched, since we want that to happen lazily.
+ private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
+ while (!list.isEmpty) {
+ val index = list.last
+ list.trimEnd(1)
+ if (copiesRunning(index) == 0 && !finished(index)) {
+ return Some(index)
+ }
+ }
+ return None
+ }
+
+ // Return a speculative task for a given host if any are available. The task should not have an
+ // attempt running on this host, in case the host is slow. In addition, if locality is set, the
+ // task must have a preference for this host/rack/no preferred locations at all.
+ private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
+
+ assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL))
+ speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
+
+ if (speculatableTasks.size > 0) {
+ val localTask = speculatableTasks.find {
+ index =>
+ val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
+ val attemptLocs = taskAttempts(index).map(_.hostPort)
+ (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort)
+ }
+
+ if (localTask != None) {
+ speculatableTasks -= localTask.get
+ return localTask
+ }
+
+ // check for rack locality
+ if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+ val rackTask = speculatableTasks.find {
+ index =>
+ val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
+ val attemptLocs = taskAttempts(index).map(_.hostPort)
+ locations.contains(hostPort) && !attemptLocs.contains(hostPort)
+ }
+
+ if (rackTask != None) {
+ speculatableTasks -= rackTask.get
+ return rackTask
+ }
+ }
+
+ // Any task ...
+ if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+ // Check for attemptLocs also ?
+ val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort))
+ if (nonLocalTask != None) {
+ speculatableTasks -= nonLocalTask.get
+ return nonLocalTask
+ }
+ }
+ }
+ return None
+ }
+
+ // Dequeue a pending task for a given node and return its index.
+ // If localOnly is set to false, allow non-local tasks as well.
+ private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
+ val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort))
+ if (processLocalTask != None) {
+ return processLocalTask
+ }
+
+ val localTask = findTaskFromList(getPendingTasksForHost(hostPort))
+ if (localTask != None) {
+ return localTask
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+ val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort))
+ if (rackLocalTask != None) {
+ return rackLocalTask
+ }
+ }
+
+ // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner.
+ // TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down).
+ val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
+ if (noPrefTask != None) {
+ return noPrefTask
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+ val nonLocalTask = findTaskFromList(allPendingTasks)
+ if (nonLocalTask != None) {
+ return nonLocalTask
+ }
+ }
+
+ // Finally, if all else has failed, find a speculative task
+ return findSpeculativeTask(hostPort, locality)
+ }
+
+ private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = {
+ Utils.checkHostPort(hostPort)
+
+ val locs = task.preferredLocations
+
+ locs.contains(hostPort)
+ }
+
+ private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = {
+ val locs = task.preferredLocations
+
+ // If no preference, consider it as host local
+ if (locs.isEmpty) return true
+
+ val host = Utils.parseHostPort(hostPort)._1
+ locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined
+ }
+
+ // Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location).
+ // This is true if either the task has preferred locations and this host is one, or it has
+ // no preferred locations (in which we still count the launch as preferred).
+ private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = {
+
+ val locs = task.preferredLocations
+
+ val preferredRacks = new HashSet[String]()
+ for (preferredHost <- locs) {
+ val rack = sched.getRackForHost(preferredHost)
+ if (None != rack) preferredRacks += rack.get
+ }
+
+ if (preferredRacks.isEmpty) return false
+
+ val hostRack = sched.getRackForHost(hostPort)
+
+ return None != hostRack && preferredRacks.contains(hostRack.get)
+ }
+
+ // Respond to an offer of a single slave from the scheduler by finding a task
+ def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
+
+ if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
+ // If explicitly specified, use that
+ val locality = if (overrideLocality != null) overrideLocality else {
+ // expand only if we have waited for more than LOCALITY_WAIT for a host local task ...
+ val time = System.currentTimeMillis
+ if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY
+ }
+
+ findTask(hostPort, locality) match {
+ case Some(index) => {
+ // Found a task; do some bookkeeping and return a Mesos task for it
+ val task = tasks(index)
+ val taskId = sched.newTaskId()
+ // Figure out whether this should count as a preferred launch
+ val taskLocality =
+ if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else
+ if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else
+ if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else
+ TaskLocality.ANY
+ val prefStr = taskLocality.toString
+ logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
+ taskSet.id, index, taskId, execId, hostPort, prefStr))
+ // Do various bookkeeping
+ copiesRunning(index) += 1
+ val time = System.currentTimeMillis
+ val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality)
+ taskInfos(taskId) = info
+ taskAttempts(index) = info :: taskAttempts(index)
+ if (TaskLocality.NODE_LOCAL == taskLocality) {
+ lastPreferredLaunchTime = time
+ }
+ // Serialize and return the task
+ val startTime = System.currentTimeMillis
+ val serializedTask = Task.serializeWithDependencies(
+ task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ val timeTaken = System.currentTimeMillis - startTime
+ increaseRunningTasks(1)
+ logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
+ taskSet.id, index, serializedTask.limit, timeTaken))
+ val taskName = "task %s:%d".format(taskSet.id, index)
+ return Some(new TaskDescription(taskId, execId, taskName, serializedTask))
+ }
+ case _ =>
+ }
+ }
+ return None
+ }
+
+ def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ state match {
+ case TaskState.FINISHED =>
+ taskFinished(tid, state, serializedData)
+ case TaskState.LOST =>
+ taskLost(tid, state, serializedData)
+ case TaskState.FAILED =>
+ taskLost(tid, state, serializedData)
+ case TaskState.KILLED =>
+ taskLost(tid, state, serializedData)
+ case _ =>
+ }
+ }
+
+ def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ if (info.failed) {
+ // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
+ // or even from Mesos itself when acks get delayed.
+ return
+ }
+ val index = info.index
+ info.markSuccessful()
+ decreaseRunningTasks(1)
+ if (!finished(index)) {
+ tasksFinished += 1
+ logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
+ tid, info.duration, tasksFinished, numTasks))
+ // Deserialize task result and pass it to the scheduler
+ try {
+ val result = ser.deserialize[TaskResult[_]](serializedData)
+ result.metrics.resultSize = serializedData.limit()
+ sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
+ } catch {
+ case cnf: ClassNotFoundException =>
+ val loader = Thread.currentThread().getContextClassLoader
+ throw new SparkException("ClassNotFound with classloader: " + loader, cnf)
+ case ex => throw ex
+ }
+ // Mark finished and stop if we've finished all the tasks
+ finished(index) = true
+ if (tasksFinished == numTasks) {
+ sched.taskSetFinished(this)
+ }
+ } else {
+ logInfo("Ignoring task-finished event for TID " + tid +
+ " because task " + index + " is already finished")
+ }
+ }
+
+ def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ if (info.failed) {
+ // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
+ // or even from Mesos itself when acks get delayed.
+ return
+ }
+ val index = info.index
+ info.markFailed()
+ decreaseRunningTasks(1)
+ if (!finished(index)) {
+ logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
+ copiesRunning(index) -= 1
+ // Check if the problem is a map output fetch failure. In that case, this
+ // task will never succeed on any node, so tell the scheduler about it.
+ if (serializedData != null && serializedData.limit() > 0) {
+ val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
+ reason match {
+ case fetchFailed: FetchFailed =>
+ logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+ sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+ finished(index) = true
+ tasksFinished += 1
+ sched.taskSetFinished(this)
+ decreaseRunningTasks(runningTasks)
+ return
+
+ case taskResultTooBig: TaskResultTooBigFailure =>
+ logInfo("Loss was due to task %s result exceeding Akka frame size; " +
+ "aborting job".format(tid))
+ abort("Task %s result exceeded Akka frame size".format(tid))
+ return
+
+ case ef: ExceptionFailure =>
+ val key = ef.description
+ val now = System.currentTimeMillis
+ val (printFull, dupCount) = {
+ if (recentExceptions.contains(key)) {
+ val (dupCount, printTime) = recentExceptions(key)
+ if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
+ recentExceptions(key) = (0, now)
+ (true, 0)
+ } else {
+ recentExceptions(key) = (dupCount + 1, printTime)
+ (false, dupCount + 1)
+ }
+ } else {
+ recentExceptions(key) = (0, now)
+ (true, 0)
+ }
+ }
+ if (printFull) {
+ val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
+ logInfo("Loss was due to %s\n%s\n%s".format(
+ ef.className, ef.description, locs.mkString("\n")))
+ } else {
+ logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+ }
+
+ case _ => {}
+ }
+ }
+ // On non-fetch failures, re-enqueue the task as pending for a max number of retries
+ addPendingTask(index)
+ // Count failed attempts only on FAILED and LOST state (not on KILLED)
+ if (state == TaskState.FAILED || state == TaskState.LOST) {
+ numFailures(index) += 1
+ if (numFailures(index) > MAX_TASK_FAILURES) {
+ logError("Task %s:%d failed more than %d times; aborting job".format(
+ taskSet.id, index, MAX_TASK_FAILURES))
+ abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
+ }
+ }
+ } else {
+ logInfo("Ignoring task-lost event for TID " + tid +
+ " because task " + index + " is already finished")
+ }
+ }
+
+ def error(message: String) {
+ // Save the error message
+ abort("Error: " + message)
+ }
+
+ def abort(message: String) {
+ failed = true
+ causeOfFailure = message
+ // TODO: Kill running tasks if we were not terminated due to a Mesos error
+ sched.listener.taskSetFailed(taskSet, message)
+ decreaseRunningTasks(runningTasks)
+ sched.taskSetFinished(this)
+ }
+
+ override def increaseRunningTasks(taskNum: Int) {
+ runningTasks += taskNum
+ if (parent != null) {
+ parent.increaseRunningTasks(taskNum)
+ }
+ }
+
+ override def decreaseRunningTasks(taskNum: Int) {
+ runningTasks -= taskNum
+ if (parent != null) {
+ parent.decreaseRunningTasks(taskNum)
+ }
+ }
+
+ //TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed
+ override def getSchedulableByName(name: String): Schedulable = {
+ return null
+ }
+
+ override def addSchedulable(schedulable:Schedulable) {
+ //nothing
+ }
+
+ override def removeSchedulable(schedulable:Schedulable) {
+ //nothing
+ }
+
+ override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+ var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
+ sortedTaskSetQueue += this
+ return sortedTaskSetQueue
+ }
+
+ override def executorLost(execId: String, hostPort: String) {
+ logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
+
+ // If some task has preferred locations only on hostname, and there are no more executors there,
+ // put it in the no-prefs list to avoid the wait from delay scheduling
+
+ // host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to
+ // no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc.
+ // Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if
+ // there is no host local node for the task (not if there is no process local node for the task)
+ for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) {
+ // val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
+ val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
+ if (newLocs.isEmpty) {
+ pendingTasksWithNoPrefs += index
+ }
+ }
+
+ // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
+ if (tasks(0).isInstanceOf[ShuffleMapTask]) {
+ for ((tid, info) <- taskInfos if info.executorId == execId) {
+ val index = taskInfos(tid).index
+ if (finished(index)) {
+ finished(index) = false
+ copiesRunning(index) -= 1
+ tasksFinished -= 1
+ addPendingTask(index)
+ // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
+ // stage finishes when a total of tasks.size tasks finish.
+ sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
+ }
+ }
+ }
+ // Also re-enqueue any tasks that were running on the node
+ for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
+ taskLost(tid, TaskState.KILLED, null)
+ }
+ }
+
+ /**
+ * Check for tasks to be speculated and return true if there are any. This is called periodically
+ * by the ClusterScheduler.
+ *
+ * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
+ * we don't scan the whole task set. It might also help to make this sorted by launch time.
+ */
+ override def checkSpeculatableTasks(): Boolean = {
+ // Can't speculate if we only have one task, or if all tasks have finished.
+ if (numTasks == 1 || tasksFinished == numTasks) {
+ return false
+ }
+ var foundTasks = false
+ val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
+ logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
+ if (tasksFinished >= minFinishedForSpeculation) {
+ val time = System.currentTimeMillis()
+ val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
+ Arrays.sort(durations)
+ val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
+ val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
+ // TODO: Threshold should also look at standard deviation of task durations and have a lower
+ // bound based on that.
+ logDebug("Task length threshold for speculation: " + threshold)
+ for ((tid, info) <- taskInfos) {
+ val index = info.index
+ if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
+ !speculatableTasks.contains(index)) {
+ logInfo(
+ "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
+ taskSet.id, index, info.hostPort, threshold))
+ speculatableTasks += index
+ foundTasks = true
+ }
+ }
+ }
+ return foundTasks
+ }
+
+ override def hasPendingTasks(): Boolean = {
+ numTasks > 0 && tasksFinished < numTasks
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/cluster/Pool.scala b/core/src/main/scala/spark/scheduler/cluster/Pool.scala
new file mode 100644
index 0000000000..941ba7a3f1
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/Pool.scala
@@ -0,0 +1,104 @@
+package spark.scheduler.cluster
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+
+import spark.Logging
+import spark.scheduler.cluster.SchedulingMode.SchedulingMode
+
+/**
+ * An Schedulable entity that represent collection of Pools or TaskSetManagers
+ */
+
+private[spark] class Pool(
+ val poolName: String,
+ val schedulingMode: SchedulingMode,
+ initMinShare: Int,
+ initWeight: Int)
+ extends Schedulable
+ with Logging {
+
+ var schedulableQueue = new ArrayBuffer[Schedulable]
+ var schedulableNameToSchedulable = new HashMap[String, Schedulable]
+
+ var weight = initWeight
+ var minShare = initMinShare
+ var runningTasks = 0
+
+ var priority = 0
+ var stageId = 0
+ var name = poolName
+ var parent:Schedulable = null
+
+ var taskSetSchedulingAlgorithm: SchedulingAlgorithm = {
+ schedulingMode match {
+ case SchedulingMode.FAIR =>
+ new FairSchedulingAlgorithm()
+ case SchedulingMode.FIFO =>
+ new FIFOSchedulingAlgorithm()
+ }
+ }
+
+ override def addSchedulable(schedulable: Schedulable) {
+ schedulableQueue += schedulable
+ schedulableNameToSchedulable(schedulable.name) = schedulable
+ schedulable.parent= this
+ }
+
+ override def removeSchedulable(schedulable: Schedulable) {
+ schedulableQueue -= schedulable
+ schedulableNameToSchedulable -= schedulable.name
+ }
+
+ override def getSchedulableByName(schedulableName: String): Schedulable = {
+ if (schedulableNameToSchedulable.contains(schedulableName)) {
+ return schedulableNameToSchedulable(schedulableName)
+ }
+ for (schedulable <- schedulableQueue) {
+ var sched = schedulable.getSchedulableByName(schedulableName)
+ if (sched != null) {
+ return sched
+ }
+ }
+ return null
+ }
+
+ override def executorLost(executorId: String, host: String) {
+ schedulableQueue.foreach(_.executorLost(executorId, host))
+ }
+
+ override def checkSpeculatableTasks(): Boolean = {
+ var shouldRevive = false
+ for (schedulable <- schedulableQueue) {
+ shouldRevive |= schedulable.checkSpeculatableTasks()
+ }
+ return shouldRevive
+ }
+
+ override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+ var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
+ val sortedSchedulableQueue = schedulableQueue.sortWith(taskSetSchedulingAlgorithm.comparator)
+ for (schedulable <- sortedSchedulableQueue) {
+ sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue()
+ }
+ return sortedTaskSetQueue
+ }
+
+ override def increaseRunningTasks(taskNum: Int) {
+ runningTasks += taskNum
+ if (parent != null) {
+ parent.increaseRunningTasks(taskNum)
+ }
+ }
+
+ override def decreaseRunningTasks(taskNum: Int) {
+ runningTasks -= taskNum
+ if (parent != null) {
+ parent.decreaseRunningTasks(taskNum)
+ }
+ }
+
+ override def hasPendingTasks(): Boolean = {
+ schedulableQueue.exists(_.hasPendingTasks())
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala
new file mode 100644
index 0000000000..2dd9c0564f
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/Schedulable.scala
@@ -0,0 +1,27 @@
+package spark.scheduler.cluster
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * An interface for schedulable entities.
+ * there are two type of Schedulable entities(Pools and TaskSetManagers)
+ */
+private[spark] trait Schedulable {
+ var parent: Schedulable
+ def weight: Int
+ def minShare: Int
+ def runningTasks: Int
+ def priority: Int
+ def stageId: Int
+ def name: String
+
+ def increaseRunningTasks(taskNum: Int): Unit
+ def decreaseRunningTasks(taskNum: Int): Unit
+ def addSchedulable(schedulable: Schedulable): Unit
+ def removeSchedulable(schedulable: Schedulable): Unit
+ def getSchedulableByName(name: String): Schedulable
+ def executorLost(executorId: String, host: String): Unit
+ def checkSpeculatableTasks(): Boolean
+ def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager]
+ def hasPendingTasks(): Boolean
+}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala
new file mode 100644
index 0000000000..18cc15c2a5
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala
@@ -0,0 +1,115 @@
+package spark.scheduler.cluster
+
+import java.io.{File, FileInputStream, FileOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.util.control.Breaks._
+import scala.xml._
+
+import spark.Logging
+import spark.scheduler.cluster.SchedulingMode.SchedulingMode
+
+import java.util.Properties
+
+/**
+ * An interface to build Schedulable tree
+ * buildPools: build the tree nodes(pools)
+ * addTaskSetManager: build the leaf nodes(TaskSetManagers)
+ */
+private[spark] trait SchedulableBuilder {
+ def buildPools()
+ def addTaskSetManager(manager: Schedulable, properties: Properties)
+}
+
+private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) extends SchedulableBuilder with Logging {
+
+ override def buildPools() {
+ //nothing
+ }
+
+ override def addTaskSetManager(manager: Schedulable, properties: Properties) {
+ rootPool.addSchedulable(manager)
+ }
+}
+
+private[spark] class FairSchedulableBuilder(val rootPool: Pool) extends SchedulableBuilder with Logging {
+
+ val schedulerAllocFile = System.getProperty("spark.fairscheduler.allocation.file","unspecified")
+ val FAIR_SCHEDULER_PROPERTIES = "spark.scheduler.cluster.fair.pool"
+ val DEFAULT_POOL_NAME = "default"
+ val MINIMUM_SHARES_PROPERTY = "minShare"
+ val SCHEDULING_MODE_PROPERTY = "schedulingMode"
+ val WEIGHT_PROPERTY = "weight"
+ val POOL_NAME_PROPERTY = "@name"
+ val POOLS_PROPERTY = "pool"
+ val DEFAULT_SCHEDULING_MODE = SchedulingMode.FIFO
+ val DEFAULT_MINIMUM_SHARE = 2
+ val DEFAULT_WEIGHT = 1
+
+ override def buildPools() {
+ val file = new File(schedulerAllocFile)
+ if (file.exists()) {
+ val xml = XML.loadFile(file)
+ for (poolNode <- (xml \\ POOLS_PROPERTY)) {
+
+ val poolName = (poolNode \ POOL_NAME_PROPERTY).text
+ var schedulingMode = DEFAULT_SCHEDULING_MODE
+ var minShare = DEFAULT_MINIMUM_SHARE
+ var weight = DEFAULT_WEIGHT
+
+ val xmlSchedulingMode = (poolNode \ SCHEDULING_MODE_PROPERTY).text
+ if (xmlSchedulingMode != "") {
+ try {
+ schedulingMode = SchedulingMode.withName(xmlSchedulingMode)
+ } catch {
+ case e: Exception => logInfo("Error xml schedulingMode, using default schedulingMode")
+ }
+ }
+
+ val xmlMinShare = (poolNode \ MINIMUM_SHARES_PROPERTY).text
+ if (xmlMinShare != "") {
+ minShare = xmlMinShare.toInt
+ }
+
+ val xmlWeight = (poolNode \ WEIGHT_PROPERTY).text
+ if (xmlWeight != "") {
+ weight = xmlWeight.toInt
+ }
+
+ val pool = new Pool(poolName, schedulingMode, minShare, weight)
+ rootPool.addSchedulable(pool)
+ logInfo("Create new pool with name:%s,schedulingMode:%s,minShare:%d,weight:%d".format(
+ poolName, schedulingMode, minShare, weight))
+ }
+ }
+
+ //finally create "default" pool
+ if (rootPool.getSchedulableByName(DEFAULT_POOL_NAME) == null) {
+ val pool = new Pool(DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)
+ rootPool.addSchedulable(pool)
+ logInfo("Create default pool with name:%s,schedulingMode:%s,minShare:%d,weight:%d".format(
+ DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT))
+ }
+ }
+
+ override def addTaskSetManager(manager: Schedulable, properties: Properties) {
+ var poolName = DEFAULT_POOL_NAME
+ var parentPool = rootPool.getSchedulableByName(poolName)
+ if (properties != null) {
+ poolName = properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME)
+ parentPool = rootPool.getSchedulableByName(poolName)
+ if (parentPool == null) {
+ //we will create a new pool that user has configured in app instead of being defined in xml file
+ parentPool = new Pool(poolName,DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)
+ rootPool.addSchedulable(parentPool)
+ logInfo("Create pool with name:%s,schedulingMode:%s,minShare:%d,weight:%d".format(
+ poolName, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT))
+ }
+ }
+ parentPool.addSchedulable(manager)
+ logInfo("Added task set " + manager.name + " tasks to pool "+poolName)
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
index 9ac875de3a..8844057a5c 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
@@ -1,6 +1,6 @@
package spark.scheduler.cluster
-import spark.Utils
+import spark.{SparkContext, Utils}
/**
* A backend interface for cluster scheduling systems that allows plugging in different ones under
@@ -14,14 +14,7 @@ private[spark] trait SchedulerBackend {
def defaultParallelism(): Int
// Memory used by each executor (in megabytes)
- protected val executorMemory = {
- // TODO: Might need to add some extra memory for the non-heap parts of the JVM
- Option(System.getProperty("spark.executor.memory"))
- .orElse(Option(System.getenv("SPARK_MEM")))
- .map(Utils.memoryStringToMb)
- .getOrElse(512)
- }
-
+ protected val executorMemory: Int = SparkContext.executorMemoryRequested
// TODO: Probably want to add a killTask too
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala
new file mode 100644
index 0000000000..f33310a34a
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingAlgorithm.scala
@@ -0,0 +1,64 @@
+package spark.scheduler.cluster
+
+/**
+ * An interface for sort algorithm
+ * FIFO: FIFO algorithm between TaskSetManagers
+ * FS: FS algorithm between Pools, and FIFO or FS within Pools
+ */
+private[spark] trait SchedulingAlgorithm {
+ def comparator(s1: Schedulable, s2: Schedulable): Boolean
+}
+
+private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm {
+ override def comparator(s1: Schedulable, s2: Schedulable): Boolean = {
+ val priority1 = s1.priority
+ val priority2 = s2.priority
+ var res = math.signum(priority1 - priority2)
+ if (res == 0) {
+ val stageId1 = s1.stageId
+ val stageId2 = s2.stageId
+ res = math.signum(stageId1 - stageId2)
+ }
+ if (res < 0) {
+ return true
+ } else {
+ return false
+ }
+ }
+}
+
+private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm {
+ override def comparator(s1: Schedulable, s2: Schedulable): Boolean = {
+ val minShare1 = s1.minShare
+ val minShare2 = s2.minShare
+ val runningTasks1 = s1.runningTasks
+ val runningTasks2 = s2.runningTasks
+ val s1Needy = runningTasks1 < minShare1
+ val s2Needy = runningTasks2 < minShare2
+ val minShareRatio1 = runningTasks1.toDouble / math.max(minShare1, 1.0).toDouble
+ val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble
+ val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble
+ val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble
+ var res:Boolean = true
+ var compare:Int = 0
+
+ if (s1Needy && !s2Needy) {
+ return true
+ } else if (!s1Needy && s2Needy) {
+ return false
+ } else if (s1Needy && s2Needy) {
+ compare = minShareRatio1.compareTo(minShareRatio2)
+ } else {
+ compare = taskToWeightRatio1.compareTo(taskToWeightRatio2)
+ }
+
+ if (compare < 0) {
+ return true
+ } else if (compare > 0) {
+ return false
+ } else {
+ return s1.name < s2.name
+ }
+ }
+}
+
diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala
new file mode 100644
index 0000000000..6e0c6793e0
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/SchedulingMode.scala
@@ -0,0 +1,7 @@
+package spark.scheduler.cluster
+
+object SchedulingMode extends Enumeration("FAIR","FIFO"){
+
+ type SchedulingMode = Value
+ val FAIR,FIFO = Value
+}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index bb289c9cf3..170ede0f44 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -31,7 +31,8 @@ private[spark] class SparkDeploySchedulerBackend(
val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
val sparkHome = sc.getSparkHome().getOrElse(
throw new IllegalArgumentException("must supply spark home for spark standalone"))
- val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome)
+ val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome,
+ sc.ui.appUIAddress)
client = new Client(sc.env.actorSystem, master, appDesc, this)
client.start()
@@ -57,9 +58,9 @@ private[spark] class SparkDeploySchedulerBackend(
}
}
- override def executorAdded(executorId: String, workerId: String, host: String, cores: Int, memory: Int) {
- logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format(
- executorId, host, cores, Utils.memoryMegabytesToString(memory)))
+ override def executorAdded(executorId: String, workerId: String, hostPort: String, cores: Int, memory: Int) {
+ logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format(
+ executorId, hostPort, cores, Utils.memoryMegabytesToString(memory)))
}
override def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) {
diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
index d766067824..3335294844 100644
--- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
@@ -3,6 +3,7 @@ package spark.scheduler.cluster
import spark.TaskState.TaskState
import java.nio.ByteBuffer
import spark.util.SerializableBuffer
+import spark.Utils
private[spark] sealed trait StandaloneClusterMessage extends Serializable
@@ -19,8 +20,10 @@ case class RegisterExecutorFailed(message: String) extends StandaloneClusterMess
// Executors to driver
private[spark]
-case class RegisterExecutor(executorId: String, host: String, cores: Int)
- extends StandaloneClusterMessage
+case class RegisterExecutor(executorId: String, hostPort: String, cores: Int)
+ extends StandaloneClusterMessage {
+ Utils.checkHostPort(hostPort, "Expected host port")
+}
private[spark]
case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer)
diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
index a06d853b46..16131215c8 100644
--- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
@@ -6,8 +6,9 @@ import akka.actor._
import scala.concurrent.duration._
import akka.pattern.ask
-import spark.{SparkException, Logging, TaskState}
+import spark.{Utils, SparkException, Logging, TaskState}
import scala.concurrent.Await
+
import java.util.concurrent.atomic.AtomicInteger
import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent}
@@ -24,12 +25,12 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
var totalCoreCount = new AtomicInteger(0)
class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor {
- val executorActor = new HashMap[String, ActorRef]
- val executorAddress = new HashMap[String, Address]
- val executorHost = new HashMap[String, String]
- val freeCores = new HashMap[String, Int]
- val actorToExecutorId = new HashMap[ActorRef, String]
- val addressToExecutorId = new HashMap[Address, String]
+ private val executorActor = new HashMap[String, ActorRef]
+ private val executorAddress = new HashMap[String, Address]
+ private val executorHostPort = new HashMap[String, String]
+ private val freeCores = new HashMap[String, Int]
+ private val actorToExecutorId = new HashMap[ActorRef, String]
+ private val addressToExecutorId = new HashMap[Address, String]
override def preStart() {
// Listen for remote client disconnection events, since they don't go through Akka's watch()
@@ -37,7 +38,8 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
def receive = {
- case RegisterExecutor(executorId, host, cores) =>
+ case RegisterExecutor(executorId, hostPort, cores) =>
+ Utils.checkHostPort(hostPort, "Host port expected " + hostPort)
if (executorActor.contains(executorId)) {
sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId)
} else {
@@ -45,7 +47,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
sender ! RegisteredExecutor(sparkProperties)
context.watch(sender)
executorActor(executorId) = sender
- executorHost(executorId) = host
+ executorHostPort(executorId) = hostPort
freeCores(executorId) = cores
executorAddress(executorId) = sender.path.address
actorToExecutorId(sender) = executorId
@@ -85,13 +87,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
// Make fake resource offers on all executors
def makeOffers() {
launchTasks(scheduler.resourceOffers(
- executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))}))
+ executorHostPort.toArray.map {case (id, hostPort) => new WorkerOffer(id, hostPort, freeCores(id))}))
}
// Make fake resource offers on just one executor
def makeOffers(executorId: String) {
launchTasks(scheduler.resourceOffers(
- Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId)))))
+ Seq(new WorkerOffer(executorId, executorHostPort(executorId), freeCores(executorId)))))
}
// Launch tasks returned by a set of resource offers
@@ -110,9 +112,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
actorToExecutorId -= executorActor(executorId)
addressToExecutorId -= executorAddress(executorId)
executorActor -= executorId
- executorHost -= executorId
+ executorHostPort -= executorId
freeCores -= executorId
- executorHost -= executorId
+ executorHostPort -= executorId
totalCoreCount.addAndGet(-numCores)
scheduler.executorLost(executorId, SlaveLost(reason))
}
@@ -128,7 +130,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
while (iterator.hasNext) {
val entry = iterator.next
val (key, value) = (entry.getKey.toString, entry.getValue.toString)
- if (key.startsWith("spark.")) {
+ if (key.startsWith("spark.") && !key.equals("spark.hostPort")) {
properties += ((key, value))
}
}
@@ -136,10 +138,11 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
}
+ private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+
override def stop() {
try {
if (driverActor != null) {
- val timeout = 5.seconds
val future = driverActor.ask(StopDriver)(timeout)
Await.result(future, timeout)
}
@@ -159,7 +162,6 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
// Called by subclasses when notified of a lost worker
def removeExecutor(executorId: String, reason: String) {
try {
- val timeout = 5.seconds
val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout)
Await.result(future, timeout)
} catch {
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
index dfe3c5a85b..718f26bfbd 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
@@ -1,5 +1,7 @@
package spark.scheduler.cluster
+import spark.Utils
+
/**
* Information about a running task attempt inside a TaskSet.
*/
@@ -9,8 +11,11 @@ class TaskInfo(
val index: Int,
val launchTime: Long,
val executorId: String,
- val host: String,
- val preferred: Boolean) {
+ val hostPort: String,
+ val taskLocality: TaskLocality.TaskLocality) {
+
+ Utils.checkHostPort(hostPort, "Expected hostport")
+
var finishTime: Long = 0
var failed = false
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index c9f2c48804..b4dd75d90f 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -1,430 +1,17 @@
package spark.scheduler.cluster
-import java.util.Arrays
-import java.util.{HashMap => JHashMap}
-
import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.math.max
-import scala.math.min
-
-import spark._
import spark.scheduler._
import spark.TaskState.TaskState
import java.nio.ByteBuffer
-/**
- * Schedules the tasks within a single TaskSet in the ClusterScheduler.
- */
-private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends Logging {
-
- // Maximum time to wait to run a task in a preferred location (in ms)
- val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
-
- // CPUs to request per task
- val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble
-
- // Maximum times a task is allowed to fail before failing the job
- val MAX_TASK_FAILURES = 4
-
- // Quantile of tasks at which to start speculation
- val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
- val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
-
- // Serializer for closures and tasks.
- val ser = SparkEnv.get.closureSerializer.newInstance()
-
- val priority = taskSet.priority
- val tasks = taskSet.tasks
- val numTasks = tasks.length
- val copiesRunning = new Array[Int](numTasks)
- val finished = new Array[Boolean](numTasks)
- val numFailures = new Array[Int](numTasks)
- val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
- var tasksFinished = 0
-
- // Last time when we launched a preferred task (for delay scheduling)
- var lastPreferredLaunchTime = System.currentTimeMillis
-
- // List of pending tasks for each node. These collections are actually
- // treated as stacks, in which new tasks are added to the end of the
- // ArrayBuffer and removed from the end. This makes it faster to detect
- // tasks that repeatedly fail because whenever a task failed, it is put
- // back at the head of the stack. They are also only cleaned up lazily;
- // when a task is launched, it remains in all the pending lists except
- // the one that it was launched from, but gets removed from them later.
- val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
-
- // List containing pending tasks with no locality preferences
- val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
-
- // List containing all pending tasks (also used as a stack, as above)
- val allPendingTasks = new ArrayBuffer[Int]
-
- // Tasks that can be speculated. Since these will be a small fraction of total
- // tasks, we'll just hold them in a HashSet.
- val speculatableTasks = new HashSet[Int]
-
- // Task index, start and finish time for each task attempt (indexed by task ID)
- val taskInfos = new HashMap[Long, TaskInfo]
-
- // Did the job fail?
- var failed = false
- var causeOfFailure = ""
-
- // How frequently to reprint duplicate exceptions in full, in milliseconds
- val EXCEPTION_PRINT_INTERVAL =
- System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
- // Map of recent exceptions (identified by string representation and
- // top stack frame) to duplicate count (how many times the same
- // exception has appeared) and time the full exception was
- // printed. This should ideally be an LRU map that can drop old
- // exceptions automatically.
- val recentExceptions = HashMap[String, (Int, Long)]()
-
- // Figure out the current map output tracker generation and set it on all tasks
- val generation = sched.mapOutputTracker.getGeneration
- logDebug("Generation for " + taskSet.id + ": " + generation)
- for (t <- tasks) {
- t.generation = generation
- }
-
- // Add all our tasks to the pending lists. We do this in reverse order
- // of task index so that tasks with low indices get launched first.
- for (i <- (0 until numTasks).reverse) {
- addPendingTask(i)
- }
-
- // Add a task to all the pending-task lists that it should be on.
- private def addPendingTask(index: Int) {
- val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
- if (locations.size == 0) {
- pendingTasksWithNoPrefs += index
- } else {
- for (host <- locations) {
- val list = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
- list += index
- }
- }
- allPendingTasks += index
- }
-
- // Return the pending tasks list for a given host, or an empty list if
- // there is no map entry for that host
- private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
- pendingTasksForHost.getOrElse(host, ArrayBuffer())
- }
-
- // Dequeue a pending task from the given list and return its index.
- // Return None if the list is empty.
- // This method also cleans up any tasks in the list that have already
- // been launched, since we want that to happen lazily.
- private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
- while (!list.isEmpty) {
- val index = list.last
- list.trimEnd(1)
- if (copiesRunning(index) == 0 && !finished(index)) {
- return Some(index)
- }
- }
- return None
- }
-
- // Return a speculative task for a given host if any are available. The task should not have an
- // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
- // task must have a preference for this host (or no preferred locations at all).
- private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
- val hostsAlive = sched.hostsAlive
- speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
- val localTask = speculatableTasks.find {
- index =>
- val locations = tasks(index).preferredLocations.toSet & hostsAlive
- val attemptLocs = taskAttempts(index).map(_.host)
- (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host)
- }
- if (localTask != None) {
- speculatableTasks -= localTask.get
- return localTask
- }
- if (!localOnly && speculatableTasks.size > 0) {
- val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.host).contains(host))
- if (nonLocalTask != None) {
- speculatableTasks -= nonLocalTask.get
- return nonLocalTask
- }
- }
- return None
- }
-
- // Dequeue a pending task for a given node and return its index.
- // If localOnly is set to false, allow non-local tasks as well.
- private def findTask(host: String, localOnly: Boolean): Option[Int] = {
- val localTask = findTaskFromList(getPendingTasksForHost(host))
- if (localTask != None) {
- return localTask
- }
- val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
- if (noPrefTask != None) {
- return noPrefTask
- }
- if (!localOnly) {
- val nonLocalTask = findTaskFromList(allPendingTasks)
- if (nonLocalTask != None) {
- return nonLocalTask
- }
- }
- // Finally, if all else has failed, find a speculative task
- return findSpeculativeTask(host, localOnly)
- }
-
- // Does a host count as a preferred location for a task? This is true if
- // either the task has preferred locations and this host is one, or it has
- // no preferred locations (in which we still count the launch as preferred).
- private def isPreferredLocation(task: Task[_], host: String): Boolean = {
- val locs = task.preferredLocations
- return (locs.contains(host) || locs.isEmpty)
- }
-
- // Respond to an offer of a single slave from the scheduler by finding a task
- def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = {
- if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
- val time = System.currentTimeMillis
- val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
-
- findTask(host, localOnly) match {
- case Some(index) => {
- // Found a task; do some bookkeeping and return a Mesos task for it
- val task = tasks(index)
- val taskId = sched.newTaskId()
- // Figure out whether this should count as a preferred launch
- val preferred = isPreferredLocation(task, host)
- val prefStr = if (preferred) {
- "preferred"
- } else {
- "non-preferred, not one of " + task.preferredLocations.mkString(", ")
- }
- logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
- taskSet.id, index, taskId, execId, host, prefStr))
- // Do various bookkeeping
- copiesRunning(index) += 1
- val info = new TaskInfo(taskId, index, time, execId, host, preferred)
- taskInfos(taskId) = info
- taskAttempts(index) = info :: taskAttempts(index)
- if (preferred) {
- lastPreferredLaunchTime = time
- }
- // Serialize and return the task
- val startTime = System.currentTimeMillis
- val serializedTask = Task.serializeWithDependencies(
- task, sched.sc.addedFiles, sched.sc.addedJars, ser)
- val timeTaken = System.currentTimeMillis - startTime
- logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
- taskSet.id, index, serializedTask.limit, timeTaken))
- val taskName = "task %s:%d".format(taskSet.id, index)
- return Some(new TaskDescription(taskId, execId, taskName, serializedTask))
- }
- case _ =>
- }
- }
- return None
- }
-
- def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- state match {
- case TaskState.FINISHED =>
- taskFinished(tid, state, serializedData)
- case TaskState.LOST =>
- taskLost(tid, state, serializedData)
- case TaskState.FAILED =>
- taskLost(tid, state, serializedData)
- case TaskState.KILLED =>
- taskLost(tid, state, serializedData)
- case _ =>
- }
- }
-
- def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- val info = taskInfos(tid)
- if (info.failed) {
- // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
- // or even from Mesos itself when acks get delayed.
- return
- }
- val index = info.index
- info.markSuccessful()
- if (!finished(index)) {
- tasksFinished += 1
- logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
- tid, info.duration, tasksFinished, numTasks))
- // Deserialize task result and pass it to the scheduler
- val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
- result.metrics.resultSize = serializedData.limit()
- sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
- // Mark finished and stop if we've finished all the tasks
- finished(index) = true
- if (tasksFinished == numTasks) {
- sched.taskSetFinished(this)
- }
- } else {
- logInfo("Ignoring task-finished event for TID " + tid +
- " because task " + index + " is already finished")
- }
- }
-
- def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- val info = taskInfos(tid)
- if (info.failed) {
- // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
- // or even from Mesos itself when acks get delayed.
- return
- }
- val index = info.index
- info.markFailed()
- if (!finished(index)) {
- logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
- copiesRunning(index) -= 1
- // Check if the problem is a map output fetch failure. In that case, this
- // task will never succeed on any node, so tell the scheduler about it.
- if (serializedData != null && serializedData.limit() > 0) {
- val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
- reason match {
- case fetchFailed: FetchFailed =>
- logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
- sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
- finished(index) = true
- tasksFinished += 1
- sched.taskSetFinished(this)
- return
-
- case ef: ExceptionFailure =>
- val key = ef.exception.toString
- val now = System.currentTimeMillis
- val (printFull, dupCount) = {
- if (recentExceptions.contains(key)) {
- val (dupCount, printTime) = recentExceptions(key)
- if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
- recentExceptions(key) = (0, now)
- (true, 0)
- } else {
- recentExceptions(key) = (dupCount + 1, printTime)
- (false, dupCount + 1)
- }
- } else {
- recentExceptions(key) = (0, now)
- (true, 0)
- }
- }
- if (printFull) {
- val locs = ef.exception.getStackTrace.map(loc => "\tat %s".format(loc.toString))
- logInfo("Loss was due to %s\n%s".format(ef.exception.toString, locs.mkString("\n")))
- } else {
- logInfo("Loss was due to %s [duplicate %d]".format(ef.exception.toString, dupCount))
- }
-
- case _ => {}
- }
- }
- // On non-fetch failures, re-enqueue the task as pending for a max number of retries
- addPendingTask(index)
- // Count failed attempts only on FAILED and LOST state (not on KILLED)
- if (state == TaskState.FAILED || state == TaskState.LOST) {
- numFailures(index) += 1
- if (numFailures(index) > MAX_TASK_FAILURES) {
- logError("Task %s:%d failed more than %d times; aborting job".format(
- taskSet.id, index, MAX_TASK_FAILURES))
- abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
- }
- }
- } else {
- logInfo("Ignoring task-lost event for TID " + tid +
- " because task " + index + " is already finished")
- }
- }
-
- def error(message: String) {
- // Save the error message
- abort("Error: " + message)
- }
-
- def abort(message: String) {
- failed = true
- causeOfFailure = message
- // TODO: Kill running tasks if we were not terminated due to a Mesos error
- sched.listener.taskSetFailed(taskSet, message)
- sched.taskSetFinished(this)
- }
-
- def executorLost(execId: String, hostname: String) {
- logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
- val newHostsAlive = sched.hostsAlive
- // If some task has preferred locations only on hostname, and there are no more executors there,
- // put it in the no-prefs list to avoid the wait from delay scheduling
- if (!newHostsAlive.contains(hostname)) {
- for (index <- getPendingTasksForHost(hostname)) {
- val newLocs = tasks(index).preferredLocations.toSet & newHostsAlive
- if (newLocs.isEmpty) {
- pendingTasksWithNoPrefs += index
- }
- }
- }
- // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
- if (tasks(0).isInstanceOf[ShuffleMapTask]) {
- for ((tid, info) <- taskInfos if info.executorId == execId) {
- val index = taskInfos(tid).index
- if (finished(index)) {
- finished(index) = false
- copiesRunning(index) -= 1
- tasksFinished -= 1
- addPendingTask(index)
- // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
- // stage finishes when a total of tasks.size tasks finish.
- sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
- }
- }
- }
- // Also re-enqueue any tasks that were running on the node
- for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
- taskLost(tid, TaskState.KILLED, null)
- }
- }
-
- /**
- * Check for tasks to be speculated and return true if there are any. This is called periodically
- * by the ClusterScheduler.
- *
- * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
- * we don't scan the whole task set. It might also help to make this sorted by launch time.
- */
- def checkSpeculatableTasks(): Boolean = {
- // Can't speculate if we only have one task, or if all tasks have finished.
- if (numTasks == 1 || tasksFinished == numTasks) {
- return false
- }
- var foundTasks = false
- val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
- logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
- if (tasksFinished >= minFinishedForSpeculation) {
- val time = System.currentTimeMillis()
- val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
- Arrays.sort(durations)
- val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
- val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
- // TODO: Threshold should also look at standard deviation of task durations and have a lower
- // bound based on that.
- logDebug("Task length threshold for speculation: " + threshold)
- for ((tid, info) <- taskInfos) {
- val index = info.index
- if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
- !speculatableTasks.contains(index)) {
- logInfo(
- "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
- taskSet.id, index, info.host, threshold))
- speculatableTasks += index
- foundTasks = true
- }
- }
- }
- return foundTasks
- }
+private[spark] trait TaskSetManager extends Schedulable {
+ def taskSet: TaskSet
+ def slaveOffer(execId: String, hostPort: String, availableCpus: Double,
+ overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription]
+ def numPendingTasksForHostPort(hostPort: String): Int
+ def numRackLocalPendingTasksForHost(hostPort :String): Int
+ def numPendingTasksForHost(hostPort: String): Int
+ def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer)
+ def error(message: String)
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
index 3c3afcbb14..c47824315c 100644
--- a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
@@ -4,5 +4,5 @@ package spark.scheduler.cluster
* Represents free resources available on an executor.
*/
private[spark]
-class WorkerOffer(val executorId: String, val hostname: String, val cores: Int) {
+class WorkerOffer(val executorId: String, val hostPort: String, val cores: Int) {
}
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 9e1bde3fbe..93d4318b29 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -2,19 +2,50 @@ package spark.scheduler.local
import java.io.File
import java.util.concurrent.atomic.AtomicInteger
+import java.nio.ByteBuffer
+import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
import spark._
+import spark.TaskState.TaskState
import spark.executor.ExecutorURLClassLoader
import spark.scheduler._
-import spark.scheduler.cluster.TaskInfo
+import spark.scheduler.cluster._
+import akka.actor._
/**
- * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
+ * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
* the scheduler also allows each task to fail up to maxFailures times, which is useful for
* testing fault recovery.
*/
-private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext)
+
+private[spark] case class LocalReviveOffers()
+private[spark] case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
+
+private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging {
+ def receive = {
+ case LocalReviveOffers =>
+ launchTask(localScheduler.resourceOffer(freeCores))
+ case LocalStatusUpdate(taskId, state, serializeData) =>
+ freeCores += 1
+ localScheduler.statusUpdate(taskId, state, serializeData)
+ launchTask(localScheduler.resourceOffer(freeCores))
+ }
+
+ def launchTask(tasks : Seq[TaskDescription]) {
+ for (task <- tasks) {
+ freeCores -= 1
+ localScheduler.threadPool.submit(new Runnable {
+ def run() {
+ localScheduler.runTask(task.taskId,task.serializedTask)
+ }
+ })
+ }
+ }
+}
+
+private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
extends TaskScheduler
with Logging {
@@ -30,87 +61,127 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader)
- // TODO: Need to take into account stage priority in scheduling
+ var schedulableBuilder: SchedulableBuilder = null
+ var rootPool: Pool = null
+ val activeTaskSets = new HashMap[String, TaskSetManager]
+ val taskIdToTaskSetId = new HashMap[Long, String]
+ val taskSetTaskIds = new HashMap[String, HashSet[Long]]
+
+ var localActor: ActorRef = null
+
+ override def start() {
+ //default scheduler is FIFO
+ val schedulingMode = System.getProperty("spark.cluster.schedulingmode", "FIFO")
+ //temporarily set rootPool name to empty
+ rootPool = new Pool("", SchedulingMode.withName(schedulingMode), 0, 0)
+ schedulableBuilder = {
+ schedulingMode match {
+ case "FIFO" =>
+ new FIFOSchedulableBuilder(rootPool)
+ case "FAIR" =>
+ new FairSchedulableBuilder(rootPool)
+ }
+ }
+ schedulableBuilder.buildPools()
- override def start() { }
+ localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
+ }
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
}
override def submitTasks(taskSet: TaskSet) {
- val tasks = taskSet.tasks
- val failCount = new Array[Int](tasks.size)
-
- def submitTask(task: Task[_], idInJob: Int) {
- val myAttemptId = attemptId.getAndIncrement()
- threadPool.submit(new Runnable {
- def run() {
- runTask(task, idInJob, myAttemptId)
- }
- })
+ synchronized {
+ var manager = new LocalTaskSetManager(this, taskSet)
+ schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
+ activeTaskSets(taskSet.id) = manager
+ taskSetTaskIds(taskSet.id) = new HashSet[Long]()
+ localActor ! LocalReviveOffers
}
+ }
+
+ def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
+ synchronized {
+ var freeCpuCores = freeCores
+ val tasks = new ArrayBuffer[TaskDescription](freeCores)
+ val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
+ for (manager <- sortedTaskSetQueue) {
+ logDebug("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks))
+ }
- def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
- logInfo("Running " + task)
- val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local", true)
- // Set the Spark execution environment for the worker thread
- SparkEnv.set(env)
- try {
- Accumulators.clear()
- Thread.currentThread().setContextClassLoader(classLoader)
-
- // Serialize and deserialize the task so that accumulators are changed to thread-local ones;
- // this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
- val ser = SparkEnv.get.closureSerializer.newInstance()
- val bytes = Task.serializeWithDependencies(task, sc.addedFiles, sc.addedJars, ser)
- logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes")
- val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
- updateDependencies(taskFiles, taskJars) // Download any files added with addFile
- val deserStart = System.currentTimeMillis()
- val deserializedTask = ser.deserialize[Task[_]](
- taskBytes, Thread.currentThread.getContextClassLoader)
- val deserTime = System.currentTimeMillis() - deserStart
-
- // Run it
- val result: Any = deserializedTask.run(attemptId)
-
- // Serialize and deserialize the result to emulate what the Mesos
- // executor does. This is useful to catch serialization errors early
- // on in development (so when users move their local Spark programs
- // to the cluster, they don't get surprised by serialization errors).
- val serResult = ser.serialize(result)
- deserializedTask.metrics.get.resultSize = serResult.limit()
- val resultToReturn = ser.deserialize[Any](serResult)
- val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
- ser.serialize(Accumulators.values))
- logInfo("Finished " + task)
- info.markSuccessful()
- deserializedTask.metrics.get.executorRunTime = info.duration.toInt //close enough
- deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
-
- // If the threadpool has not already been shutdown, notify DAGScheduler
- if (!Thread.currentThread().isInterrupted)
- listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, deserializedTask.metrics.getOrElse(null))
- } catch {
- case t: Throwable => {
- logError("Exception in task " + idInJob, t)
- failCount.synchronized {
- failCount(idInJob) += 1
- if (failCount(idInJob) <= maxFailures) {
- submitTask(task, idInJob)
- } else {
- // TODO: Do something nicer here to return all the way to the user
- if (!Thread.currentThread().isInterrupted)
- listener.taskEnded(task, new ExceptionFailure(t), null, null, info, null)
+ var launchTask = false
+ for (manager <- sortedTaskSetQueue) {
+ do {
+ launchTask = false
+ manager.slaveOffer(null,null,freeCpuCores) match {
+ case Some(task) =>
+ tasks += task
+ taskIdToTaskSetId(task.taskId) = manager.taskSet.id
+ taskSetTaskIds(manager.taskSet.id) += task.taskId
+ freeCpuCores -= 1
+ launchTask = true
+ case None => {}
}
- }
- }
+ } while(launchTask)
}
+ return tasks
}
+ }
- for ((task, i) <- tasks.zipWithIndex) {
- submitTask(task, i)
+ def taskSetFinished(manager: TaskSetManager) {
+ synchronized {
+ activeTaskSets -= manager.taskSet.id
+ manager.parent.removeSchedulable(manager)
+ logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
+ taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
+ taskSetTaskIds -= manager.taskSet.id
+ }
+ }
+
+ def runTask(taskId: Long, bytes: ByteBuffer) {
+ logInfo("Running " + taskId)
+ val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
+ // Set the Spark execution environment for the worker thread
+ SparkEnv.set(env)
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ try {
+ Accumulators.clear()
+ Thread.currentThread().setContextClassLoader(classLoader)
+
+ // Serialize and deserialize the task so that accumulators are changed to thread-local ones;
+ // this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
+ val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
+ updateDependencies(taskFiles, taskJars) // Download any files added with addFile
+ val deserStart = System.currentTimeMillis()
+ val deserializedTask = ser.deserialize[Task[_]](
+ taskBytes, Thread.currentThread.getContextClassLoader)
+ val deserTime = System.currentTimeMillis() - deserStart
+
+ // Run it
+ val result: Any = deserializedTask.run(taskId)
+
+ // Serialize and deserialize the result to emulate what the Mesos
+ // executor does. This is useful to catch serialization errors early
+ // on in development (so when users move their local Spark programs
+ // to the cluster, they don't get surprised by serialization errors).
+ val serResult = ser.serialize(result)
+ deserializedTask.metrics.get.resultSize = serResult.limit()
+ val resultToReturn = ser.deserialize[Any](serResult)
+ val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
+ ser.serialize(Accumulators.values))
+ logInfo("Finished " + taskId)
+ deserializedTask.metrics.get.executorRunTime = deserTime.toInt//info.duration.toInt //close enough
+ deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
+
+ val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null))
+ val serializedResult = ser.serialize(taskResult)
+ localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
+ } catch {
+ case t: Throwable => {
+ val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace)
+ localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure))
+ }
}
}
@@ -126,6 +197,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentFiles(name) = timestamp
}
+
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
@@ -141,7 +213,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
}
}
- override def stop() {
+ def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) {
+ synchronized {
+ val taskSetId = taskIdToTaskSetId(taskId)
+ val taskSetManager = activeTaskSets(taskSetId)
+ taskSetTaskIds(taskSetId) -= taskId
+ taskSetManager.statusUpdate(taskId, state, serializedData)
+ }
+ }
+
+ override def stop() {
threadPool.shutdownNow()
}
diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala
new file mode 100644
index 0000000000..70b69bb26f
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala
@@ -0,0 +1,172 @@
+package spark.scheduler.local
+
+import java.io.File
+import java.util.concurrent.atomic.AtomicInteger
+import java.nio.ByteBuffer
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+
+import spark._
+import spark.TaskState.TaskState
+import spark.scheduler._
+import spark.scheduler.cluster._
+
+private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging {
+ var parent: Schedulable = null
+ var weight: Int = 1
+ var minShare: Int = 0
+ var runningTasks: Int = 0
+ var priority: Int = taskSet.priority
+ var stageId: Int = taskSet.stageId
+ var name: String = "TaskSet_"+taskSet.stageId.toString
+
+
+ var failCount = new Array[Int](taskSet.tasks.size)
+ val taskInfos = new HashMap[Long, TaskInfo]
+ val numTasks = taskSet.tasks.size
+ var numFinished = 0
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ val copiesRunning = new Array[Int](numTasks)
+ val finished = new Array[Boolean](numTasks)
+ val numFailures = new Array[Int](numTasks)
+ val MAX_TASK_FAILURES = sched.maxFailures
+
+ def increaseRunningTasks(taskNum: Int): Unit = {
+ runningTasks += taskNum
+ if (parent != null) {
+ parent.increaseRunningTasks(taskNum)
+ }
+ }
+
+ def decreaseRunningTasks(taskNum: Int): Unit = {
+ runningTasks -= taskNum
+ if (parent != null) {
+ parent.decreaseRunningTasks(taskNum)
+ }
+ }
+
+ def addSchedulable(schedulable: Schedulable): Unit = {
+ //nothing
+ }
+
+ def removeSchedulable(schedulable: Schedulable): Unit = {
+ //nothing
+ }
+
+ def getSchedulableByName(name: String): Schedulable = {
+ return null
+ }
+
+ def executorLost(executorId: String, host: String): Unit = {
+ //nothing
+ }
+
+ def checkSpeculatableTasks(): Boolean = {
+ return true
+ }
+
+ def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+ var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
+ sortedTaskSetQueue += this
+ return sortedTaskSetQueue
+ }
+
+ def hasPendingTasks(): Boolean = {
+ return true
+ }
+
+ def findTask(): Option[Int] = {
+ for (i <- 0 to numTasks-1) {
+ if (copiesRunning(i) == 0 && !finished(i)) {
+ return Some(i)
+ }
+ }
+ return None
+ }
+
+ def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
+ SparkEnv.set(sched.env)
+ logDebug("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks))
+ if (availableCpus > 0 && numFinished < numTasks) {
+ findTask() match {
+ case Some(index) =>
+ val taskId = sched.attemptId.getAndIncrement()
+ val task = taskSet.tasks(index)
+ val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
+ taskInfos(taskId) = info
+ val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
+ val taskName = "task %s:%d".format(taskSet.id, index)
+ copiesRunning(index) += 1
+ increaseRunningTasks(1)
+ return Some(new TaskDescription(taskId, null, taskName, bytes))
+ case None => {}
+ }
+ }
+ return None
+ }
+
+ def numPendingTasksForHostPort(hostPort: String): Int = {
+ return 0
+ }
+
+ def numRackLocalPendingTasksForHost(hostPort :String): Int = {
+ return 0
+ }
+
+ def numPendingTasksForHost(hostPort: String): Int = {
+ return 0
+ }
+
+ def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ state match {
+ case TaskState.FINISHED =>
+ taskEnded(tid, state, serializedData)
+ case TaskState.FAILED =>
+ taskFailed(tid, state, serializedData)
+ case _ => {}
+ }
+ }
+
+ def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ val index = info.index
+ val task = taskSet.tasks(index)
+ info.markSuccessful()
+ val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
+ result.metrics.resultSize = serializedData.limit()
+ sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
+ numFinished += 1
+ decreaseRunningTasks(1)
+ finished(index) = true
+ if (numFinished == numTasks) {
+ sched.taskSetFinished(this)
+ }
+ }
+
+ def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ val index = info.index
+ val task = taskSet.tasks(index)
+ info.markFailed()
+ decreaseRunningTasks(1)
+ val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader)
+ if (!finished(index)) {
+ copiesRunning(index) -= 1
+ numFailures(index) += 1
+ val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
+ logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n")))
+ if (numFailures(index) > MAX_TASK_FAILURES) {
+ val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description)
+ decreaseRunningTasks(runningTasks)
+ sched.listener.taskSetFailed(taskSet, errorMessage)
+ // need to delete failed Taskset from schedule queue
+ sched.taskSetFinished(this)
+ }
+ }
+ }
+
+ def error(message: String) {
+ }
+}
diff --git a/core/src/main/scala/spark/serializer/Serializer.scala b/core/src/main/scala/spark/serializer/Serializer.scala
index aca86ab6f0..2ad73b711d 100644
--- a/core/src/main/scala/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/spark/serializer/Serializer.scala
@@ -1,10 +1,13 @@
package spark.serializer
-import java.nio.ByteBuffer
import java.io.{EOFException, InputStream, OutputStream}
+import java.nio.ByteBuffer
+
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+
import spark.util.ByteBufferInputStream
+
/**
* A serializer. Because some serialization libraries are not thread safe, this class is used to
* create [[spark.serializer.SerializerInstance]] objects that do the actual serialization and are
@@ -14,6 +17,7 @@ trait Serializer {
def newInstance(): SerializerInstance
}
+
/**
* An instance of a serializer, for use by one thread at a time.
*/
@@ -45,6 +49,7 @@ trait SerializerInstance {
}
}
+
/**
* A stream for writing serialized objects.
*/
@@ -61,6 +66,7 @@ trait SerializationStream {
}
}
+
/**
* A stream for reading serialized objects.
*/
diff --git a/core/src/main/scala/spark/serializer/SerializerManager.scala b/core/src/main/scala/spark/serializer/SerializerManager.scala
new file mode 100644
index 0000000000..60b2aac797
--- /dev/null
+++ b/core/src/main/scala/spark/serializer/SerializerManager.scala
@@ -0,0 +1,45 @@
+package spark.serializer
+
+import java.util.concurrent.ConcurrentHashMap
+
+
+/**
+ * A service that returns a serializer object given the serializer's class name. If a previous
+ * instance of the serializer object has been created, the get method returns that instead of
+ * creating a new one.
+ */
+private[spark] class SerializerManager {
+
+ private val serializers = new ConcurrentHashMap[String, Serializer]
+ private var _default: Serializer = _
+
+ def default = _default
+
+ def setDefault(clsName: String): Serializer = {
+ _default = get(clsName)
+ _default
+ }
+
+ def get(clsName: String): Serializer = {
+ if (clsName == null) {
+ default
+ } else {
+ var serializer = serializers.get(clsName)
+ if (serializer != null) {
+ // If the serializer has been created previously, reuse that.
+ serializer
+ } else this.synchronized {
+ // Otherwise, create a new one. But make sure no other thread has attempted
+ // to create another new one at the same time.
+ serializer = serializers.get(clsName)
+ if (serializer == null) {
+ val clsLoader = Thread.currentThread.getContextClassLoader
+ serializer =
+ Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer]
+ serializers.put(clsName, serializer)
+ }
+ serializer
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockException.scala b/core/src/main/scala/spark/storage/BlockException.scala
new file mode 100644
index 0000000000..f275d476df
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockException.scala
@@ -0,0 +1,5 @@
+package spark.storage
+
+private[spark]
+case class BlockException(blockId: String, message: String) extends Exception(message)
+
diff --git a/core/src/main/scala/spark/storage/BlockFetchTracker.scala b/core/src/main/scala/spark/storage/BlockFetchTracker.scala
index 993aece1f7..0718156b1b 100644
--- a/core/src/main/scala/spark/storage/BlockFetchTracker.scala
+++ b/core/src/main/scala/spark/storage/BlockFetchTracker.scala
@@ -1,10 +1,10 @@
package spark.storage
private[spark] trait BlockFetchTracker {
- def totalBlocks : Int
- def numLocalBlocks: Int
- def numRemoteBlocks: Int
- def remoteFetchTime : Long
- def fetchWaitTime: Long
- def remoteBytesRead : Long
+ def totalBlocks : Int
+ def numLocalBlocks: Int
+ def numRemoteBlocks: Int
+ def remoteFetchTime : Long
+ def fetchWaitTime: Long
+ def remoteBytesRead : Long
}
diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
new file mode 100644
index 0000000000..bec876213e
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala
@@ -0,0 +1,330 @@
+package spark.storage
+
+import java.nio.ByteBuffer
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashSet
+import scala.collection.mutable.Queue
+
+import io.netty.buffer.ByteBuf
+
+import spark.Logging
+import spark.Utils
+import spark.SparkException
+import spark.network.BufferMessage
+import spark.network.ConnectionManagerId
+import spark.network.netty.ShuffleCopier
+import spark.serializer.Serializer
+
+
+/**
+ * A block fetcher iterator interface. There are two implementations:
+ *
+ * BasicBlockFetcherIterator: uses a custom-built NIO communication layer.
+ * NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer.
+ *
+ * Eventually we would like the two to converge and use a single NIO-based communication layer,
+ * but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores),
+ * NIO would perform poorly and thus the need for the Netty OIO one.
+ */
+
+private[storage]
+trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])]
+ with Logging with BlockFetchTracker {
+ def initialize()
+}
+
+
+private[storage]
+object BlockFetcherIterator {
+
+ // A request to fetch one or more blocks, complete with their sizes
+ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
+ val size = blocks.map(_._2).sum
+ }
+
+ // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
+ // the block (since we want all deserializaton to happen in the calling thread); can also
+ // represent a fetch failure if size == -1.
+ class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
+ def failed: Boolean = size == -1
+ }
+
+ class BasicBlockFetcherIterator(
+ private val blockManager: BlockManager,
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ serializer: Serializer)
+ extends BlockFetcherIterator {
+
+ import blockManager._
+
+ private var _remoteBytesRead = 0l
+ private var _remoteFetchTime = 0l
+ private var _fetchWaitTime = 0l
+
+ if (blocksByAddress == null) {
+ throw new IllegalArgumentException("BlocksByAddress is null")
+ }
+
+ // Total number blocks fetched (local + remote). Also number of FetchResults expected
+ protected var _numBlocksToFetch = 0
+
+ protected var startTime = System.currentTimeMillis
+
+ // This represents the number of local blocks, also counting zero-sized blocks
+ private var numLocal = 0
+ // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
+ protected val localBlocksToFetch = new ArrayBuffer[String]()
+
+ // This represents the number of remote blocks, also counting zero-sized blocks
+ private var numRemote = 0
+ // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
+ protected val remoteBlocksToFetch = new HashSet[String]()
+
+ // A queue to hold our results.
+ protected val results = new LinkedBlockingQueue[FetchResult]
+
+ // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+ // the number of bytes in flight is limited to maxBytesInFlight
+ private val fetchRequests = new Queue[FetchRequest]
+
+ // Current bytes in flight from our requests
+ private var bytesInFlight = 0L
+
+ protected def sendRequest(req: FetchRequest) {
+ logDebug("Sending request for %d blocks (%s) from %s".format(
+ req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort))
+ val cmId = new ConnectionManagerId(req.address.host, req.address.port)
+ val blockMessageArray = new BlockMessageArray(req.blocks.map {
+ case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
+ })
+ bytesInFlight += req.size
+ val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
+ val fetchStart = System.currentTimeMillis()
+ val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
+ future.onSuccess {
+ case Some(message) => {
+ val fetchDone = System.currentTimeMillis()
+ _remoteFetchTime += fetchDone - fetchStart
+ val bufferMessage = message.asInstanceOf[BufferMessage]
+ val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
+ for (blockMessage <- blockMessageArray) {
+ if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
+ throw new SparkException(
+ "Unexpected message " + blockMessage.getType + " received from " + cmId)
+ }
+ val blockId = blockMessage.getId
+ results.put(new FetchResult(blockId, sizeMap(blockId),
+ () => dataDeserialize(blockId, blockMessage.getData, serializer)))
+ _remoteBytesRead += req.size
+ logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+ }
+ }
+ case None => {
+ logError("Could not get block(s) from " + cmId)
+ for ((blockId, size) <- req.blocks) {
+ results.put(new FetchResult(blockId, -1, null))
+ }
+ }
+ }
+ }
+
+ protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
+ // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+ // at most maxBytesInFlight in order to limit the amount of data in flight.
+ val remoteRequests = new ArrayBuffer[FetchRequest]
+ for ((address, blockInfos) <- blocksByAddress) {
+ if (address == blockManagerId) {
+ numLocal = blockInfos.size
+ // Filter out zero-sized blocks
+ localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
+ _numBlocksToFetch += localBlocksToFetch.size
+ } else {
+ numRemote += blockInfos.size
+ // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
+ // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
+ // nodes, rather than blocking on reading output from one node.
+ val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
+ logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
+ val iterator = blockInfos.iterator
+ var curRequestSize = 0L
+ var curBlocks = new ArrayBuffer[(String, Long)]
+ while (iterator.hasNext) {
+ val (blockId, size) = iterator.next()
+ // Skip empty blocks
+ if (size > 0) {
+ curBlocks += ((blockId, size))
+ remoteBlocksToFetch += blockId
+ _numBlocksToFetch += 1
+ curRequestSize += size
+ } else if (size < 0) {
+ throw new BlockException(blockId, "Negative block size " + size)
+ }
+ if (curRequestSize >= minRequestSize) {
+ // Add this FetchRequest
+ remoteRequests += new FetchRequest(address, curBlocks)
+ curRequestSize = 0
+ curBlocks = new ArrayBuffer[(String, Long)]
+ }
+ }
+ // Add in the final request
+ if (!curBlocks.isEmpty) {
+ remoteRequests += new FetchRequest(address, curBlocks)
+ }
+ }
+ }
+ logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " +
+ totalBlocks + " blocks")
+ remoteRequests
+ }
+
+ protected def getLocalBlocks() {
+ // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
+ // these all at once because they will just memory-map some files, so they won't consume
+ // any memory that might exceed our maxBytesInFlight
+ for (id <- localBlocksToFetch) {
+ getLocalFromDisk(id, serializer) match {
+ case Some(iter) => {
+ // Pass 0 as size since it's not in flight
+ results.put(new FetchResult(id, 0, () => iter))
+ logDebug("Got local block " + id)
+ }
+ case None => {
+ throw new BlockException(id, "Could not get block " + id + " from local machine")
+ }
+ }
+ }
+ }
+
+ override def initialize() {
+ // Split local and remote blocks.
+ val remoteRequests = splitLocalRemoteBlocks()
+ // Add the remote requests into our queue in a random order
+ fetchRequests ++= Utils.randomize(remoteRequests)
+
+ // Send out initial requests for blocks, up to our maxBytesInFlight
+ while (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+
+ val numGets = remoteRequests.size - fetchRequests.size
+ logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
+
+ // Get Local Blocks
+ startTime = System.currentTimeMillis
+ getLocalBlocks()
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+ }
+
+ //an iterator that will read fetched blocks off the queue as they arrive.
+ @volatile protected var resultsGotten = 0
+
+ override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
+
+ override def next(): (String, Option[Iterator[Any]]) = {
+ resultsGotten += 1
+ val startFetchWait = System.currentTimeMillis()
+ val result = results.take()
+ val stopFetchWait = System.currentTimeMillis()
+ _fetchWaitTime += (stopFetchWait - startFetchWait)
+ if (! result.failed) bytesInFlight -= result.size
+ while (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+ (result.blockId, if (result.failed) None else Some(result.deserialize()))
+ }
+
+ // Implementing BlockFetchTracker trait.
+ override def totalBlocks: Int = numLocal + numRemote
+ override def numLocalBlocks: Int = numLocal
+ override def numRemoteBlocks: Int = numRemote
+ override def remoteFetchTime: Long = _remoteFetchTime
+ override def fetchWaitTime: Long = _fetchWaitTime
+ override def remoteBytesRead: Long = _remoteBytesRead
+ }
+ // End of BasicBlockFetcherIterator
+
+ class NettyBlockFetcherIterator(
+ blockManager: BlockManager,
+ blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ serializer: Serializer)
+ extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) {
+
+ import blockManager._
+
+ val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest]
+
+ private def startCopiers(numCopiers: Int): List[_ <: Thread] = {
+ (for ( i <- Range(0,numCopiers) ) yield {
+ val copier = new Thread {
+ override def run(){
+ try {
+ while(!isInterrupted && !fetchRequestsSync.isEmpty) {
+ sendRequest(fetchRequestsSync.take())
+ }
+ } catch {
+ case x: InterruptedException => logInfo("Copier Interrupted")
+ //case _ => throw new SparkException("Exception Throw in Shuffle Copier")
+ }
+ }
+ }
+ copier.start
+ copier
+ }).toList
+ }
+
+ // keep this to interrupt the threads when necessary
+ private def stopCopiers() {
+ for (copier <- copiers) {
+ copier.interrupt()
+ }
+ }
+
+ override protected def sendRequest(req: FetchRequest) {
+
+ def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) {
+ val fetchResult = new FetchResult(blockId, blockSize,
+ () => dataDeserialize(blockId, blockData.nioBuffer, serializer))
+ results.put(fetchResult)
+ }
+
+ logDebug("Sending request for %d blocks (%s) from %s".format(
+ req.blocks.size, Utils.memoryBytesToString(req.size), req.address.host))
+ val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort)
+ val cpier = new ShuffleCopier
+ cpier.getBlocks(cmId, req.blocks, putResult)
+ logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host )
+ }
+
+ private var copiers: List[_ <: Thread] = null
+
+ override def initialize() {
+ // Split Local Remote Blocks and set numBlocksToFetch
+ val remoteRequests = splitLocalRemoteBlocks()
+ // Add the remote requests into our queue in a random order
+ for (request <- Utils.randomize(remoteRequests)) {
+ fetchRequestsSync.put(request)
+ }
+
+ copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt)
+ logInfo("Started " + fetchRequestsSync.size + " remote gets in " +
+ Utils.getUsedTimeMs(startTime))
+
+ // Get Local Blocks
+ startTime = System.currentTimeMillis
+ getLocalBlocks()
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+ }
+
+ override def next(): (String, Option[Iterator[Any]]) = {
+ resultsGotten += 1
+ val result = results.take()
+ // If all the results has been retrieved, copiers will exit automatically
+ (result.blockId, if (result.failed) None else Some(result.deserialize()))
+ }
+ }
+ // End of NettyBlockFetcherIterator
+}
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index d3f6cd78dc..4bb4927b4a 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -2,10 +2,8 @@ package spark.storage
import java.io.{InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
-import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue}
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
-import scala.collection.JavaConversions._
+import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet}
import akka.actor.{ActorSystem, Cancellable, Props}
import scala.concurrent.{Await, Future}
@@ -16,7 +14,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-import spark.{Logging, SizeEstimator, SparkEnv, SparkException, Utils}
+import spark.{Logging, SparkEnv, SparkException, Utils}
import spark.network._
import spark.serializer.Serializer
import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap}
@@ -24,30 +22,35 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam
import sun.nio.ch.DirectBuffer
-private[spark]
-case class BlockException(blockId: String, message: String, ex: Exception = null)
-extends Exception(message)
-
-private[spark]
-class BlockManager(
+private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
val master: BlockManagerMaster,
- val serializer: Serializer,
+ val defaultSerializer: Serializer,
maxMemory: Long)
extends Logging {
- class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
- var pending: Boolean = true
- var size: Long = -1L
- var failed: Boolean = false
+ private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
+ @volatile var pending: Boolean = true
+ @volatile var size: Long = -1L
+ @volatile var initThread: Thread = null
+ @volatile var failed = false
+
+ setInitThread()
+
+ private def setInitThread() {
+ // Set current thread as init thread - waitForReady will not block this thread
+ // (in case there is non trivial initialization which ends up calling waitForReady as part of
+ // initialization itself)
+ this.initThread = Thread.currentThread()
+ }
/**
* Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
* Return true if the block is available, false otherwise.
*/
def waitForReady(): Boolean = {
- if (pending) {
+ if (initThread != Thread.currentThread() && pending) {
synchronized {
while (pending) this.wait()
}
@@ -57,35 +60,51 @@ class BlockManager(
/** Mark this BlockInfo as ready (i.e. block is finished writing) */
def markReady(sizeInBytes: Long) {
+ assert (pending)
+ size = sizeInBytes
+ initThread = null
+ failed = false
+ initThread = null
+ pending = false
synchronized {
- pending = false
- failed = false
- size = sizeInBytes
this.notifyAll()
}
}
/** Mark this BlockInfo as ready but failed */
def markFailure() {
+ assert (pending)
+ size = 0
+ initThread = null
+ failed = true
+ initThread = null
+ pending = false
synchronized {
- failed = true
- pending = false
this.notifyAll()
}
}
}
+ val shuffleBlockManager = new ShuffleBlockManager(this)
+
private val blockInfo = new TimeStampedHashMap[String, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
- private[storage] val diskStore: BlockStore =
+ private[storage] val diskStore: DiskStore =
new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
+ // If we use Netty for shuffle, start a new Netty-based shuffle sender service.
+ private val nettyPort: Int = {
+ val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
+ val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt
+ if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0
+ }
+
val connectionManager = new ConnectionManager(0)
implicit val futureExecContext = connectionManager.futureExecContext
val blockManagerId = BlockManagerId(
- executorId, connectionManager.id.host, connectionManager.id.port)
+ executorId, connectionManager.id.host, connectionManager.id.port, nettyPort)
// Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
// for receiving shuffle outputs)
@@ -101,7 +120,7 @@ class BlockManager(
val heartBeatFrequency = BlockManager.getHeartBeatFrequencyFromSystemProperties
- val host = System.getProperty("spark.hostname", Utils.localHostName())
+ val hostPort = Utils.localHostPort()
val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
@@ -212,9 +231,12 @@ class BlockManager(
* Tell the master about the current storage status of a block. This will send a block update
* message reflecting the current status, *not* the desired storage level in its block info.
* For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk.
+ *
+ * droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid).
+ * This ensures that update in master will compensate for the increase in memory on slave.
*/
- def reportBlockStatus(blockId: String, info: BlockInfo) {
- val needReregister = !tryToReportBlockStatus(blockId, info)
+ def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) {
+ val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize)
if (needReregister) {
logInfo("Got told to reregister updating block " + blockId)
// Reregistering will report our new block for free.
@@ -228,7 +250,7 @@ class BlockManager(
* which will be true if the block was successfully recorded and false if
* the slave needs to re-register.
*/
- private def tryToReportBlockStatus(blockId: String, info: BlockInfo): Boolean = {
+ private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = {
val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized {
info.level match {
case null =>
@@ -237,7 +259,7 @@ class BlockManager(
val inMem = level.useMemory && memoryStore.contains(blockId)
val onDisk = level.useDisk && diskStore.contains(blockId)
val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication)
- val memSize = if (inMem) memoryStore.getSize(blockId) else 0L
+ val memSize = if (inMem) memoryStore.getSize(blockId) else droppedMemorySize
val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L
(storageLevel, memSize, diskSize, info.tellMaster)
}
@@ -250,26 +272,24 @@ class BlockManager(
}
}
-
/**
- * Get locations of the block.
+ * Get locations of an array of blocks.
*/
- def getLocations(blockId: String): Seq[String] = {
+ def getLocationBlockIds(blockIds: Array[String]): Array[Seq[BlockManagerId]] = {
val startTimeMs = System.currentTimeMillis
- var managers = master.getLocations(blockId)
- val locations = managers.map(_.ip)
- logDebug("Got block locations in " + Utils.getUsedTimeMs(startTimeMs))
- return locations
+ val locations = master.getLocations(blockIds).toArray
+ logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
+ locations
}
/**
- * Get locations of an array of blocks.
+ * A short-circuited method to get blocks directly from disk. This is used for getting
+ * shuffle blocks. It is safe to do so without a lock on block info since disk store
+ * never deletes (recent) items.
*/
- def getLocations(blockIds: Array[String]): Array[Seq[String]] = {
- val startTimeMs = System.currentTimeMillis
- val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray
- logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
- return locations
+ def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ diskStore.getValues(blockId, serializer).orElse(
+ sys.error("Block " + blockId + " not found on disk, though it should be"))
}
/**
@@ -277,18 +297,6 @@ class BlockManager(
*/
def getLocal(blockId: String): Option[Iterator[Any]] = {
logDebug("Getting local block " + blockId)
-
- // As an optimization for map output fetches, if the block is for a shuffle, return it
- // without acquiring a lock; the disk store never deletes (recent) items so this should work
- if (blockId.startsWith("shuffle_")) {
- return diskStore.getValues(blockId) match {
- case Some(iterator) =>
- Some(iterator)
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- }
-
val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
@@ -339,6 +347,8 @@ class BlockManager(
case Some(bytes) =>
// Put a copy of the block back in memory before returning it. Note that we can't
// put the ByteBuffer returned by the disk store as that's a memory-mapped file.
+ // The use of rewind assumes this.
+ assert (0 == bytes.position())
val copyForMemory = ByteBuffer.allocate(bytes.limit)
copyForMemory.put(bytes)
memoryStore.putBytes(blockId, copyForMemory, level)
@@ -372,7 +382,7 @@ class BlockManager(
// As an optimization for map output fetches, if the block is for a shuffle, return it
// without acquiring a lock; the disk store never deletes (recent) items so this should work
- if (blockId.startsWith("shuffle_")) {
+ if (ShuffleBlockManager.isShuffle(blockId)) {
return diskStore.getBytes(blockId) match {
case Some(bytes) =>
Some(bytes)
@@ -411,6 +421,7 @@ class BlockManager(
// Read it as a byte buffer into memory first, then return it
diskStore.getBytes(blockId) match {
case Some(bytes) =>
+ assert (0 == bytes.position())
if (level.useMemory) {
if (level.deserialized) {
memoryStore.putBytes(blockId, bytes, level)
@@ -450,7 +461,7 @@ class BlockManager(
for (loc <- locations) {
logDebug("Getting remote block " + blockId + " from " + loc)
val data = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port))
+ GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
if (data != null) {
return Some(dataDeserialize(blockId, data))
}
@@ -473,9 +484,19 @@ class BlockManager(
* fashion as they're received. Expects a size in bytes to be provided for each block fetched,
* so that we can control the maxMegabytesInFlight for the fetch.
*/
- def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])])
+ def getMultiple(
+ blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer)
: BlockFetcherIterator = {
- return new BlockFetcherIterator(this, blocksByAddress)
+
+ val iter =
+ if (System.getProperty("spark.shuffle.use.netty", "false").toBoolean) {
+ new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer)
+ } else {
+ new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer)
+ }
+
+ iter.initialize()
+ iter
}
def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
@@ -486,6 +507,22 @@ class BlockManager(
}
/**
+ * A short circuited method to get a block writer that can write data directly to disk.
+ * This is currently used for writing shuffle files out. Callers should handle error
+ * cases.
+ */
+ def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ : BlockObjectWriter = {
+ val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
+ writer.registerCloseEventHandler(() => {
+ val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
+ blockInfo.put(blockId, myInfo)
+ myInfo.markReady(writer.size())
+ })
+ writer
+ }
+
+ /**
* Put a new block of values to the block manager. Returns its (estimated) size in bytes.
*/
def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel,
@@ -501,17 +538,26 @@ class BlockManager(
throw new IllegalArgumentException("Storage level is null or invalid")
}
- val oldBlock = blockInfo.get(blockId).orNull
- if (oldBlock != null && oldBlock.waitForReady()) {
- logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
- return oldBlock.size
- }
-
// Remember the block's storage level so that we can correctly drop it to disk if it needs
// to be dropped right after it got put into memory. Note, however, that other threads will
// not be able to get() this block until we call markReady on its BlockInfo.
- val myInfo = new BlockInfo(level, tellMaster)
- blockInfo.put(blockId, myInfo)
+ val myInfo = {
+ val tinfo = new BlockInfo(level, tellMaster)
+ // Do atomically !
+ val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
+
+ if (oldBlockOpt.isDefined) {
+ if (oldBlockOpt.get.waitForReady()) {
+ logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
+ return oldBlockOpt.get.size
+ }
+
+ // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
+ oldBlockOpt.get
+ } else {
+ tinfo
+ }
+ }
val startTimeMs = System.currentTimeMillis
@@ -531,6 +577,7 @@ class BlockManager(
logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
+ var marked = false
try {
if (level.useMemory) {
// Save it just to memory first, even if it also has useDisk set to true; we will later
@@ -555,26 +602,25 @@ class BlockManager(
// Now that the block is in either the memory or disk store, let other threads read it,
// and tell the master about it.
+ marked = true
myInfo.markReady(size)
if (tellMaster) {
reportBlockStatus(blockId, myInfo)
}
- } catch {
+ } finally {
// If we failed at putting the block to memory/disk, notify other possible readers
// that it has failed, and then remove it from the block info map.
- case e: Exception => {
+ if (! marked) {
// Note that the remove must happen before markFailure otherwise another thread
// could've inserted a new BlockInfo before we remove it.
blockInfo.remove(blockId)
myInfo.markFailure()
- logWarning("Putting block " + blockId + " failed", e)
- throw e
+ logWarning("Putting block " + blockId + " failed")
}
}
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
-
// Replicate block if required
if (level.replication > 1) {
val remoteStartTime = System.currentTimeMillis
@@ -611,16 +657,26 @@ class BlockManager(
throw new IllegalArgumentException("Storage level is null or invalid")
}
- if (blockInfo.contains(blockId)) {
- logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
- return
- }
-
// Remember the block's storage level so that we can correctly drop it to disk if it needs
// to be dropped right after it got put into memory. Note, however, that other threads will
// not be able to get() this block until we call markReady on its BlockInfo.
- val myInfo = new BlockInfo(level, tellMaster)
- blockInfo.put(blockId, myInfo)
+ val myInfo = {
+ val tinfo = new BlockInfo(level, tellMaster)
+ // Do atomically !
+ val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
+
+ if (oldBlockOpt.isDefined) {
+ if (oldBlockOpt.get.waitForReady()) {
+ logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
+ return
+ }
+
+ // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
+ oldBlockOpt.get
+ } else {
+ tinfo
+ }
+ }
val startTimeMs = System.currentTimeMillis
@@ -639,6 +695,7 @@ class BlockManager(
logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
+ var marked = false
try {
if (level.useMemory) {
// Store it only in memory at first, even if useDisk is also set to true
@@ -649,22 +706,24 @@ class BlockManager(
diskStore.putBytes(blockId, bytes, level)
}
+ // assert (0 == bytes.position(), "" + bytes)
+
// Now that the block is in either the memory or disk store, let other threads read it,
// and tell the master about it.
+ marked = true
myInfo.markReady(bytes.limit)
if (tellMaster) {
reportBlockStatus(blockId, myInfo)
}
- } catch {
+ } finally {
// If we failed at putting the block to memory/disk, notify other possible readers
// that it has failed, and then remove it from the block info map.
- case e: Exception => {
+ if (! marked) {
// Note that the remove must happen before markFailure otherwise another thread
// could've inserted a new BlockInfo before we remove it.
blockInfo.remove(blockId)
myInfo.markFailure()
- logWarning("Putting block " + blockId + " failed", e)
- throw e
+ logWarning("Putting block " + blockId + " failed")
}
}
}
@@ -698,7 +757,7 @@ class BlockManager(
logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is "
+ data.limit() + " Bytes. To node: " + peer)
if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel),
- new ConnectionManagerId(peer.ip, peer.port))) {
+ new ConnectionManagerId(peer.host, peer.port))) {
logError("Failed to call syncPutBlock to " + peer)
}
logDebug("Replicated BlockId " + blockId + " once used " +
@@ -730,6 +789,14 @@ class BlockManager(
val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
+ // required ? As of now, this will be invoked only for blocks which are ready
+ // But in case this changes in future, adding for consistency sake.
+ if (! info.waitForReady() ) {
+ // If we get here, the block write failed.
+ logWarning("Block " + blockId + " was marked as failure. Nothing to drop")
+ return
+ }
+
val level = info.level
if (level.useDisk && !diskStore.contains(blockId)) {
logInfo("Writing block " + blockId + " to disk")
@@ -740,12 +807,13 @@ class BlockManager(
diskStore.putBytes(blockId, bytes, level)
}
}
+ val droppedMemorySize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L
val blockWasRemoved = memoryStore.remove(blockId)
if (!blockWasRemoved) {
logWarning("Block " + blockId + " could not be dropped from memory as it does not exist")
}
if (info.tellMaster) {
- reportBlockStatus(blockId, info)
+ reportBlockStatus(blockId, info, droppedMemorySize)
}
if (!level.useDisk) {
// The block is completely gone from this node; forget it so we can put() it again later.
@@ -758,9 +826,23 @@ class BlockManager(
}
/**
+ * Remove all blocks belonging to the given RDD.
+ * @return The number of blocks removed.
+ */
+ def removeRdd(rddId: Int): Int = {
+ // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps
+ // from RDD.id to blocks.
+ logInfo("Removing RDD " + rddId)
+ val rddPrefix = "rdd_" + rddId + "_"
+ val blocksToRemove = blockInfo.filter(_._1.startsWith(rddPrefix)).map(_._1)
+ blocksToRemove.foreach(blockId => removeBlock(blockId, false))
+ blocksToRemove.size
+ }
+
+ /**
* Remove a block from both memory and disk.
*/
- def removeBlock(blockId: String) {
+ def removeBlock(blockId: String, tellMaster: Boolean = true) {
logInfo("Removing block " + blockId)
val info = blockInfo.get(blockId).orNull
if (info != null) info.synchronized {
@@ -772,7 +854,7 @@ class BlockManager(
"the disk or memory store")
}
blockInfo.remove(blockId)
- if (info.tellMaster) {
+ if (tellMaster && info.tellMaster) {
reportBlockStatus(blockId, info)
}
} else {
@@ -805,7 +887,7 @@ class BlockManager(
}
def shouldCompress(blockId: String): Boolean = {
- if (blockId.startsWith("shuffle_")) {
+ if (ShuffleBlockManager.isShuffle(blockId)) {
compressShuffle
} else if (blockId.startsWith("broadcast_")) {
compressBroadcast
@@ -820,7 +902,11 @@ class BlockManager(
* Wrap an output stream for compression if block compression is enabled for its block type
*/
def wrapForCompression(blockId: String, s: OutputStream): OutputStream = {
- if (shouldCompress(blockId)) new LZFOutputStream(s) else s
+ if (shouldCompress(blockId)) {
+ (new LZFOutputStream(s)).setFinishBlockOnFlush(true)
+ } else {
+ s
+ }
}
/**
@@ -830,7 +916,10 @@ class BlockManager(
if (shouldCompress(blockId)) new LZFInputStream(s) else s
}
- def dataSerialize(blockId: String, values: Iterator[Any]): ByteBuffer = {
+ def dataSerialize(
+ blockId: String,
+ values: Iterator[Any],
+ serializer: Serializer = defaultSerializer): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
val ser = serializer.newInstance()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
@@ -842,7 +931,10 @@ class BlockManager(
* Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of
* the iterator is reached.
*/
- def dataDeserialize(blockId: String, bytes: ByteBuffer): Iterator[Any] = {
+ def dataDeserialize(
+ blockId: String,
+ bytes: ByteBuffer,
+ serializer: Serializer = defaultSerializer): Iterator[Any] = {
bytes.rewind()
val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true))
serializer.newInstance().deserializeStream(stream).asIterator
@@ -862,8 +954,8 @@ class BlockManager(
}
}
-private[spark]
-object BlockManager extends Logging {
+
+private[spark] object BlockManager extends Logging {
val ID_GENERATOR = new IdGenerator
@@ -873,7 +965,8 @@ object BlockManager extends Logging {
}
def getHeartBeatFrequencyFromSystemProperties: Long =
- System.getProperty("spark.storage.blockManagerHeartBeatMs", "10000").toLong
+
+ System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong / 4
def getDisableHeartBeatsForTesting: Boolean =
System.getProperty("spark.test.disableBlockManagerHeartBeat", "false").toBoolean
@@ -892,177 +985,43 @@ object BlockManager extends Logging {
}
}
}
-}
-
-class BlockFetcherIterator(
- private val blockManager: BlockManager,
- val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]
-) extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker {
- import blockManager._
-
- private var _remoteBytesRead = 0l
- private var _remoteFetchTime = 0l
- private var _fetchWaitTime = 0l
-
- if (blocksByAddress == null) {
- throw new IllegalArgumentException("BlocksByAddress is null")
- }
- val totalBlocks = blocksByAddress.map(_._2.size).sum
- logDebug("Getting " + totalBlocks + " blocks")
- var startTime = System.currentTimeMillis
- val localBlockIds = new ArrayBuffer[String]()
- val remoteBlockIds = new HashSet[String]()
-
- // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
- // the block (since we want all deserializaton to happen in the calling thread); can also
- // represent a fetch failure if size == -1.
- class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
- def failed: Boolean = size == -1
- }
-
- // A queue to hold our results.
- val results = new LinkedBlockingQueue[FetchResult]
-
- // A request to fetch one or more blocks, complete with their sizes
- class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
- val size = blocks.map(_._2).sum
- }
+ def blockIdsToExecutorLocations(blockIds: Array[String], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null): HashMap[String, List[String]] = {
+ // env == null and blockManagerMaster != null is used in tests
+ assert (env != null || blockManagerMaster != null)
+ val locationBlockIds: Seq[Seq[BlockManagerId]] =
+ if (env != null) {
+ env.blockManager.getLocationBlockIds(blockIds)
+ } else {
+ blockManagerMaster.getLocations(blockIds)
+ }
- // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
- // the number of bytes in flight is limited to maxBytesInFlight
- val fetchRequests = new Queue[FetchRequest]
+ // Convert from block master locations to executor locations (we need that for task scheduling)
+ val executorLocations = new HashMap[String, List[String]]()
+ for (i <- 0 until blockIds.length) {
+ val blockId = blockIds(i)
+ val blockLocations = locationBlockIds(i)
- // Current bytes in flight from our requests
- var bytesInFlight = 0L
+ val executors = new HashSet[String]()
- def sendRequest(req: FetchRequest) {
- logDebug("Sending request for %d blocks (%s) from %s".format(
- req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
- val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
- val blockMessageArray = new BlockMessageArray(req.blocks.map {
- case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
- })
- bytesInFlight += req.size
- val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
- val fetchStart = System.currentTimeMillis()
- val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
- future.onSuccess {
- case Some(message) => {
- val fetchDone = System.currentTimeMillis()
- _remoteFetchTime += fetchDone - fetchStart
- val bufferMessage = message.asInstanceOf[BufferMessage]
- val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
- for (blockMessage <- blockMessageArray) {
- if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
- throw new SparkException(
- "Unexpected message " + blockMessage.getType + " received from " + cmId)
- }
- val blockId = blockMessage.getId
- results.put(new FetchResult(
- blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData)))
- _remoteBytesRead += req.size
- logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+ if (env != null) {
+ for (bkLocation <- blockLocations) {
+ val executorHostPort = env.resolveExecutorIdToHostPort(bkLocation.executorId, bkLocation.host)
+ executors += executorHostPort
+ // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort)
}
- }
- case None => {
- logError("Could not get block(s) from " + cmId)
- for ((blockId, size) <- req.blocks) {
- results.put(new FetchResult(blockId, -1, null))
+ } else {
+ // Typically while testing, etc - revert to simply using host.
+ for (bkLocation <- blockLocations) {
+ executors += bkLocation.host
+ // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort)
}
}
- }
- }
- // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
- // at most maxBytesInFlight in order to limit the amount of data in flight.
- val remoteRequests = new ArrayBuffer[FetchRequest]
- for ((address, blockInfos) <- blocksByAddress) {
- if (address == blockManagerId) {
- localBlockIds ++= blockInfos.map(_._1)
- } else {
- remoteBlockIds ++= blockInfos.map(_._1)
- // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
- // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
- // nodes, rather than blocking on reading output from one node.
- val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
- logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
- val iterator = blockInfos.iterator
- var curRequestSize = 0L
- var curBlocks = new ArrayBuffer[(String, Long)]
- while (iterator.hasNext) {
- val (blockId, size) = iterator.next()
- curBlocks += ((blockId, size))
- curRequestSize += size
- if (curRequestSize >= minRequestSize) {
- // Add this FetchRequest
- remoteRequests += new FetchRequest(address, curBlocks)
- curRequestSize = 0
- curBlocks = new ArrayBuffer[(String, Long)]
- }
- }
- // Add in the final request
- if (!curBlocks.isEmpty) {
- remoteRequests += new FetchRequest(address, curBlocks)
- }
+ executorLocations.put(blockId, executors.toSeq.toList)
}
- }
- // Add the remote requests into our queue in a random order
- fetchRequests ++= Utils.randomize(remoteRequests)
- // Send out initial requests for blocks, up to our maxBytesInFlight
- while (!fetchRequests.isEmpty &&
- (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
- sendRequest(fetchRequests.dequeue())
+ executorLocations
}
- val numGets = remoteBlockIds.size - fetchRequests.size
- logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
-
- // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
- // these all at once because they will just memory-map some files, so they won't consume
- // any memory that might exceed our maxBytesInFlight
- startTime = System.currentTimeMillis
- for (id <- localBlockIds) {
- getLocal(id) match {
- case Some(iter) => {
- results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
- logDebug("Got local block " + id)
- }
- case None => {
- throw new BlockException(id, "Could not get block " + id + " from local machine")
- }
- }
- }
- logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
-
- //an iterator that will read fetched blocks off the queue as they arrive.
- var resultsGotten = 0
-
- def hasNext: Boolean = resultsGotten < totalBlocks
-
- def next(): (String, Option[Iterator[Any]]) = {
- resultsGotten += 1
- val startFetchWait = System.currentTimeMillis()
- val result = results.take()
- val stopFetchWait = System.currentTimeMillis()
- _fetchWaitTime += (stopFetchWait - startFetchWait)
- bytesInFlight -= result.size
- while (!fetchRequests.isEmpty &&
- (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
- sendRequest(fetchRequests.dequeue())
- }
- (result.blockId, if (result.failed) None else Some(result.deserialize()))
- }
-
-
- //methods to profile the block fetching
- def numLocalBlocks = localBlockIds.size
- def numRemoteBlocks = remoteBlockIds.size
-
- def remoteFetchTime = _remoteFetchTime
- def fetchWaitTime = _fetchWaitTime
-
- def remoteBytesRead = _remoteBytesRead
-
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala
index f2f1e77d41..1e557d6148 100644
--- a/core/src/main/scala/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerId.scala
@@ -2,51 +2,70 @@ package spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.ConcurrentHashMap
+import spark.Utils
/**
* This class represent an unique identifier for a BlockManager.
* The first 2 constructors of this class is made private to ensure that
- * BlockManagerId objects can be created only using the factory method in
- * [[spark.storage.BlockManager$]]. This allows de-duplication of ID objects.
+ * BlockManagerId objects can be created only using the apply method in
+ * the companion object. This allows de-duplication of ID objects.
* Also, constructor parameters are private to ensure that parameters cannot
* be modified from outside this class.
*/
private[spark] class BlockManagerId private (
private var executorId_ : String,
- private var ip_ : String,
- private var port_ : Int
+ private var host_ : String,
+ private var port_ : Int,
+ private var nettyPort_ : Int
) extends Externalizable {
- private def this() = this(null, null, 0) // For deserialization only
+ private def this() = this(null, null, 0, 0) // For deserialization only
def executorId: String = executorId_
- def ip: String = ip_
+ if (null != host_){
+ Utils.checkHost(host_, "Expected hostname")
+ assert (port_ > 0)
+ }
+
+ def hostPort: String = {
+ // DEBUG code
+ Utils.checkHost(host)
+ assert (port > 0)
+
+ host + ":" + port
+ }
+
+ def host: String = host_
def port: Int = port_
+ def nettyPort: Int = nettyPort_
+
override def writeExternal(out: ObjectOutput) {
out.writeUTF(executorId_)
- out.writeUTF(ip_)
+ out.writeUTF(host_)
out.writeInt(port_)
+ out.writeInt(nettyPort_)
}
override def readExternal(in: ObjectInput) {
executorId_ = in.readUTF()
- ip_ = in.readUTF()
+ host_ = in.readUTF()
port_ = in.readInt()
+ nettyPort_ = in.readInt()
}
@throws(classOf[IOException])
private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this)
- override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, ip, port)
+ override def toString = "BlockManagerId(%s, %s, %d, %d)".format(executorId, host, port, nettyPort)
- override def hashCode: Int = (executorId.hashCode * 41 + ip.hashCode) * 41 + port
+ override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port + nettyPort
override def equals(that: Any) = that match {
case id: BlockManagerId =>
- executorId == id.executorId && port == id.port && ip == id.ip
+ executorId == id.executorId && port == id.port && host == id.host && nettyPort == id.nettyPort
case _ =>
false
}
@@ -55,8 +74,17 @@ private[spark] class BlockManagerId private (
private[spark] object BlockManagerId {
- def apply(execId: String, ip: String, port: Int) =
- getCachedBlockManagerId(new BlockManagerId(execId, ip, port))
+ /**
+ * Returns a [[spark.storage.BlockManagerId]] for the given configuraiton.
+ *
+ * @param execId ID of the executor.
+ * @param host Host name of the block manager.
+ * @param port Port of the block manager.
+ * @param nettyPort Optional port for the Netty-based shuffle sender.
+ * @return A new [[spark.storage.BlockManagerId]].
+ */
+ def apply(execId: String, host: String, port: Int, nettyPort: Int) =
+ getCachedBlockManagerId(new BlockManagerId(execId, host, port, nettyPort))
def apply(in: ObjectInput) = {
val obj = new BlockManagerId()
@@ -67,11 +95,7 @@ private[spark] object BlockManagerId {
val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()
def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
- if (blockManagerIdCache.containsKey(id)) {
- blockManagerIdCache.get(id)
- } else {
- blockManagerIdCache.put(id, id)
- id
- }
+ blockManagerIdCache.putIfAbsent(id, id)
+ blockManagerIdCache.get(id)
}
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
index 4e55936d28..6a9278292e 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -9,10 +9,13 @@ import scala.util.Random
import akka.actor.{Actor, ActorRef, ActorSystem, Props}
import scala.concurrent.Await
+import scala.concurrent.Future
+import scala.concurrent.ExecutionContext.Implicits.global
+
import akka.pattern.ask
import scala.concurrent.duration._
-import spark.{Logging, SparkException, Utils}
+import spark.{Logging, SparkException}
private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging {
@@ -21,7 +24,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster"
- val timeout = 10.seconds
+ val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
/** Remove a dead executor from the driver actor. This is only called on the driver side. */
def removeExecutor(execId: String) {
@@ -87,6 +90,19 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
}
/**
+ * Remove all blocks belonging to the given RDD.
+ */
+ def removeRdd(rddId: Int, blocking: Boolean) {
+ val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId))
+ future onFailure {
+ case e: Throwable => logError("Failed to remove RDD " + rddId, e)
+ }
+ if (blocking) {
+ Await.result(future, timeout)
+ }
+ }
+
+ /**
* Return the memory status for each block manager, in the form of a map from
* the block manager's id to two long values. The first value is the maximum
* amount of memory allocated for the block manager, while the second is the
@@ -97,7 +113,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
}
def getStorageStatus: Array[StorageStatus] = {
- askDriverWithReply[ArrayBuffer[StorageStatus]](GetStorageStatus).toArray
+ askDriverWithReply[Array[StorageStatus]](GetStorageStatus)
}
/** Stop the driver actor, called only on the Spark driver node */
@@ -134,7 +150,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
val future = driverActor.ask(message)(timeout)
val result = Await.result(future, timeout)
if (result == null) {
- throw new Exception("BlockManagerMaster returned null")
+ throw new SparkException("BlockManagerMaster returned null")
}
return result.asInstanceOf[T]
} catch {
diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
index 2d39e2c15c..6b5e38124b 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
@@ -2,14 +2,16 @@ package spark.storage
import java.util.{HashMap => JHashMap}
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.mutable
import scala.collection.JavaConversions._
-import scala.util.Random
import akka.actor.{Actor, ActorRef, Cancellable}
+import akka.pattern.ask
+
import scala.concurrent.duration._
+import scala.concurrent.Future
-import spark.{Logging, Utils}
+import spark.{Logging, Utils, SparkException}
/**
* BlockManagerMasterActor is an actor on the master node to track statuses of
@@ -20,13 +22,16 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
// Mapping from block manager id to the block manager's information.
private val blockManagerInfo =
- new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo]
+ new mutable.HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo]
// Mapping from executor ID to block manager ID.
- private val blockManagerIdByExecutor = new HashMap[String, BlockManagerId]
+ private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId]
// Mapping from block id to the set of block managers that have the block.
- private val blockLocations = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]]
+ private val blockLocations = new JHashMap[String, mutable.HashSet[BlockManagerId]]
+
+ val akkaTimeout = Duration.create(
+ System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
initLogging()
@@ -34,7 +39,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
"" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong
val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs",
- "5000").toLong
+ "60000").toLong
var timeoutCheckingTask: Cancellable = null
@@ -50,28 +55,34 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
def receive = {
case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) =>
register(blockManagerId, maxMemSize, slaveActor)
+ sender ! true
case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
+ // TODO: Ideally we want to handle all the message replies in receive instead of in the
+ // individual private methods.
updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size)
case GetLocations(blockId) =>
- getLocations(blockId)
+ sender ! getLocations(blockId)
case GetLocationsMultipleBlockIds(blockIds) =>
- getLocationsMultipleBlockIds(blockIds)
+ sender ! getLocationsMultipleBlockIds(blockIds)
case GetPeers(blockManagerId, size) =>
- getPeersDeterministic(blockManagerId, size)
- /*getPeers(blockManagerId, size)*/
+ sender ! getPeers(blockManagerId, size)
case GetMemoryStatus =>
- getMemoryStatus
+ sender ! memoryStatus
case GetStorageStatus =>
- getStorageStatus
+ sender ! storageStatus
+
+ case RemoveRdd(rddId) =>
+ sender ! removeRdd(rddId)
case RemoveBlock(blockId) =>
- removeBlock(blockId)
+ removeBlockFromWorkers(blockId)
+ sender ! true
case RemoveExecutor(execId) =>
removeExecutor(execId)
@@ -81,7 +92,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
logInfo("Stopping BlockManagerMaster")
sender ! true
if (timeoutCheckingTask != null) {
- timeoutCheckingTask.cancel
+ timeoutCheckingTask.cancel()
}
context.stop(self)
@@ -89,13 +100,36 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
expireDeadHosts()
case HeartBeat(blockManagerId) =>
- heartBeat(blockManagerId)
+ sender ! heartBeat(blockManagerId)
case other =>
- logInfo("Got unknown message: " + other)
+ logWarning("Got unknown message: " + other)
+ }
+
+ private def removeRdd(rddId: Int): Future[Seq[Int]] = {
+ // First remove the metadata for the given RDD, and then asynchronously remove the blocks
+ // from the slaves.
+
+ val prefix = "rdd_" + rddId + "_"
+ // Find all blocks for the given RDD, remove the block from both blockLocations and
+ // the blockManagerInfo that is tracking the blocks.
+ val blocks = blockLocations.keySet().filter(_.startsWith(prefix))
+ blocks.foreach { blockId =>
+ val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId)
+ bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId)))
+ blockLocations.remove(blockId)
+ }
+
+ // Ask the slaves to remove the RDD, and put the result in a sequence of Futures.
+ // The dispatcher is used as an implicit argument into the Future sequence construction.
+ import context.dispatcher
+ val removeMsg = RemoveRdd(rddId)
+ Future.sequence(blockManagerInfo.values.map { bm =>
+ bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
+ }.toSeq)
}
- def removeBlockManager(blockManagerId: BlockManagerId) {
+ private def removeBlockManager(blockManagerId: BlockManagerId) {
val info = blockManagerInfo(blockManagerId)
// Remove the block manager from blockManagerIdByExecutor.
@@ -106,7 +140,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
val iterator = info.blocks.keySet.iterator
while (iterator.hasNext) {
val blockId = iterator.next
- val locations = blockLocations.get(blockId)._2
+ val locations = blockLocations.get(blockId)
locations -= blockManagerId
if (locations.size == 0) {
blockLocations.remove(locations)
@@ -114,11 +148,11 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
}
- def expireDeadHosts() {
+ private def expireDeadHosts() {
logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.")
val now = System.currentTimeMillis()
val minSeenTime = now - slaveTimeout
- val toRemove = new HashSet[BlockManagerId]
+ val toRemove = new mutable.HashSet[BlockManagerId]
for (info <- blockManagerInfo.values) {
if (info.lastSeenMs < minSeenTime) {
logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: " +
@@ -129,31 +163,26 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
toRemove.foreach(removeBlockManager)
}
- def removeExecutor(execId: String) {
+ private def removeExecutor(execId: String) {
logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.")
blockManagerIdByExecutor.get(execId).foreach(removeBlockManager)
- sender ! true
}
- def heartBeat(blockManagerId: BlockManagerId) {
+ private def heartBeat(blockManagerId: BlockManagerId): Boolean = {
if (!blockManagerInfo.contains(blockManagerId)) {
- if (blockManagerId.executorId == "<driver>" && !isLocal) {
- sender ! true
- } else {
- sender ! false
- }
+ blockManagerId.executorId == "<driver>" && !isLocal
} else {
blockManagerInfo(blockManagerId).updateLastSeenMs()
- sender ! true
+ true
}
}
// Remove a block from the slaves that have it. This can only be used to remove
// blocks that the master knows about.
- private def removeBlock(blockId: String) {
- val block = blockLocations.get(blockId)
- if (block != null) {
- block._2.foreach { blockManagerId: BlockManagerId =>
+ private def removeBlockFromWorkers(blockId: String) {
+ val locations = blockLocations.get(blockId)
+ if (locations != null) {
+ locations.foreach { blockManagerId: BlockManagerId =>
val blockManager = blockManagerInfo.get(blockManagerId)
if (blockManager.isDefined) {
// Remove the block from the slave's BlockManager.
@@ -163,23 +192,20 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
}
}
- sender ! true
}
// Return a map from the block manager id to max memory and remaining memory.
- private def getMemoryStatus() {
- val res = blockManagerInfo.map { case(blockManagerId, info) =>
+ private def memoryStatus: Map[BlockManagerId, (Long, Long)] = {
+ blockManagerInfo.map { case(blockManagerId, info) =>
(blockManagerId, (info.maxMem, info.remainingMem))
}.toMap
- sender ! res
}
- private def getStorageStatus() {
- val res = blockManagerInfo.map { case(blockManagerId, info) =>
+ private def storageStatus: Array[StorageStatus] = {
+ blockManagerInfo.map { case(blockManagerId, info) =>
import collection.JavaConverters._
StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap)
- }
- sender ! res
+ }.toArray
}
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
@@ -188,7 +214,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
} else if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
case Some(manager) =>
- // A block manager of the same host name already exists
+ // A block manager of the same executor already exists.
+ // This should never happen. Let's just quit.
logError("Got two different block manager registrations on " + id.executorId)
System.exit(1)
case None =>
@@ -197,7 +224,6 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
blockManagerInfo(id) = new BlockManagerMasterActor.BlockManagerInfo(
id, System.currentTimeMillis(), maxMemSize, slaveActor)
}
- sender ! true
}
private def updateBlockInfo(
@@ -226,12 +252,12 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize)
- var locations: HashSet[BlockManagerId] = null
+ var locations: mutable.HashSet[BlockManagerId] = null
if (blockLocations.containsKey(blockId)) {
- locations = blockLocations.get(blockId)._2
+ locations = blockLocations.get(blockId)
} else {
- locations = new HashSet[BlockManagerId]
- blockLocations.put(blockId, (storageLevel.replication, locations))
+ locations = new mutable.HashSet[BlockManagerId]
+ blockLocations.put(blockId, locations)
}
if (storageLevel.isValid) {
@@ -247,70 +273,24 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
sender ! true
}
- private def getLocations(blockId: String) {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " " + blockId + " "
- if (blockLocations.containsKey(blockId)) {
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- res.appendAll(blockLocations.get(blockId)._2)
- sender ! res.toSeq
- } else {
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- sender ! res
- }
+ private def getLocations(blockId: String): Seq[BlockManagerId] = {
+ if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty
}
- private def getLocationsMultipleBlockIds(blockIds: Array[String]) {
- def getLocations(blockId: String): Seq[BlockManagerId] = {
- val tmp = blockId
- if (blockLocations.containsKey(blockId)) {
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- res.appendAll(blockLocations.get(blockId)._2)
- return res.toSeq
- } else {
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- return res.toSeq
- }
- }
-
- var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]]
- for (blockId <- blockIds) {
- res.append(getLocations(blockId))
- }
- sender ! res.toSeq
+ private def getLocationsMultipleBlockIds(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+ blockIds.map(blockId => getLocations(blockId))
}
- private def getPeers(blockManagerId: BlockManagerId, size: Int) {
- var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- res.appendAll(peers)
- res -= blockManagerId
- val rand = new Random(System.currentTimeMillis())
- while (res.length > size) {
- res.remove(rand.nextInt(res.length))
- }
- sender ! res.toSeq
- }
-
- private def getPeersDeterministic(blockManagerId: BlockManagerId, size: Int) {
- var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ private def getPeers(blockManagerId: BlockManagerId, size: Int): Seq[BlockManagerId] = {
+ val peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
val selfIndex = peers.indexOf(blockManagerId)
if (selfIndex == -1) {
- throw new Exception("Self index for " + blockManagerId + " not found")
+ throw new SparkException("Self index for " + blockManagerId + " not found")
}
// Note that this logic will select the same node multiple times if there aren't enough peers
- var index = selfIndex
- while (res.size < size) {
- index += 1
- if (index == selfIndex) {
- throw new Exception("More peer expected than available")
- }
- res += peers(index % peers.size)
- }
- sender ! res.toSeq
+ Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq
}
}
@@ -333,8 +313,8 @@ object BlockManagerMasterActor {
// Mapping from block id to its status.
private val _blocks = new JHashMap[String, BlockStatus]
- logInfo("Registering block manager %s:%d with %s RAM".format(
- blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem)))
+ logInfo("Registering block manager %s with %s RAM".format(
+ blockManagerId.hostPort, Utils.memoryBytesToString(maxMem)))
def updateLastSeenMs() {
_lastSeenMs = System.currentTimeMillis()
@@ -359,13 +339,13 @@ object BlockManagerMasterActor {
_blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize))
if (storageLevel.useMemory) {
_remainingMem -= memSize
- logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
+ logInfo("Added %s in memory on %s (size: %s, free: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize),
Utils.memoryBytesToString(_remainingMem)))
}
if (storageLevel.useDisk) {
- logInfo("Added %s on disk on %s:%d (size: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
+ logInfo("Added %s on disk on %s (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize)))
}
} else if (_blocks.containsKey(blockId)) {
// If isValid is not true, drop the block.
@@ -373,17 +353,24 @@ object BlockManagerMasterActor {
_blocks.remove(blockId)
if (blockStatus.storageLevel.useMemory) {
_remainingMem += blockStatus.memSize
- logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
+ logInfo("Removed %s on %s in memory (size: %s, free: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize),
Utils.memoryBytesToString(_remainingMem)))
}
if (blockStatus.storageLevel.useDisk) {
- logInfo("Removed %s on %s:%d on disk (size: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
+ logInfo("Removed %s on %s on disk (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize)))
}
}
}
+ def removeBlock(blockId: String) {
+ if (_blocks.containsKey(blockId)) {
+ _remainingMem += _blocks.get(blockId).memSize
+ _blocks.remove(blockId)
+ }
+ }
+
def remainingMem: Long = _remainingMem
def lastSeenMs: Long = _lastSeenMs
diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala
index cff48d9909..0010726c8d 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala
@@ -16,6 +16,9 @@ sealed trait ToBlockManagerSlave
private[spark]
case class RemoveBlock(blockId: String) extends ToBlockManagerSlave
+// Remove all blocks belonging to a specific RDD.
+private[spark] case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave
+
//////////////////////////////////////////////////////////////////////////////////
// Messages from slaves to the master.
diff --git a/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala
index f570cdc52d..b264d1deb5 100644
--- a/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala
@@ -11,6 +11,12 @@ import spark.{Logging, SparkException, Utils}
*/
class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor {
override def receive = {
- case RemoveBlock(blockId) => blockManager.removeBlock(blockId)
+
+ case RemoveBlock(blockId) =>
+ blockManager.removeBlock(blockId)
+
+ case RemoveRdd(rddId) =>
+ val numBlocksRemoved = blockManager.removeRdd(rddId)
+ sender ! numBlocksRemoved
}
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala
index a3397a0fb4..631455abcd 100644
--- a/core/src/main/scala/spark/storage/BlockManagerUI.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala
@@ -1,10 +1,12 @@
package spark.storage
import akka.actor.{ActorRef, ActorSystem}
+
import akka.util.Timeout
import scala.concurrent.duration._
import spray.httpx.TwirlSupport._
import spray.routing.Directives
+
import spark.{Logging, SparkContext}
import spark.util.AkkaUtils
import spark.Utils
@@ -20,20 +22,21 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef,
implicit val implicitActorSystem = actorSystem
val STATIC_RESOURCE_DIR = "spark/deploy/static"
- implicit val timeout = Timeout(10 seconds)
+ implicit val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ val host = Utils.localHostName()
+ val port = if (System.getProperty("spark.ui.port") != null) {
+ System.getProperty("spark.ui.port").toInt
+ } else {
+ // TODO: Unfortunately, it's not possible to pass port 0 to spray and figure out which
+ // random port it bound to, so we have to try to find a local one by creating a socket.
+ Utils.findFreePort()
+ }
/** Start a HTTP server to run the Web interface */
def start() {
try {
- val port = if (System.getProperty("spark.ui.port") != null) {
- System.getProperty("spark.ui.port").toInt
- } else {
- // TODO: Unfortunately, it's not possible to pass port 0 to spray and figure out which
- // random port it bound to, so we have to try to find a local one by creating a socket.
- Utils.findFreePort()
- }
- AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, handler)
- logInfo("Started BlockManager web UI at http://%s:%d".format(Utils.localHostName(), port))
+ AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, handler, "BlockManagerHTTPServer")
+ logInfo("Started BlockManager web UI at http://%s:%d".format(host, port))
} catch {
case e: Exception =>
logError("Failed to create BlockManager WebUI", e)
@@ -74,4 +77,6 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef,
}
}
}
+
+ private[spark] def appUIAddress = "http://" + host + ":" + port
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
index d2985559c1..3057ade233 100644
--- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
@@ -2,13 +2,7 @@ package spark.storage
import java.nio.ByteBuffer
-import scala.actors._
-import scala.actors.Actor._
-import scala.actors.remote._
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
-import scala.util.Random
-
-import spark.{Logging, Utils, SparkEnv}
+import spark.{Logging, Utils}
import spark.network._
/**
@@ -19,7 +13,7 @@ import spark.network._
*/
private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
initLogging()
-
+
blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive)
def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = {
@@ -51,7 +45,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
logDebug("Received [" + pB + "]")
putBlock(pB.id, pB.data, pB.level)
return None
- }
+ }
case BlockMessage.TYPE_GET_BLOCK => {
val gB = new GetBlock(blockMessage.getId)
logDebug("Received [" + gB + "]")
@@ -88,30 +82,26 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
private[spark] object BlockManagerWorker extends Logging {
private var blockManagerWorker: BlockManagerWorker = null
- private val DATA_TRANSFER_TIME_OUT_MS: Long = 500
- private val REQUEST_RETRY_INTERVAL_MS: Long = 1000
-
+
initLogging()
-
+
def startBlockManagerWorker(manager: BlockManager) {
blockManagerWorker = new BlockManagerWorker(manager)
}
-
+
def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = {
val blockManager = blockManagerWorker.blockManager
- val connectionManager = blockManager.connectionManager
- val serializer = blockManager.serializer
+ val connectionManager = blockManager.connectionManager
val blockMessage = BlockMessage.fromPutBlock(msg)
val blockMessageArray = new BlockMessageArray(blockMessage)
val resultMessage = connectionManager.sendMessageReliablySync(
toConnManagerId, blockMessageArray.toBufferMessage)
return (resultMessage != None)
}
-
+
def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
val blockManager = blockManagerWorker.blockManager
- val connectionManager = blockManager.connectionManager
- val serializer = blockManager.serializer
+ val connectionManager = blockManager.connectionManager
val blockMessage = BlockMessage.fromGetBlock(msg)
val blockMessageArray = new BlockMessageArray(blockMessage)
val responseMessage = connectionManager.sendMessageReliablySync(
diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala
index a25decb123..ee0c5ff9a2 100644
--- a/core/src/main/scala/spark/storage/BlockMessageArray.scala
+++ b/core/src/main/scala/spark/storage/BlockMessageArray.scala
@@ -115,6 +115,7 @@ private[spark] object BlockMessageArray {
val newBuffer = ByteBuffer.allocate(totalSize)
newBuffer.clear()
bufferMessage.buffers.foreach(buffer => {
+ assert (0 == buffer.position())
newBuffer.put(buffer)
buffer.rewind()
})
diff --git a/core/src/main/scala/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/spark/storage/BlockObjectWriter.scala
new file mode 100644
index 0000000000..42e2b07d5c
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockObjectWriter.scala
@@ -0,0 +1,50 @@
+package spark.storage
+
+import java.nio.ByteBuffer
+
+
+/**
+ * An interface for writing JVM objects to some underlying storage. This interface allows
+ * appending data to an existing block, and can guarantee atomicity in the case of faults
+ * as it allows the caller to revert partial writes.
+ *
+ * This interface does not support concurrent writes.
+ */
+abstract class BlockObjectWriter(val blockId: String) {
+
+ var closeEventHandler: () => Unit = _
+
+ def open(): BlockObjectWriter
+
+ def close() {
+ closeEventHandler()
+ }
+
+ def isOpen: Boolean
+
+ def registerCloseEventHandler(handler: () => Unit) {
+ closeEventHandler = handler
+ }
+
+ /**
+ * Flush the partial writes and commit them as a single atomic block. Return the
+ * number of bytes written for this commit.
+ */
+ def commit(): Long
+
+ /**
+ * Reverts writes that haven't been flushed yet. Callers should invoke this function
+ * when there are runtime exceptions.
+ */
+ def revertPartialWrites()
+
+ /**
+ * Writes an object.
+ */
+ def write(value: Any)
+
+ /**
+ * Size of the valid writes, in bytes.
+ */
+ def size(): Long
+}
diff --git a/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala b/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala
deleted file mode 100644
index f6c28dce52..0000000000
--- a/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala
+++ /dev/null
@@ -1,12 +0,0 @@
-package spark.storage
-
-private[spark] trait DelegateBlockFetchTracker extends BlockFetchTracker {
- var delegate : BlockFetchTracker = _
- def setDelegate(d: BlockFetchTracker) {delegate = d}
- def totalBlocks = delegate.totalBlocks
- def numLocalBlocks = delegate.numLocalBlocks
- def numRemoteBlocks = delegate.numRemoteBlocks
- def remoteFetchTime = delegate.remoteFetchTime
- def fetchWaitTime = delegate.fetchWaitTime
- def remoteBytesRead = delegate.remoteBytesRead
-}
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index ddbf8821ad..da859eebcb 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -1,41 +1,126 @@
package spark.storage
+import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
import java.nio.ByteBuffer
-import java.io.{File, FileOutputStream, RandomAccessFile}
+import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
import java.util.{Random, Date}
import java.text.SimpleDateFormat
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
import scala.collection.mutable.ArrayBuffer
-import spark.executor.ExecutorExitCode
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import spark.Utils
+import spark.executor.ExecutorExitCode
+import spark.serializer.{Serializer, SerializationStream}
+import spark.Logging
+import spark.network.netty.ShuffleSender
+import spark.network.netty.PathResolver
+
/**
* Stores BlockManager blocks on disk.
*/
private class DiskStore(blockManager: BlockManager, rootDirs: String)
- extends BlockStore(blockManager) {
+ extends BlockStore(blockManager) with Logging {
+
+ class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ extends BlockObjectWriter(blockId) {
+
+ private val f: File = createFile(blockId /*, allowAppendExisting */)
+
+ // The file channel, used for repositioning / truncating the file.
+ private var channel: FileChannel = null
+ private var bs: OutputStream = null
+ private var objOut: SerializationStream = null
+ private var lastValidPosition = 0L
+ private var initialized = false
+
+ override def open(): DiskBlockObjectWriter = {
+ val fos = new FileOutputStream(f, true)
+ channel = fos.getChannel()
+ bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize))
+ objOut = serializer.newInstance().serializeStream(bs)
+ initialized = true
+ this
+ }
+
+ override def close() {
+ if (initialized) {
+ objOut.close()
+ bs.close()
+ channel = null
+ bs = null
+ objOut = null
+ }
+ // Invoke the close callback handler.
+ super.close()
+ }
- val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
+ override def isOpen: Boolean = objOut != null
+ // Flush the partial writes, and set valid length to be the length of the entire file.
+ // Return the number of bytes written for this commit.
+ override def commit(): Long = {
+ if (initialized) {
+ // NOTE: Flush the serializer first and then the compressed/buffered output stream
+ objOut.flush()
+ bs.flush()
+ val prevPos = lastValidPosition
+ lastValidPosition = channel.position()
+ lastValidPosition - prevPos
+ } else {
+ // lastValidPosition is zero if stream is uninitialized
+ lastValidPosition
+ }
+ }
+
+ override def revertPartialWrites() {
+ if (initialized) {
+ // Discard current writes. We do this by flushing the outstanding writes and
+ // truncate the file to the last valid position.
+ objOut.flush()
+ bs.flush()
+ channel.truncate(lastValidPosition)
+ }
+ }
+
+ override def write(value: Any) {
+ if (!initialized) {
+ open()
+ }
+ objOut.writeObject(value)
+ }
+
+ override def size(): Long = lastValidPosition
+ }
+
+ private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
+ private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
+
+ private var shuffleSender : ShuffleSender = null
// Create one local directory for each path mentioned in spark.local.dir; then, inside this
// directory, create multiple subdirectories that we will hash files into, in order to avoid
// having really large inodes at the top level.
- val localDirs = createLocalDirs()
- val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
+ private val localDirs: Array[File] = createLocalDirs()
+ private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
addShutdownHook()
+ def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ : BlockObjectWriter = {
+ new DiskBlockObjectWriter(blockId, serializer, bufferSize)
+ }
+
override def getSize(blockId: String): Long = {
getFile(blockId).length()
}
- override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
+ // So that we do not modify the input offsets !
+ // duplicate does not copy buffer, so inexpensive
+ val bytes = _bytes.duplicate()
logDebug("Attempting to put block " + blockId)
val startTime = System.currentTimeMillis
val file = createFile(blockId)
@@ -49,6 +134,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
blockId, Utils.memoryBytesToString(bytes.limit), (finishTime - startTime)))
}
+ private def getFileBytes(file: File): ByteBuffer = {
+ val length = file.length()
+ val channel = new RandomAccessFile(file, "r").getChannel()
+ val buffer = try {
+ channel.map(MapMode.READ_ONLY, 0, length)
+ } finally {
+ channel.close()
+ }
+
+ buffer
+ }
+
override def putValues(
blockId: String,
values: ArrayBuffer[Any],
@@ -61,18 +158,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
val file = createFile(blockId)
val fileOut = blockManager.wrapForCompression(blockId,
new FastBufferedOutputStream(new FileOutputStream(file)))
- val objOut = blockManager.serializer.newInstance().serializeStream(fileOut)
+ val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut)
objOut.writeAll(values.iterator)
objOut.close()
val length = file.length()
+
+ val timeTaken = System.currentTimeMillis - startTime
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.memoryBytesToString(length), (System.currentTimeMillis - startTime)))
+ blockId, Utils.memoryBytesToString(length), timeTaken))
if (returnValues) {
// Return a byte buffer for the contents of the file
- val channel = new RandomAccessFile(file, "r").getChannel()
- val buffer = channel.map(MapMode.READ_ONLY, 0, length)
- channel.close()
+ val buffer = getFileBytes(file)
PutResult(length, Right(buffer))
} else {
PutResult(length, null)
@@ -81,10 +178,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
override def getBytes(blockId: String): Option[ByteBuffer] = {
val file = getFile(blockId)
- val length = file.length().toInt
- val channel = new RandomAccessFile(file, "r").getChannel()
- val bytes = channel.map(MapMode.READ_ONLY, 0, length)
- channel.close()
+ val bytes = getFileBytes(file)
Some(bytes)
}
@@ -92,11 +186,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
}
+ /**
+ * A version of getValues that allows a custom serializer. This is used as part of the
+ * shuffle short-circuit code.
+ */
+ def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer))
+ }
+
override def remove(blockId: String): Boolean = {
val file = getFile(blockId)
if (file.exists()) {
file.delete()
- true
} else {
false
}
@@ -106,10 +207,13 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
getFile(blockId).exists()
}
- private def createFile(blockId: String): File = {
+ private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = {
val file = getFile(blockId)
- if (file.exists()) {
- throw new Exception("File for block " + blockId + " already exists on disk: " + file)
+ if (!allowAppendExisting && file.exists()) {
+ // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
+ // was rescheduled on the same machine as the old task.
+ logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
+ file.delete()
}
file
}
@@ -144,8 +248,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private def createLocalDirs(): Array[File] = {
logDebug("Creating local directories at root dirs '" + rootDirs + "'")
val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
- rootDirs.split(",").map(rootDir => {
- var foundLocalDir: Boolean = false
+ rootDirs.split(",").map { rootDir =>
+ var foundLocalDir = false
var localDir: File = null
var localDirId: String = null
var tries = 0
@@ -156,12 +260,11 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
localDir = new File(rootDir, "spark-local-" + localDirId)
if (!localDir.exists) {
- localDir.mkdirs()
- foundLocalDir = true
+ foundLocalDir = localDir.mkdirs()
}
} catch {
case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir failed", e)
+ logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
}
}
if (!foundLocalDir) {
@@ -171,19 +274,40 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
logInfo("Created local directory at " + localDir)
localDir
- })
+ }
}
private def addShutdownHook() {
+ localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
override def run() {
logDebug("Shutdown hook called")
- try {
- localDirs.foreach(localDir => Utils.deleteRecursively(localDir))
- } catch {
- case t: Throwable => logError("Exception while deleting local spark dirs", t)
+ localDirs.foreach { localDir =>
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ } catch {
+ case t: Throwable =>
+ logError("Exception while deleting local spark dir: " + localDir, t)
+ }
+ }
+ if (shuffleSender != null) {
+ shuffleSender.stop
}
}
})
}
+
+ private[storage] def startShuffleBlockSender(port: Int): Int = {
+ val pResolver = new PathResolver {
+ override def getAbsolutePath(blockId: String): String = {
+ if (!blockId.startsWith("shuffle_")) {
+ return null
+ }
+ DiskStore.this.getFile(blockId).getAbsolutePath()
+ }
+ }
+ shuffleSender = new ShuffleSender(port, pResolver)
+ logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port)
+ shuffleSender.port
+ }
}
diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala
index 949588476c..eba5ee507f 100644
--- a/core/src/main/scala/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/spark/storage/MemoryStore.scala
@@ -31,7 +31,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
+ // Work on a duplicate - since the original input might be used elsewhere.
+ val bytes = _bytes.duplicate()
bytes.rewind()
if (level.deserialized) {
val values = blockManager.dataDeserialize(blockId, bytes)
diff --git a/core/src/main/scala/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
new file mode 100644
index 0000000000..44638e0c2d
--- /dev/null
+++ b/core/src/main/scala/spark/storage/ShuffleBlockManager.scala
@@ -0,0 +1,50 @@
+package spark.storage
+
+import spark.serializer.Serializer
+
+
+private[spark]
+class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter])
+
+
+private[spark]
+trait ShuffleBlocks {
+ def acquireWriters(mapId: Int): ShuffleWriterGroup
+ def releaseWriters(group: ShuffleWriterGroup)
+}
+
+
+private[spark]
+class ShuffleBlockManager(blockManager: BlockManager) {
+
+ def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = {
+ new ShuffleBlocks {
+ // Get a group of writers for a map task.
+ override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
+ val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+ val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
+ blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
+ }
+ new ShuffleWriterGroup(mapId, writers)
+ }
+
+ override def releaseWriters(group: ShuffleWriterGroup) = {
+ // Nothing really to release here.
+ }
+ }
+ }
+}
+
+
+private[spark]
+object ShuffleBlockManager {
+
+ // Returns the block id for a given shuffle block.
+ def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = {
+ "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId
+ }
+
+ // Returns true if the block is a shuffle block.
+ def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_")
+}
diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala
index 3b5a77ab22..cc0c354e7e 100644
--- a/core/src/main/scala/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/spark/storage/StorageLevel.scala
@@ -123,11 +123,7 @@ object StorageLevel {
val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]()
private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = {
- if (storageLevelCache.containsKey(level)) {
- storageLevelCache.get(level)
- } else {
- storageLevelCache.put(level, level)
- level
- }
+ storageLevelCache.putIfAbsent(level, level)
+ storageLevelCache.get(level)
}
}
diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala
index dec47a9d41..950c0cdf35 100644
--- a/core/src/main/scala/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/spark/storage/StorageUtils.scala
@@ -4,9 +4,9 @@ import spark.{Utils, SparkContext}
import BlockManagerMasterActor.BlockStatus
private[spark]
-case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
+case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
blocks: Map[String, BlockStatus]) {
-
+
def memUsed(blockPrefix: String = "") = {
blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize).
reduceOption(_+_).getOrElse(0l)
@@ -22,53 +22,62 @@ case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
}
case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
- numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) {
+ numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long)
+ extends Ordered[RDDInfo] {
override def toString = {
import Utils.memoryBytesToString
"RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id,
storageLevel.toString, numCachedPartitions, numPartitions, memoryBytesToString(memSize), memoryBytesToString(diskSize))
}
+
+ override def compare(that: RDDInfo) = {
+ this.id - that.id
+ }
}
/* Helper methods for storage-related objects */
private[spark]
object StorageUtils {
- /* Given the current storage status of the BlockManager, returns information for each RDD */
- def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus],
+ /* Given the current storage status of the BlockManager, returns information for each RDD */
+ def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus],
sc: SparkContext) : Array[RDDInfo] = {
- rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc)
+ rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc)
}
- /* Given a list of BlockStatus objets, returns information for each RDD */
- def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
+ /* Given a list of BlockStatus objets, returns information for each RDD */
+ def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
sc: SparkContext) : Array[RDDInfo] = {
// Group by rddId, ignore the partition name
- val groupedRddBlocks = infos.groupBy { case(k, v) =>
+ val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) =>
k.substring(0,k.lastIndexOf('_'))
}.mapValues(_.values.toArray)
// For each RDD, generate an RDDInfo object
- groupedRddBlocks.map { case(rddKey, rddBlocks) =>
-
+ val rddInfos = groupedRddBlocks.map { case (rddKey, rddBlocks) =>
// Add up memory and disk sizes
val memSize = rddBlocks.map(_.memSize).reduce(_ + _)
val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _)
// Find the id of the RDD, e.g. rdd_1 => 1
val rddId = rddKey.split("_").last.toInt
- // Get the friendly name for the rdd, if available.
- val rdd = sc.persistentRdds(rddId)
- val rddName = Option(rdd.name).getOrElse(rddKey)
- val rddStorageLevel = rdd.getStorageLevel
- RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.partitions.size, memSize, diskSize)
- }.toArray
+ // Get the friendly name and storage level for the RDD, if available
+ sc.persistentRdds.get(rddId).map { r =>
+ val rddName = Option(r.name).getOrElse(rddKey)
+ val rddStorageLevel = r.getStorageLevel
+ RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize)
+ }
+ }.flatten.toArray
+
+ scala.util.Sorting.quickSort(rddInfos)
+
+ rddInfos
}
- /* Removes all BlockStatus object that are not part of a block prefix */
- def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus],
+ /* Removes all BlockStatus object that are not part of a block prefix */
+ def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus],
prefix: String) : Array[StorageStatus] = {
storageStatusList.map { status =>
diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala
index e16915c8e9..ea39888c21 100644
--- a/core/src/main/scala/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/spark/util/AkkaUtils.scala
@@ -5,13 +5,15 @@ import com.typesafe.config.ConfigFactory
import scala.concurrent.duration._
import akka.pattern.ask
import akka.remote.RemoteActorRefProvider
+
import spray.routing.Route
import spray.io.IOExtension
import spray.routing.HttpServiceActor
import spray.can.server.{HttpServer, ServerSettings}
import spray.io.SingletonHandler
import scala.concurrent.Await
-import spark.SparkException
+import spark.{Utils, SparkException}
+
import java.util.concurrent.TimeoutException
/**
@@ -29,9 +31,14 @@ private[spark] object AkkaUtils {
def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = {
val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt
val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt
- val akkaTimeout = System.getProperty("spark.akka.timeout", "20").toInt
+
+ val akkaTimeout = System.getProperty("spark.akka.timeout", "60").toInt
+
val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt
- val lifecycleEvents = System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean
+ val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off"
+ // 10 seconds is the default akka timeout, but in a cluster, we need higher by default.
+ val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt
+
val akkaConf = ConfigFactory.parseString("""
akka.daemonic = on
akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
@@ -45,10 +52,11 @@ private[spark] object AkkaUtils {
akka.remote.netty.execution-pool-size = %d
akka.actor.default-dispatcher.throughput = %d
akka.remote.log-remote-lifecycle-events = %s
+ akka.remote.netty.write-timeout = %ds
""".format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize,
- if (lifecycleEvents) "on" else "off"))
+ lifecycleEvents, akkaWriteTimeout))
- val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader)
+ val actorSystem = ActorSystem(name, akkaConf)
// Figure out the port number we bound to, in case port was passed as 0. This is a bit of a
// hack because Akka doesn't let you figure out the port through the public API yet.
@@ -60,12 +68,13 @@ private[spark] object AkkaUtils {
/**
* Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to
* handle requests. Returns the bound port or throws a SparkException on failure.
+ * TODO: Not changing ip to host here - is it required ?
*/
- def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route) = {
+ def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, name: String = "HttpServer") = {
val ioWorker = IOExtension(actorSystem).ioBridge()
val httpService = actorSystem.actorOf(Props(HttpServiceActor(route)))
val server = actorSystem.actorOf(
- Props(new HttpServer(ioWorker, SingletonHandler(httpService), ServerSettings())), name = "HttpServer")
+ Props(new HttpServer(ioWorker, SingletonHandler(httpService), ServerSettings())), name = name)
actorSystem.registerOnTermination { actorSystem.stop(ioWorker) }
val timeout = 3.seconds
val future = server.ask(HttpServer.Bind(ip, port))(timeout)
diff --git a/core/src/main/scala/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala
new file mode 100644
index 0000000000..4bc5db8bb7
--- /dev/null
+++ b/core/src/main/scala/spark/util/BoundedPriorityQueue.scala
@@ -0,0 +1,45 @@
+package spark.util
+
+import java.io.Serializable
+import java.util.{PriorityQueue => JPriorityQueue}
+import scala.collection.generic.Growable
+import scala.collection.JavaConverters._
+
+/**
+ * Bounded priority queue. This class wraps the original PriorityQueue
+ * class and modifies it such that only the top K elements are retained.
+ * The top K elements are defined by an implicit Ordering[A].
+ */
+class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A])
+ extends Iterable[A] with Growable[A] with Serializable {
+
+ private val underlying = new JPriorityQueue[A](maxSize, ord)
+
+ override def iterator: Iterator[A] = underlying.iterator.asScala
+
+ override def ++=(xs: TraversableOnce[A]): this.type = {
+ xs.foreach { this += _ }
+ this
+ }
+
+ override def +=(elem: A): this.type = {
+ if (size < maxSize) underlying.offer(elem)
+ else maybeReplaceLowest(elem)
+ this
+ }
+
+ override def +=(elem1: A, elem2: A, elems: A*): this.type = {
+ this += elem1 += elem2 ++= elems
+ }
+
+ override def clear() { underlying.clear() }
+
+ private def maybeReplaceLowest(a: A): Boolean = {
+ val head = underlying.peek()
+ if (head != null && ord.gt(a, head)) {
+ underlying.poll()
+ underlying.offer(a)
+ } else false
+ }
+}
+
diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala
index 5f80180339..2b980340b7 100644
--- a/core/src/main/scala/spark/util/StatCounter.scala
+++ b/core/src/main/scala/spark/util/StatCounter.scala
@@ -37,17 +37,23 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
if (other == this) {
merge(other.copy()) // Avoid overwriting fields in a weird order
} else {
- val delta = other.mu - mu
- if (other.n * 10 < n) {
- mu = mu + (delta * other.n) / (n + other.n)
- } else if (n * 10 < other.n) {
- mu = other.mu - (delta * n) / (n + other.n)
- } else {
- mu = (mu * n + other.mu * other.n) / (n + other.n)
+ if (n == 0) {
+ mu = other.mu
+ m2 = other.m2
+ n = other.n
+ } else if (other.n != 0) {
+ val delta = other.mu - mu
+ if (other.n * 10 < n) {
+ mu = mu + (delta * other.n) / (n + other.n)
+ } else if (n * 10 < other.n) {
+ mu = other.mu - (delta * n) / (n + other.n)
+ } else {
+ mu = (mu * n + other.mu * other.n) / (n + other.n)
+ }
+ m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
+ n += other.n
}
- m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
- n += other.n
- this
+ this
}
}
diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala
index 4afba0a4c3..e95ca1fc8e 100644
--- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala
@@ -3,6 +3,7 @@ package spark.util
import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConversions
import scala.collection.mutable.Map
+import spark.scheduler.MapStatus
/**
* This is a custom implementation of scala.collection.mutable.Map which stores the insertion
@@ -42,6 +43,13 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging {
this
}
+ // Should we return previous value directly or as Option ?
+ def putIfAbsent(key: A, value: B): Option[B] = {
+ val prev = internalMap.putIfAbsent(key, (value, currentTime))
+ if (prev != null) Some(prev._1) else None
+ }
+
+
override def -= (key: A): this.type = {
internalMap.remove(key)
this
diff --git a/core/src/main/scala/spark/util/TimedIterator.scala b/core/src/main/scala/spark/util/TimedIterator.scala
deleted file mode 100644
index 539b01f4ce..0000000000
--- a/core/src/main/scala/spark/util/TimedIterator.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-package spark.util
-
-/**
- * A utility for tracking the total time an iterator takes to iterate through its elements.
- *
- * In general, this should only be used if you expect it to take a considerable amount of time
- * (eg. milliseconds) to get each element -- otherwise, the timing won't be very accurate,
- * and you are probably just adding more overhead
- */
-class TimedIterator[+A](val sub: Iterator[A]) extends Iterator[A] {
- private var netMillis = 0l
- private var nElems = 0
- def hasNext = {
- val start = System.currentTimeMillis()
- val r = sub.hasNext
- val end = System.currentTimeMillis()
- netMillis += (end - start)
- r
- }
- def next = {
- val start = System.currentTimeMillis()
- val r = sub.next
- val end = System.currentTimeMillis()
- netMillis += (end - start)
- nElems += 1
- r
- }
-
- def getNetMillis = netMillis
- def getAverageTimePerItem = netMillis / nElems.toDouble
-
-}
diff --git a/core/src/main/twirl/spark/deploy/master/app_details.scala.html b/core/src/main/twirl/spark/deploy/master/app_details.scala.html
index 301a7e2124..5e5e5de551 100644
--- a/core/src/main/twirl/spark/deploy/master/app_details.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/app_details.scala.html
@@ -9,19 +9,17 @@
<li><strong>ID:</strong> @app.id</li>
<li><strong>Description:</strong> @app.desc.name</li>
<li><strong>User:</strong> @app.desc.user</li>
- <li><strong>Cores:</strong>
- @app.desc.cores
- (@app.coresGranted Granted
- @if(app.desc.cores == Integer.MAX_VALUE) {
-
+ <li><strong>Cores:</strong>
+ @if(app.desc.maxCores == Integer.MAX_VALUE) {
+ Unlimited (@app.coresGranted granted)
} else {
- , @app.coresLeft
+ @app.desc.maxCores (@app.coresGranted granted, @app.coresLeft left)
}
- )
</li>
<li><strong>Memory per Slave:</strong> @app.desc.memoryPerSlave</li>
<li><strong>Submit Date:</strong> @app.submitDate</li>
<li><strong>State:</strong> @app.state</li>
+ <li><strong><a href="@app.appUiUrl">Application Detail UI</a></strong></li>
</ul>
</div>
</div>
diff --git a/core/src/main/twirl/spark/deploy/master/executor_row.scala.html b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html
index d2d80fad48..21e72c7aab 100644
--- a/core/src/main/twirl/spark/deploy/master/executor_row.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html
@@ -3,7 +3,7 @@
<tr>
<td>@executor.id</td>
<td>
- <a href="@executor.worker.webUiAddress">@executor.worker.id</href>
+ <a href="@executor.worker.webUiAddress">@executor.worker.id</a>
</td>
<td>@executor.cores</td>
<td>@executor.memory</td>
diff --git a/core/src/main/twirl/spark/deploy/master/index.scala.html b/core/src/main/twirl/spark/deploy/master/index.scala.html
index ac51a39a51..b9b9f08810 100644
--- a/core/src/main/twirl/spark/deploy/master/index.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/index.scala.html
@@ -2,7 +2,7 @@
@import spark.deploy.master._
@import spark.Utils
-@spark.common.html.layout(title = "Spark Master on " + state.host) {
+@spark.common.html.layout(title = "Spark Master on " + state.host + ":" + state.port) {
<!-- Cluster Details -->
<div class="row">
diff --git a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html
index be69e9bf02..46277ca421 100644
--- a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html
@@ -4,7 +4,7 @@
<tr>
<td>
- <a href="@worker.webUiAddress">@worker.id</href>
+ <a href="@worker.webUiAddress">@worker.id</a>
</td>
<td>@{worker.host}:@{worker.port}</td>
<td>@worker.state</td>
diff --git a/core/src/main/twirl/spark/deploy/worker/index.scala.html b/core/src/main/twirl/spark/deploy/worker/index.scala.html
index c39f769a73..0e66af9284 100644
--- a/core/src/main/twirl/spark/deploy/worker/index.scala.html
+++ b/core/src/main/twirl/spark/deploy/worker/index.scala.html
@@ -1,7 +1,7 @@
@(worker: spark.deploy.WorkerState)
@import spark.Utils
-@spark.common.html.layout(title = "Spark Worker on " + worker.host) {
+@spark.common.html.layout(title = "Spark Worker on " + worker.host + ":" + worker.port) {
<!-- Worker Details -->
<div class="row">
diff --git a/core/src/main/twirl/spark/storage/worker_table.scala.html b/core/src/main/twirl/spark/storage/worker_table.scala.html
index d54b8de4cc..cd72a688c1 100644
--- a/core/src/main/twirl/spark/storage/worker_table.scala.html
+++ b/core/src/main/twirl/spark/storage/worker_table.scala.html
@@ -12,7 +12,7 @@
<tbody>
@for(status <- workersStatusList) {
<tr>
- <td>@(status.blockManagerId.ip + ":" + status.blockManagerId.port)</td>
+ <td>@(status.blockManagerId.host + ":" + status.blockManagerId.port)</td>
<td>
@(Utils.memoryBytesToString(status.memUsed(prefix)))
(@(Utils.memoryBytesToString(status.memRemaining)) Total Available)
diff --git a/core/src/test/resources/fairscheduler.xml b/core/src/test/resources/fairscheduler.xml
new file mode 100644
index 0000000000..5a688b0ebb
--- /dev/null
+++ b/core/src/test/resources/fairscheduler.xml
@@ -0,0 +1,14 @@
+<allocations>
+<pool name="1">
+ <minShare>2</minShare>
+ <weight>1</weight>
+ <schedulingMode>FIFO</schedulingMode>
+</pool>
+<pool name="2">
+ <minShare>3</minShare>
+ <weight>1</weight>
+ <schedulingMode>FIFO</schedulingMode>
+</pool>
+<pool name="3">
+</pool>
+</allocations>
diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala
index 8836c68ae6..6785787b7e 100644
--- a/core/src/test/scala/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/spark/CheckpointSuite.scala
@@ -28,6 +28,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
}
+ test("basic checkpointing") {
+ val parCollection = sc.makeRDD(1 to 4)
+ val flatMappedRDD = parCollection.flatMap(x => 1 to x)
+ flatMappedRDD.checkpoint()
+ assert(flatMappedRDD.dependencies.head.rdd == parCollection)
+ val result = flatMappedRDD.collect()
+ assert(flatMappedRDD.dependencies.head.rdd != parCollection)
+ assert(flatMappedRDD.collect() === result)
+ }
+
test("RDDs with one-to-one dependencies") {
testCheckpointing(_.map(x => x.toString))
testCheckpointing(_.flatMap(x => 1 to x))
diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala
index 46b74fe5ee..0866fb47b3 100644
--- a/core/src/test/scala/spark/DistributedSuite.scala
+++ b/core/src/test/scala/spark/DistributedSuite.scala
@@ -3,8 +3,10 @@ package spark
import network.ConnectionManagerId
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
+import org.scalatest.concurrent.Timeouts._
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.prop.Checkers
+import org.scalatest.time.{Span, Millis}
import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
import org.scalacheck.Prop._
@@ -16,7 +18,13 @@ import scala.collection.mutable.ArrayBuffer
import SparkContext._
import storage.{GetBlock, BlockManagerWorker, StorageLevel}
-class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter with LocalSparkContext {
+
+class NotSerializableClass
+class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {}
+
+
+class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
+ with LocalSparkContext {
val clusterUrl = "local-cluster[2,1,512]"
@@ -25,6 +33,24 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
System.clearProperty("spark.storage.memoryFraction")
}
+ test("task throws not serializable exception") {
+ // Ensures that executors do not crash when an exn is not serializable. If executors crash,
+ // this test will hang. Correct behavior is that executors don't crash but fail tasks
+ // and the scheduler throws a SparkException.
+
+ // numSlaves must be less than numPartitions
+ val numSlaves = 3
+ val numPartitions = 10
+
+ sc = new SparkContext("local-cluster[%s,1,512]".format(numSlaves), "test")
+ val data = sc.parallelize(1 to 100, numPartitions).
+ map(x => throw new NotSerializableExn(new NotSerializableClass))
+ intercept[SparkException] {
+ data.count()
+ }
+ resetSparkContext()
+ }
+
test("local-cluster format") {
sc = new SparkContext("local-cluster[2,1,512]", "test")
assert(sc.parallelize(1 to 2, 2).count() == 2)
@@ -153,7 +179,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
val blockManager = SparkEnv.get.blockManager
blockManager.master.getLocations(blockId).foreach(id => {
val bytes = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(id.ip, id.port))
+ GetBlock(blockId), ConnectionManagerId(id.host, id.port))
val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList
assert(deserialized === (1 to 100).toList)
})
@@ -196,7 +222,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
sc = new SparkContext(clusterUrl, "test")
val data = sc.parallelize(Seq(true, true), 2)
assert(data.count === 2) // force executors to start
- val masterId = SparkEnv.get.blockManager.blockManagerId
assert(data.map(markNodeIfIdentity).collect.size === 2)
assert(data.map(failOnMarkedIdentity).collect.size === 2)
}
@@ -252,6 +277,42 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
assert(data2.count === 2)
}
}
+
+ test("unpersist RDDs") {
+ DistributedSuite.amMaster = true
+ sc = new SparkContext("local-cluster[3,1,512]", "test")
+ val data = sc.parallelize(Seq(true, false, false, false), 4)
+ data.persist(StorageLevel.MEMORY_ONLY_2)
+ data.count
+ assert(sc.persistentRdds.isEmpty === false)
+ data.unpersist()
+ assert(sc.persistentRdds.isEmpty === true)
+
+ failAfter(Span(3000, Millis)) {
+ try {
+ while (! sc.getRDDStorageInfo.isEmpty) {
+ Thread.sleep(200)
+ }
+ } catch {
+ case _ => { Thread.sleep(10) }
+ // Do nothing. We might see exceptions because block manager
+ // is racing this thread to remove entries from the driver.
+ }
+ }
+ }
+
+ test("job should fail if TaskResult exceeds Akka frame size") {
+ // We must use local-cluster mode since results are returned differently
+ // when running under LocalScheduler:
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+ val akkaFrameSize =
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
+ val rdd = sc.parallelize(Seq(1)).map{x => new Array[Byte](akkaFrameSize)}
+ val exception = intercept[SparkException] {
+ rdd.reduce((x, y) => x)
+ }
+ exception.getMessage should endWith("result exceeded Akka frame size")
+ }
}
object DistributedSuite {
diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala
index 91b48c7456..e61ff7793d 100644
--- a/core/src/test/scala/spark/FileSuite.scala
+++ b/core/src/test/scala/spark/FileSuite.scala
@@ -7,6 +7,8 @@ import scala.io.Source
import com.google.common.io.Files
import org.scalatest.FunSuite
import org.apache.hadoop.io._
+import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodec, GzipCodec}
+
import SparkContext._
@@ -26,6 +28,28 @@ class FileSuite extends FunSuite with LocalSparkContext {
assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4"))
}
+ test("text files (compressed)") {
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val normalDir = new File(tempDir, "output_normal").getAbsolutePath
+ val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath
+ val codec = new DefaultCodec()
+
+ val data = sc.parallelize("a" * 10000, 1)
+ data.saveAsTextFile(normalDir)
+ data.saveAsTextFile(compressedOutputDir, classOf[DefaultCodec])
+
+ val normalFile = new File(normalDir, "part-00000")
+ val normalContent = sc.textFile(normalDir).collect
+ assert(normalContent === Array.fill(10000)("a"))
+
+ val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension)
+ val compressedContent = sc.textFile(compressedOutputDir).collect
+ assert(compressedContent === Array.fill(10000)("a"))
+
+ assert(compressedFile.length < normalFile.length)
+ }
+
test("SequenceFiles") {
sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
@@ -37,6 +61,28 @@ class FileSuite extends FunSuite with LocalSparkContext {
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
}
+ test("SequenceFile (compressed)") {
+ sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val normalDir = new File(tempDir, "output_normal").getAbsolutePath
+ val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath
+ val codec = new DefaultCodec()
+
+ val data = sc.parallelize(Seq.fill(100)("abc"), 1).map(x => (x, x))
+ data.saveAsSequenceFile(normalDir)
+ data.saveAsSequenceFile(compressedOutputDir, Some(classOf[DefaultCodec]))
+
+ val normalFile = new File(normalDir, "part-00000")
+ val normalContent = sc.sequenceFile[String, String](normalDir).collect
+ assert(normalContent === Array.fill(100)("abc", "abc"))
+
+ val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension)
+ val compressedContent = sc.sequenceFile[String, String](compressedOutputDir).collect
+ assert(compressedContent === Array.fill(100)("abc", "abc"))
+
+ assert(compressedFile.length < normalFile.length)
+ }
+
test("SequenceFile with writable key") {
sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index d3dcd3bbeb..d306124fca 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -8,6 +8,7 @@ import java.util.*;
import scala.Tuple2;
import com.google.common.base.Charsets;
+import org.apache.hadoop.io.compress.DefaultCodec;
import com.google.common.io.Files;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
@@ -474,6 +475,19 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void textFilesCompressed() throws IOException {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output").getAbsolutePath();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
+ rdd.saveAsTextFile(outputDir, DefaultCodec.class);
+
+ // Try reading it in as a text file RDD
+ List<String> expected = Arrays.asList("1", "2", "3", "4");
+ JavaRDD<String> readRDD = sc.textFile(outputDir);
+ Assert.assertEquals(expected, readRDD.collect());
+ }
+
+ @Test
public void sequenceFile() {
File tempDir = Files.createTempDir();
String outputDir = new File(tempDir, "output").getAbsolutePath();
@@ -620,6 +634,37 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void hadoopFileCompressed() {
+ File tempDir = Files.createTempDir();
+ String outputDir = new File(tempDir, "output_compressed").getAbsolutePath();
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
+
+ rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
+ @Override
+ public Tuple2<IntWritable, Text> call(Tuple2<Integer, String> pair) {
+ return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
+ }
+ }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class,
+ DefaultCodec.class);
+
+ JavaPairRDD<IntWritable, Text> output = sc.hadoopFile(outputDir,
+ SequenceFileInputFormat.class, IntWritable.class, Text.class);
+
+ Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>,
+ String>() {
+ @Override
+ public String call(Tuple2<IntWritable, Text> x) {
+ return x.toString();
+ }
+ }).collect().toString());
+ }
+
+ @Test
public void zip() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() {
@@ -633,6 +678,32 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void zipPartitions() {
+ JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2);
+ JavaRDD<String> rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2);
+ FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer> sizesFn =
+ new FlatMapFunction2<Iterator<Integer>, Iterator<String>, Integer>() {
+ @Override
+ public Iterable<Integer> call(Iterator<Integer> i, Iterator<String> s) {
+ int sizeI = 0;
+ int sizeS = 0;
+ while (i.hasNext()) {
+ sizeI += 1;
+ i.next();
+ }
+ while (s.hasNext()) {
+ sizeS += 1;
+ s.next();
+ }
+ return Arrays.asList(sizeI, sizeS);
+ }
+ };
+
+ JavaRDD<Integer> sizes = rdd1.zipPartitions(sizesFn, rdd2);
+ Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString());
+ }
+
+ @Test
public void accumulators() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala
index ff00dd05dd..76d5258b02 100644
--- a/core/src/test/scala/spark/LocalSparkContext.scala
+++ b/core/src/test/scala/spark/LocalSparkContext.scala
@@ -27,6 +27,7 @@ object LocalSparkContext {
sc.stop()
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
}
/** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */
@@ -38,4 +39,4 @@ object LocalSparkContext {
}
}
-} \ No newline at end of file
+}
diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
index 3abc584b6a..6e585e1c3a 100644
--- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
@@ -8,7 +8,7 @@ import spark.storage.BlockManagerId
import spark.util.AkkaUtils
class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
-
+
test("compressSize") {
assert(MapOutputTracker.compressSize(0L) === 0)
assert(MapOutputTracker.compressSize(1L) === 1)
@@ -45,13 +45,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
Array(compressedSize1000, compressedSize10000)))
- tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
Array(compressedSize10000, compressedSize1000)))
val statuses = tracker.getServerStatuses(10, 0)
- assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000),
- (BlockManagerId("b", "hostB", 1000), size10000)))
+ assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000, 0), size1000),
+ (BlockManagerId("b", "hostB", 1000, 0), size10000)))
tracker.stop()
}
@@ -64,14 +64,14 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
Array(compressedSize1000, compressedSize1000, compressedSize1000)))
- tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
// As if we had two simulatenous fetch failures
- tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
- tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
// The remaining reduce task might try to grab the output despite the shuffle failure;
// this should cause it to fail, and the scheduler will ignore the failure due to the
@@ -80,16 +80,20 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
}
test("remote fetch") {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0)
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0)
+ System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+
val masterTracker = new MapOutputTracker()
masterTracker.trackerActor = actorSystem.actorOf(
Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker")
-
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", "localhost", 0)
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0)
val slaveTracker = new MapOutputTracker()
slaveTracker.trackerActor = slaveSystem.actorFor(
"akka://spark@localhost:" + boundPort + "/user/MapOutputTracker")
-
+
masterTracker.registerShuffle(10, 1)
masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration)
@@ -98,13 +102,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
masterTracker.registerMapOutput(10, 0, new MapStatus(
- BlockManagerId("a", "hostA", 1000), Array(compressedSize1000)))
+ BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration)
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000), size1000)))
+ Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
- masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
+ masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
masterTracker.incrementGeneration()
slaveTracker.updateGeneration(masterTracker.getGeneration)
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
diff --git a/core/src/test/scala/spark/PairRDDFunctionsSuite.scala b/core/src/test/scala/spark/PairRDDFunctionsSuite.scala
new file mode 100644
index 0000000000..682d2745bf
--- /dev/null
+++ b/core/src/test/scala/spark/PairRDDFunctionsSuite.scala
@@ -0,0 +1,287 @@
+package spark
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashSet
+
+import org.scalatest.FunSuite
+import org.scalatest.prop.Checkers
+import org.scalacheck.Arbitrary._
+import org.scalacheck.Gen
+import org.scalacheck.Prop._
+
+import com.google.common.io.Files
+
+import spark.rdd.ShuffledRDD
+import spark.SparkContext._
+
+class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
+ test("groupByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with duplicates") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with negative key hash codes") {
+ val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesForMinus1 = groups.find(_._1 == -1).get._2
+ assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with many output partitions") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
+ val groups = pairs.groupByKey(10).collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("reduceByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("reduceByKey with collectAsMap") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_).collectAsMap()
+ assert(sums.size === 2)
+ assert(sums(1) === 7)
+ assert(sums(2) === 1)
+ }
+
+ test("reduceByKey with many output partitons") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_, 10).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("reduceByKey with partitioner") {
+ val p = new Partitioner() {
+ def numPartitions = 2
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
+ val sums = pairs.reduceByKey(_+_)
+ assert(sums.collect().toSet === Set((1, 4), (0, 1)))
+ assert(sums.partitioner === Some(p))
+ // count the dependencies to make sure there is only 1 ShuffledRDD
+ val deps = new HashSet[RDD[_]]()
+ def visit(r: RDD[_]) {
+ for (dep <- r.dependencies) {
+ deps += dep.rdd
+ visit(dep.rdd)
+ }
+ }
+ visit(sums)
+ assert(deps.size === 2) // ShuffledRDD, ParallelCollection
+ }
+
+ test("join") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (2, 'x')),
+ (2, (1, 'y')),
+ (2, (1, 'z'))
+ ))
+ }
+
+ test("join all-to-all") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 6)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (1, 'y')),
+ (1, (2, 'x')),
+ (1, (2, 'y')),
+ (1, (3, 'x')),
+ (1, (3, 'y'))
+ ))
+ }
+
+ test("leftOuterJoin") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.leftOuterJoin(rdd2).collect()
+ assert(joined.size === 5)
+ assert(joined.toSet === Set(
+ (1, (1, Some('x'))),
+ (1, (2, Some('x'))),
+ (2, (1, Some('y'))),
+ (2, (1, Some('z'))),
+ (3, (1, None))
+ ))
+ }
+
+ test("rightOuterJoin") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.rightOuterJoin(rdd2).collect()
+ assert(joined.size === 5)
+ assert(joined.toSet === Set(
+ (1, (Some(1), 'x')),
+ (1, (Some(2), 'x')),
+ (2, (Some(1), 'y')),
+ (2, (Some(1), 'z')),
+ (4, (None, 'w'))
+ ))
+ }
+
+ test("join with no matches") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 0)
+ }
+
+ test("join with many output partitions") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.join(rdd2, 10).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (2, 'x')),
+ (2, (1, 'y')),
+ (2, (1, 'z'))
+ ))
+ }
+
+ test("groupWith") {
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.groupWith(rdd2).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))),
+ (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))),
+ (3, (ArrayBuffer(1), ArrayBuffer())),
+ (4, (ArrayBuffer(), ArrayBuffer('w')))
+ ))
+ }
+
+ test("zero-partition RDD") {
+ val emptyDir = Files.createTempDir()
+ val file = sc.textFile(emptyDir.getAbsolutePath)
+ assert(file.partitions.size == 0)
+ assert(file.collect().toList === Nil)
+ // Test that a shuffle on the file works, because this used to be a bug
+ assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
+ }
+
+ test("keys and values") {
+ val rdd = sc.parallelize(Array((1, "a"), (2, "b")))
+ assert(rdd.keys.collect().toList === List(1, 2))
+ assert(rdd.values.collect().toList === List("a", "b"))
+ }
+
+ test("default partitioner uses partition size") {
+ // specify 2000 partitions
+ val a = sc.makeRDD(Array(1, 2, 3, 4), 2000)
+ // do a map, which loses the partitioner
+ val b = a.map(a => (a, (a * 2).toString))
+ // then a group by, and see we didn't revert to 2 partitions
+ val c = b.groupByKey()
+ assert(c.partitions.size === 2000)
+ }
+
+ test("default partitioner uses largest partitioner") {
+ val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2)
+ val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000)
+ val c = a.join(b)
+ assert(c.partitions.size === 2000)
+ }
+
+ test("subtract") {
+ val a = sc.parallelize(Array(1, 2, 3), 2)
+ val b = sc.parallelize(Array(2, 3, 4), 4)
+ val c = a.subtract(b)
+ assert(c.collect().toSet === Set(1))
+ assert(c.partitions.size === a.partitions.size)
+ }
+
+ test("subtract with narrow dependency") {
+ // use a deterministic partitioner
+ val p = new Partitioner() {
+ def numPartitions = 5
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ // partitionBy so we have a narrow dependency
+ val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
+ // more partitions/no partitioner so a shuffle dependency
+ val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
+ val c = a.subtract(b)
+ assert(c.collect().toSet === Set((1, "a"), (3, "c")))
+ // Ideally we could keep the original partitioner...
+ assert(c.partitioner === None)
+ }
+
+ test("subtractByKey") {
+ val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
+ val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4)
+ val c = a.subtractByKey(b)
+ assert(c.collect().toSet === Set((1, "a"), (1, "a")))
+ assert(c.partitions.size === a.partitions.size)
+ }
+
+ test("subtractByKey with narrow dependency") {
+ // use a deterministic partitioner
+ val p = new Partitioner() {
+ def numPartitions = 5
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ // partitionBy so we have a narrow dependency
+ val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
+ // more partitions/no partitioner so a shuffle dependency
+ val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
+ val c = a.subtractByKey(b)
+ assert(c.collect().toSet === Set((1, "a"), (1, "a")))
+ assert(c.partitioner.get === p)
+ }
+
+ test("foldByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.foldByKey(0)(_+_).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("foldByKey with mutable result type") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache()
+ // Fold the values using in-place mutation
+ val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect()
+ assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1))))
+ // Check that the mutable objects in the original RDD were not changed
+ assert(bufs.collect().toSet === Set(
+ (1, ArrayBuffer(1)),
+ (1, ArrayBuffer(2)),
+ (1, ArrayBuffer(3)),
+ (1, ArrayBuffer(1)),
+ (2, ArrayBuffer(1))))
+ }
+}
diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala
index 60db759c25..99e433e3bd 100644
--- a/core/src/test/scala/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/spark/PartitioningSuite.scala
@@ -1,13 +1,13 @@
package spark
import org.scalatest.FunSuite
-
import scala.collection.mutable.ArrayBuffer
-
import SparkContext._
+import spark.util.StatCounter
+import scala.math.abs
+
+class PartitioningSuite extends FunSuite with SharedSparkContext {
-class PartitioningSuite extends FunSuite with LocalSparkContext {
-
test("HashPartitioner equality") {
val p2 = new HashPartitioner(2)
val p4 = new HashPartitioner(4)
@@ -21,8 +21,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
}
test("RangePartitioner equality") {
- sc = new SparkContext("local", "test")
-
// Make an RDD where all the elements are the same so that the partition range bounds
// are deterministically all the same.
val rdd = sc.parallelize(Seq(1, 1, 1, 1)).map(x => (x, x))
@@ -50,7 +48,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
}
test("HashPartitioner not equal to RangePartitioner") {
- sc = new SparkContext("local", "test")
val rdd = sc.parallelize(1 to 10).map(x => (x, x))
val rangeP2 = new RangePartitioner(2, rdd)
val hashP2 = new HashPartitioner(2)
@@ -61,8 +58,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
}
test("partitioner preservation") {
- sc = new SparkContext("local", "test")
-
val rdd = sc.parallelize(1 to 10, 4).map(x => (x, x))
val grouped2 = rdd.groupByKey(2)
@@ -101,7 +96,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
}
test("partitioning Java arrays should fail") {
- sc = new SparkContext("local", "test")
val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x))
val arrPairs: RDD[(Array[Int], Int)] =
sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x))
@@ -120,4 +114,20 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array"))
}
+
+ test("zero-length partitions should be correctly handled") {
+ // Create RDD with some consecutive empty partitions (including the "first" one)
+ val rdd: RDD[Double] = sc
+ .parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8)
+ .filter(_ >= 0.0)
+
+ // Run the partitions, including the consecutive empty ones, through StatCounter
+ val stats: StatCounter = rdd.stats();
+ assert(abs(6.0 - stats.sum) < 0.01);
+ assert(abs(6.0/2 - rdd.mean) < 0.01);
+ assert(abs(1.0 - rdd.variance) < 0.01);
+ assert(abs(1.0 - rdd.stdev) < 0.01);
+
+ // Add other tests here for classes that should be able to handle empty partitions correctly
+ }
}
diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala
index a6344edf8f..1c9ca50811 100644
--- a/core/src/test/scala/spark/PipedRDDSuite.scala
+++ b/core/src/test/scala/spark/PipedRDDSuite.scala
@@ -3,10 +3,9 @@ package spark
import org.scalatest.FunSuite
import SparkContext._
-class PipedRDDSuite extends FunSuite with LocalSparkContext {
-
+class PipedRDDSuite extends FunSuite with SharedSparkContext {
+
test("basic pipe") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("cat"))
@@ -19,8 +18,45 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext {
assert(c(3) === "4")
}
+ test("advanced pipe") {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val bl = sc.broadcast(List("0"))
+
+ val piped = nums.pipe(Seq("cat"),
+ Map[String, String](),
+ (f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
+ (i:Int, f: String=> Unit) => f(i + "_"))
+
+ val c = piped.collect()
+
+ assert(c.size === 8)
+ assert(c(0) === "0")
+ assert(c(1) === "\u0001")
+ assert(c(2) === "1_")
+ assert(c(3) === "2_")
+ assert(c(4) === "0")
+ assert(c(5) === "\u0001")
+ assert(c(6) === "3_")
+ assert(c(7) === "4_")
+
+ val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2)
+ val d = nums1.groupBy(str=>str.split("\t")(0)).
+ pipe(Seq("cat"),
+ Map[String, String](),
+ (f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
+ (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect()
+ assert(d.size === 8)
+ assert(d(0) === "0")
+ assert(d(1) === "\u0001")
+ assert(d(2) === "b\t2_")
+ assert(d(3) === "b\t4_")
+ assert(d(4) === "0")
+ assert(d(5) === "\u0001")
+ assert(d(6) === "a\t1_")
+ assert(d(7) === "a\t3_")
+ }
+
test("pipe with env variable") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
val c = piped.collect()
@@ -30,7 +66,6 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext {
}
test("pipe with non-zero exit status") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe("cat nonexistent_file")
intercept[SparkException] {
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index 7fbdd44340..d8db69b1c9 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -2,13 +2,14 @@ package spark
import scala.collection.mutable.HashMap
import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts._
+import org.scalatest.time.{Span, Millis}
import spark.SparkContext._
-import spark.rdd.{CoalescedRDD, CoGroupedRDD, PartitionPruningRDD, ShuffledRDD}
+import spark.rdd.{CoalescedRDD, CoGroupedRDD, EmptyRDD, PartitionPruningRDD, ShuffledRDD}
-class RDDSuite extends FunSuite with LocalSparkContext {
+class RDDSuite extends FunSuite with SharedSparkContext {
test("basic operations") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
@@ -44,7 +45,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("SparkContext.union") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(sc.union(nums).collect().toList === List(1, 2, 3, 4))
assert(sc.union(nums, nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
@@ -53,7 +53,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("aggregate") {
- sc = new SparkContext("local", "test")
val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3)))
type StringMap = HashMap[String, Int]
val emptyMap = new StringMap {
@@ -73,27 +72,7 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
}
- test("basic checkpointing") {
- import java.io.File
- val checkpointDir = File.createTempFile("temp", "")
- checkpointDir.delete()
-
- sc = new SparkContext("local", "test")
- sc.setCheckpointDir(checkpointDir.toString)
- val parCollection = sc.makeRDD(1 to 4)
- val flatMappedRDD = parCollection.flatMap(x => 1 to x)
- flatMappedRDD.checkpoint()
- assert(flatMappedRDD.dependencies.head.rdd == parCollection)
- val result = flatMappedRDD.collect()
- Thread.sleep(1000)
- assert(flatMappedRDD.dependencies.head.rdd != parCollection)
- assert(flatMappedRDD.collect() === result)
-
- checkpointDir.deleteOnExit()
- }
-
test("basic caching") {
- sc = new SparkContext("local", "test")
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
assert(rdd.collect().toList === List(1, 2, 3, 4))
assert(rdd.collect().toList === List(1, 2, 3, 4))
@@ -101,7 +80,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("caching with failures") {
- sc = new SparkContext("local", "test")
val onlySplit = new Partition { override def index: Int = 0 }
var shouldFail = true
val rdd = new RDD[Int](sc, Nil) {
@@ -123,38 +101,26 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(rdd.collect().toList === List(1, 2, 3, 4))
}
- test("cogrouped RDDs") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.makeRDD(Array((1, "one"), (1, "another one"), (2, "two"), (3, "three")), 2)
- val rdd2 = sc.makeRDD(Array((1, "one1"), (1, "another one1"), (2, "two1")), 2)
-
- // Use cogroup function
- val cogrouped = rdd1.cogroup(rdd2).collectAsMap()
- assert(cogrouped(1) === (Seq("one", "another one"), Seq("one1", "another one1")))
- assert(cogrouped(2) === (Seq("two"), Seq("two1")))
- assert(cogrouped(3) === (Seq("three"), Seq()))
-
- // Construct CoGroupedRDD directly, with map side combine enabled
- val cogrouped1 = new CoGroupedRDD[Int](
- Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]),
- new HashPartitioner(3),
- true).collectAsMap()
- assert(cogrouped1(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1")))
- assert(cogrouped1(2).toSeq === Seq(Seq("two"), Seq("two1")))
- assert(cogrouped1(3).toSeq === Seq(Seq("three"), Seq()))
+ test("empty RDD") {
+ val empty = new EmptyRDD[Int](sc)
+ assert(empty.count === 0)
+ assert(empty.collect().size === 0)
- // Construct CoGroupedRDD directly, with map side combine disabled
- val cogrouped2 = new CoGroupedRDD[Int](
- Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]),
- new HashPartitioner(3),
- false).collectAsMap()
- assert(cogrouped2(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1")))
- assert(cogrouped2(2).toSeq === Seq(Seq("two"), Seq("two1")))
- assert(cogrouped2(3).toSeq === Seq(Seq("three"), Seq()))
+ val thrown = intercept[UnsupportedOperationException]{
+ empty.reduce(_+_)
+ }
+ assert(thrown.getMessage.contains("empty"))
+
+ val emptyKv = new EmptyRDD[(Int, Int)](sc)
+ val rdd = sc.parallelize(1 to 2, 2).map(x => (x, x))
+ assert(rdd.join(emptyKv).collect().size === 0)
+ assert(rdd.rightOuterJoin(emptyKv).collect().size === 0)
+ assert(rdd.leftOuterJoin(emptyKv).collect().size === 2)
+ assert(rdd.cogroup(emptyKv).collect().size === 2)
+ assert(rdd.union(emptyKv).collect().size === 2)
}
- test("coalesced RDDs") {
- sc = new SparkContext("local", "test")
+ test("cogrouped RDDs") {
val data = sc.parallelize(1 to 10, 10)
val coalesced1 = data.coalesce(2)
@@ -192,7 +158,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("zipped RDDs") {
- sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val zipped = nums.zip(nums.map(_ + 1.0))
assert(zipped.glom().map(_.toList).collect().toList ===
@@ -204,7 +169,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("partition pruning") {
- sc = new SparkContext("local", "test")
val data = sc.parallelize(1 to 10, 10)
// Note that split number starts from 0, so > 8 means only 10th partition left.
val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8)
@@ -216,7 +180,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
test("mapWith") {
import java.util.Random
- sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.mapWith(
(index: Int) => new Random(index + 42))
@@ -235,7 +198,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
test("flatMapWith") {
import java.util.Random
- sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.flatMapWith(
(index: Int) => new Random(index + 42))
@@ -257,7 +219,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
test("filterWith") {
import java.util.Random
- sc = new SparkContext("local", "test")
val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
val sample = ints.filterWith(
(index: Int) => new Random(index + 42))
@@ -273,4 +234,21 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(sample.size === checkSample.size)
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
}
+
+ test("top with predefined ordering") {
+ val nums = Array.range(1, 100000)
+ val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)
+ val topK = ints.top(5)
+ assert(topK.size === 5)
+ assert(topK.sorted === nums.sorted.takeRight(5))
+ }
+
+ test("top with custom ordering") {
+ val words = Vector("a", "b", "c", "d")
+ implicit val ord = implicitly[Ordering[String]].reverse
+ val rdd = sc.makeRDD(words, 2)
+ val topK = rdd.top(2)
+ assert(topK.size === 2)
+ assert(topK.sorted === Array("b", "a"))
+ }
}
diff --git a/core/src/test/scala/spark/SharedSparkContext.scala b/core/src/test/scala/spark/SharedSparkContext.scala
new file mode 100644
index 0000000000..1da79f9824
--- /dev/null
+++ b/core/src/test/scala/spark/SharedSparkContext.scala
@@ -0,0 +1,25 @@
+package spark
+
+import org.scalatest.Suite
+import org.scalatest.BeforeAndAfterAll
+
+/** Shares a local `SparkContext` between all tests in a suite and closes it at the end */
+trait SharedSparkContext extends BeforeAndAfterAll { self: Suite =>
+
+ @transient private var _sc: SparkContext = _
+
+ def sc: SparkContext = _sc
+
+ override def beforeAll() {
+ _sc = new SparkContext("local", "test")
+ super.beforeAll()
+ }
+
+ override def afterAll() {
+ if (_sc != null) {
+ LocalSparkContext.stop(_sc)
+ _sc = null
+ }
+ super.afterAll()
+ }
+}
diff --git a/core/src/test/scala/spark/ShuffleNettySuite.scala b/core/src/test/scala/spark/ShuffleNettySuite.scala
new file mode 100644
index 0000000000..bfaffa953e
--- /dev/null
+++ b/core/src/test/scala/spark/ShuffleNettySuite.scala
@@ -0,0 +1,17 @@
+package spark
+
+import org.scalatest.BeforeAndAfterAll
+
+
+class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll {
+
+ // This test suite should run all tests in ShuffleSuite with Netty shuffle mode.
+
+ override def beforeAll(configMap: Map[String, Any]) {
+ System.setProperty("spark.shuffle.use.netty", "true")
+ }
+
+ override def afterAll(configMap: Map[String, Any]) {
+ System.setProperty("spark.shuffle.use.netty", "false")
+ }
+}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 2b2a90defa..950218fa28 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -16,54 +16,9 @@ import spark.rdd.ShuffledRDD
import spark.SparkContext._
class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
-
- test("groupByKey") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
- val groups = pairs.groupByKey().collect()
- assert(groups.size === 2)
- val valuesFor1 = groups.find(_._1 == 1).get._2
- assert(valuesFor1.toList.sorted === List(1, 2, 3))
- val valuesFor2 = groups.find(_._1 == 2).get._2
- assert(valuesFor2.toList.sorted === List(1))
- }
-
- test("groupByKey with duplicates") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
- val groups = pairs.groupByKey().collect()
- assert(groups.size === 2)
- val valuesFor1 = groups.find(_._1 == 1).get._2
- assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
- val valuesFor2 = groups.find(_._1 == 2).get._2
- assert(valuesFor2.toList.sorted === List(1))
- }
-
- test("groupByKey with negative key hash codes") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1)))
- val groups = pairs.groupByKey().collect()
- assert(groups.size === 2)
- val valuesForMinus1 = groups.find(_._1 == -1).get._2
- assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
- val valuesFor2 = groups.find(_._1 == 2).get._2
- assert(valuesFor2.toList.sorted === List(1))
- }
-
- test("groupByKey with many output partitions") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
- val groups = pairs.groupByKey(10).collect()
- assert(groups.size === 2)
- val valuesFor1 = groups.find(_._1 == 1).get._2
- assert(valuesFor1.toList.sorted === List(1, 2, 3))
- val valuesFor2 = groups.find(_._1 == 2).get._2
- assert(valuesFor2.toList.sorted === List(1))
- }
-
test("groupByKey with compression") {
try {
- System.setProperty("spark.blockManager.compress", "true")
+ System.setProperty("spark.shuffle.compress", "true")
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4)
val groups = pairs.groupByKey(4).collect()
@@ -77,239 +32,100 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
}
}
- test("reduceByKey") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
- val sums = pairs.reduceByKey(_+_).collect()
- assert(sums.toSet === Set((1, 7), (2, 1)))
- }
-
- test("reduceByKey with collectAsMap") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
- val sums = pairs.reduceByKey(_+_).collectAsMap()
- assert(sums.size === 2)
- assert(sums(1) === 7)
- assert(sums(2) === 1)
- }
+ test("shuffle non-zero block size") {
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ val NUM_BLOCKS = 3
- test("reduceByKey with many output partitons") {
- sc = new SparkContext("local", "test")
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
- val sums = pairs.reduceByKey(_+_, 10).collect()
- assert(sums.toSet === Set((1, 7), (2, 1)))
- }
-
- test("reduceByKey with partitioner") {
- sc = new SparkContext("local", "test")
- val p = new Partitioner() {
- def numPartitions = 2
- def getPartition(key: Any) = key.asInstanceOf[Int]
+ val a = sc.parallelize(1 to 10, 2)
+ val b = a.map { x =>
+ (x, new ShuffleSuite.NonJavaSerializableClass(x * 2))
}
- val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
- val sums = pairs.reduceByKey(_+_)
- assert(sums.collect().toSet === Set((1, 4), (0, 1)))
- assert(sums.partitioner === Some(p))
- // count the dependencies to make sure there is only 1 ShuffledRDD
- val deps = new HashSet[RDD[_]]()
- def visit(r: RDD[_]) {
- for (dep <- r.dependencies) {
- deps += dep.rdd
- visit(dep.rdd)
- }
+ // If the Kryo serializer is not used correctly, the shuffle would fail because the
+ // default Java serializer cannot handle the non serializable class.
+ val c = new ShuffledRDD(b, new HashPartitioner(NUM_BLOCKS),
+ classOf[spark.KryoSerializer].getName)
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+
+ assert(c.count === 10)
+
+ // All blocks must have non-zero size
+ (0 until NUM_BLOCKS).foreach { id =>
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
+ assert(statuses.forall(s => s._2 > 0))
}
- visit(sums)
- assert(deps.size === 2) // ShuffledRDD, ParallelCollection
}
- test("join") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
- val joined = rdd1.join(rdd2).collect()
- assert(joined.size === 4)
- assert(joined.toSet === Set(
- (1, (1, 'x')),
- (1, (2, 'x')),
- (2, (1, 'y')),
- (2, (1, 'z'))
- ))
+ test("shuffle serializer") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ val a = sc.parallelize(1 to 10, 2)
+ val b = a.map { x =>
+ (x, new ShuffleSuite.NonJavaSerializableClass(x * 2))
+ }
+ // If the Kryo serializer is not used correctly, the shuffle would fail because the
+ // default Java serializer cannot handle the non serializable class.
+ val c = new ShuffledRDD(b, new HashPartitioner(3), classOf[spark.KryoSerializer].getName)
+ assert(c.count === 10)
}
- test("join all-to-all") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
- val joined = rdd1.join(rdd2).collect()
- assert(joined.size === 6)
- assert(joined.toSet === Set(
- (1, (1, 'x')),
- (1, (1, 'y')),
- (1, (2, 'x')),
- (1, (2, 'y')),
- (1, (3, 'x')),
- (1, (3, 'y'))
- ))
- }
+ test("zero sized blocks") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
- test("leftOuterJoin") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
- val joined = rdd1.leftOuterJoin(rdd2).collect()
- assert(joined.size === 5)
- assert(joined.toSet === Set(
- (1, (1, Some('x'))),
- (1, (2, Some('x'))),
- (2, (1, Some('y'))),
- (2, (1, Some('z'))),
- (3, (1, None))
- ))
- }
+ // 10 partitions from 4 keys
+ val NUM_BLOCKS = 10
+ val a = sc.parallelize(1 to 4, NUM_BLOCKS)
+ val b = a.map(x => (x, x*2))
- test("rightOuterJoin") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
- val joined = rdd1.rightOuterJoin(rdd2).collect()
- assert(joined.size === 5)
- assert(joined.toSet === Set(
- (1, (Some(1), 'x')),
- (1, (Some(2), 'x')),
- (2, (Some(1), 'y')),
- (2, (Some(1), 'z')),
- (4, (None, 'w'))
- ))
- }
+ // NOTE: The default Java serializer doesn't create zero-sized blocks.
+ // So, use Kryo
+ val c = new ShuffledRDD(b, new HashPartitioner(10), classOf[spark.KryoSerializer].getName)
- test("join with no matches") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
- val joined = rdd1.join(rdd2).collect()
- assert(joined.size === 0)
- }
-
- test("join with many output partitions") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
- val joined = rdd1.join(rdd2, 10).collect()
- assert(joined.size === 4)
- assert(joined.toSet === Set(
- (1, (1, 'x')),
- (1, (2, 'x')),
- (2, (1, 'y')),
- (2, (1, 'z'))
- ))
- }
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ assert(c.count === 4)
- test("groupWith") {
- sc = new SparkContext("local", "test")
- val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
- val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
- val joined = rdd1.groupWith(rdd2).collect()
- assert(joined.size === 4)
- assert(joined.toSet === Set(
- (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))),
- (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))),
- (3, (ArrayBuffer(1), ArrayBuffer())),
- (4, (ArrayBuffer(), ArrayBuffer('w')))
- ))
- }
+ val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
+ statuses.map(x => x._2)
+ }
+ val nonEmptyBlocks = blockSizes.filter(x => x > 0)
- test("zero-partition RDD") {
- sc = new SparkContext("local", "test")
- val emptyDir = Files.createTempDir()
- val file = sc.textFile(emptyDir.getAbsolutePath)
- assert(file.partitions.size == 0)
- assert(file.collect().toList === Nil)
- // Test that a shuffle on the file works, because this used to be a bug
- assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
+ // We should have at most 4 non-zero sized partitions
+ assert(nonEmptyBlocks.size <= 4)
}
- test("keys and values") {
- sc = new SparkContext("local", "test")
- val rdd = sc.parallelize(Array((1, "a"), (2, "b")))
- assert(rdd.keys.collect().toList === List(1, 2))
- assert(rdd.values.collect().toList === List("a", "b"))
- }
+ test("zero sized blocks without kryo") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
- test("default partitioner uses partition size") {
- sc = new SparkContext("local", "test")
- // specify 2000 partitions
- val a = sc.makeRDD(Array(1, 2, 3, 4), 2000)
- // do a map, which loses the partitioner
- val b = a.map(a => (a, (a * 2).toString))
- // then a group by, and see we didn't revert to 2 partitions
- val c = b.groupByKey()
- assert(c.partitions.size === 2000)
- }
+ // 10 partitions from 4 keys
+ val NUM_BLOCKS = 10
+ val a = sc.parallelize(1 to 4, NUM_BLOCKS)
+ val b = a.map(x => (x, x*2))
- test("default partitioner uses largest partitioner") {
- sc = new SparkContext("local", "test")
- val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2)
- val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000)
- val c = a.join(b)
- assert(c.partitions.size === 2000)
- }
+ // NOTE: The default Java serializer should create zero-sized blocks
+ val c = new ShuffledRDD(b, new HashPartitioner(10))
- test("subtract") {
- sc = new SparkContext("local", "test")
- val a = sc.parallelize(Array(1, 2, 3), 2)
- val b = sc.parallelize(Array(2, 3, 4), 4)
- val c = a.subtract(b)
- assert(c.collect().toSet === Set(1))
- assert(c.partitions.size === a.partitions.size)
- }
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ assert(c.count === 4)
- test("subtract with narrow dependency") {
- sc = new SparkContext("local", "test")
- // use a deterministic partitioner
- val p = new Partitioner() {
- def numPartitions = 5
- def getPartition(key: Any) = key.asInstanceOf[Int]
+ val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
+ statuses.map(x => x._2)
}
- // partitionBy so we have a narrow dependency
- val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
- // more partitions/no partitioner so a shuffle dependency
- val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
- val c = a.subtract(b)
- assert(c.collect().toSet === Set((1, "a"), (3, "c")))
- // Ideally we could keep the original partitioner...
- assert(c.partitioner === None)
- }
-
- test("subtractByKey") {
- sc = new SparkContext("local", "test")
- val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
- val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4)
- val c = a.subtractByKey(b)
- assert(c.collect().toSet === Set((1, "a"), (1, "a")))
- assert(c.partitions.size === a.partitions.size)
- }
+ val nonEmptyBlocks = blockSizes.filter(x => x > 0)
- test("subtractByKey with narrow dependency") {
- sc = new SparkContext("local", "test")
- // use a deterministic partitioner
- val p = new Partitioner() {
- def numPartitions = 5
- def getPartition(key: Any) = key.asInstanceOf[Int]
- }
- // partitionBy so we have a narrow dependency
- val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
- // more partitions/no partitioner so a shuffle dependency
- val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
- val c = a.subtractByKey(b)
- assert(c.collect().toSet === Set((1, "a"), (1, "a")))
- assert(c.partitioner.get === p)
+ // We should have at most 4 non-zero sized partitions
+ assert(nonEmptyBlocks.size <= 4)
}
-
}
object ShuffleSuite {
+
def mergeCombineException(x: Int, y: Int): Int = {
throw new SparkException("Exception for map-side combine.")
x + y
}
+
+ class NonJavaSerializableClass(val value: Int)
}
diff --git a/core/src/test/scala/spark/SizeEstimatorSuite.scala b/core/src/test/scala/spark/SizeEstimatorSuite.scala
index 9f3aa6628d..c385965c35 100644
--- a/core/src/test/scala/spark/SizeEstimatorSuite.scala
+++ b/core/src/test/scala/spark/SizeEstimatorSuite.scala
@@ -78,7 +78,6 @@ class SizeEstimatorSuite
// Arrays containing nulls should just have one pointer per element
expectResult(56)(SizeEstimator.estimate(new Array[String](10)))
expectResult(56)(SizeEstimator.estimate(new Array[AnyRef](10)))
-
// For object arrays with non-null elements, each object should take one pointer plus
// however many bytes that class takes. (Note that Array.fill calls the code in its
// second parameter separately for each object, so we get distinct objects.)
@@ -115,7 +114,6 @@ class SizeEstimatorSuite
expectResult(48)(SizeEstimator.estimate(DummyString("a")))
expectResult(48)(SizeEstimator.estimate(DummyString("ab")))
expectResult(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
-
resetOrClear("os.arch", arch)
}
diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala
index 495f957e53..f7bf207c68 100644
--- a/core/src/test/scala/spark/SortingSuite.scala
+++ b/core/src/test/scala/spark/SortingSuite.scala
@@ -5,16 +5,14 @@ import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.ShouldMatchers
import SparkContext._
-class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers with Logging {
-
+class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging {
+
test("sortByKey") {
- sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2)
- assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
+ assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
}
test("large array") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@@ -24,7 +22,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("large array with one split") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@@ -32,9 +29,8 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
assert(sorted.partitions.size === 1)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
-
+
test("large array with many partitions") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@@ -42,9 +38,8 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
assert(sorted.partitions.size === 20)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
-
+
test("sort descending") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@@ -52,15 +47,13 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("sort descending with one split") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 1)
assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
}
-
+
test("sort descending with many partitions") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@@ -68,7 +61,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("more partitions than elements") {
- sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 30)
@@ -76,14 +68,12 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("empty RDD") {
- sc = new SparkContext("local", "test")
val pairArr = new Array[(Int, Int)](0)
val pairs = sc.parallelize(pairArr, 2)
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
}
test("partition balancing") {
- sc = new SparkContext("local", "test")
val pairArr = (1 to 1000).map(x => (x, x)).toArray
val sorted = sc.parallelize(pairArr, 4).sortByKey()
assert(sorted.collect() === pairArr.sortBy(_._1))
@@ -99,7 +89,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("partition balancing for descending sort") {
- sc = new SparkContext("local", "test")
val pairArr = (1 to 1000).map(x => (x, x)).toArray
val sorted = sc.parallelize(pairArr, 4).sortByKey(false)
assert(sorted.collect() === pairArr.sortBy(_._1).reverse)
diff --git a/core/src/test/scala/spark/UnpersistSuite.scala b/core/src/test/scala/spark/UnpersistSuite.scala
new file mode 100644
index 0000000000..94776e7572
--- /dev/null
+++ b/core/src/test/scala/spark/UnpersistSuite.scala
@@ -0,0 +1,30 @@
+package spark
+
+import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts._
+import org.scalatest.time.{Span, Millis}
+import spark.SparkContext._
+
+class UnpersistSuite extends FunSuite with LocalSparkContext {
+ test("unpersist RDD") {
+ sc = new SparkContext("local", "test")
+ val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
+ rdd.count
+ assert(sc.persistentRdds.isEmpty === false)
+ rdd.unpersist()
+ assert(sc.persistentRdds.isEmpty === true)
+
+ failAfter(Span(3000, Millis)) {
+ try {
+ while (! sc.getRDDStorageInfo.isEmpty) {
+ Thread.sleep(200)
+ }
+ } catch {
+ case _ => { Thread.sleep(10) }
+ // Do nothing. We might see exceptions because block manager
+ // is racing this thread to remove entries from the driver.
+ }
+ }
+ assert(sc.getRDDStorageInfo.isEmpty === true)
+ }
+}
diff --git a/core/src/test/scala/spark/UtilsSuite.scala b/core/src/test/scala/spark/UtilsSuite.scala
index ed4701574f..4a113e16bf 100644
--- a/core/src/test/scala/spark/UtilsSuite.scala
+++ b/core/src/test/scala/spark/UtilsSuite.scala
@@ -27,24 +27,49 @@ class UtilsSuite extends FunSuite {
assert(os.toByteArray.toList.equals(bytes.toList))
}
- test("memoryStringToMb"){
- assert(Utils.memoryStringToMb("1") == 0)
- assert(Utils.memoryStringToMb("1048575") == 0)
- assert(Utils.memoryStringToMb("3145728") == 3)
+ test("memoryStringToMb") {
+ assert(Utils.memoryStringToMb("1") === 0)
+ assert(Utils.memoryStringToMb("1048575") === 0)
+ assert(Utils.memoryStringToMb("3145728") === 3)
- assert(Utils.memoryStringToMb("1024k") == 1)
- assert(Utils.memoryStringToMb("5000k") == 4)
- assert(Utils.memoryStringToMb("4024k") == Utils.memoryStringToMb("4024K"))
+ assert(Utils.memoryStringToMb("1024k") === 1)
+ assert(Utils.memoryStringToMb("5000k") === 4)
+ assert(Utils.memoryStringToMb("4024k") === Utils.memoryStringToMb("4024K"))
- assert(Utils.memoryStringToMb("1024m") == 1024)
- assert(Utils.memoryStringToMb("5000m") == 5000)
- assert(Utils.memoryStringToMb("4024m") == Utils.memoryStringToMb("4024M"))
+ assert(Utils.memoryStringToMb("1024m") === 1024)
+ assert(Utils.memoryStringToMb("5000m") === 5000)
+ assert(Utils.memoryStringToMb("4024m") === Utils.memoryStringToMb("4024M"))
- assert(Utils.memoryStringToMb("2g") == 2048)
- assert(Utils.memoryStringToMb("3g") == Utils.memoryStringToMb("3G"))
+ assert(Utils.memoryStringToMb("2g") === 2048)
+ assert(Utils.memoryStringToMb("3g") === Utils.memoryStringToMb("3G"))
- assert(Utils.memoryStringToMb("2t") == 2097152)
- assert(Utils.memoryStringToMb("3t") == Utils.memoryStringToMb("3T"))
+ assert(Utils.memoryStringToMb("2t") === 2097152)
+ assert(Utils.memoryStringToMb("3t") === Utils.memoryStringToMb("3T"))
+ }
+
+ test("splitCommandString") {
+ assert(Utils.splitCommandString("") === Seq())
+ assert(Utils.splitCommandString("a") === Seq("a"))
+ assert(Utils.splitCommandString("aaa") === Seq("aaa"))
+ assert(Utils.splitCommandString("a b c") === Seq("a", "b", "c"))
+ assert(Utils.splitCommandString(" a b\t c ") === Seq("a", "b", "c"))
+ assert(Utils.splitCommandString("a 'b c'") === Seq("a", "b c"))
+ assert(Utils.splitCommandString("a 'b c' d") === Seq("a", "b c", "d"))
+ assert(Utils.splitCommandString("'b c'") === Seq("b c"))
+ assert(Utils.splitCommandString("a \"b c\"") === Seq("a", "b c"))
+ assert(Utils.splitCommandString("a \"b c\" d") === Seq("a", "b c", "d"))
+ assert(Utils.splitCommandString("\"b c\"") === Seq("b c"))
+ assert(Utils.splitCommandString("a 'b\" c' \"d' e\"") === Seq("a", "b\" c", "d' e"))
+ assert(Utils.splitCommandString("a\t'b\nc'\nd") === Seq("a", "b\nc", "d"))
+ assert(Utils.splitCommandString("a \"b\\\\c\"") === Seq("a", "b\\c"))
+ assert(Utils.splitCommandString("a \"b\\\"c\"") === Seq("a", "b\"c"))
+ assert(Utils.splitCommandString("a 'b\\\"c'") === Seq("a", "b\\\"c"))
+ assert(Utils.splitCommandString("'a'b") === Seq("ab"))
+ assert(Utils.splitCommandString("'a''b'") === Seq("ab"))
+ assert(Utils.splitCommandString("\"a\"b") === Seq("ab"))
+ assert(Utils.splitCommandString("\"a\"\"b\"") === Seq("ab"))
+ assert(Utils.splitCommandString("''") === Seq(""))
+ assert(Utils.splitCommandString("\"\"") === Seq(""))
}
}
diff --git a/core/src/test/scala/spark/ZippedPartitionsSuite.scala b/core/src/test/scala/spark/ZippedPartitionsSuite.scala
new file mode 100644
index 0000000000..96cb295f45
--- /dev/null
+++ b/core/src/test/scala/spark/ZippedPartitionsSuite.scala
@@ -0,0 +1,33 @@
+package spark
+
+import scala.collection.immutable.NumericRange
+
+import org.scalatest.FunSuite
+import org.scalatest.prop.Checkers
+import org.scalacheck.Arbitrary._
+import org.scalacheck.Gen
+import org.scalacheck.Prop._
+
+import SparkContext._
+
+
+object ZippedPartitionsSuite {
+ def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = {
+ Iterator(i.toArray.size, s.toArray.size, d.toArray.size)
+ }
+}
+
+class ZippedPartitionsSuite extends FunSuite with SharedSparkContext {
+ test("print sizes") {
+ val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2)
+ val data3 = sc.makeRDD(Array(1.0, 2.0), 2)
+
+ val zippedRDD = data1.zipPartitions(ZippedPartitionsSuite.procZippedData, data2, data3)
+
+ val obtainedSizes = zippedRDD.collect()
+ val expectedSizes = Array(2, 3, 1, 2, 3, 1)
+ assert(obtainedSizes.size == 6)
+ assert(obtainedSizes.zip(expectedSizes).forall(x => x._1 == x._2))
+ }
+}
diff --git a/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala
new file mode 100644
index 0000000000..6afb0fa9bc
--- /dev/null
+++ b/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala
@@ -0,0 +1,56 @@
+package spark
+
+import org.scalatest.{ BeforeAndAfter, FunSuite }
+import spark.SparkContext._
+import spark.rdd.JdbcRDD
+import java.sql._
+
+class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
+
+ before {
+ Class.forName("org.apache.derby.jdbc.EmbeddedDriver")
+ val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true")
+ try {
+ val create = conn.createStatement
+ create.execute("""
+ CREATE TABLE FOO(
+ ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1),
+ DATA INTEGER
+ )""")
+ create.close
+ val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)")
+ (1 to 100).foreach { i =>
+ insert.setInt(1, i * 2)
+ insert.executeUpdate
+ }
+ insert.close
+ } catch {
+ case e: SQLException if e.getSQLState == "X0Y32" =>
+ // table exists
+ } finally {
+ conn.close
+ }
+ }
+
+ test("basic functionality") {
+ sc = new SparkContext("local", "test")
+ val rdd = new JdbcRDD(
+ sc,
+ () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") },
+ "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?",
+ 1, 100, 3,
+ (r: ResultSet) => { r.getInt(1) } ).cache
+
+ assert(rdd.count === 100)
+ assert(rdd.reduce(_+_) === 10100)
+ }
+
+ after {
+ try {
+ DriverManager.getConnection("jdbc:derby:;shutdown=true")
+ } catch {
+ case se: SQLException if se.getSQLState == "XJ015" =>
+ // normal shutdown
+ }
+ }
+}
diff --git a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala
new file mode 100644
index 0000000000..8e1ad27e14
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala
@@ -0,0 +1,250 @@
+package spark.scheduler
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+
+import spark._
+import spark.scheduler._
+import spark.scheduler.cluster._
+import scala.collection.mutable.ArrayBuffer
+
+import java.util.Properties
+
+class DummyTaskSetManager(
+ initPriority: Int,
+ initStageId: Int,
+ initNumTasks: Int,
+ clusterScheduler: ClusterScheduler,
+ taskSet: TaskSet)
+ extends ClusterTaskSetManager(clusterScheduler,taskSet) {
+
+ parent = null
+ weight = 1
+ minShare = 2
+ runningTasks = 0
+ priority = initPriority
+ stageId = initStageId
+ name = "TaskSet_"+stageId
+ override val numTasks = initNumTasks
+ tasksFinished = 0
+
+ override def increaseRunningTasks(taskNum: Int) {
+ runningTasks += taskNum
+ if (parent != null) {
+ parent.increaseRunningTasks(taskNum)
+ }
+ }
+
+ override def decreaseRunningTasks(taskNum: Int) {
+ runningTasks -= taskNum
+ if (parent != null) {
+ parent.decreaseRunningTasks(taskNum)
+ }
+ }
+
+ override def addSchedulable(schedulable: Schedulable) {
+ }
+
+ override def removeSchedulable(schedulable: Schedulable) {
+ }
+
+ override def getSchedulableByName(name: String): Schedulable = {
+ return null
+ }
+
+ override def executorLost(executorId: String, host: String): Unit = {
+ }
+
+ override def slaveOffer(execId: String, host: String, avaiableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
+ if (tasksFinished + runningTasks < numTasks) {
+ increaseRunningTasks(1)
+ return Some(new TaskDescription(0, execId, "task 0:0", null))
+ }
+ return None
+ }
+
+ override def checkSpeculatableTasks(): Boolean = {
+ return true
+ }
+
+ def taskFinished() {
+ decreaseRunningTasks(1)
+ tasksFinished +=1
+ if (tasksFinished == numTasks) {
+ parent.removeSchedulable(this)
+ }
+ }
+
+ def abort() {
+ decreaseRunningTasks(runningTasks)
+ parent.removeSchedulable(this)
+ }
+}
+
+class DummyTask(stageId: Int) extends Task[Int](stageId)
+{
+ def run(attemptId: Long): Int = {
+ return 0
+ }
+}
+
+class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging {
+
+ def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): DummyTaskSetManager = {
+ new DummyTaskSetManager(priority, stage, numTasks, cs , taskSet)
+ }
+
+ def resourceOffer(rootPool: Pool): Int = {
+ val taskSetQueue = rootPool.getSortedTaskSetQueue()
+ /* Just for Test*/
+ for (manager <- taskSetQueue) {
+ logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks))
+ }
+ for (taskSet <- taskSetQueue) {
+ taskSet.slaveOffer("execId_1", "hostname_1", 1) match {
+ case Some(task) =>
+ return taskSet.stageId
+ case None => {}
+ }
+ }
+ -1
+ }
+
+ def checkTaskSetId(rootPool: Pool, expectedTaskSetId: Int) {
+ assert(resourceOffer(rootPool) === expectedTaskSetId)
+ }
+
+ test("FIFO Scheduler Test") {
+ sc = new SparkContext("local", "ClusterSchedulerSuite")
+ val clusterScheduler = new ClusterScheduler(sc)
+ var tasks = ArrayBuffer[Task[_]]()
+ val task = new DummyTask(0)
+ tasks += task
+ val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
+
+ val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0)
+ val schedulableBuilder = new FIFOSchedulableBuilder(rootPool)
+ schedulableBuilder.buildPools()
+
+ val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, clusterScheduler, taskSet)
+ val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, clusterScheduler, taskSet)
+ val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, clusterScheduler, taskSet)
+ schedulableBuilder.addTaskSetManager(taskSetManager0, null)
+ schedulableBuilder.addTaskSetManager(taskSetManager1, null)
+ schedulableBuilder.addTaskSetManager(taskSetManager2, null)
+
+ checkTaskSetId(rootPool, 0)
+ resourceOffer(rootPool)
+ checkTaskSetId(rootPool, 1)
+ resourceOffer(rootPool)
+ taskSetManager1.abort()
+ checkTaskSetId(rootPool, 2)
+ }
+
+ test("Fair Scheduler Test") {
+ sc = new SparkContext("local", "ClusterSchedulerSuite")
+ val clusterScheduler = new ClusterScheduler(sc)
+ var tasks = ArrayBuffer[Task[_]]()
+ val task = new DummyTask(0)
+ tasks += task
+ val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
+
+ val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
+ System.setProperty("spark.fairscheduler.allocation.file", xmlPath)
+ val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
+ val schedulableBuilder = new FairSchedulableBuilder(rootPool)
+ schedulableBuilder.buildPools()
+
+ assert(rootPool.getSchedulableByName("default") != null)
+ assert(rootPool.getSchedulableByName("1") != null)
+ assert(rootPool.getSchedulableByName("2") != null)
+ assert(rootPool.getSchedulableByName("3") != null)
+ assert(rootPool.getSchedulableByName("1").minShare === 2)
+ assert(rootPool.getSchedulableByName("1").weight === 1)
+ assert(rootPool.getSchedulableByName("2").minShare === 3)
+ assert(rootPool.getSchedulableByName("2").weight === 1)
+ assert(rootPool.getSchedulableByName("3").minShare === 2)
+ assert(rootPool.getSchedulableByName("3").weight === 1)
+
+ val properties1 = new Properties()
+ properties1.setProperty("spark.scheduler.cluster.fair.pool","1")
+ val properties2 = new Properties()
+ properties2.setProperty("spark.scheduler.cluster.fair.pool","2")
+
+ val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, clusterScheduler, taskSet)
+ val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, clusterScheduler, taskSet)
+ val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, clusterScheduler, taskSet)
+ schedulableBuilder.addTaskSetManager(taskSetManager10, properties1)
+ schedulableBuilder.addTaskSetManager(taskSetManager11, properties1)
+ schedulableBuilder.addTaskSetManager(taskSetManager12, properties1)
+
+ val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, clusterScheduler, taskSet)
+ val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, clusterScheduler, taskSet)
+ schedulableBuilder.addTaskSetManager(taskSetManager23, properties2)
+ schedulableBuilder.addTaskSetManager(taskSetManager24, properties2)
+
+ checkTaskSetId(rootPool, 0)
+ checkTaskSetId(rootPool, 3)
+ checkTaskSetId(rootPool, 3)
+ checkTaskSetId(rootPool, 1)
+ checkTaskSetId(rootPool, 4)
+ checkTaskSetId(rootPool, 2)
+ checkTaskSetId(rootPool, 2)
+ checkTaskSetId(rootPool, 4)
+
+ taskSetManager12.taskFinished()
+ assert(rootPool.getSchedulableByName("1").runningTasks === 3)
+ taskSetManager24.abort()
+ assert(rootPool.getSchedulableByName("2").runningTasks === 2)
+ }
+
+ test("Nested Pool Test") {
+ sc = new SparkContext("local", "ClusterSchedulerSuite")
+ val clusterScheduler = new ClusterScheduler(sc)
+ var tasks = ArrayBuffer[Task[_]]()
+ val task = new DummyTask(0)
+ tasks += task
+ val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
+
+ val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
+ val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1)
+ val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1)
+ rootPool.addSchedulable(pool0)
+ rootPool.addSchedulable(pool1)
+
+ val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2)
+ val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1)
+ pool0.addSchedulable(pool00)
+ pool0.addSchedulable(pool01)
+
+ val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2)
+ val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1)
+ pool1.addSchedulable(pool10)
+ pool1.addSchedulable(pool11)
+
+ val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, clusterScheduler, taskSet)
+ val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, clusterScheduler, taskSet)
+ pool00.addSchedulable(taskSetManager000)
+ pool00.addSchedulable(taskSetManager001)
+
+ val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, clusterScheduler, taskSet)
+ val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, clusterScheduler, taskSet)
+ pool01.addSchedulable(taskSetManager010)
+ pool01.addSchedulable(taskSetManager011)
+
+ val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, clusterScheduler, taskSet)
+ val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, clusterScheduler, taskSet)
+ pool10.addSchedulable(taskSetManager100)
+ pool10.addSchedulable(taskSetManager101)
+
+ val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, clusterScheduler, taskSet)
+ val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, clusterScheduler, taskSet)
+ pool11.addSchedulable(taskSetManager110)
+ pool11.addSchedulable(taskSetManager111)
+
+ checkTaskSetId(rootPool, 0)
+ checkTaskSetId(rootPool, 4)
+ checkTaskSetId(rootPool, 6)
+ checkTaskSetId(rootPool, 2)
+ }
+}
diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
index 6da58a0f6e..30e6fef950 100644
--- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -44,7 +44,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
override def submitTasks(taskSet: TaskSet) = {
// normally done by TaskSetManager
taskSet.tasks.foreach(_.generation = mapOutputTracker.getGeneration)
- taskSets += taskSet
+ taskSets += taskSet
}
override def setListener(listener: TaskSchedulerListener) = {}
override def defaultParallelism() = 2
@@ -164,7 +164,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
}
}
}
-
+
/** Sends the rdd to the scheduler for scheduling. */
private def submit(
rdd: RDD[_],
@@ -174,7 +174,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
listener: JobListener = listener) {
runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener))
}
-
+
/** Sends TaskSetFailed to the scheduler. */
private def failed(taskSet: TaskSet, message: String) {
runEvent(TaskSetFailed(taskSet, message))
@@ -209,11 +209,11 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener))
assert(results === Map(0 -> 42))
}
-
+
test("run trivial job w/ dependency") {
val baseRdd = makeRdd(1, Nil)
val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
- submit(finalRdd, Array(0))
+ submit(finalRdd, Array(0))
complete(taskSets(0), Seq((Success, 42)))
assert(results === Map(0 -> 42))
}
@@ -250,7 +250,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
complete(taskSets(1), Seq((Success, 42)))
assert(results === Map(0 -> 42))
}
-
+
test("run trivial shuffle with fetch failure") {
val shuffleMapRdd = makeRdd(2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
@@ -271,7 +271,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
// have the 2nd attempt pass
complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
// we can see both result blocks now
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.ip) === Array("hostA", "hostB"))
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB"))
complete(taskSets(3), Seq((Success, 43)))
assert(results === Map(0 -> 42, 1 -> 43))
}
@@ -385,12 +385,12 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
assert(results === Map(0 -> 42))
}
- /** Assert that the supplied TaskSet has exactly the given preferredLocations. */
+ /** Assert that the supplied TaskSet has exactly the given preferredLocations. Note, converts taskSet's locations to host only. */
private def assertLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
assert(locations.size === taskSet.tasks.size)
for ((expectLocs, taskLocs) <-
taskSet.tasks.map(_.preferredLocations).zip(locations)) {
- assert(expectLocs === taskLocs)
+ assert(expectLocs.map(loc => spark.Utils.parseHostPort(loc)._1) === taskLocs)
}
}
@@ -398,6 +398,6 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
private def makeBlockManagerId(host: String): BlockManagerId =
- BlockManagerId("exec-" + host, host, 12345)
+ BlockManagerId("exec-" + host, host, 12345, 0)
}
diff --git a/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala
new file mode 100644
index 0000000000..699901f1a1
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/JobLoggerSuite.scala
@@ -0,0 +1,104 @@
+package spark.scheduler
+
+import java.util.Properties
+import java.util.concurrent.LinkedBlockingQueue
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import scala.collection.mutable
+import spark._
+import spark.SparkContext._
+
+
+class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+
+ test("inner method") {
+ sc = new SparkContext("local", "joblogger")
+ val joblogger = new JobLogger {
+ def createLogWriterTest(jobID: Int) = createLogWriter(jobID)
+ def closeLogWriterTest(jobID: Int) = closeLogWriter(jobID)
+ def getRddNameTest(rdd: RDD[_]) = getRddName(rdd)
+ def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage)
+ }
+ type MyRDD = RDD[(Int, Int)]
+ def makeRdd(
+ numPartitions: Int,
+ dependencies: List[Dependency[_]]
+ ): MyRDD = {
+ val maxPartition = numPartitions - 1
+ return new MyRDD(sc, dependencies) {
+ override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+ throw new RuntimeException("should not be reached")
+ override def getPartitions = (0 to maxPartition).map(i => new Partition {
+ override def index = i
+ }).toArray
+ }
+ }
+ val jobID = 5
+ val parentRdd = makeRdd(4, Nil)
+ val shuffleDep = new ShuffleDependency(parentRdd, null)
+ val rootRdd = makeRdd(4, List(shuffleDep))
+ val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID)
+ val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID)
+
+ joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4))
+ joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
+ parentRdd.setName("MyRDD")
+ joblogger.getRddNameTest(parentRdd) should be ("MyRDD")
+ joblogger.createLogWriterTest(jobID)
+ joblogger.getJobIDtoPrintWriter.size should be (1)
+ joblogger.buildJobDepTest(jobID, rootStage)
+ joblogger.getJobIDToStages.get(jobID).get.size should be (2)
+ joblogger.getStageIDToJobID.get(0) should be (Some(jobID))
+ joblogger.getStageIDToJobID.get(1) should be (Some(jobID))
+ joblogger.closeLogWriterTest(jobID)
+ joblogger.getStageIDToJobID.size should be (0)
+ joblogger.getJobIDToStages.size should be (0)
+ joblogger.getJobIDtoPrintWriter.size should be (0)
+ }
+
+ test("inner variables") {
+ sc = new SparkContext("local[4]", "joblogger")
+ val joblogger = new JobLogger {
+ override protected def closeLogWriter(jobID: Int) =
+ getJobIDtoPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ }
+ }
+ sc.addSparkListener(joblogger)
+ val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
+ rdd.reduceByKey(_+_).collect()
+
+ joblogger.getLogDir should be ("/tmp/spark")
+ joblogger.getJobIDtoPrintWriter.size should be (1)
+ joblogger.getStageIDToJobID.size should be (2)
+ joblogger.getStageIDToJobID.get(0) should be (Some(0))
+ joblogger.getStageIDToJobID.get(1) should be (Some(0))
+ joblogger.getJobIDToStages.size should be (1)
+ }
+
+
+ test("interface functions") {
+ sc = new SparkContext("local[4]", "joblogger")
+ val joblogger = new JobLogger {
+ var onTaskEndCount = 0
+ var onJobEndCount = 0
+ var onJobStartCount = 0
+ var onStageCompletedCount = 0
+ var onStageSubmittedCount = 0
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1
+ override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1
+ override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1
+ }
+ sc.addSparkListener(joblogger)
+ val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
+ rdd.reduceByKey(_+_).collect()
+
+ joblogger.onJobStartCount should be (1)
+ joblogger.onJobEndCount should be (1)
+ joblogger.onTaskEndCount should be (8)
+ joblogger.onStageSubmittedCount should be (2)
+ joblogger.onStageCompletedCount should be (2)
+ }
+}
diff --git a/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala
new file mode 100644
index 0000000000..8bd813fd14
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala
@@ -0,0 +1,206 @@
+package spark.scheduler
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+
+import spark._
+import spark.scheduler._
+import spark.scheduler.cluster._
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ConcurrentMap, HashMap}
+import java.util.concurrent.Semaphore
+import java.util.concurrent.CountDownLatch
+import java.util.Properties
+
+class Lock() {
+ var finished = false
+ def jobWait() = {
+ synchronized {
+ while(!finished) {
+ this.wait()
+ }
+ }
+ }
+
+ def jobFinished() = {
+ synchronized {
+ finished = true
+ this.notifyAll()
+ }
+ }
+}
+
+object TaskThreadInfo {
+ val threadToLock = HashMap[Int, Lock]()
+ val threadToRunning = HashMap[Int, Boolean]()
+ val threadToStarted = HashMap[Int, CountDownLatch]()
+}
+
+/*
+ * 1. each thread contains one job.
+ * 2. each job contains one stage.
+ * 3. each stage only contains one task.
+ * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure
+ * it will get cpu core resource, and will wait to finished after user manually
+ * release "Lock" and then cluster will contain another free cpu cores.
+ * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue,
+ * thus it will be scheduled later when cluster has free cpu cores.
+ */
+class LocalSchedulerSuite extends FunSuite with LocalSparkContext {
+
+ def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
+
+ TaskThreadInfo.threadToRunning(threadIndex) = false
+ val nums = sc.parallelize(threadIndex to threadIndex, 1)
+ TaskThreadInfo.threadToLock(threadIndex) = new Lock()
+ TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1)
+ new Thread {
+ if (poolName != null) {
+ sc.addLocalProperties("spark.scheduler.cluster.fair.pool",poolName)
+ }
+ override def run() {
+ val ans = nums.map(number => {
+ TaskThreadInfo.threadToRunning(number) = true
+ TaskThreadInfo.threadToStarted(number).countDown()
+ TaskThreadInfo.threadToLock(number).jobWait()
+ TaskThreadInfo.threadToRunning(number) = false
+ number
+ }).collect()
+ assert(ans.toList === List(threadIndex))
+ sem.release()
+ }
+ }.start()
+ }
+
+ test("Local FIFO scheduler end-to-end test") {
+ System.setProperty("spark.cluster.schedulingmode", "FIFO")
+ sc = new SparkContext("local[4]", "test")
+ val sem = new Semaphore(0)
+
+ createThread(1,null,sc,sem)
+ TaskThreadInfo.threadToStarted(1).await()
+ createThread(2,null,sc,sem)
+ TaskThreadInfo.threadToStarted(2).await()
+ createThread(3,null,sc,sem)
+ TaskThreadInfo.threadToStarted(3).await()
+ createThread(4,null,sc,sem)
+ TaskThreadInfo.threadToStarted(4).await()
+ // thread 5 and 6 (stage pending)must meet following two points
+ // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager
+ // queue before executing TaskThreadInfo.threadToLock(1).jobFinished()
+ // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6
+ // So I just use "sleep" 1s here for each thread.
+ // TODO: any better solution?
+ createThread(5,null,sc,sem)
+ Thread.sleep(1000)
+ createThread(6,null,sc,sem)
+ Thread.sleep(1000)
+
+ assert(TaskThreadInfo.threadToRunning(1) === true)
+ assert(TaskThreadInfo.threadToRunning(2) === true)
+ assert(TaskThreadInfo.threadToRunning(3) === true)
+ assert(TaskThreadInfo.threadToRunning(4) === true)
+ assert(TaskThreadInfo.threadToRunning(5) === false)
+ assert(TaskThreadInfo.threadToRunning(6) === false)
+
+ TaskThreadInfo.threadToLock(1).jobFinished()
+ TaskThreadInfo.threadToStarted(5).await()
+
+ assert(TaskThreadInfo.threadToRunning(1) === false)
+ assert(TaskThreadInfo.threadToRunning(2) === true)
+ assert(TaskThreadInfo.threadToRunning(3) === true)
+ assert(TaskThreadInfo.threadToRunning(4) === true)
+ assert(TaskThreadInfo.threadToRunning(5) === true)
+ assert(TaskThreadInfo.threadToRunning(6) === false)
+
+ TaskThreadInfo.threadToLock(3).jobFinished()
+ TaskThreadInfo.threadToStarted(6).await()
+
+ assert(TaskThreadInfo.threadToRunning(1) === false)
+ assert(TaskThreadInfo.threadToRunning(2) === true)
+ assert(TaskThreadInfo.threadToRunning(3) === false)
+ assert(TaskThreadInfo.threadToRunning(4) === true)
+ assert(TaskThreadInfo.threadToRunning(5) === true)
+ assert(TaskThreadInfo.threadToRunning(6) === true)
+
+ TaskThreadInfo.threadToLock(2).jobFinished()
+ TaskThreadInfo.threadToLock(4).jobFinished()
+ TaskThreadInfo.threadToLock(5).jobFinished()
+ TaskThreadInfo.threadToLock(6).jobFinished()
+ sem.acquire(6)
+ }
+
+ test("Local fair scheduler end-to-end test") {
+ sc = new SparkContext("local[8]", "LocalSchedulerSuite")
+ val sem = new Semaphore(0)
+ System.setProperty("spark.cluster.schedulingmode", "FAIR")
+ val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
+ System.setProperty("spark.fairscheduler.allocation.file", xmlPath)
+
+ createThread(10,"1",sc,sem)
+ TaskThreadInfo.threadToStarted(10).await()
+ createThread(20,"2",sc,sem)
+ TaskThreadInfo.threadToStarted(20).await()
+ createThread(30,"3",sc,sem)
+ TaskThreadInfo.threadToStarted(30).await()
+
+ assert(TaskThreadInfo.threadToRunning(10) === true)
+ assert(TaskThreadInfo.threadToRunning(20) === true)
+ assert(TaskThreadInfo.threadToRunning(30) === true)
+
+ createThread(11,"1",sc,sem)
+ TaskThreadInfo.threadToStarted(11).await()
+ createThread(21,"2",sc,sem)
+ TaskThreadInfo.threadToStarted(21).await()
+ createThread(31,"3",sc,sem)
+ TaskThreadInfo.threadToStarted(31).await()
+
+ assert(TaskThreadInfo.threadToRunning(11) === true)
+ assert(TaskThreadInfo.threadToRunning(21) === true)
+ assert(TaskThreadInfo.threadToRunning(31) === true)
+
+ createThread(12,"1",sc,sem)
+ TaskThreadInfo.threadToStarted(12).await()
+ createThread(22,"2",sc,sem)
+ TaskThreadInfo.threadToStarted(22).await()
+ createThread(32,"3",sc,sem)
+
+ assert(TaskThreadInfo.threadToRunning(12) === true)
+ assert(TaskThreadInfo.threadToRunning(22) === true)
+ assert(TaskThreadInfo.threadToRunning(32) === false)
+
+ TaskThreadInfo.threadToLock(10).jobFinished()
+ TaskThreadInfo.threadToStarted(32).await()
+
+ assert(TaskThreadInfo.threadToRunning(32) === true)
+
+ //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager
+ // queue so that cluster will assign free cpu core to stage 23 after stage 11 finished.
+ //2. priority of 23 and 33 will be meaningless as using fair scheduler here.
+ createThread(23,"2",sc,sem)
+ createThread(33,"3",sc,sem)
+ Thread.sleep(1000)
+
+ TaskThreadInfo.threadToLock(11).jobFinished()
+ TaskThreadInfo.threadToStarted(23).await()
+
+ assert(TaskThreadInfo.threadToRunning(23) === true)
+ assert(TaskThreadInfo.threadToRunning(33) === false)
+
+ TaskThreadInfo.threadToLock(12).jobFinished()
+ TaskThreadInfo.threadToStarted(33).await()
+
+ assert(TaskThreadInfo.threadToRunning(33) === true)
+
+ TaskThreadInfo.threadToLock(20).jobFinished()
+ TaskThreadInfo.threadToLock(21).jobFinished()
+ TaskThreadInfo.threadToLock(22).jobFinished()
+ TaskThreadInfo.threadToLock(23).jobFinished()
+ TaskThreadInfo.threadToLock(30).jobFinished()
+ TaskThreadInfo.threadToLock(31).jobFinished()
+ TaskThreadInfo.threadToLock(32).jobFinished()
+ TaskThreadInfo.threadToLock(33).jobFinished()
+
+ sem.acquire(11)
+ }
+}
diff --git a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
index 2f5af10e69..48aa67c543 100644
--- a/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/spark/scheduler/SparkListenerSuite.scala
@@ -57,7 +57,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
taskMetrics.shuffleReadMetrics should be ('defined)
val sm = taskMetrics.shuffleReadMetrics.get
sm.totalBlocksFetched should be > (0)
- sm.shuffleReadMillis should be > (0l)
sm.localBlocksFetched should be > (0)
sm.remoteBlocksFetched should be (0)
sm.remoteBytesRead should be (0l)
@@ -78,7 +77,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
class SaveStageInfo extends SparkListener {
val stageInfos = mutable.Buffer[StageInfo]()
- def onStageCompleted(stage: StageCompleted) {
+ override def onStageCompleted(stage: StageCompleted) {
stageInfos += stage.stageInfo
}
}
diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
index b8c0f6fb76..b9d5f9668e 100644
--- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
@@ -15,8 +15,10 @@ import org.scalatest.time.SpanSugar._
import spark.JavaSerializer
import spark.KryoSerializer
import spark.SizeEstimator
+import spark.util.AkkaUtils
import spark.util.ByteBufferInputStream
+
class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester {
var store: BlockManager = null
var store2: BlockManager = null
@@ -31,7 +33,11 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
val serializer = new KryoSerializer
before {
- actorSystem = ActorSystem("test")
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0)
+ this.actorSystem = actorSystem
+ System.setProperty("spark.driver.port", boundPort.toString)
+ System.setProperty("spark.hostPort", "localhost:" + boundPort)
+
master = new BlockManagerMaster(
actorSystem.actorOf(Props(new spark.storage.BlockManagerMasterActor(true))))
@@ -41,9 +47,14 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
oldHeartBeat = System.setProperty("spark.storage.disableBlockManagerHeartBeat", "true")
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
+ // Set some value ...
+ System.setProperty("spark.hostPort", spark.Utils.localHostName() + ":" + 1111)
}
after {
+ System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
+
if (store != null) {
store.stop()
store = null
@@ -88,9 +99,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("BlockManagerId object caching") {
- val id1 = BlockManagerId("e1", "XXX", 1)
- val id2 = BlockManagerId("e1", "XXX", 1) // this should return the same object as id1
- val id3 = BlockManagerId("e1", "XXX", 2) // this should return a different object
+ val id1 = BlockManagerId("e1", "XXX", 1, 0)
+ val id2 = BlockManagerId("e1", "XXX", 1, 0) // this should return the same object as id1
+ val id3 = BlockManagerId("e1", "XXX", 2, 0) // this should return a different object
assert(id2 === id1, "id2 is not same as id1")
assert(id2.eq(id1), "id2 is not the same object as id1")
assert(id3 != id1, "id3 is same as id1")
@@ -113,7 +124,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
// Putting a1, a2 and a3 in memory and telling master only about a1 and a2
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY)
- store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, false)
+ store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, tellMaster = false)
// Checking whether blocks are in memory
assert(store.getSingle("a1") != None, "a1 was not in store")
@@ -159,7 +170,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
// Putting a1, a2 and a3 in memory and telling master only about a1 and a2
store.putSingle("a1-to-remove", a1, StorageLevel.MEMORY_ONLY)
store.putSingle("a2-to-remove", a2, StorageLevel.MEMORY_ONLY)
- store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, false)
+ store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, tellMaster = false)
// Checking whether blocks are in memory and memory size
val memStatus = master.getMemoryStatus.head._2
@@ -198,6 +209,39 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
}
+ test("removing rdd") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+ // Putting a1, a2 and a3 in memory.
+ store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY)
+ master.removeRdd(0, blocking = false)
+
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("rdd_0_0") should be (None)
+ master.getLocations("rdd_0_0") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("rdd_0_1") should be (None)
+ master.getLocations("rdd_0_1") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("nonrddblock") should not be (None)
+ master.getLocations("nonrddblock") should have size (1)
+ }
+
+ store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY)
+ master.removeRdd(0, blocking = true)
+ store.getSingle("rdd_0_0") should be (None)
+ master.getLocations("rdd_0_0") should have size 0
+ store.getSingle("rdd_0_1") should be (None)
+ master.getLocations("rdd_0_1") should have size 0
+ }
+
test("reregistration on heart beat") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
store = new BlockManager("<driver>", actorSystem, master, serializer, 2000)
@@ -226,7 +270,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
master.removeExecutor(store.blockManagerId.executorId)
assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
- store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY)
store.waitForAsyncReregister()
assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master")
@@ -244,7 +288,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
master.removeExecutor(store.blockManagerId.executorId)
val t1 = new Thread {
override def run() {
- store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, true)
+ store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
}
}
val t2 = new Thread {
@@ -454,9 +498,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
- store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, true)
- store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY, true)
- store.put("list3", list3.iterator, StorageLevel.MEMORY_ONLY, true)
+ store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
+ store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
+ store.put("list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
assert(store.get("list2") != None, "list2 was not in store")
assert(store.get("list2").get.size == 2)
assert(store.get("list3") != None, "list3 was not in store")
@@ -465,7 +509,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
assert(store.get("list2") != None, "list2 was not in store")
assert(store.get("list2").get.size == 2)
// At this point list2 was gotten last, so LRU will getSingle rid of list3
- store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, true)
+ store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
assert(store.get("list1") != None, "list1 was not in store")
assert(store.get("list1").get.size == 2)
assert(store.get("list2") != None, "list2 was not in store")
@@ -480,9 +524,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
val list3 = List(new Array[Byte](200), new Array[Byte](200))
val list4 = List(new Array[Byte](200), new Array[Byte](200))
// First store list1 and list2, both in memory, and list3, on disk only
- store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, true)
- store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, true)
- store.put("list3", list3.iterator, StorageLevel.DISK_ONLY, true)
+ store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true)
+ store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY_SER, tellMaster = true)
+ store.put("list3", list3.iterator, StorageLevel.DISK_ONLY, tellMaster = true)
// At this point LRU should not kick in because list3 is only on disk
assert(store.get("list1") != None, "list2 was not in store")
assert(store.get("list1").get.size === 2)
@@ -497,7 +541,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
assert(store.get("list3") != None, "list1 was not in store")
assert(store.get("list3").get.size === 2)
// Now let's add in list4, which uses both disk and memory; list1 should drop out
- store.put("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, true)
+ store.put("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)
assert(store.get("list1") === None, "list1 was in store")
assert(store.get("list2") != None, "list3 was not in store")
assert(store.get("list2").get.size === 2)