diff options
author | zhichao.li <zhichao.li@intel.com> | 2015-10-22 03:59:26 -0700 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2015-10-22 03:59:26 -0700 |
commit | c03b6d11589102b91f08728519e8520025db91e1 (patch) | |
tree | 104b3a1a1e97adbe5dfc8af1ec3870a8dd9c5b81 | |
parent | 1d9733271595596683a6d956a7433fa601df1cc1 (diff) | |
download | spark-c03b6d11589102b91f08728519e8520025db91e1.tar.gz spark-c03b6d11589102b91f08728519e8520025db91e1.tar.bz2 spark-c03b6d11589102b91f08728519e8520025db91e1.zip |
[SPARK-11121][CORE] Correct the TaskLocation type
Correct the logic to return `HDFSCacheTaskLocation` instance when the input `str` is a in memory location.
Author: zhichao.li <zhichao.li@intel.com>
Closes #9096 from zhichao-li/uselessBranch.
-rw-r--r-- | core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala | 2 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala | 11 |
2 files changed, 9 insertions, 4 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala index da07ce2c6e..1b65926f5c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala @@ -67,7 +67,7 @@ private[spark] object TaskLocation { if (hstr.equals(str)) { new HostTaskLocation(str) } else { - new HostTaskLocation(hstr) + new HDFSCacheTaskLocation(hstr) } } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index f0eadf2409..695523cc8a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -759,9 +759,9 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) val taskSet = FakeTask.createTaskSet(3, - Seq(HostTaskLocation("host1")), - Seq(HostTaskLocation("host2")), - Seq(HDFSCacheTaskLocation("host3"))) + Seq(TaskLocation("host1")), + Seq(TaskLocation("host2")), + Seq(TaskLocation("hdfs_cache_host3"))) val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) @@ -776,6 +776,11 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.myLocalityLevels.sameElements(Array(ANY))) } + test("Test TaskLocation for different host type.") { + assert(TaskLocation("host1") === HostTaskLocation("host1")) + assert(TaskLocation("hdfs_cache_host1") === HDFSCacheTaskLocation("host1")) + } + def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics) |