aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorphaller <hallerp@gmail.com>2012-12-19 00:22:13 +0100
committerphaller <hallerp@gmail.com>2012-12-19 00:22:13 +0100
commite8e84ba8d3e3ebf266250cdb48176a9fef5361d9 (patch)
treefc0fd882aa34b8b5e5571e26164c62b551e2c9fe
parent3859c11bca3ce9d9d2b18ec629dcc4b6218cf2e0 (diff)
downloadscala-async-e8e84ba8d3e3ebf266250cdb48176a9fef5361d9.tar.gz
scala-async-e8e84ba8d3e3ebf266250cdb48176a9fef5361d9.tar.bz2
scala-async-e8e84ba8d3e3ebf266250cdb48176a9fef5361d9.zip
New fix for #1861: Add fall-back to CPS for all unsupported uses of await
This is a re-implementation of a previous fix. It is more modular, since it enables the definition of a CPS-based fall-back as a subclass of `AsyncBase`. Thus, it's possible to define fall-back-enabled subclasses of `AsyncBase` targetting not only Scala Futures.
-rw-r--r--src/main/scala/scala/async/Async.scala48
-rw-r--r--src/main/scala/scala/async/AsyncAnalysis.scala18
-rw-r--r--src/main/scala/scala/async/AsyncWithCPSFallback.scala31
-rw-r--r--src/main/scala/scala/async/FutureSystem.scala5
-rw-r--r--src/main/scala/scala/async/TransformUtils.scala7
5 files changed, 93 insertions, 16 deletions
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala
index fd28fb8..c2969a9 100644
--- a/src/main/scala/scala/async/Async.scala
+++ b/src/main/scala/scala/async/Async.scala
@@ -6,6 +6,7 @@ package scala.async
import scala.language.experimental.macros
import scala.reflect.macros.Context
+import scala.util.continuations.{cpsParam, reset}
object Async extends AsyncBase {
@@ -60,16 +61,18 @@ abstract class AsyncBase {
@deprecated("`await` must be enclosed in an `async` block", "0.1")
def await[T](awaitable: futureSystem.Fut[T]): T = ???
+ def awaitFallback[T, U](awaitable: futureSystem.Fut[T], p: futureSystem.Prom[U]): T @cpsParam[U, Unit] = ???
+
+ def fallbackEnabled = false
+
def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = {
import c.universe._
- val analyzer = AsyncAnalysis[c.type](c)
+ val analyzer = AsyncAnalysis[c.type](c, this)
val utils = TransformUtils[c.type](c)
import utils.{name, defn}
- import builder.futureSystemOps
-
- analyzer.reportUnsupportedAwaits(body.tree)
+ if (!analyzer.reportUnsupportedAwaits(body.tree) || !fallbackEnabled) {
// Transform to A-normal form:
// - no await calls in qualifiers or arguments,
// - if/match only used in statement position.
@@ -91,6 +94,7 @@ abstract class AsyncBase {
}
val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem, anfTree)
+ import builder.futureSystemOps
val asyncBlock: builder.AsyncBlock = builder.build(anfTree, renameMap)
import asyncBlock.asyncStates
logDiagnostics(c)(anfTree, asyncStates.map(_.toString))
@@ -139,19 +143,16 @@ abstract class AsyncBase {
def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection)
- def spawn(tree: Tree): Tree =
- futureSystemOps.future(c.Expr[Unit](tree))(futureSystemOps.execContext).tree
-
val code: c.Expr[futureSystem.Fut[T]] = {
val isSimple = asyncStates.size == 1
val tree =
if (isSimple)
- Block(Nil, spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }`
+ Block(Nil, futureSystemOps.spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }`
else {
Block(List[Tree](
stateMachine,
ValDef(NoMods, name.stateMachine, stateMachineType, New(Ident(name.stateMachineT), Nil)),
- spawn(Apply(selectStateMachine(name.apply), Nil))
+ futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil))
),
futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree)
}
@@ -160,6 +161,35 @@ abstract class AsyncBase {
AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}")
code
+ } else {
+ // replace `await` invocations with `awaitFallback` invocations
+ val awaitReplacer = new Transformer {
+ override def transform(tree: Tree): Tree = tree match {
+ case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == defn.Async_await =>
+ val typeApp = treeCopy.TypeApply(fun, Ident(defn.Async_awaitFallback), List(TypeTree(futArgTpt.tpe), TypeTree(body.tree.tpe)))
+ treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate)) :+ Ident(name.result))
+ case _ =>
+ super.transform(tree)
+ }
+ }
+ val newBody = awaitReplacer.transform(body.tree)
+
+ val resetBody = reify {
+ reset { c.Expr(c.resetAllAttrs(newBody.duplicate)).splice }
+ }
+
+ val futureSystemOps = futureSystem.mkOps(c)
+ val code = {
+ val tree = Block(List(
+ ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree),
+ futureSystemOps.spawn(resetBody.tree)
+ ), futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](Ident(name.result))).tree)
+ c.Expr[futureSystem.Fut[T]](tree)
+ }
+
+ AsyncUtils.vprintln(s"async CPS fallback transform expands to:\n ${code.tree}")
+ code
+ }
}
def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) {
diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala
index 8bb5bcd..bda4d5c 100644
--- a/src/main/scala/scala/async/AsyncAnalysis.scala
+++ b/src/main/scala/scala/async/AsyncAnalysis.scala
@@ -7,7 +7,7 @@ package scala.async
import scala.reflect.macros.Context
import scala.collection.mutable
-private[async] final case class AsyncAnalysis[C <: Context](c: C) {
+private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: AsyncBase) {
import c.universe._
val utils = TransformUtils[c.type](c)
@@ -20,8 +20,10 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
*
* Must be called on the original tree, not on the ANF transformed tree.
*/
- def reportUnsupportedAwaits(tree: Tree) {
- new UnsupportedAwaitAnalyzer().traverse(tree)
+ def reportUnsupportedAwaits(tree: Tree): Boolean = {
+ val analyzer = new UnsupportedAwaitAnalyzer
+ analyzer.traverse(tree)
+ analyzer.hasUnsupportedAwaits
}
/**
@@ -39,6 +41,8 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
}
private class UnsupportedAwaitAnalyzer extends AsyncTraverser {
+ var hasUnsupportedAwaits = false
+
override def nestedClass(classDef: ClassDef) {
val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class"
if (!reportUnsupportedAwait(classDef, s"nested $kind")) {
@@ -89,10 +93,16 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
}
badAwaits foreach {
tree =>
- c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
+ reportError(tree.pos, s"await must not be used under a $whyUnsupported.")
}
badAwaits.nonEmpty
}
+
+ private def reportError(pos: Position, msg: String) {
+ hasUnsupportedAwaits = true
+ if (!asyncBase.fallbackEnabled)
+ c.error(pos, msg)
+ }
}
private class AsyncDefinitionUseAnalyzer extends AsyncTraverser {
diff --git a/src/main/scala/scala/async/AsyncWithCPSFallback.scala b/src/main/scala/scala/async/AsyncWithCPSFallback.scala
new file mode 100644
index 0000000..39e43a3
--- /dev/null
+++ b/src/main/scala/scala/async/AsyncWithCPSFallback.scala
@@ -0,0 +1,31 @@
+package scala.async
+
+import scala.language.experimental.macros
+
+import scala.reflect.macros.Context
+import scala.util.continuations._
+
+object AsyncWithCPSFallback extends AsyncBase {
+
+ import scala.concurrent.{Future, ExecutionContext}
+ import ExecutionContext.Implicits.global
+
+ lazy val futureSystem = ScalaConcurrentFutureSystem
+ type FS = ScalaConcurrentFutureSystem.type
+
+ /* Fall-back for `await` when it is called at an unsupported position.
+ */
+ override def awaitFallback[T, U](awaitable: futureSystem.Fut[T], p: futureSystem.Prom[U]): T @cpsParam[U, Unit] =
+ shift {
+ (k: (T => U)) =>
+ awaitable onComplete {
+ case tr => p.success(k(tr.get))
+ }
+ }
+
+ override def fallbackEnabled = true
+
+ def async[T](body: T) = macro asyncImpl[T]
+
+ override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body)
+}
diff --git a/src/main/scala/scala/async/FutureSystem.scala b/src/main/scala/scala/async/FutureSystem.scala
index e9373b3..f0b4653 100644
--- a/src/main/scala/scala/async/FutureSystem.scala
+++ b/src/main/scala/scala/async/FutureSystem.scala
@@ -51,9 +51,12 @@ trait FutureSystem {
/** Complete a promise with a value */
def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit]
+
+ def spawn(tree: context.Tree): context.Tree =
+ future(context.Expr[Unit](tree))(execContext).tree
}
- def mkOps(c: Context): Ops {val context: c.type}
+ def mkOps(c: Context): Ops { val context: c.type }
}
object ScalaConcurrentFutureSystem extends FutureSystem {
diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala
index f780799..faebcc3 100644
--- a/src/main/scala/scala/async/TransformUtils.scala
+++ b/src/main/scala/scala/async/TransformUtils.scala
@@ -162,11 +162,14 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe))
val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal")
- val Async_await = {
+ private def asyncMember(name: String) = {
val asyncMod = c.mirror.staticClass("scala.async.AsyncBase")
val tpe = asyncMod.asType.toType
- tpe.member(c.universe.newTermName("await")).ensuring(_ != NoSymbol)
+ tpe.member(newTermName(name)).ensuring(_ != NoSymbol)
}
+
+ val Async_await = asyncMember("await")
+ val Async_awaitFallback = asyncMember("awaitFallback")
}
/** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */