diff options
author | jinxing <jinxing6042@126.com> | 2017-03-23 23:25:56 -0700 |
---|---|---|
committer | Kay Ousterhout <kayousterhout@gmail.com> | 2017-03-23 23:25:56 -0700 |
commit | 19596c28b6ef6e7abe0cfccfd2269c2fddf1fdee (patch) | |
tree | 3634264ca51c7cb76ae1b6c8de710fdeed08ac0a /core | |
parent | bb823ca4b479a00030c4919c2d857d254b2a44d8 (diff) | |
download | spark-19596c28b6ef6e7abe0cfccfd2269c2fddf1fdee.tar.gz spark-19596c28b6ef6e7abe0cfccfd2269c2fddf1fdee.tar.bz2 spark-19596c28b6ef6e7abe0cfccfd2269c2fddf1fdee.zip |
[SPARK-16929] Improve performance when check speculatable tasks.
## What changes were proposed in this pull request?
1. Use a MedianHeap to record durations of successful tasks. When check speculatable tasks, we can get the median duration with O(1) time complexity.
2. `checkSpeculatableTasks` will synchronize `TaskSchedulerImpl`. If `checkSpeculatableTasks` doesn't finish with 100ms, then the possibility exists for that thread to release and then immediately re-acquire the lock. Change `scheduleAtFixedRate` to be `scheduleWithFixedDelay` when call method of `checkSpeculatableTasks`.
## How was this patch tested?
Added MedianHeapSuite.
Author: jinxing <jinxing6042@126.com>
Closes #16867 from jinxing64/SPARK-16929.
Diffstat (limited to 'core')
5 files changed, 176 insertions, 6 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 8257c70d67..d6225a0873 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -174,7 +174,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - speculationScheduler.scheduleAtFixedRate(new Runnable { + speculationScheduler.scheduleWithFixedDelay(new Runnable { override def run(): Unit = Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index fd93a1f5c5..f4a21bca79 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -19,11 +19,10 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.nio.ByteBuffer -import java.util.Arrays import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.math.{max, min} +import scala.math.max import scala.util.control.NonFatal import org.apache.spark._ @@ -31,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.TaskState.TaskState import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} +import org.apache.spark.util.collection.MedianHeap /** * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of @@ -63,6 +63,8 @@ private[spark] class TaskSetManager( // Limit of bytes for total size of results (default is 1GB) val maxResultSize = Utils.getMaxResultSize(conf) + val speculationEnabled = conf.getBoolean("spark.speculation", false) + // Serializer for closures and tasks. val env = SparkEnv.get val ser = env.closureSerializer.newInstance() @@ -141,6 +143,11 @@ private[spark] class TaskSetManager( // Task index, start and finish time for each task attempt (indexed by task ID) private val taskInfos = new HashMap[Long, TaskInfo] + // Use a MedianHeap to record durations of successful tasks so we know when to launch + // speculative tasks. This is only used when speculation is enabled, to avoid the overhead + // of inserting into the heap when the heap won't be used. + val successfulTaskDurations = new MedianHeap() + // How frequently to reprint duplicate exceptions in full, in milliseconds val EXCEPTION_PRINT_INTERVAL = conf.getLong("spark.logging.exceptionPrintInterval", 10000) @@ -698,6 +705,9 @@ private[spark] class TaskSetManager( val info = taskInfos(tid) val index = info.index info.markFinished(TaskState.FINISHED, clock.getTimeMillis()) + if (speculationEnabled) { + successfulTaskDurations.insert(info.duration) + } removeRunningTask(tid) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not @@ -919,11 +929,10 @@ private[spark] class TaskSetManager( var foundTasks = false val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) + if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { val time = clock.getTimeMillis() - val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray - Arrays.sort(durations) - val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.length - 1)) + var medianDuration = successfulTaskDurations.median val threshold = max(SPECULATION_MULTIPLIER * medianDuration, minTimeToSpeculation) // TODO: Threshold should also look at standard deviation of task durations and have a lower // bound based on that. diff --git a/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala b/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala new file mode 100644 index 0000000000..6e57c3c5be --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.collection.mutable.PriorityQueue + +/** + * MedianHeap is designed to be used to quickly track the median of a group of numbers + * that may contain duplicates. Inserting a new number has O(log n) time complexity and + * determining the median has O(1) time complexity. + * The basic idea is to maintain two heaps: a smallerHalf and a largerHalf. The smallerHalf + * stores the smaller half of all numbers while the largerHalf stores the larger half. + * The sizes of two heaps need to be balanced each time when a new number is inserted so + * that their sizes will not be different by more than 1. Therefore each time when + * findMedian() is called we check if two heaps have the same size. If they do, we should + * return the average of the two top values of heaps. Otherwise we return the top of the + * heap which has one more element. + */ +private[spark] class MedianHeap(implicit val ord: Ordering[Double]) { + + /** + * Stores all the numbers less than the current median in a smallerHalf, + * i.e median is the maximum, at the root. + */ + private[this] var smallerHalf = PriorityQueue.empty[Double](ord) + + /** + * Stores all the numbers greater than the current median in a largerHalf, + * i.e median is the minimum, at the root. + */ + private[this] var largerHalf = PriorityQueue.empty[Double](ord.reverse) + + def isEmpty(): Boolean = { + smallerHalf.isEmpty && largerHalf.isEmpty + } + + def size(): Int = { + smallerHalf.size + largerHalf.size + } + + def insert(x: Double): Unit = { + // If both heaps are empty, we arbitrarily insert it into a heap, let's say, the largerHalf. + if (isEmpty) { + largerHalf.enqueue(x) + } else { + // If the number is larger than current median, it should be inserted into largerHalf, + // otherwise smallerHalf. + if (x > median) { + largerHalf.enqueue(x) + } else { + smallerHalf.enqueue(x) + } + } + rebalance() + } + + private[this] def rebalance(): Unit = { + if (largerHalf.size - smallerHalf.size > 1) { + smallerHalf.enqueue(largerHalf.dequeue()) + } + if (smallerHalf.size - largerHalf.size > 1) { + largerHalf.enqueue(smallerHalf.dequeue) + } + } + + def median: Double = { + if (isEmpty) { + throw new NoSuchElementException("MedianHeap is empty.") + } + if (largerHalf.size == smallerHalf.size) { + (largerHalf.head + smallerHalf.head) / 2.0 + } else if (largerHalf.size > smallerHalf.size) { + largerHalf.head + } else { + smallerHalf.head + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index f36bcd8504..064af381a7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -893,6 +893,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSet = FakeTask.createTaskSet(4) // Set the speculation multiplier to be 0 so speculative tasks are launched immediately sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation", "true") val clock = new ManualClock() val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => @@ -948,6 +949,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Set the speculation multiplier to be 0 so speculative tasks are launched immediately sc.conf.set("spark.speculation.multiplier", "0.0") sc.conf.set("spark.speculation.quantile", "0.6") + sc.conf.set("spark.speculation", "true") val clock = new ManualClock() val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => diff --git a/core/src/test/scala/org/apache/spark/util/collection/MedianHeapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/MedianHeapSuite.scala new file mode 100644 index 0000000000..c2a3ee95f1 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/MedianHeapSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.util.NoSuchElementException + +import org.apache.spark.SparkFunSuite + +class MedianHeapSuite extends SparkFunSuite { + + test("If no numbers in MedianHeap, NoSuchElementException is thrown.") { + val medianHeap = new MedianHeap() + intercept[NoSuchElementException] { + medianHeap.median + } + } + + test("Median should be correct when size of MedianHeap is even") { + val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + val medianHeap = new MedianHeap() + array.foreach(medianHeap.insert(_)) + assert(medianHeap.size() === 10) + assert(medianHeap.median === 4.5) + } + + test("Median should be correct when size of MedianHeap is odd") { + val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8) + val medianHeap = new MedianHeap() + array.foreach(medianHeap.insert(_)) + assert(medianHeap.size() === 9) + assert(medianHeap.median === 4) + } + + test("Median should be correct though there are duplicated numbers inside.") { + val array = Array(0, 0, 1, 1, 2, 3, 4) + val medianHeap = new MedianHeap() + array.foreach(medianHeap.insert(_)) + assert(medianHeap.size === 7) + assert(medianHeap.median === 1) + } + + test("Median should be correct when input data is skewed.") { + val medianHeap = new MedianHeap() + (0 until 10).foreach(_ => medianHeap.insert(5)) + assert(medianHeap.median === 5) + (0 until 100).foreach(_ => medianHeap.insert(10)) + assert(medianHeap.median === 10) + (0 until 1000).foreach(_ => medianHeap.insert(0)) + assert(medianHeap.median === 0) + } +} |