From ffd7b9649cde61553f16c25f64c3072ecfa6c713 Mon Sep 17 00:00:00 2001 From: Philipp Haller Date: Tue, 26 Feb 2013 01:16:40 +0100 Subject: Remove CPS dependency from default async implementation - move all CPS-related code to `continuations` sub package - fix CPS-based async implementation - enable testing of CPS-based async implementation --- src/main/scala/scala/async/Async.scala | 37 +-------- .../scala/scala/async/AsyncWithCPSFallback.scala | 31 -------- src/main/scala/scala/async/FutureSystem.scala | 8 ++ src/main/scala/scala/async/TransformUtils.scala | 3 +- .../continuations/AsyncBaseWithCPSFallback.scala | 90 ++++++++++++++++++++++ .../async/continuations/AsyncWithCPSFallback.scala | 20 +++++ .../scala/async/continuations/CPSBasedAsync.scala | 21 +++++ .../async/continuations/CPSBasedAsyncBase.scala | 21 +++++ .../continuations/ScalaConcurrentCPSFallback.scala | 31 ++++++++ src/test/scala/scala/async/run/cps/CPSSpec.scala | 49 ++++++++++++ 10 files changed, 244 insertions(+), 67 deletions(-) delete mode 100644 src/main/scala/scala/async/AsyncWithCPSFallback.scala create mode 100644 src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala create mode 100644 src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala create mode 100644 src/main/scala/scala/async/continuations/CPSBasedAsync.scala create mode 100644 src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala create mode 100644 src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala create mode 100644 src/test/scala/scala/async/run/cps/CPSSpec.scala diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index 8323ac5..5a71ebe 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -6,7 +6,6 @@ package scala.async import scala.language.experimental.macros import scala.reflect.macros.Context -import scala.util.continuations.{cpsParam, reset} object Async extends AsyncBase { @@ -61,9 +60,7 @@ abstract class AsyncBase { @deprecated("`await` must be enclosed in an `async` block", "0.1") def await[T](awaitable: futureSystem.Fut[T]): T = ??? - def awaitFallback[T, U](awaitable: futureSystem.Fut[T], p: futureSystem.Prom[U]): T @cpsParam[U, Unit] = ??? - - def fallbackEnabled = false + protected[async] def fallbackEnabled = false def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { import c.universe._ @@ -72,7 +69,8 @@ abstract class AsyncBase { val utils = TransformUtils[c.type](c) import utils.{name, defn} - if (!analyzer.reportUnsupportedAwaits(body.tree) || !fallbackEnabled) { + analyzer.reportUnsupportedAwaits(body.tree) + // Transform to A-normal form: // - no await calls in qualifiers or arguments, // - if/match only used in statement position. @@ -162,35 +160,6 @@ abstract class AsyncBase { AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}") code - } else { - // replace `await` invocations with `awaitFallback` invocations - val awaitReplacer = new Transformer { - override def transform(tree: Tree): Tree = tree match { - case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == defn.Async_await => - val typeApp = treeCopy.TypeApply(fun, Ident(defn.Async_awaitFallback), List(TypeTree(futArgTpt.tpe), TypeTree(body.tree.tpe))) - treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate)) :+ Ident(name.result)) - case _ => - super.transform(tree) - } - } - val newBody = awaitReplacer.transform(body.tree) - - val resetBody = reify { - reset { c.Expr(c.resetAllAttrs(newBody.duplicate)).splice } - } - - val futureSystemOps = futureSystem.mkOps(c) - val code = { - val tree = Block(List( - ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree), - futureSystemOps.spawn(resetBody.tree) - ), futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](Ident(name.result))).tree) - c.Expr[futureSystem.Fut[T]](tree) - } - - AsyncUtils.vprintln(s"async CPS fallback transform expands to:\n ${code.tree}") - code - } } def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) { diff --git a/src/main/scala/scala/async/AsyncWithCPSFallback.scala b/src/main/scala/scala/async/AsyncWithCPSFallback.scala deleted file mode 100644 index 39e43a3..0000000 --- a/src/main/scala/scala/async/AsyncWithCPSFallback.scala +++ /dev/null @@ -1,31 +0,0 @@ -package scala.async - -import scala.language.experimental.macros - -import scala.reflect.macros.Context -import scala.util.continuations._ - -object AsyncWithCPSFallback extends AsyncBase { - - import scala.concurrent.{Future, ExecutionContext} - import ExecutionContext.Implicits.global - - lazy val futureSystem = ScalaConcurrentFutureSystem - type FS = ScalaConcurrentFutureSystem.type - - /* Fall-back for `await` when it is called at an unsupported position. - */ - override def awaitFallback[T, U](awaitable: futureSystem.Fut[T], p: futureSystem.Prom[U]): T @cpsParam[U, Unit] = - shift { - (k: (T => U)) => - awaitable onComplete { - case tr => p.success(k(tr.get)) - } - } - - override def fallbackEnabled = true - - def async[T](body: T) = macro asyncImpl[T] - - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body) -} diff --git a/src/main/scala/scala/async/FutureSystem.scala b/src/main/scala/scala/async/FutureSystem.scala index f0b4653..a050bec 100644 --- a/src/main/scala/scala/async/FutureSystem.scala +++ b/src/main/scala/scala/async/FutureSystem.scala @@ -54,6 +54,8 @@ trait FutureSystem { def spawn(tree: context.Tree): context.Tree = future(context.Expr[Unit](tree))(execContext).tree + + def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] } def mkOps(c: Context): Ops { val context: c.type } @@ -101,6 +103,10 @@ object ScalaConcurrentFutureSystem extends FutureSystem { prom.splice.complete(value.splice) context.literalUnit.splice } + + def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = reify { + future.splice.asInstanceOf[Fut[A]] + } } } @@ -145,5 +151,7 @@ object IdentityFutureSystem extends FutureSystem { prom.splice.a = value.splice.get context.literalUnit.splice } + + def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = ??? } } diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index 38c33a4..090a334 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -181,8 +181,7 @@ private[async] final case class TransformUtils[C <: Context](c: C) { tpe.member(newTermName(name)).ensuring(_ != NoSymbol) } - val Async_await = asyncMember("await") - val Async_awaitFallback = asyncMember("awaitFallback") + val Async_await = asyncMember("await") } /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ diff --git a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala new file mode 100644 index 0000000..a669cfa --- /dev/null +++ b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async +package continuations + +import scala.language.experimental.macros + +import scala.reflect.macros.Context +import scala.util.continuations._ + +trait AsyncBaseWithCPSFallback extends AsyncBase { + + /* Fall-back for `await` using CPS plugin. + * + * Note: This method is public, but is intended only for internal use. + */ + def awaitFallback[T](awaitable: futureSystem.Fut[T]): T @cps[futureSystem.Fut[Any]] + + override protected[async] def fallbackEnabled = true + + /* Implements `async { ... }` using the CPS plugin. + */ + protected def cpsBasedAsyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { + import c.universe._ + + def lookupMember(name: String) = { + val asyncTrait = c.mirror.staticClass("scala.async.continuations.AsyncBaseWithCPSFallback") + val tpe = asyncTrait.asType.toType + tpe.member(newTermName(name)).ensuring(_ != NoSymbol) + } + + AsyncUtils.vprintln("AsyncBaseWithCPSFallback.cpsBasedAsyncImpl") + + val utils = TransformUtils[c.type](c) + val futureSystemOps = futureSystem.mkOps(c) + val awaitSym = utils.defn.Async_await + val awaitFallbackSym = lookupMember("awaitFallback") + + // replace `await` invocations with `awaitFallback` invocations + val awaitReplacer = new Transformer { + override def transform(tree: Tree): Tree = tree match { + case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == awaitSym => + val typeApp = treeCopy.TypeApply(fun, Ident(awaitFallbackSym), List(TypeTree(futArgTpt.tpe))) + treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate))) + case _ => + super.transform(tree) + } + } + val bodyWithAwaitFallback = awaitReplacer.transform(body.tree) + + /* generate an expression that looks like this: + reset { + val f = future { ... } + ... + val x = awaitFallback(f) + ... + future { expr } + }.asInstanceOf[Future[T]] + */ + + val bodyWithFuture = { + val tree = bodyWithAwaitFallback match { + case Block(stmts, expr) => Block(stmts, futureSystemOps.spawn(expr)) + case expr => futureSystemOps.spawn(expr) + } + c.Expr[futureSystem.Fut[Any]](c.resetAllAttrs(tree.duplicate)) + } + + val bodyWithReset: c.Expr[futureSystem.Fut[Any]] = reify { + reset { bodyWithFuture.splice } + } + val bodyWithCast = futureSystemOps.castTo[T](bodyWithReset) + + AsyncUtils.vprintln(s"CPS-based async transform expands to:\n${bodyWithCast.tree}") + bodyWithCast + } + + override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { + AsyncUtils.vprintln("AsyncBaseWithCPSFallback.asyncImpl") + + val analyzer = AsyncAnalysis[c.type](c, this) + + if (!analyzer.reportUnsupportedAwaits(body.tree)) + super.asyncImpl[T](c)(body) // no unsupported awaits + else + cpsBasedAsyncImpl[T](c)(body) // fallback to CPS + } +} diff --git a/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala new file mode 100644 index 0000000..fe6e1a6 --- /dev/null +++ b/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala @@ -0,0 +1,20 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async +package continuations + +import scala.language.experimental.macros + +import scala.reflect.macros.Context +import scala.concurrent.Future + +trait AsyncWithCPSFallback extends AsyncBaseWithCPSFallback with ScalaConcurrentCPSFallback + +object AsyncWithCPSFallback extends AsyncWithCPSFallback { + + def async[T](body: T) = macro asyncImpl[T] + + override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body) +} diff --git a/src/main/scala/scala/async/continuations/CPSBasedAsync.scala b/src/main/scala/scala/async/continuations/CPSBasedAsync.scala new file mode 100644 index 0000000..922d1ac --- /dev/null +++ b/src/main/scala/scala/async/continuations/CPSBasedAsync.scala @@ -0,0 +1,21 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async +package continuations + +import scala.language.experimental.macros + +import scala.reflect.macros.Context +import scala.concurrent.Future + +trait CPSBasedAsync extends CPSBasedAsyncBase with ScalaConcurrentCPSFallback + +object CPSBasedAsync extends CPSBasedAsync { + + def async[T](body: T) = macro asyncImpl[T] + + override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body) + +} diff --git a/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala b/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala new file mode 100644 index 0000000..4e8ec80 --- /dev/null +++ b/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala @@ -0,0 +1,21 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async +package continuations + +import scala.language.experimental.macros + +import scala.reflect.macros.Context +import scala.util.continuations._ + +/* Specializes `AsyncBaseWithCPSFallback` to always fall back to CPS, yielding a purely CPS-based + * implementation of async/await. + */ +trait CPSBasedAsyncBase extends AsyncBaseWithCPSFallback { + + override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = + super.cpsBasedAsyncImpl[T](c)(body) + +} diff --git a/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala b/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala new file mode 100644 index 0000000..018ad05 --- /dev/null +++ b/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async +package continuations + +import scala.util.continuations._ +import scala.concurrent.{Future, Promise, ExecutionContext} + +trait ScalaConcurrentCPSFallback { + self: AsyncBaseWithCPSFallback => + + import ExecutionContext.Implicits.global + + lazy val futureSystem = ScalaConcurrentFutureSystem + type FS = ScalaConcurrentFutureSystem.type + + /* Fall-back for `await` when it is called at an unsupported position. + */ + override def awaitFallback[T](awaitable: futureSystem.Fut[T]): T @cps[Future[Any]] = + shift { + (k: (T => Future[Any])) => + val fr = Promise[Any]() + awaitable onComplete { + case tr => fr completeWith k(tr.get) + } + fr.future + } + +} diff --git a/src/test/scala/scala/async/run/cps/CPSSpec.scala b/src/test/scala/scala/async/run/cps/CPSSpec.scala new file mode 100644 index 0000000..b56c6ad --- /dev/null +++ b/src/test/scala/scala/async/run/cps/CPSSpec.scala @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async +package run +package cps + +import scala.concurrent.{Future, Promise, ExecutionContext, future, Await} +import scala.concurrent.duration._ +import scala.async.continuations.CPSBasedAsync._ +import scala.util.continuations._ + +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.junit.Test + +@RunWith(classOf[JUnit4]) +class CPSSpec { + + import ExecutionContext.Implicits.global + + def m1(y: Int): Future[Int] = async { + val f = future { y + 2 } + val f2 = future { y + 3 } + val x1 = await(f) + val x2 = await(f2) + x1 + x2 + } + + def m2(y: Int): Future[Int] = async { + val f = future { y + 2 } + val res = await(f) + if (y > 0) res + 2 + else res - 2 + } + + @Test + def testCPSFallback() { + val fut1 = m1(10) + val res1 = Await.result(fut1, 2.seconds) + assert(res1 == 25, s"expected 25, got $res1") + + val fut2 = m2(10) + val res2 = Await.result(fut2, 2.seconds) + assert(res2 == 14, s"expected 14, got $res2") + } + +} -- cgit v1.2.3