summaryrefslogtreecommitdiff
path: root/src/library/scala/concurrent/impl
diff options
context:
space:
mode:
Diffstat (limited to 'src/library/scala/concurrent/impl')
-rw-r--r--src/library/scala/concurrent/impl/AbstractPromise.java17
-rw-r--r--src/library/scala/concurrent/impl/ExecutionContextImpl.scala230
-rw-r--r--src/library/scala/concurrent/impl/Future.scala34
-rw-r--r--src/library/scala/concurrent/impl/Promise.scala181
4 files changed, 249 insertions, 213 deletions
diff --git a/src/library/scala/concurrent/impl/AbstractPromise.java b/src/library/scala/concurrent/impl/AbstractPromise.java
deleted file mode 100644
index c2520a1692..0000000000
--- a/src/library/scala/concurrent/impl/AbstractPromise.java
+++ /dev/null
@@ -1,17 +0,0 @@
-/* __ *\
-** ________ ___ / / ___ Scala API **
-** / __/ __// _ | / / / _ | (c) 2003-2015, LAMP/EPFL **
-** __\ \/ /__/ __ |/ /__/ __ | http://scala-lang.org/ **
-** /____/\___/_/ |_/____/_/ | | **
-** |/ **
-\* */
-
-package scala.concurrent.impl;
-
-import java.util.concurrent.atomic.AtomicReference;
-
-@Deprecated // Since 2.11.8. Extend java.util.concurrent.atomic.AtomicReference instead.
-abstract class AbstractPromise extends AtomicReference<Object> {
- protected final boolean updateState(Object oldState, Object newState) { return compareAndSet(oldState, newState); }
- protected final Object getState() { return get(); }
-}
diff --git a/src/library/scala/concurrent/impl/ExecutionContextImpl.scala b/src/library/scala/concurrent/impl/ExecutionContextImpl.scala
index 479720287c..19233d7531 100644
--- a/src/library/scala/concurrent/impl/ExecutionContextImpl.scala
+++ b/src/library/scala/concurrent/impl/ExecutionContextImpl.scala
@@ -8,55 +8,87 @@
package scala.concurrent.impl
-
-
-import java.util.concurrent.{ LinkedBlockingQueue, Callable, Executor, ExecutorService, Executors, ThreadFactory, TimeUnit, ThreadPoolExecutor }
+import java.util.concurrent.{ ForkJoinPool, ForkJoinWorkerThread, ForkJoinTask, Callable, Executor, ExecutorService, ThreadFactory, TimeUnit }
+import java.util.concurrent.atomic.AtomicInteger
import java.util.Collection
-import scala.concurrent.forkjoin._
-import scala.concurrent.{ BlockContext, ExecutionContext, Awaitable, CanAwait, ExecutionContextExecutor, ExecutionContextExecutorService }
-import scala.util.control.NonFatal
+import scala.concurrent.{ BlockContext, ExecutionContext, CanAwait, ExecutionContextExecutor, ExecutionContextExecutorService }
+import scala.annotation.tailrec
+private[scala] class ExecutionContextImpl private[impl] (val executor: Executor, val reporter: Throwable => Unit) extends ExecutionContextExecutor {
+ require(executor ne null, "Executor must not be null")
+ override def execute(runnable: Runnable) = executor execute runnable
+ override def reportFailure(t: Throwable) = reporter(t)
+}
-private[scala] class ExecutionContextImpl private[impl] (es: Executor, reporter: Throwable => Unit) extends ExecutionContextExecutor {
- // Placed here since the creation of the executor needs to read this val
- private[this] val uncaughtExceptionHandler: Thread.UncaughtExceptionHandler = new Thread.UncaughtExceptionHandler {
- def uncaughtException(thread: Thread, cause: Throwable): Unit = reporter(cause)
- }
- val executor: Executor = es match {
- case null => createExecutorService
- case some => some
- }
+private[concurrent] object ExecutionContextImpl {
// Implement BlockContext on FJP threads
- class DefaultThreadFactory(daemonic: Boolean) extends ThreadFactory with ForkJoinPool.ForkJoinWorkerThreadFactory {
+ final class DefaultThreadFactory(
+ daemonic: Boolean,
+ maxThreads: Int,
+ prefix: String,
+ uncaught: Thread.UncaughtExceptionHandler) extends ThreadFactory with ForkJoinPool.ForkJoinWorkerThreadFactory {
+
+ require(prefix ne null, "DefaultThreadFactory.prefix must be non null")
+ require(maxThreads > 0, "DefaultThreadFactory.maxThreads must be greater than 0")
+
+ private final val currentNumberOfThreads = new AtomicInteger(0)
+
+ @tailrec private final def reserveThread(): Boolean = currentNumberOfThreads.get() match {
+ case `maxThreads` | Int.`MaxValue` => false
+ case other => currentNumberOfThreads.compareAndSet(other, other + 1) || reserveThread()
+ }
+
+ @tailrec private final def deregisterThread(): Boolean = currentNumberOfThreads.get() match {
+ case 0 => false
+ case other => currentNumberOfThreads.compareAndSet(other, other - 1) || deregisterThread()
+ }
+
def wire[T <: Thread](thread: T): T = {
thread.setDaemon(daemonic)
- thread.setUncaughtExceptionHandler(uncaughtExceptionHandler)
+ thread.setUncaughtExceptionHandler(uncaught)
+ thread.setName(prefix + "-" + thread.getId())
thread
}
- def newThread(runnable: Runnable): Thread = wire(new Thread(runnable))
-
- def newThread(fjp: ForkJoinPool): ForkJoinWorkerThread = wire(new ForkJoinWorkerThread(fjp) with BlockContext {
- override def blockOn[T](thunk: =>T)(implicit permission: CanAwait): T = {
- var result: T = null.asInstanceOf[T]
- ForkJoinPool.managedBlock(new ForkJoinPool.ManagedBlocker {
- @volatile var isdone = false
- override def block(): Boolean = {
- result = try thunk finally { isdone = true }
- true
+ // As per ThreadFactory contract newThread should return `null` if cannot create new thread.
+ def newThread(runnable: Runnable): Thread =
+ if (reserveThread())
+ wire(new Thread(new Runnable {
+ // We have to decrement the current thread count when the thread exits
+ override def run() = try runnable.run() finally deregisterThread()
+ })) else null
+
+ def newThread(fjp: ForkJoinPool): ForkJoinWorkerThread =
+ if (reserveThread()) {
+ wire(new ForkJoinWorkerThread(fjp) with BlockContext {
+ // We have to decrement the current thread count when the thread exits
+ final override def onTermination(exception: Throwable): Unit = deregisterThread()
+ final override def blockOn[T](thunk: =>T)(implicit permission: CanAwait): T = {
+ var result: T = null.asInstanceOf[T]
+ ForkJoinPool.managedBlock(new ForkJoinPool.ManagedBlocker {
+ @volatile var isdone = false
+ override def block(): Boolean = {
+ result = try {
+ // When we block, switch out the BlockContext temporarily so that nested blocking does not created N new Threads
+ BlockContext.withBlockContext(BlockContext.defaultBlockContext) { thunk }
+ } finally {
+ isdone = true
+ }
+
+ true
+ }
+ override def isReleasable = isdone
+ })
+ result
}
- override def isReleasable = isdone
})
- result
- }
- })
+ } else null
}
- def createExecutorService: ExecutorService = {
-
+ def createDefaultExecutorService(reporter: Throwable => Unit): ExecutorService = {
def getInt(name: String, default: String) = (try System.getProperty(name, default) catch {
case e: SecurityException => default
}) match {
@@ -65,87 +97,79 @@ private[scala] class ExecutionContextImpl private[impl] (es: Executor, reporter:
}
def range(floor: Int, desired: Int, ceiling: Int) = scala.math.min(scala.math.max(floor, desired), ceiling)
+ val numThreads = getInt("scala.concurrent.context.numThreads", "x1")
+ // The hard limit on the number of active threads that the thread factory will produce
+ // SI-8955 Deadlocks can happen if maxNoOfThreads is too low, although we're currently not sure
+ // about what the exact threshold is. numThreads + 256 is conservatively high.
+ val maxNoOfThreads = getInt("scala.concurrent.context.maxThreads", "x1")
val desiredParallelism = range(
getInt("scala.concurrent.context.minThreads", "1"),
- getInt("scala.concurrent.context.numThreads", "x1"),
- getInt("scala.concurrent.context.maxThreads", "x1"))
-
- val threadFactory = new DefaultThreadFactory(daemonic = true)
-
- try {
- new ForkJoinPool(
- desiredParallelism,
- threadFactory,
- uncaughtExceptionHandler,
- true) // Async all the way baby
- } catch {
- case NonFatal(t) =>
- System.err.println("Failed to create ForkJoinPool for the default ExecutionContext, falling back to ThreadPoolExecutor")
- t.printStackTrace(System.err)
- val exec = new ThreadPoolExecutor(
- desiredParallelism,
- desiredParallelism,
- 5L,
- TimeUnit.MINUTES,
- new LinkedBlockingQueue[Runnable],
- threadFactory
- )
- exec.allowCoreThreadTimeOut(true)
- exec
- }
- }
+ numThreads,
+ maxNoOfThreads)
- def execute(runnable: Runnable): Unit = executor match {
- case fj: ForkJoinPool =>
- val fjt: ForkJoinTask[_] = runnable match {
- case t: ForkJoinTask[_] => t
- case r => new ExecutionContextImpl.AdaptedForkJoinTask(r)
- }
- Thread.currentThread match {
- case fjw: ForkJoinWorkerThread if fjw.getPool eq fj => fjt.fork()
- case _ => fj execute fjt
- }
- case generic => generic execute runnable
- }
+ // The thread factory must provide additional threads to support managed blocking.
+ val maxExtraThreads = getInt("scala.concurrent.context.maxExtraThreads", "256")
- def reportFailure(t: Throwable) = reporter(t)
-}
+ val uncaughtExceptionHandler: Thread.UncaughtExceptionHandler = new Thread.UncaughtExceptionHandler {
+ override def uncaughtException(thread: Thread, cause: Throwable): Unit = reporter(cause)
+ }
+ val threadFactory = new ExecutionContextImpl.DefaultThreadFactory(daemonic = true,
+ maxThreads = maxNoOfThreads + maxExtraThreads,
+ prefix = "scala-execution-context-global",
+ uncaught = uncaughtExceptionHandler)
-private[concurrent] object ExecutionContextImpl {
+ new ForkJoinPool(desiredParallelism, threadFactory, uncaughtExceptionHandler, true) {
+ override def execute(runnable: Runnable): Unit = {
+ val fjt: ForkJoinTask[_] = runnable match {
+ case t: ForkJoinTask[_] => t
+ case r => new ExecutionContextImpl.AdaptedForkJoinTask(r)
+ }
+ Thread.currentThread match {
+ case fjw: ForkJoinWorkerThread if fjw.getPool eq this => fjt.fork()
+ case _ => super.execute(fjt)
+ }
+ }
+ }
+ }
final class AdaptedForkJoinTask(runnable: Runnable) extends ForkJoinTask[Unit] {
- final override def setRawResult(u: Unit): Unit = ()
- final override def getRawResult(): Unit = ()
- final override def exec(): Boolean = try { runnable.run(); true } catch {
- case anything: Throwable ⇒
- val t = Thread.currentThread
- t.getUncaughtExceptionHandler match {
- case null ⇒
- case some ⇒ some.uncaughtException(t, anything)
- }
- throw anything
- }
+ final override def setRawResult(u: Unit): Unit = ()
+ final override def getRawResult(): Unit = ()
+ final override def exec(): Boolean = try { runnable.run(); true } catch {
+ case anything: Throwable =>
+ val t = Thread.currentThread
+ t.getUncaughtExceptionHandler match {
+ case null =>
+ case some => some.uncaughtException(t, anything)
}
+ throw anything
+ }
+ }
- def fromExecutor(e: Executor, reporter: Throwable => Unit = ExecutionContext.defaultReporter): ExecutionContextImpl = new ExecutionContextImpl(e, reporter)
- def fromExecutorService(es: ExecutorService, reporter: Throwable => Unit = ExecutionContext.defaultReporter): ExecutionContextImpl with ExecutionContextExecutorService =
- new ExecutionContextImpl(es, reporter) with ExecutionContextExecutorService {
- final def asExecutorService: ExecutorService = executor.asInstanceOf[ExecutorService]
- override def execute(command: Runnable) = executor.execute(command)
- override def shutdown() { asExecutorService.shutdown() }
- override def shutdownNow() = asExecutorService.shutdownNow()
- override def isShutdown = asExecutorService.isShutdown
- override def isTerminated = asExecutorService.isTerminated
- override def awaitTermination(l: Long, timeUnit: TimeUnit) = asExecutorService.awaitTermination(l, timeUnit)
- override def submit[T](callable: Callable[T]) = asExecutorService.submit(callable)
- override def submit[T](runnable: Runnable, t: T) = asExecutorService.submit(runnable, t)
- override def submit(runnable: Runnable) = asExecutorService.submit(runnable)
- override def invokeAll[T](callables: Collection[_ <: Callable[T]]) = asExecutorService.invokeAll(callables)
- override def invokeAll[T](callables: Collection[_ <: Callable[T]], l: Long, timeUnit: TimeUnit) = asExecutorService.invokeAll(callables, l, timeUnit)
- override def invokeAny[T](callables: Collection[_ <: Callable[T]]) = asExecutorService.invokeAny(callables)
- override def invokeAny[T](callables: Collection[_ <: Callable[T]], l: Long, timeUnit: TimeUnit) = asExecutorService.invokeAny(callables, l, timeUnit)
+ def fromExecutor(e: Executor, reporter: Throwable => Unit = ExecutionContext.defaultReporter): ExecutionContextImpl =
+ new ExecutionContextImpl(Option(e).getOrElse(createDefaultExecutorService(reporter)), reporter)
+
+ def fromExecutorService(es: ExecutorService, reporter: Throwable => Unit = ExecutionContext.defaultReporter):
+ ExecutionContextImpl with ExecutionContextExecutorService = {
+ new ExecutionContextImpl(Option(es).getOrElse(createDefaultExecutorService(reporter)), reporter)
+ with ExecutionContextExecutorService {
+ final def asExecutorService: ExecutorService = executor.asInstanceOf[ExecutorService]
+ override def execute(command: Runnable) = executor.execute(command)
+ override def shutdown() { asExecutorService.shutdown() }
+ override def shutdownNow() = asExecutorService.shutdownNow()
+ override def isShutdown = asExecutorService.isShutdown
+ override def isTerminated = asExecutorService.isTerminated
+ override def awaitTermination(l: Long, timeUnit: TimeUnit) = asExecutorService.awaitTermination(l, timeUnit)
+ override def submit[T](callable: Callable[T]) = asExecutorService.submit(callable)
+ override def submit[T](runnable: Runnable, t: T) = asExecutorService.submit(runnable, t)
+ override def submit(runnable: Runnable) = asExecutorService.submit(runnable)
+ override def invokeAll[T](callables: Collection[_ <: Callable[T]]) = asExecutorService.invokeAll(callables)
+ override def invokeAll[T](callables: Collection[_ <: Callable[T]], l: Long, timeUnit: TimeUnit) = asExecutorService.invokeAll(callables, l, timeUnit)
+ override def invokeAny[T](callables: Collection[_ <: Callable[T]]) = asExecutorService.invokeAny(callables)
+ override def invokeAny[T](callables: Collection[_ <: Callable[T]], l: Long, timeUnit: TimeUnit) = asExecutorService.invokeAny(callables, l, timeUnit)
+ }
}
}
diff --git a/src/library/scala/concurrent/impl/Future.scala b/src/library/scala/concurrent/impl/Future.scala
deleted file mode 100644
index 042d32c234..0000000000
--- a/src/library/scala/concurrent/impl/Future.scala
+++ /dev/null
@@ -1,34 +0,0 @@
-/* __ *\
-** ________ ___ / / ___ Scala API **
-** / __/ __// _ | / / / _ | (c) 2003-2013, LAMP/EPFL **
-** __\ \/ /__/ __ |/ /__/ __ | http://scala-lang.org/ **
-** /____/\___/_/ |_/____/_/ | | **
-** |/ **
-\* */
-
-package scala.concurrent.impl
-
-
-
-import scala.concurrent.ExecutionContext
-import scala.util.control.NonFatal
-import scala.util.{ Success, Failure }
-
-
-private[concurrent] object Future {
- class PromiseCompletingRunnable[T](body: => T) extends Runnable {
- val promise = new Promise.DefaultPromise[T]()
-
- override def run() = {
- promise complete {
- try Success(body) catch { case NonFatal(e) => Failure(e) }
- }
- }
- }
-
- def apply[T](body: =>T)(implicit executor: ExecutionContext): scala.concurrent.Future[T] = {
- val runnable = new PromiseCompletingRunnable(body)
- executor.prepare.execute(runnable)
- runnable.promise.future
- }
-}
diff --git a/src/library/scala/concurrent/impl/Promise.scala b/src/library/scala/concurrent/impl/Promise.scala
index 6d2fc5c87c..7fcc8c9f2d 100644
--- a/src/library/scala/concurrent/impl/Promise.scala
+++ b/src/library/scala/concurrent/impl/Promise.scala
@@ -8,17 +8,41 @@
package scala.concurrent.impl
-import scala.concurrent.{ ExecutionContext, CanAwait, OnCompleteRunnable, TimeoutException, ExecutionException, blocking }
+import scala.concurrent.{ ExecutionContext, CanAwait, OnCompleteRunnable, TimeoutException, ExecutionException }
import scala.concurrent.Future.InternalCallbackExecutor
-import scala.concurrent.duration.{ Duration, Deadline, FiniteDuration, NANOSECONDS }
+import scala.concurrent.duration.{ Duration, FiniteDuration }
import scala.annotation.tailrec
import scala.util.control.NonFatal
import scala.util.{ Try, Success, Failure }
-import java.io.ObjectInputStream
+
import java.util.concurrent.locks.AbstractQueuedSynchronizer
+import java.util.concurrent.atomic.AtomicReference
private[concurrent] trait Promise[T] extends scala.concurrent.Promise[T] with scala.concurrent.Future[T] {
def future: this.type = this
+
+ import scala.concurrent.Future
+ import scala.concurrent.impl.Promise.DefaultPromise
+
+ override def transform[S](f: Try[T] => Try[S])(implicit executor: ExecutionContext): Future[S] = {
+ val p = new DefaultPromise[S]()
+ onComplete { result => p.complete(try f(result) catch { case NonFatal(t) => Failure(t) }) }
+ p.future
+ }
+
+ // If possible, link DefaultPromises to avoid space leaks
+ override def transformWith[S](f: Try[T] => Future[S])(implicit executor: ExecutionContext): Future[S] = {
+ val p = new DefaultPromise[S]()
+ onComplete {
+ v => try f(v) match {
+ case fut if fut eq this => p complete v.asInstanceOf[Try[S]]
+ case dp: DefaultPromise[_] => dp.asInstanceOf[DefaultPromise[S]].linkRootOf(p)
+ case fut => p completeWith fut
+ } catch { case NonFatal(t) => p failure t }
+ }
+ p.future
+ }
+
override def toString: String = value match {
case Some(result) => "Future("+result+")"
case None => "Future(<not completed>)"
@@ -27,7 +51,7 @@ private[concurrent] trait Promise[T] extends scala.concurrent.Promise[T] with sc
/* Precondition: `executor` is prepared, i.e., `executor` has been returned from invocation of `prepare` on some other `ExecutionContext`.
*/
-private class CallbackRunnable[T](val executor: ExecutionContext, val onComplete: Try[T] => Any) extends Runnable with OnCompleteRunnable {
+private final class CallbackRunnable[T](val executor: ExecutionContext, val onComplete: Try[T] => Any) extends Runnable with OnCompleteRunnable {
// must be filled in before running it
var value: Try[T] = null
@@ -93,7 +117,7 @@ private[concurrent] object Promise {
* incomplete, or as complete with the same result value.
*
* A DefaultPromise stores its state entirely in the AnyRef cell exposed by
- * AbstractPromise. The type of object stored in the cell fully describes the
+ * AtomicReference. The type of object stored in the cell fully describes the
* current state of the promise.
*
* 1. List[CallbackRunnable] - The promise is incomplete and has zero or more callbacks
@@ -154,8 +178,9 @@ private[concurrent] object Promise {
* DefaultPromises, and `linkedRootOf` is currently only designed to be called
* by Future.flatMap.
*/
- class DefaultPromise[T] extends AbstractPromise with Promise[T] { self =>
- updateState(null, Nil) // The promise is incomplete and has no callbacks
+ // Left non-final to enable addition of extra fields by Java/Scala converters
+ // in scala-java8-compat.
+ class DefaultPromise[T] extends AtomicReference[AnyRef](Nil) with Promise[T] {
/** Get the root promise for this promise, compressing the link chain to that
* promise if necessary.
@@ -171,14 +196,23 @@ private[concurrent] object Promise {
* be garbage collected. Also, subsequent calls to this method should be
* faster as the link chain will be shorter.
*/
- @tailrec
- private def compressedRoot(): DefaultPromise[T] = {
- getState match {
- case linked: DefaultPromise[_] =>
- val target = linked.asInstanceOf[DefaultPromise[T]].root
- if (linked eq target) target else if (updateState(linked, target)) target else compressedRoot()
+ private def compressedRoot(): DefaultPromise[T] =
+ get() match {
+ case linked: DefaultPromise[_] => compressedRoot(linked)
case _ => this
}
+
+ @tailrec
+ private[this] final def compressedRoot(linked: DefaultPromise[_]): DefaultPromise[T] = {
+ val target = linked.asInstanceOf[DefaultPromise[T]].root
+ if (linked eq target) target
+ else if (compareAndSet(linked, target)) target
+ else {
+ get() match {
+ case newLinked: DefaultPromise[_] => compressedRoot(newLinked)
+ case _ => this
+ }
+ }
}
/** Get the promise at the root of the chain of linked promises. Used by `compressedRoot()`.
@@ -186,18 +220,16 @@ private[concurrent] object Promise {
* to compress the link chain whenever possible.
*/
@tailrec
- private def root: DefaultPromise[T] = {
- getState match {
+ private def root: DefaultPromise[T] =
+ get() match {
case linked: DefaultPromise[_] => linked.asInstanceOf[DefaultPromise[T]].root
case _ => this
}
- }
/** Try waiting for this promise to be completed.
*/
protected final def tryAwait(atMost: Duration): Boolean = if (!isCompleted) {
import Duration.Undefined
- import scala.concurrent.Future.InternalCallbackExecutor
atMost match {
case e if e eq Undefined => throw new IllegalArgumentException("cannot wait for Undefined period")
case Duration.Inf =>
@@ -218,33 +250,33 @@ private[concurrent] object Promise {
@throws(classOf[TimeoutException])
@throws(classOf[InterruptedException])
- def ready(atMost: Duration)(implicit permit: CanAwait): this.type =
+ final def ready(atMost: Duration)(implicit permit: CanAwait): this.type =
if (tryAwait(atMost)) this
else throw new TimeoutException("Futures timed out after [" + atMost + "]")
@throws(classOf[Exception])
- def result(atMost: Duration)(implicit permit: CanAwait): T =
+ final def result(atMost: Duration)(implicit permit: CanAwait): T =
ready(atMost).value.get.get // ready throws TimeoutException if timeout so value.get is safe here
def value: Option[Try[T]] = value0
@tailrec
- private def value0: Option[Try[T]] = getState match {
+ private def value0: Option[Try[T]] = get() match {
case c: Try[_] => Some(c.asInstanceOf[Try[T]])
- case _: DefaultPromise[_] => compressedRoot().value0
+ case dp: DefaultPromise[_] => compressedRoot(dp).value0
case _ => None
}
- override def isCompleted: Boolean = isCompleted0
+ override final def isCompleted: Boolean = isCompleted0
@tailrec
- private def isCompleted0: Boolean = getState match {
+ private def isCompleted0: Boolean = get() match {
case _: Try[_] => true
- case _: DefaultPromise[_] => compressedRoot().isCompleted0
+ case dp: DefaultPromise[_] => compressedRoot(dp).isCompleted0
case _ => false
}
- def tryComplete(value: Try[T]): Boolean = {
+ final def tryComplete(value: Try[T]): Boolean = {
val resolved = resolveTry(value)
tryCompleteAndGetListeners(resolved) match {
case null => false
@@ -258,37 +290,34 @@ private[concurrent] object Promise {
*/
@tailrec
private def tryCompleteAndGetListeners(v: Try[T]): List[CallbackRunnable[T]] = {
- getState match {
+ get() match {
case raw: List[_] =>
val cur = raw.asInstanceOf[List[CallbackRunnable[T]]]
- if (updateState(cur, v)) cur else tryCompleteAndGetListeners(v)
- case _: DefaultPromise[_] =>
- compressedRoot().tryCompleteAndGetListeners(v)
+ if (compareAndSet(cur, v)) cur else tryCompleteAndGetListeners(v)
+ case dp: DefaultPromise[_] => compressedRoot(dp).tryCompleteAndGetListeners(v)
case _ => null
}
}
- def onComplete[U](func: Try[T] => U)(implicit executor: ExecutionContext): Unit = {
- val preparedEC = executor.prepare()
- val runnable = new CallbackRunnable[T](preparedEC, func)
- dispatchOrAddCallback(runnable)
- }
+ final def onComplete[U](func: Try[T] => U)(implicit executor: ExecutionContext): Unit =
+ dispatchOrAddCallback(new CallbackRunnable[T](executor.prepare(), func))
/** Tries to add the callback, if already completed, it dispatches the callback to be executed.
* Used by `onComplete()` to add callbacks to a promise and by `link()` to transfer callbacks
- * to the root promise when linking two promises togehter.
+ * to the root promise when linking two promises together.
*/
@tailrec
private def dispatchOrAddCallback(runnable: CallbackRunnable[T]): Unit = {
- getState match {
+ get() match {
case r: Try[_] => runnable.executeWithValue(r.asInstanceOf[Try[T]])
- case _: DefaultPromise[_] => compressedRoot().dispatchOrAddCallback(runnable)
- case listeners: List[_] => if (updateState(listeners, runnable :: listeners)) () else dispatchOrAddCallback(runnable)
+ case dp: DefaultPromise[_] => compressedRoot(dp).dispatchOrAddCallback(runnable)
+ case listeners: List[_] => if (compareAndSet(listeners, runnable :: listeners)) ()
+ else dispatchOrAddCallback(runnable)
}
}
/** Link this promise to the root of another promise using `link()`. Should only be
- * be called by Future.flatMap.
+ * be called by transformWith.
*/
protected[concurrent] final def linkRootOf(target: DefaultPromise[T]): Unit = link(target.compressedRoot())
@@ -303,18 +332,17 @@ private[concurrent] object Promise {
*/
@tailrec
private def link(target: DefaultPromise[T]): Unit = if (this ne target) {
- getState match {
+ get() match {
case r: Try[_] =>
- if (!target.tryComplete(r.asInstanceOf[Try[T]])) {
- // Currently linking is done from Future.flatMap, which should ensure only
- // one promise can be completed. Therefore this situation is unexpected.
+ if (!target.tryComplete(r.asInstanceOf[Try[T]]))
throw new IllegalStateException("Cannot link completed promises together")
- }
- case _: DefaultPromise[_] =>
- compressedRoot().link(target)
- case listeners: List[_] => if (updateState(listeners, target)) {
- if (!listeners.isEmpty) listeners.asInstanceOf[List[CallbackRunnable[T]]].foreach(target.dispatchOrAddCallback(_))
- } else link(target)
+ case dp: DefaultPromise[_] =>
+ compressedRoot(dp).link(target)
+ case listeners: List[_] if compareAndSet(listeners, target) =>
+ if (listeners.nonEmpty)
+ listeners.asInstanceOf[List[CallbackRunnable[T]]].foreach(target.dispatchOrAddCallback(_))
+ case _ =>
+ link(target)
}
}
}
@@ -323,23 +351,58 @@ private[concurrent] object Promise {
*
* Useful in Future-composition when a value to contribute is already available.
*/
- final class KeptPromise[T](suppliedValue: Try[T]) extends Promise[T] {
+ object KeptPromise {
+ import scala.concurrent.Future
+ import scala.reflect.ClassTag
+
+ private[this] sealed trait Kept[T] extends Promise[T] {
+ def result: Try[T]
+
+ override def value: Option[Try[T]] = Some(result)
- val value = Some(resolveTry(suppliedValue))
+ override def isCompleted: Boolean = true
- override def isCompleted: Boolean = true
+ override def tryComplete(value: Try[T]): Boolean = false
- def tryComplete(value: Try[T]): Boolean = false
+ override def onComplete[U](func: Try[T] => U)(implicit executor: ExecutionContext): Unit =
+ (new CallbackRunnable(executor.prepare(), func)).executeWithValue(result)
- def onComplete[U](func: Try[T] => U)(implicit executor: ExecutionContext): Unit = {
- val completedAs = value.get
- val preparedEC = executor.prepare()
- (new CallbackRunnable(preparedEC, func)).executeWithValue(completedAs)
+ override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = this
+
+ override def result(atMost: Duration)(implicit permit: CanAwait): T = result.get
}
- def ready(atMost: Duration)(implicit permit: CanAwait): this.type = this
+ private[this] final class Successful[T](val result: Success[T]) extends Kept[T] {
+ override def onFailure[U](pf: PartialFunction[Throwable, U])(implicit executor: ExecutionContext): Unit = ()
+ override def failed: Future[Throwable] = KeptPromise(Failure(new NoSuchElementException("Future.failed not completed with a throwable."))).future
+ override def recover[U >: T](pf: PartialFunction[Throwable, U])(implicit executor: ExecutionContext): Future[U] = this
+ override def recoverWith[U >: T](pf: PartialFunction[Throwable, Future[U]])(implicit executor: ExecutionContext): Future[U] = this
+ override def fallbackTo[U >: T](that: Future[U]): Future[U] = this
+ }
- def result(atMost: Duration)(implicit permit: CanAwait): T = value.get.get
+ private[this] final class Failed[T](val result: Failure[T]) extends Kept[T] {
+ private[this] final def thisAs[S]: Future[S] = future.asInstanceOf[Future[S]]
+
+ override def onSuccess[U](pf: PartialFunction[T, U])(implicit executor: ExecutionContext): Unit = ()
+ override def failed: Future[Throwable] = KeptPromise(Success(result.exception)).future
+ override def foreach[U](f: T => U)(implicit executor: ExecutionContext): Unit = ()
+ override def map[S](f: T => S)(implicit executor: ExecutionContext): Future[S] = thisAs[S]
+ override def flatMap[S](f: T => Future[S])(implicit executor: ExecutionContext): Future[S] = thisAs[S]
+ override def flatten[S](implicit ev: T <:< Future[S]): Future[S] = thisAs[S]
+ override def filter(p: T => Boolean)(implicit executor: ExecutionContext): Future[T] = this
+ override def collect[S](pf: PartialFunction[T, S])(implicit executor: ExecutionContext): Future[S] = thisAs[S]
+ override def zip[U](that: Future[U]): Future[(T, U)] = thisAs[(T,U)]
+ override def zipWith[U, R](that: Future[U])(f: (T, U) => R)(implicit executor: ExecutionContext): Future[R] = thisAs[R]
+ override def fallbackTo[U >: T](that: Future[U]): Future[U] =
+ if (this eq that) this else that.recoverWith({ case _ => this })(InternalCallbackExecutor)
+ override def mapTo[S](implicit tag: ClassTag[S]): Future[S] = thisAs[S]
+ }
+
+ def apply[T](result: Try[T]): scala.concurrent.Promise[T] =
+ resolveTry(result) match {
+ case s @ Success(_) => new Successful(s)
+ case f @ Failure(_) => new Failed(f)
+ }
}
}