aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/internal
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2013-07-07 10:48:11 +1000
committerJason Zaugg <jzaugg@gmail.com>2013-07-07 10:48:11 +1000
commit2d8506a64392cd7192b6831c38798cc9a7c8bfed (patch)
tree84eafcf1a9a179eeaa97dd1e3595c18351b2b814 /src/main/scala/scala/async/internal
parentc60c38ca6098402f7a9cc6d6746b664bb2b1306c (diff)
downloadscala-async-2d8506a64392cd7192b6831c38798cc9a7c8bfed.tar.gz
scala-async-2d8506a64392cd7192b6831c38798cc9a7c8bfed.tar.bz2
scala-async-2d8506a64392cd7192b6831c38798cc9a7c8bfed.zip
Move implementation details to scala.async.internal._.
If we intend to keep CPS fallback around for any length of time it should probably move there too.
Diffstat (limited to 'src/main/scala/scala/async/internal')
-rw-r--r--src/main/scala/scala/async/internal/AnfTransform.scala253
-rw-r--r--src/main/scala/scala/async/internal/AsyncAnalysis.scala91
-rw-r--r--src/main/scala/scala/async/internal/AsyncBase.scala58
-rw-r--r--src/main/scala/scala/async/internal/AsyncId.scala64
-rw-r--r--src/main/scala/scala/async/internal/AsyncMacro.scala29
-rw-r--r--src/main/scala/scala/async/internal/AsyncTransform.scala176
-rw-r--r--src/main/scala/scala/async/internal/AsyncUtils.scala16
-rw-r--r--src/main/scala/scala/async/internal/ExprBuilder.scala388
-rw-r--r--src/main/scala/scala/async/internal/FutureSystem.scala106
-rw-r--r--src/main/scala/scala/async/internal/Lifter.scala150
-rw-r--r--src/main/scala/scala/async/internal/StateAssigner.scala14
-rw-r--r--src/main/scala/scala/async/internal/TransformUtils.scala251
12 files changed, 1596 insertions, 0 deletions
diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala
new file mode 100644
index 0000000..80f8161
--- /dev/null
+++ b/src/main/scala/scala/async/internal/AnfTransform.scala
@@ -0,0 +1,253 @@
+
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async.internal
+
+import scala.tools.nsc.Global
+import scala.Predef._
+
+private[async] trait AnfTransform {
+ self: AsyncMacro =>
+
+ import global._
+ import reflect.internal.Flags._
+
+ def anfTransform(tree: Tree): Block = {
+ // Must prepend the () for issue #31.
+ val block = callSiteTyper.typedPos(tree.pos)(Block(List(Literal(Constant(()))), tree)).setType(tree.tpe)
+
+ new SelectiveAnfTransform().transform(block)
+ }
+
+ sealed abstract class AnfMode
+
+ case object Anf extends AnfMode
+
+ case object Linearizing extends AnfMode
+
+ final class SelectiveAnfTransform extends MacroTypingTransformer {
+ var mode: AnfMode = Anf
+
+ def blockToList(tree: Tree): List[Tree] = tree match {
+ case Block(stats, expr) => stats :+ expr
+ case t => t :: Nil
+ }
+
+ def listToBlock(trees: List[Tree]): Block = trees match {
+ case trees @ (init :+ last) =>
+ val pos = trees.map(_.pos).reduceLeft(_ union _)
+ Block(init, last).setType(last.tpe).setPos(pos)
+ }
+
+ override def transform(tree: Tree): Block = {
+ def anfLinearize: Block = {
+ val trees: List[Tree] = mode match {
+ case Anf => anf._transformToList(tree)
+ case Linearizing => linearize._transformToList(tree)
+ }
+ listToBlock(trees)
+ }
+ tree match {
+ case _: ValDef | _: DefDef | _: Function | _: ClassDef | _: TypeDef =>
+ atOwner(tree.symbol)(anfLinearize)
+ case _: ModuleDef =>
+ atOwner(tree.symbol.moduleClass orElse tree.symbol)(anfLinearize)
+ case _ =>
+ anfLinearize
+ }
+ }
+
+ private object linearize {
+ def transformToList(tree: Tree): List[Tree] = {
+ mode = Linearizing; blockToList(transform(tree))
+ }
+
+ def transformToBlock(tree: Tree): Block = listToBlock(transformToList(tree))
+
+ def _transformToList(tree: Tree): List[Tree] = trace(tree) {
+ val stats :+ expr = anf.transformToList(tree)
+ expr match {
+ case Apply(fun, args) if isAwait(fun) =>
+ val valDef = defineVal(name.await, expr, tree.pos)
+ stats :+ valDef :+ gen.mkAttributedStableRef(valDef.symbol)
+
+ case If(cond, thenp, elsep) =>
+ // if type of if-else is Unit don't introduce assignment,
+ // but add Unit value to bring it into form expected by async transform
+ if (expr.tpe =:= definitions.UnitTpe) {
+ stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(())))
+ } else {
+ val varDef = defineVar(name.ifRes, expr.tpe, tree.pos)
+ def branchWithAssign(orig: Tree) = localTyper.typedPos(orig.pos) {
+ def cast(t: Tree) = mkAttributedCastPreservingAnnotations(t, varDef.symbol.tpe)
+ orig match {
+ case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr)))
+ case _ => Assign(Ident(varDef.symbol), cast(orig))
+ }
+ }.setType(orig.tpe)
+ val ifWithAssign = treeCopy.If(tree, cond, branchWithAssign(thenp), branchWithAssign(elsep))
+ stats :+ varDef :+ ifWithAssign :+ gen.mkAttributedStableRef(varDef.symbol)
+ }
+
+ case Match(scrut, cases) =>
+ // if type of match is Unit don't introduce assignment,
+ // but add Unit value to bring it into form expected by async transform
+ if (expr.tpe =:= definitions.UnitTpe) {
+ stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(())))
+ }
+ else {
+ val varDef = defineVar(name.matchRes, expr.tpe, tree.pos)
+ def typedAssign(lhs: Tree) =
+ localTyper.typedPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, varDef.symbol.tpe)))
+ val casesWithAssign = cases map {
+ case cd@CaseDef(pat, guard, body) =>
+ val newBody = body match {
+ case b@Block(caseStats, caseExpr) => treeCopy.Block(b, caseStats, typedAssign(caseExpr))
+ case _ => typedAssign(body)
+ }
+ treeCopy.CaseDef(cd, pat, guard, newBody)
+ }
+ val matchWithAssign = treeCopy.Match(tree, scrut, casesWithAssign)
+ require(matchWithAssign.tpe != null, matchWithAssign)
+ stats :+ varDef :+ matchWithAssign :+ gen.mkAttributedStableRef(varDef.symbol)
+ }
+ case _ =>
+ stats :+ expr
+ }
+ }
+
+ private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
+ val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(tp)
+ ValDef(sym, gen.mkZero(tp)).setType(NoType).setPos(pos)
+ }
+ }
+
+ private object trace {
+ private var indent = -1
+
+ def indentString = " " * indent
+
+ def apply[T](args: Any)(t: => T): T = {
+ def prefix = mode.toString.toLowerCase
+ indent += 1
+ def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127)
+ try {
+ AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})")
+ val result = t
+ AsyncUtils.trace(s"${indentString}= ${oneLine(result)}")
+ result
+ } finally {
+ indent -= 1
+ }
+ }
+ }
+
+ private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = {
+ val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, SYNTHETIC).setInfo(lhs.tpe)
+ changeOwner(lhs, currentOwner, sym)
+ ValDef(sym, changeOwner(lhs, currentOwner, sym)).setType(NoType).setPos(pos)
+ }
+
+ private object anf {
+ def transformToList(tree: Tree): List[Tree] = {
+ mode = Anf; blockToList(transform(tree))
+ }
+
+ def _transformToList(tree: Tree): List[Tree] = trace(tree) {
+ val containsAwait = tree exists isAwait
+ if (!containsAwait) {
+ List(tree)
+ } else tree match {
+ case Select(qual, sel) =>
+ val stats :+ expr = linearize.transformToList(qual)
+ stats :+ treeCopy.Select(tree, expr, sel)
+
+ case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty =>
+ // we an assume that no await call appears in a by-name argument position,
+ // this has already been checked.
+ val funStats :+ simpleFun = linearize.transformToList(fun)
+ def isAwaitRef(name: Name) = name.toString.startsWith(AnfTransform.this.name.await + "$")
+ val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) =
+ mapArgumentss[List[Tree]](fun, argss) {
+ case Arg(expr, byName, _) if byName /*|| isPure(expr) TODO */ => (Nil, expr)
+ case Arg(expr@Ident(name), _, _) if isAwaitRef(name) => (Nil, expr) // TODO needed? // not typed, so it eludes the check in `isSafeToInline`
+ case Arg(expr, _, argName) =>
+ linearize.transformToList(expr) match {
+ case stats :+ expr1 =>
+ val valDef = defineVal(argName, expr1, expr1.pos)
+ require(valDef.tpe != null, valDef)
+ val stats1 = stats :+ valDef
+ //stats1.foreach(changeOwner(_, currentOwner, currentOwner.owner))
+ (stats1, gen.stabilize(gen.mkAttributedIdent(valDef.symbol)))
+ }
+ }
+ val applied = treeInfo.dissectApplied(tree)
+ val core = if (targs.isEmpty) simpleFun else treeCopy.TypeApply(applied.callee, simpleFun, targs)
+ val newApply = argExprss.foldLeft(core)(Apply(_, _)).setSymbol(tree.symbol)
+ val typedNewApply = localTyper.typedPos(tree.pos)(newApply).setType(tree.tpe)
+ funStats ++ argStatss.flatten.flatten :+ typedNewApply
+ case Block(stats, expr) =>
+ (stats :+ expr).flatMap(linearize.transformToList)
+
+ case ValDef(mods, name, tpt, rhs) =>
+ if (rhs exists isAwait) {
+ val stats :+ expr = atOwner(currOwner.owner)(linearize.transformToList(rhs))
+ stats.foreach(changeOwner(_, currOwner, currOwner.owner))
+ stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr)
+ } else List(tree)
+
+ case Assign(lhs, rhs) =>
+ val stats :+ expr = linearize.transformToList(rhs)
+ stats :+ treeCopy.Assign(tree, lhs, expr)
+
+ case If(cond, thenp, elsep) =>
+ val condStats :+ condExpr = linearize.transformToList(cond)
+ val thenBlock = linearize.transformToBlock(thenp)
+ val elseBlock = linearize.transformToBlock(elsep)
+ // Typechecking with `condExpr` as the condition fails if the condition
+ // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems
+ // we rely on this call to `typeCheck` descending into the branches.
+ // But, we can get away with typechecking a throwaway `If` tree with the
+ // original scrutinee and the new branches, and setting that type on
+ // the real `If` tree.
+ val iff = treeCopy.If(tree, condExpr, thenBlock, elseBlock)
+ condStats :+ iff
+
+ case Match(scrut, cases) =>
+ val scrutStats :+ scrutExpr = linearize.transformToList(scrut)
+ val caseDefs = cases map {
+ case CaseDef(pat, guard, body) =>
+ // extract local variables for all names bound in `pat`, and rewrite `body`
+ // to refer to these.
+ // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`.
+ val block = linearize.transformToBlock(body)
+ val (valDefs, mappings) = (pat collect {
+ case b@Bind(name, _) =>
+ val vd = defineVal(name.toTermName + AnfTransform.this.name.bindSuffix, gen.mkAttributedStableRef(b.symbol), b.pos)
+ (vd, (b.symbol, vd.symbol))
+ }).unzip
+ val (from, to) = mappings.unzip
+ val b@Block(stats1, expr1) = block.substituteSymbols(from, to).asInstanceOf[Block]
+ val newBlock = treeCopy.Block(b, valDefs ++ stats1, expr1)
+ treeCopy.CaseDef(tree, pat, guard, newBlock)
+ }
+ // Refer to comments the translation of `If` above.
+ val typedMatch = treeCopy.Match(tree, scrutExpr, caseDefs)
+ scrutStats :+ typedMatch
+
+ case LabelDef(name, params, rhs) =>
+ List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
+
+ case TypeApply(fun, targs) =>
+ val funStats :+ simpleFun = linearize.transformToList(fun)
+ funStats :+ treeCopy.TypeApply(tree, simpleFun, targs)
+
+ case _ =>
+ List(tree)
+ }
+ }
+ }
+ }
+}
diff --git a/src/main/scala/scala/async/internal/AsyncAnalysis.scala b/src/main/scala/scala/async/internal/AsyncAnalysis.scala
new file mode 100644
index 0000000..62842c9
--- /dev/null
+++ b/src/main/scala/scala/async/internal/AsyncAnalysis.scala
@@ -0,0 +1,91 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async.internal
+
+import scala.reflect.macros.Context
+import scala.collection.mutable
+
+trait AsyncAnalysis {
+ self: AsyncMacro =>
+
+ import global._
+
+ /**
+ * Analyze the contents of an `async` block in order to:
+ * - Report unsupported `await` calls under nested templates, functions, by-name arguments.
+ *
+ * Must be called on the original tree, not on the ANF transformed tree.
+ */
+ def reportUnsupportedAwaits(tree: Tree, report: Boolean): Boolean = {
+ val analyzer = new UnsupportedAwaitAnalyzer(report)
+ analyzer.traverse(tree)
+ analyzer.hasUnsupportedAwaits
+ }
+
+ private class UnsupportedAwaitAnalyzer(report: Boolean) extends AsyncTraverser {
+ var hasUnsupportedAwaits = false
+
+ override def nestedClass(classDef: ClassDef) {
+ val kind = if (classDef.symbol.isTrait) "trait" else "class"
+ reportUnsupportedAwait(classDef, s"nested ${kind}")
+ }
+
+ override def nestedModule(module: ModuleDef) {
+ reportUnsupportedAwait(module, "nested object")
+ }
+
+ override def nestedMethod(defDef: DefDef) {
+ reportUnsupportedAwait(defDef, "nested method")
+ }
+
+ override def byNameArgument(arg: Tree) {
+ reportUnsupportedAwait(arg, "by-name argument")
+ }
+
+ override def function(function: Function) {
+ reportUnsupportedAwait(function, "nested function")
+ }
+
+ override def patMatFunction(tree: Match) {
+ reportUnsupportedAwait(tree, "nested function")
+ }
+
+ override def traverse(tree: Tree) {
+ def containsAwait = tree exists isAwait
+ tree match {
+ case Try(_, _, _) if containsAwait =>
+ reportUnsupportedAwait(tree, "try/catch")
+ super.traverse(tree)
+ case Return(_) =>
+ abort(tree.pos, "return is illegal within a async block")
+ case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) =>
+ // TODO lift this restriction
+ abort(tree.pos, "lazy vals are illegal within an async block")
+ case _ =>
+ super.traverse(tree)
+ }
+ }
+
+ /**
+ * @return true, if the tree contained an unsupported await.
+ */
+ private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = {
+ val badAwaits: List[RefTree] = tree collect {
+ case rt: RefTree if isAwait(rt) => rt
+ }
+ badAwaits foreach {
+ tree =>
+ reportError(tree.pos, s"await must not be used under a $whyUnsupported.")
+ }
+ badAwaits.nonEmpty
+ }
+
+ private def reportError(pos: Position, msg: String) {
+ hasUnsupportedAwaits = true
+ if (report)
+ abort(pos, msg)
+ }
+ }
+}
diff --git a/src/main/scala/scala/async/internal/AsyncBase.scala b/src/main/scala/scala/async/internal/AsyncBase.scala
new file mode 100644
index 0000000..2f7e38d
--- /dev/null
+++ b/src/main/scala/scala/async/internal/AsyncBase.scala
@@ -0,0 +1,58 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async.internal
+
+import scala.reflect.internal.annotations.compileTimeOnly
+import scala.reflect.macros.Context
+
+/**
+ * A base class for the `async` macro. Subclasses must provide:
+ *
+ * - Concrete types for a given future system
+ * - Tree manipulations to create and complete the equivalent of Future and Promise
+ * in that system.
+ * - The `async` macro declaration itself, and a forwarder for the macro implementation.
+ * (The latter is temporarily needed to workaround bug SI-6650 in the macro system)
+ *
+ * The default implementation, [[scala.async.Async]], binds the macro to `scala.concurrent._`.
+ */
+abstract class AsyncBase {
+ self =>
+
+ type FS <: FutureSystem
+ val futureSystem: FS
+
+ /**
+ * A call to `await` must be nested in an enclosing `async` block.
+ *
+ * A call to `await` does not block the current thread, rather it is a delimiter
+ * used by the enclosing `async` macro. Code following the `await`
+ * call is executed asynchronously, when the argument of `await` has been completed.
+ *
+ * @param awaitable the future from which a value is awaited.
+ * @tparam T the type of that value.
+ * @return the value.
+ */
+ @compileTimeOnly("`await` must be enclosed in an `async` block")
+ def await[T](awaitable: futureSystem.Fut[T]): T = ???
+
+ protected[async] def fallbackEnabled = false
+
+ def asyncImpl[T: c.WeakTypeTag](c: Context)
+ (body: c.Expr[T])
+ (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = {
+ import c.universe._
+
+ val asyncMacro = AsyncMacro(c, futureSystem)
+
+ val code = asyncMacro.asyncTransform[T](
+ body.tree.asInstanceOf[asyncMacro.global.Tree],
+ execContext.tree.asInstanceOf[asyncMacro.global.Tree],
+ fallbackEnabled)(implicitly[c.WeakTypeTag[T]].asInstanceOf[asyncMacro.global.WeakTypeTag[T]])
+
+ AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}")
+ c.Expr[futureSystem.Fut[T]](code.asInstanceOf[Tree])
+ }
+}
diff --git a/src/main/scala/scala/async/internal/AsyncId.scala b/src/main/scala/scala/async/internal/AsyncId.scala
new file mode 100644
index 0000000..394f587
--- /dev/null
+++ b/src/main/scala/scala/async/internal/AsyncId.scala
@@ -0,0 +1,64 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async.internal
+
+import language.experimental.macros
+import scala.reflect.macros.Context
+import scala.reflect.internal.SymbolTable
+
+object AsyncId extends AsyncBase {
+ lazy val futureSystem = IdentityFutureSystem
+ type FS = IdentityFutureSystem.type
+
+ def async[T](body: T) = macro asyncIdImpl[T]
+
+ def asyncIdImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit)
+}
+
+/**
+ * A trivial implementation of [[FutureSystem]] that performs computations
+ * on the current thread. Useful for testing.
+ */
+object IdentityFutureSystem extends FutureSystem {
+
+ class Prom[A](var a: A)
+
+ type Fut[A] = A
+ type ExecContext = Unit
+
+ def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops {
+ val universe: c.type = c
+
+ import universe._
+
+ def execContext: Expr[ExecContext] = Expr[Unit](Literal(Constant(())))
+
+ def promType[A: WeakTypeTag]: Type = weakTypeOf[Prom[A]]
+ def execContextType: Type = weakTypeOf[Unit]
+
+ def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify {
+ new Prom(null.asInstanceOf[A])
+ }
+
+ def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify {
+ prom.splice.a
+ }
+
+ def future[A: WeakTypeTag](t: Expr[A])(execContext: Expr[ExecContext]) = t
+
+ def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U],
+ execContext: Expr[ExecContext]): Expr[Unit] = reify {
+ fun.splice.apply(util.Success(future.splice))
+ Expr[Unit](Literal(Constant(()))).splice
+ }
+
+ def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify {
+ prom.splice.a = value.splice.get
+ Expr[Unit](Literal(Constant(()))).splice
+ }
+
+ def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = ???
+ }
+}
diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala
new file mode 100644
index 0000000..6b7d031
--- /dev/null
+++ b/src/main/scala/scala/async/internal/AsyncMacro.scala
@@ -0,0 +1,29 @@
+package scala.async.internal
+
+import scala.tools.nsc.Global
+import scala.tools.nsc.transform.TypingTransformers
+
+object AsyncMacro {
+ def apply(c: reflect.macros.Context, futureSystem0: FutureSystem): AsyncMacro = {
+ import language.reflectiveCalls
+ val powerContext = c.asInstanceOf[c.type {val universe: Global; val callsiteTyper: universe.analyzer.Typer}]
+ new AsyncMacro {
+ val global: powerContext.universe.type = powerContext.universe
+ val callSiteTyper: global.analyzer.Typer = powerContext.callsiteTyper
+ val futureSystem: futureSystem0.type = futureSystem0
+ val futureSystemOps: futureSystem.Ops {val universe: global.type} = futureSystem0.mkOps(global)
+ val macroApplication: global.Tree = c.macroApplication.asInstanceOf[global.Tree]
+ }
+ }
+}
+
+private[async] trait AsyncMacro
+ extends TypingTransformers
+ with AnfTransform with TransformUtils with Lifter
+ with ExprBuilder with AsyncTransform with AsyncAnalysis {
+
+ val global: Global
+ val callSiteTyper: global.analyzer.Typer
+ val macroApplication: global.Tree
+
+}
diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala
new file mode 100644
index 0000000..bdc8664
--- /dev/null
+++ b/src/main/scala/scala/async/internal/AsyncTransform.scala
@@ -0,0 +1,176 @@
+package scala.async.internal
+
+trait AsyncTransform {
+ self: AsyncMacro =>
+
+ import global._
+
+ def asyncTransform[T](body: Tree, execContext: Tree, cpsFallbackEnabled: Boolean)
+ (implicit resultType: WeakTypeTag[T]): Tree = {
+
+ reportUnsupportedAwaits(body, report = !cpsFallbackEnabled)
+
+ // Transform to A-normal form:
+ // - no await calls in qualifiers or arguments,
+ // - if/match only used in statement position.
+ val anfTree: Block = anfTransform(body)
+
+ val resumeFunTreeDummyBody = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), Literal(Constant(())))
+
+ val applyDefDefDummyBody: DefDef = {
+ val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)))
+ DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), Literal(Constant(())))
+ }
+
+ val stateMachineType = applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType))
+
+ val stateMachine: ClassDef = {
+ val body: List[Tree] = {
+ val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0)))
+ val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree)
+ val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext)
+
+ val apply0DefDef: DefDef = {
+ // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`.
+ // See SI-1247 for the the optimization that avoids creatio
+ DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil))
+ }
+ List(emptyConstructor, stateVar, result, execContextValDef) ++ List(resumeFunTreeDummyBody, applyDefDefDummyBody, apply0DefDef)
+ }
+ val template = {
+ Template(List(stateMachineType), emptyValDef, body)
+ }
+ val t = ClassDef(NoMods, name.stateMachineT, Nil, template)
+ callSiteTyper.typedPos(macroApplication.pos)(Block(t :: Nil, Literal(Constant(()))))
+ t
+ }
+
+ val asyncBlock: AsyncBlock = {
+ val symLookup = new SymLookup(stateMachine.symbol, applyDefDefDummyBody.vparamss.head.head.symbol)
+ buildAsyncBlock(anfTree, symLookup)
+ }
+
+ logDiagnostics(anfTree, asyncBlock.asyncStates.map(_.toString))
+
+ def startStateMachine: Tree = {
+ val stateMachineSpliced: Tree = spliceMethodBodies(
+ liftables(asyncBlock.asyncStates),
+ stateMachine,
+ asyncBlock.onCompleteHandler[T],
+ asyncBlock.resumeFunTree[T].rhs
+ )
+
+ def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection)
+
+ Block(List[Tree](
+ stateMachineSpliced,
+ ValDef(NoMods, name.stateMachine, stateMachineType, Apply(Select(New(Ident(stateMachine.symbol)), nme.CONSTRUCTOR), Nil)),
+ futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil), selectStateMachine(name.execContext))
+ ),
+ futureSystemOps.promiseToFuture(Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree)
+ }
+
+ val isSimple = asyncBlock.asyncStates.size == 1
+ if (isSimple)
+ futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }`
+ else
+ startStateMachine
+ }
+
+ def logDiagnostics(anfTree: Tree, states: Seq[String]) {
+ val pos = macroApplication.pos
+ def location = try {
+ pos.source.path
+ } catch {
+ case _: UnsupportedOperationException =>
+ pos.toString
+ }
+
+ AsyncUtils.vprintln(s"In file '$location':")
+ AsyncUtils.vprintln(s"${macroApplication}")
+ AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree")
+ states foreach (s => AsyncUtils.vprintln(s))
+ }
+
+ def spliceMethodBodies(liftables: List[Tree], tree: Tree, applyBody: Tree,
+ resumeBody: Tree): Tree = {
+
+ val liftedSyms = liftables.map(_.symbol).toSet
+ val stateMachineClass = tree.symbol
+ liftedSyms.foreach {
+ sym =>
+ if (sym != null) {
+ sym.owner = stateMachineClass
+ if (sym.isModule)
+ sym.moduleClass.owner = stateMachineClass
+ }
+ }
+ // Replace the ValDefs in the splicee with Assigns to the corresponding lifted
+ // fields. Similarly, replace references to them with references to the field.
+ //
+ // This transform will be only be run on the RHS of `def foo`.
+ class UseFields extends MacroTypingTransformer {
+ override def transform(tree: Tree): Tree = tree match {
+ case _ if currentOwner == stateMachineClass =>
+ super.transform(tree)
+ case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) =>
+ atOwner(currentOwner) {
+ val fieldSym = tree.symbol
+ val set = Assign(gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), transform(rhs))
+ changeOwner(set, tree.symbol, currentOwner)
+ localTyper.typedPos(tree.pos)(set)
+ }
+ case _: DefTree if liftedSyms(tree.symbol) =>
+ EmptyTree
+ case Ident(name) if liftedSyms(tree.symbol) =>
+ val fieldSym = tree.symbol
+ gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym).setType(tree.tpe)
+ case _ =>
+ super.transform(tree)
+ }
+ }
+
+ val liftablesUseFields = liftables.map {
+ case vd: ValDef => vd
+ case x =>
+ val useField = new UseFields()
+ //.substituteSymbols(fromSyms, toSyms)
+ useField.atOwner(stateMachineClass)(useField.transform(x))
+ }
+
+ tree.children.foreach {
+ t =>
+ new ChangeOwnerAndModuleClassTraverser(callSiteTyper.context.owner, tree.symbol).traverse(t)
+ }
+ val treeSubst = tree
+
+ def fixup(dd: DefDef, body: Tree, ctx: analyzer.Context): Tree = {
+ val spliceeAnfFixedOwnerSyms = body
+ val useField = new UseFields()
+ val newRhs = useField.atOwner(dd.symbol)(useField.transform(spliceeAnfFixedOwnerSyms))
+ val typer = global.analyzer.newTyper(ctx.make(dd, dd.symbol))
+ treeCopy.DefDef(dd, dd.mods, dd.name, dd.tparams, dd.vparamss, dd.tpt, typer.typed(newRhs))
+ }
+
+ liftablesUseFields.foreach(t => if (t.symbol != null) stateMachineClass.info.decls.enter(t.symbol))
+
+ val result0 = transformAt(treeSubst) {
+ case t@Template(parents, self, stats) =>
+ (ctx: analyzer.Context) => {
+ treeCopy.Template(t, parents, self, liftablesUseFields ++ stats)
+ }
+ }
+ val result = transformAt(result0) {
+ case dd@DefDef(_, name.apply, _, List(List(_)), _, _) if dd.symbol.owner == stateMachineClass =>
+ (ctx: analyzer.Context) =>
+ val typedTree = fixup(dd, changeOwner(applyBody, callSiteTyper.context.owner, dd.symbol), ctx)
+ typedTree
+ case dd@DefDef(_, name.resume, _, _, _, _) if dd.symbol.owner == stateMachineClass =>
+ (ctx: analyzer.Context) =>
+ val changed = changeOwner(resumeBody, callSiteTyper.context.owner, dd.symbol)
+ val res = fixup(dd, changed, ctx)
+ res
+ }
+ result
+ }
+}
diff --git a/src/main/scala/scala/async/internal/AsyncUtils.scala b/src/main/scala/scala/async/internal/AsyncUtils.scala
new file mode 100644
index 0000000..8700bd6
--- /dev/null
+++ b/src/main/scala/scala/async/internal/AsyncUtils.scala
@@ -0,0 +1,16 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+package scala.async.internal
+
+object AsyncUtils {
+
+ private def enabled(level: String) = sys.props.getOrElse(s"scala.async.$level", "false").equalsIgnoreCase("true")
+
+ private def verbose = enabled("debug")
+ private def trace = enabled("trace")
+
+ private[async] def vprintln(s: => Any): Unit = if (verbose) println(s"[async] $s")
+
+ private[async] def trace(s: => Any): Unit = if (trace) println(s"[async] $s")
+}
diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala
new file mode 100644
index 0000000..1ce30e6
--- /dev/null
+++ b/src/main/scala/scala/async/internal/ExprBuilder.scala
@@ -0,0 +1,388 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+package scala.async.internal
+
+import scala.reflect.macros.Context
+import scala.collection.mutable.ListBuffer
+import collection.mutable
+import language.existentials
+import scala.reflect.api.Universe
+import scala.reflect.api
+import scala.Some
+
+trait ExprBuilder {
+ builder: AsyncMacro =>
+
+ import global._
+ import defn._
+
+ val futureSystem: FutureSystem
+ val futureSystemOps: futureSystem.Ops { val universe: global.type }
+
+ val stateAssigner = new StateAssigner
+ val labelDefStates = collection.mutable.Map[Symbol, Int]()
+
+ trait AsyncState {
+ def state: Int
+
+ def mkHandlerCaseForState: CaseDef
+
+ def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None
+
+ def stats: List[Tree]
+
+ final def allStats: List[Tree] = this match {
+ case a: AsyncStateWithAwait => stats :+ a.awaitable.resultValDef
+ case _ => stats
+ }
+
+ final def body: Tree = stats match {
+ case stat :: Nil => stat
+ case init :+ last => Block(init, last)
+ }
+ }
+
+ /** A sequence of statements the concludes with a unconditional transition to `nextState` */
+ final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup)
+ extends AsyncState {
+
+ def mkHandlerCaseForState: CaseDef =
+ mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup))
+
+ override val toString: String =
+ s"AsyncState #$state, next = $nextState"
+ }
+
+ /** A sequence of statements with a conditional transition to the next state, which will represent
+ * a branch of an `if` or a `match`.
+ */
+ final class AsyncStateWithoutAwait(val stats: List[Tree], val state: Int) extends AsyncState {
+ override def mkHandlerCaseForState: CaseDef =
+ mkHandlerCase(state, stats)
+
+ override val toString: String =
+ s"AsyncStateWithoutAwait #$state"
+ }
+
+ /** A sequence of statements that concludes with an `await` call. The `onComplete`
+ * handler will unconditionally transition to `nestState`.``
+ */
+ final class AsyncStateWithAwait(val stats: List[Tree], val state: Int, nextState: Int,
+ val awaitable: Awaitable, symLookup: SymLookup)
+ extends AsyncState {
+
+ override def mkHandlerCaseForState: CaseDef = {
+ val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr),
+ Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree
+ mkHandlerCase(state, stats :+ callOnComplete)
+ }
+
+ override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = {
+ val tryGetTree =
+ Assign(
+ Ident(awaitable.resultName),
+ TypeApply(Select(Select(Ident(symLookup.applyTrParam), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType)))
+ )
+
+ /* if (tr.isFailure)
+ * result.complete(tr.asInstanceOf[Try[T]])
+ * else {
+ * <resultName> = tr.get.asInstanceOf[<resultType>]
+ * <nextState>
+ * <mkResumeApply>
+ * }
+ */
+ val ifIsFailureTree =
+ If(Select(Ident(symLookup.applyTrParam), Try_isFailure),
+ futureSystemOps.completeProm[T](
+ Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)),
+ Expr[scala.util.Try[T]](
+ TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")),
+ List(TypeTree(weakTypeOf[scala.util.Try[T]]))))).tree,
+ Block(List(tryGetTree, mkStateTree(nextState, symLookup)), mkResumeApply(symLookup))
+ )
+
+ Some(mkHandlerCase(state, List(ifIsFailureTree)))
+ }
+
+ override val toString: String =
+ s"AsyncStateWithAwait #$state, next = $nextState"
+ }
+
+ /*
+ * Builder for a single state of an async method.
+ */
+ final class AsyncStateBuilder(state: Int, private val symLookup: SymLookup) {
+ /* Statements preceding an await call. */
+ private val stats = ListBuffer[Tree]()
+ /** The state of the target of a LabelDef application (while loop jump) */
+ private var nextJumpState: Option[Int] = None
+
+ def +=(stat: Tree): this.type = {
+ assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat")
+ def addStat() = stats += stat
+ stat match {
+ case Apply(fun, Nil) =>
+ labelDefStates get fun.symbol match {
+ case Some(nextState) => nextJumpState = Some(nextState)
+ case None => addStat()
+ }
+ case _ => addStat()
+ }
+ this
+ }
+
+ def resultWithAwait(awaitable: Awaitable,
+ nextState: Int): AsyncState = {
+ val effectiveNextState = nextJumpState.getOrElse(nextState)
+ new AsyncStateWithAwait(stats.toList, state, effectiveNextState, awaitable, symLookup)
+ }
+
+ def resultSimple(nextState: Int): AsyncState = {
+ val effectiveNextState = nextJumpState.getOrElse(nextState)
+ new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup)
+ }
+
+ def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = {
+ def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup))
+ this += If(condTree, mkBranch(thenState), mkBranch(elseState))
+ new AsyncStateWithoutAwait(stats.toList, state)
+ }
+
+ /**
+ * Build `AsyncState` ending with a match expression.
+ *
+ * The cases of the match simply resume at the state of their corresponding right-hand side.
+ *
+ * @param scrutTree tree of the scrutinee
+ * @param cases list of case definitions
+ * @param caseStates starting state of the right-hand side of the each case
+ * @return an `AsyncState` representing the match expression
+ */
+ def resultWithMatch(scrutTree: Tree, cases: List[CaseDef], caseStates: List[Int], symLookup: SymLookup): AsyncState = {
+ // 1. build list of changed cases
+ val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match {
+ case CaseDef(pat, guard, rhs) =>
+ val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal)
+ CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num), symLookup), mkResumeApply(symLookup)))
+ }
+ // 2. insert changed match tree at the end of the current state
+ this += Match(scrutTree, newCases)
+ new AsyncStateWithoutAwait(stats.toList, state)
+ }
+
+ def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = {
+ this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup))
+ new AsyncStateWithoutAwait(stats.toList, state)
+ }
+
+ override def toString: String = {
+ val statsBeforeAwait = stats.mkString("\n")
+ s"ASYNC STATE:\n$statsBeforeAwait"
+ }
+ }
+
+ /**
+ * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`).
+ *
+ * @param stats a list of expressions
+ * @param expr the last expression of the block
+ * @param startState the start state
+ * @param endState the state to continue with
+ */
+ final private class AsyncBlockBuilder(stats: List[Tree], expr: Tree, startState: Int, endState: Int,
+ private val symLookup: SymLookup) {
+ val asyncStates = ListBuffer[AsyncState]()
+
+ var stateBuilder = new AsyncStateBuilder(startState, symLookup)
+ var currState = startState
+
+ /* TODO Fall back to CPS plug-in if tree contains an `await` call. */
+ def checkForUnsupportedAwait(tree: Tree) = if (tree exists {
+ case Apply(fun, _) if isAwait(fun) => true
+ case _ => false
+ }) abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException
+
+ def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = {
+ val (nestedStats, nestedExpr) = statsAndExpr(nestedTree)
+ new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, symLookup)
+ }
+
+ import stateAssigner.nextState
+
+ // populate asyncStates
+ for (stat <- stats) stat match {
+ // the val name = await(..) pattern
+ case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) =>
+ val afterAwaitState = nextState()
+ val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd)
+ asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await
+ currState = afterAwaitState
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
+
+ case If(cond, thenp, elsep) if stat exists isAwait =>
+ checkForUnsupportedAwait(cond)
+
+ val thenStartState = nextState()
+ val elseStartState = nextState()
+ val afterIfState = nextState()
+
+ asyncStates +=
+ // the two Int arguments are the start state of the then branch and the else branch, respectively
+ stateBuilder.resultWithIf(cond, thenStartState, elseStartState)
+
+ List((thenp, thenStartState), (elsep, elseStartState)) foreach {
+ case (branchTree, state) =>
+ val builder = nestedBlockBuilder(branchTree, state, afterIfState)
+ asyncStates ++= builder.asyncStates
+ }
+
+ currState = afterIfState
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
+
+ case Match(scrutinee, cases) if stat exists isAwait =>
+ checkForUnsupportedAwait(scrutinee)
+
+ val caseStates = cases.map(_ => nextState())
+ val afterMatchState = nextState()
+
+ asyncStates +=
+ stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup)
+
+ for ((cas, num) <- cases.zipWithIndex) {
+ val (stats, expr) = statsAndExpr(cas.body)
+ val stats1 = stats.dropWhile(isSyntheticBindVal)
+ val builder = nestedBlockBuilder(Block(stats1, expr), caseStates(num), afterMatchState)
+ asyncStates ++= builder.asyncStates
+ }
+
+ currState = afterMatchState
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
+
+ case ld@LabelDef(name, params, rhs) if rhs exists isAwait =>
+ val startLabelState = nextState()
+ val afterLabelState = nextState()
+ asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup)
+ labelDefStates(ld.symbol) = startLabelState
+ val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState)
+ asyncStates ++= builder.asyncStates
+
+ currState = afterLabelState
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
+ case _ =>
+ checkForUnsupportedAwait(stat)
+ stateBuilder += stat
+ }
+ // complete last state builder (representing the expressions after the last await)
+ stateBuilder += expr
+ val lastState = stateBuilder.resultSimple(endState)
+ asyncStates += lastState
+ }
+
+ trait AsyncBlock {
+ def asyncStates: List[AsyncState]
+
+ def onCompleteHandler[T: WeakTypeTag]: Tree
+
+ def resumeFunTree[T]: DefDef
+ }
+
+ case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) {
+ def stateMachineMember(name: TermName): Symbol =
+ stateMachineClass.info.member(name)
+ def memberRef(name: TermName) = gen.mkAttributedRef(stateMachineMember(name))
+ }
+
+ def buildAsyncBlock(block: Block, symLookup: SymLookup): AsyncBlock = {
+ val Block(stats, expr) = block
+ val startState = stateAssigner.nextState()
+ val endState = Int.MaxValue
+
+ val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, symLookup)
+
+ new AsyncBlock {
+ def asyncStates = blockBuilder.asyncStates.toList
+
+ def mkCombinedHandlerCases[T]: List[CaseDef] = {
+ val caseForLastState: CaseDef = {
+ val lastState = asyncStates.last
+ val lastStateBody = Expr[T](lastState.body)
+ val rhs = futureSystemOps.completeProm(
+ Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Success(lastStateBody.splice)))
+ mkHandlerCase(lastState.state, rhs.tree)
+ }
+ asyncStates.toList match {
+ case s :: Nil =>
+ List(caseForLastState)
+ case _ =>
+ val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState
+ initCases :+ caseForLastState
+ }
+ }
+
+ val initStates = asyncStates.init
+
+ /**
+ * def resume(): Unit = {
+ * try {
+ * state match {
+ * case 0 => {
+ * f11 = exprReturningFuture
+ * f11.onComplete(onCompleteHandler)(context)
+ * }
+ * ...
+ * }
+ * } catch {
+ * case NonFatal(t) => result.failure(t)
+ * }
+ * }
+ */
+ def resumeFunTree[T]: DefDef =
+ DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass),
+ Try(
+ Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T]),
+ List(
+ CaseDef(
+ Bind(name.t, Ident(nme.WILDCARD)),
+ Apply(Ident(defn.NonFatalClass), List(Ident(name.t))),
+ Block(List({
+ val t = Expr[Throwable](Ident(name.t))
+ futureSystemOps.completeProm[T](
+ Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Failure(t.splice))).tree
+ }), literalUnit))), EmptyTree))
+
+ /**
+ * // assumes tr: Try[Any] is in scope.
+ * //
+ * state match {
+ * case 0 => {
+ * x11 = tr.get.asInstanceOf[Double];
+ * state = 1;
+ * resume()
+ * }
+ */
+ def onCompleteHandler[T: WeakTypeTag]: Tree = Match(symLookup.memberRef(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList)
+ }
+ }
+
+ private def isSyntheticBindVal(tree: Tree) = tree match {
+ case vd@ValDef(_, lname, _, Ident(rname)) => lname.toString.contains(name.bindSuffix)
+ case _ => false
+ }
+
+ case class Awaitable(expr: Tree, resultName: Symbol, resultType: Type, resultValDef: ValDef)
+
+ private def mkResumeApply(symLookup: SymLookup) = Apply(symLookup.memberRef(name.resume), Nil)
+
+ private def mkStateTree(nextState: Int, symLookup: SymLookup): Tree =
+ Assign(symLookup.memberRef(name.state), Literal(Constant(nextState)))
+
+ private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef =
+ mkHandlerCase(num, Block(rhs, literalUnit))
+
+ private def mkHandlerCase(num: Int, rhs: Tree): CaseDef =
+ CaseDef(Literal(Constant(num)), EmptyTree, rhs)
+
+ private def literalUnit = Literal(Constant(()))
+}
diff --git a/src/main/scala/scala/async/internal/FutureSystem.scala b/src/main/scala/scala/async/internal/FutureSystem.scala
new file mode 100644
index 0000000..101b7bf
--- /dev/null
+++ b/src/main/scala/scala/async/internal/FutureSystem.scala
@@ -0,0 +1,106 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+package scala.async.internal
+
+import scala.language.higherKinds
+
+import scala.reflect.macros.Context
+import scala.reflect.internal.SymbolTable
+
+/**
+ * An abstraction over a future system.
+ *
+ * Used by the macro implementations in [[scala.async.AsyncBase]] to
+ * customize the code generation.
+ *
+ * The API mirrors that of `scala.concurrent.Future`, see the instance
+ * [[ScalaConcurrentFutureSystem]] for an example of how
+ * to implement this.
+ */
+trait FutureSystem {
+ /** A container to receive the final value of the computation */
+ type Prom[A]
+ /** A (potentially in-progress) computation */
+ type Fut[A]
+ /** An execution context, required to create or register an on completion callback on a Future. */
+ type ExecContext
+
+ trait Ops {
+ val universe: reflect.internal.SymbolTable
+
+ import universe._
+ def Expr[T: WeakTypeTag](tree: Tree): Expr[T] = universe.Expr[T](rootMirror, universe.FixedMirrorTreeCreator(rootMirror, tree))
+
+ def promType[A: WeakTypeTag]: Type
+ def execContextType: Type
+
+ /** Create an empty promise */
+ def createProm[A: WeakTypeTag]: Expr[Prom[A]]
+
+ /** Extract a future from the given promise. */
+ def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]): Expr[Fut[A]]
+
+ /** Construct a future to asynchronously compute the given expression */
+ def future[A: WeakTypeTag](a: Expr[A])(execContext: Expr[ExecContext]): Expr[Fut[A]]
+
+ /** Register an call back to run on completion of the given future */
+ def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U],
+ execContext: Expr[ExecContext]): Expr[Unit]
+
+ /** Complete a promise with a value */
+ def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit]
+
+ def spawn(tree: Tree, execContext: Tree): Tree =
+ future(Expr[Unit](tree))(Expr[ExecContext](execContext)).tree
+
+ // TODO Why is this needed?
+ def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]]
+ }
+
+ def mkOps(c: SymbolTable): Ops { val universe: c.type }
+}
+
+object ScalaConcurrentFutureSystem extends FutureSystem {
+
+ import scala.concurrent._
+
+ type Prom[A] = Promise[A]
+ type Fut[A] = Future[A]
+ type ExecContext = ExecutionContext
+
+ def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops {
+ val universe: c.type = c
+
+ import universe._
+
+ def promType[A: WeakTypeTag]: Type = weakTypeOf[Promise[A]]
+ def execContextType: Type = weakTypeOf[ExecutionContext]
+
+ def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify {
+ Promise[A]()
+ }
+
+ def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify {
+ prom.splice.future
+ }
+
+ def future[A: WeakTypeTag](a: Expr[A])(execContext: Expr[ExecContext]) = reify {
+ Future(a.splice)(execContext.splice)
+ }
+
+ def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U],
+ execContext: Expr[ExecContext]): Expr[Unit] = reify {
+ future.splice.onComplete(fun.splice)(execContext.splice)
+ }
+
+ def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify {
+ prom.splice.complete(value.splice)
+ Expr[Unit](Literal(Constant(()))).splice
+ }
+
+ def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = reify {
+ future.splice.asInstanceOf[Fut[A]]
+ }
+ }
+}
diff --git a/src/main/scala/scala/async/internal/Lifter.scala b/src/main/scala/scala/async/internal/Lifter.scala
new file mode 100644
index 0000000..f49dcbb
--- /dev/null
+++ b/src/main/scala/scala/async/internal/Lifter.scala
@@ -0,0 +1,150 @@
+package scala.async.internal
+
+trait Lifter {
+ self: AsyncMacro =>
+ import global._
+
+ /**
+ * Identify which DefTrees are used (including transitively) which are declared
+ * in some state but used (including transitively) in another state.
+ *
+ * These will need to be lifted to class members of the state machine.
+ */
+ def liftables(asyncStates: List[AsyncState]): List[Tree] = {
+ object companionship {
+ private val companions = collection.mutable.Map[Symbol, Symbol]()
+ private val companionsInverse = collection.mutable.Map[Symbol, Symbol]()
+ private def record(sym1: Symbol, sym2: Symbol) {
+ companions(sym1) = sym2
+ companions(sym2) = sym1
+ }
+
+ def record(defs: List[Tree]) {
+ // Keep note of local companions so we rename them consistently
+ // when lifting.
+ val comps = for {
+ cd@ClassDef(_, _, _, _) <- defs
+ md@ModuleDef(_, _, _) <- defs
+ if (cd.name.toTermName == md.name)
+ } record(cd.symbol, md.symbol)
+ }
+ def companionOf(sym: Symbol): Symbol = {
+ companions.get(sym).orElse(companionsInverse.get(sym)).getOrElse(NoSymbol)
+ }
+ }
+
+
+ val defs: Map[Tree, Int] = {
+ /** Collect the DefTrees directly enclosed within `t` that have the same owner */
+ def collectDirectlyEnclosedDefs(t: Tree): List[DefTree] = t match {
+ case dt: DefTree => dt :: Nil
+ case _: Function => Nil
+ case t =>
+ val childDefs = t.children.flatMap(collectDirectlyEnclosedDefs(_))
+ companionship.record(childDefs)
+ childDefs
+ }
+ asyncStates.flatMap {
+ asyncState =>
+ val defs = collectDirectlyEnclosedDefs(Block(asyncState.allStats: _*))
+ defs.map((_, asyncState.state))
+ }.toMap
+ }
+
+ // In which block are these symbols defined?
+ val symToDefiningState: Map[Symbol, Int] = defs.map {
+ case (k, v) => (k.symbol, v)
+ }
+
+ // The definitions trees
+ val symToTree: Map[Symbol, Tree] = defs.map {
+ case (k, v) => (k.symbol, k)
+ }
+
+ // The direct references of each definition tree
+ val defSymToReferenced: Map[Symbol, List[Symbol]] = defs.keys.map {
+ case tree => (tree.symbol, tree.collect {
+ case rt: RefTree if symToDefiningState.contains(rt.symbol) => rt.symbol
+ })
+ }.toMap
+
+ // The direct references of each block, excluding references of `DefTree`-s which
+ // are already accounted for.
+ val stateIdToDirectlyReferenced: Map[Int, List[Symbol]] = {
+ val refs: List[(Int, Symbol)] = asyncStates.flatMap(
+ asyncState => asyncState.stats.filterNot(_.isDef).flatMap(_.collect {
+ case rt: RefTree if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol)
+ })
+ )
+ toMultiMap(refs)
+ }
+
+ def liftableSyms: Set[Symbol] = {
+ val liftableMutableSet = collection.mutable.Set[Symbol]()
+ def markForLift(sym: Symbol) {
+ if (!liftableMutableSet(sym)) {
+ liftableMutableSet += sym
+
+ // Only mark transitive references of defs, modules and classes. The RHS of lifted vals/vars
+ // stays in its original location, so things that it refers to need not be lifted.
+ if (!(sym.isVal || sym.isVar))
+ defSymToReferenced(sym).foreach(sym2 => markForLift(sym2))
+ }
+ }
+ // Start things with DefTrees directly referenced from statements from other states...
+ val liftableStatementRefs: List[Symbol] = stateIdToDirectlyReferenced.toList.flatMap {
+ case (i, syms) => syms.filter(sym => symToDefiningState(sym) != i)
+ }
+ // .. and likewise for DefTrees directly referenced by other DefTrees from other states
+ val liftableRefsOfDefTrees = defSymToReferenced.toList.flatMap {
+ case (referee, referents) => referents.filter(sym => symToDefiningState(sym) != symToDefiningState(referee))
+ }
+ // Mark these for lifting, which will follow transitive references.
+ (liftableStatementRefs ++ liftableRefsOfDefTrees).foreach(markForLift)
+ liftableMutableSet.toSet
+ }
+
+ val lifted = liftableSyms.map(symToTree).toList.map {
+ case vd@ValDef(_, _, tpt, rhs) =>
+ import reflect.internal.Flags._
+ val sym = vd.symbol
+ sym.setFlag(MUTABLE | STABLE | PRIVATE | LOCAL)
+ sym.name = name.fresh(sym.name.toTermName)
+ sym.modifyInfo(_.deconst)
+ ValDef(vd.symbol, gen.mkZero(vd.symbol.info)).setPos(vd.pos)
+ case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
+ import reflect.internal.Flags._
+ val sym = dd.symbol
+ sym.name = this.name.fresh(sym.name.toTermName)
+ sym.setFlag(PRIVATE | LOCAL)
+ DefDef(dd.symbol, rhs).setPos(dd.pos)
+ case cd@ClassDef(_, _, _, impl) =>
+ import reflect.internal.Flags._
+ val sym = cd.symbol
+ sym.name = newTypeName(name.fresh(sym.name.toString).toString)
+ companionship.companionOf(cd.symbol) match {
+ case NoSymbol =>
+ case moduleSymbol =>
+ moduleSymbol.name = sym.name.toTermName
+ moduleSymbol.moduleClass.name = moduleSymbol.name.toTypeName
+ }
+ ClassDef(cd.symbol, impl).setPos(cd.pos)
+ case md@ModuleDef(_, _, impl) =>
+ import reflect.internal.Flags._
+ val sym = md.symbol
+ companionship.companionOf(md.symbol) match {
+ case NoSymbol =>
+ sym.name = name.fresh(sym.name.toTermName)
+ sym.moduleClass.name = sym.name.toTypeName
+ case classSymbol => // will be renamed by `case ClassDef` above.
+ }
+ ModuleDef(md.symbol, impl).setPos(md.pos)
+ case td@TypeDef(_, _, _, rhs) =>
+ import reflect.internal.Flags._
+ val sym = td.symbol
+ sym.name = newTypeName(name.fresh(sym.name.toString).toString)
+ TypeDef(td.symbol, rhs).setPos(td.pos)
+ }
+ lifted
+ }
+}
diff --git a/src/main/scala/scala/async/internal/StateAssigner.scala b/src/main/scala/scala/async/internal/StateAssigner.scala
new file mode 100644
index 0000000..cdde7a4
--- /dev/null
+++ b/src/main/scala/scala/async/internal/StateAssigner.scala
@@ -0,0 +1,14 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async.internal
+
+private[async] final class StateAssigner {
+ private var current = -1
+
+ def nextState(): Int = {
+ current += 1
+ current
+ }
+}
diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala
new file mode 100644
index 0000000..2582c91
--- /dev/null
+++ b/src/main/scala/scala/async/internal/TransformUtils.scala
@@ -0,0 +1,251 @@
+/*
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+package scala.async.internal
+
+import scala.reflect.macros.Context
+import reflect.ClassTag
+import scala.reflect.macros.runtime.AbortMacroException
+
+/**
+ * Utilities used in both `ExprBuilder` and `AnfTransform`.
+ */
+private[async] trait TransformUtils {
+ self: AsyncMacro =>
+
+ import global._
+
+ object name {
+ val resume = newTermName("resume")
+ val apply = newTermName("apply")
+ val matchRes = "matchres"
+ val ifRes = "ifres"
+ val await = "await"
+ val bindSuffix = "$bind"
+
+ val state = newTermName("state")
+ val result = newTermName("result")
+ val execContext = newTermName("execContext")
+ val stateMachine = newTermName(fresh("stateMachine"))
+ val stateMachineT = stateMachine.toTypeName
+ val tr = newTermName("tr")
+ val t = newTermName("throwable")
+
+ def fresh(name: TermName): TermName = newTermName(fresh(name.toString))
+
+ def fresh(name: String): String = currentUnit.freshTermName("" + name + "$").toString
+ }
+
+ def isAwait(fun: Tree) =
+ fun.symbol == defn.Async_await
+
+ private lazy val Boolean_ShortCircuits: Set[Symbol] = {
+ import definitions.BooleanClass
+ def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName)
+ val Boolean_&& = BooleanTermMember("&&")
+ val Boolean_|| = BooleanTermMember("||")
+ Set(Boolean_&&, Boolean_||)
+ }
+
+ private def isByName(fun: Tree): ((Int, Int) => Boolean) = {
+ if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true
+ else {
+ val paramss = fun.tpe.paramss
+ val byNamess = paramss.map(_.map(_.isByNameParam))
+ (i, j) => util.Try(byNamess(i)(j)).getOrElse(false)
+ }
+ }
+ private def argName(fun: Tree): ((Int, Int) => String) = {
+ val paramss = fun.tpe.paramss
+ val namess = paramss.map(_.map(_.name.toString))
+ (i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}")
+ }
+
+ def Expr[A: WeakTypeTag](t: Tree) = global.Expr[A](rootMirror, new FixedMirrorTreeCreator(rootMirror, t))
+
+ object defn {
+ def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = {
+ Expr(Apply(Ident(definitions.List_apply), args.map(_.tree)))
+ }
+
+ def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify {
+ self.splice.contains(elem.splice)
+ }
+
+ def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify {
+ self.splice.apply(arg.splice)
+ }
+
+ def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify {
+ self.splice == other.splice
+ }
+
+ def mkTry_get[A](self: Expr[util.Try[A]]) = reify {
+ self.splice.get
+ }
+
+ val TryClass = rootMirror.staticClass("scala.util.Try")
+ val Try_get = TryClass.typeSignature.member(newTermName("get")).ensuring(_ != NoSymbol)
+ val Try_isFailure = TryClass.typeSignature.member(newTermName("isFailure")).ensuring(_ != NoSymbol)
+ val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe))
+ val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal")
+ val AsyncClass = rootMirror.staticClass("scala.async.internal.AsyncBase")
+ val Async_await = AsyncClass.typeSignature.member(newTermName("await")).ensuring(_ != NoSymbol)
+ }
+
+ def isSafeToInline(tree: Tree) = {
+ treeInfo.isExprSafeToInline(tree)
+ }
+
+ /** Map a list of arguments to:
+ * - A list of argument Trees
+ * - A list of auxillary results.
+ *
+ * The function unwraps and rewraps the `arg :_*` construct.
+ *
+ * @param args The original argument trees
+ * @param f A function from argument (with '_*' unwrapped) and argument index to argument.
+ * @tparam A The type of the auxillary result
+ */
+ private def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = {
+ args match {
+ case args :+ Typed(tree, Ident(tpnme.WILDCARD_STAR)) =>
+ val (a, argExprs :+ lastArgExpr) = (args :+ tree).zipWithIndex.map(f.tupled).unzip
+ val exprs = argExprs :+ Typed(lastArgExpr, Ident(tpnme.WILDCARD_STAR)).setPos(lastArgExpr.pos)
+ (a, exprs)
+ case args =>
+ args.zipWithIndex.map(f.tupled).unzip
+ }
+ }
+
+ case class Arg(expr: Tree, isByName: Boolean, argName: String)
+
+ /**
+ * Transform a list of argument lists, producing the transformed lists, and lists of auxillary
+ * results.
+ *
+ * The function `f` need not concern itself with varargs arguments e.g (`xs : _*`). It will
+ * receive `xs`, and it's result will be re-wrapped as `f(xs) : _*`.
+ *
+ * @param fun The function being applied
+ * @param argss The argument lists
+ * @return (auxillary results, mapped argument trees)
+ */
+ def mapArgumentss[A](fun: Tree, argss: List[List[Tree]])(f: Arg => (A, Tree)): (List[List[A]], List[List[Tree]]) = {
+ val isByNamess: (Int, Int) => Boolean = isByName(fun)
+ val argNamess: (Int, Int) => String = argName(fun)
+ argss.zipWithIndex.map { case (args, i) =>
+ mapArguments[A](args) {
+ (tree, j) => f(Arg(tree, isByNamess(i, j), argNamess(i, j)))
+ }
+ }.unzip
+ }
+
+
+ def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match {
+ case Block(stats, expr) => (stats, expr)
+ case _ => (List(tree), Literal(Constant(())))
+ }
+
+ def emptyConstructor: DefDef = {
+ val emptySuperCall = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil)
+ DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), Literal(Constant(()))))
+ }
+
+ def applied(className: String, types: List[Type]): AppliedTypeTree =
+ AppliedTypeTree(Ident(rootMirror.staticClass(className)), types.map(TypeTree(_)))
+
+ /** Descends into the regions of the tree that are subject to the
+ * translation to a state machine by `async`. When a nested template,
+ * function, or by-name argument is encountered, the descent stops,
+ * and `nestedClass` etc are invoked.
+ */
+ trait AsyncTraverser extends Traverser {
+ def nestedClass(classDef: ClassDef) {
+ }
+
+ def nestedModule(module: ModuleDef) {
+ }
+
+ def nestedMethod(module: DefDef) {
+ }
+
+ def byNameArgument(arg: Tree) {
+ }
+
+ def function(function: Function) {
+ }
+
+ def patMatFunction(tree: Match) {
+ }
+
+ override def traverse(tree: Tree) {
+ tree match {
+ case cd: ClassDef => nestedClass(cd)
+ case md: ModuleDef => nestedModule(md)
+ case dd: DefDef => nestedMethod(dd)
+ case fun: Function => function(fun)
+ case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions`
+ case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty =>
+ val isInByName = isByName(fun)
+ for ((args, i) <- argss.zipWithIndex) {
+ for ((arg, j) <- args.zipWithIndex) {
+ if (!isInByName(i, j)) traverse(arg)
+ else byNameArgument(arg)
+ }
+ }
+ traverse(fun)
+ case _ => super.traverse(tree)
+ }
+ }
+ }
+
+ def abort(pos: Position, msg: String) = throw new AbortMacroException(pos, msg)
+
+ abstract class MacroTypingTransformer extends TypingTransformer(callSiteTyper.context.unit) {
+ currentOwner = callSiteTyper.context.owner
+
+ def currOwner: Symbol = currentOwner
+
+ localTyper = global.analyzer.newTyper(callSiteTyper.context.make(unit = callSiteTyper.context.unit))
+ }
+
+ def transformAt(tree: Tree)(f: PartialFunction[Tree, (analyzer.Context => Tree)]) = {
+ object trans extends MacroTypingTransformer {
+ override def transform(tree: Tree): Tree = {
+ if (f.isDefinedAt(tree)) {
+ f(tree)(localTyper.context)
+ } else super.transform(tree)
+ }
+ }
+ trans.transform(tree)
+ }
+
+ def changeOwner(tree: Tree, oldOwner: Symbol, newOwner: Symbol): tree.type = {
+ new ChangeOwnerAndModuleClassTraverser(oldOwner, newOwner).traverse(tree)
+ tree
+ }
+
+ class ChangeOwnerAndModuleClassTraverser(oldowner: Symbol, newowner: Symbol)
+ extends ChangeOwnerTraverser(oldowner, newowner) {
+
+ override def traverse(tree: Tree) {
+ tree match {
+ case _: DefTree => change(tree.symbol.moduleClass)
+ case _ =>
+ }
+ super.traverse(tree)
+ }
+ }
+
+ def toMultiMap[A, B](as: Iterable[(A, B)]): Map[A, List[B]] =
+ as.toList.groupBy(_._1).mapValues(_.map(_._2).toList).toMap
+
+ // Attributed version of `TreeGen#mkCastPreservingAnnotations`
+ def mkAttributedCastPreservingAnnotations(tree: Tree, tp: Type): Tree = {
+ atPos(tree.pos) {
+ val casted = gen.mkAttributedCast(tree, tp.withoutAnnotations.dealias)
+ Typed(casted, TypeTree(tp)).setType(tp)
+ }
+ }
+}