aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md2
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala31
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala664
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala52
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockInfo.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala49
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskStore.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala189
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala31
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/BitSet.scala103
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala152
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala271
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala127
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala51
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala19
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala84
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala73
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala148
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala145
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala90
-rw-r--r--docs/cluster-overview.md14
-rw-r--r--docs/ec2-scripts.md2
-rw-r--r--docs/running-on-yarn.md1
-rwxr-xr-xec2/spark_ec2.py68
-rw-r--r--pom.xml6
-rw-r--r--project/SparkBuild.scala5
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala11
-rw-r--r--repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala35
-rwxr-xr-xspark-class13
-rw-r--r--spark-class2.cmd7
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala4
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala83
-rw-r--r--yarn/pom.xml50
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala2
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala276
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala228
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala42
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala220
56 files changed, 2873 insertions, 722 deletions
diff --git a/README.md b/README.md
index 28ad1e4604..456b8060ef 100644
--- a/README.md
+++ b/README.md
@@ -69,7 +69,7 @@ described below.
When developing a Spark application, specify the Hadoop version by adding the
"hadoop-client" artifact to your project's dependencies. For example, if you're
-using Hadoop 1.0.1 and build your application using SBT, add this entry to
+using Hadoop 1.2.1 and build your application using SBT, add this entry to
`libraryDependencies`:
"org.apache.hadoop" % "hadoop-client" % "1.2.1"
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index e8ff4da475..10db2fa7e7 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -24,7 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.Map
import scala.collection.generic.Growable
-import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@@ -144,6 +144,14 @@ class SparkContext(
executorEnvs ++= environment
}
+ // Set SPARK_USER for user who is running SparkContext.
+ val sparkUser = Option {
+ Option(System.getProperty("user.name")).getOrElse(System.getenv("SPARK_USER"))
+ }.getOrElse {
+ SparkContext.SPARK_UNKNOWN_USER
+ }
+ executorEnvs("SPARK_USER") = sparkUser
+
// Create and start the scheduler
private[spark] var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
@@ -255,8 +263,10 @@ class SparkContext(
conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
}
// Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
- for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) {
- conf.set(key.substring("spark.hadoop.".length), System.getProperty(key))
+ Utils.getSystemProperties.foreach { case (key, value) =>
+ if (key.startsWith("spark.hadoop.")) {
+ conf.set(key.substring("spark.hadoop.".length), value)
+ }
}
val bufferSize = System.getProperty("spark.buffer.size", "65536")
conf.set("io.file.buffer.size", bufferSize)
@@ -270,6 +280,12 @@ class SparkContext(
override protected def childValue(parent: Properties): Properties = new Properties(parent)
}
+ private[spark] def getLocalProperties(): Properties = localProperties.get()
+
+ private[spark] def setLocalProperties(props: Properties) {
+ localProperties.set(props)
+ }
+
def initLocalProperties() {
localProperties.set(new Properties())
}
@@ -291,7 +307,7 @@ class SparkContext(
/** Set a human readable description of the current job. */
@deprecated("use setJobGroup", "0.8.1")
def setJobDescription(value: String) {
- setJobGroup("", value)
+ setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
}
/**
@@ -589,7 +605,8 @@ class SparkContext(
val uri = new URI(path)
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
- case _ => path
+ case "local" => "file:" + uri.getPath
+ case _ => path
}
addedFiles(key) = System.currentTimeMillis
@@ -791,11 +808,10 @@ class SparkContext(
val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
+ dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
resultHandler, localProperties.get)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
- result
}
/**
@@ -977,6 +993,8 @@ object SparkContext {
private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
+ private[spark] val SPARK_UNKNOWN_USER = "<unknown>"
+
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 6bc846aa92..fc1537f796 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -17,16 +17,39 @@
package org.apache.spark.deploy
+import java.security.PrivilegedExceptionAction
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.security.UserGroupInformation
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkContext, SparkException}
/**
* Contains util methods to interact with Hadoop from Spark.
*/
private[spark]
class SparkHadoopUtil {
+ val conf = newConfiguration()
+ UserGroupInformation.setConfiguration(conf)
+
+ def runAsUser(user: String)(func: () => Unit) {
+ // if we are already running as the user intended there is no reason to do the doAs. It
+ // will actually break secure HDFS access as it doesn't fill in the credentials. Also if
+ // the user is UNKNOWN then we shouldn't be creating a remote unknown user
+ // (this is actually the path spark on yarn takes) since SPARK_USER is initialized only
+ // in SparkContext.
+ val currentUser = Option(System.getProperty("user.name")).
+ getOrElse(SparkContext.SPARK_UNKNOWN_USER)
+ if (user != SparkContext.SPARK_UNKNOWN_USER && currentUser != user) {
+ val ugi = UserGroupInformation.createRemoteUser(user)
+ ugi.doAs(new PrivilegedExceptionAction[Unit] {
+ def run: Unit = func()
+ })
+ } else {
+ func()
+ }
+ }
/**
* Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
@@ -42,9 +65,9 @@ class SparkHadoopUtil {
def isYarnMode(): Boolean = { false }
}
-
+
object SparkHadoopUtil {
- private val hadoop = {
+ private val hadoop = {
val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
if (yarnMode) {
try {
@@ -56,7 +79,7 @@ object SparkHadoopUtil {
new SparkHadoopUtil
}
}
-
+
def get: SparkHadoopUtil = {
hadoop
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 8fabc95665..fff9cb60c7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -104,7 +104,7 @@ private[spark] class ExecutorRunner(
// SPARK-698: do not call the run.cmd script, as process.destroy()
// fails to kill a process tree on Windows
Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++
- command.arguments.map(substituteVariables)
+ (command.arguments ++ Seq(appId)).map(substituteVariables)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 80ff4c59cb..caee6b01ab 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -111,7 +111,7 @@ private[spark] object CoarseGrainedExecutorBackend {
def main(args: Array[String]) {
if (args.length < 4) {
- //the reason we allow the last frameworkId argument is to make it easy to kill rogue executors
+ //the reason we allow the last appid argument is to make it easy to kill rogue executors
System.err.println(
"Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> <cores> " +
"[<appid>]")
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index b773346df3..5c9bb9db1c 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -25,8 +25,9 @@ import java.util.concurrent._
import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap
-import org.apache.spark.scheduler._
import org.apache.spark._
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.scheduler._
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util.Utils
@@ -129,6 +130,8 @@ private[spark] class Executor(
// Maintains the list of running tasks.
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
+ val sparkUser = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER)
+
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
val tr = new TaskRunner(context, taskId, serializedTask)
runningTasks.put(taskId, tr)
@@ -176,7 +179,7 @@ private[spark] class Executor(
}
}
- override def run() {
+ override def run(): Unit = SparkHadoopUtil.get.runAsUser(sparkUser) { () =>
val startTime = System.currentTimeMillis()
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(replClassLoader)
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 32901a508f..47e958b5e6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -132,6 +132,8 @@ class HadoopRDD[K, V](
override def getPartitions: Array[Partition] = {
val jobConf = getJobConf()
+ // add the credentials here as this can be called before SparkContext initialized
+ SparkHadoopUtil.get.addCredentials(jobConf)
val inputFormat = getInputFormat(jobConf)
if (inputFormat.isInstanceOf[Configurable]) {
inputFormat.asInstanceOf[Configurable].setConf(jobConf)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala
index c7d1295215..c5d7ca0481 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala
@@ -25,6 +25,8 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
+import akka.util.duration._
+
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
@@ -127,21 +129,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext, isLocal: Boolean = f
backend.start()
if (!isLocal && System.getProperty("spark.speculation", "false").toBoolean) {
- new Thread("TaskScheduler speculation check") {
- setDaemon(true)
-
- override def run() {
- logInfo("Starting speculative execution thread")
- while (true) {
- try {
- Thread.sleep(SPECULATION_INTERVAL)
- } catch {
- case e: InterruptedException => {}
- }
- checkSpeculatableTasks()
- }
- }
- }.start()
+ logInfo("Starting speculative execution thread")
+
+ sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds,
+ SPECULATION_INTERVAL milliseconds) {
+ checkSpeculatableTasks()
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 4cef0825dd..d0b21e896e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -417,15 +417,14 @@ class DAGScheduler(
case ExecutorLost(execId) =>
handleExecutorLost(execId)
- case begin: BeginEvent =>
- listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo))
+ case BeginEvent(task, taskInfo) =>
+ listenerBus.post(SparkListenerTaskStart(task, taskInfo))
- case gettingResult: GettingResultEvent =>
- listenerBus.post(SparkListenerTaskGettingResult(gettingResult.task, gettingResult.taskInfo))
+ case GettingResultEvent(task, taskInfo) =>
+ listenerBus.post(SparkListenerTaskGettingResult(task, taskInfo))
- case completion: CompletionEvent =>
- listenerBus.post(SparkListenerTaskEnd(
- completion.task, completion.reason, completion.taskInfo, completion.taskMetrics))
+ case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
+ listenerBus.post(SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics))
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 12b0d74fb5..60927831a1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -1,280 +1,384 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler
-
-import java.io.PrintWriter
-import java.io.File
-import java.io.FileNotFoundException
-import java.text.SimpleDateFormat
-import java.util.{Date, Properties}
-import java.util.concurrent.LinkedBlockingQueue
-
-import scala.collection.mutable.{HashMap, ListBuffer}
-
-import org.apache.spark._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.executor.TaskMetrics
-
-/**
- * A logger class to record runtime information for jobs in Spark. This class outputs one log file
- * per Spark job with information such as RDD graph, tasks start/stop, shuffle information.
- *
- * @param logDirName The base directory for the log files.
- */
-class JobLogger(val logDirName: String) extends SparkListener with Logging {
-
- private val logDir = Option(System.getenv("SPARK_LOG_DIR")).getOrElse("/tmp/spark")
-
- private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
- private val stageIDToJobID = new HashMap[Int, Int]
- private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
- private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
- private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
-
- createLogDir()
- def this() = this(String.valueOf(System.currentTimeMillis()))
-
- // The following 5 functions are used only in testing.
- private[scheduler] def getLogDir = logDir
- private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
- private[scheduler] def getStageIDToJobID = stageIDToJobID
- private[scheduler] def getJobIDToStages = jobIDToStages
- private[scheduler] def getEventQueue = eventQueue
-
- // Create a folder for log files, the folder's name is the creation time of the jobLogger
- protected def createLogDir() {
- val dir = new File(logDir + "/" + logDirName + "/")
- if (!dir.exists() && !dir.mkdirs()) {
- logError("Error creating log directory: " + logDir + "/" + logDirName + "/")
- }
- }
-
- // Create a log file for one job, the file name is the jobID
- protected def createLogWriter(jobID: Int) {
- try {
- val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
- jobIDToPrintWriter += (jobID -> fileWriter)
- } catch {
- case e: FileNotFoundException => e.printStackTrace()
- }
- }
-
- // Close log file, and clean the stage relationship in stageIDToJobID
- protected def closeLogWriter(jobID: Int) =
- jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
- fileWriter.close()
- jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
- stageIDToJobID -= stage.id
- })
- jobIDToPrintWriter -= jobID
- jobIDToStages -= jobID
- }
-
- // Write log information to log file, withTime parameter controls whether to recored
- // time stamp for the information
- protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
- var writeInfo = info
- if (withTime) {
- val date = new Date(System.currentTimeMillis())
- writeInfo = DATE_FORMAT.format(date) + ": " +info
- }
- jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
- }
-
- protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
- stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
-
- protected def buildJobDep(jobID: Int, stage: Stage) {
- if (stage.jobId == jobID) {
- jobIDToStages.get(jobID) match {
- case Some(stageList) => stageList += stage
- case None => val stageList = new ListBuffer[Stage]
- stageList += stage
- jobIDToStages += (jobID -> stageList)
- }
- stageIDToJobID += (stage.id -> jobID)
- stage.parents.foreach(buildJobDep(jobID, _))
- }
- }
-
- protected def recordStageDep(jobID: Int) {
- def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
- var rddList = new ListBuffer[RDD[_]]
- rddList += rdd
- rdd.dependencies.foreach {
- case shufDep: ShuffleDependency[_, _] =>
- case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
- }
- rddList
- }
- jobIDToStages.get(jobID).foreach {_.foreach { stage =>
- var depRddDesc: String = ""
- getRddsInStage(stage.rdd).foreach { rdd =>
- depRddDesc += rdd.id + ","
- }
- var depStageDesc: String = ""
- stage.parents.foreach { stage =>
- depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
- }
- jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
- depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
- " STAGE_DEP=" + depStageDesc, false)
- }
- }
- }
-
- // Generate indents and convert to String
- protected def indentString(indent: Int) = {
- val sb = new StringBuilder()
- for (i <- 1 to indent) {
- sb.append(" ")
- }
- sb.toString()
- }
-
- protected def getRddName(rdd: RDD[_]) = {
- var rddName = rdd.getClass.getName
- if (rdd.name != null) {
- rddName = rdd.name
- }
- rddName
- }
-
- protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
- val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
- jobLogInfo(jobID, indentString(indent) + rddInfo, false)
- rdd.dependencies.foreach {
- case shufDep: ShuffleDependency[_, _] =>
- val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
- jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
- case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
- }
- }
-
- protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
- val stageInfo = if (stage.isShuffleMap) {
- "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
- } else {
- "STAGE_ID=" + stage.id + " RESULT_STAGE"
- }
- if (stage.jobId == jobID) {
- jobLogInfo(jobID, indentString(indent) + stageInfo, false)
- recordRddInStageGraph(jobID, stage.rdd, indent)
- stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
- } else {
- jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
- }
- }
-
- // Record task metrics into job log files
- protected def recordTaskMetrics(stageID: Int, status: String,
- taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
- val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
- " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
- " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
- val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
- val readMetrics = taskMetrics.shuffleReadMetrics match {
- case Some(metrics) =>
- " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
- " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
- " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
- " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
- " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
- " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
- " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
- case None => ""
- }
- val writeMetrics = taskMetrics.shuffleWriteMetrics match {
- case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
- case None => ""
- }
- stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
- }
-
- override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
- stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
- stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
- }
-
- override def onStageCompleted(stageCompleted: StageCompleted) {
- stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
- stageCompleted.stage.stageId))
- }
-
- override def onTaskStart(taskStart: SparkListenerTaskStart) { }
-
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
- val task = taskEnd.task
- val taskInfo = taskEnd.taskInfo
- var taskStatus = ""
- task match {
- case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
- case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
- }
- taskEnd.reason match {
- case Success => taskStatus += " STATUS=SUCCESS"
- recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
- case Resubmitted =>
- taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
- " STAGE_ID=" + task.stageId
- stageLogInfo(task.stageId, taskStatus)
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
- taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
- task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
- mapId + " REDUCE_ID=" + reduceId
- stageLogInfo(task.stageId, taskStatus)
- case OtherFailure(message) =>
- taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
- " STAGE_ID=" + task.stageId + " INFO=" + message
- stageLogInfo(task.stageId, taskStatus)
- case _ =>
- }
- }
-
- override def onJobEnd(jobEnd: SparkListenerJobEnd) {
- val job = jobEnd.job
- var info = "JOB_ID=" + job.jobId
- jobEnd.jobResult match {
- case JobSucceeded => info += " STATUS=SUCCESS"
- case JobFailed(exception, _) =>
- info += " STATUS=FAILED REASON="
- exception.getMessage.split("\\s+").foreach(info += _ + "_")
- case _ =>
- }
- jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
- closeLogWriter(job.jobId)
- }
-
- protected def recordJobProperties(jobID: Int, properties: Properties) {
- if(properties != null) {
- val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
- jobLogInfo(jobID, description, false)
- }
- }
-
- override def onJobStart(jobStart: SparkListenerJobStart) {
- val job = jobStart.job
- val properties = jobStart.properties
- createLogWriter(job.jobId)
- recordJobProperties(job.jobId, properties)
- buildJobDep(job.jobId, job.finalStage)
- recordStageDep(job.jobId)
- recordStageDepGraph(job.jobId, job.finalStage)
- jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
- }
-}
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.{IOException, File, FileNotFoundException, PrintWriter}
+import java.text.SimpleDateFormat
+import java.util.{Date, Properties}
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
+
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * A logger class to record runtime information for jobs in Spark. This class outputs one log file
+ * for each Spark job, containing RDD graph, tasks start/stop, shuffle information.
+ * JobLogger is a subclass of SparkListener, use addSparkListener to add JobLogger to a SparkContext
+ * after the SparkContext is created.
+ * Note that each JobLogger only works for one SparkContext
+ * @param logDirName The base directory for the log files.
+ */
+
+class JobLogger(val user: String, val logDirName: String)
+ extends SparkListener with Logging {
+
+ def this() = this(System.getProperty("user.name", "<unknown>"),
+ String.valueOf(System.currentTimeMillis()))
+
+ private val logDir =
+ if (System.getenv("SPARK_LOG_DIR") != null)
+ System.getenv("SPARK_LOG_DIR")
+ else
+ "/tmp/spark-%s".format(user)
+
+ private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
+ private val stageIDToJobID = new HashMap[Int, Int]
+ private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
+ private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
+
+ createLogDir()
+
+ // The following 5 functions are used only in testing.
+ private[scheduler] def getLogDir = logDir
+ private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
+ private[scheduler] def getStageIDToJobID = stageIDToJobID
+ private[scheduler] def getJobIDToStages = jobIDToStages
+ private[scheduler] def getEventQueue = eventQueue
+
+ /** Create a folder for log files, the folder's name is the creation time of jobLogger */
+ protected def createLogDir() {
+ val dir = new File(logDir + "/" + logDirName + "/")
+ if (dir.exists()) {
+ return
+ }
+ if (dir.mkdirs() == false) {
+ // JobLogger should throw a exception rather than continue to construct this object.
+ throw new IOException("create log directory error:" + logDir + "/" + logDirName + "/")
+ }
+ }
+
+ /**
+ * Create a log file for one job
+ * @param jobID ID of the job
+ * @exception FileNotFoundException Fail to create log file
+ */
+ protected def createLogWriter(jobID: Int) {
+ try {
+ val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
+ jobIDToPrintWriter += (jobID -> fileWriter)
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
+ }
+
+ /**
+ * Close log file, and clean the stage relationship in stageIDToJobID
+ * @param jobID ID of the job
+ */
+ protected def closeLogWriter(jobID: Int) {
+ jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
+ stageIDToJobID -= stage.id
+ })
+ jobIDToPrintWriter -= jobID
+ jobIDToStages -= jobID
+ }
+ }
+
+ /**
+ * Write info into log file
+ * @param jobID ID of the job
+ * @param info Info to be recorded
+ * @param withTime Controls whether to record time stamp before the info, default is true
+ */
+ protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
+ var writeInfo = info
+ if (withTime) {
+ val date = new Date(System.currentTimeMillis())
+ writeInfo = DATE_FORMAT.format(date) + ": " +info
+ }
+ jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
+ }
+
+ /**
+ * Write info into log file
+ * @param stageID ID of the stage
+ * @param info Info to be recorded
+ * @param withTime Controls whether to record time stamp before the info, default is true
+ */
+ protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) {
+ stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
+ }
+
+ /**
+ * Build stage dependency for a job
+ * @param jobID ID of the job
+ * @param stage Root stage of the job
+ */
+ protected def buildJobDep(jobID: Int, stage: Stage) {
+ if (stage.jobId == jobID) {
+ jobIDToStages.get(jobID) match {
+ case Some(stageList) => stageList += stage
+ case None => val stageList = new ListBuffer[Stage]
+ stageList += stage
+ jobIDToStages += (jobID -> stageList)
+ }
+ stageIDToJobID += (stage.id -> jobID)
+ stage.parents.foreach(buildJobDep(jobID, _))
+ }
+ }
+
+ /**
+ * Record stage dependency and RDD dependency for a stage
+ * @param jobID Job ID of the stage
+ */
+ protected def recordStageDep(jobID: Int) {
+ def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
+ var rddList = new ListBuffer[RDD[_]]
+ rddList += rdd
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
+ }
+ rddList
+ }
+ jobIDToStages.get(jobID).foreach {_.foreach { stage =>
+ var depRddDesc: String = ""
+ getRddsInStage(stage.rdd).foreach { rdd =>
+ depRddDesc += rdd.id + ","
+ }
+ var depStageDesc: String = ""
+ stage.parents.foreach { stage =>
+ depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
+ }
+ jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
+ depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
+ " STAGE_DEP=" + depStageDesc, false)
+ }
+ }
+ }
+
+ /**
+ * Generate indents and convert to String
+ * @param indent Number of indents
+ * @return string of indents
+ */
+ protected def indentString(indent: Int): String = {
+ val sb = new StringBuilder()
+ for (i <- 1 to indent) {
+ sb.append(" ")
+ }
+ sb.toString()
+ }
+
+ /**
+ * Get RDD's name
+ * @param rdd Input RDD
+ * @return String of RDD's name
+ */
+ protected def getRddName(rdd: RDD[_]): String = {
+ var rddName = rdd.getClass.getSimpleName
+ if (rdd.name != null) {
+ rddName = rdd.name
+ }
+ rddName
+ }
+
+ /**
+ * Record RDD dependency graph in a stage
+ * @param jobID Job ID of the stage
+ * @param rdd Root RDD of the stage
+ * @param indent Indent number before info
+ */
+ protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
+ val rddInfo =
+ if (rdd.getStorageLevel != StorageLevel.NONE) {
+ "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " CACHED" + " " +
+ rdd.origin + " " + rdd.generator
+ } else {
+ "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " NONE" + " " +
+ rdd.origin + " " + rdd.generator
+ }
+ jobLogInfo(jobID, indentString(indent) + rddInfo, false)
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
+ }
+ }
+
+ /**
+ * Record stage dependency graph of a job
+ * @param jobID Job ID of the stage
+ * @param stage Root stage of the job
+ * @param indent Indent number before info, default is 0
+ */
+ protected def recordStageDepGraph(jobID: Int, stage: Stage, idSet: HashSet[Int], indent: Int = 0) {
+ val stageInfo = if (stage.isShuffleMap) {
+ "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
+ } else {
+ "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ }
+ if (stage.jobId == jobID) {
+ jobLogInfo(jobID, indentString(indent) + stageInfo, false)
+ if (!idSet.contains(stage.id)) {
+ idSet += stage.id
+ recordRddInStageGraph(jobID, stage.rdd, indent)
+ stage.parents.foreach(recordStageDepGraph(jobID, _, idSet, indent + 2))
+ }
+ } else {
+ jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
+ }
+ }
+
+ /**
+ * Record task metrics into job log files, including execution info and shuffle metrics
+ * @param stageID Stage ID of the task
+ * @param status Status info of the task
+ * @param taskInfo Task description info
+ * @param taskMetrics Task running metrics
+ */
+ protected def recordTaskMetrics(stageID: Int, status: String,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
+ " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
+ " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
+ val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
+ val readMetrics = taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics = taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
+ stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
+ }
+
+ /**
+ * When stage is submitted, record stage submit info
+ * @param stageSubmitted Stage submitted event
+ */
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
+ stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
+ stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
+ }
+
+ /**
+ * When stage is completed, record stage completion status
+ * @param stageCompleted Stage completed event
+ */
+ override def onStageCompleted(stageCompleted: StageCompleted) {
+ stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
+ stageCompleted.stage.stageId))
+ }
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart) { }
+
+ /**
+ * When task ends, record task completion status and metrics
+ * @param taskEnd Task end event
+ */
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ val task = taskEnd.task
+ val taskInfo = taskEnd.taskInfo
+ var taskStatus = ""
+ task match {
+ case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
+ case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
+ }
+ taskEnd.reason match {
+ case Success => taskStatus += " STATUS=SUCCESS"
+ recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
+ case Resubmitted =>
+ taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId
+ stageLogInfo(task.stageId, taskStatus)
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
+ task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
+ mapId + " REDUCE_ID=" + reduceId
+ stageLogInfo(task.stageId, taskStatus)
+ case OtherFailure(message) =>
+ taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId + " INFO=" + message
+ stageLogInfo(task.stageId, taskStatus)
+ case _ =>
+ }
+ }
+
+ /**
+ * When job ends, recording job completion status and close log file
+ * @param jobEnd Job end event
+ */
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) {
+ val job = jobEnd.job
+ var info = "JOB_ID=" + job.jobId
+ jobEnd.jobResult match {
+ case JobSucceeded => info += " STATUS=SUCCESS"
+ case JobFailed(exception, _) =>
+ info += " STATUS=FAILED REASON="
+ exception.getMessage.split("\\s+").foreach(info += _ + "_")
+ case _ =>
+ }
+ jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
+ closeLogWriter(job.jobId)
+ }
+
+ /**
+ * Record job properties into job log file
+ * @param jobID ID of the job
+ * @param properties Properties of the job
+ */
+ protected def recordJobProperties(jobID: Int, properties: Properties) {
+ if(properties != null) {
+ val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
+ jobLogInfo(jobID, description, false)
+ }
+ }
+
+ /**
+ * When job starts, record job property and stage graph
+ * @param jobStart Job start event
+ */
+ override def onJobStart(jobStart: SparkListenerJobStart) {
+ val job = jobStart.job
+ val properties = jobStart.properties
+ createLogWriter(job.jobId)
+ recordJobProperties(job.jobId, properties)
+ buildJobDep(job.jobId, job.finalStage)
+ recordStageDep(job.jobId)
+ recordStageDepGraph(job.jobId, job.finalStage, new HashSet[Int])
+ jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 24d97da6eb..1dc71a0428 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -146,26 +146,26 @@ private[spark] class ShuffleMapTask(
metrics = Some(context.taskMetrics)
val blockManager = SparkEnv.get.blockManager
- var shuffle: ShuffleBlocks = null
- var buckets: ShuffleWriterGroup = null
+ val shuffleBlockManager = blockManager.shuffleBlockManager
+ var shuffle: ShuffleWriterGroup = null
+ var success = false
try {
// Obtain all the block writers for shuffle blocks.
val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
- shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
- buckets = shuffle.acquireWriters(partitionId)
+ shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
// Write the map output to its associated buckets.
for (elem <- rdd.iterator(split, context)) {
val pair = elem.asInstanceOf[Product2[Any, Any]]
val bucketId = dep.partitioner.getPartition(pair._1)
- buckets.writers(bucketId).write(pair)
+ shuffle.writers(bucketId).write(pair)
}
// Commit the writes. Get the size of each bucket block (total block size).
var totalBytes = 0L
var totalTime = 0L
- val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
+ val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
writer.commit()
val size = writer.fileSegment().length
totalBytes += size
@@ -179,19 +179,20 @@ private[spark] class ShuffleMapTask(
shuffleMetrics.shuffleWriteTime = totalTime
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
+ success = true
new MapStatus(blockManager.blockManagerId, compressedSizes)
} catch { case e: Exception =>
// If there is an exception from running the task, revert the partial writes
// and throw the exception upstream to Spark.
- if (buckets != null) {
- buckets.writers.foreach(_.revertPartialWrites())
+ if (shuffle != null) {
+ shuffle.writers.foreach(_.revertPartialWrites())
}
throw e
} finally {
// Release the writers back to the shuffle block manager.
- if (shuffle != null && buckets != null) {
- buckets.writers.foreach(_.close())
- shuffle.releaseWriters(buckets)
+ if (shuffle != null && shuffle.writers != null) {
+ shuffle.writers.foreach(_.close())
+ shuffle.releaseWriters(success)
}
// Execute the callbacks on task completion.
context.executeOnCompleteCallbacks()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index f5548fc2da..3bb715e7d0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -88,8 +88,14 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
case StatusUpdate(executorId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) {
- freeCores(executorId) += 1
- makeOffers(executorId)
+ if (executorActor.contains(executorId)) {
+ freeCores(executorId) += 1
+ makeOffers(executorId)
+ } else {
+ // Ignoring the update since we don't know about the executor.
+ val msg = "Ignored task status update (%d state %s) from unknown executor %s with ID %s"
+ logWarning(msg.format(taskId, state, sender, executorId))
+ }
}
case ReviveOffers =>
@@ -176,7 +182,9 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME)
}
- private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ private val timeout = {
+ Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ }
def stopExecutors() {
try {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
index 40fdfcddb1..cec02e945c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -33,6 +33,10 @@ private[spark] class SimrSchedulerBackend(
val tmpPath = new Path(driverFilePath + "_tmp")
val filePath = new Path(driverFilePath)
+ val uiFilePath = driverFilePath + "_ui"
+ val tmpUiPath = new Path(uiFilePath + "_tmp")
+ val uiPath = new Path(uiFilePath)
+
val maxCores = System.getProperty("spark.simr.executor.cores", "1").toInt
override def start() {
@@ -47,6 +51,8 @@ private[spark] class SimrSchedulerBackend(
logInfo("Writing to HDFS file: " + driverFilePath)
logInfo("Writing Akka address: " + driverUrl)
+ logInfo("Writing to HDFS file: " + uiFilePath)
+ logInfo("Writing Spark UI Address: " + sc.ui.appUIAddress)
// Create temporary file to prevent race condition where executors get empty driverUrl file
val temp = fs.create(tmpPath, true)
@@ -56,6 +62,12 @@ private[spark] class SimrSchedulerBackend(
// "Atomic" rename
fs.rename(tmpPath, filePath)
+
+ // Write Spark UI Address to file
+ val uiTemp = fs.create(tmpUiPath, true)
+ uiTemp.writeUTF(sc.ui.appUIAddress)
+ uiTemp.close()
+ fs.rename(tmpUiPath, uiPath)
}
override def stop() {
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 55b25f145a..e748c2275d 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -27,13 +27,17 @@ import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar}
import org.apache.spark.{SerializableWritable, Logging}
import org.apache.spark.broadcast.HttpBroadcast
-import org.apache.spark.storage.{GetBlock,GotBlock, PutBlock, StorageLevel, TestBlockId}
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage._
/**
- * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]].
+ * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]].
*/
class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging {
- private val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
+
+ private val bufferSize = {
+ System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
+ }
def newKryoOutput() = new KryoOutput(bufferSize)
@@ -42,21 +46,11 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging
val kryo = instantiator.newKryo()
val classLoader = Thread.currentThread.getContextClassLoader
- val blockId = TestBlockId("1")
- // Register some commonly used classes
- val toRegister: Seq[AnyRef] = Seq(
- ByteBuffer.allocate(1),
- StorageLevel.MEMORY_ONLY,
- PutBlock(blockId, ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
- GotBlock(blockId, ByteBuffer.allocate(1)),
- GetBlock(blockId),
- 1 to 10,
- 1 until 10,
- 1L to 10L,
- 1L until 10L
- )
-
- for (obj <- toRegister) kryo.register(obj.getClass)
+ // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
+ // Do this before we invoke the user registrator so the user registrator can override this.
+ kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean)
+
+ for (cls <- KryoSerializer.toRegister) kryo.register(cls)
// Allow sending SerializableWritable
kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
@@ -78,10 +72,6 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging
new AllScalaRegistrar().apply(kryo)
kryo.setClassLoader(classLoader)
-
- // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops
- kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean)
-
kryo
}
@@ -165,3 +155,21 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
trait KryoRegistrator {
def registerClasses(kryo: Kryo)
}
+
+private[serializer] object KryoSerializer {
+ // Commonly used classes.
+ private val toRegister: Seq[Class[_]] = Seq(
+ ByteBuffer.allocate(1).getClass,
+ classOf[StorageLevel],
+ classOf[PutBlock],
+ classOf[GotBlock],
+ classOf[GetBlock],
+ classOf[MapStatus],
+ classOf[BlockManagerId],
+ classOf[Array[Byte]],
+ (1 to 10).getClass,
+ (1 until 10).getClass,
+ (1L to 10L).getClass,
+ (1L until 10L).getClass
+ )
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
index dbe0bda615..c8f397609a 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
@@ -19,9 +19,7 @@ package org.apache.spark.storage
import java.util.concurrent.ConcurrentHashMap
-private[storage] trait BlockInfo {
- def level: StorageLevel
- def tellMaster: Boolean
+private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
// To save space, 'pending' and 'failed' are encoded as special sizes:
@volatile var size: Long = BlockInfo.BLOCK_PENDING
private def pending: Boolean = size == BlockInfo.BLOCK_PENDING
@@ -81,17 +79,3 @@ private object BlockInfo {
private val BLOCK_PENDING: Long = -1L
private val BLOCK_FAILED: Long = -2L
}
-
-// All shuffle blocks have the same `level` and `tellMaster` properties,
-// so we can save space by not storing them in each instance:
-private[storage] class ShuffleBlockInfo extends BlockInfo {
- // These need to be defined using 'def' instead of 'val' in order for
- // the compiler to eliminate the fields:
- def level: StorageLevel = StorageLevel.DISK_ONLY
- def tellMaster: Boolean = false
-}
-
-private[storage] class BlockInfoImpl(val level: StorageLevel, val tellMaster: Boolean)
- extends BlockInfo {
- // Intentionally left blank
-} \ No newline at end of file
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 76d537f8e8..a34c95b6f0 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -17,7 +17,7 @@
package org.apache.spark.storage
-import java.io.{InputStream, OutputStream}
+import java.io.{File, InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
import scala.collection.mutable.{HashMap, ArrayBuffer}
@@ -47,7 +47,7 @@ private[spark] class BlockManager(
extends Logging {
val shuffleBlockManager = new ShuffleBlockManager(this)
- val diskBlockManager = new DiskBlockManager(
+ val diskBlockManager = new DiskBlockManager(shuffleBlockManager,
System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
@@ -462,20 +462,10 @@ private[spark] class BlockManager(
* This is currently used for writing shuffle files out. Callers should handle error
* cases.
*/
- def getDiskWriter(blockId: BlockId, filename: String, serializer: Serializer, bufferSize: Int)
+ def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
- val file = diskBlockManager.createBlockFile(blockId, filename, allowAppending = true)
- val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
- writer.registerCloseEventHandler(() => {
- if (shuffleBlockManager.consolidateShuffleFiles) {
- diskBlockManager.mapBlockToFileSegment(blockId, writer.fileSegment())
- }
- val myInfo = new ShuffleBlockInfo()
- blockInfo.put(blockId, myInfo)
- myInfo.markReady(writer.fileSegment().length)
- })
- writer
+ new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
}
/**
@@ -505,7 +495,7 @@ private[spark] class BlockManager(
// to be dropped right after it got put into memory. Note, however, that other threads will
// not be able to get() this block until we call markReady on its BlockInfo.
val myInfo = {
- val tinfo = new BlockInfoImpl(level, tellMaster)
+ val tinfo = new BlockInfo(level, tellMaster)
// Do atomically !
val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 32d2dd0694..469e68fed7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -34,20 +34,12 @@ import org.apache.spark.serializer.{SerializationStream, Serializer}
*/
abstract class BlockObjectWriter(val blockId: BlockId) {
- var closeEventHandler: () => Unit = _
-
def open(): BlockObjectWriter
- def close() {
- closeEventHandler()
- }
+ def close()
def isOpen: Boolean
- def registerCloseEventHandler(handler: () => Unit) {
- closeEventHandler = handler
- }
-
/**
* Flush the partial writes and commit them as a single atomic block. Return the
* number of bytes written for this commit.
@@ -78,11 +70,11 @@ abstract class BlockObjectWriter(val blockId: BlockId) {
/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
class DiskBlockObjectWriter(
- blockId: BlockId,
- file: File,
- serializer: Serializer,
- bufferSize: Int,
- compressStream: OutputStream => OutputStream)
+ blockId: BlockId,
+ file: File,
+ serializer: Serializer,
+ bufferSize: Int,
+ compressStream: OutputStream => OutputStream)
extends BlockObjectWriter(blockId)
with Logging
{
@@ -111,8 +103,8 @@ class DiskBlockObjectWriter(
private var fos: FileOutputStream = null
private var ts: TimeTrackingOutputStream = null
private var objOut: SerializationStream = null
- private var initialPosition = 0L
- private var lastValidPosition = 0L
+ private val initialPosition = file.length()
+ private var lastValidPosition = initialPosition
private var initialized = false
private var _timeWriting = 0L
@@ -120,7 +112,6 @@ class DiskBlockObjectWriter(
fos = new FileOutputStream(file, true)
ts = new TimeTrackingOutputStream(fos)
channel = fos.getChannel()
- initialPosition = channel.position
lastValidPosition = initialPosition
bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
objOut = serializer.newInstance().serializeStream(bs)
@@ -147,8 +138,6 @@ class DiskBlockObjectWriter(
ts = null
objOut = null
}
- // Invoke the close callback handler.
- super.close()
}
override def isOpen: Boolean = objOut != null
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index bcb58ad946..fcd2e97982 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -20,12 +20,11 @@ package org.apache.spark.storage
import java.io.File
import java.text.SimpleDateFormat
import java.util.{Date, Random}
-import java.util.concurrent.ConcurrentHashMap
import org.apache.spark.Logging
import org.apache.spark.executor.ExecutorExitCode
import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
+import org.apache.spark.util.Utils
/**
* Creates and maintains the logical mapping between logical blocks and physical on-disk
@@ -35,7 +34,8 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
*
* @param rootDirs The directories to use for storing block files. Data will be hashed among these.
*/
-private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver with Logging {
+private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootDirs: String)
+ extends PathResolver with Logging {
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
@@ -47,54 +47,23 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit
private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
private var shuffleSender : ShuffleSender = null
- // Stores only Blocks which have been specifically mapped to segments of files
- // (rather than the default, which maps a Block to a whole file).
- // This keeps our bookkeeping down, since the file system itself tracks the standalone Blocks.
- private val blockToFileSegmentMap = new TimeStampedHashMap[BlockId, FileSegment]
-
- val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DISK_BLOCK_MANAGER, this.cleanup)
-
addShutdownHook()
/**
- * Creates a logical mapping from the given BlockId to a segment of a file.
- * This will cause any accesses of the logical BlockId to be directed to the specified
- * physical location.
- */
- def mapBlockToFileSegment(blockId: BlockId, fileSegment: FileSegment) {
- blockToFileSegmentMap.put(blockId, fileSegment)
- }
-
- /**
* Returns the phyiscal file segment in which the given BlockId is located.
* If the BlockId has been mapped to a specific FileSegment, that will be returned.
* Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
*/
def getBlockLocation(blockId: BlockId): FileSegment = {
- if (blockToFileSegmentMap.internalMap.containsKey(blockId)) {
- blockToFileSegmentMap.get(blockId).get
+ if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) {
+ shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])
} else {
val file = getFile(blockId.name)
new FileSegment(file, 0, file.length())
}
}
- /**
- * Simply returns a File to place the given Block into. This does not physically create the file.
- * If filename is given, that file will be used. Otherwise, we will use the BlockId to get
- * a unique filename.
- */
- def createBlockFile(blockId: BlockId, filename: String = "", allowAppending: Boolean): File = {
- val actualFilename = if (filename == "") blockId.name else filename
- val file = getFile(actualFilename)
- if (!allowAppending && file.exists()) {
- throw new IllegalStateException(
- "Attempted to create file that already exists: " + actualFilename)
- }
- file
- }
-
- private def getFile(filename: String): File = {
+ def getFile(filename: String): File = {
// Figure out which local directory it hashes to, and which subdirectory in that
val hash = Utils.nonNegativeHash(filename)
val dirId = hash % localDirs.length
@@ -119,6 +88,8 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit
new File(subDir, filename)
}
+ def getFile(blockId: BlockId): File = getFile(blockId.name)
+
private def createLocalDirs(): Array[File] = {
logDebug("Creating local directories at root dirs '" + rootDirs + "'")
val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
@@ -151,10 +122,6 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit
}
}
- private def cleanup(cleanupTime: Long) {
- blockToFileSegmentMap.clearOldValues(cleanupTime)
- }
-
private def addShutdownHook() {
localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index a3c496f9e0..5a1e7b4444 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -44,7 +44,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
val bytes = _bytes.duplicate()
logDebug("Attempting to put block " + blockId)
val startTime = System.currentTimeMillis
- val file = diskManager.createBlockFile(blockId, allowAppending = false)
+ val file = diskManager.getFile(blockId)
val channel = new FileOutputStream(file).getChannel()
while (bytes.remaining > 0) {
channel.write(bytes)
@@ -64,7 +64,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
logDebug("Attempting to write values for block " + blockId)
val startTime = System.currentTimeMillis
- val file = diskManager.createBlockFile(blockId, allowAppending = false)
+ val file = diskManager.getFile(blockId)
val outputStream = new FileOutputStream(file)
blockManager.dataSerializeStream(blockId, outputStream, values.iterator)
val length = file.length
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index 066e45a12b..2f1b049ce4 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -17,33 +17,45 @@
package org.apache.spark.storage
+import java.io.File
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.JavaConversions._
+
import org.apache.spark.serializer.Serializer
+import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
+import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
-private[spark]
-class ShuffleWriterGroup(val id: Int, val fileId: Int, val writers: Array[BlockObjectWriter])
+/** A group of writers for a ShuffleMapTask, one writer per reducer. */
+private[spark] trait ShuffleWriterGroup {
+ val writers: Array[BlockObjectWriter]
-private[spark]
-trait ShuffleBlocks {
- def acquireWriters(mapId: Int): ShuffleWriterGroup
- def releaseWriters(group: ShuffleWriterGroup)
+ /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */
+ def releaseWriters(success: Boolean)
}
/**
- * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one writer
- * per reducer.
+ * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file
+ * per reducer (this set of files is called a ShuffleFileGroup).
*
* As an optimization to reduce the number of physical shuffle files produced, multiple shuffle
* blocks are aggregated into the same file. There is one "combined shuffle file" per reducer
- * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle files,
- * it releases them for another task.
+ * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle
+ * files, it releases them for another task.
* Regarding the implementation of this feature, shuffle files are identified by a 3-tuple:
* - shuffleId: The unique id given to the entire shuffle stage.
* - bucketId: The id of the output partition (i.e., reducer id)
* - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a
* time owns a particular fileId, and this id is returned to a pool when the task finishes.
+ * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length)
+ * that specifies where in a given file the actual block data is located.
+ *
+ * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping
+ * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for
+ * each block stored in each file. In order to find the location of a shuffle block, we search the
+ * files within a ShuffleFileGroups associated with the block's reducer.
*/
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) {
@@ -52,45 +64,152 @@ class ShuffleBlockManager(blockManager: BlockManager) {
val consolidateShuffleFiles =
System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean
- var nextFileId = new AtomicInteger(0)
- val unusedFileIds = new ConcurrentLinkedQueue[java.lang.Integer]()
+ private val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+
+ /**
+ * Contains all the state related to a particular shuffle. This includes a pool of unused
+ * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle.
+ */
+ private class ShuffleState() {
+ val nextFileId = new AtomicInteger(0)
+ val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+ val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+ }
+
+ type ShuffleId = Int
+ private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState]
+
+ private
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup)
- def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) = {
- new ShuffleBlocks {
- // Get a group of writers for a map task.
- override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
- val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
- val fileId = getUnusedFileId()
- val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
+ new ShuffleWriterGroup {
+ shuffleStates.putIfAbsent(shuffleId, new ShuffleState())
+ private val shuffleState = shuffleStates(shuffleId)
+ private var fileGroup: ShuffleFileGroup = null
+
+ val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
+ fileGroup = getUnusedFileGroup()
+ Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
- if (consolidateShuffleFiles) {
- val filename = physicalFileName(shuffleId, bucketId, fileId)
- blockManager.getDiskWriter(blockId, filename, serializer, bufferSize)
- } else {
- blockManager.getDiskWriter(blockId, blockId.name, serializer, bufferSize)
+ blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize)
+ }
+ } else {
+ Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
+ val blockFile = blockManager.diskBlockManager.getFile(blockId)
+ blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize)
+ }
+ }
+
+ override def releaseWriters(success: Boolean) {
+ if (consolidateShuffleFiles) {
+ if (success) {
+ val offsets = writers.map(_.fileSegment().offset)
+ fileGroup.recordMapOutput(mapId, offsets)
}
+ recycleFileGroup(fileGroup)
}
- new ShuffleWriterGroup(mapId, fileId, writers)
}
- override def releaseWriters(group: ShuffleWriterGroup) {
- recycleFileId(group.fileId)
+ private def getUnusedFileGroup(): ShuffleFileGroup = {
+ val fileGroup = shuffleState.unusedFileGroups.poll()
+ if (fileGroup != null) fileGroup else newFileGroup()
+ }
+
+ private def newFileGroup(): ShuffleFileGroup = {
+ val fileId = shuffleState.nextFileId.getAndIncrement()
+ val files = Array.tabulate[File](numBuckets) { bucketId =>
+ val filename = physicalFileName(shuffleId, bucketId, fileId)
+ blockManager.diskBlockManager.getFile(filename)
+ }
+ val fileGroup = new ShuffleFileGroup(fileId, shuffleId, files)
+ shuffleState.allFileGroups.add(fileGroup)
+ fileGroup
}
- }
- }
- private def getUnusedFileId(): Int = {
- val fileId = unusedFileIds.poll()
- if (fileId == null) nextFileId.getAndIncrement() else fileId
+ private def recycleFileGroup(group: ShuffleFileGroup) {
+ shuffleState.unusedFileGroups.add(group)
+ }
+ }
}
- private def recycleFileId(fileId: Int) {
- if (consolidateShuffleFiles) {
- unusedFileIds.add(fileId)
+ /**
+ * Returns the physical file segment in which the given BlockId is located.
+ * This function should only be called if shuffle file consolidation is enabled, as it is
+ * an error condition if we don't find the expected block.
+ */
+ def getBlockLocation(id: ShuffleBlockId): FileSegment = {
+ // Search all file groups associated with this shuffle.
+ val shuffleState = shuffleStates(id.shuffleId)
+ for (fileGroup <- shuffleState.allFileGroups) {
+ val segment = fileGroup.getFileSegmentFor(id.mapId, id.reduceId)
+ if (segment.isDefined) { return segment.get }
}
+ throw new IllegalStateException("Failed to find shuffle block: " + id)
}
private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
"merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
}
+
+ private def cleanup(cleanupTime: Long) {
+ shuffleStates.clearOldValues(cleanupTime)
+ }
+}
+
+private[spark]
+object ShuffleBlockManager {
+ /**
+ * A group of shuffle files, one per reducer.
+ * A particular mapper will be assigned a single ShuffleFileGroup to write its output to.
+ */
+ private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) {
+ /**
+ * Stores the absolute index of each mapId in the files of this group. For instance,
+ * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0.
+ */
+ private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]()
+
+ /**
+ * Stores consecutive offsets of blocks into each reducer file, ordered by position in the file.
+ * This ordering allows us to compute block lengths by examining the following block offset.
+ * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every
+ * reducer.
+ */
+ private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) {
+ new PrimitiveVector[Long]()
+ }
+
+ def numBlocks = mapIdToIndex.size
+
+ def apply(bucketId: Int) = files(bucketId)
+
+ def recordMapOutput(mapId: Int, offsets: Array[Long]) {
+ mapIdToIndex(mapId) = numBlocks
+ for (i <- 0 until offsets.length) {
+ blockOffsetsByReducer(i) += offsets(i)
+ }
+ }
+
+ /** Returns the FileSegment associated with the given map task, or None if no entry exists. */
+ def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = {
+ val file = files(reducerId)
+ val blockOffsets = blockOffsetsByReducer(reducerId)
+ val index = mapIdToIndex.getOrElse(mapId, -1)
+ if (index >= 0) {
+ val offset = blockOffsets(index)
+ val length =
+ if (index + 1 < numBlocks) {
+ blockOffsets(index + 1) - offset
+ } else {
+ file.length() - offset
+ }
+ assert(length >= 0)
+ Some(new FileSegment(file, offset, length))
+ } else {
+ None
+ }
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
index 7dcadc3805..1e4db4f66b 100644
--- a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
@@ -38,19 +38,19 @@ object StoragePerfTester {
val blockManager = sc.env.blockManager
def writeOutputBytes(mapId: Int, total: AtomicLong) = {
- val shuffle = blockManager.shuffleBlockManager.forShuffle(1, numOutputSplits,
+ val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits,
new KryoSerializer())
- val buckets = shuffle.acquireWriters(mapId)
+ val writers = shuffle.writers
for (i <- 1 to recordsPerMap) {
- buckets.writers(i % numOutputSplits).write(writeData)
+ writers(i % numOutputSplits).write(writeData)
}
- buckets.writers.map {w =>
+ writers.map {w =>
w.commit()
total.addAndGet(w.fileSegment().length)
w.close()
}
- shuffle.releaseWriters(buckets)
+ shuffle.releaseWriters(true)
}
val start = System.currentTimeMillis()
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 35b5d5fd59..c1c7aa70e6 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -152,6 +152,22 @@ private[spark] class StagePage(parent: JobProgressUI) {
else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("")
val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L)
+ var shuffleReadSortable: String = ""
+ var shuffleReadReadable: String = ""
+ if (shuffleRead) {
+ shuffleReadSortable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead}.toString()
+ shuffleReadReadable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s =>
+ Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")
+ }
+
+ var shuffleWriteSortable: String = ""
+ var shuffleWriteReadable: String = ""
+ if (shuffleWrite) {
+ shuffleWriteSortable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten}.toString()
+ shuffleWriteReadable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
+ Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")
+ }
+
<tr>
<td>{info.index}</td>
<td>{info.taskId}</td>
@@ -166,14 +182,17 @@ private[spark] class StagePage(parent: JobProgressUI) {
{if (gcTime > 0) parent.formatDuration(gcTime) else ""}
</td>
{if (shuffleRead) {
- <td>{metrics.flatMap{m => m.shuffleReadMetrics}.map{s =>
- Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")}</td>
+ <td sorttable_customkey={shuffleReadSortable}>
+ {shuffleReadReadable}
+ </td>
}}
{if (shuffleWrite) {
- <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
- parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}</td>
- <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
- Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")}</td>
+ <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
+ parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}
+ </td>
+ <td sorttable_customkey={shuffleWriteSortable}>
+ {shuffleWriteReadable}
+ </td>
}}
<td>{exception.map(e =>
<span>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index d7d0441c38..9ad6de3c6d 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -79,11 +79,14 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr
case None => "Unknown"
}
- val shuffleRead = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L) match {
+ val shuffleReadSortable = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L)
+ val shuffleRead = shuffleReadSortable match {
case 0 => ""
case b => Utils.bytesToString(b)
}
- val shuffleWrite = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L) match {
+
+ val shuffleWriteSortable = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L)
+ val shuffleWrite = shuffleWriteSortable match {
case 0 => ""
case b => Utils.bytesToString(b)
}
@@ -119,8 +122,8 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr
<td class="progress-cell">
{makeProgressBar(startedTasks, completedTasks, failedTasks, totalTasks)}
</td>
- <td>{shuffleRead}</td>
- <td>{shuffleWrite}</td>
+ <td sorttable_customekey={shuffleReadSortable.toString}>{shuffleRead}</td>
+ <td sorttable_customekey={shuffleWriteSortable.toString}>{shuffleWrite}</td>
</tr>
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index 3f963727d9..67a7f87a5c 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -59,7 +59,7 @@ object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext
"ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") {
val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
- SHUFFLE_MAP_TASK, BLOCK_MANAGER, DISK_BLOCK_MANAGER, BROADCAST_VARS = Value
+ SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
type MetadataCleanerType = Value
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index fd2811e44c..fe932d8ede 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,13 +18,12 @@
package org.apache.spark.util
import java.io._
-import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket}
+import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address}
import java.util.{Locale, Random, UUID}
-import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor}
-import java.util.regex.Pattern
+import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor}
import scala.collection.Map
-import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import scala.io.Source
@@ -36,7 +35,7 @@ import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
import org.apache.spark.deploy.SparkHadoopUtil
import java.nio.ByteBuffer
-import org.apache.spark.{SparkEnv, SparkException, Logging}
+import org.apache.spark.{SparkException, Logging}
/**
@@ -148,7 +147,7 @@ private[spark] object Utils extends Logging {
return buf
}
- private val shutdownDeletePaths = new collection.mutable.HashSet[String]()
+ private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]()
// Register the path to be deleted via shutdown hook
def registerShutdownDeleteDir(file: File) {
@@ -818,4 +817,10 @@ private[spark] object Utils extends Logging {
// Nothing else to guard against ?
hashAbs
}
+
+ /** Returns a copy of the system properties that is thread-safe to iterator over. */
+ def getSystemProperties(): Map[String, String] = {
+ return System.getProperties().clone()
+ .asInstanceOf[java.util.Properties].toMap[String, String]
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
new file mode 100644
index 0000000000..a1a452315d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+
+/**
+ * A simple, fixed-size bit set implementation. This implementation is fast because it avoids
+ * safety/bound checking.
+ */
+class BitSet(numBits: Int) {
+
+ private[this] val words = new Array[Long](bit2words(numBits))
+ private[this] val numWords = words.length
+
+ /**
+ * Sets the bit at the specified index to true.
+ * @param index the bit index
+ */
+ def set(index: Int) {
+ val bitmask = 1L << (index & 0x3f) // mod 64 and shift
+ words(index >> 6) |= bitmask // div by 64 and mask
+ }
+
+ /**
+ * Return the value of the bit with the specified index. The value is true if the bit with
+ * the index is currently set in this BitSet; otherwise, the result is false.
+ *
+ * @param index the bit index
+ * @return the value of the bit with the specified index
+ */
+ def get(index: Int): Boolean = {
+ val bitmask = 1L << (index & 0x3f) // mod 64 and shift
+ (words(index >> 6) & bitmask) != 0 // div by 64 and mask
+ }
+
+ /** Return the number of bits set to true in this BitSet. */
+ def cardinality(): Int = {
+ var sum = 0
+ var i = 0
+ while (i < numWords) {
+ sum += java.lang.Long.bitCount(words(i))
+ i += 1
+ }
+ sum
+ }
+
+ /**
+ * Returns the index of the first bit that is set to true that occurs on or after the
+ * specified starting index. If no such bit exists then -1 is returned.
+ *
+ * To iterate over the true bits in a BitSet, use the following loop:
+ *
+ * for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i+1)) {
+ * // operate on index i here
+ * }
+ *
+ * @param fromIndex the index to start checking from (inclusive)
+ * @return the index of the next set bit, or -1 if there is no such bit
+ */
+ def nextSetBit(fromIndex: Int): Int = {
+ var wordIndex = fromIndex >> 6
+ if (wordIndex >= numWords) {
+ return -1
+ }
+
+ // Try to find the next set bit in the current word
+ val subIndex = fromIndex & 0x3f
+ var word = words(wordIndex) >> subIndex
+ if (word != 0) {
+ return (wordIndex << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word)
+ }
+
+ // Find the next set bit in the rest of the words
+ wordIndex += 1
+ while (wordIndex < numWords) {
+ word = words(wordIndex)
+ if (word != 0) {
+ return (wordIndex << 6) + java.lang.Long.numberOfTrailingZeros(word)
+ }
+ wordIndex += 1
+ }
+
+ -1
+ }
+
+ /** Return the number of longs it would take to hold numBits. */
+ private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
new file mode 100644
index 0000000000..80545c9688
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+
+/**
+ * A fast hash map implementation for nullable keys. This hash map supports insertions and updates,
+ * but not deletions. This map is about 5X faster than java.util.HashMap, while using much less
+ * space overhead.
+ *
+ * Under the hood, it uses our OpenHashSet implementation.
+ */
+private[spark]
+class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: ClassManifest](
+ initialCapacity: Int)
+ extends Iterable[(K, V)]
+ with Serializable {
+
+ def this() = this(64)
+
+ protected var _keySet = new OpenHashSet[K](initialCapacity)
+
+ // Init in constructor (instead of in declaration) to work around a Scala compiler specialization
+ // bug that would generate two arrays (one for Object and one for specialized T).
+ private var _values: Array[V] = _
+ _values = new Array[V](_keySet.capacity)
+
+ @transient private var _oldValues: Array[V] = null
+
+ // Treat the null key differently so we can use nulls in "data" to represent empty items.
+ private var haveNullValue = false
+ private var nullValue: V = null.asInstanceOf[V]
+
+ override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size
+
+ /** Get the value for a given key */
+ def apply(k: K): V = {
+ if (k == null) {
+ nullValue
+ } else {
+ val pos = _keySet.getPos(k)
+ if (pos < 0) {
+ null.asInstanceOf[V]
+ } else {
+ _values(pos)
+ }
+ }
+ }
+
+ /** Set the value for a key */
+ def update(k: K, v: V) {
+ if (k == null) {
+ haveNullValue = true
+ nullValue = v
+ } else {
+ val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
+ _values(pos) = v
+ _keySet.rehashIfNeeded(k, grow, move)
+ _oldValues = null
+ }
+ }
+
+ /**
+ * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
+ * set its value to mergeValue(oldValue).
+ *
+ * @return the newly updated value.
+ */
+ def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
+ if (k == null) {
+ if (haveNullValue) {
+ nullValue = mergeValue(nullValue)
+ } else {
+ haveNullValue = true
+ nullValue = defaultValue
+ }
+ nullValue
+ } else {
+ val pos = _keySet.addWithoutResize(k)
+ if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
+ val newValue = defaultValue
+ _values(pos & OpenHashSet.POSITION_MASK) = newValue
+ _keySet.rehashIfNeeded(k, grow, move)
+ newValue
+ } else {
+ _values(pos) = mergeValue(_values(pos))
+ _values(pos)
+ }
+ }
+ }
+
+ override def iterator = new Iterator[(K, V)] {
+ var pos = -1
+ var nextPair: (K, V) = computeNextPair()
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def computeNextPair(): (K, V) = {
+ if (pos == -1) { // Treat position -1 as looking at the null value
+ if (haveNullValue) {
+ pos += 1
+ return (null.asInstanceOf[K], nullValue)
+ }
+ pos += 1
+ }
+ pos = _keySet.nextPos(pos)
+ if (pos >= 0) {
+ val ret = (_keySet.getValue(pos), _values(pos))
+ pos += 1
+ ret
+ } else {
+ null
+ }
+ }
+
+ def hasNext = nextPair != null
+
+ def next() = {
+ val pair = nextPair
+ nextPair = computeNextPair()
+ pair
+ }
+ }
+
+ // The following member variables are declared as protected instead of private for the
+ // specialization to work (specialized class extends the non-specialized one and needs access
+ // to the "private" variables).
+ // They also should have been val's. We use var's because there is a Scala compiler bug that
+ // would throw illegal access error at runtime if they are declared as val's.
+ protected var grow = (newCapacity: Int) => {
+ _oldValues = _values
+ _values = new Array[V](newCapacity)
+ }
+
+ protected var move = (oldPos: Int, newPos: Int) => {
+ _values(newPos) = _oldValues(oldPos)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
new file mode 100644
index 0000000000..4592e4f939
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+
+/**
+ * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never
+ * removed.
+ *
+ * The underlying implementation uses Scala compiler's specialization to generate optimized
+ * storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet
+ * while incurring much less memory overhead. This can serve as building blocks for higher level
+ * data structures such as an optimized HashMap.
+ *
+ * This OpenHashSet is designed to serve as building blocks for higher level data structures
+ * such as an optimized hash map. Compared with standard hash set implementations, this class
+ * provides its various callbacks interfaces (e.g. allocateFunc, moveFunc) and interfaces to
+ * retrieve the position of a key in the underlying array.
+ *
+ * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed
+ * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing).
+ */
+private[spark]
+class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
+ initialCapacity: Int,
+ loadFactor: Double)
+ extends Serializable {
+
+ require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
+ require(initialCapacity >= 1, "Invalid initial capacity")
+ require(loadFactor < 1.0, "Load factor must be less than 1.0")
+ require(loadFactor > 0.0, "Load factor must be greater than 0.0")
+
+ import OpenHashSet._
+
+ def this(initialCapacity: Int) = this(initialCapacity, 0.7)
+
+ def this() = this(64)
+
+ // The following member variables are declared as protected instead of private for the
+ // specialization to work (specialized class extends the non-specialized one and needs access
+ // to the "private" variables).
+
+ protected val hasher: Hasher[T] = {
+ // It would've been more natural to write the following using pattern matching. But Scala 2.9.x
+ // compiler has a bug when specialization is used together with this pattern matching, and
+ // throws:
+ // scala.tools.nsc.symtab.Types$TypeError: type mismatch;
+ // found : scala.reflect.AnyValManifest[Long]
+ // required: scala.reflect.ClassManifest[Int]
+ // at scala.tools.nsc.typechecker.Contexts$Context.error(Contexts.scala:298)
+ // at scala.tools.nsc.typechecker.Infer$Inferencer.error(Infer.scala:207)
+ // ...
+ val mt = classManifest[T]
+ if (mt == ClassManifest.Long) {
+ (new LongHasher).asInstanceOf[Hasher[T]]
+ } else if (mt == ClassManifest.Int) {
+ (new IntHasher).asInstanceOf[Hasher[T]]
+ } else {
+ new Hasher[T]
+ }
+ }
+
+ protected var _capacity = nextPowerOf2(initialCapacity)
+ protected var _mask = _capacity - 1
+ protected var _size = 0
+
+ protected var _bitset = new BitSet(_capacity)
+
+ // Init of the array in constructor (instead of in declaration) to work around a Scala compiler
+ // specialization bug that would generate two arrays (one for Object and one for specialized T).
+ protected var _data: Array[T] = _
+ _data = new Array[T](_capacity)
+
+ /** Number of elements in the set. */
+ def size: Int = _size
+
+ /** The capacity of the set (i.e. size of the underlying array). */
+ def capacity: Int = _capacity
+
+ /** Return true if this set contains the specified element. */
+ def contains(k: T): Boolean = getPos(k) != INVALID_POS
+
+ /**
+ * Add an element to the set. If the set is over capacity after the insertion, grow the set
+ * and rehash all elements.
+ */
+ def add(k: T) {
+ addWithoutResize(k)
+ rehashIfNeeded(k, grow, move)
+ }
+
+ /**
+ * Add an element to the set. This one differs from add in that it doesn't trigger rehashing.
+ * The caller is responsible for calling rehashIfNeeded.
+ *
+ * Use (retval & POSITION_MASK) to get the actual position, and
+ * (retval & EXISTENCE_MASK) != 0 for prior existence.
+ *
+ * @return The position where the key is placed, plus the highest order bit is set if the key
+ * exists previously.
+ */
+ def addWithoutResize(k: T): Int = putInto(_bitset, _data, k)
+
+ /**
+ * Rehash the set if it is overloaded.
+ * @param k A parameter unused in the function, but to force the Scala compiler to specialize
+ * this method.
+ * @param allocateFunc Callback invoked when we are allocating a new, larger array.
+ * @param moveFunc Callback invoked when we move the key from one position (in the old data array)
+ * to a new position (in the new data array).
+ */
+ def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
+ if (_size > loadFactor * _capacity) {
+ rehash(k, allocateFunc, moveFunc)
+ }
+ }
+
+ /**
+ * Return the position of the element in the underlying array, or INVALID_POS if it is not found.
+ */
+ def getPos(k: T): Int = {
+ var pos = hashcode(hasher.hash(k)) & _mask
+ var i = 1
+ while (true) {
+ if (!_bitset.get(pos)) {
+ return INVALID_POS
+ } else if (k == _data(pos)) {
+ return pos
+ } else {
+ val delta = i
+ pos = (pos + delta) & _mask
+ i += 1
+ }
+ }
+ // Never reached here
+ INVALID_POS
+ }
+
+ /** Return the value at the specified position. */
+ def getValue(pos: Int): T = _data(pos)
+
+ /**
+ * Return the next position with an element stored, starting from the given position inclusively.
+ */
+ def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos)
+
+ /**
+ * Put an entry into the set. Return the position where the key is placed. In addition, the
+ * highest bit in the returned position is set if the key exists prior to this put.
+ *
+ * This function assumes the data array has at least one empty slot.
+ */
+ private def putInto(bitset: BitSet, data: Array[T], k: T): Int = {
+ val mask = data.length - 1
+ var pos = hashcode(hasher.hash(k)) & mask
+ var i = 1
+ while (true) {
+ if (!bitset.get(pos)) {
+ // This is a new key.
+ data(pos) = k
+ bitset.set(pos)
+ _size += 1
+ return pos | NONEXISTENCE_MASK
+ } else if (data(pos) == k) {
+ // Found an existing key.
+ return pos
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
+ }
+ // Never reached here
+ assert(INVALID_POS != INVALID_POS)
+ INVALID_POS
+ }
+
+ /**
+ * Double the table's size and re-hash everything. We are not really using k, but it is declared
+ * so Scala compiler can specialize this method (which leads to calling the specialized version
+ * of putInto).
+ *
+ * @param k A parameter unused in the function, but to force the Scala compiler to specialize
+ * this method.
+ * @param allocateFunc Callback invoked when we are allocating a new, larger array.
+ * @param moveFunc Callback invoked when we move the key from one position (in the old data array)
+ * to a new position (in the new data array).
+ */
+ private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
+ val newCapacity = _capacity * 2
+ require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
+
+ allocateFunc(newCapacity)
+ val newData = new Array[T](newCapacity)
+ val newBitset = new BitSet(newCapacity)
+ var pos = 0
+ _size = 0
+ while (pos < _capacity) {
+ if (_bitset.get(pos)) {
+ val newPos = putInto(newBitset, newData, _data(pos))
+ moveFunc(pos, newPos & POSITION_MASK)
+ }
+ pos += 1
+ }
+ _bitset = newBitset
+ _data = newData
+ _capacity = newCapacity
+ _mask = newCapacity - 1
+ }
+
+ /**
+ * Re-hash a value to deal better with hash functions that don't differ
+ * in the lower bits, similar to java.util.HashMap
+ */
+ private def hashcode(h: Int): Int = {
+ val r = h ^ (h >>> 20) ^ (h >>> 12)
+ r ^ (r >>> 7) ^ (r >>> 4)
+ }
+
+ private def nextPowerOf2(n: Int): Int = {
+ val highBit = Integer.highestOneBit(n)
+ if (highBit == n) n else highBit << 1
+ }
+}
+
+
+private[spark]
+object OpenHashSet {
+
+ val INVALID_POS = -1
+ val NONEXISTENCE_MASK = 0x80000000
+ val POSITION_MASK = 0xEFFFFFF
+
+ /**
+ * A set of specialized hash function implementation to avoid boxing hash code computation
+ * in the specialized implementation of OpenHashSet.
+ */
+ sealed class Hasher[@specialized(Long, Int) T] {
+ def hash(o: T): Int = o.hashCode()
+ }
+
+ class LongHasher extends Hasher[Long] {
+ override def hash(o: Long): Int = (o ^ (o >>> 32)).toInt
+ }
+
+ class IntHasher extends Hasher[Int] {
+ override def hash(o: Int): Int = o
+ }
+
+ private def grow1(newSize: Int) {}
+ private def move1(oldPos: Int, newPos: Int) { }
+
+ private val grow = grow1 _
+ private val move = move1 _
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
new file mode 100644
index 0000000000..d76143e45a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+
+/**
+ * A fast hash map implementation for primitive, non-null keys. This hash map supports
+ * insertions and updates, but not deletions. This map is about an order of magnitude
+ * faster than java.util.HashMap, while using much less space overhead.
+ *
+ * Under the hood, it uses our OpenHashSet implementation.
+ */
+private[spark]
+class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest,
+ @specialized(Long, Int, Double) V: ClassManifest](
+ initialCapacity: Int)
+ extends Iterable[(K, V)]
+ with Serializable {
+
+ def this() = this(64)
+
+ require(classManifest[K] == classManifest[Long] || classManifest[K] == classManifest[Int])
+
+ // Init in constructor (instead of in declaration) to work around a Scala compiler specialization
+ // bug that would generate two arrays (one for Object and one for specialized T).
+ protected var _keySet: OpenHashSet[K] = _
+ private var _values: Array[V] = _
+ _keySet = new OpenHashSet[K](initialCapacity)
+ _values = new Array[V](_keySet.capacity)
+
+ private var _oldValues: Array[V] = null
+
+ override def size = _keySet.size
+
+ /** Get the value for a given key */
+ def apply(k: K): V = {
+ val pos = _keySet.getPos(k)
+ _values(pos)
+ }
+
+ /** Get the value for a given key, or returns elseValue if it doesn't exist. */
+ def getOrElse(k: K, elseValue: V): V = {
+ val pos = _keySet.getPos(k)
+ if (pos >= 0) _values(pos) else elseValue
+ }
+
+ /** Set the value for a key */
+ def update(k: K, v: V) {
+ val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
+ _values(pos) = v
+ _keySet.rehashIfNeeded(k, grow, move)
+ _oldValues = null
+ }
+
+ /**
+ * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
+ * set its value to mergeValue(oldValue).
+ *
+ * @return the newly updated value.
+ */
+ def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
+ val pos = _keySet.addWithoutResize(k)
+ if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
+ val newValue = defaultValue
+ _values(pos & OpenHashSet.POSITION_MASK) = newValue
+ _keySet.rehashIfNeeded(k, grow, move)
+ newValue
+ } else {
+ _values(pos) = mergeValue(_values(pos))
+ _values(pos)
+ }
+ }
+
+ override def iterator = new Iterator[(K, V)] {
+ var pos = 0
+ var nextPair: (K, V) = computeNextPair()
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def computeNextPair(): (K, V) = {
+ pos = _keySet.nextPos(pos)
+ if (pos >= 0) {
+ val ret = (_keySet.getValue(pos), _values(pos))
+ pos += 1
+ ret
+ } else {
+ null
+ }
+ }
+
+ def hasNext = nextPair != null
+
+ def next() = {
+ val pair = nextPair
+ nextPair = computeNextPair()
+ pair
+ }
+ }
+
+ // The following member variables are declared as protected instead of private for the
+ // specialization to work (specialized class extends the unspecialized one and needs access
+ // to the "private" variables).
+ // They also should have been val's. We use var's because there is a Scala compiler bug that
+ // would throw illegal access error at runtime if they are declared as val's.
+ protected var grow = (newCapacity: Int) => {
+ _oldValues = _values
+ _values = new Array[V](newCapacity)
+ }
+
+ protected var move = (oldPos: Int, newPos: Int) => {
+ _values(newPos) = _oldValues(oldPos)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
new file mode 100644
index 0000000000..369519c559
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+/** Provides a simple, non-threadsafe, array-backed vector that can store primitives. */
+private[spark]
+class PrimitiveVector[@specialized(Long, Int, Double) V: ClassManifest](initialSize: Int = 64) {
+ private var numElements = 0
+ private var array: Array[V] = _
+
+ // NB: This must be separate from the declaration, otherwise the specialized parent class
+ // will get its own array with the same initial size. TODO: Figure out why...
+ array = new Array[V](initialSize)
+
+ def apply(index: Int): V = {
+ require(index < numElements)
+ array(index)
+ }
+
+ def +=(value: V) {
+ if (numElements == array.length) { resize(array.length * 2) }
+ array(numElements) = value
+ numElements += 1
+ }
+
+ def length = numElements
+
+ def getUnderlyingArray = array
+
+ /** Resizes the array, dropping elements if the total length decreases. */
+ def resize(newLength: Int) {
+ val newArray = new Array[V](newLength)
+ array.copyToArray(newArray)
+ array = newArray
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
new file mode 100644
index 0000000000..8f0954122b
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
@@ -0,0 +1,19 @@
+package org.apache.spark.deploy.worker
+
+import java.io.File
+import org.scalatest.FunSuite
+import org.apache.spark.deploy.{ExecutorState, Command, ApplicationDescription}
+
+class ExecutorRunnerTest extends FunSuite {
+ test("command includes appId") {
+ def f(s:String) = new File(s)
+ val sparkHome = sys.env("SPARK_HOME")
+ val appDesc = new ApplicationDescription("app name", 8, 500, Command("foo", Seq(),Map()),
+ sparkHome, "appUiUrl")
+ val appId = "12345-worker321-9876"
+ val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome),
+ f("ooga"), ExecutorState.RUNNING)
+
+ assert(er.buildCommandSeq().last === appId)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
index 8406093246..984881861c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
@@ -65,7 +65,7 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
val rootStageInfo = new StageInfo(rootStage)
joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStageInfo, null))
- joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
+ joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getSimpleName)
parentRdd.setName("MyRDD")
joblogger.getRddNameTest(parentRdd) should be ("MyRDD")
joblogger.createLogWriterTest(jobID)
@@ -91,8 +91,10 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
sc.addSparkListener(joblogger)
val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
rdd.reduceByKey(_+_).collect()
+
+ val user = System.getProperty("user.name", SparkContext.SPARK_UNKNOWN_USER)
- joblogger.getLogDir should be ("/tmp/spark")
+ joblogger.getLogDir should be ("/tmp/spark-%s".format(user))
joblogger.getJobIDtoPrintWriter.size should be (1)
joblogger.getStageIDToJobID.size should be (2)
joblogger.getStageIDToJobID.get(0) should be (Some(0))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index d9a1d6d087..f3e592bf5c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -84,7 +84,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
i
}
- val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)}
+ val d = sc.parallelize(0 to 1e4.toInt, 64).map{i => w(i)}
d.count()
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be (1)
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
new file mode 100644
index 0000000000..0b9056344c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -0,0 +1,84 @@
+package org.apache.spark.storage
+
+import java.io.{FileWriter, File}
+
+import scala.collection.mutable
+
+import com.google.common.io.Files
+import org.scalatest.{BeforeAndAfterEach, FunSuite}
+
+class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach {
+
+ val rootDir0 = Files.createTempDir()
+ rootDir0.deleteOnExit()
+ val rootDir1 = Files.createTempDir()
+ rootDir1.deleteOnExit()
+ val rootDirs = rootDir0.getName + "," + rootDir1.getName
+ println("Created root dirs: " + rootDirs)
+
+ val shuffleBlockManager = new ShuffleBlockManager(null) {
+ var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]()
+ override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id)
+ }
+
+ var diskBlockManager: DiskBlockManager = _
+
+ override def beforeEach() {
+ diskBlockManager = new DiskBlockManager(shuffleBlockManager, rootDirs)
+ shuffleBlockManager.idToSegmentMap.clear()
+ }
+
+ test("basic block creation") {
+ val blockId = new TestBlockId("test")
+ assertSegmentEquals(blockId, blockId.name, 0, 0)
+
+ val newFile = diskBlockManager.getFile(blockId)
+ writeToFile(newFile, 10)
+ assertSegmentEquals(blockId, blockId.name, 0, 10)
+
+ newFile.delete()
+ }
+
+ test("block appending") {
+ val blockId = new TestBlockId("test")
+ val newFile = diskBlockManager.getFile(blockId)
+ writeToFile(newFile, 15)
+ assertSegmentEquals(blockId, blockId.name, 0, 15)
+ val newFile2 = diskBlockManager.getFile(blockId)
+ assert(newFile === newFile2)
+ writeToFile(newFile2, 12)
+ assertSegmentEquals(blockId, blockId.name, 0, 27)
+ newFile.delete()
+ }
+
+ test("block remapping") {
+ val filename = "test"
+ val blockId0 = new ShuffleBlockId(1, 2, 3)
+ val newFile = diskBlockManager.getFile(filename)
+ writeToFile(newFile, 15)
+ shuffleBlockManager.idToSegmentMap(blockId0) = new FileSegment(newFile, 0, 15)
+ assertSegmentEquals(blockId0, filename, 0, 15)
+
+ val blockId1 = new ShuffleBlockId(1, 2, 4)
+ val newFile2 = diskBlockManager.getFile(filename)
+ writeToFile(newFile2, 12)
+ shuffleBlockManager.idToSegmentMap(blockId1) = new FileSegment(newFile, 15, 12)
+ assertSegmentEquals(blockId1, filename, 15, 12)
+
+ assert(newFile === newFile2)
+ newFile.delete()
+ }
+
+ def assertSegmentEquals(blockId: BlockId, filename: String, offset: Int, length: Int) {
+ val segment = diskBlockManager.getBlockLocation(blockId)
+ assert(segment.file.getName === filename)
+ assert(segment.offset === offset)
+ assert(segment.length === length)
+ }
+
+ def writeToFile(file: File, numBytes: Int) {
+ val writer = new FileWriter(file, true)
+ for (i <- 0 until numBytes) writer.write(i)
+ writer.close()
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala
new file mode 100644
index 0000000000..0f1ab3d20e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import org.scalatest.FunSuite
+
+
+class BitSetSuite extends FunSuite {
+
+ test("basic set and get") {
+ val setBits = Seq(0, 9, 1, 10, 90, 96)
+ val bitset = new BitSet(100)
+
+ for (i <- 0 until 100) {
+ assert(!bitset.get(i))
+ }
+
+ setBits.foreach(i => bitset.set(i))
+
+ for (i <- 0 until 100) {
+ if (setBits.contains(i)) {
+ assert(bitset.get(i))
+ } else {
+ assert(!bitset.get(i))
+ }
+ }
+ assert(bitset.cardinality() === setBits.size)
+ }
+
+ test("100% full bit set") {
+ val bitset = new BitSet(10000)
+ for (i <- 0 until 10000) {
+ assert(!bitset.get(i))
+ bitset.set(i)
+ }
+ for (i <- 0 until 10000) {
+ assert(bitset.get(i))
+ }
+ assert(bitset.cardinality() === 10000)
+ }
+
+ test("nextSetBit") {
+ val setBits = Seq(0, 9, 1, 10, 90, 96)
+ val bitset = new BitSet(100)
+ setBits.foreach(i => bitset.set(i))
+
+ assert(bitset.nextSetBit(0) === 0)
+ assert(bitset.nextSetBit(1) === 1)
+ assert(bitset.nextSetBit(2) === 9)
+ assert(bitset.nextSetBit(9) === 9)
+ assert(bitset.nextSetBit(10) === 10)
+ assert(bitset.nextSetBit(11) === 90)
+ assert(bitset.nextSetBit(80) === 90)
+ assert(bitset.nextSetBit(91) === 96)
+ assert(bitset.nextSetBit(96) === 96)
+ assert(bitset.nextSetBit(97) === -1)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
new file mode 100644
index 0000000000..ca3f684668
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
@@ -0,0 +1,148 @@
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.HashSet
+import org.scalatest.FunSuite
+
+class OpenHashMapSuite extends FunSuite {
+
+ test("initialization") {
+ val goodMap1 = new OpenHashMap[String, Int](1)
+ assert(goodMap1.size === 0)
+ val goodMap2 = new OpenHashMap[String, Int](255)
+ assert(goodMap2.size === 0)
+ val goodMap3 = new OpenHashMap[String, String](256)
+ assert(goodMap3.size === 0)
+ intercept[IllegalArgumentException] {
+ new OpenHashMap[String, Int](1 << 30) // Invalid map size: bigger than 2^29
+ }
+ intercept[IllegalArgumentException] {
+ new OpenHashMap[String, Int](-1)
+ }
+ intercept[IllegalArgumentException] {
+ new OpenHashMap[String, String](0)
+ }
+ }
+
+ test("primitive value") {
+ val map = new OpenHashMap[String, Int]
+
+ for (i <- 1 to 1000) {
+ map(i.toString) = i
+ assert(map(i.toString) === i)
+ }
+
+ assert(map.size === 1000)
+ assert(map(null) === 0)
+
+ map(null) = -1
+ assert(map.size === 1001)
+ assert(map(null) === -1)
+
+ for (i <- 1 to 1000) {
+ assert(map(i.toString) === i)
+ }
+
+ // Test iterator
+ val set = new HashSet[(String, Int)]
+ for ((k, v) <- map) {
+ set.add((k, v))
+ }
+ val expected = (1 to 1000).map(x => (x.toString, x)) :+ (null.asInstanceOf[String], -1)
+ assert(set === expected.toSet)
+ }
+
+ test("non-primitive value") {
+ val map = new OpenHashMap[String, String]
+
+ for (i <- 1 to 1000) {
+ map(i.toString) = i.toString
+ assert(map(i.toString) === i.toString)
+ }
+
+ assert(map.size === 1000)
+ assert(map(null) === null)
+
+ map(null) = "-1"
+ assert(map.size === 1001)
+ assert(map(null) === "-1")
+
+ for (i <- 1 to 1000) {
+ assert(map(i.toString) === i.toString)
+ }
+
+ // Test iterator
+ val set = new HashSet[(String, String)]
+ for ((k, v) <- map) {
+ set.add((k, v))
+ }
+ val expected = (1 to 1000).map(_.toString).map(x => (x, x)) :+ (null.asInstanceOf[String], "-1")
+ assert(set === expected.toSet)
+ }
+
+ test("null keys") {
+ val map = new OpenHashMap[String, String]()
+ for (i <- 1 to 100) {
+ map(i.toString) = i.toString
+ }
+ assert(map.size === 100)
+ assert(map(null) === null)
+ map(null) = "hello"
+ assert(map.size === 101)
+ assert(map(null) === "hello")
+ }
+
+ test("null values") {
+ val map = new OpenHashMap[String, String]()
+ for (i <- 1 to 100) {
+ map(i.toString) = null
+ }
+ assert(map.size === 100)
+ assert(map("1") === null)
+ assert(map(null) === null)
+ assert(map.size === 100)
+ map(null) = null
+ assert(map.size === 101)
+ assert(map(null) === null)
+ }
+
+ test("changeValue") {
+ val map = new OpenHashMap[String, String]()
+ for (i <- 1 to 100) {
+ map(i.toString) = i.toString
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ val res = map.changeValue(i.toString, { assert(false); "" }, v => {
+ assert(v === i.toString)
+ v + "!"
+ })
+ assert(res === i + "!")
+ }
+ // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a
+ // bug where changeValue would return the wrong result when the map grew on that insert
+ for (i <- 101 to 400) {
+ val res = map.changeValue(i.toString, { i + "!" }, v => { assert(false); v })
+ assert(res === i + "!")
+ }
+ assert(map.size === 400)
+ assert(map(null) === null)
+ map.changeValue(null, { "null!" }, v => { assert(false); v })
+ assert(map.size === 401)
+ map.changeValue(null, { assert(false); "" }, v => {
+ assert(v === "null!")
+ "null!!"
+ })
+ assert(map.size === 401)
+ }
+
+ test("inserting in capacity-1 map") {
+ val map = new OpenHashMap[String, String](1)
+ for (i <- 1 to 100) {
+ map(i.toString) = i.toString
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ assert(map(i.toString) === i.toString)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
new file mode 100644
index 0000000000..4e11e8a628
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
@@ -0,0 +1,145 @@
+package org.apache.spark.util.collection
+
+import org.scalatest.FunSuite
+
+
+class OpenHashSetSuite extends FunSuite {
+
+ test("primitive int") {
+ val set = new OpenHashSet[Int]
+ assert(set.size === 0)
+ assert(!set.contains(10))
+ assert(!set.contains(50))
+ assert(!set.contains(999))
+ assert(!set.contains(10000))
+
+ set.add(10)
+ assert(set.contains(10))
+ assert(!set.contains(50))
+ assert(!set.contains(999))
+ assert(!set.contains(10000))
+
+ set.add(50)
+ assert(set.size === 2)
+ assert(set.contains(10))
+ assert(set.contains(50))
+ assert(!set.contains(999))
+ assert(!set.contains(10000))
+
+ set.add(999)
+ assert(set.size === 3)
+ assert(set.contains(10))
+ assert(set.contains(50))
+ assert(set.contains(999))
+ assert(!set.contains(10000))
+
+ set.add(50)
+ assert(set.size === 3)
+ assert(set.contains(10))
+ assert(set.contains(50))
+ assert(set.contains(999))
+ assert(!set.contains(10000))
+ }
+
+ test("primitive long") {
+ val set = new OpenHashSet[Long]
+ assert(set.size === 0)
+ assert(!set.contains(10L))
+ assert(!set.contains(50L))
+ assert(!set.contains(999L))
+ assert(!set.contains(10000L))
+
+ set.add(10L)
+ assert(set.size === 1)
+ assert(set.contains(10L))
+ assert(!set.contains(50L))
+ assert(!set.contains(999L))
+ assert(!set.contains(10000L))
+
+ set.add(50L)
+ assert(set.size === 2)
+ assert(set.contains(10L))
+ assert(set.contains(50L))
+ assert(!set.contains(999L))
+ assert(!set.contains(10000L))
+
+ set.add(999L)
+ assert(set.size === 3)
+ assert(set.contains(10L))
+ assert(set.contains(50L))
+ assert(set.contains(999L))
+ assert(!set.contains(10000L))
+
+ set.add(50L)
+ assert(set.size === 3)
+ assert(set.contains(10L))
+ assert(set.contains(50L))
+ assert(set.contains(999L))
+ assert(!set.contains(10000L))
+ }
+
+ test("non-primitive") {
+ val set = new OpenHashSet[String]
+ assert(set.size === 0)
+ assert(!set.contains(10.toString))
+ assert(!set.contains(50.toString))
+ assert(!set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+
+ set.add(10.toString)
+ assert(set.size === 1)
+ assert(set.contains(10.toString))
+ assert(!set.contains(50.toString))
+ assert(!set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+
+ set.add(50.toString)
+ assert(set.size === 2)
+ assert(set.contains(10.toString))
+ assert(set.contains(50.toString))
+ assert(!set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+
+ set.add(999.toString)
+ assert(set.size === 3)
+ assert(set.contains(10.toString))
+ assert(set.contains(50.toString))
+ assert(set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+
+ set.add(50.toString)
+ assert(set.size === 3)
+ assert(set.contains(10.toString))
+ assert(set.contains(50.toString))
+ assert(set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+ }
+
+ test("non-primitive set growth") {
+ val set = new OpenHashSet[String]
+ for (i <- 1 to 1000) {
+ set.add(i.toString)
+ }
+ assert(set.size === 1000)
+ assert(set.capacity > 1000)
+ for (i <- 1 to 100) {
+ set.add(i.toString)
+ }
+ assert(set.size === 1000)
+ assert(set.capacity > 1000)
+ }
+
+ test("primitive set growth") {
+ val set = new OpenHashSet[Long]
+ for (i <- 1 to 1000) {
+ set.add(i.toLong)
+ }
+ assert(set.size === 1000)
+ assert(set.capacity > 1000)
+ for (i <- 1 to 100) {
+ set.add(i.toLong)
+ }
+ assert(set.size === 1000)
+ assert(set.capacity > 1000)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala
new file mode 100644
index 0000000000..dfd6aed2c4
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala
@@ -0,0 +1,90 @@
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.HashSet
+import org.scalatest.FunSuite
+
+class PrimitiveKeyOpenHashSetSuite extends FunSuite {
+
+ test("initialization") {
+ val goodMap1 = new PrimitiveKeyOpenHashMap[Int, Int](1)
+ assert(goodMap1.size === 0)
+ val goodMap2 = new PrimitiveKeyOpenHashMap[Int, Int](255)
+ assert(goodMap2.size === 0)
+ val goodMap3 = new PrimitiveKeyOpenHashMap[Int, Int](256)
+ assert(goodMap3.size === 0)
+ intercept[IllegalArgumentException] {
+ new PrimitiveKeyOpenHashMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29
+ }
+ intercept[IllegalArgumentException] {
+ new PrimitiveKeyOpenHashMap[Int, Int](-1)
+ }
+ intercept[IllegalArgumentException] {
+ new PrimitiveKeyOpenHashMap[Int, Int](0)
+ }
+ }
+
+ test("basic operations") {
+ val longBase = 1000000L
+ val map = new PrimitiveKeyOpenHashMap[Long, Int]
+
+ for (i <- 1 to 1000) {
+ map(i + longBase) = i
+ assert(map(i + longBase) === i)
+ }
+
+ assert(map.size === 1000)
+
+ for (i <- 1 to 1000) {
+ assert(map(i + longBase) === i)
+ }
+
+ // Test iterator
+ val set = new HashSet[(Long, Int)]
+ for ((k, v) <- map) {
+ set.add((k, v))
+ }
+ assert(set === (1 to 1000).map(x => (x + longBase, x)).toSet)
+ }
+
+ test("null values") {
+ val map = new PrimitiveKeyOpenHashMap[Long, String]()
+ for (i <- 1 to 100) {
+ map(i.toLong) = null
+ }
+ assert(map.size === 100)
+ assert(map(1.toLong) === null)
+ }
+
+ test("changeValue") {
+ val map = new PrimitiveKeyOpenHashMap[Long, String]()
+ for (i <- 1 to 100) {
+ map(i.toLong) = i.toString
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ val res = map.changeValue(i.toLong, { assert(false); "" }, v => {
+ assert(v === i.toString)
+ v + "!"
+ })
+ assert(res === i + "!")
+ }
+ // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a
+ // bug where changeValue would return the wrong result when the map grew on that insert
+ for (i <- 101 to 400) {
+ val res = map.changeValue(i.toLong, { i + "!" }, v => { assert(false); v })
+ assert(res === i + "!")
+ }
+ assert(map.size === 400)
+ }
+
+ test("inserting in capacity-1 map") {
+ val map = new PrimitiveKeyOpenHashMap[Long, String](1)
+ for (i <- 1 to 100) {
+ map(i.toLong) = i.toString
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ assert(map(i.toLong) === i.toString)
+ }
+ }
+}
diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md
index f679cad713..5927f736f3 100644
--- a/docs/cluster-overview.md
+++ b/docs/cluster-overview.md
@@ -13,7 +13,7 @@ object in your main program (called the _driver program_).
Specifically, to run on a cluster, the SparkContext can connect to several types of _cluster managers_
(either Spark's own standalone cluster manager or Mesos/YARN), which allocate resources across
applications. Once connected, Spark acquires *executors* on nodes in the cluster, which are
-worker processes that run computations and store data for your application.
+worker processes that run computations and store data for your application.
Next, it sends your application code (defined by JAR or Python files passed to SparkContext) to
the executors. Finally, SparkContext sends *tasks* for the executors to run.
@@ -57,6 +57,18 @@ which takes a list of JAR files (Java/Scala) or .egg and .zip libraries (Python)
worker nodes. You can also dynamically add new files to be sent to executors with `SparkContext.addJar`
and `addFile`.
+## URIs for addJar / addFile
+
+- **file:** - Absolute paths and `file:/` URIs are served by the driver's HTTP file server, and every executor
+ pulls the file from the driver HTTP server
+- **hdfs:**, **http:**, **https:**, **ftp:** - these pull down files and JARs from the URI as expected
+- **local:** - a URI starting with local:/ is expected to exist as a local file on each worker node. This
+ means that no network IO will be incurred, and works well for large files/JARs that are pushed to each worker,
+ or shared via NFS, GlusterFS, etc.
+
+Note that JARs and files are copied to the working directory for each SparkContext on the executor nodes.
+Over time this can use up a significant amount of space and will need to be cleaned up.
+
# Monitoring
Each driver program has a web UI, typically on port 4040, that displays information about running
diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md
index 1e5575d657..156a727026 100644
--- a/docs/ec2-scripts.md
+++ b/docs/ec2-scripts.md
@@ -98,7 +98,7 @@ permissions on your private key file, you can run `launch` with the
`bin/hadoop` script in that directory. Note that the data in this
HDFS goes away when you stop and restart a machine.
- There is also a *persistent HDFS* instance in
- `/root/presistent-hdfs` that will keep data across cluster restarts.
+ `/root/persistent-hdfs` that will keep data across cluster restarts.
Typically each node has relatively little space of persistent data
(about 3 GB), but you can use the `--ebs-vol-size` option to
`spark-ec2` to attach a persistent EBS volume to each node for
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 2898af0bed..6fd1d0d150 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -21,6 +21,7 @@ The assembled JAR will be something like this:
# Preparations
- Building a YARN-enabled assembly (see above).
+- The assembled jar can be installed into HDFS or used locally.
- Your application code must be packaged into a separate JAR file.
If you want to test out the YARN deployment mode, you can use the current Spark examples. A `spark-examples_{{site.SCALA_VERSION}}-{{site.SPARK_VERSION}}` file can be generated by running `sbt/sbt assembly`. NOTE: since the documentation you're reading is for Spark version {{site.SPARK_VERSION}}, we are assuming here that you have downloaded Spark {{site.SPARK_VERSION}} or checked it out of source control. If you are using a different version of Spark, the version numbers in the jar generated by the sbt package command will obviously be different.
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 65868b76b9..1189232428 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -72,12 +72,12 @@ def parse_args():
parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use")
parser.add_option("-v", "--spark-version", default="0.8.0",
help="Version of Spark to use: 'X.Y.Z' or a specific git hash")
- parser.add_option("--spark-git-repo",
- default="https://github.com/mesos/spark",
+ parser.add_option("--spark-git-repo",
+ default="https://github.com/apache/incubator-spark",
help="Github repo from which to checkout supplied commit hash")
parser.add_option("--hadoop-major-version", default="1",
help="Major version of Hadoop (default: 1)")
- parser.add_option("-D", metavar="[ADDRESS:]PORT", dest="proxy_port",
+ parser.add_option("-D", metavar="[ADDRESS:]PORT", dest="proxy_port",
help="Use SSH dynamic port forwarding to create a SOCKS proxy at " +
"the given local address (for use with login)")
parser.add_option("--resume", action="store_true", default=False,
@@ -101,6 +101,8 @@ def parse_args():
help="The SSH user you want to connect as (default: root)")
parser.add_option("--delete-groups", action="store_true", default=False,
help="When destroying a cluster, delete the security groups that were created")
+ parser.add_option("--use-existing-master", action="store_true", default=False,
+ help="Launch fresh slaves, but use an existing stopped master if possible")
(opts, args) = parser.parse_args()
if len(args) != 2:
@@ -191,7 +193,7 @@ def get_spark_ami(opts):
instance_type = "pvm"
print >> stderr,\
"Don't recognize %s, assuming type is pvm" % opts.instance_type
-
+
ami_path = "%s/%s/%s" % (AMI_PREFIX, opts.region, instance_type)
try:
ami = urllib2.urlopen(ami_path).read().strip()
@@ -215,6 +217,7 @@ def launch_cluster(conn, opts, cluster_name):
master_group.authorize(src_group=slave_group)
master_group.authorize('tcp', 22, 22, '0.0.0.0/0')
master_group.authorize('tcp', 8080, 8081, '0.0.0.0/0')
+ master_group.authorize('tcp', 19999, 19999, '0.0.0.0/0')
master_group.authorize('tcp', 50030, 50030, '0.0.0.0/0')
master_group.authorize('tcp', 50070, 50070, '0.0.0.0/0')
master_group.authorize('tcp', 60070, 60070, '0.0.0.0/0')
@@ -232,9 +235,9 @@ def launch_cluster(conn, opts, cluster_name):
slave_group.authorize('tcp', 60075, 60075, '0.0.0.0/0')
# Check if instances are already running in our groups
- active_nodes = get_existing_cluster(conn, opts, cluster_name,
- die_on_error=False)
- if any(active_nodes):
+ existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name,
+ die_on_error=False)
+ if existing_slaves or (existing_masters and not opts.use_existing_master):
print >> stderr, ("ERROR: There are already instances running in " +
"group %s or %s" % (master_group.name, slave_group.name))
sys.exit(1)
@@ -335,21 +338,28 @@ def launch_cluster(conn, opts, cluster_name):
zone, slave_res.id)
i += 1
- # Launch masters
- master_type = opts.master_instance_type
- if master_type == "":
- master_type = opts.instance_type
- if opts.zone == 'all':
- opts.zone = random.choice(conn.get_all_zones()).name
- master_res = image.run(key_name = opts.key_pair,
- security_groups = [master_group],
- instance_type = master_type,
- placement = opts.zone,
- min_count = 1,
- max_count = 1,
- block_device_map = block_map)
- master_nodes = master_res.instances
- print "Launched master in %s, regid = %s" % (zone, master_res.id)
+ # Launch or resume masters
+ if existing_masters:
+ print "Starting master..."
+ for inst in existing_masters:
+ if inst.state not in ["shutting-down", "terminated"]:
+ inst.start()
+ master_nodes = existing_masters
+ else:
+ master_type = opts.master_instance_type
+ if master_type == "":
+ master_type = opts.instance_type
+ if opts.zone == 'all':
+ opts.zone = random.choice(conn.get_all_zones()).name
+ master_res = image.run(key_name = opts.key_pair,
+ security_groups = [master_group],
+ instance_type = master_type,
+ placement = opts.zone,
+ min_count = 1,
+ max_count = 1,
+ block_device_map = block_map)
+ master_nodes = master_res.instances
+ print "Launched master in %s, regid = %s" % (zone, master_res.id)
# Return all the instances
return (master_nodes, slave_nodes)
@@ -403,8 +413,8 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
print slave.public_dns_name
ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar)
- modules = ['spark', 'shark', 'ephemeral-hdfs', 'persistent-hdfs',
- 'mapreduce', 'spark-standalone']
+ modules = ['spark', 'shark', 'ephemeral-hdfs', 'persistent-hdfs',
+ 'mapreduce', 'spark-standalone', 'tachyon']
if opts.hadoop_major_version == "1":
modules = filter(lambda x: x != "mapreduce", modules)
@@ -668,12 +678,12 @@ def real_main():
print "Terminating slaves..."
for inst in slave_nodes:
inst.terminate()
-
+
# Delete security groups as well
if opts.delete_groups:
print "Deleting security groups (this will take some time)..."
group_names = [cluster_name + "-master", cluster_name + "-slaves"]
-
+
attempt = 1;
while attempt <= 3:
print "Attempt %d" % attempt
@@ -731,6 +741,7 @@ def real_main():
cluster_name + "?\nDATA ON EPHEMERAL DISKS WILL BE LOST, " +
"BUT THE CLUSTER WILL KEEP USING SPACE ON\n" +
"AMAZON EBS IF IT IS EBS-BACKED!!\n" +
+ "All data on spot-instance slaves will be lost.\n" +
"Stop cluster " + cluster_name + " (y/N): ")
if response == "y":
(master_nodes, slave_nodes) = get_existing_cluster(
@@ -742,7 +753,10 @@ def real_main():
print "Stopping slaves..."
for inst in slave_nodes:
if inst.state not in ["shutting-down", "terminated"]:
- inst.stop()
+ if inst.spot_instance_request_id:
+ inst.terminate()
+ else:
+ inst.stop()
elif action == "start":
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
diff --git a/pom.xml b/pom.xml
index 53ac82efd0..edcc3b35cd 100644
--- a/pom.xml
+++ b/pom.xml
@@ -385,6 +385,12 @@
<version>3.1</version>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <version>1.8.5</version>
+ <scope>test</scope>
+ </dependency>
<dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_2.9.3</artifactId>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 45fd30a7c8..96232718f8 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -170,7 +170,8 @@ object SparkBuild extends Build {
"org.scalatest" %% "scalatest" % "1.9.1" % "test",
"org.scalacheck" %% "scalacheck" % "1.10.0" % "test",
"com.novocode" % "junit-interface" % "0.9" % "test",
- "org.easymock" % "easymock" % "3.1" % "test"
+ "org.easymock" % "easymock" % "3.1" % "test",
+ "org.mockito" % "mockito-all" % "1.8.5" % "test"
),
/* Workaround for issue #206 (fixed after SBT 0.11.0) */
watchTransitiveSources <<= Defaults.inDependencies[Task[Seq[File]]](watchSources.task,
@@ -265,7 +266,7 @@ object SparkBuild extends Build {
def toolsSettings = sharedSettings ++ Seq(
name := "spark-tools"
- )
+ ) ++ assemblySettings ++ extraAssemblySettings
def bagelSettings = sharedSettings ++ Seq(
name := "spark-bagel"
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
index e6e35c9b5d..870e12de34 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
@@ -878,14 +878,21 @@ class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends
(message, false)
}
}
+
+ // Get a copy of the local properties from SparkContext, and set it later in the thread
+ // that triggers the execution. This is to make sure the caller of this function can pass
+ // the right thread local (inheritable) properties down into Spark.
+ val sc = org.apache.spark.repl.Main.interp.sparkContext
+ val props = if (sc != null) sc.getLocalProperties() else null
try {
val execution = lineManager.set(originalLine) {
// MATEI: set the right SparkEnv for our SparkContext, because
// this execution will happen in a separate thread
- val sc = org.apache.spark.repl.Main.interp.sparkContext
- if (sc != null && sc.env != null)
+ if (sc != null && sc.env != null) {
SparkEnv.set(sc.env)
+ sc.setLocalProperties(props)
+ }
// Execute the line
lineRep call "$export"
}
diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 8f9b632c0e..6e4504d4d5 100644
--- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -21,12 +21,14 @@ import java.io._
import java.net.URLClassLoader
import scala.collection.mutable.ArrayBuffer
-import scala.collection.JavaConversions._
-import org.scalatest.FunSuite
import com.google.common.io.Files
+import org.scalatest.FunSuite
+import org.apache.spark.SparkContext
+
class ReplSuite extends FunSuite {
+
def runInterpreter(master: String, input: String): String = {
val in = new BufferedReader(new StringReader(input + "\n"))
val out = new StringWriter()
@@ -64,6 +66,35 @@ class ReplSuite extends FunSuite {
"Interpreter output contained '" + message + "':\n" + output)
}
+ test("propagation of local properties") {
+ // A mock ILoop that doesn't install the SIGINT handler.
+ class ILoop(out: PrintWriter) extends SparkILoop(None, out, None) {
+ settings = new scala.tools.nsc.Settings
+ settings.usejavacp.value = true
+ org.apache.spark.repl.Main.interp = this
+ override def createInterpreter() {
+ intp = new SparkILoopInterpreter
+ intp.setContextClassLoader()
+ }
+ }
+
+ val out = new StringWriter()
+ val interp = new ILoop(new PrintWriter(out))
+ interp.sparkContext = new SparkContext("local", "repl-test")
+ interp.createInterpreter()
+ interp.intp.initialize()
+ interp.sparkContext.setLocalProperty("someKey", "someValue")
+
+ // Make sure the value we set in the caller to interpret is propagated in the thread that
+ // interprets the command.
+ interp.interpret("org.apache.spark.repl.Main.interp.sparkContext.getLocalProperty(\"someKey\")")
+ assert(out.toString.contains("someValue"))
+
+ interp.sparkContext.stop()
+ System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
+ }
+
test ("simple foreach with accumulator") {
val output = runInterpreter("local", """
val accum = sc.accumulator(0)
diff --git a/spark-class b/spark-class
index fb9d1a4f8e..bbeca7f245 100755
--- a/spark-class
+++ b/spark-class
@@ -110,8 +110,21 @@ if [ ! -f "$FWDIR/RELEASE" ]; then
fi
fi
+TOOLS_DIR="$FWDIR"/tools
+SPARK_TOOLS_JAR=""
+if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/*assembly*[0-9Tg].jar ]; then
+ # Use the JAR from the SBT build
+ export SPARK_TOOLS_JAR=`ls "$TOOLS_DIR"/target/scala-$SCALA_VERSION/*assembly*[0-9Tg].jar`
+fi
+if [ -e "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar ]; then
+ # Use the JAR from the Maven build
+ # TODO: this also needs to become an assembly!
+ export SPARK_TOOLS_JAR=`ls "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar`
+fi
+
# Compute classpath using external script
CLASSPATH=`$FWDIR/bin/compute-classpath.sh`
+CLASSPATH="$SPARK_TOOLS_JAR:$CLASSPATH"
export CLASSPATH
if [ "$SPARK_PRINT_LAUNCH_COMMAND" == "1" ]; then
diff --git a/spark-class2.cmd b/spark-class2.cmd
index d4d853e8ad..3869d0761b 100644
--- a/spark-class2.cmd
+++ b/spark-class2.cmd
@@ -65,10 +65,17 @@ if "%FOUND_JAR%"=="0" (
)
:skip_build_test
+set TOOLS_DIR=%FWDIR%tools
+set SPARK_TOOLS_JAR=
+for %%d in ("%TOOLS_DIR%\target\scala-%SCALA_VERSION%\spark-tools*assembly*.jar") do (
+ set SPARK_TOOLS_JAR=%%d
+)
+
rem Compute classpath using external script
set DONT_PRINT_CLASSPATH=1
call "%FWDIR%bin\compute-classpath.cmd"
set DONT_PRINT_CLASSPATH=0
+set CLASSPATH=%SPARK_TOOLS_JAR%;%CLASSPATH%
rem Figure out where java is.
set RUNNER=java
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
index 8d3ac0fc65..a82862c802 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
@@ -232,11 +232,11 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
logInfo("Data handler stopped")
}
- def += (obj: T) {
+ def += (obj: T): Unit = synchronized {
currentBuffer += obj
}
- private def updateCurrentBuffer(time: Long) {
+ private def updateCurrentBuffer(time: Long): Unit = synchronized {
try {
val newBlockBuffer = currentBuffer
currentBuffer = new ArrayBuffer[T]
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index c29b75ece6..a559db468a 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -23,15 +23,15 @@ import akka.actor.IOManager
import akka.actor.Props
import akka.util.ByteString
-import dstream.SparkFlumeEvent
+import org.apache.spark.streaming.dstream.{NetworkReceiver, SparkFlumeEvent}
import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket}
import java.io.{File, BufferedWriter, OutputStreamWriter}
-import java.util.concurrent.{TimeUnit, ArrayBlockingQueue}
+import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue}
import collection.mutable.{SynchronizedBuffer, ArrayBuffer}
import util.ManualClock
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.receivers.Receiver
-import org.apache.spark.Logging
+import org.apache.spark.{SparkContext, Logging}
import scala.util.Random
import org.apache.commons.io.FileUtils
import org.scalatest.BeforeAndAfter
@@ -44,6 +44,7 @@ import java.nio.ByteBuffer
import collection.JavaConversions._
import java.nio.charset.Charset
import com.google.common.io.Files
+import java.util.concurrent.atomic.AtomicInteger
class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
@@ -61,7 +62,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
System.clearProperty("spark.hostPort")
}
-
test("socket input stream") {
// Start the server
val testServer = new TestServer()
@@ -275,10 +275,49 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
kafka.serializer.StringDecoder,
kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK)
}
+
+ test("multi-thread receiver") {
+ // set up the test receiver
+ val numThreads = 10
+ val numRecordsPerThread = 1000
+ val numTotalRecords = numThreads * numRecordsPerThread
+ val testReceiver = new MultiThreadTestReceiver(numThreads, numRecordsPerThread)
+ MultiThreadTestReceiver.haveAllThreadsFinished = false
+
+ // set up the network stream using the test receiver
+ val ssc = new StreamingContext(master, framework, batchDuration)
+ val networkStream = ssc.networkStream[Int](testReceiver)
+ val countStream = networkStream.count
+ val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]]
+ val outputStream = new TestOutputStream(countStream, outputBuffer)
+ def output = outputBuffer.flatMap(x => x)
+ ssc.registerOutputStream(outputStream)
+ ssc.start()
+
+ // Let the data from the receiver be received
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ val startTime = System.currentTimeMillis()
+ while((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) &&
+ System.currentTimeMillis() - startTime < 5000) {
+ Thread.sleep(100)
+ clock.addToTime(batchDuration.milliseconds)
+ }
+ Thread.sleep(1000)
+ logInfo("Stopping context")
+ ssc.stop()
+
+ // Verify whether data received was as expected
+ logInfo("--------------------------------")
+ logInfo("output.size = " + outputBuffer.size)
+ logInfo("output")
+ outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("--------------------------------")
+ assert(output.sum === numTotalRecords)
+ }
}
-/** This is server to test the network input stream */
+/** This is a server to test the network input stream */
class TestServer() extends Logging {
val queue = new ArrayBlockingQueue[String](100)
@@ -340,6 +379,7 @@ object TestServer {
}
}
+/** This is an actor for testing actor input stream */
class TestActor(port: Int) extends Actor with Receiver {
def bytesToString(byteString: ByteString) = byteString.utf8String
@@ -351,3 +391,36 @@ class TestActor(port: Int) extends Actor with Receiver {
pushBlock(bytesToString(bytes))
}
}
+
+/** This is a receiver to test multiple threads inserting data using block generator */
+class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int)
+ extends NetworkReceiver[Int] {
+ lazy val executorPool = Executors.newFixedThreadPool(numThreads)
+ lazy val blockGenerator = new BlockGenerator(StorageLevel.MEMORY_ONLY)
+ lazy val finishCount = new AtomicInteger(0)
+
+ protected def onStart() {
+ blockGenerator.start()
+ (1 to numThreads).map(threadId => {
+ val runnable = new Runnable {
+ def run() {
+ (1 to numRecordsPerThread).foreach(i =>
+ blockGenerator += (threadId * numRecordsPerThread + i) )
+ if (finishCount.incrementAndGet == numThreads) {
+ MultiThreadTestReceiver.haveAllThreadsFinished = true
+ }
+ logInfo("Finished thread " + threadId)
+ }
+ }
+ executorPool.submit(runnable)
+ })
+ }
+
+ protected def onStop() {
+ executorPool.shutdown()
+ }
+}
+
+object MultiThreadTestReceiver {
+ var haveAllThreadsFinished = false
+}
diff --git a/yarn/pom.xml b/yarn/pom.xml
index 3bc619df07..8a065c6d7d 100644
--- a/yarn/pom.xml
+++ b/yarn/pom.xml
@@ -61,6 +61,16 @@
<groupId>org.apache.avro</groupId>
<artifactId>avro-ipc</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_2.9.3</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
@@ -106,6 +116,46 @@
</execution>
</executions>
</plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-antrun-plugin</artifactId>
+ <executions>
+ <execution>
+ <phase>test</phase>
+ <goals>
+ <goal>run</goal>
+ </goals>
+ <configuration>
+ <exportAntProperties>true</exportAntProperties>
+ <tasks>
+ <property name="spark.classpath" refid="maven.test.classpath" />
+ <property environment="env" />
+ <fail message="Please set the SCALA_HOME (or SCALA_LIBRARY_PATH if scala is on the path) environment variables and retry.">
+ <condition>
+ <not>
+ <or>
+ <isset property="env.SCALA_HOME" />
+ <isset property="env.SCALA_LIBRARY_PATH" />
+ </or>
+ </not>
+ </condition>
+ </fail>
+ </tasks>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ <configuration>
+ <environmentVariables>
+ <SPARK_HOME>${basedir}/..</SPARK_HOME>
+ <SPARK_TESTING>1</SPARK_TESTING>
+ <SPARK_CLASSPATH>${spark.classpath}</SPARK_CLASSPATH>
+ </environmentVariables>
+ </configuration>
+ </plugin>
</plugins>
</build>
</project>
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index c1a87d3373..4302ef4cda 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -349,7 +349,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
try {
val preserveFiles = System.getProperty("spark.yarn.preserve.staging.files", "false").toBoolean
if (!preserveFiles) {
- stagingDirPath = new Path(System.getenv("SPARK_YARN_JAR_PATH")).getParent()
+ stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR"))
if (stagingDirPath == null) {
logError("Staging directory is null")
return
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 1a380ae714..4e0e060ddc 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -17,26 +17,31 @@
package org.apache.spark.deploy.yarn
-import java.net.{InetSocketAddress, URI}
+import java.net.{InetAddress, InetSocketAddress, UnknownHostException, URI}
import java.nio.ByteBuffer
+
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.apache.hadoop.fs.{FileContext, FileStatus, FileSystem, Path, FileUtil}
+import org.apache.hadoop.fs.permission.FsPermission;
import org.apache.hadoop.mapred.Master
import org.apache.hadoop.net.NetUtils
import org.apache.hadoop.io.DataOutputBuffer
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.client.YarnClientImpl
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{Apps, Records}
+
import scala.collection.mutable.HashMap
+import scala.collection.mutable.Map
import scala.collection.JavaConversions._
+
import org.apache.spark.Logging
import org.apache.spark.util.Utils
-import org.apache.hadoop.yarn.util.{Apps, Records, ConverterUtils}
-import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import org.apache.spark.deploy.SparkHadoopUtil
class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging {
@@ -46,13 +51,14 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
var rpc: YarnRPC = YarnRPC.create(conf)
val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
val credentials = UserGroupInformation.getCurrentUser().getCredentials()
- private var distFiles = None: Option[String]
- private var distFilesTimeStamps = None: Option[String]
- private var distFilesFileSizes = None: Option[String]
- private var distArchives = None: Option[String]
- private var distArchivesTimeStamps = None: Option[String]
- private var distArchivesFileSizes = None: Option[String]
-
+ private val SPARK_STAGING: String = ".sparkStaging"
+ private val distCacheMgr = new ClientDistributedCacheManager()
+
+ // staging directory is private! -> rwx--------
+ val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700:Short)
+ // app files are world-wide readable and owner writable -> rw-r--r--
+ val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short)
+
def run() {
init(yarnConf)
start()
@@ -63,8 +69,9 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
verifyClusterResources(newApp)
val appContext = createApplicationSubmissionContext(appId)
- val localResources = prepareLocalResources(appId, ".sparkStaging")
- val env = setupLaunchEnv(localResources)
+ val appStagingDir = getAppStagingDir(appId)
+ val localResources = prepareLocalResources(appStagingDir)
+ val env = setupLaunchEnv(localResources, appStagingDir)
val amContainer = createContainerLaunchContext(newApp, localResources, env)
appContext.setQueue(args.amQueue)
@@ -76,7 +83,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
monitorApplication(appId)
System.exit(0)
}
-
+
+ def getAppStagingDir(appId: ApplicationId): String = {
+ SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR
+ }
def logClusterResourceDetails() {
val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics
@@ -116,73 +126,73 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
return appContext
}
+ /*
+ * see if two file systems are the same or not.
+ */
+ private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = {
+ val srcUri = srcFs.getUri()
+ val dstUri = destFs.getUri()
+ if (srcUri.getScheme() == null) {
+ return false
+ }
+ if (!srcUri.getScheme().equals(dstUri.getScheme())) {
+ return false
+ }
+ var srcHost = srcUri.getHost()
+ var dstHost = dstUri.getHost()
+ if ((srcHost != null) && (dstHost != null)) {
+ try {
+ srcHost = InetAddress.getByName(srcHost).getCanonicalHostName();
+ dstHost = InetAddress.getByName(dstHost).getCanonicalHostName();
+ } catch {
+ case e: UnknownHostException =>
+ return false
+ }
+ if (!srcHost.equals(dstHost)) {
+ return false
+ }
+ } else if (srcHost == null && dstHost != null) {
+ return false
+ } else if (srcHost != null && dstHost == null) {
+ return false
+ }
+ //check for ports
+ if (srcUri.getPort() != dstUri.getPort()) {
+ return false
+ }
+ return true;
+ }
+
/**
- * Copy the local file into HDFS and configure to be distributed with the
- * job via the distributed cache.
- * If a fragment is specified the file will be referenced as that fragment.
+ * Copy the file into HDFS if needed.
*/
- private def copyLocalFile(
+ private def copyRemoteFile(
dstDir: Path,
- resourceType: LocalResourceType,
originalPath: Path,
replication: Short,
- localResources: HashMap[String,LocalResource],
- fragment: String,
- appMasterOnly: Boolean = false): Unit = {
+ setPerms: Boolean = false): Path = {
val fs = FileSystem.get(conf)
- val newPath = new Path(dstDir, originalPath.getName())
- logInfo("Uploading " + originalPath + " to " + newPath)
- fs.copyFromLocalFile(false, true, originalPath, newPath)
- fs.setReplication(newPath, replication);
- val destStatus = fs.getFileStatus(newPath)
-
- val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
- amJarRsrc.setType(resourceType)
- amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
- amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(newPath))
- amJarRsrc.setTimestamp(destStatus.getModificationTime())
- amJarRsrc.setSize(destStatus.getLen())
- var pathURI: URI = new URI(newPath.toString() + "#" + originalPath.getName());
- if ((fragment == null) || (fragment.isEmpty())){
- localResources(originalPath.getName()) = amJarRsrc
- } else {
- localResources(fragment) = amJarRsrc
- pathURI = new URI(newPath.toString() + "#" + fragment);
- }
- val distPath = pathURI.toString()
- if (appMasterOnly == true) return
- if (resourceType == LocalResourceType.FILE) {
- distFiles match {
- case Some(path) =>
- distFilesFileSizes = Some(distFilesFileSizes.get + "," +
- destStatus.getLen().toString())
- distFilesTimeStamps = Some(distFilesTimeStamps.get + "," +
- destStatus.getModificationTime().toString())
- distFiles = Some(path + "," + distPath)
- case _ =>
- distFilesFileSizes = Some(destStatus.getLen().toString())
- distFilesTimeStamps = Some(destStatus.getModificationTime().toString())
- distFiles = Some(distPath)
- }
- } else {
- distArchives match {
- case Some(path) =>
- distArchivesTimeStamps = Some(distArchivesTimeStamps.get + "," +
- destStatus.getModificationTime().toString())
- distArchivesFileSizes = Some(distArchivesFileSizes.get + "," +
- destStatus.getLen().toString())
- distArchives = Some(path + "," + distPath)
- case _ =>
- distArchivesTimeStamps = Some(destStatus.getModificationTime().toString())
- distArchivesFileSizes = Some(destStatus.getLen().toString())
- distArchives = Some(distPath)
- }
- }
+ val remoteFs = originalPath.getFileSystem(conf);
+ var newPath = originalPath
+ if (! compareFs(remoteFs, fs)) {
+ newPath = new Path(dstDir, originalPath.getName())
+ logInfo("Uploading " + originalPath + " to " + newPath)
+ FileUtil.copy(remoteFs, originalPath, fs, newPath, false, conf);
+ fs.setReplication(newPath, replication);
+ if (setPerms) fs.setPermission(newPath, new FsPermission(APP_FILE_PERMISSION))
+ }
+ // resolve any symlinks in the URI path so using a "current" symlink
+ // to point to a specific version shows the specific version
+ // in the distributed cache configuration
+ val qualPath = fs.makeQualified(newPath)
+ val fc = FileContext.getFileContext(qualPath.toUri(), conf)
+ val destPath = fc.resolvePath(qualPath)
+ destPath
}
- def prepareLocalResources(appId: ApplicationId, sparkStagingDir: String): HashMap[String, LocalResource] = {
+ def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = {
logInfo("Preparing Local resources")
- // Upload Spark and the application JAR to the remote file system
+ // Upload Spark and the application JAR to the remote file system if necessary
// Add them as local resources to the AM
val fs = FileSystem.get(conf)
@@ -193,9 +203,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
System.exit(1)
}
}
-
- val pathSuffix = sparkStagingDir + "/" + appId.toString() + "/"
- val dst = new Path(fs.getHomeDirectory(), pathSuffix)
+ val dst = new Path(fs.getHomeDirectory(), appStagingDir)
val replication = System.getProperty("spark.yarn.submit.file.replication", "3").toShort
if (UserGroupInformation.isSecurityEnabled()) {
@@ -203,55 +211,65 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
dstFs.addDelegationTokens(delegTokenRenewer, credentials);
}
val localResources = HashMap[String, LocalResource]()
+ FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION))
+
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+
+ if (System.getenv("SPARK_JAR") == null || args.userJar == null) {
+ logError("Error: You must set SPARK_JAR environment variable and specify a user jar!")
+ System.exit(1)
+ }
- Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF"))
+ Map(Client.SPARK_JAR -> System.getenv("SPARK_JAR"), Client.APP_JAR -> args.userJar,
+ Client.LOG4J_PROP -> System.getenv("SPARK_LOG4J_CONF"))
.foreach { case(destName, _localPath) =>
val localPath: String = if (_localPath != null) _localPath.trim() else ""
if (! localPath.isEmpty()) {
- val src = new Path(localPath)
- val newPath = new Path(dst, destName)
- logInfo("Uploading " + src + " to " + newPath)
- fs.copyFromLocalFile(false, true, src, newPath)
- fs.setReplication(newPath, replication);
- val destStatus = fs.getFileStatus(newPath)
-
- val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
- amJarRsrc.setType(LocalResourceType.FILE)
- amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
- amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(newPath))
- amJarRsrc.setTimestamp(destStatus.getModificationTime())
- amJarRsrc.setSize(destStatus.getLen())
- localResources(destName) = amJarRsrc
+ var localURI = new URI(localPath)
+ // if not specified assume these are in the local filesystem to keep behavior like Hadoop
+ if (localURI.getScheme() == null) {
+ localURI = new URI(FileSystem.getLocal(conf).makeQualified(new Path(localPath)).toString())
+ }
+ val setPermissions = if (destName.equals(Client.APP_JAR)) true else false
+ val destPath = copyRemoteFile(dst, new Path(localURI), replication, setPermissions)
+ distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE,
+ destName, statCache)
}
}
// handle any add jars
if ((args.addJars != null) && (!args.addJars.isEmpty())){
args.addJars.split(',').foreach { case file: String =>
- val tmpURI = new URI(file)
- val tmp = new Path(tmpURI)
- copyLocalFile(dst, LocalResourceType.FILE, tmp, replication, localResources,
- tmpURI.getFragment(), true)
+ val localURI = new URI(file.trim())
+ val localPath = new Path(localURI)
+ val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
+ val destPath = copyRemoteFile(dst, localPath, replication)
+ distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE,
+ linkname, statCache, true)
}
}
// handle any distributed cache files
if ((args.files != null) && (!args.files.isEmpty())){
args.files.split(',').foreach { case file: String =>
- val tmpURI = new URI(file)
- val tmp = new Path(tmpURI)
- copyLocalFile(dst, LocalResourceType.FILE, tmp, replication, localResources,
- tmpURI.getFragment())
+ val localURI = new URI(file.trim())
+ val localPath = new Path(localURI)
+ val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
+ val destPath = copyRemoteFile(dst, localPath, replication)
+ distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE,
+ linkname, statCache)
}
}
// handle any distributed cache archives
if ((args.archives != null) && (!args.archives.isEmpty())) {
args.archives.split(',').foreach { case file:String =>
- val tmpURI = new URI(file)
- val tmp = new Path(tmpURI)
- copyLocalFile(dst, LocalResourceType.ARCHIVE, tmp, replication,
- localResources, tmpURI.getFragment())
+ val localURI = new URI(file.trim())
+ val localPath = new Path(localURI)
+ val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
+ val destPath = copyRemoteFile(dst, localPath, replication)
+ distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE,
+ linkname, statCache)
}
}
@@ -259,44 +277,21 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
return localResources
}
- def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = {
+ def setupLaunchEnv(
+ localResources: HashMap[String, LocalResource],
+ stagingDir: String): HashMap[String, String] = {
logInfo("Setting up the launch environment")
- val log4jConfLocalRes = localResources.getOrElse("log4j.properties", null)
+ val log4jConfLocalRes = localResources.getOrElse(Client.LOG4J_PROP, null)
val env = new HashMap[String, String]()
Client.populateClasspath(yarnConf, log4jConfLocalRes != null, env)
env("SPARK_YARN_MODE") = "true"
- env("SPARK_YARN_JAR_PATH") =
- localResources("spark.jar").getResource().getScheme.toString() + "://" +
- localResources("spark.jar").getResource().getFile().toString()
- env("SPARK_YARN_JAR_TIMESTAMP") = localResources("spark.jar").getTimestamp().toString()
- env("SPARK_YARN_JAR_SIZE") = localResources("spark.jar").getSize().toString()
-
- env("SPARK_YARN_USERJAR_PATH") =
- localResources("app.jar").getResource().getScheme.toString() + "://" +
- localResources("app.jar").getResource().getFile().toString()
- env("SPARK_YARN_USERJAR_TIMESTAMP") = localResources("app.jar").getTimestamp().toString()
- env("SPARK_YARN_USERJAR_SIZE") = localResources("app.jar").getSize().toString()
-
- if (log4jConfLocalRes != null) {
- env("SPARK_YARN_LOG4J_PATH") =
- log4jConfLocalRes.getResource().getScheme.toString() + "://" + log4jConfLocalRes.getResource().getFile().toString()
- env("SPARK_YARN_LOG4J_TIMESTAMP") = log4jConfLocalRes.getTimestamp().toString()
- env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString()
- }
+ env("SPARK_YARN_STAGING_DIR") = stagingDir
// set the environment variables to be passed on to the Workers
- if (distFiles != None) {
- env("SPARK_YARN_CACHE_FILES") = distFiles.get
- env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = distFilesTimeStamps.get
- env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = distFilesFileSizes.get
- }
- if (distArchives != None) {
- env("SPARK_YARN_CACHE_ARCHIVES") = distArchives.get
- env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = distArchivesTimeStamps.get
- env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") = distArchivesFileSizes.get
- }
+ distCacheMgr.setDistFilesEnv(env)
+ distCacheMgr.setDistArchivesEnv(env)
// allow users to specify some environment variables
Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"))
@@ -365,6 +360,11 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
javaCommand = Environment.JAVA_HOME.$() + "/bin/java"
}
+ if (args.userClass == null) {
+ logError("Error: You must specify a user class!")
+ System.exit(1)
+ }
+
val commands = List[String](javaCommand +
" -server " +
JAVA_OPTS +
@@ -432,6 +432,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
}
object Client {
+ val SPARK_JAR: String = "spark.jar"
+ val APP_JAR: String = "app.jar"
+ val LOG4J_PROP: String = "log4j.properties"
+
def main(argStrings: Array[String]) {
// Set an env variable indicating we are running in YARN mode.
// Note that anything with SPARK prefix gets propagated to all (remote) processes
@@ -453,22 +457,22 @@ object Client {
// If log4j present, ensure ours overrides all others
if (addLog4j) {
Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
- Path.SEPARATOR + "log4j.properties")
+ Path.SEPARATOR + LOG4J_PROP)
}
// normally the users app.jar is last in case conflicts with spark jars
val userClasspathFirst = System.getProperty("spark.yarn.user.classpath.first", "false")
.toBoolean
if (userClasspathFirst) {
Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
- Path.SEPARATOR + "app.jar")
+ Path.SEPARATOR + APP_JAR)
}
Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
- Path.SEPARATOR + "spark.jar")
+ Path.SEPARATOR + SPARK_JAR)
Client.populateHadoopClasspath(conf, env)
if (!userClasspathFirst) {
Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
- Path.SEPARATOR + "app.jar")
+ Path.SEPARATOR + APP_JAR)
}
Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
Path.SEPARATOR + "*")
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
new file mode 100644
index 0000000000..07686fefd7
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.URI;
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.FileStatus
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.permission.FsAction
+import org.apache.hadoop.yarn.api.records.LocalResource
+import org.apache.hadoop.yarn.api.records.LocalResourceVisibility
+import org.apache.hadoop.yarn.api.records.LocalResourceType
+import org.apache.hadoop.yarn.util.{Records, ConverterUtils}
+
+import org.apache.spark.Logging
+
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.LinkedHashMap
+import scala.collection.mutable.Map
+
+
+/** Client side methods to setup the Hadoop distributed cache */
+class ClientDistributedCacheManager() extends Logging {
+ private val distCacheFiles: Map[String, Tuple3[String, String, String]] =
+ LinkedHashMap[String, Tuple3[String, String, String]]()
+ private val distCacheArchives: Map[String, Tuple3[String, String, String]] =
+ LinkedHashMap[String, Tuple3[String, String, String]]()
+
+
+ /**
+ * Add a resource to the list of distributed cache resources. This list can
+ * be sent to the ApplicationMaster and possibly the workers so that it can
+ * be downloaded into the Hadoop distributed cache for use by this application.
+ * Adds the LocalResource to the localResources HashMap passed in and saves
+ * the stats of the resources to they can be sent to the workers and verified.
+ *
+ * @param fs FileSystem
+ * @param conf Configuration
+ * @param destPath path to the resource
+ * @param localResources localResource hashMap to insert the resource into
+ * @param resourceType LocalResourceType
+ * @param link link presented in the distributed cache to the destination
+ * @param statCache cache to store the file/directory stats
+ * @param appMasterOnly Whether to only add the resource to the app master
+ */
+ def addResource(
+ fs: FileSystem,
+ conf: Configuration,
+ destPath: Path,
+ localResources: HashMap[String, LocalResource],
+ resourceType: LocalResourceType,
+ link: String,
+ statCache: Map[URI, FileStatus],
+ appMasterOnly: Boolean = false) = {
+ val destStatus = fs.getFileStatus(destPath)
+ val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ amJarRsrc.setType(resourceType)
+ val visibility = getVisibility(conf, destPath.toUri(), statCache)
+ amJarRsrc.setVisibility(visibility)
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(destPath))
+ amJarRsrc.setTimestamp(destStatus.getModificationTime())
+ amJarRsrc.setSize(destStatus.getLen())
+ if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name")
+ localResources(link) = amJarRsrc
+
+ if (appMasterOnly == false) {
+ val uri = destPath.toUri()
+ val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link)
+ if (resourceType == LocalResourceType.FILE) {
+ distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(),
+ destStatus.getModificationTime().toString(), visibility.name())
+ } else {
+ distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(),
+ destStatus.getModificationTime().toString(), visibility.name())
+ }
+ }
+ }
+
+ /**
+ * Adds the necessary cache file env variables to the env passed in
+ * @param env
+ */
+ def setDistFilesEnv(env: Map[String, String]) = {
+ val (keys, tupleValues) = distCacheFiles.unzip
+ val (sizes, timeStamps, visibilities) = tupleValues.unzip3
+
+ if (keys.size > 0) {
+ env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") =
+ timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_FILES_FILE_SIZES") =
+ sizes.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_FILES_VISIBILITIES") =
+ visibilities.reduceLeft[String] { (acc,n) => acc + "," + n }
+ }
+ }
+
+ /**
+ * Adds the necessary cache archive env variables to the env passed in
+ * @param env
+ */
+ def setDistArchivesEnv(env: Map[String, String]) = {
+ val (keys, tupleValues) = distCacheArchives.unzip
+ val (sizes, timeStamps, visibilities) = tupleValues.unzip3
+
+ if (keys.size > 0) {
+ env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") =
+ timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") =
+ sizes.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") =
+ visibilities.reduceLeft[String] { (acc,n) => acc + "," + n }
+ }
+ }
+
+ /**
+ * Returns the local resource visibility depending on the cache file permissions
+ * @param conf
+ * @param uri
+ * @param statCache
+ * @return LocalResourceVisibility
+ */
+ def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]):
+ LocalResourceVisibility = {
+ if (isPublic(conf, uri, statCache)) {
+ return LocalResourceVisibility.PUBLIC
+ }
+ return LocalResourceVisibility.PRIVATE
+ }
+
+ /**
+ * Returns a boolean to denote whether a cache file is visible to all(public)
+ * or not
+ * @param conf
+ * @param uri
+ * @param statCache
+ * @return true if the path in the uri is visible to all, false otherwise
+ */
+ def isPublic(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): Boolean = {
+ val fs = FileSystem.get(uri, conf)
+ val current = new Path(uri.getPath())
+ //the leaf level file should be readable by others
+ if (!checkPermissionOfOther(fs, current, FsAction.READ, statCache)) {
+ return false
+ }
+ return ancestorsHaveExecutePermissions(fs, current.getParent(), statCache)
+ }
+
+ /**
+ * Returns true if all ancestors of the specified path have the 'execute'
+ * permission set for all users (i.e. that other users can traverse
+ * the directory heirarchy to the given path)
+ * @param fs
+ * @param path
+ * @param statCache
+ * @return true if all ancestors have the 'execute' permission set for all users
+ */
+ def ancestorsHaveExecutePermissions(fs: FileSystem, path: Path,
+ statCache: Map[URI, FileStatus]): Boolean = {
+ var current = path
+ while (current != null) {
+ //the subdirs in the path should have execute permissions for others
+ if (!checkPermissionOfOther(fs, current, FsAction.EXECUTE, statCache)) {
+ return false
+ }
+ current = current.getParent()
+ }
+ return true
+ }
+
+ /**
+ * Checks for a given path whether the Other permissions on it
+ * imply the permission in the passed FsAction
+ * @param fs
+ * @param path
+ * @param action
+ * @param statCache
+ * @return true if the path in the uri is visible to all, false otherwise
+ */
+ def checkPermissionOfOther(fs: FileSystem, path: Path,
+ action: FsAction, statCache: Map[URI, FileStatus]): Boolean = {
+ val status = getFileStatus(fs, path.toUri(), statCache);
+ val perms = status.getPermission()
+ val otherAction = perms.getOtherAction()
+ if (otherAction.implies(action)) {
+ return true;
+ }
+ return false
+ }
+
+ /**
+ * Checks to see if the given uri exists in the cache, if it does it
+ * returns the existing FileStatus, otherwise it stats the uri, stores
+ * it in the cache, and returns the FileStatus.
+ * @param fs
+ * @param uri
+ * @param statCache
+ * @return FileStatus
+ */
+ def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = {
+ val stat = statCache.get(uri) match {
+ case Some(existstat) => existstat
+ case None =>
+ val newStat = fs.getFileStatus(new Path(uri))
+ statCache.put(uri, newStat)
+ newStat
+ }
+ return stat
+ }
+}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
index ba352daac4..7a66532254 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
@@ -142,11 +142,12 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
rtype: LocalResourceType,
localResources: HashMap[String, LocalResource],
timestamp: String,
- size: String) = {
+ size: String,
+ vis: String) = {
val uri = new URI(file)
val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
amJarRsrc.setType(rtype)
- amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
+ amJarRsrc.setVisibility(LocalResourceVisibility.valueOf(vis))
amJarRsrc.setResource(ConverterUtils.getYarnUrlFromURI(uri))
amJarRsrc.setTimestamp(timestamp.toLong)
amJarRsrc.setSize(size.toLong)
@@ -158,44 +159,14 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
logInfo("Preparing Local resources")
val localResources = HashMap[String, LocalResource]()
- // Spark JAR
- val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
- sparkJarResource.setType(LocalResourceType.FILE)
- sparkJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
- sparkJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
- new URI(System.getenv("SPARK_YARN_JAR_PATH"))))
- sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong)
- sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong)
- localResources("spark.jar") = sparkJarResource
- // User JAR
- val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
- userJarResource.setType(LocalResourceType.FILE)
- userJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
- userJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
- new URI(System.getenv("SPARK_YARN_USERJAR_PATH"))))
- userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong)
- userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong)
- localResources("app.jar") = userJarResource
-
- // Log4j conf - if available
- if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
- val log4jConfResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
- log4jConfResource.setType(LocalResourceType.FILE)
- log4jConfResource.setVisibility(LocalResourceVisibility.APPLICATION)
- log4jConfResource.setResource(ConverterUtils.getYarnUrlFromURI(
- new URI(System.getenv("SPARK_YARN_LOG4J_PATH"))))
- log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong)
- log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong)
- localResources("log4j.properties") = log4jConfResource
- }
-
if (System.getenv("SPARK_YARN_CACHE_FILES") != null) {
val timeStamps = System.getenv("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',')
val fileSizes = System.getenv("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',')
val distFiles = System.getenv("SPARK_YARN_CACHE_FILES").split(',')
+ val visibilities = System.getenv("SPARK_YARN_CACHE_FILES_VISIBILITIES").split(',')
for( i <- 0 to distFiles.length - 1) {
setupDistributedCache(distFiles(i), LocalResourceType.FILE, localResources, timeStamps(i),
- fileSizes(i))
+ fileSizes(i), visibilities(i))
}
}
@@ -203,9 +174,10 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
val timeStamps = System.getenv("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS").split(',')
val fileSizes = System.getenv("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES").split(',')
val distArchives = System.getenv("SPARK_YARN_CACHE_ARCHIVES").split(',')
+ val visibilities = System.getenv("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES").split(',')
for( i <- 0 to distArchives.length - 1) {
setupDistributedCache(distArchives(i), LocalResourceType.ARCHIVE, localResources,
- timeStamps(i), fileSizes(i))
+ timeStamps(i), fileSizes(i), visibilities(i))
}
}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
new file mode 100644
index 0000000000..c0a2af0c6f
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.URI;
+
+import org.scalatest.FunSuite
+import org.scalatest.mock.MockitoSugar
+import org.mockito.Mockito.when
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.FileStatus
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.permission.FsAction
+import org.apache.hadoop.yarn.api.records.LocalResource
+import org.apache.hadoop.yarn.api.records.LocalResourceVisibility
+import org.apache.hadoop.yarn.api.records.LocalResourceType
+import org.apache.hadoop.yarn.util.{Records, ConverterUtils}
+
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.Map
+
+
+class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
+
+ class MockClientDistributedCacheManager extends ClientDistributedCacheManager {
+ override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]):
+ LocalResourceVisibility = {
+ return LocalResourceVisibility.PRIVATE
+ }
+ }
+
+ test("test getFileStatus empty") {
+ val distMgr = new ClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val uri = new URI("/tmp/testing")
+ when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus())
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val stat = distMgr.getFileStatus(fs, uri, statCache)
+ assert(stat.getPath() === null)
+ }
+
+ test("test getFileStatus cached") {
+ val distMgr = new ClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val uri = new URI("/tmp/testing")
+ val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus())
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus)
+ val stat = distMgr.getFileStatus(fs, uri, statCache)
+ assert(stat.getPath().toString() === "/tmp/testing")
+ }
+
+ test("test addResource") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link",
+ statCache, false)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 0)
+ assert(resource.getSize() === 0)
+ assert(resource.getType() === LocalResourceType.FILE)
+
+ val env = new HashMap[String, String]()
+ distMgr.setDistFilesEnv(env)
+ assert(env("SPARK_YARN_CACHE_FILES") === "file:/foo.invalid.com:8080/tmp/testing#link")
+ assert(env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === "0")
+ assert(env("SPARK_YARN_CACHE_FILES_FILE_SIZES") === "0")
+ assert(env("SPARK_YARN_CACHE_FILES_VISIBILITIES") === LocalResourceVisibility.PRIVATE.name())
+
+ distMgr.setDistArchivesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None)
+
+ //add another one and verify both there and order correct
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing2"))
+ val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2")
+ when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus)
+ distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2",
+ statCache, false)
+ val resource2 = localResources("link2")
+ assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource2.getResource()) === destPath2)
+ assert(resource2.getTimestamp() === 10)
+ assert(resource2.getSize() === 20)
+ assert(resource2.getType() === LocalResourceType.FILE)
+
+ val env2 = new HashMap[String, String]()
+ distMgr.setDistFilesEnv(env2)
+ val timestamps = env2("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',')
+ val files = env2("SPARK_YARN_CACHE_FILES").split(',')
+ val sizes = env2("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',')
+ val visibilities = env2("SPARK_YARN_CACHE_FILES_VISIBILITIES") .split(',')
+ assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link")
+ assert(timestamps(0) === "0")
+ assert(sizes(0) === "0")
+ assert(visibilities(0) === LocalResourceVisibility.PRIVATE.name())
+
+ assert(files(1) === "file:/foo.invalid.com:8080/tmp/testing2#link2")
+ assert(timestamps(1) === "10")
+ assert(sizes(1) === "20")
+ assert(visibilities(1) === LocalResourceVisibility.PRIVATE.name())
+ }
+
+ test("test addResource link null") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
+
+ intercept[Exception] {
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null,
+ statCache, false)
+ }
+ assert(localResources.get("link") === None)
+ assert(localResources.size === 0)
+ }
+
+ test("test addResource appmaster only") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
+ statCache, true)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 10)
+ assert(resource.getSize() === 20)
+ assert(resource.getType() === LocalResourceType.ARCHIVE)
+
+ val env = new HashMap[String, String]()
+ distMgr.setDistFilesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_FILES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_VISIBILITIES") === None)
+
+ distMgr.setDistArchivesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None)
+ }
+
+ test("test addResource archive") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
+ statCache, false)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 10)
+ assert(resource.getSize() === 20)
+ assert(resource.getType() === LocalResourceType.ARCHIVE)
+
+ val env = new HashMap[String, String]()
+
+ distMgr.setDistArchivesEnv(env)
+ assert(env("SPARK_YARN_CACHE_ARCHIVES") === "file:/foo.invalid.com:8080/tmp/testing#link")
+ assert(env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === "10")
+ assert(env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === "20")
+ assert(env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === LocalResourceVisibility.PRIVATE.name())
+
+ distMgr.setDistFilesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_FILES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_VISIBILITIES") === None)
+ }
+
+
+}