aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2015-11-25 23:31:21 -0800
committerShixiong Zhu <shixiong@databricks.com>2015-11-25 23:31:21 -0800
commitd3ef693325f91a1ed340c9756c81244a80398eb2 (patch)
tree077f5f457d58eb42b69bb90b56804db65792a6d9 /core
parent068b6438d6886ce5b4aa698383866f466d913d66 (diff)
downloadspark-d3ef693325f91a1ed340c9756c81244a80398eb2.tar.gz
spark-d3ef693325f91a1ed340c9756c81244a80398eb2.tar.bz2
spark-d3ef693325f91a1ed340c9756c81244a80398eb2.zip
[SPARK-11999][CORE] Fix the issue that ThreadUtils.newDaemonCachedThreadPool doesn't cache any task
In the previous codes, `newDaemonCachedThreadPool` uses `SynchronousQueue`, which is wrong. `SynchronousQueue` is an empty queue that cannot cache any task. This patch uses `LinkedBlockingQueue` to fix it along with other fixes to make sure `newDaemonCachedThreadPool` can use at most `maxThreadNumber` threads, and after that, cache tasks to `LinkedBlockingQueue`. Author: Shixiong Zhu <shixiong@databricks.com> Closes #9978 from zsxwing/cached-threadpool.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/util/ThreadUtils.scala14
-rw-r--r--core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala45
2 files changed, 56 insertions, 3 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index 53283448c8..f9fbe2ff85 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -56,10 +56,18 @@ private[spark] object ThreadUtils {
* Create a cached thread pool whose max number of threads is `maxThreadNumber`. Thread names
* are formatted as prefix-ID, where ID is a unique, sequentially assigned integer.
*/
- def newDaemonCachedThreadPool(prefix: String, maxThreadNumber: Int): ThreadPoolExecutor = {
+ def newDaemonCachedThreadPool(
+ prefix: String, maxThreadNumber: Int, keepAliveSeconds: Int = 60): ThreadPoolExecutor = {
val threadFactory = namedThreadFactory(prefix)
- new ThreadPoolExecutor(
- 0, maxThreadNumber, 60L, TimeUnit.SECONDS, new SynchronousQueue[Runnable], threadFactory)
+ val threadPool = new ThreadPoolExecutor(
+ maxThreadNumber, // corePoolSize: the max number of threads to create before queuing the tasks
+ maxThreadNumber, // maximumPoolSize: because we use LinkedBlockingDeque, this one is not used
+ keepAliveSeconds,
+ TimeUnit.SECONDS,
+ new LinkedBlockingQueue[Runnable],
+ threadFactory)
+ threadPool.allowCoreThreadTimeOut(true)
+ threadPool
}
/**
diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
index 620e4debf4..92ae038967 100644
--- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala
@@ -24,6 +24,8 @@ import scala.concurrent.duration._
import scala.concurrent.{Await, Future}
import scala.util.Random
+import org.scalatest.concurrent.Eventually._
+
import org.apache.spark.SparkFunSuite
class ThreadUtilsSuite extends SparkFunSuite {
@@ -59,6 +61,49 @@ class ThreadUtilsSuite extends SparkFunSuite {
}
}
+ test("newDaemonCachedThreadPool") {
+ val maxThreadNumber = 10
+ val startThreadsLatch = new CountDownLatch(maxThreadNumber)
+ val latch = new CountDownLatch(1)
+ val cachedThreadPool = ThreadUtils.newDaemonCachedThreadPool(
+ "ThreadUtilsSuite-newDaemonCachedThreadPool",
+ maxThreadNumber,
+ keepAliveSeconds = 2)
+ try {
+ for (_ <- 1 to maxThreadNumber) {
+ cachedThreadPool.execute(new Runnable {
+ override def run(): Unit = {
+ startThreadsLatch.countDown()
+ latch.await(10, TimeUnit.SECONDS)
+ }
+ })
+ }
+ startThreadsLatch.await(10, TimeUnit.SECONDS)
+ assert(cachedThreadPool.getActiveCount === maxThreadNumber)
+ assert(cachedThreadPool.getQueue.size === 0)
+
+ // Submit a new task and it should be put into the queue since the thread number reaches the
+ // limitation
+ cachedThreadPool.execute(new Runnable {
+ override def run(): Unit = {
+ latch.await(10, TimeUnit.SECONDS)
+ }
+ })
+
+ assert(cachedThreadPool.getActiveCount === maxThreadNumber)
+ assert(cachedThreadPool.getQueue.size === 1)
+
+ latch.countDown()
+ eventually(timeout(10.seconds)) {
+ // All threads should be stopped after keepAliveSeconds
+ assert(cachedThreadPool.getActiveCount === 0)
+ assert(cachedThreadPool.getPoolSize === 0)
+ }
+ } finally {
+ cachedThreadPool.shutdownNow()
+ }
+ }
+
test("sameThread") {
val callerThreadName = Thread.currentThread().getName()
val f = Future {