summaryrefslogtreecommitdiff
path: root/src/library/scala/concurrent/impl/ExecutionContextImpl.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/library/scala/concurrent/impl/ExecutionContextImpl.scala')
-rw-r--r--src/library/scala/concurrent/impl/ExecutionContextImpl.scala230
1 files changed, 127 insertions, 103 deletions
diff --git a/src/library/scala/concurrent/impl/ExecutionContextImpl.scala b/src/library/scala/concurrent/impl/ExecutionContextImpl.scala
index 479720287c..19233d7531 100644
--- a/src/library/scala/concurrent/impl/ExecutionContextImpl.scala
+++ b/src/library/scala/concurrent/impl/ExecutionContextImpl.scala
@@ -8,55 +8,87 @@
package scala.concurrent.impl
-
-
-import java.util.concurrent.{ LinkedBlockingQueue, Callable, Executor, ExecutorService, Executors, ThreadFactory, TimeUnit, ThreadPoolExecutor }
+import java.util.concurrent.{ ForkJoinPool, ForkJoinWorkerThread, ForkJoinTask, Callable, Executor, ExecutorService, ThreadFactory, TimeUnit }
+import java.util.concurrent.atomic.AtomicInteger
import java.util.Collection
-import scala.concurrent.forkjoin._
-import scala.concurrent.{ BlockContext, ExecutionContext, Awaitable, CanAwait, ExecutionContextExecutor, ExecutionContextExecutorService }
-import scala.util.control.NonFatal
+import scala.concurrent.{ BlockContext, ExecutionContext, CanAwait, ExecutionContextExecutor, ExecutionContextExecutorService }
+import scala.annotation.tailrec
+private[scala] class ExecutionContextImpl private[impl] (val executor: Executor, val reporter: Throwable => Unit) extends ExecutionContextExecutor {
+ require(executor ne null, "Executor must not be null")
+ override def execute(runnable: Runnable) = executor execute runnable
+ override def reportFailure(t: Throwable) = reporter(t)
+}
-private[scala] class ExecutionContextImpl private[impl] (es: Executor, reporter: Throwable => Unit) extends ExecutionContextExecutor {
- // Placed here since the creation of the executor needs to read this val
- private[this] val uncaughtExceptionHandler: Thread.UncaughtExceptionHandler = new Thread.UncaughtExceptionHandler {
- def uncaughtException(thread: Thread, cause: Throwable): Unit = reporter(cause)
- }
- val executor: Executor = es match {
- case null => createExecutorService
- case some => some
- }
+private[concurrent] object ExecutionContextImpl {
// Implement BlockContext on FJP threads
- class DefaultThreadFactory(daemonic: Boolean) extends ThreadFactory with ForkJoinPool.ForkJoinWorkerThreadFactory {
+ final class DefaultThreadFactory(
+ daemonic: Boolean,
+ maxThreads: Int,
+ prefix: String,
+ uncaught: Thread.UncaughtExceptionHandler) extends ThreadFactory with ForkJoinPool.ForkJoinWorkerThreadFactory {
+
+ require(prefix ne null, "DefaultThreadFactory.prefix must be non null")
+ require(maxThreads > 0, "DefaultThreadFactory.maxThreads must be greater than 0")
+
+ private final val currentNumberOfThreads = new AtomicInteger(0)
+
+ @tailrec private final def reserveThread(): Boolean = currentNumberOfThreads.get() match {
+ case `maxThreads` | Int.`MaxValue` => false
+ case other => currentNumberOfThreads.compareAndSet(other, other + 1) || reserveThread()
+ }
+
+ @tailrec private final def deregisterThread(): Boolean = currentNumberOfThreads.get() match {
+ case 0 => false
+ case other => currentNumberOfThreads.compareAndSet(other, other - 1) || deregisterThread()
+ }
+
def wire[T <: Thread](thread: T): T = {
thread.setDaemon(daemonic)
- thread.setUncaughtExceptionHandler(uncaughtExceptionHandler)
+ thread.setUncaughtExceptionHandler(uncaught)
+ thread.setName(prefix + "-" + thread.getId())
thread
}
- def newThread(runnable: Runnable): Thread = wire(new Thread(runnable))
-
- def newThread(fjp: ForkJoinPool): ForkJoinWorkerThread = wire(new ForkJoinWorkerThread(fjp) with BlockContext {
- override def blockOn[T](thunk: =>T)(implicit permission: CanAwait): T = {
- var result: T = null.asInstanceOf[T]
- ForkJoinPool.managedBlock(new ForkJoinPool.ManagedBlocker {
- @volatile var isdone = false
- override def block(): Boolean = {
- result = try thunk finally { isdone = true }
- true
+ // As per ThreadFactory contract newThread should return `null` if cannot create new thread.
+ def newThread(runnable: Runnable): Thread =
+ if (reserveThread())
+ wire(new Thread(new Runnable {
+ // We have to decrement the current thread count when the thread exits
+ override def run() = try runnable.run() finally deregisterThread()
+ })) else null
+
+ def newThread(fjp: ForkJoinPool): ForkJoinWorkerThread =
+ if (reserveThread()) {
+ wire(new ForkJoinWorkerThread(fjp) with BlockContext {
+ // We have to decrement the current thread count when the thread exits
+ final override def onTermination(exception: Throwable): Unit = deregisterThread()
+ final override def blockOn[T](thunk: =>T)(implicit permission: CanAwait): T = {
+ var result: T = null.asInstanceOf[T]
+ ForkJoinPool.managedBlock(new ForkJoinPool.ManagedBlocker {
+ @volatile var isdone = false
+ override def block(): Boolean = {
+ result = try {
+ // When we block, switch out the BlockContext temporarily so that nested blocking does not created N new Threads
+ BlockContext.withBlockContext(BlockContext.defaultBlockContext) { thunk }
+ } finally {
+ isdone = true
+ }
+
+ true
+ }
+ override def isReleasable = isdone
+ })
+ result
}
- override def isReleasable = isdone
})
- result
- }
- })
+ } else null
}
- def createExecutorService: ExecutorService = {
-
+ def createDefaultExecutorService(reporter: Throwable => Unit): ExecutorService = {
def getInt(name: String, default: String) = (try System.getProperty(name, default) catch {
case e: SecurityException => default
}) match {
@@ -65,87 +97,79 @@ private[scala] class ExecutionContextImpl private[impl] (es: Executor, reporter:
}
def range(floor: Int, desired: Int, ceiling: Int) = scala.math.min(scala.math.max(floor, desired), ceiling)
+ val numThreads = getInt("scala.concurrent.context.numThreads", "x1")
+ // The hard limit on the number of active threads that the thread factory will produce
+ // SI-8955 Deadlocks can happen if maxNoOfThreads is too low, although we're currently not sure
+ // about what the exact threshold is. numThreads + 256 is conservatively high.
+ val maxNoOfThreads = getInt("scala.concurrent.context.maxThreads", "x1")
val desiredParallelism = range(
getInt("scala.concurrent.context.minThreads", "1"),
- getInt("scala.concurrent.context.numThreads", "x1"),
- getInt("scala.concurrent.context.maxThreads", "x1"))
-
- val threadFactory = new DefaultThreadFactory(daemonic = true)
-
- try {
- new ForkJoinPool(
- desiredParallelism,
- threadFactory,
- uncaughtExceptionHandler,
- true) // Async all the way baby
- } catch {
- case NonFatal(t) =>
- System.err.println("Failed to create ForkJoinPool for the default ExecutionContext, falling back to ThreadPoolExecutor")
- t.printStackTrace(System.err)
- val exec = new ThreadPoolExecutor(
- desiredParallelism,
- desiredParallelism,
- 5L,
- TimeUnit.MINUTES,
- new LinkedBlockingQueue[Runnable],
- threadFactory
- )
- exec.allowCoreThreadTimeOut(true)
- exec
- }
- }
+ numThreads,
+ maxNoOfThreads)
- def execute(runnable: Runnable): Unit = executor match {
- case fj: ForkJoinPool =>
- val fjt: ForkJoinTask[_] = runnable match {
- case t: ForkJoinTask[_] => t
- case r => new ExecutionContextImpl.AdaptedForkJoinTask(r)
- }
- Thread.currentThread match {
- case fjw: ForkJoinWorkerThread if fjw.getPool eq fj => fjt.fork()
- case _ => fj execute fjt
- }
- case generic => generic execute runnable
- }
+ // The thread factory must provide additional threads to support managed blocking.
+ val maxExtraThreads = getInt("scala.concurrent.context.maxExtraThreads", "256")
- def reportFailure(t: Throwable) = reporter(t)
-}
+ val uncaughtExceptionHandler: Thread.UncaughtExceptionHandler = new Thread.UncaughtExceptionHandler {
+ override def uncaughtException(thread: Thread, cause: Throwable): Unit = reporter(cause)
+ }
+ val threadFactory = new ExecutionContextImpl.DefaultThreadFactory(daemonic = true,
+ maxThreads = maxNoOfThreads + maxExtraThreads,
+ prefix = "scala-execution-context-global",
+ uncaught = uncaughtExceptionHandler)
-private[concurrent] object ExecutionContextImpl {
+ new ForkJoinPool(desiredParallelism, threadFactory, uncaughtExceptionHandler, true) {
+ override def execute(runnable: Runnable): Unit = {
+ val fjt: ForkJoinTask[_] = runnable match {
+ case t: ForkJoinTask[_] => t
+ case r => new ExecutionContextImpl.AdaptedForkJoinTask(r)
+ }
+ Thread.currentThread match {
+ case fjw: ForkJoinWorkerThread if fjw.getPool eq this => fjt.fork()
+ case _ => super.execute(fjt)
+ }
+ }
+ }
+ }
final class AdaptedForkJoinTask(runnable: Runnable) extends ForkJoinTask[Unit] {
- final override def setRawResult(u: Unit): Unit = ()
- final override def getRawResult(): Unit = ()
- final override def exec(): Boolean = try { runnable.run(); true } catch {
- case anything: Throwable ⇒
- val t = Thread.currentThread
- t.getUncaughtExceptionHandler match {
- case null ⇒
- case some ⇒ some.uncaughtException(t, anything)
- }
- throw anything
- }
+ final override def setRawResult(u: Unit): Unit = ()
+ final override def getRawResult(): Unit = ()
+ final override def exec(): Boolean = try { runnable.run(); true } catch {
+ case anything: Throwable =>
+ val t = Thread.currentThread
+ t.getUncaughtExceptionHandler match {
+ case null =>
+ case some => some.uncaughtException(t, anything)
}
+ throw anything
+ }
+ }
- def fromExecutor(e: Executor, reporter: Throwable => Unit = ExecutionContext.defaultReporter): ExecutionContextImpl = new ExecutionContextImpl(e, reporter)
- def fromExecutorService(es: ExecutorService, reporter: Throwable => Unit = ExecutionContext.defaultReporter): ExecutionContextImpl with ExecutionContextExecutorService =
- new ExecutionContextImpl(es, reporter) with ExecutionContextExecutorService {
- final def asExecutorService: ExecutorService = executor.asInstanceOf[ExecutorService]
- override def execute(command: Runnable) = executor.execute(command)
- override def shutdown() { asExecutorService.shutdown() }
- override def shutdownNow() = asExecutorService.shutdownNow()
- override def isShutdown = asExecutorService.isShutdown
- override def isTerminated = asExecutorService.isTerminated
- override def awaitTermination(l: Long, timeUnit: TimeUnit) = asExecutorService.awaitTermination(l, timeUnit)
- override def submit[T](callable: Callable[T]) = asExecutorService.submit(callable)
- override def submit[T](runnable: Runnable, t: T) = asExecutorService.submit(runnable, t)
- override def submit(runnable: Runnable) = asExecutorService.submit(runnable)
- override def invokeAll[T](callables: Collection[_ <: Callable[T]]) = asExecutorService.invokeAll(callables)
- override def invokeAll[T](callables: Collection[_ <: Callable[T]], l: Long, timeUnit: TimeUnit) = asExecutorService.invokeAll(callables, l, timeUnit)
- override def invokeAny[T](callables: Collection[_ <: Callable[T]]) = asExecutorService.invokeAny(callables)
- override def invokeAny[T](callables: Collection[_ <: Callable[T]], l: Long, timeUnit: TimeUnit) = asExecutorService.invokeAny(callables, l, timeUnit)
+ def fromExecutor(e: Executor, reporter: Throwable => Unit = ExecutionContext.defaultReporter): ExecutionContextImpl =
+ new ExecutionContextImpl(Option(e).getOrElse(createDefaultExecutorService(reporter)), reporter)
+
+ def fromExecutorService(es: ExecutorService, reporter: Throwable => Unit = ExecutionContext.defaultReporter):
+ ExecutionContextImpl with ExecutionContextExecutorService = {
+ new ExecutionContextImpl(Option(es).getOrElse(createDefaultExecutorService(reporter)), reporter)
+ with ExecutionContextExecutorService {
+ final def asExecutorService: ExecutorService = executor.asInstanceOf[ExecutorService]
+ override def execute(command: Runnable) = executor.execute(command)
+ override def shutdown() { asExecutorService.shutdown() }
+ override def shutdownNow() = asExecutorService.shutdownNow()
+ override def isShutdown = asExecutorService.isShutdown
+ override def isTerminated = asExecutorService.isTerminated
+ override def awaitTermination(l: Long, timeUnit: TimeUnit) = asExecutorService.awaitTermination(l, timeUnit)
+ override def submit[T](callable: Callable[T]) = asExecutorService.submit(callable)
+ override def submit[T](runnable: Runnable, t: T) = asExecutorService.submit(runnable, t)
+ override def submit(runnable: Runnable) = asExecutorService.submit(runnable)
+ override def invokeAll[T](callables: Collection[_ <: Callable[T]]) = asExecutorService.invokeAll(callables)
+ override def invokeAll[T](callables: Collection[_ <: Callable[T]], l: Long, timeUnit: TimeUnit) = asExecutorService.invokeAll(callables, l, timeUnit)
+ override def invokeAny[T](callables: Collection[_ <: Callable[T]]) = asExecutorService.invokeAny(callables)
+ override def invokeAny[T](callables: Collection[_ <: Callable[T]], l: Long, timeUnit: TimeUnit) = asExecutorService.invokeAny(callables, l, timeUnit)
+ }
}
}