From a36175feb5bfce59909fa4f3d9d5df6753b6ee3a Mon Sep 17 00:00:00 2001 From: Aleksandar Prokopec Date: Wed, 7 Dec 2011 11:59:35 +0100 Subject: Implementation of the fj task and future. --- .../scala/concurrent/ExecutionContext.scala | 92 -------------- .../scala/concurrent/ForkJoinTaskImpl.scala | 133 +++++++++++++++++++++ src/library/scala/concurrent/Future.scala | 23 ++-- src/library/scala/concurrent/Task.scala | 13 ++ src/library/scala/concurrent/package.scala | 20 +++- src/library/scala/package.scala | 3 +- 6 files changed, 175 insertions(+), 109 deletions(-) create mode 100644 src/library/scala/concurrent/ForkJoinTaskImpl.scala create mode 100644 src/library/scala/concurrent/Task.scala diff --git a/src/library/scala/concurrent/ExecutionContext.scala b/src/library/scala/concurrent/ExecutionContext.scala index 230f0de388..34c14147f5 100644 --- a/src/library/scala/concurrent/ExecutionContext.scala +++ b/src/library/scala/concurrent/ExecutionContext.scala @@ -21,98 +21,6 @@ trait ExecutionContext { } -trait Task[T] { - - def start(): Unit - - def future: Future[T] - -} - - -/* DONE: The challenge is to make ForkJoinPromise inherit from RecursiveAction - * to avoid an object allocation per promise. This requires turning DefaultPromise - * into a trait, i.e., removing its constructor parameters. - */ -private[concurrent] class ForkJoinTaskImpl[T](context: ForkJoinExecutionContext, body: () => T, within: Timeout) extends FJTask[T] with Task[T] { - - val timeout = within - implicit val dispatcher = context - - // body of RecursiveTask - def compute(): T = - body() - - def start(): Unit = - fork() - - def future: Future[T] = { - null - } - - // TODO FIXME: handle timeouts - def await(atMost: Duration): this.type = - await - - def await: this.type = { - this.join() - this - } - - def tryCancel(): Unit = - tryUnfork() -} - -private[concurrent] final class ForkJoinExecutionContext extends ExecutionContext { - val pool = new ForkJoinPool - - @inline - private def executeForkJoinTask(task: RecursiveAction) { - if (Thread.currentThread.isInstanceOf[ForkJoinWorkerThread]) - task.fork() - else - pool execute task - } - - def execute(task: Runnable) { - val action = new RecursiveAction { def compute() { task.run() } } - executeForkJoinTask(action) - } - - def makeTask[T](body: () => T)(implicit timeout: Timeout): Task[T] = { - new ForkJoinTaskImpl(this, body, timeout) - } - - def makePromise[T](timeout: Timeout): Promise[T] = - null - - def blockingCall[T](body: Blockable[T]): T = - body.block()(CanBlockEvidence) - -} - -/** - * Implements a blocking execution context - */ -/* -private[concurrent] class BlockingExecutionContext extends ExecutionContext { - //val pool = makeCachedThreadPool // TODO FIXME: need to merge thread pool factory methods from Heather's parcolls repo - - def execute(task: Runnable) { - /* TODO - val p = newPromise(task.run()) - p.start() - pool execute p - */ - } - - // TODO FIXME: implement - def newPromise[T](body: => T): Promise[T] = { - throw new Exception("not yet implemented") - } -} -*/ - object ExecutionContext { lazy val forNonBlocking = new ForkJoinExecutionContext diff --git a/src/library/scala/concurrent/ForkJoinTaskImpl.scala b/src/library/scala/concurrent/ForkJoinTaskImpl.scala new file mode 100644 index 0000000000..6a33ca162a --- /dev/null +++ b/src/library/scala/concurrent/ForkJoinTaskImpl.scala @@ -0,0 +1,133 @@ +package scala.concurrent + + + +import scala.concurrent.forkjoin.{ ForkJoinPool, RecursiveTask => FJTask, RecursiveAction, ForkJoinWorkerThread } +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater +import scala.annotation.tailrec + + + +/* DONE: The challenge is to make ForkJoinPromise inherit from RecursiveAction + * to avoid an object allocation per promise. This requires turning DefaultPromise + * into a trait, i.e., removing its constructor parameters. + */ +private[concurrent] class ForkJoinTaskImpl[T](val executionContext: ForkJoinExecutionContext, val body: () => T, val timeout: Timeout) +extends FJTask[T] with Task[T] with Future[T] { + + private val updater = AtomicReferenceFieldUpdater.newUpdater(classOf[ForkJoinTaskImpl[T]], classOf[FJState[T]], "state") + @volatile private var state: State[T] = _ + + updater.set(this, Pending(List())) + + private def casState(oldv: State[T], newv: State[T]) = { + updater.compareAndSet(this, oldv, newv) + } + + @tailrec private def trySucceedState(res: T): Unit = updater.get(this) match { + case p @ Pending(cbs) => if (!casState(p, Success(res))) trySucceedState(res) + case _ => // return + } + + @tailrec private def tryFailState(t: Throwable): Unit = updater.get(this) match { + case p @ Pending(cbs) => if (!casState(p, Failure(t))) tryFailState(t) + case _ => // return + } + + // body of RecursiveTask + def compute(): T = { + try { + val res = body() + trySucceedState(res) + } catch handledFutureException andThen { + t => tryFailState(t) + } finally tryFailState(new ExecutionException) + } + + def start(): Unit = { + Thread.currentThread match { + case fj: ForkJoinWorkerThread if fj.pool eq executionContext.pool => fork() + case _ => executionContext.pool.execute(this) + } + } + + def future: Future[T] = this + + def onComplete[U](callback: Either[Throwable, T] => U): this.type = { + @tailrec def tryAddCallback(): Either[Throwable, T] = { + updater.get(this) match { + case p @ Pending(lst) => + val pt = p.asInstanceOf[Pending[T]] + if (casState(pt, Pending(callback :: pt.lst))) null + else tryAddCallback() + case Success(res) => Right(res) + case Failure(t) => Left(t) + } + } + + val res = tryAddCallback() + if (res != null) dispatchTask new Runnable { + override def run() = + try callback(res) + catch handledFutureException + } + } + + private def dispatchTask[U](r: Runnable) = executionContext execute r + + def isTimedout: Boolean = false // TODO + + // TODO FIXME: handle timeouts + def await(atMost: Duration): this.type = + await + + def await: this.type = { + this.join() + this + } + + def tryCancel(): Unit = + tryUnfork() + +} + + +private[concurrent] sealed abstract class FJState[T] + + +case class Pending[T](callbacks: List[Either[Throwable, T] => Any]) extends FJState[T] + + +case class Success[T](result: T) extends FJState[T] + + +case class Failure[T](throwable: Throwable) extends FJState[T] + + +private[concurrent] final class ForkJoinExecutionContext extends ExecutionContext { + val pool = new ForkJoinPool + + @inline + private def executeForkJoinTask(task: RecursiveAction) { + if (Thread.currentThread.isInstanceOf[ForkJoinWorkerThread]) + task.fork() + else + pool execute task + } + + def execute(task: Runnable) { + val action = new RecursiveAction { def compute() { task.run() } } + executeForkJoinTask(action) + } + + def makeTask[T](body: () => T)(implicit timeout: Timeout): Task[T] = { + new ForkJoinTaskImpl(this, body, timeout) + } + + def makePromise[T](timeout: Timeout): Promise[T] = + null + + def blockingCall[T](body: Blockable[T]): T = + body.block()(CanBlockEvidence) + +} diff --git a/src/library/scala/concurrent/Future.scala b/src/library/scala/concurrent/Future.scala index 2393efcef6..b65d777d67 100644 --- a/src/library/scala/concurrent/Future.scala +++ b/src/library/scala/concurrent/Future.scala @@ -108,7 +108,7 @@ self => * $multipleCallbacks */ def onTimeout[U](body: =>U): this.type = onComplete { - case Left(te: TimeoutException) => body + case Left(te: FutureTimeoutException) => body } /** When this future is completed, either through an exception, a timeout, or a value, @@ -124,9 +124,13 @@ self => /* Miscellaneous */ + /** The execution context of the future. + */ + def executionContext: ExecutionContext + /** Creates a new promise. */ - def newPromise[S]: Promise[S] + def newPromise[S]: Promise[S] = executionContext promise /** Tests whether this Future's timeout has expired. * @@ -162,10 +166,10 @@ self => def timeout = self.timeout } - def timedout: Future[TimeoutException] = new Future[TimeoutException] { + def timedout: Future[FutureTimeoutException] = new Future[FutureTimeoutException] { def newPromise[S] = self.newPromise[S] - def onComplete[U](func: Either[Throwable, TimeoutException] => U) = self.onComplete { - case Left(te: TimeoutException) => func(Right(te)) + def onComplete[U](func: Either[Throwable, FutureTimeoutException] => U) = self.onComplete { + case Left(te: FutureTimeoutException) => func(Right(te)) case _ => // do nothing } def isTimedout = self.isTimedout @@ -273,12 +277,3 @@ self => } -/** A timeout exception. - * - * Futures are failed with a timeout exception when their timeout expires. - * - * Each timeout exception contains an origin future which originally timed out. - */ -class TimeoutException(origin: Future[T], message: String) extends java.util.concurrent.TimeoutException(message) { - def this(origin: Future[T]) = this(origin, "Future timed out.") -} diff --git a/src/library/scala/concurrent/Task.scala b/src/library/scala/concurrent/Task.scala new file mode 100644 index 0000000000..98c7da77d2 --- /dev/null +++ b/src/library/scala/concurrent/Task.scala @@ -0,0 +1,13 @@ +package scala.concurrent + + + +trait Task[T] { + + def start(): Unit + + def future: Future[T] + +} + + diff --git a/src/library/scala/concurrent/package.scala b/src/library/scala/concurrent/package.scala index 74e8b71eff..51bb1ac3e0 100644 --- a/src/library/scala/concurrent/package.scala +++ b/src/library/scala/concurrent/package.scala @@ -16,8 +16,10 @@ package scala /** This package object contains primitives for parallel programming. */ package object concurrent { - - type MessageDispatcher = ExecutionContext // TODO FIXME: change futures to use execution context + + type ExecutionException = java.util.concurrent.ExecutionException + type CancellationException = java.util.concurrent.CancellationException + type TimeoutException = java.util.concurrent.TimeoutException private[concurrent] def currentExecutionContext: ThreadLocal[ExecutionContext] = new ThreadLocal[ExecutionContext] { override protected def initialValue = null @@ -59,6 +61,10 @@ package object concurrent { def future[T](body: =>T): Future[T] = null // TODO + val handledFutureException: PartialFunction[Throwable, Throwable] = { + case t: Throwable if isFutureThrowable => t + } + // TODO rename appropriately and make public private[concurrent] def isFutureThrowable(t: Throwable) = t match { case e: Error => false @@ -74,4 +80,14 @@ package concurrent { private[concurrent] trait CanBlock + /** A timeout exception. + * + * Futures are failed with a timeout exception when their timeout expires. + * + * Each timeout exception contains an origin future which originally timed out. + */ + class FutureTimeoutException(origin: Future[T], message: String) extends TimeoutException(message) { + def this(origin: Future[T]) = this(origin, "Future timed out.") + } + } diff --git a/src/library/scala/package.scala b/src/library/scala/package.scala index 0c5d10b15e..915ce6a648 100644 --- a/src/library/scala/package.scala +++ b/src/library/scala/package.scala @@ -27,7 +27,8 @@ package object scala { type NoSuchElementException = java.util.NoSuchElementException type NumberFormatException = java.lang.NumberFormatException type AbstractMethodError = java.lang.AbstractMethodError - + type InterruptedException = java.lang.InterruptedException + @deprecated("instead of `@serializable class C`, use `class C extends Serializable`", "2.9.0") type serializable = annotation.serializable -- cgit v1.2.3