From e8e84ba8d3e3ebf266250cdb48176a9fef5361d9 Mon Sep 17 00:00:00 2001 From: phaller Date: Wed, 19 Dec 2012 00:22:13 +0100 Subject: New fix for #1861: Add fall-back to CPS for all unsupported uses of await This is a re-implementation of a previous fix. It is more modular, since it enables the definition of a CPS-based fall-back as a subclass of `AsyncBase`. Thus, it's possible to define fall-back-enabled subclasses of `AsyncBase` targetting not only Scala Futures. --- src/main/scala/scala/async/Async.scala | 48 ++++++++++++++++++---- src/main/scala/scala/async/AsyncAnalysis.scala | 18 ++++++-- .../scala/scala/async/AsyncWithCPSFallback.scala | 31 ++++++++++++++ src/main/scala/scala/async/FutureSystem.scala | 5 ++- src/main/scala/scala/async/TransformUtils.scala | 7 +++- 5 files changed, 93 insertions(+), 16 deletions(-) create mode 100644 src/main/scala/scala/async/AsyncWithCPSFallback.scala diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index fd28fb8..c2969a9 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -6,6 +6,7 @@ package scala.async import scala.language.experimental.macros import scala.reflect.macros.Context +import scala.util.continuations.{cpsParam, reset} object Async extends AsyncBase { @@ -60,16 +61,18 @@ abstract class AsyncBase { @deprecated("`await` must be enclosed in an `async` block", "0.1") def await[T](awaitable: futureSystem.Fut[T]): T = ??? + def awaitFallback[T, U](awaitable: futureSystem.Fut[T], p: futureSystem.Prom[U]): T @cpsParam[U, Unit] = ??? + + def fallbackEnabled = false + def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { import c.universe._ - val analyzer = AsyncAnalysis[c.type](c) + val analyzer = AsyncAnalysis[c.type](c, this) val utils = TransformUtils[c.type](c) import utils.{name, defn} - import builder.futureSystemOps - - analyzer.reportUnsupportedAwaits(body.tree) + if (!analyzer.reportUnsupportedAwaits(body.tree) || !fallbackEnabled) { // Transform to A-normal form: // - no await calls in qualifiers or arguments, // - if/match only used in statement position. @@ -91,6 +94,7 @@ abstract class AsyncBase { } val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem, anfTree) + import builder.futureSystemOps val asyncBlock: builder.AsyncBlock = builder.build(anfTree, renameMap) import asyncBlock.asyncStates logDiagnostics(c)(anfTree, asyncStates.map(_.toString)) @@ -139,19 +143,16 @@ abstract class AsyncBase { def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) - def spawn(tree: Tree): Tree = - futureSystemOps.future(c.Expr[Unit](tree))(futureSystemOps.execContext).tree - val code: c.Expr[futureSystem.Fut[T]] = { val isSimple = asyncStates.size == 1 val tree = if (isSimple) - Block(Nil, spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }` + Block(Nil, futureSystemOps.spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }` else { Block(List[Tree]( stateMachine, ValDef(NoMods, name.stateMachine, stateMachineType, New(Ident(name.stateMachineT), Nil)), - spawn(Apply(selectStateMachine(name.apply), Nil)) + futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil)) ), futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree) } @@ -160,6 +161,35 @@ abstract class AsyncBase { AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}") code + } else { + // replace `await` invocations with `awaitFallback` invocations + val awaitReplacer = new Transformer { + override def transform(tree: Tree): Tree = tree match { + case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == defn.Async_await => + val typeApp = treeCopy.TypeApply(fun, Ident(defn.Async_awaitFallback), List(TypeTree(futArgTpt.tpe), TypeTree(body.tree.tpe))) + treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate)) :+ Ident(name.result)) + case _ => + super.transform(tree) + } + } + val newBody = awaitReplacer.transform(body.tree) + + val resetBody = reify { + reset { c.Expr(c.resetAllAttrs(newBody.duplicate)).splice } + } + + val futureSystemOps = futureSystem.mkOps(c) + val code = { + val tree = Block(List( + ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree), + futureSystemOps.spawn(resetBody.tree) + ), futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](Ident(name.result))).tree) + c.Expr[futureSystem.Fut[T]](tree) + } + + AsyncUtils.vprintln(s"async CPS fallback transform expands to:\n ${code.tree}") + code + } } def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) { diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index 8bb5bcd..bda4d5c 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -7,7 +7,7 @@ package scala.async import scala.reflect.macros.Context import scala.collection.mutable -private[async] final case class AsyncAnalysis[C <: Context](c: C) { +private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: AsyncBase) { import c.universe._ val utils = TransformUtils[c.type](c) @@ -20,8 +20,10 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { * * Must be called on the original tree, not on the ANF transformed tree. */ - def reportUnsupportedAwaits(tree: Tree) { - new UnsupportedAwaitAnalyzer().traverse(tree) + def reportUnsupportedAwaits(tree: Tree): Boolean = { + val analyzer = new UnsupportedAwaitAnalyzer + analyzer.traverse(tree) + analyzer.hasUnsupportedAwaits } /** @@ -39,6 +41,8 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { } private class UnsupportedAwaitAnalyzer extends AsyncTraverser { + var hasUnsupportedAwaits = false + override def nestedClass(classDef: ClassDef) { val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" if (!reportUnsupportedAwait(classDef, s"nested $kind")) { @@ -89,10 +93,16 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { } badAwaits foreach { tree => - c.error(tree.pos, s"await must not be used under a $whyUnsupported.") + reportError(tree.pos, s"await must not be used under a $whyUnsupported.") } badAwaits.nonEmpty } + + private def reportError(pos: Position, msg: String) { + hasUnsupportedAwaits = true + if (!asyncBase.fallbackEnabled) + c.error(pos, msg) + } } private class AsyncDefinitionUseAnalyzer extends AsyncTraverser { diff --git a/src/main/scala/scala/async/AsyncWithCPSFallback.scala b/src/main/scala/scala/async/AsyncWithCPSFallback.scala new file mode 100644 index 0000000..39e43a3 --- /dev/null +++ b/src/main/scala/scala/async/AsyncWithCPSFallback.scala @@ -0,0 +1,31 @@ +package scala.async + +import scala.language.experimental.macros + +import scala.reflect.macros.Context +import scala.util.continuations._ + +object AsyncWithCPSFallback extends AsyncBase { + + import scala.concurrent.{Future, ExecutionContext} + import ExecutionContext.Implicits.global + + lazy val futureSystem = ScalaConcurrentFutureSystem + type FS = ScalaConcurrentFutureSystem.type + + /* Fall-back for `await` when it is called at an unsupported position. + */ + override def awaitFallback[T, U](awaitable: futureSystem.Fut[T], p: futureSystem.Prom[U]): T @cpsParam[U, Unit] = + shift { + (k: (T => U)) => + awaitable onComplete { + case tr => p.success(k(tr.get)) + } + } + + override def fallbackEnabled = true + + def async[T](body: T) = macro asyncImpl[T] + + override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body) +} diff --git a/src/main/scala/scala/async/FutureSystem.scala b/src/main/scala/scala/async/FutureSystem.scala index e9373b3..f0b4653 100644 --- a/src/main/scala/scala/async/FutureSystem.scala +++ b/src/main/scala/scala/async/FutureSystem.scala @@ -51,9 +51,12 @@ trait FutureSystem { /** Complete a promise with a value */ def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] + + def spawn(tree: context.Tree): context.Tree = + future(context.Expr[Unit](tree))(execContext).tree } - def mkOps(c: Context): Ops {val context: c.type} + def mkOps(c: Context): Ops { val context: c.type } } object ScalaConcurrentFutureSystem extends FutureSystem { diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index f780799..faebcc3 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -162,11 +162,14 @@ private[async] final case class TransformUtils[C <: Context](c: C) { val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe)) val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal") - val Async_await = { + private def asyncMember(name: String) = { val asyncMod = c.mirror.staticClass("scala.async.AsyncBase") val tpe = asyncMod.asType.toType - tpe.member(c.universe.newTermName("await")).ensuring(_ != NoSymbol) + tpe.member(newTermName(name)).ensuring(_ != NoSymbol) } + + val Async_await = asyncMember("await") + val Async_awaitFallback = asyncMember("awaitFallback") } /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ -- cgit v1.2.3