aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala48
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala25
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala22
-rw-r--r--project/MimaExcludes.scala2
8 files changed, 162 insertions, 17 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 21d0cc7b5c..6b63eb23e9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -23,6 +23,7 @@ import java.io.EOFException
import scala.collection.immutable.Map
import scala.reflect.ClassTag
+import scala.collection.mutable.ListBuffer
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.mapred.FileSplit
@@ -43,6 +44,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.{DataReadMethod, InputMetrics}
import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.{NextIterator, Utils}
+import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation}
/**
@@ -249,9 +251,21 @@ class HadoopRDD[K, V](
}
override def getPreferredLocations(split: Partition): Seq[String] = {
- // TODO: Filtering out "localhost" in case of file:// URLs
- val hadoopSplit = split.asInstanceOf[HadoopPartition]
- hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
+ val hsplit = split.asInstanceOf[HadoopPartition].inputSplit.value
+ val locs: Option[Seq[String]] = HadoopRDD.SPLIT_INFO_REFLECTIONS match {
+ case Some(c) =>
+ try {
+ val lsplit = c.inputSplitWithLocationInfo.cast(hsplit)
+ val infos = c.getLocationInfo.invoke(lsplit).asInstanceOf[Array[AnyRef]]
+ Some(HadoopRDD.convertSplitLocationInfo(infos))
+ } catch {
+ case e: Exception =>
+ logDebug("Failed to use InputSplitWithLocations.", e)
+ None
+ }
+ case None => None
+ }
+ locs.getOrElse(hsplit.getLocations.filter(_ != "localhost"))
}
override def checkpoint() {
@@ -261,7 +275,7 @@ class HadoopRDD[K, V](
def getConf: Configuration = getJobConf()
}
-private[spark] object HadoopRDD {
+private[spark] object HadoopRDD extends Logging {
/** Constructing Configuration objects is not threadsafe, use this lock to serialize. */
val CONFIGURATION_INSTANTIATION_LOCK = new Object()
@@ -309,4 +323,42 @@ private[spark] object HadoopRDD {
f(inputSplit, firstParent[T].iterator(split, context))
}
}
+
+ private[spark] class SplitInfoReflections {
+ val inputSplitWithLocationInfo =
+ Class.forName("org.apache.hadoop.mapred.InputSplitWithLocationInfo")
+ val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo")
+ val newInputSplit = Class.forName("org.apache.hadoop.mapreduce.InputSplit")
+ val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo")
+ val splitLocationInfo = Class.forName("org.apache.hadoop.mapred.SplitLocationInfo")
+ val isInMemory = splitLocationInfo.getMethod("isInMemory")
+ val getLocation = splitLocationInfo.getMethod("getLocation")
+ }
+
+ private[spark] val SPLIT_INFO_REFLECTIONS: Option[SplitInfoReflections] = try {
+ Some(new SplitInfoReflections)
+ } catch {
+ case e: Exception =>
+ logDebug("SplitLocationInfo and other new Hadoop classes are " +
+ "unavailable. Using the older Hadoop location info code.", e)
+ None
+ }
+
+ private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Seq[String] = {
+ val out = ListBuffer[String]()
+ infos.foreach { loc => {
+ val locationStr = HadoopRDD.SPLIT_INFO_REFLECTIONS.get.
+ getLocation.invoke(loc).asInstanceOf[String]
+ if (locationStr != "localhost") {
+ if (HadoopRDD.SPLIT_INFO_REFLECTIONS.get.isInMemory.
+ invoke(loc).asInstanceOf[Boolean]) {
+ logDebug("Partition " + locationStr + " is cached by Hadoop.")
+ out += new HDFSCacheTaskLocation(locationStr).toString
+ } else {
+ out += new HostTaskLocation(locationStr).toString
+ }
+ }
+ }}
+ out.seq
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 4c84b3f623..0cccdefc5e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -173,9 +173,21 @@ class NewHadoopRDD[K, V](
new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning)
}
- override def getPreferredLocations(split: Partition): Seq[String] = {
- val theSplit = split.asInstanceOf[NewHadoopPartition]
- theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
+ override def getPreferredLocations(hsplit: Partition): Seq[String] = {
+ val split = hsplit.asInstanceOf[NewHadoopPartition].serializableHadoopSplit.value
+ val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match {
+ case Some(c) =>
+ try {
+ val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]]
+ Some(HadoopRDD.convertSplitLocationInfo(infos))
+ } catch {
+ case e : Exception =>
+ logDebug("Failed to use InputSplit#getLocationInfo.", e)
+ None
+ }
+ case None => None
+ }
+ locs.getOrElse(split.getLocations.filter(_ != "localhost"))
}
def getConf: Configuration = confBroadcast.value.value
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index ab9e97c8fe..2aba40d152 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -208,7 +208,7 @@ abstract class RDD[T: ClassTag](
}
/**
- * Get the preferred locations of a partition (as hostnames), taking into account whether the
+ * Get the preferred locations of a partition, taking into account whether the
* RDD is checkpointed.
*/
final def preferredLocations(split: Partition): Seq[String] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 5a96f52a10..8135cdbb4c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1303,7 +1303,7 @@ class DAGScheduler(
// If the RDD has some placement preferences (as is the case for input RDDs), get those
val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
if (!rddPrefs.isEmpty) {
- return rddPrefs.map(host => TaskLocation(host))
+ return rddPrefs.map(TaskLocation(_))
}
// If the RDD has narrow dependencies, pick the first partition of the first narrow dep
// that has any placement preferences. Ideally we would choose based on transfer sizes,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
index 67c9a6760b..10c685f29d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
@@ -22,13 +22,51 @@ package org.apache.spark.scheduler
* In the latter case, we will prefer to launch the task on that executorID, but our next level
* of preference will be executors on the same host if this is not possible.
*/
-private[spark]
-class TaskLocation private (val host: String, val executorId: Option[String]) extends Serializable {
- override def toString: String = "TaskLocation(" + host + ", " + executorId + ")"
+private[spark] sealed trait TaskLocation {
+ def host: String
+}
+
+/**
+ * A location that includes both a host and an executor id on that host.
+ */
+private [spark] case class ExecutorCacheTaskLocation(override val host: String,
+ val executorId: String) extends TaskLocation {
+}
+
+/**
+ * A location on a host.
+ */
+private [spark] case class HostTaskLocation(override val host: String) extends TaskLocation {
+ override def toString = host
+}
+
+/**
+ * A location on a host that is cached by HDFS.
+ */
+private [spark] case class HDFSCacheTaskLocation(override val host: String)
+ extends TaskLocation {
+ override def toString = TaskLocation.inMemoryLocationTag + host
}
private[spark] object TaskLocation {
- def apply(host: String, executorId: String) = new TaskLocation(host, Some(executorId))
+ // We identify hosts on which the block is cached with this prefix. Because this prefix contains
+ // underscores, which are not legal characters in hostnames, there should be no potential for
+ // confusion. See RFC 952 and RFC 1123 for information about the format of hostnames.
+ val inMemoryLocationTag = "hdfs_cache_"
+
+ def apply(host: String, executorId: String) = new ExecutorCacheTaskLocation(host, executorId)
- def apply(host: String) = new TaskLocation(host, None)
+ /**
+ * Create a TaskLocation from a string returned by getPreferredLocations.
+ * These strings have the form [hostname] or hdfs_cache_[hostname], depending on whether the
+ * location is cached.
+ */
+ def apply(str: String) = {
+ val hstr = str.stripPrefix(inMemoryLocationTag)
+ if (hstr.equals(str)) {
+ new HostTaskLocation(str)
+ } else {
+ new HostTaskLocation(hstr)
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index d9d53faf84..a6c23fc85a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -181,8 +181,24 @@ private[spark] class TaskSetManager(
}
for (loc <- tasks(index).preferredLocations) {
- for (execId <- loc.executorId) {
- addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
+ loc match {
+ case e: ExecutorCacheTaskLocation =>
+ addTo(pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer))
+ case e: HDFSCacheTaskLocation => {
+ val exe = sched.getExecutorsAliveOnHost(loc.host)
+ exe match {
+ case Some(set) => {
+ for (e <- set) {
+ addTo(pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer))
+ }
+ logInfo(s"Pending task $index has a cached location at ${e.host} " +
+ ", where there are executors " + set.mkString(","))
+ }
+ case None => logDebug(s"Pending task $index has a cached location at ${e.host} " +
+ ", but there are no executors alive there.")
+ }
+ }
+ case _ => Unit
}
addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
for (rack <- sched.getRackForHost(loc.host)) {
@@ -283,7 +299,10 @@ private[spark] class TaskSetManager(
// on multiple nodes when we replicate cached blocks, as in Spark Streaming
for (index <- speculatableTasks if canRunOnHost(index)) {
val prefs = tasks(index).preferredLocations
- val executors = prefs.flatMap(_.executorId)
+ val executors = prefs.flatMap(_ match {
+ case e: ExecutorCacheTaskLocation => Some(e.executorId)
+ case _ => None
+ });
if (executors.contains(execId)) {
speculatableTasks -= index
return Some((index, TaskLocality.PROCESS_LOCAL))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 93e8ddacf8..c0b07649eb 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -642,6 +642,28 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
assert(manager.resourceOffer("execC", "host3", ANY) !== None)
}
+ test("Test that locations with HDFSCacheTaskLocation are treated as PROCESS_LOCAL.") {
+ // Regression test for SPARK-2931
+ sc = new SparkContext("local", "test")
+ val sched = new FakeTaskScheduler(sc,
+ ("execA", "host1"), ("execB", "host2"), ("execC", "host3"))
+ val taskSet = FakeTask.createTaskSet(3,
+ Seq(HostTaskLocation("host1")),
+ Seq(HostTaskLocation("host2")),
+ Seq(HDFSCacheTaskLocation("host3")))
+ val clock = new FakeClock
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
+ assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
+ sched.removeExecutor("execA")
+ manager.executorAdded()
+ assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
+ sched.removeExecutor("execB")
+ manager.executorAdded()
+ assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
+ sched.removeExecutor("execC")
+ manager.executorAdded()
+ assert(manager.myLocalityLevels.sameElements(Array(ANY)))
+ }
def createTaskResult(id: Int): DirectTaskResult[Int] = {
val valueSer = SparkEnv.get.serializer.newInstance()
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 4076ebc6fc..d499302124 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -41,6 +41,8 @@ object MimaExcludes {
MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++
MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++
Seq(
+ ProblemFilters.exclude[IncompatibleTemplateDefProblem](
+ "org.apache.spark.scheduler.TaskLocation"),
// Added normL1 and normL2 to trait MultivariateStatisticalSummary
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"),