diff options
Diffstat (limited to 'src/library/scala/concurrent/impl')
-rw-r--r-- | src/library/scala/concurrent/impl/AbstractPromise.java | 17 | ||||
-rw-r--r-- | src/library/scala/concurrent/impl/ExecutionContextImpl.scala | 230 | ||||
-rw-r--r-- | src/library/scala/concurrent/impl/Future.scala | 34 | ||||
-rw-r--r-- | src/library/scala/concurrent/impl/Promise.scala | 181 |
4 files changed, 249 insertions, 213 deletions
diff --git a/src/library/scala/concurrent/impl/AbstractPromise.java b/src/library/scala/concurrent/impl/AbstractPromise.java deleted file mode 100644 index c2520a1692..0000000000 --- a/src/library/scala/concurrent/impl/AbstractPromise.java +++ /dev/null @@ -1,17 +0,0 @@ -/* __ *\ -** ________ ___ / / ___ Scala API ** -** / __/ __// _ | / / / _ | (c) 2003-2015, LAMP/EPFL ** -** __\ \/ /__/ __ |/ /__/ __ | http://scala-lang.org/ ** -** /____/\___/_/ |_/____/_/ | | ** -** |/ ** -\* */ - -package scala.concurrent.impl; - -import java.util.concurrent.atomic.AtomicReference; - -@Deprecated // Since 2.11.8. Extend java.util.concurrent.atomic.AtomicReference instead. -abstract class AbstractPromise extends AtomicReference<Object> { - protected final boolean updateState(Object oldState, Object newState) { return compareAndSet(oldState, newState); } - protected final Object getState() { return get(); } -} diff --git a/src/library/scala/concurrent/impl/ExecutionContextImpl.scala b/src/library/scala/concurrent/impl/ExecutionContextImpl.scala index 479720287c..7bf5cc5729 100644 --- a/src/library/scala/concurrent/impl/ExecutionContextImpl.scala +++ b/src/library/scala/concurrent/impl/ExecutionContextImpl.scala @@ -8,55 +8,87 @@ package scala.concurrent.impl - - -import java.util.concurrent.{ LinkedBlockingQueue, Callable, Executor, ExecutorService, Executors, ThreadFactory, TimeUnit, ThreadPoolExecutor } +import java.util.concurrent.{ ForkJoinPool, ForkJoinWorkerThread, ForkJoinTask, Callable, Executor, ExecutorService, ThreadFactory, TimeUnit } +import java.util.concurrent.atomic.AtomicInteger import java.util.Collection -import scala.concurrent.forkjoin._ -import scala.concurrent.{ BlockContext, ExecutionContext, Awaitable, CanAwait, ExecutionContextExecutor, ExecutionContextExecutorService } -import scala.util.control.NonFatal +import scala.concurrent.{ BlockContext, ExecutionContext, CanAwait, ExecutionContextExecutor, ExecutionContextExecutorService } +import scala.annotation.tailrec +private[scala] class ExecutionContextImpl private[impl] (val executor: Executor, val reporter: Throwable => Unit) extends ExecutionContextExecutor { + require(executor ne null, "Executor must not be null") + override def execute(runnable: Runnable) = executor execute runnable + override def reportFailure(t: Throwable) = reporter(t) +} -private[scala] class ExecutionContextImpl private[impl] (es: Executor, reporter: Throwable => Unit) extends ExecutionContextExecutor { - // Placed here since the creation of the executor needs to read this val - private[this] val uncaughtExceptionHandler: Thread.UncaughtExceptionHandler = new Thread.UncaughtExceptionHandler { - def uncaughtException(thread: Thread, cause: Throwable): Unit = reporter(cause) - } - val executor: Executor = es match { - case null => createExecutorService - case some => some - } +private[concurrent] object ExecutionContextImpl { // Implement BlockContext on FJP threads - class DefaultThreadFactory(daemonic: Boolean) extends ThreadFactory with ForkJoinPool.ForkJoinWorkerThreadFactory { + final class DefaultThreadFactory( + daemonic: Boolean, + maxThreads: Int, + prefix: String, + uncaught: Thread.UncaughtExceptionHandler) extends ThreadFactory with ForkJoinPool.ForkJoinWorkerThreadFactory { + + require(prefix ne null, "DefaultThreadFactory.prefix must be non null") + require(maxThreads > 0, "DefaultThreadFactory.maxThreads must be greater than 0") + + private final val currentNumberOfThreads = new AtomicInteger(0) + + @tailrec private final def reserveThread(): Boolean = currentNumberOfThreads.get() match { + case `maxThreads` | Int.`MaxValue` => false + case other => currentNumberOfThreads.compareAndSet(other, other + 1) || reserveThread() + } + + @tailrec private final def deregisterThread(): Boolean = currentNumberOfThreads.get() match { + case 0 => false + case other => currentNumberOfThreads.compareAndSet(other, other - 1) || deregisterThread() + } + def wire[T <: Thread](thread: T): T = { thread.setDaemon(daemonic) - thread.setUncaughtExceptionHandler(uncaughtExceptionHandler) + thread.setUncaughtExceptionHandler(uncaught) + thread.setName(prefix + "-" + thread.getId()) thread } - def newThread(runnable: Runnable): Thread = wire(new Thread(runnable)) - - def newThread(fjp: ForkJoinPool): ForkJoinWorkerThread = wire(new ForkJoinWorkerThread(fjp) with BlockContext { - override def blockOn[T](thunk: =>T)(implicit permission: CanAwait): T = { - var result: T = null.asInstanceOf[T] - ForkJoinPool.managedBlock(new ForkJoinPool.ManagedBlocker { - @volatile var isdone = false - override def block(): Boolean = { - result = try thunk finally { isdone = true } - true + // As per ThreadFactory contract newThread should return `null` if cannot create new thread. + def newThread(runnable: Runnable): Thread = + if (reserveThread()) + wire(new Thread(new Runnable { + // We have to decrement the current thread count when the thread exits + override def run() = try runnable.run() finally deregisterThread() + })) else null + + def newThread(fjp: ForkJoinPool): ForkJoinWorkerThread = + if (reserveThread()) { + wire(new ForkJoinWorkerThread(fjp) with BlockContext { + // We have to decrement the current thread count when the thread exits + final override def onTermination(exception: Throwable): Unit = deregisterThread() + final override def blockOn[T](thunk: =>T)(implicit permission: CanAwait): T = { + var result: T = null.asInstanceOf[T] + ForkJoinPool.managedBlock(new ForkJoinPool.ManagedBlocker { + @volatile var isdone = false + override def block(): Boolean = { + result = try { + // When we block, switch out the BlockContext temporarily so that nested blocking does not created N new Threads + BlockContext.withBlockContext(BlockContext.defaultBlockContext) { thunk } + } finally { + isdone = true + } + + true + } + override def isReleasable = isdone + }) + result } - override def isReleasable = isdone }) - result - } - }) + } else null } - def createExecutorService: ExecutorService = { - + def createDefaultExecutorService(reporter: Throwable => Unit): ExecutorService = { def getInt(name: String, default: String) = (try System.getProperty(name, default) catch { case e: SecurityException => default }) match { @@ -65,87 +97,79 @@ private[scala] class ExecutionContextImpl private[impl] (es: Executor, reporter: } def range(floor: Int, desired: Int, ceiling: Int) = scala.math.min(scala.math.max(floor, desired), ceiling) + val numThreads = getInt("scala.concurrent.context.numThreads", "x1") + // The hard limit on the number of active threads that the thread factory will produce + // SI-8955 Deadlocks can happen if maxNoOfThreads is too low, although we're currently not sure + // about what the exact threshhold is. numThreads + 256 is conservatively high. + val maxNoOfThreads = getInt("scala.concurrent.context.maxThreads", "x1") val desiredParallelism = range( getInt("scala.concurrent.context.minThreads", "1"), - getInt("scala.concurrent.context.numThreads", "x1"), - getInt("scala.concurrent.context.maxThreads", "x1")) - - val threadFactory = new DefaultThreadFactory(daemonic = true) - - try { - new ForkJoinPool( - desiredParallelism, - threadFactory, - uncaughtExceptionHandler, - true) // Async all the way baby - } catch { - case NonFatal(t) => - System.err.println("Failed to create ForkJoinPool for the default ExecutionContext, falling back to ThreadPoolExecutor") - t.printStackTrace(System.err) - val exec = new ThreadPoolExecutor( - desiredParallelism, - desiredParallelism, - 5L, - TimeUnit.MINUTES, - new LinkedBlockingQueue[Runnable], - threadFactory - ) - exec.allowCoreThreadTimeOut(true) - exec - } - } + numThreads, + maxNoOfThreads) - def execute(runnable: Runnable): Unit = executor match { - case fj: ForkJoinPool => - val fjt: ForkJoinTask[_] = runnable match { - case t: ForkJoinTask[_] => t - case r => new ExecutionContextImpl.AdaptedForkJoinTask(r) - } - Thread.currentThread match { - case fjw: ForkJoinWorkerThread if fjw.getPool eq fj => fjt.fork() - case _ => fj execute fjt - } - case generic => generic execute runnable - } + // The thread factory must provide additional threads to support managed blocking. + val maxExtraThreads = getInt("scala.concurrent.context.maxExtraThreads", "256") - def reportFailure(t: Throwable) = reporter(t) -} + val uncaughtExceptionHandler: Thread.UncaughtExceptionHandler = new Thread.UncaughtExceptionHandler { + override def uncaughtException(thread: Thread, cause: Throwable): Unit = reporter(cause) + } + val threadFactory = new ExecutionContextImpl.DefaultThreadFactory(daemonic = true, + maxThreads = maxNoOfThreads + maxExtraThreads, + prefix = "scala-execution-context-global", + uncaught = uncaughtExceptionHandler) -private[concurrent] object ExecutionContextImpl { + new ForkJoinPool(desiredParallelism, threadFactory, uncaughtExceptionHandler, true) { + override def execute(runnable: Runnable): Unit = { + val fjt: ForkJoinTask[_] = runnable match { + case t: ForkJoinTask[_] => t + case r => new ExecutionContextImpl.AdaptedForkJoinTask(r) + } + Thread.currentThread match { + case fjw: ForkJoinWorkerThread if fjw.getPool eq this => fjt.fork() + case _ => super.execute(fjt) + } + } + } + } final class AdaptedForkJoinTask(runnable: Runnable) extends ForkJoinTask[Unit] { - final override def setRawResult(u: Unit): Unit = () - final override def getRawResult(): Unit = () - final override def exec(): Boolean = try { runnable.run(); true } catch { - case anything: Throwable ⇒ - val t = Thread.currentThread - t.getUncaughtExceptionHandler match { - case null ⇒ - case some ⇒ some.uncaughtException(t, anything) - } - throw anything - } + final override def setRawResult(u: Unit): Unit = () + final override def getRawResult(): Unit = () + final override def exec(): Boolean = try { runnable.run(); true } catch { + case anything: Throwable => + val t = Thread.currentThread + t.getUncaughtExceptionHandler match { + case null => + case some => some.uncaughtException(t, anything) } + throw anything + } + } - def fromExecutor(e: Executor, reporter: Throwable => Unit = ExecutionContext.defaultReporter): ExecutionContextImpl = new ExecutionContextImpl(e, reporter) - def fromExecutorService(es: ExecutorService, reporter: Throwable => Unit = ExecutionContext.defaultReporter): ExecutionContextImpl with ExecutionContextExecutorService = - new ExecutionContextImpl(es, reporter) with ExecutionContextExecutorService { - final def asExecutorService: ExecutorService = executor.asInstanceOf[ExecutorService] - override def execute(command: Runnable) = executor.execute(command) - override def shutdown() { asExecutorService.shutdown() } - override def shutdownNow() = asExecutorService.shutdownNow() - override def isShutdown = asExecutorService.isShutdown - override def isTerminated = asExecutorService.isTerminated - override def awaitTermination(l: Long, timeUnit: TimeUnit) = asExecutorService.awaitTermination(l, timeUnit) - override def submit[T](callable: Callable[T]) = asExecutorService.submit(callable) - override def submit[T](runnable: Runnable, t: T) = asExecutorService.submit(runnable, t) - override def submit(runnable: Runnable) = asExecutorService.submit(runnable) - override def invokeAll[T](callables: Collection[_ <: Callable[T]]) = asExecutorService.invokeAll(callables) - override def invokeAll[T](callables: Collection[_ <: Callable[T]], l: Long, timeUnit: TimeUnit) = asExecutorService.invokeAll(callables, l, timeUnit) - override def invokeAny[T](callables: Collection[_ <: Callable[T]]) = asExecutorService.invokeAny(callables) - override def invokeAny[T](callables: Collection[_ <: Callable[T]], l: Long, timeUnit: TimeUnit) = asExecutorService.invokeAny(callables, l, timeUnit) + def fromExecutor(e: Executor, reporter: Throwable => Unit = ExecutionContext.defaultReporter): ExecutionContextImpl = + new ExecutionContextImpl(Option(e).getOrElse(createDefaultExecutorService(reporter)), reporter) + + def fromExecutorService(es: ExecutorService, reporter: Throwable => Unit = ExecutionContext.defaultReporter): + ExecutionContextImpl with ExecutionContextExecutorService = { + new ExecutionContextImpl(Option(es).getOrElse(createDefaultExecutorService(reporter)), reporter) + with ExecutionContextExecutorService { + final def asExecutorService: ExecutorService = executor.asInstanceOf[ExecutorService] + override def execute(command: Runnable) = executor.execute(command) + override def shutdown() { asExecutorService.shutdown() } + override def shutdownNow() = asExecutorService.shutdownNow() + override def isShutdown = asExecutorService.isShutdown + override def isTerminated = asExecutorService.isTerminated + override def awaitTermination(l: Long, timeUnit: TimeUnit) = asExecutorService.awaitTermination(l, timeUnit) + override def submit[T](callable: Callable[T]) = asExecutorService.submit(callable) + override def submit[T](runnable: Runnable, t: T) = asExecutorService.submit(runnable, t) + override def submit(runnable: Runnable) = asExecutorService.submit(runnable) + override def invokeAll[T](callables: Collection[_ <: Callable[T]]) = asExecutorService.invokeAll(callables) + override def invokeAll[T](callables: Collection[_ <: Callable[T]], l: Long, timeUnit: TimeUnit) = asExecutorService.invokeAll(callables, l, timeUnit) + override def invokeAny[T](callables: Collection[_ <: Callable[T]]) = asExecutorService.invokeAny(callables) + override def invokeAny[T](callables: Collection[_ <: Callable[T]], l: Long, timeUnit: TimeUnit) = asExecutorService.invokeAny(callables, l, timeUnit) + } } } diff --git a/src/library/scala/concurrent/impl/Future.scala b/src/library/scala/concurrent/impl/Future.scala deleted file mode 100644 index 042d32c234..0000000000 --- a/src/library/scala/concurrent/impl/Future.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* __ *\ -** ________ ___ / / ___ Scala API ** -** / __/ __// _ | / / / _ | (c) 2003-2013, LAMP/EPFL ** -** __\ \/ /__/ __ |/ /__/ __ | http://scala-lang.org/ ** -** /____/\___/_/ |_/____/_/ | | ** -** |/ ** -\* */ - -package scala.concurrent.impl - - - -import scala.concurrent.ExecutionContext -import scala.util.control.NonFatal -import scala.util.{ Success, Failure } - - -private[concurrent] object Future { - class PromiseCompletingRunnable[T](body: => T) extends Runnable { - val promise = new Promise.DefaultPromise[T]() - - override def run() = { - promise complete { - try Success(body) catch { case NonFatal(e) => Failure(e) } - } - } - } - - def apply[T](body: =>T)(implicit executor: ExecutionContext): scala.concurrent.Future[T] = { - val runnable = new PromiseCompletingRunnable(body) - executor.prepare.execute(runnable) - runnable.promise.future - } -} diff --git a/src/library/scala/concurrent/impl/Promise.scala b/src/library/scala/concurrent/impl/Promise.scala index 6d2fc5c87c..626540425f 100644 --- a/src/library/scala/concurrent/impl/Promise.scala +++ b/src/library/scala/concurrent/impl/Promise.scala @@ -8,17 +8,41 @@ package scala.concurrent.impl -import scala.concurrent.{ ExecutionContext, CanAwait, OnCompleteRunnable, TimeoutException, ExecutionException, blocking } +import scala.concurrent.{ ExecutionContext, CanAwait, OnCompleteRunnable, TimeoutException, ExecutionException } import scala.concurrent.Future.InternalCallbackExecutor -import scala.concurrent.duration.{ Duration, Deadline, FiniteDuration, NANOSECONDS } +import scala.concurrent.duration.{ Duration, FiniteDuration } import scala.annotation.tailrec import scala.util.control.NonFatal import scala.util.{ Try, Success, Failure } -import java.io.ObjectInputStream + import java.util.concurrent.locks.AbstractQueuedSynchronizer +import java.util.concurrent.atomic.AtomicReference private[concurrent] trait Promise[T] extends scala.concurrent.Promise[T] with scala.concurrent.Future[T] { def future: this.type = this + + import scala.concurrent.Future + import scala.concurrent.impl.Promise.DefaultPromise + + override def transform[S](f: Try[T] => Try[S])(implicit executor: ExecutionContext): Future[S] = { + val p = new DefaultPromise[S]() + onComplete { result => p.complete(try f(result) catch { case NonFatal(t) => Failure(t) }) } + p.future + } + + // If possible, link DefaultPromises to avoid space leaks + override def transformWith[S](f: Try[T] => Future[S])(implicit executor: ExecutionContext): Future[S] = { + val p = new DefaultPromise[S]() + onComplete { + v => try f(v) match { + case fut if fut eq this => p complete v.asInstanceOf[Try[S]] + case dp: DefaultPromise[_] => dp.asInstanceOf[DefaultPromise[S]].linkRootOf(p) + case fut => p completeWith fut + } catch { case NonFatal(t) => p failure t } + } + p.future + } + override def toString: String = value match { case Some(result) => "Future("+result+")" case None => "Future(<not completed>)" @@ -27,7 +51,7 @@ private[concurrent] trait Promise[T] extends scala.concurrent.Promise[T] with sc /* Precondition: `executor` is prepared, i.e., `executor` has been returned from invocation of `prepare` on some other `ExecutionContext`. */ -private class CallbackRunnable[T](val executor: ExecutionContext, val onComplete: Try[T] => Any) extends Runnable with OnCompleteRunnable { +private final class CallbackRunnable[T](val executor: ExecutionContext, val onComplete: Try[T] => Any) extends Runnable with OnCompleteRunnable { // must be filled in before running it var value: Try[T] = null @@ -93,7 +117,7 @@ private[concurrent] object Promise { * incomplete, or as complete with the same result value. * * A DefaultPromise stores its state entirely in the AnyRef cell exposed by - * AbstractPromise. The type of object stored in the cell fully describes the + * AtomicReference. The type of object stored in the cell fully describes the * current state of the promise. * * 1. List[CallbackRunnable] - The promise is incomplete and has zero or more callbacks @@ -154,8 +178,9 @@ private[concurrent] object Promise { * DefaultPromises, and `linkedRootOf` is currently only designed to be called * by Future.flatMap. */ - class DefaultPromise[T] extends AbstractPromise with Promise[T] { self => - updateState(null, Nil) // The promise is incomplete and has no callbacks + // Left non-final to enable addition of extra fields by Java/Scala converters + // in scala-java8-compat. + class DefaultPromise[T] extends AtomicReference[AnyRef](Nil) with Promise[T] { /** Get the root promise for this promise, compressing the link chain to that * promise if necessary. @@ -171,14 +196,23 @@ private[concurrent] object Promise { * be garbage collected. Also, subsequent calls to this method should be * faster as the link chain will be shorter. */ - @tailrec - private def compressedRoot(): DefaultPromise[T] = { - getState match { - case linked: DefaultPromise[_] => - val target = linked.asInstanceOf[DefaultPromise[T]].root - if (linked eq target) target else if (updateState(linked, target)) target else compressedRoot() + private def compressedRoot(): DefaultPromise[T] = + get() match { + case linked: DefaultPromise[_] => compressedRoot(linked) case _ => this } + + @tailrec + private[this] final def compressedRoot(linked: DefaultPromise[_]): DefaultPromise[T] = { + val target = linked.asInstanceOf[DefaultPromise[T]].root + if (linked eq target) target + else if (compareAndSet(linked, target)) target + else { + get() match { + case newLinked: DefaultPromise[_] => compressedRoot(newLinked) + case _ => this + } + } } /** Get the promise at the root of the chain of linked promises. Used by `compressedRoot()`. @@ -186,18 +220,16 @@ private[concurrent] object Promise { * to compress the link chain whenever possible. */ @tailrec - private def root: DefaultPromise[T] = { - getState match { + private def root: DefaultPromise[T] = + get() match { case linked: DefaultPromise[_] => linked.asInstanceOf[DefaultPromise[T]].root case _ => this } - } /** Try waiting for this promise to be completed. */ protected final def tryAwait(atMost: Duration): Boolean = if (!isCompleted) { import Duration.Undefined - import scala.concurrent.Future.InternalCallbackExecutor atMost match { case e if e eq Undefined => throw new IllegalArgumentException("cannot wait for Undefined period") case Duration.Inf => @@ -218,33 +250,33 @@ private[concurrent] object Promise { @throws(classOf[TimeoutException]) @throws(classOf[InterruptedException]) - def ready(atMost: Duration)(implicit permit: CanAwait): this.type = + final def ready(atMost: Duration)(implicit permit: CanAwait): this.type = if (tryAwait(atMost)) this else throw new TimeoutException("Futures timed out after [" + atMost + "]") @throws(classOf[Exception]) - def result(atMost: Duration)(implicit permit: CanAwait): T = + final def result(atMost: Duration)(implicit permit: CanAwait): T = ready(atMost).value.get.get // ready throws TimeoutException if timeout so value.get is safe here def value: Option[Try[T]] = value0 @tailrec - private def value0: Option[Try[T]] = getState match { + private def value0: Option[Try[T]] = get() match { case c: Try[_] => Some(c.asInstanceOf[Try[T]]) - case _: DefaultPromise[_] => compressedRoot().value0 + case dp: DefaultPromise[_] => compressedRoot(dp).value0 case _ => None } - override def isCompleted: Boolean = isCompleted0 + override final def isCompleted: Boolean = isCompleted0 @tailrec - private def isCompleted0: Boolean = getState match { + private def isCompleted0: Boolean = get() match { case _: Try[_] => true - case _: DefaultPromise[_] => compressedRoot().isCompleted0 + case dp: DefaultPromise[_] => compressedRoot(dp).isCompleted0 case _ => false } - def tryComplete(value: Try[T]): Boolean = { + final def tryComplete(value: Try[T]): Boolean = { val resolved = resolveTry(value) tryCompleteAndGetListeners(resolved) match { case null => false @@ -258,37 +290,34 @@ private[concurrent] object Promise { */ @tailrec private def tryCompleteAndGetListeners(v: Try[T]): List[CallbackRunnable[T]] = { - getState match { + get() match { case raw: List[_] => val cur = raw.asInstanceOf[List[CallbackRunnable[T]]] - if (updateState(cur, v)) cur else tryCompleteAndGetListeners(v) - case _: DefaultPromise[_] => - compressedRoot().tryCompleteAndGetListeners(v) + if (compareAndSet(cur, v)) cur else tryCompleteAndGetListeners(v) + case dp: DefaultPromise[_] => compressedRoot(dp).tryCompleteAndGetListeners(v) case _ => null } } - def onComplete[U](func: Try[T] => U)(implicit executor: ExecutionContext): Unit = { - val preparedEC = executor.prepare() - val runnable = new CallbackRunnable[T](preparedEC, func) - dispatchOrAddCallback(runnable) - } + final def onComplete[U](func: Try[T] => U)(implicit executor: ExecutionContext): Unit = + dispatchOrAddCallback(new CallbackRunnable[T](executor.prepare(), func)) /** Tries to add the callback, if already completed, it dispatches the callback to be executed. * Used by `onComplete()` to add callbacks to a promise and by `link()` to transfer callbacks - * to the root promise when linking two promises togehter. + * to the root promise when linking two promises together. */ @tailrec private def dispatchOrAddCallback(runnable: CallbackRunnable[T]): Unit = { - getState match { + get() match { case r: Try[_] => runnable.executeWithValue(r.asInstanceOf[Try[T]]) - case _: DefaultPromise[_] => compressedRoot().dispatchOrAddCallback(runnable) - case listeners: List[_] => if (updateState(listeners, runnable :: listeners)) () else dispatchOrAddCallback(runnable) + case dp: DefaultPromise[_] => compressedRoot(dp).dispatchOrAddCallback(runnable) + case listeners: List[_] => if (compareAndSet(listeners, runnable :: listeners)) () + else dispatchOrAddCallback(runnable) } } /** Link this promise to the root of another promise using `link()`. Should only be - * be called by Future.flatMap. + * be called by transformWith. */ protected[concurrent] final def linkRootOf(target: DefaultPromise[T]): Unit = link(target.compressedRoot()) @@ -303,18 +332,17 @@ private[concurrent] object Promise { */ @tailrec private def link(target: DefaultPromise[T]): Unit = if (this ne target) { - getState match { + get() match { case r: Try[_] => - if (!target.tryComplete(r.asInstanceOf[Try[T]])) { - // Currently linking is done from Future.flatMap, which should ensure only - // one promise can be completed. Therefore this situation is unexpected. + if (!target.tryComplete(r.asInstanceOf[Try[T]])) throw new IllegalStateException("Cannot link completed promises together") - } - case _: DefaultPromise[_] => - compressedRoot().link(target) - case listeners: List[_] => if (updateState(listeners, target)) { - if (!listeners.isEmpty) listeners.asInstanceOf[List[CallbackRunnable[T]]].foreach(target.dispatchOrAddCallback(_)) - } else link(target) + case dp: DefaultPromise[_] => + compressedRoot(dp).link(target) + case listeners: List[_] if compareAndSet(listeners, target) => + if (listeners.nonEmpty) + listeners.asInstanceOf[List[CallbackRunnable[T]]].foreach(target.dispatchOrAddCallback(_)) + case _ => + link(target) } } } @@ -323,23 +351,58 @@ private[concurrent] object Promise { * * Useful in Future-composition when a value to contribute is already available. */ - final class KeptPromise[T](suppliedValue: Try[T]) extends Promise[T] { + object KeptPromise { + import scala.concurrent.Future + import scala.reflect.ClassTag + + private[this] sealed trait Kept[T] extends Promise[T] { + def result: Try[T] + + override def value: Option[Try[T]] = Some(result) - val value = Some(resolveTry(suppliedValue)) + override def isCompleted: Boolean = true - override def isCompleted: Boolean = true + override def tryComplete(value: Try[T]): Boolean = false - def tryComplete(value: Try[T]): Boolean = false + override def onComplete[U](func: Try[T] => U)(implicit executor: ExecutionContext): Unit = + (new CallbackRunnable(executor.prepare(), func)).executeWithValue(result) - def onComplete[U](func: Try[T] => U)(implicit executor: ExecutionContext): Unit = { - val completedAs = value.get - val preparedEC = executor.prepare() - (new CallbackRunnable(preparedEC, func)).executeWithValue(completedAs) + override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = this + + override def result(atMost: Duration)(implicit permit: CanAwait): T = result.get } - def ready(atMost: Duration)(implicit permit: CanAwait): this.type = this + private[this] final class Successful[T](val result: Success[T]) extends Kept[T] { + override def onFailure[U](pf: PartialFunction[Throwable, U])(implicit executor: ExecutionContext): Unit = () + override def failed: Future[Throwable] = KeptPromise(Failure(new NoSuchElementException("Future.failed not completed with a throwable."))).future + override def recover[U >: T](pf: PartialFunction[Throwable, U])(implicit executor: ExecutionContext): Future[U] = this + override def recoverWith[U >: T](pf: PartialFunction[Throwable, Future[U]])(implicit executor: ExecutionContext): Future[U] = this + override def fallbackTo[U >: T](that: Future[U]): Future[U] = this + } - def result(atMost: Duration)(implicit permit: CanAwait): T = value.get.get + private[this] final class Failed[T](val result: Failure[T]) extends Kept[T] { + private[this] final def thisAs[S]: Future[S] = future.asInstanceOf[Future[S]] + + override def onSuccess[U](pf: PartialFunction[T, U])(implicit executor: ExecutionContext): Unit = () + override def failed: Future[Throwable] = thisAs[Throwable] + override def foreach[U](f: T => U)(implicit executor: ExecutionContext): Unit = () + override def map[S](f: T => S)(implicit executor: ExecutionContext): Future[S] = thisAs[S] + override def flatMap[S](f: T => Future[S])(implicit executor: ExecutionContext): Future[S] = thisAs[S] + override def flatten[S](implicit ev: T <:< Future[S]): Future[S] = thisAs[S] + override def filter(p: T => Boolean)(implicit executor: ExecutionContext): Future[T] = this + override def collect[S](pf: PartialFunction[T, S])(implicit executor: ExecutionContext): Future[S] = thisAs[S] + override def zip[U](that: Future[U]): Future[(T, U)] = thisAs[(T,U)] + override def zipWith[U, R](that: Future[U])(f: (T, U) => R)(implicit executor: ExecutionContext): Future[R] = thisAs[R] + override def fallbackTo[U >: T](that: Future[U]): Future[U] = + if (this eq that) this else that.recoverWith({ case _ => this })(InternalCallbackExecutor) + override def mapTo[S](implicit tag: ClassTag[S]): Future[S] = thisAs[S] + } + + def apply[T](result: Try[T]): scala.concurrent.Promise[T] = + resolveTry(result) match { + case s @ Success(_) => new Successful(s) + case f @ Failure(_) => new Failed(f) + } } } |