aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache
diff options
context:
space:
mode:
authorKay Ousterhout <kayousterhout@gmail.com>2013-11-13 14:38:44 -0800
committerKay Ousterhout <kayousterhout@gmail.com>2013-11-13 14:38:44 -0800
commit150615a31e0c2a5112c37cca62dd80dba8a12fab (patch)
tree6ed706d3f5a5bc9f55fd12bba7425df950ed11df /core/src/main/scala/org/apache
parent68e5ad58b7e7e3e1b42852de8d0fdf9e9b9c1a14 (diff)
parent39af914b273e35ff431844951ee8dfadcbc0c400 (diff)
downloadspark-150615a31e0c2a5112c37cca62dd80dba8a12fab.tar.gz
spark-150615a31e0c2a5112c37cca62dd80dba8a12fab.tar.bz2
spark-150615a31e0c2a5112c37cca62dd80dba8a12fab.zip
Merge remote-tracking branch 'upstream/master' into consolidate_schedulers
Conflicts: core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala
Diffstat (limited to 'core/src/main/scala/org/apache')
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala31
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala664
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala52
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockInfo.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala49
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskStore.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala189
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala31
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/BitSet.scala103
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala152
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala271
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala127
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala51
29 files changed, 1455 insertions, 504 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index e8ff4da475..10db2fa7e7 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -24,7 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.Map
import scala.collection.generic.Growable
-import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@@ -144,6 +144,14 @@ class SparkContext(
executorEnvs ++= environment
}
+ // Set SPARK_USER for user who is running SparkContext.
+ val sparkUser = Option {
+ Option(System.getProperty("user.name")).getOrElse(System.getenv("SPARK_USER"))
+ }.getOrElse {
+ SparkContext.SPARK_UNKNOWN_USER
+ }
+ executorEnvs("SPARK_USER") = sparkUser
+
// Create and start the scheduler
private[spark] var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
@@ -255,8 +263,10 @@ class SparkContext(
conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
}
// Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
- for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) {
- conf.set(key.substring("spark.hadoop.".length), System.getProperty(key))
+ Utils.getSystemProperties.foreach { case (key, value) =>
+ if (key.startsWith("spark.hadoop.")) {
+ conf.set(key.substring("spark.hadoop.".length), value)
+ }
}
val bufferSize = System.getProperty("spark.buffer.size", "65536")
conf.set("io.file.buffer.size", bufferSize)
@@ -270,6 +280,12 @@ class SparkContext(
override protected def childValue(parent: Properties): Properties = new Properties(parent)
}
+ private[spark] def getLocalProperties(): Properties = localProperties.get()
+
+ private[spark] def setLocalProperties(props: Properties) {
+ localProperties.set(props)
+ }
+
def initLocalProperties() {
localProperties.set(new Properties())
}
@@ -291,7 +307,7 @@ class SparkContext(
/** Set a human readable description of the current job. */
@deprecated("use setJobGroup", "0.8.1")
def setJobDescription(value: String) {
- setJobGroup("", value)
+ setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
}
/**
@@ -589,7 +605,8 @@ class SparkContext(
val uri = new URI(path)
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
- case _ => path
+ case "local" => "file:" + uri.getPath
+ case _ => path
}
addedFiles(key) = System.currentTimeMillis
@@ -791,11 +808,10 @@ class SparkContext(
val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
+ dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
resultHandler, localProperties.get)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
- result
}
/**
@@ -977,6 +993,8 @@ object SparkContext {
private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
+ private[spark] val SPARK_UNKNOWN_USER = "<unknown>"
+
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 6bc846aa92..fc1537f796 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -17,16 +17,39 @@
package org.apache.spark.deploy
+import java.security.PrivilegedExceptionAction
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.security.UserGroupInformation
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkContext, SparkException}
/**
* Contains util methods to interact with Hadoop from Spark.
*/
private[spark]
class SparkHadoopUtil {
+ val conf = newConfiguration()
+ UserGroupInformation.setConfiguration(conf)
+
+ def runAsUser(user: String)(func: () => Unit) {
+ // if we are already running as the user intended there is no reason to do the doAs. It
+ // will actually break secure HDFS access as it doesn't fill in the credentials. Also if
+ // the user is UNKNOWN then we shouldn't be creating a remote unknown user
+ // (this is actually the path spark on yarn takes) since SPARK_USER is initialized only
+ // in SparkContext.
+ val currentUser = Option(System.getProperty("user.name")).
+ getOrElse(SparkContext.SPARK_UNKNOWN_USER)
+ if (user != SparkContext.SPARK_UNKNOWN_USER && currentUser != user) {
+ val ugi = UserGroupInformation.createRemoteUser(user)
+ ugi.doAs(new PrivilegedExceptionAction[Unit] {
+ def run: Unit = func()
+ })
+ } else {
+ func()
+ }
+ }
/**
* Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
@@ -42,9 +65,9 @@ class SparkHadoopUtil {
def isYarnMode(): Boolean = { false }
}
-
+
object SparkHadoopUtil {
- private val hadoop = {
+ private val hadoop = {
val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
if (yarnMode) {
try {
@@ -56,7 +79,7 @@ object SparkHadoopUtil {
new SparkHadoopUtil
}
}
-
+
def get: SparkHadoopUtil = {
hadoop
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 8fabc95665..fff9cb60c7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -104,7 +104,7 @@ private[spark] class ExecutorRunner(
// SPARK-698: do not call the run.cmd script, as process.destroy()
// fails to kill a process tree on Windows
Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++
- command.arguments.map(substituteVariables)
+ (command.arguments ++ Seq(appId)).map(substituteVariables)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 80ff4c59cb..caee6b01ab 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -111,7 +111,7 @@ private[spark] object CoarseGrainedExecutorBackend {
def main(args: Array[String]) {
if (args.length < 4) {
- //the reason we allow the last frameworkId argument is to make it easy to kill rogue executors
+ //the reason we allow the last appid argument is to make it easy to kill rogue executors
System.err.println(
"Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> <cores> " +
"[<appid>]")
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index b773346df3..5c9bb9db1c 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -25,8 +25,9 @@ import java.util.concurrent._
import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap
-import org.apache.spark.scheduler._
import org.apache.spark._
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.scheduler._
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util.Utils
@@ -129,6 +130,8 @@ private[spark] class Executor(
// Maintains the list of running tasks.
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
+ val sparkUser = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER)
+
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
val tr = new TaskRunner(context, taskId, serializedTask)
runningTasks.put(taskId, tr)
@@ -176,7 +179,7 @@ private[spark] class Executor(
}
}
- override def run() {
+ override def run(): Unit = SparkHadoopUtil.get.runAsUser(sparkUser) { () =>
val startTime = System.currentTimeMillis()
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(replClassLoader)
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 32901a508f..47e958b5e6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -132,6 +132,8 @@ class HadoopRDD[K, V](
override def getPartitions: Array[Partition] = {
val jobConf = getJobConf()
+ // add the credentials here as this can be called before SparkContext initialized
+ SparkHadoopUtil.get.addCredentials(jobConf)
val inputFormat = getInputFormat(jobConf)
if (inputFormat.isInstanceOf[Configurable]) {
inputFormat.asInstanceOf[Configurable].setConf(jobConf)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala
index c7d1295215..c5d7ca0481 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala
@@ -25,6 +25,8 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
+import akka.util.duration._
+
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
@@ -127,21 +129,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext, isLocal: Boolean = f
backend.start()
if (!isLocal && System.getProperty("spark.speculation", "false").toBoolean) {
- new Thread("TaskScheduler speculation check") {
- setDaemon(true)
-
- override def run() {
- logInfo("Starting speculative execution thread")
- while (true) {
- try {
- Thread.sleep(SPECULATION_INTERVAL)
- } catch {
- case e: InterruptedException => {}
- }
- checkSpeculatableTasks()
- }
- }
- }.start()
+ logInfo("Starting speculative execution thread")
+
+ sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds,
+ SPECULATION_INTERVAL milliseconds) {
+ checkSpeculatableTasks()
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 4cef0825dd..d0b21e896e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -417,15 +417,14 @@ class DAGScheduler(
case ExecutorLost(execId) =>
handleExecutorLost(execId)
- case begin: BeginEvent =>
- listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo))
+ case BeginEvent(task, taskInfo) =>
+ listenerBus.post(SparkListenerTaskStart(task, taskInfo))
- case gettingResult: GettingResultEvent =>
- listenerBus.post(SparkListenerTaskGettingResult(gettingResult.task, gettingResult.taskInfo))
+ case GettingResultEvent(task, taskInfo) =>
+ listenerBus.post(SparkListenerTaskGettingResult(task, taskInfo))
- case completion: CompletionEvent =>
- listenerBus.post(SparkListenerTaskEnd(
- completion.task, completion.reason, completion.taskInfo, completion.taskMetrics))
+ case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
+ listenerBus.post(SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics))
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 12b0d74fb5..60927831a1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -1,280 +1,384 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler
-
-import java.io.PrintWriter
-import java.io.File
-import java.io.FileNotFoundException
-import java.text.SimpleDateFormat
-import java.util.{Date, Properties}
-import java.util.concurrent.LinkedBlockingQueue
-
-import scala.collection.mutable.{HashMap, ListBuffer}
-
-import org.apache.spark._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.executor.TaskMetrics
-
-/**
- * A logger class to record runtime information for jobs in Spark. This class outputs one log file
- * per Spark job with information such as RDD graph, tasks start/stop, shuffle information.
- *
- * @param logDirName The base directory for the log files.
- */
-class JobLogger(val logDirName: String) extends SparkListener with Logging {
-
- private val logDir = Option(System.getenv("SPARK_LOG_DIR")).getOrElse("/tmp/spark")
-
- private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
- private val stageIDToJobID = new HashMap[Int, Int]
- private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
- private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
- private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
-
- createLogDir()
- def this() = this(String.valueOf(System.currentTimeMillis()))
-
- // The following 5 functions are used only in testing.
- private[scheduler] def getLogDir = logDir
- private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
- private[scheduler] def getStageIDToJobID = stageIDToJobID
- private[scheduler] def getJobIDToStages = jobIDToStages
- private[scheduler] def getEventQueue = eventQueue
-
- // Create a folder for log files, the folder's name is the creation time of the jobLogger
- protected def createLogDir() {
- val dir = new File(logDir + "/" + logDirName + "/")
- if (!dir.exists() && !dir.mkdirs()) {
- logError("Error creating log directory: " + logDir + "/" + logDirName + "/")
- }
- }
-
- // Create a log file for one job, the file name is the jobID
- protected def createLogWriter(jobID: Int) {
- try {
- val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
- jobIDToPrintWriter += (jobID -> fileWriter)
- } catch {
- case e: FileNotFoundException => e.printStackTrace()
- }
- }
-
- // Close log file, and clean the stage relationship in stageIDToJobID
- protected def closeLogWriter(jobID: Int) =
- jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
- fileWriter.close()
- jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
- stageIDToJobID -= stage.id
- })
- jobIDToPrintWriter -= jobID
- jobIDToStages -= jobID
- }
-
- // Write log information to log file, withTime parameter controls whether to recored
- // time stamp for the information
- protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
- var writeInfo = info
- if (withTime) {
- val date = new Date(System.currentTimeMillis())
- writeInfo = DATE_FORMAT.format(date) + ": " +info
- }
- jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
- }
-
- protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
- stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
-
- protected def buildJobDep(jobID: Int, stage: Stage) {
- if (stage.jobId == jobID) {
- jobIDToStages.get(jobID) match {
- case Some(stageList) => stageList += stage
- case None => val stageList = new ListBuffer[Stage]
- stageList += stage
- jobIDToStages += (jobID -> stageList)
- }
- stageIDToJobID += (stage.id -> jobID)
- stage.parents.foreach(buildJobDep(jobID, _))
- }
- }
-
- protected def recordStageDep(jobID: Int) {
- def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
- var rddList = new ListBuffer[RDD[_]]
- rddList += rdd
- rdd.dependencies.foreach {
- case shufDep: ShuffleDependency[_, _] =>
- case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
- }
- rddList
- }
- jobIDToStages.get(jobID).foreach {_.foreach { stage =>
- var depRddDesc: String = ""
- getRddsInStage(stage.rdd).foreach { rdd =>
- depRddDesc += rdd.id + ","
- }
- var depStageDesc: String = ""
- stage.parents.foreach { stage =>
- depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
- }
- jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
- depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
- " STAGE_DEP=" + depStageDesc, false)
- }
- }
- }
-
- // Generate indents and convert to String
- protected def indentString(indent: Int) = {
- val sb = new StringBuilder()
- for (i <- 1 to indent) {
- sb.append(" ")
- }
- sb.toString()
- }
-
- protected def getRddName(rdd: RDD[_]) = {
- var rddName = rdd.getClass.getName
- if (rdd.name != null) {
- rddName = rdd.name
- }
- rddName
- }
-
- protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
- val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
- jobLogInfo(jobID, indentString(indent) + rddInfo, false)
- rdd.dependencies.foreach {
- case shufDep: ShuffleDependency[_, _] =>
- val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
- jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
- case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
- }
- }
-
- protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
- val stageInfo = if (stage.isShuffleMap) {
- "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
- } else {
- "STAGE_ID=" + stage.id + " RESULT_STAGE"
- }
- if (stage.jobId == jobID) {
- jobLogInfo(jobID, indentString(indent) + stageInfo, false)
- recordRddInStageGraph(jobID, stage.rdd, indent)
- stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
- } else {
- jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
- }
- }
-
- // Record task metrics into job log files
- protected def recordTaskMetrics(stageID: Int, status: String,
- taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
- val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
- " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
- " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
- val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
- val readMetrics = taskMetrics.shuffleReadMetrics match {
- case Some(metrics) =>
- " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
- " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
- " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
- " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
- " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
- " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
- " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
- case None => ""
- }
- val writeMetrics = taskMetrics.shuffleWriteMetrics match {
- case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
- case None => ""
- }
- stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
- }
-
- override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
- stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
- stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
- }
-
- override def onStageCompleted(stageCompleted: StageCompleted) {
- stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
- stageCompleted.stage.stageId))
- }
-
- override def onTaskStart(taskStart: SparkListenerTaskStart) { }
-
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
- val task = taskEnd.task
- val taskInfo = taskEnd.taskInfo
- var taskStatus = ""
- task match {
- case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
- case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
- }
- taskEnd.reason match {
- case Success => taskStatus += " STATUS=SUCCESS"
- recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
- case Resubmitted =>
- taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
- " STAGE_ID=" + task.stageId
- stageLogInfo(task.stageId, taskStatus)
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
- taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
- task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
- mapId + " REDUCE_ID=" + reduceId
- stageLogInfo(task.stageId, taskStatus)
- case OtherFailure(message) =>
- taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
- " STAGE_ID=" + task.stageId + " INFO=" + message
- stageLogInfo(task.stageId, taskStatus)
- case _ =>
- }
- }
-
- override def onJobEnd(jobEnd: SparkListenerJobEnd) {
- val job = jobEnd.job
- var info = "JOB_ID=" + job.jobId
- jobEnd.jobResult match {
- case JobSucceeded => info += " STATUS=SUCCESS"
- case JobFailed(exception, _) =>
- info += " STATUS=FAILED REASON="
- exception.getMessage.split("\\s+").foreach(info += _ + "_")
- case _ =>
- }
- jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
- closeLogWriter(job.jobId)
- }
-
- protected def recordJobProperties(jobID: Int, properties: Properties) {
- if(properties != null) {
- val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
- jobLogInfo(jobID, description, false)
- }
- }
-
- override def onJobStart(jobStart: SparkListenerJobStart) {
- val job = jobStart.job
- val properties = jobStart.properties
- createLogWriter(job.jobId)
- recordJobProperties(job.jobId, properties)
- buildJobDep(job.jobId, job.finalStage)
- recordStageDep(job.jobId)
- recordStageDepGraph(job.jobId, job.finalStage)
- jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
- }
-}
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.{IOException, File, FileNotFoundException, PrintWriter}
+import java.text.SimpleDateFormat
+import java.util.{Date, Properties}
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
+
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * A logger class to record runtime information for jobs in Spark. This class outputs one log file
+ * for each Spark job, containing RDD graph, tasks start/stop, shuffle information.
+ * JobLogger is a subclass of SparkListener, use addSparkListener to add JobLogger to a SparkContext
+ * after the SparkContext is created.
+ * Note that each JobLogger only works for one SparkContext
+ * @param logDirName The base directory for the log files.
+ */
+
+class JobLogger(val user: String, val logDirName: String)
+ extends SparkListener with Logging {
+
+ def this() = this(System.getProperty("user.name", "<unknown>"),
+ String.valueOf(System.currentTimeMillis()))
+
+ private val logDir =
+ if (System.getenv("SPARK_LOG_DIR") != null)
+ System.getenv("SPARK_LOG_DIR")
+ else
+ "/tmp/spark-%s".format(user)
+
+ private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
+ private val stageIDToJobID = new HashMap[Int, Int]
+ private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
+ private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
+
+ createLogDir()
+
+ // The following 5 functions are used only in testing.
+ private[scheduler] def getLogDir = logDir
+ private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
+ private[scheduler] def getStageIDToJobID = stageIDToJobID
+ private[scheduler] def getJobIDToStages = jobIDToStages
+ private[scheduler] def getEventQueue = eventQueue
+
+ /** Create a folder for log files, the folder's name is the creation time of jobLogger */
+ protected def createLogDir() {
+ val dir = new File(logDir + "/" + logDirName + "/")
+ if (dir.exists()) {
+ return
+ }
+ if (dir.mkdirs() == false) {
+ // JobLogger should throw a exception rather than continue to construct this object.
+ throw new IOException("create log directory error:" + logDir + "/" + logDirName + "/")
+ }
+ }
+
+ /**
+ * Create a log file for one job
+ * @param jobID ID of the job
+ * @exception FileNotFoundException Fail to create log file
+ */
+ protected def createLogWriter(jobID: Int) {
+ try {
+ val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
+ jobIDToPrintWriter += (jobID -> fileWriter)
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
+ }
+
+ /**
+ * Close log file, and clean the stage relationship in stageIDToJobID
+ * @param jobID ID of the job
+ */
+ protected def closeLogWriter(jobID: Int) {
+ jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
+ stageIDToJobID -= stage.id
+ })
+ jobIDToPrintWriter -= jobID
+ jobIDToStages -= jobID
+ }
+ }
+
+ /**
+ * Write info into log file
+ * @param jobID ID of the job
+ * @param info Info to be recorded
+ * @param withTime Controls whether to record time stamp before the info, default is true
+ */
+ protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
+ var writeInfo = info
+ if (withTime) {
+ val date = new Date(System.currentTimeMillis())
+ writeInfo = DATE_FORMAT.format(date) + ": " +info
+ }
+ jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
+ }
+
+ /**
+ * Write info into log file
+ * @param stageID ID of the stage
+ * @param info Info to be recorded
+ * @param withTime Controls whether to record time stamp before the info, default is true
+ */
+ protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) {
+ stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
+ }
+
+ /**
+ * Build stage dependency for a job
+ * @param jobID ID of the job
+ * @param stage Root stage of the job
+ */
+ protected def buildJobDep(jobID: Int, stage: Stage) {
+ if (stage.jobId == jobID) {
+ jobIDToStages.get(jobID) match {
+ case Some(stageList) => stageList += stage
+ case None => val stageList = new ListBuffer[Stage]
+ stageList += stage
+ jobIDToStages += (jobID -> stageList)
+ }
+ stageIDToJobID += (stage.id -> jobID)
+ stage.parents.foreach(buildJobDep(jobID, _))
+ }
+ }
+
+ /**
+ * Record stage dependency and RDD dependency for a stage
+ * @param jobID Job ID of the stage
+ */
+ protected def recordStageDep(jobID: Int) {
+ def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
+ var rddList = new ListBuffer[RDD[_]]
+ rddList += rdd
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
+ }
+ rddList
+ }
+ jobIDToStages.get(jobID).foreach {_.foreach { stage =>
+ var depRddDesc: String = ""
+ getRddsInStage(stage.rdd).foreach { rdd =>
+ depRddDesc += rdd.id + ","
+ }
+ var depStageDesc: String = ""
+ stage.parents.foreach { stage =>
+ depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
+ }
+ jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
+ depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
+ " STAGE_DEP=" + depStageDesc, false)
+ }
+ }
+ }
+
+ /**
+ * Generate indents and convert to String
+ * @param indent Number of indents
+ * @return string of indents
+ */
+ protected def indentString(indent: Int): String = {
+ val sb = new StringBuilder()
+ for (i <- 1 to indent) {
+ sb.append(" ")
+ }
+ sb.toString()
+ }
+
+ /**
+ * Get RDD's name
+ * @param rdd Input RDD
+ * @return String of RDD's name
+ */
+ protected def getRddName(rdd: RDD[_]): String = {
+ var rddName = rdd.getClass.getSimpleName
+ if (rdd.name != null) {
+ rddName = rdd.name
+ }
+ rddName
+ }
+
+ /**
+ * Record RDD dependency graph in a stage
+ * @param jobID Job ID of the stage
+ * @param rdd Root RDD of the stage
+ * @param indent Indent number before info
+ */
+ protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
+ val rddInfo =
+ if (rdd.getStorageLevel != StorageLevel.NONE) {
+ "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " CACHED" + " " +
+ rdd.origin + " " + rdd.generator
+ } else {
+ "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " NONE" + " " +
+ rdd.origin + " " + rdd.generator
+ }
+ jobLogInfo(jobID, indentString(indent) + rddInfo, false)
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
+ }
+ }
+
+ /**
+ * Record stage dependency graph of a job
+ * @param jobID Job ID of the stage
+ * @param stage Root stage of the job
+ * @param indent Indent number before info, default is 0
+ */
+ protected def recordStageDepGraph(jobID: Int, stage: Stage, idSet: HashSet[Int], indent: Int = 0) {
+ val stageInfo = if (stage.isShuffleMap) {
+ "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
+ } else {
+ "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ }
+ if (stage.jobId == jobID) {
+ jobLogInfo(jobID, indentString(indent) + stageInfo, false)
+ if (!idSet.contains(stage.id)) {
+ idSet += stage.id
+ recordRddInStageGraph(jobID, stage.rdd, indent)
+ stage.parents.foreach(recordStageDepGraph(jobID, _, idSet, indent + 2))
+ }
+ } else {
+ jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
+ }
+ }
+
+ /**
+ * Record task metrics into job log files, including execution info and shuffle metrics
+ * @param stageID Stage ID of the task
+ * @param status Status info of the task
+ * @param taskInfo Task description info
+ * @param taskMetrics Task running metrics
+ */
+ protected def recordTaskMetrics(stageID: Int, status: String,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
+ " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
+ " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
+ val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
+ val readMetrics = taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics = taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
+ stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
+ }
+
+ /**
+ * When stage is submitted, record stage submit info
+ * @param stageSubmitted Stage submitted event
+ */
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
+ stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
+ stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
+ }
+
+ /**
+ * When stage is completed, record stage completion status
+ * @param stageCompleted Stage completed event
+ */
+ override def onStageCompleted(stageCompleted: StageCompleted) {
+ stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
+ stageCompleted.stage.stageId))
+ }
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart) { }
+
+ /**
+ * When task ends, record task completion status and metrics
+ * @param taskEnd Task end event
+ */
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ val task = taskEnd.task
+ val taskInfo = taskEnd.taskInfo
+ var taskStatus = ""
+ task match {
+ case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
+ case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
+ }
+ taskEnd.reason match {
+ case Success => taskStatus += " STATUS=SUCCESS"
+ recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
+ case Resubmitted =>
+ taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId
+ stageLogInfo(task.stageId, taskStatus)
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
+ task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
+ mapId + " REDUCE_ID=" + reduceId
+ stageLogInfo(task.stageId, taskStatus)
+ case OtherFailure(message) =>
+ taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId + " INFO=" + message
+ stageLogInfo(task.stageId, taskStatus)
+ case _ =>
+ }
+ }
+
+ /**
+ * When job ends, recording job completion status and close log file
+ * @param jobEnd Job end event
+ */
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) {
+ val job = jobEnd.job
+ var info = "JOB_ID=" + job.jobId
+ jobEnd.jobResult match {
+ case JobSucceeded => info += " STATUS=SUCCESS"
+ case JobFailed(exception, _) =>
+ info += " STATUS=FAILED REASON="
+ exception.getMessage.split("\\s+").foreach(info += _ + "_")
+ case _ =>
+ }
+ jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
+ closeLogWriter(job.jobId)
+ }
+
+ /**
+ * Record job properties into job log file
+ * @param jobID ID of the job
+ * @param properties Properties of the job
+ */
+ protected def recordJobProperties(jobID: Int, properties: Properties) {
+ if(properties != null) {
+ val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
+ jobLogInfo(jobID, description, false)
+ }
+ }
+
+ /**
+ * When job starts, record job property and stage graph
+ * @param jobStart Job start event
+ */
+ override def onJobStart(jobStart: SparkListenerJobStart) {
+ val job = jobStart.job
+ val properties = jobStart.properties
+ createLogWriter(job.jobId)
+ recordJobProperties(job.jobId, properties)
+ buildJobDep(job.jobId, job.finalStage)
+ recordStageDep(job.jobId)
+ recordStageDepGraph(job.jobId, job.finalStage, new HashSet[Int])
+ jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 24d97da6eb..1dc71a0428 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -146,26 +146,26 @@ private[spark] class ShuffleMapTask(
metrics = Some(context.taskMetrics)
val blockManager = SparkEnv.get.blockManager
- var shuffle: ShuffleBlocks = null
- var buckets: ShuffleWriterGroup = null
+ val shuffleBlockManager = blockManager.shuffleBlockManager
+ var shuffle: ShuffleWriterGroup = null
+ var success = false
try {
// Obtain all the block writers for shuffle blocks.
val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
- shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
- buckets = shuffle.acquireWriters(partitionId)
+ shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
// Write the map output to its associated buckets.
for (elem <- rdd.iterator(split, context)) {
val pair = elem.asInstanceOf[Product2[Any, Any]]
val bucketId = dep.partitioner.getPartition(pair._1)
- buckets.writers(bucketId).write(pair)
+ shuffle.writers(bucketId).write(pair)
}
// Commit the writes. Get the size of each bucket block (total block size).
var totalBytes = 0L
var totalTime = 0L
- val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
+ val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
writer.commit()
val size = writer.fileSegment().length
totalBytes += size
@@ -179,19 +179,20 @@ private[spark] class ShuffleMapTask(
shuffleMetrics.shuffleWriteTime = totalTime
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
+ success = true
new MapStatus(blockManager.blockManagerId, compressedSizes)
} catch { case e: Exception =>
// If there is an exception from running the task, revert the partial writes
// and throw the exception upstream to Spark.
- if (buckets != null) {
- buckets.writers.foreach(_.revertPartialWrites())
+ if (shuffle != null) {
+ shuffle.writers.foreach(_.revertPartialWrites())
}
throw e
} finally {
// Release the writers back to the shuffle block manager.
- if (shuffle != null && buckets != null) {
- buckets.writers.foreach(_.close())
- shuffle.releaseWriters(buckets)
+ if (shuffle != null && shuffle.writers != null) {
+ shuffle.writers.foreach(_.close())
+ shuffle.releaseWriters(success)
}
// Execute the callbacks on task completion.
context.executeOnCompleteCallbacks()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index f5548fc2da..3bb715e7d0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -88,8 +88,14 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
case StatusUpdate(executorId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) {
- freeCores(executorId) += 1
- makeOffers(executorId)
+ if (executorActor.contains(executorId)) {
+ freeCores(executorId) += 1
+ makeOffers(executorId)
+ } else {
+ // Ignoring the update since we don't know about the executor.
+ val msg = "Ignored task status update (%d state %s) from unknown executor %s with ID %s"
+ logWarning(msg.format(taskId, state, sender, executorId))
+ }
}
case ReviveOffers =>
@@ -176,7 +182,9 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME)
}
- private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ private val timeout = {
+ Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ }
def stopExecutors() {
try {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
index 40fdfcddb1..cec02e945c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -33,6 +33,10 @@ private[spark] class SimrSchedulerBackend(
val tmpPath = new Path(driverFilePath + "_tmp")
val filePath = new Path(driverFilePath)
+ val uiFilePath = driverFilePath + "_ui"
+ val tmpUiPath = new Path(uiFilePath + "_tmp")
+ val uiPath = new Path(uiFilePath)
+
val maxCores = System.getProperty("spark.simr.executor.cores", "1").toInt
override def start() {
@@ -47,6 +51,8 @@ private[spark] class SimrSchedulerBackend(
logInfo("Writing to HDFS file: " + driverFilePath)
logInfo("Writing Akka address: " + driverUrl)
+ logInfo("Writing to HDFS file: " + uiFilePath)
+ logInfo("Writing Spark UI Address: " + sc.ui.appUIAddress)
// Create temporary file to prevent race condition where executors get empty driverUrl file
val temp = fs.create(tmpPath, true)
@@ -56,6 +62,12 @@ private[spark] class SimrSchedulerBackend(
// "Atomic" rename
fs.rename(tmpPath, filePath)
+
+ // Write Spark UI Address to file
+ val uiTemp = fs.create(tmpUiPath, true)
+ uiTemp.writeUTF(sc.ui.appUIAddress)
+ uiTemp.close()
+ fs.rename(tmpUiPath, uiPath)
}
override def stop() {
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 55b25f145a..e748c2275d 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -27,13 +27,17 @@ import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar}
import org.apache.spark.{SerializableWritable, Logging}
import org.apache.spark.broadcast.HttpBroadcast
-import org.apache.spark.storage.{GetBlock,GotBlock, PutBlock, StorageLevel, TestBlockId}
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage._
/**
- * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]].
+ * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]].
*/
class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging {
- private val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
+
+ private val bufferSize = {
+ System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
+ }
def newKryoOutput() = new KryoOutput(bufferSize)
@@ -42,21 +46,11 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging
val kryo = instantiator.newKryo()
val classLoader = Thread.currentThread.getContextClassLoader
- val blockId = TestBlockId("1")
- // Register some commonly used classes
- val toRegister: Seq[AnyRef] = Seq(
- ByteBuffer.allocate(1),
- StorageLevel.MEMORY_ONLY,
- PutBlock(blockId, ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
- GotBlock(blockId, ByteBuffer.allocate(1)),
- GetBlock(blockId),
- 1 to 10,
- 1 until 10,
- 1L to 10L,
- 1L until 10L
- )
-
- for (obj <- toRegister) kryo.register(obj.getClass)
+ // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
+ // Do this before we invoke the user registrator so the user registrator can override this.
+ kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean)
+
+ for (cls <- KryoSerializer.toRegister) kryo.register(cls)
// Allow sending SerializableWritable
kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
@@ -78,10 +72,6 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging
new AllScalaRegistrar().apply(kryo)
kryo.setClassLoader(classLoader)
-
- // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops
- kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean)
-
kryo
}
@@ -165,3 +155,21 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
trait KryoRegistrator {
def registerClasses(kryo: Kryo)
}
+
+private[serializer] object KryoSerializer {
+ // Commonly used classes.
+ private val toRegister: Seq[Class[_]] = Seq(
+ ByteBuffer.allocate(1).getClass,
+ classOf[StorageLevel],
+ classOf[PutBlock],
+ classOf[GotBlock],
+ classOf[GetBlock],
+ classOf[MapStatus],
+ classOf[BlockManagerId],
+ classOf[Array[Byte]],
+ (1 to 10).getClass,
+ (1 until 10).getClass,
+ (1L to 10L).getClass,
+ (1L until 10L).getClass
+ )
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
index dbe0bda615..c8f397609a 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
@@ -19,9 +19,7 @@ package org.apache.spark.storage
import java.util.concurrent.ConcurrentHashMap
-private[storage] trait BlockInfo {
- def level: StorageLevel
- def tellMaster: Boolean
+private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
// To save space, 'pending' and 'failed' are encoded as special sizes:
@volatile var size: Long = BlockInfo.BLOCK_PENDING
private def pending: Boolean = size == BlockInfo.BLOCK_PENDING
@@ -81,17 +79,3 @@ private object BlockInfo {
private val BLOCK_PENDING: Long = -1L
private val BLOCK_FAILED: Long = -2L
}
-
-// All shuffle blocks have the same `level` and `tellMaster` properties,
-// so we can save space by not storing them in each instance:
-private[storage] class ShuffleBlockInfo extends BlockInfo {
- // These need to be defined using 'def' instead of 'val' in order for
- // the compiler to eliminate the fields:
- def level: StorageLevel = StorageLevel.DISK_ONLY
- def tellMaster: Boolean = false
-}
-
-private[storage] class BlockInfoImpl(val level: StorageLevel, val tellMaster: Boolean)
- extends BlockInfo {
- // Intentionally left blank
-} \ No newline at end of file
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 76d537f8e8..a34c95b6f0 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -17,7 +17,7 @@
package org.apache.spark.storage
-import java.io.{InputStream, OutputStream}
+import java.io.{File, InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
import scala.collection.mutable.{HashMap, ArrayBuffer}
@@ -47,7 +47,7 @@ private[spark] class BlockManager(
extends Logging {
val shuffleBlockManager = new ShuffleBlockManager(this)
- val diskBlockManager = new DiskBlockManager(
+ val diskBlockManager = new DiskBlockManager(shuffleBlockManager,
System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
@@ -462,20 +462,10 @@ private[spark] class BlockManager(
* This is currently used for writing shuffle files out. Callers should handle error
* cases.
*/
- def getDiskWriter(blockId: BlockId, filename: String, serializer: Serializer, bufferSize: Int)
+ def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
- val file = diskBlockManager.createBlockFile(blockId, filename, allowAppending = true)
- val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
- writer.registerCloseEventHandler(() => {
- if (shuffleBlockManager.consolidateShuffleFiles) {
- diskBlockManager.mapBlockToFileSegment(blockId, writer.fileSegment())
- }
- val myInfo = new ShuffleBlockInfo()
- blockInfo.put(blockId, myInfo)
- myInfo.markReady(writer.fileSegment().length)
- })
- writer
+ new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
}
/**
@@ -505,7 +495,7 @@ private[spark] class BlockManager(
// to be dropped right after it got put into memory. Note, however, that other threads will
// not be able to get() this block until we call markReady on its BlockInfo.
val myInfo = {
- val tinfo = new BlockInfoImpl(level, tellMaster)
+ val tinfo = new BlockInfo(level, tellMaster)
// Do atomically !
val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 32d2dd0694..469e68fed7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -34,20 +34,12 @@ import org.apache.spark.serializer.{SerializationStream, Serializer}
*/
abstract class BlockObjectWriter(val blockId: BlockId) {
- var closeEventHandler: () => Unit = _
-
def open(): BlockObjectWriter
- def close() {
- closeEventHandler()
- }
+ def close()
def isOpen: Boolean
- def registerCloseEventHandler(handler: () => Unit) {
- closeEventHandler = handler
- }
-
/**
* Flush the partial writes and commit them as a single atomic block. Return the
* number of bytes written for this commit.
@@ -78,11 +70,11 @@ abstract class BlockObjectWriter(val blockId: BlockId) {
/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
class DiskBlockObjectWriter(
- blockId: BlockId,
- file: File,
- serializer: Serializer,
- bufferSize: Int,
- compressStream: OutputStream => OutputStream)
+ blockId: BlockId,
+ file: File,
+ serializer: Serializer,
+ bufferSize: Int,
+ compressStream: OutputStream => OutputStream)
extends BlockObjectWriter(blockId)
with Logging
{
@@ -111,8 +103,8 @@ class DiskBlockObjectWriter(
private var fos: FileOutputStream = null
private var ts: TimeTrackingOutputStream = null
private var objOut: SerializationStream = null
- private var initialPosition = 0L
- private var lastValidPosition = 0L
+ private val initialPosition = file.length()
+ private var lastValidPosition = initialPosition
private var initialized = false
private var _timeWriting = 0L
@@ -120,7 +112,6 @@ class DiskBlockObjectWriter(
fos = new FileOutputStream(file, true)
ts = new TimeTrackingOutputStream(fos)
channel = fos.getChannel()
- initialPosition = channel.position
lastValidPosition = initialPosition
bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
objOut = serializer.newInstance().serializeStream(bs)
@@ -147,8 +138,6 @@ class DiskBlockObjectWriter(
ts = null
objOut = null
}
- // Invoke the close callback handler.
- super.close()
}
override def isOpen: Boolean = objOut != null
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index bcb58ad946..fcd2e97982 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -20,12 +20,11 @@ package org.apache.spark.storage
import java.io.File
import java.text.SimpleDateFormat
import java.util.{Date, Random}
-import java.util.concurrent.ConcurrentHashMap
import org.apache.spark.Logging
import org.apache.spark.executor.ExecutorExitCode
import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
+import org.apache.spark.util.Utils
/**
* Creates and maintains the logical mapping between logical blocks and physical on-disk
@@ -35,7 +34,8 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
*
* @param rootDirs The directories to use for storing block files. Data will be hashed among these.
*/
-private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver with Logging {
+private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootDirs: String)
+ extends PathResolver with Logging {
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
@@ -47,54 +47,23 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit
private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
private var shuffleSender : ShuffleSender = null
- // Stores only Blocks which have been specifically mapped to segments of files
- // (rather than the default, which maps a Block to a whole file).
- // This keeps our bookkeeping down, since the file system itself tracks the standalone Blocks.
- private val blockToFileSegmentMap = new TimeStampedHashMap[BlockId, FileSegment]
-
- val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DISK_BLOCK_MANAGER, this.cleanup)
-
addShutdownHook()
/**
- * Creates a logical mapping from the given BlockId to a segment of a file.
- * This will cause any accesses of the logical BlockId to be directed to the specified
- * physical location.
- */
- def mapBlockToFileSegment(blockId: BlockId, fileSegment: FileSegment) {
- blockToFileSegmentMap.put(blockId, fileSegment)
- }
-
- /**
* Returns the phyiscal file segment in which the given BlockId is located.
* If the BlockId has been mapped to a specific FileSegment, that will be returned.
* Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
*/
def getBlockLocation(blockId: BlockId): FileSegment = {
- if (blockToFileSegmentMap.internalMap.containsKey(blockId)) {
- blockToFileSegmentMap.get(blockId).get
+ if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) {
+ shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])
} else {
val file = getFile(blockId.name)
new FileSegment(file, 0, file.length())
}
}
- /**
- * Simply returns a File to place the given Block into. This does not physically create the file.
- * If filename is given, that file will be used. Otherwise, we will use the BlockId to get
- * a unique filename.
- */
- def createBlockFile(blockId: BlockId, filename: String = "", allowAppending: Boolean): File = {
- val actualFilename = if (filename == "") blockId.name else filename
- val file = getFile(actualFilename)
- if (!allowAppending && file.exists()) {
- throw new IllegalStateException(
- "Attempted to create file that already exists: " + actualFilename)
- }
- file
- }
-
- private def getFile(filename: String): File = {
+ def getFile(filename: String): File = {
// Figure out which local directory it hashes to, and which subdirectory in that
val hash = Utils.nonNegativeHash(filename)
val dirId = hash % localDirs.length
@@ -119,6 +88,8 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit
new File(subDir, filename)
}
+ def getFile(blockId: BlockId): File = getFile(blockId.name)
+
private def createLocalDirs(): Array[File] = {
logDebug("Creating local directories at root dirs '" + rootDirs + "'")
val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
@@ -151,10 +122,6 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit
}
}
- private def cleanup(cleanupTime: Long) {
- blockToFileSegmentMap.clearOldValues(cleanupTime)
- }
-
private def addShutdownHook() {
localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index a3c496f9e0..5a1e7b4444 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -44,7 +44,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
val bytes = _bytes.duplicate()
logDebug("Attempting to put block " + blockId)
val startTime = System.currentTimeMillis
- val file = diskManager.createBlockFile(blockId, allowAppending = false)
+ val file = diskManager.getFile(blockId)
val channel = new FileOutputStream(file).getChannel()
while (bytes.remaining > 0) {
channel.write(bytes)
@@ -64,7 +64,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
logDebug("Attempting to write values for block " + blockId)
val startTime = System.currentTimeMillis
- val file = diskManager.createBlockFile(blockId, allowAppending = false)
+ val file = diskManager.getFile(blockId)
val outputStream = new FileOutputStream(file)
blockManager.dataSerializeStream(blockId, outputStream, values.iterator)
val length = file.length
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index 066e45a12b..2f1b049ce4 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -17,33 +17,45 @@
package org.apache.spark.storage
+import java.io.File
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.JavaConversions._
+
import org.apache.spark.serializer.Serializer
+import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
+import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
-private[spark]
-class ShuffleWriterGroup(val id: Int, val fileId: Int, val writers: Array[BlockObjectWriter])
+/** A group of writers for a ShuffleMapTask, one writer per reducer. */
+private[spark] trait ShuffleWriterGroup {
+ val writers: Array[BlockObjectWriter]
-private[spark]
-trait ShuffleBlocks {
- def acquireWriters(mapId: Int): ShuffleWriterGroup
- def releaseWriters(group: ShuffleWriterGroup)
+ /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */
+ def releaseWriters(success: Boolean)
}
/**
- * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one writer
- * per reducer.
+ * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file
+ * per reducer (this set of files is called a ShuffleFileGroup).
*
* As an optimization to reduce the number of physical shuffle files produced, multiple shuffle
* blocks are aggregated into the same file. There is one "combined shuffle file" per reducer
- * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle files,
- * it releases them for another task.
+ * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle
+ * files, it releases them for another task.
* Regarding the implementation of this feature, shuffle files are identified by a 3-tuple:
* - shuffleId: The unique id given to the entire shuffle stage.
* - bucketId: The id of the output partition (i.e., reducer id)
* - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a
* time owns a particular fileId, and this id is returned to a pool when the task finishes.
+ * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length)
+ * that specifies where in a given file the actual block data is located.
+ *
+ * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping
+ * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for
+ * each block stored in each file. In order to find the location of a shuffle block, we search the
+ * files within a ShuffleFileGroups associated with the block's reducer.
*/
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) {
@@ -52,45 +64,152 @@ class ShuffleBlockManager(blockManager: BlockManager) {
val consolidateShuffleFiles =
System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean
- var nextFileId = new AtomicInteger(0)
- val unusedFileIds = new ConcurrentLinkedQueue[java.lang.Integer]()
+ private val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+
+ /**
+ * Contains all the state related to a particular shuffle. This includes a pool of unused
+ * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle.
+ */
+ private class ShuffleState() {
+ val nextFileId = new AtomicInteger(0)
+ val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+ val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+ }
+
+ type ShuffleId = Int
+ private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState]
+
+ private
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup)
- def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) = {
- new ShuffleBlocks {
- // Get a group of writers for a map task.
- override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
- val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
- val fileId = getUnusedFileId()
- val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
+ new ShuffleWriterGroup {
+ shuffleStates.putIfAbsent(shuffleId, new ShuffleState())
+ private val shuffleState = shuffleStates(shuffleId)
+ private var fileGroup: ShuffleFileGroup = null
+
+ val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
+ fileGroup = getUnusedFileGroup()
+ Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
- if (consolidateShuffleFiles) {
- val filename = physicalFileName(shuffleId, bucketId, fileId)
- blockManager.getDiskWriter(blockId, filename, serializer, bufferSize)
- } else {
- blockManager.getDiskWriter(blockId, blockId.name, serializer, bufferSize)
+ blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize)
+ }
+ } else {
+ Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
+ val blockFile = blockManager.diskBlockManager.getFile(blockId)
+ blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize)
+ }
+ }
+
+ override def releaseWriters(success: Boolean) {
+ if (consolidateShuffleFiles) {
+ if (success) {
+ val offsets = writers.map(_.fileSegment().offset)
+ fileGroup.recordMapOutput(mapId, offsets)
}
+ recycleFileGroup(fileGroup)
}
- new ShuffleWriterGroup(mapId, fileId, writers)
}
- override def releaseWriters(group: ShuffleWriterGroup) {
- recycleFileId(group.fileId)
+ private def getUnusedFileGroup(): ShuffleFileGroup = {
+ val fileGroup = shuffleState.unusedFileGroups.poll()
+ if (fileGroup != null) fileGroup else newFileGroup()
+ }
+
+ private def newFileGroup(): ShuffleFileGroup = {
+ val fileId = shuffleState.nextFileId.getAndIncrement()
+ val files = Array.tabulate[File](numBuckets) { bucketId =>
+ val filename = physicalFileName(shuffleId, bucketId, fileId)
+ blockManager.diskBlockManager.getFile(filename)
+ }
+ val fileGroup = new ShuffleFileGroup(fileId, shuffleId, files)
+ shuffleState.allFileGroups.add(fileGroup)
+ fileGroup
}
- }
- }
- private def getUnusedFileId(): Int = {
- val fileId = unusedFileIds.poll()
- if (fileId == null) nextFileId.getAndIncrement() else fileId
+ private def recycleFileGroup(group: ShuffleFileGroup) {
+ shuffleState.unusedFileGroups.add(group)
+ }
+ }
}
- private def recycleFileId(fileId: Int) {
- if (consolidateShuffleFiles) {
- unusedFileIds.add(fileId)
+ /**
+ * Returns the physical file segment in which the given BlockId is located.
+ * This function should only be called if shuffle file consolidation is enabled, as it is
+ * an error condition if we don't find the expected block.
+ */
+ def getBlockLocation(id: ShuffleBlockId): FileSegment = {
+ // Search all file groups associated with this shuffle.
+ val shuffleState = shuffleStates(id.shuffleId)
+ for (fileGroup <- shuffleState.allFileGroups) {
+ val segment = fileGroup.getFileSegmentFor(id.mapId, id.reduceId)
+ if (segment.isDefined) { return segment.get }
}
+ throw new IllegalStateException("Failed to find shuffle block: " + id)
}
private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
"merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
}
+
+ private def cleanup(cleanupTime: Long) {
+ shuffleStates.clearOldValues(cleanupTime)
+ }
+}
+
+private[spark]
+object ShuffleBlockManager {
+ /**
+ * A group of shuffle files, one per reducer.
+ * A particular mapper will be assigned a single ShuffleFileGroup to write its output to.
+ */
+ private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) {
+ /**
+ * Stores the absolute index of each mapId in the files of this group. For instance,
+ * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0.
+ */
+ private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]()
+
+ /**
+ * Stores consecutive offsets of blocks into each reducer file, ordered by position in the file.
+ * This ordering allows us to compute block lengths by examining the following block offset.
+ * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every
+ * reducer.
+ */
+ private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) {
+ new PrimitiveVector[Long]()
+ }
+
+ def numBlocks = mapIdToIndex.size
+
+ def apply(bucketId: Int) = files(bucketId)
+
+ def recordMapOutput(mapId: Int, offsets: Array[Long]) {
+ mapIdToIndex(mapId) = numBlocks
+ for (i <- 0 until offsets.length) {
+ blockOffsetsByReducer(i) += offsets(i)
+ }
+ }
+
+ /** Returns the FileSegment associated with the given map task, or None if no entry exists. */
+ def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = {
+ val file = files(reducerId)
+ val blockOffsets = blockOffsetsByReducer(reducerId)
+ val index = mapIdToIndex.getOrElse(mapId, -1)
+ if (index >= 0) {
+ val offset = blockOffsets(index)
+ val length =
+ if (index + 1 < numBlocks) {
+ blockOffsets(index + 1) - offset
+ } else {
+ file.length() - offset
+ }
+ assert(length >= 0)
+ Some(new FileSegment(file, offset, length))
+ } else {
+ None
+ }
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
index 7dcadc3805..1e4db4f66b 100644
--- a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
@@ -38,19 +38,19 @@ object StoragePerfTester {
val blockManager = sc.env.blockManager
def writeOutputBytes(mapId: Int, total: AtomicLong) = {
- val shuffle = blockManager.shuffleBlockManager.forShuffle(1, numOutputSplits,
+ val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits,
new KryoSerializer())
- val buckets = shuffle.acquireWriters(mapId)
+ val writers = shuffle.writers
for (i <- 1 to recordsPerMap) {
- buckets.writers(i % numOutputSplits).write(writeData)
+ writers(i % numOutputSplits).write(writeData)
}
- buckets.writers.map {w =>
+ writers.map {w =>
w.commit()
total.addAndGet(w.fileSegment().length)
w.close()
}
- shuffle.releaseWriters(buckets)
+ shuffle.releaseWriters(true)
}
val start = System.currentTimeMillis()
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 35b5d5fd59..c1c7aa70e6 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -152,6 +152,22 @@ private[spark] class StagePage(parent: JobProgressUI) {
else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("")
val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L)
+ var shuffleReadSortable: String = ""
+ var shuffleReadReadable: String = ""
+ if (shuffleRead) {
+ shuffleReadSortable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead}.toString()
+ shuffleReadReadable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s =>
+ Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")
+ }
+
+ var shuffleWriteSortable: String = ""
+ var shuffleWriteReadable: String = ""
+ if (shuffleWrite) {
+ shuffleWriteSortable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten}.toString()
+ shuffleWriteReadable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
+ Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")
+ }
+
<tr>
<td>{info.index}</td>
<td>{info.taskId}</td>
@@ -166,14 +182,17 @@ private[spark] class StagePage(parent: JobProgressUI) {
{if (gcTime > 0) parent.formatDuration(gcTime) else ""}
</td>
{if (shuffleRead) {
- <td>{metrics.flatMap{m => m.shuffleReadMetrics}.map{s =>
- Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")}</td>
+ <td sorttable_customkey={shuffleReadSortable}>
+ {shuffleReadReadable}
+ </td>
}}
{if (shuffleWrite) {
- <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
- parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}</td>
- <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
- Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")}</td>
+ <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
+ parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}
+ </td>
+ <td sorttable_customkey={shuffleWriteSortable}>
+ {shuffleWriteReadable}
+ </td>
}}
<td>{exception.map(e =>
<span>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index d7d0441c38..9ad6de3c6d 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -79,11 +79,14 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr
case None => "Unknown"
}
- val shuffleRead = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L) match {
+ val shuffleReadSortable = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L)
+ val shuffleRead = shuffleReadSortable match {
case 0 => ""
case b => Utils.bytesToString(b)
}
- val shuffleWrite = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L) match {
+
+ val shuffleWriteSortable = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L)
+ val shuffleWrite = shuffleWriteSortable match {
case 0 => ""
case b => Utils.bytesToString(b)
}
@@ -119,8 +122,8 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr
<td class="progress-cell">
{makeProgressBar(startedTasks, completedTasks, failedTasks, totalTasks)}
</td>
- <td>{shuffleRead}</td>
- <td>{shuffleWrite}</td>
+ <td sorttable_customekey={shuffleReadSortable.toString}>{shuffleRead}</td>
+ <td sorttable_customekey={shuffleWriteSortable.toString}>{shuffleWrite}</td>
</tr>
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index 3f963727d9..67a7f87a5c 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -59,7 +59,7 @@ object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext
"ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") {
val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
- SHUFFLE_MAP_TASK, BLOCK_MANAGER, DISK_BLOCK_MANAGER, BROADCAST_VARS = Value
+ SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
type MetadataCleanerType = Value
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index fd2811e44c..fe932d8ede 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,13 +18,12 @@
package org.apache.spark.util
import java.io._
-import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket}
+import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address}
import java.util.{Locale, Random, UUID}
-import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor}
-import java.util.regex.Pattern
+import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor}
import scala.collection.Map
-import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import scala.io.Source
@@ -36,7 +35,7 @@ import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
import org.apache.spark.deploy.SparkHadoopUtil
import java.nio.ByteBuffer
-import org.apache.spark.{SparkEnv, SparkException, Logging}
+import org.apache.spark.{SparkException, Logging}
/**
@@ -148,7 +147,7 @@ private[spark] object Utils extends Logging {
return buf
}
- private val shutdownDeletePaths = new collection.mutable.HashSet[String]()
+ private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]()
// Register the path to be deleted via shutdown hook
def registerShutdownDeleteDir(file: File) {
@@ -818,4 +817,10 @@ private[spark] object Utils extends Logging {
// Nothing else to guard against ?
hashAbs
}
+
+ /** Returns a copy of the system properties that is thread-safe to iterator over. */
+ def getSystemProperties(): Map[String, String] = {
+ return System.getProperties().clone()
+ .asInstanceOf[java.util.Properties].toMap[String, String]
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
new file mode 100644
index 0000000000..a1a452315d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+
+/**
+ * A simple, fixed-size bit set implementation. This implementation is fast because it avoids
+ * safety/bound checking.
+ */
+class BitSet(numBits: Int) {
+
+ private[this] val words = new Array[Long](bit2words(numBits))
+ private[this] val numWords = words.length
+
+ /**
+ * Sets the bit at the specified index to true.
+ * @param index the bit index
+ */
+ def set(index: Int) {
+ val bitmask = 1L << (index & 0x3f) // mod 64 and shift
+ words(index >> 6) |= bitmask // div by 64 and mask
+ }
+
+ /**
+ * Return the value of the bit with the specified index. The value is true if the bit with
+ * the index is currently set in this BitSet; otherwise, the result is false.
+ *
+ * @param index the bit index
+ * @return the value of the bit with the specified index
+ */
+ def get(index: Int): Boolean = {
+ val bitmask = 1L << (index & 0x3f) // mod 64 and shift
+ (words(index >> 6) & bitmask) != 0 // div by 64 and mask
+ }
+
+ /** Return the number of bits set to true in this BitSet. */
+ def cardinality(): Int = {
+ var sum = 0
+ var i = 0
+ while (i < numWords) {
+ sum += java.lang.Long.bitCount(words(i))
+ i += 1
+ }
+ sum
+ }
+
+ /**
+ * Returns the index of the first bit that is set to true that occurs on or after the
+ * specified starting index. If no such bit exists then -1 is returned.
+ *
+ * To iterate over the true bits in a BitSet, use the following loop:
+ *
+ * for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i+1)) {
+ * // operate on index i here
+ * }
+ *
+ * @param fromIndex the index to start checking from (inclusive)
+ * @return the index of the next set bit, or -1 if there is no such bit
+ */
+ def nextSetBit(fromIndex: Int): Int = {
+ var wordIndex = fromIndex >> 6
+ if (wordIndex >= numWords) {
+ return -1
+ }
+
+ // Try to find the next set bit in the current word
+ val subIndex = fromIndex & 0x3f
+ var word = words(wordIndex) >> subIndex
+ if (word != 0) {
+ return (wordIndex << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word)
+ }
+
+ // Find the next set bit in the rest of the words
+ wordIndex += 1
+ while (wordIndex < numWords) {
+ word = words(wordIndex)
+ if (word != 0) {
+ return (wordIndex << 6) + java.lang.Long.numberOfTrailingZeros(word)
+ }
+ wordIndex += 1
+ }
+
+ -1
+ }
+
+ /** Return the number of longs it would take to hold numBits. */
+ private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
new file mode 100644
index 0000000000..80545c9688
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+
+/**
+ * A fast hash map implementation for nullable keys. This hash map supports insertions and updates,
+ * but not deletions. This map is about 5X faster than java.util.HashMap, while using much less
+ * space overhead.
+ *
+ * Under the hood, it uses our OpenHashSet implementation.
+ */
+private[spark]
+class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: ClassManifest](
+ initialCapacity: Int)
+ extends Iterable[(K, V)]
+ with Serializable {
+
+ def this() = this(64)
+
+ protected var _keySet = new OpenHashSet[K](initialCapacity)
+
+ // Init in constructor (instead of in declaration) to work around a Scala compiler specialization
+ // bug that would generate two arrays (one for Object and one for specialized T).
+ private var _values: Array[V] = _
+ _values = new Array[V](_keySet.capacity)
+
+ @transient private var _oldValues: Array[V] = null
+
+ // Treat the null key differently so we can use nulls in "data" to represent empty items.
+ private var haveNullValue = false
+ private var nullValue: V = null.asInstanceOf[V]
+
+ override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size
+
+ /** Get the value for a given key */
+ def apply(k: K): V = {
+ if (k == null) {
+ nullValue
+ } else {
+ val pos = _keySet.getPos(k)
+ if (pos < 0) {
+ null.asInstanceOf[V]
+ } else {
+ _values(pos)
+ }
+ }
+ }
+
+ /** Set the value for a key */
+ def update(k: K, v: V) {
+ if (k == null) {
+ haveNullValue = true
+ nullValue = v
+ } else {
+ val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
+ _values(pos) = v
+ _keySet.rehashIfNeeded(k, grow, move)
+ _oldValues = null
+ }
+ }
+
+ /**
+ * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
+ * set its value to mergeValue(oldValue).
+ *
+ * @return the newly updated value.
+ */
+ def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
+ if (k == null) {
+ if (haveNullValue) {
+ nullValue = mergeValue(nullValue)
+ } else {
+ haveNullValue = true
+ nullValue = defaultValue
+ }
+ nullValue
+ } else {
+ val pos = _keySet.addWithoutResize(k)
+ if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
+ val newValue = defaultValue
+ _values(pos & OpenHashSet.POSITION_MASK) = newValue
+ _keySet.rehashIfNeeded(k, grow, move)
+ newValue
+ } else {
+ _values(pos) = mergeValue(_values(pos))
+ _values(pos)
+ }
+ }
+ }
+
+ override def iterator = new Iterator[(K, V)] {
+ var pos = -1
+ var nextPair: (K, V) = computeNextPair()
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def computeNextPair(): (K, V) = {
+ if (pos == -1) { // Treat position -1 as looking at the null value
+ if (haveNullValue) {
+ pos += 1
+ return (null.asInstanceOf[K], nullValue)
+ }
+ pos += 1
+ }
+ pos = _keySet.nextPos(pos)
+ if (pos >= 0) {
+ val ret = (_keySet.getValue(pos), _values(pos))
+ pos += 1
+ ret
+ } else {
+ null
+ }
+ }
+
+ def hasNext = nextPair != null
+
+ def next() = {
+ val pair = nextPair
+ nextPair = computeNextPair()
+ pair
+ }
+ }
+
+ // The following member variables are declared as protected instead of private for the
+ // specialization to work (specialized class extends the non-specialized one and needs access
+ // to the "private" variables).
+ // They also should have been val's. We use var's because there is a Scala compiler bug that
+ // would throw illegal access error at runtime if they are declared as val's.
+ protected var grow = (newCapacity: Int) => {
+ _oldValues = _values
+ _values = new Array[V](newCapacity)
+ }
+
+ protected var move = (oldPos: Int, newPos: Int) => {
+ _values(newPos) = _oldValues(oldPos)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
new file mode 100644
index 0000000000..4592e4f939
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+
+/**
+ * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never
+ * removed.
+ *
+ * The underlying implementation uses Scala compiler's specialization to generate optimized
+ * storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet
+ * while incurring much less memory overhead. This can serve as building blocks for higher level
+ * data structures such as an optimized HashMap.
+ *
+ * This OpenHashSet is designed to serve as building blocks for higher level data structures
+ * such as an optimized hash map. Compared with standard hash set implementations, this class
+ * provides its various callbacks interfaces (e.g. allocateFunc, moveFunc) and interfaces to
+ * retrieve the position of a key in the underlying array.
+ *
+ * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed
+ * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing).
+ */
+private[spark]
+class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
+ initialCapacity: Int,
+ loadFactor: Double)
+ extends Serializable {
+
+ require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
+ require(initialCapacity >= 1, "Invalid initial capacity")
+ require(loadFactor < 1.0, "Load factor must be less than 1.0")
+ require(loadFactor > 0.0, "Load factor must be greater than 0.0")
+
+ import OpenHashSet._
+
+ def this(initialCapacity: Int) = this(initialCapacity, 0.7)
+
+ def this() = this(64)
+
+ // The following member variables are declared as protected instead of private for the
+ // specialization to work (specialized class extends the non-specialized one and needs access
+ // to the "private" variables).
+
+ protected val hasher: Hasher[T] = {
+ // It would've been more natural to write the following using pattern matching. But Scala 2.9.x
+ // compiler has a bug when specialization is used together with this pattern matching, and
+ // throws:
+ // scala.tools.nsc.symtab.Types$TypeError: type mismatch;
+ // found : scala.reflect.AnyValManifest[Long]
+ // required: scala.reflect.ClassManifest[Int]
+ // at scala.tools.nsc.typechecker.Contexts$Context.error(Contexts.scala:298)
+ // at scala.tools.nsc.typechecker.Infer$Inferencer.error(Infer.scala:207)
+ // ...
+ val mt = classManifest[T]
+ if (mt == ClassManifest.Long) {
+ (new LongHasher).asInstanceOf[Hasher[T]]
+ } else if (mt == ClassManifest.Int) {
+ (new IntHasher).asInstanceOf[Hasher[T]]
+ } else {
+ new Hasher[T]
+ }
+ }
+
+ protected var _capacity = nextPowerOf2(initialCapacity)
+ protected var _mask = _capacity - 1
+ protected var _size = 0
+
+ protected var _bitset = new BitSet(_capacity)
+
+ // Init of the array in constructor (instead of in declaration) to work around a Scala compiler
+ // specialization bug that would generate two arrays (one for Object and one for specialized T).
+ protected var _data: Array[T] = _
+ _data = new Array[T](_capacity)
+
+ /** Number of elements in the set. */
+ def size: Int = _size
+
+ /** The capacity of the set (i.e. size of the underlying array). */
+ def capacity: Int = _capacity
+
+ /** Return true if this set contains the specified element. */
+ def contains(k: T): Boolean = getPos(k) != INVALID_POS
+
+ /**
+ * Add an element to the set. If the set is over capacity after the insertion, grow the set
+ * and rehash all elements.
+ */
+ def add(k: T) {
+ addWithoutResize(k)
+ rehashIfNeeded(k, grow, move)
+ }
+
+ /**
+ * Add an element to the set. This one differs from add in that it doesn't trigger rehashing.
+ * The caller is responsible for calling rehashIfNeeded.
+ *
+ * Use (retval & POSITION_MASK) to get the actual position, and
+ * (retval & EXISTENCE_MASK) != 0 for prior existence.
+ *
+ * @return The position where the key is placed, plus the highest order bit is set if the key
+ * exists previously.
+ */
+ def addWithoutResize(k: T): Int = putInto(_bitset, _data, k)
+
+ /**
+ * Rehash the set if it is overloaded.
+ * @param k A parameter unused in the function, but to force the Scala compiler to specialize
+ * this method.
+ * @param allocateFunc Callback invoked when we are allocating a new, larger array.
+ * @param moveFunc Callback invoked when we move the key from one position (in the old data array)
+ * to a new position (in the new data array).
+ */
+ def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
+ if (_size > loadFactor * _capacity) {
+ rehash(k, allocateFunc, moveFunc)
+ }
+ }
+
+ /**
+ * Return the position of the element in the underlying array, or INVALID_POS if it is not found.
+ */
+ def getPos(k: T): Int = {
+ var pos = hashcode(hasher.hash(k)) & _mask
+ var i = 1
+ while (true) {
+ if (!_bitset.get(pos)) {
+ return INVALID_POS
+ } else if (k == _data(pos)) {
+ return pos
+ } else {
+ val delta = i
+ pos = (pos + delta) & _mask
+ i += 1
+ }
+ }
+ // Never reached here
+ INVALID_POS
+ }
+
+ /** Return the value at the specified position. */
+ def getValue(pos: Int): T = _data(pos)
+
+ /**
+ * Return the next position with an element stored, starting from the given position inclusively.
+ */
+ def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos)
+
+ /**
+ * Put an entry into the set. Return the position where the key is placed. In addition, the
+ * highest bit in the returned position is set if the key exists prior to this put.
+ *
+ * This function assumes the data array has at least one empty slot.
+ */
+ private def putInto(bitset: BitSet, data: Array[T], k: T): Int = {
+ val mask = data.length - 1
+ var pos = hashcode(hasher.hash(k)) & mask
+ var i = 1
+ while (true) {
+ if (!bitset.get(pos)) {
+ // This is a new key.
+ data(pos) = k
+ bitset.set(pos)
+ _size += 1
+ return pos | NONEXISTENCE_MASK
+ } else if (data(pos) == k) {
+ // Found an existing key.
+ return pos
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
+ }
+ // Never reached here
+ assert(INVALID_POS != INVALID_POS)
+ INVALID_POS
+ }
+
+ /**
+ * Double the table's size and re-hash everything. We are not really using k, but it is declared
+ * so Scala compiler can specialize this method (which leads to calling the specialized version
+ * of putInto).
+ *
+ * @param k A parameter unused in the function, but to force the Scala compiler to specialize
+ * this method.
+ * @param allocateFunc Callback invoked when we are allocating a new, larger array.
+ * @param moveFunc Callback invoked when we move the key from one position (in the old data array)
+ * to a new position (in the new data array).
+ */
+ private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
+ val newCapacity = _capacity * 2
+ require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
+
+ allocateFunc(newCapacity)
+ val newData = new Array[T](newCapacity)
+ val newBitset = new BitSet(newCapacity)
+ var pos = 0
+ _size = 0
+ while (pos < _capacity) {
+ if (_bitset.get(pos)) {
+ val newPos = putInto(newBitset, newData, _data(pos))
+ moveFunc(pos, newPos & POSITION_MASK)
+ }
+ pos += 1
+ }
+ _bitset = newBitset
+ _data = newData
+ _capacity = newCapacity
+ _mask = newCapacity - 1
+ }
+
+ /**
+ * Re-hash a value to deal better with hash functions that don't differ
+ * in the lower bits, similar to java.util.HashMap
+ */
+ private def hashcode(h: Int): Int = {
+ val r = h ^ (h >>> 20) ^ (h >>> 12)
+ r ^ (r >>> 7) ^ (r >>> 4)
+ }
+
+ private def nextPowerOf2(n: Int): Int = {
+ val highBit = Integer.highestOneBit(n)
+ if (highBit == n) n else highBit << 1
+ }
+}
+
+
+private[spark]
+object OpenHashSet {
+
+ val INVALID_POS = -1
+ val NONEXISTENCE_MASK = 0x80000000
+ val POSITION_MASK = 0xEFFFFFF
+
+ /**
+ * A set of specialized hash function implementation to avoid boxing hash code computation
+ * in the specialized implementation of OpenHashSet.
+ */
+ sealed class Hasher[@specialized(Long, Int) T] {
+ def hash(o: T): Int = o.hashCode()
+ }
+
+ class LongHasher extends Hasher[Long] {
+ override def hash(o: Long): Int = (o ^ (o >>> 32)).toInt
+ }
+
+ class IntHasher extends Hasher[Int] {
+ override def hash(o: Int): Int = o
+ }
+
+ private def grow1(newSize: Int) {}
+ private def move1(oldPos: Int, newPos: Int) { }
+
+ private val grow = grow1 _
+ private val move = move1 _
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
new file mode 100644
index 0000000000..d76143e45a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+
+/**
+ * A fast hash map implementation for primitive, non-null keys. This hash map supports
+ * insertions and updates, but not deletions. This map is about an order of magnitude
+ * faster than java.util.HashMap, while using much less space overhead.
+ *
+ * Under the hood, it uses our OpenHashSet implementation.
+ */
+private[spark]
+class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest,
+ @specialized(Long, Int, Double) V: ClassManifest](
+ initialCapacity: Int)
+ extends Iterable[(K, V)]
+ with Serializable {
+
+ def this() = this(64)
+
+ require(classManifest[K] == classManifest[Long] || classManifest[K] == classManifest[Int])
+
+ // Init in constructor (instead of in declaration) to work around a Scala compiler specialization
+ // bug that would generate two arrays (one for Object and one for specialized T).
+ protected var _keySet: OpenHashSet[K] = _
+ private var _values: Array[V] = _
+ _keySet = new OpenHashSet[K](initialCapacity)
+ _values = new Array[V](_keySet.capacity)
+
+ private var _oldValues: Array[V] = null
+
+ override def size = _keySet.size
+
+ /** Get the value for a given key */
+ def apply(k: K): V = {
+ val pos = _keySet.getPos(k)
+ _values(pos)
+ }
+
+ /** Get the value for a given key, or returns elseValue if it doesn't exist. */
+ def getOrElse(k: K, elseValue: V): V = {
+ val pos = _keySet.getPos(k)
+ if (pos >= 0) _values(pos) else elseValue
+ }
+
+ /** Set the value for a key */
+ def update(k: K, v: V) {
+ val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
+ _values(pos) = v
+ _keySet.rehashIfNeeded(k, grow, move)
+ _oldValues = null
+ }
+
+ /**
+ * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
+ * set its value to mergeValue(oldValue).
+ *
+ * @return the newly updated value.
+ */
+ def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
+ val pos = _keySet.addWithoutResize(k)
+ if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
+ val newValue = defaultValue
+ _values(pos & OpenHashSet.POSITION_MASK) = newValue
+ _keySet.rehashIfNeeded(k, grow, move)
+ newValue
+ } else {
+ _values(pos) = mergeValue(_values(pos))
+ _values(pos)
+ }
+ }
+
+ override def iterator = new Iterator[(K, V)] {
+ var pos = 0
+ var nextPair: (K, V) = computeNextPair()
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def computeNextPair(): (K, V) = {
+ pos = _keySet.nextPos(pos)
+ if (pos >= 0) {
+ val ret = (_keySet.getValue(pos), _values(pos))
+ pos += 1
+ ret
+ } else {
+ null
+ }
+ }
+
+ def hasNext = nextPair != null
+
+ def next() = {
+ val pair = nextPair
+ nextPair = computeNextPair()
+ pair
+ }
+ }
+
+ // The following member variables are declared as protected instead of private for the
+ // specialization to work (specialized class extends the unspecialized one and needs access
+ // to the "private" variables).
+ // They also should have been val's. We use var's because there is a Scala compiler bug that
+ // would throw illegal access error at runtime if they are declared as val's.
+ protected var grow = (newCapacity: Int) => {
+ _oldValues = _values
+ _values = new Array[V](newCapacity)
+ }
+
+ protected var move = (oldPos: Int, newPos: Int) => {
+ _values(newPos) = _oldValues(oldPos)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
new file mode 100644
index 0000000000..369519c559
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+/** Provides a simple, non-threadsafe, array-backed vector that can store primitives. */
+private[spark]
+class PrimitiveVector[@specialized(Long, Int, Double) V: ClassManifest](initialSize: Int = 64) {
+ private var numElements = 0
+ private var array: Array[V] = _
+
+ // NB: This must be separate from the declaration, otherwise the specialized parent class
+ // will get its own array with the same initial size. TODO: Figure out why...
+ array = new Array[V](initialSize)
+
+ def apply(index: Int): V = {
+ require(index < numElements)
+ array(index)
+ }
+
+ def +=(value: V) {
+ if (numElements == array.length) { resize(array.length * 2) }
+ array(numElements) = value
+ numElements += 1
+ }
+
+ def length = numElements
+
+ def getUnderlyingArray = array
+
+ /** Resizes the array, dropping elements if the total length decreases. */
+ def resize(newLength: Int) {
+ val newArray = new Array[V](newLength)
+ array.copyToArray(newArray)
+ array = newArray
+ }
+}