aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2013-07-02 15:55:34 +0200
committerJason Zaugg <jzaugg@gmail.com>2013-07-03 10:04:55 +0200
commit82232ec47effb4a6b67b3a0792e1c7600e2d31b7 (patch)
treeed9925418aa0a631d1d25fd1be30f5d508e81b24
parentd63b63f536aafa494c70835526174be1987050de (diff)
downloadscala-async-82232ec47effb4a6b67b3a0792e1c7600e2d31b7.tar.gz
scala-async-82232ec47effb4a6b67b3a0792e1c7600e2d31b7.tar.bz2
scala-async-82232ec47effb4a6b67b3a0792e1c7600e2d31b7.zip
An overdue overhaul of macro internals.
- Avoid reset + retypecheck, instead hang onto the original types/symbols - Eliminated duplication between AsyncDefinitionUseAnalyzer and ExprBuilder - Instead, decide what do lift *after* running ExprBuilder - Account for transitive references local classes/objects and lift them as needed. - Make the execution context an regular implicit parameter of the macro - Fixes interaction with existential skolems and singleton types Fixes #6, #13, #16, #17, #19, #21.
-rw-r--r--src/main/scala/scala/async/AnfTransform.scala450
-rw-r--r--src/main/scala/scala/async/Async.scala139
-rw-r--r--src/main/scala/scala/async/AsyncAnalysis.scala133
-rw-r--r--src/main/scala/scala/async/AsyncMacro.scala29
-rw-r--r--src/main/scala/scala/async/AsyncTransform.scala176
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala205
-rw-r--r--src/main/scala/scala/async/FutureSystem.scala50
-rw-r--r--src/main/scala/scala/async/Lifter.scala150
-rw-r--r--src/main/scala/scala/async/TransformUtils.scala391
-rw-r--r--src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala43
-rw-r--r--src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala9
-rw-r--r--src/main/scala/scala/async/continuations/CPSBasedAsync.scala11
-rw-r--r--src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala8
-rw-r--r--src/test/scala/scala/async/TreeInterrogation.scala6
-rw-r--r--src/test/scala/scala/async/neg/LocalClasses0Spec.scala123
-rw-r--r--src/test/scala/scala/async/run/anf/AnfTransformSpec.scala2
-rw-r--r--src/test/scala/scala/async/run/nesteddef/NestedDef.scala56
-rw-r--r--src/test/scala/scala/async/run/toughtype/ToughType.scala38
18 files changed, 1021 insertions, 998 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala
index 5b9901d..275bc49 100644
--- a/src/main/scala/scala/async/AnfTransform.scala
+++ b/src/main/scala/scala/async/AnfTransform.scala
@@ -5,270 +5,246 @@
package scala.async
-import scala.reflect.macros.Context
+import scala.tools.nsc.Global
-private[async] final case class AnfTransform[C <: Context](c: C) {
+private[async] trait AnfTransform {
+ self: AsyncMacro =>
- import c.universe._
+ import global._
+ import reflect.internal.Flags._
- val utils = TransformUtils[c.type](c)
-
- import utils._
-
- def apply(tree: Tree): List[Tree] = {
- val unique = uniqueNames(tree)
+ def anfTransform(tree: Tree): Block = {
// Must prepend the () for issue #31.
- anf.transformToList(Block(List(c.literalUnit.tree), unique))
- }
+ val block = callSiteTyper.typedPos(tree.pos)(Block(List(Literal(Constant(()))), tree)).setType(tree.tpe)
- private def uniqueNames(tree: Tree): Tree = {
- new UniqueNames(tree).transform(tree)
+ new SelectiveAnfTransform().transform(block)
}
- /** Assigns unique names to all definitions in a tree, and adjusts references to use the new name.
- * Only modifies names that appear more than once in the tree.
- *
- * This step is needed to allow us to safely merge blocks during the `inline` transform below.
- */
- private final class UniqueNames(tree: Tree) extends Transformer {
- val repeatedNames: Set[Symbol] = {
- class DuplicateNameTraverser extends AsyncTraverser {
- val result = collection.mutable.Buffer[Symbol]()
-
- override def traverse(tree: Tree) {
- tree match {
- case dt: DefTree => result += dt.symbol
- case _ => super.traverse(tree)
- }
- }
- }
- val dupNameTraverser = new DuplicateNameTraverser
- dupNameTraverser.traverse(tree)
- dupNameTraverser.result.groupBy(x => x.name).filter(_._2.size > 1).values.flatten.toSet[Symbol]
- }
+ sealed abstract class AnfMode
+
+ case object Anf extends AnfMode
- /** Stepping outside of the public Macro API to call [[scala.reflect.internal.Symbols.Symbol.name_=]] */
- val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
+ case object Linearizing extends AnfMode
- val renamed = collection.mutable.Set[Symbol]()
+ final class SelectiveAnfTransform extends MacroTypingTransformer {
+ var mode: AnfMode = Anf
- override def transform(tree: Tree): Tree = {
+ def blockToList(tree: Tree): List[Tree] = tree match {
+ case Block(stats, expr) => stats :+ expr
+ case t => t :: Nil
+ }
+
+ def listToBlock(trees: List[Tree]): Block = trees match {
+ case trees @ (init :+ last) =>
+ val pos = trees.map(_.pos).reduceLeft(_ union _)
+ Block(init, last).setType(last.tpe).setPos(pos)
+ }
+
+ override def transform(tree: Tree): Block = {
+ def anfLinearize: Block = {
+ val trees: List[Tree] = mode match {
+ case Anf => anf._transformToList(tree)
+ case Linearizing => linearize._transformToList(tree)
+ }
+ listToBlock(trees)
+ }
tree match {
- case defTree: DefTree if repeatedNames(defTree.symbol) =>
- val trans = super.transform(defTree)
- val origName = defTree.symbol.name
- val sym = defTree.symbol.asInstanceOf[symtab.Symbol]
- val fresh = name.fresh(sym.name.toString)
- sym.name = origName match {
- case _: TermName => symtab.newTermName(fresh)
- case _: TypeName => symtab.newTypeName(fresh)
- }
- renamed += trans.symbol
- val newName = trans.symbol.name
- trans match {
- case ValDef(mods, name, tpt, rhs) =>
- treeCopy.ValDef(trans, mods, newName, tpt, rhs)
- case Bind(name, body) =>
- treeCopy.Bind(trans, newName, body)
- case DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
- treeCopy.DefDef(trans, mods, newName, tparams, vparamss, tpt, rhs)
- case TypeDef(mods, name, tparams, rhs) =>
- treeCopy.TypeDef(tree, mods, newName, tparams, transform(rhs))
- // If we were to allow local classes / objects, we would need to rename here.
- case ClassDef(mods, name, tparams, impl) =>
- treeCopy.ClassDef(tree, mods, newName, tparams, transform(impl).asInstanceOf[Template])
- case ModuleDef(mods, name, impl) =>
- treeCopy.ModuleDef(tree, mods, newName, transform(impl).asInstanceOf[Template])
- case x => super.transform(x)
- }
- case Ident(name) =>
- if (renamed(tree.symbol)) treeCopy.Ident(tree, tree.symbol.name)
- else tree
- case Select(fun, name) =>
- if (renamed(tree.symbol)) {
- treeCopy.Select(tree, transform(fun), tree.symbol.name)
- } else super.transform(tree)
- case tt: TypeTree =>
- val tt1 = tt.asInstanceOf[symtab.TypeTree]
- val orig = tt1.original
- if (orig != null) tt1.setOriginal(transform(orig.asInstanceOf[Tree]).asInstanceOf[symtab.Tree])
- super.transform(tt)
- case _ => super.transform(tree)
+ case _: ValDef | _: DefDef | _: Function | _: ClassDef | _: TypeDef =>
+ atOwner(tree.symbol)(anfLinearize)
+ case _: ModuleDef =>
+ atOwner(tree.symbol.moduleClass orElse tree.symbol)(anfLinearize)
+ case _ =>
+ anfLinearize
}
}
- }
- private object trace {
- private var indent = -1
-
- def indentString = " " * indent
-
- def apply[T](prefix: String, args: Any)(t: => T): T = {
- indent += 1
- def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127)
- try {
- AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})")
- val result = t
- AsyncUtils.trace(s"${indentString}= ${oneLine(result)}")
- result
- } finally {
- indent -= 1
+ private object linearize {
+ def transformToList(tree: Tree): List[Tree] = {
+ mode = Linearizing; blockToList(transform(tree))
}
- }
- }
- private object inline {
- def transformToList(tree: Tree): List[Tree] = trace("inline", tree) {
- val stats :+ expr = anf.transformToList(tree)
- expr match {
- case Apply(fun, args) if isAwait(fun) =>
- val valDef = defineVal(name.await, expr, tree.pos)
- stats :+ valDef :+ Ident(valDef.name)
-
- case If(cond, thenp, elsep) =>
- // if type of if-else is Unit don't introduce assignment,
- // but add Unit value to bring it into form expected by async transform
- if (expr.tpe =:= definitions.UnitTpe) {
- stats :+ expr :+ Literal(Constant(()))
- } else {
- val varDef = defineVar(name.ifRes, expr.tpe, tree.pos)
- def branchWithAssign(orig: Tree) = orig match {
- case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.name), thenExpr))
- case _ => Assign(Ident(varDef.name), orig)
+ def transformToBlock(tree: Tree): Block = listToBlock(transformToList(tree))
+
+ def _transformToList(tree: Tree): List[Tree] = trace(tree) {
+ val stats :+ expr = anf.transformToList(tree)
+ expr match {
+ case Apply(fun, args) if isAwait(fun) =>
+ val valDef = defineVal(name.await, expr, tree.pos)
+ stats :+ valDef :+ gen.mkAttributedStableRef(valDef.symbol)
+
+ case If(cond, thenp, elsep) =>
+ // if type of if-else is Unit don't introduce assignment,
+ // but add Unit value to bring it into form expected by async transform
+ if (expr.tpe =:= definitions.UnitTpe) {
+ stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(())))
+ } else {
+ val varDef = defineVar(name.ifRes, expr.tpe, tree.pos)
+ def branchWithAssign(orig: Tree) = localTyper.typedPos(orig.pos)(
+ orig match {
+ case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), thenExpr))
+ case _ => Assign(Ident(varDef.symbol), orig)
+ }
+ ).setType(orig.tpe)
+ val ifWithAssign = treeCopy.If(tree, cond, branchWithAssign(thenp), branchWithAssign(elsep))
+ stats :+ varDef :+ ifWithAssign :+ gen.mkAttributedStableRef(varDef.symbol)
+ }
+
+ case Match(scrut, cases) =>
+ // if type of match is Unit don't introduce assignment,
+ // but add Unit value to bring it into form expected by async transform
+ if (expr.tpe =:= definitions.UnitTpe) {
+ stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(())))
}
- val ifWithAssign = If(cond, branchWithAssign(thenp), branchWithAssign(elsep))
- stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name)
- }
-
- case Match(scrut, cases) =>
- // if type of match is Unit don't introduce assignment,
- // but add Unit value to bring it into form expected by async transform
- if (expr.tpe =:= definitions.UnitTpe) {
- stats :+ expr :+ Literal(Constant(()))
- }
- else {
- val varDef = defineVar(name.matchRes, expr.tpe, tree.pos)
- val casesWithAssign = cases map {
- case cd@CaseDef(pat, guard, Block(caseStats, caseExpr)) =>
- attachCopy(cd)(CaseDef(pat, guard, Block(caseStats, Assign(Ident(varDef.name), caseExpr))))
- case cd@CaseDef(pat, guard, body) =>
- attachCopy(cd)(CaseDef(pat, guard, Assign(Ident(varDef.name), body)))
+ else {
+ val varDef = defineVar(name.matchRes, expr.tpe, tree.pos)
+ def typedAssign(lhs: Tree) =
+ localTyper.typedPos(lhs.pos)(Assign(Ident(varDef.symbol), lhs))
+ val casesWithAssign = cases map {
+ case cd@CaseDef(pat, guard, body) =>
+ val newBody = body match {
+ case b@Block(caseStats, caseExpr) => treeCopy.Block(b, caseStats, typedAssign(caseExpr))
+ case _ => typedAssign(body)
+ }
+ treeCopy.CaseDef(cd, pat, guard, newBody)
+ }
+ val matchWithAssign = treeCopy.Match(tree, scrut, casesWithAssign)
+ require(matchWithAssign.tpe != null, matchWithAssign)
+ stats :+ varDef :+ matchWithAssign :+ gen.mkAttributedStableRef(varDef.symbol)
}
- val matchWithAssign = attachCopy(tree)(Match(scrut, casesWithAssign))
- stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name)
- }
- case _ =>
- stats :+ expr
+ case _ =>
+ stats :+ expr
+ }
}
- }
- def transformToList(trees: List[Tree]): List[Tree] = trees flatMap transformToList
+ private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
+ val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(tp)
+ ValDef(sym, gen.mkZero(tp)).setType(NoType).setPos(pos)
+ }
+ }
- def transformToBlock(tree: Tree): Block = transformToList(tree) match {
- case stats :+ expr => Block(stats, expr)
+ private object trace {
+ private var indent = -1
+
+ def indentString = " " * indent
+
+ def apply[T](args: Any)(t: => T): T = {
+ def prefix = mode.toString.toLowerCase
+ indent += 1
+ def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127)
+ try {
+ AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})")
+ val result = t
+ AsyncUtils.trace(s"${indentString}= ${oneLine(result)}")
+ result
+ } finally {
+ indent -= 1
+ }
+ }
}
- private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
- val vd = ValDef(Modifiers(Flag.MUTABLE), name.fresh(prefix), TypeTree(tp), defaultValue(tp))
- vd.setPos(pos)
- vd
+ private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = {
+ val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, SYNTHETIC).setInfo(lhs.tpe)
+ changeOwner(lhs, currentOwner, sym)
+ ValDef(sym, changeOwner(lhs, currentOwner, sym)).setType(NoType).setPos(pos)
}
- }
- private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = {
- val vd = ValDef(NoMods, name.fresh(prefix), TypeTree(), lhs)
- vd.setPos(pos)
- vd
- }
+ private object anf {
+ def transformToList(tree: Tree): List[Tree] = {
+ mode = Anf; blockToList(transform(tree))
+ }
- private object anf {
-
- private[AnfTransform] def transformToList(tree: Tree): List[Tree] = trace("anf", tree) {
- val containsAwait = tree exists isAwait
- if (!containsAwait) {
- List(tree)
- } else tree match {
- case Select(qual, sel) =>
- val stats :+ expr = inline.transformToList(qual)
- stats :+ attachCopy(tree)(Select(expr, sel).setSymbol(tree.symbol))
-
- case Applied(fun, targs, argss) if argss.nonEmpty =>
- // we an assume that no await call appears in a by-name argument position,
- // this has already been checked.
- val funStats :+ simpleFun = inline.transformToList(fun)
- def isAwaitRef(name: Name) = name.toString.startsWith(utils.name.await + "$")
- val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) =
- mapArgumentss[List[Tree]](fun, argss) {
- case Arg(expr, byName, _) if byName || isSafeToInline(expr) => (Nil, expr)
- case Arg(expr@Ident(name), _, _) if isAwaitRef(name) => (Nil, expr) // not typed, so it eludes the check in `isSafeToInline`
- case Arg(expr, _, argName) =>
- inline.transformToList(expr) match {
- case stats :+ expr1 =>
- val valDef = defineVal(argName, expr1, expr.pos)
- (stats :+ valDef, Ident(valDef.name))
- }
- }
- val core = if (targs.isEmpty) simpleFun else TypeApply(simpleFun, targs)
- val newApply = argExprss.foldLeft(core)(Apply(_, _)).setSymbol(tree.symbol)
- funStats ++ argStatss.flatten.flatten :+ attachCopy(tree)(newApply)
- case Block(stats, expr) =>
- inline.transformToList(stats :+ expr)
-
- case ValDef(mods, name, tpt, rhs) =>
- if (rhs exists isAwait) {
- val stats :+ expr = inline.transformToList(rhs)
- stats :+ attachCopy(tree)(ValDef(mods, name, tpt, expr).setSymbol(tree.symbol))
- } else List(tree)
-
- case Assign(lhs, rhs) =>
- val stats :+ expr = inline.transformToList(rhs)
- stats :+ attachCopy(tree)(Assign(lhs, expr))
-
- case If(cond, thenp, elsep) =>
- val condStats :+ condExpr = inline.transformToList(cond)
- val thenBlock = inline.transformToBlock(thenp)
- val elseBlock = inline.transformToBlock(elsep)
- // Typechecking with `condExpr` as the condition fails if the condition
- // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems
- // we rely on this call to `typeCheck` descending into the branches.
- // But, we can get away with typechecking a throwaway `If` tree with the
- // original scrutinee and the new branches, and setting that type on
- // the real `If` tree.
- val ifType = c.typeCheck(If(cond, thenBlock, elseBlock)).tpe
- condStats :+
- attachCopy(tree)(If(condExpr, thenBlock, elseBlock)).setType(ifType)
-
- case Match(scrut, cases) =>
- val scrutStats :+ scrutExpr = inline.transformToList(scrut)
- val caseDefs = cases map {
- case CaseDef(pat, guard, body) =>
- // extract local variables for all names bound in `pat`, and rewrite `body`
- // to refer to these.
- // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`.
- val block = inline.transformToBlock(body)
- val (valDefs, mappings) = (pat collect {
- case b@Bind(name, _) =>
- val newName = newTermName(utils.name.fresh(name.toTermName + utils.name.bindSuffix))
- val vd = ValDef(NoMods, newName, TypeTree(), Ident(b.symbol))
- (vd, (b.symbol, newName))
- }).unzip
- val Block(stats1, expr1) = utils.substituteNames(block, mappings.toMap).asInstanceOf[Block]
- attachCopy(tree)(CaseDef(pat, guard, Block(valDefs ++ stats1, expr1)))
- }
- // Refer to comments the translation of `If` above.
- val matchType = c.typeCheck(Match(scrut, caseDefs)).tpe
- val typedMatch = attachCopy(tree)(Match(scrutExpr, caseDefs)).setType(tree.tpe)
- scrutStats :+ typedMatch
-
- case LabelDef(name, params, rhs) =>
- List(LabelDef(name, params, Block(inline.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
-
- case TypeApply(fun, targs) =>
- val funStats :+ simpleFun = inline.transformToList(fun)
- funStats :+ attachCopy(tree)(TypeApply(simpleFun, targs).setSymbol(tree.symbol))
-
- case _ =>
+ def _transformToList(tree: Tree): List[Tree] = trace(tree) {
+ val containsAwait = tree exists isAwait
+ if (!containsAwait) {
List(tree)
+ } else tree match {
+ case Select(qual, sel) =>
+ val stats :+ expr = linearize.transformToList(qual)
+ stats :+ treeCopy.Select(tree, expr, sel)
+
+ case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty =>
+ // we an assume that no await call appears in a by-name argument position,
+ // this has already been checked.
+ val funStats :+ simpleFun = linearize.transformToList(fun)
+ def isAwaitRef(name: Name) = name.toString.startsWith(AnfTransform.this.name.await + "$")
+ val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) =
+ mapArgumentss[List[Tree]](fun, argss) {
+ case Arg(expr, byName, _) if byName /*|| isPure(expr) TODO */ => (Nil, expr)
+ case Arg(expr@Ident(name), _, _) if isAwaitRef(name) => (Nil, expr) // TODO needed? // not typed, so it eludes the check in `isSafeToInline`
+ case Arg(expr, _, argName) =>
+ linearize.transformToList(expr) match {
+ case stats :+ expr1 =>
+ val valDef = defineVal(argName, expr1, expr1.pos)
+ require(valDef.tpe != null, valDef)
+ val stats1 = stats :+ valDef
+ //stats1.foreach(changeOwner(_, currentOwner, currentOwner.owner))
+ (stats1, gen.stabilize(gen.mkAttributedIdent(valDef.symbol)))
+ }
+ }
+ val applied = treeInfo.dissectApplied(tree)
+ val core = if (targs.isEmpty) simpleFun else treeCopy.TypeApply(applied.callee, simpleFun, targs)
+ val newApply = argExprss.foldLeft(core)(Apply(_, _)).setSymbol(tree.symbol)
+ val typedNewApply = localTyper.typedPos(tree.pos)(newApply).setType(tree.tpe)
+ funStats ++ argStatss.flatten.flatten :+ typedNewApply
+ case Block(stats, expr) =>
+ (stats :+ expr).flatMap(linearize.transformToList)
+
+ case ValDef(mods, name, tpt, rhs) =>
+ if (rhs exists isAwait) {
+ val stats :+ expr = atOwner(currOwner.owner)(linearize.transformToList(rhs))
+ stats.foreach(changeOwner(_, currOwner, currOwner.owner))
+ stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr)
+ } else List(tree)
+
+ case Assign(lhs, rhs) =>
+ val stats :+ expr = linearize.transformToList(rhs)
+ stats :+ treeCopy.Assign(tree, lhs, expr)
+
+ case If(cond, thenp, elsep) =>
+ val condStats :+ condExpr = linearize.transformToList(cond)
+ val thenBlock = linearize.transformToBlock(thenp)
+ val elseBlock = linearize.transformToBlock(elsep)
+ // Typechecking with `condExpr` as the condition fails if the condition
+ // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems
+ // we rely on this call to `typeCheck` descending into the branches.
+ // But, we can get away with typechecking a throwaway `If` tree with the
+ // original scrutinee and the new branches, and setting that type on
+ // the real `If` tree.
+ val iff = treeCopy.If(tree, condExpr, thenBlock, elseBlock)
+ condStats :+ iff
+
+ case Match(scrut, cases) =>
+ val scrutStats :+ scrutExpr = linearize.transformToList(scrut)
+ val caseDefs = cases map {
+ case CaseDef(pat, guard, body) =>
+ // extract local variables for all names bound in `pat`, and rewrite `body`
+ // to refer to these.
+ // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`.
+ val block = linearize.transformToBlock(body)
+ val (valDefs, mappings) = (pat collect {
+ case b@Bind(name, _) =>
+ val vd = defineVal(name.toTermName + AnfTransform.this.name.bindSuffix, gen.mkAttributedStableRef(b.symbol), b.pos)
+ (vd, (b.symbol, vd.symbol))
+ }).unzip
+ val (from, to) = mappings.unzip
+ val b@Block(stats1, expr1) = block.substituteSymbols(from, to).asInstanceOf[Block]
+ val newBlock = treeCopy.Block(b, valDefs ++ stats1, expr1)
+ treeCopy.CaseDef(tree, pat, guard, newBlock)
+ }
+ // Refer to comments the translation of `If` above.
+ val typedMatch = treeCopy.Match(tree, scrutExpr, caseDefs)
+ scrutStats :+ typedMatch
+
+ case LabelDef(name, params, rhs) =>
+ List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
+
+ case TypeApply(fun, targs) =>
+ val funStats :+ simpleFun = linearize.transformToList(fun)
+ funStats :+ treeCopy.TypeApply(tree, simpleFun, targs)
+
+ case _ =>
+ List(tree)
+ }
}
}
}
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala
index 35d3687..5f577cf 100644
--- a/src/main/scala/scala/async/Async.scala
+++ b/src/main/scala/scala/async/Async.scala
@@ -7,6 +7,9 @@ package scala.async
import scala.language.experimental.macros
import scala.reflect.macros.Context
import scala.reflect.internal.annotations.compileTimeOnly
+import scala.tools.nsc.Global
+import language.reflectiveCalls
+import scala.concurrent.ExecutionContext
object Async extends AsyncBase {
@@ -15,18 +18,22 @@ object Async extends AsyncBase {
lazy val futureSystem = ScalaConcurrentFutureSystem
type FS = ScalaConcurrentFutureSystem.type
- def async[T](body: T) = macro asyncImpl[T]
+ def async[T](body: T)(implicit execContext: ExecutionContext): Future[T] = macro asyncImpl[T]
- override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body)
+ override def asyncImpl[T: c.WeakTypeTag](c: Context)
+ (body: c.Expr[T])
+ (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[Future[T]] = {
+ super.asyncImpl[T](c)(body)(execContext)
+ }
}
object AsyncId extends AsyncBase {
lazy val futureSystem = IdentityFutureSystem
type FS = IdentityFutureSystem.type
- def async[T](body: T) = macro asyncImpl[T]
+ def async[T](body: T) = macro asyncIdImpl[T]
- override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = super.asyncImpl[T](c)(body)
+ def asyncIdImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit)
}
/**
@@ -62,124 +69,26 @@ abstract class AsyncBase {
protected[async] def fallbackEnabled = false
- def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = {
+ def asyncImpl[T: c.WeakTypeTag](c: Context)
+ (body: c.Expr[T])
+ (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = {
import c.universe._
- val analyzer = AsyncAnalysis[c.type](c, this)
- val utils = TransformUtils[c.type](c)
- import utils.{name, defn}
-
- analyzer.reportUnsupportedAwaits(body.tree)
-
- // Transform to A-normal form:
- // - no await calls in qualifiers or arguments,
- // - if/match only used in statement position.
- val anfTree: Block = {
- val anf = AnfTransform[c.type](c)
- val restored = utils.restorePatternMatchingFunctions(body.tree)
- val stats1 :+ expr1 = anf(restored)
- val block = Block(stats1, expr1)
- c.typeCheck(block).asInstanceOf[Block]
- }
-
- // Analyze the block to find locals that will be accessed from multiple
- // states of our generated state machine, e.g. a value assigned before
- // an `await` and read afterwards.
- val renameMap: Map[Symbol, TermName] = {
- analyzer.defTreesUsedInSubsequentStates(anfTree).map {
- vd =>
- (vd.symbol, name.fresh(vd.name.toTermName))
- }.toMap
- }
-
- val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem, anfTree)
- import builder.futureSystemOps
- val asyncBlock: builder.AsyncBlock = builder.build(anfTree, renameMap)
- import asyncBlock.asyncStates
- logDiagnostics(c)(anfTree, asyncStates.map(_.toString))
-
- // Important to retain the original declaration order here!
- val localVarTrees = anfTree.collect {
- case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol =>
- utils.mkVarDefTree(tpt.tpe, renameMap(vd.symbol))
- case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) if renameMap contains dd.symbol =>
- DefDef(mods, renameMap(dd.symbol), tparams, vparamss, tpt, c.resetAllAttrs(utils.substituteNames(rhs, renameMap)))
- }
-
- val onCompleteHandler = {
- Function(
- List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)),
- asyncBlock.onCompleteHandler)
- }
- val resumeFunTree = asyncBlock.resumeFunTree[T]
-
- val stateMachineType = utils.applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType))
-
- lazy val stateMachine: ClassDef = {
- val body: List[Tree] = {
- val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0)))
- val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree)
- val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree)
- val applyDefDef: DefDef = {
- val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)))
- val applyBody = asyncBlock.onCompleteHandler
- DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), applyBody)
- }
- val apply0DefDef: DefDef = {
- // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`.
- // See SI-1247 for the the optimization that avoids creatio
- val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)))
- val applyBody = asyncBlock.onCompleteHandler
- DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil))
- }
- List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef)
- }
- val template = {
- Template(List(stateMachineType), emptyValDef, body)
- }
- ClassDef(NoMods, name.stateMachineT, Nil, template)
- }
-
- def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection)
-
- val code: c.Expr[futureSystem.Fut[T]] = {
- val isSimple = asyncStates.size == 1
- val tree =
- if (isSimple)
- Block(Nil, futureSystemOps.spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }`
- else {
- Block(List[Tree](
- stateMachine,
- ValDef(NoMods, name.stateMachine, stateMachineType, Apply(Select(New(Ident(name.stateMachineT)), nme.CONSTRUCTOR), Nil)),
- futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil))
- ),
- futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree)
- }
- c.Expr[futureSystem.Fut[T]](tree)
- }
-
- AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}")
- code
- }
+ val asyncMacro = AsyncMacro(c, futureSystem)
+
+ val code = asyncMacro.asyncTransform[T](
+ body.tree.asInstanceOf[asyncMacro.global.Tree],
+ execContext.tree.asInstanceOf[asyncMacro.global.Tree],
+ fallbackEnabled)(implicitly[c.WeakTypeTag[T]].asInstanceOf[asyncMacro.global.WeakTypeTag[T]])
- def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) {
- def location = try {
- c.macroApplication.pos.source.path
- } catch {
- case _: UnsupportedOperationException =>
- c.macroApplication.pos.toString
- }
-
- AsyncUtils.vprintln(s"In file '$location':")
- AsyncUtils.vprintln(s"${c.macroApplication}")
- AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree")
- states foreach (s => AsyncUtils.vprintln(s))
+ AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}")
+ c.Expr[futureSystem.Fut[T]](code.asInstanceOf[Tree])
}
}
/** Internal class used by the `async` macro; should not be manually extended by client code */
abstract class StateMachine[Result, EC] extends (scala.util.Try[Any] => Unit) with (() => Unit) {
- def result$async: Result
+ def result: Result
- def execContext$async: EC
+ def execContext: EC
}
diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala
index 4f55f1b..424318e 100644
--- a/src/main/scala/scala/async/AsyncAnalysis.scala
+++ b/src/main/scala/scala/async/AsyncAnalysis.scala
@@ -7,12 +7,10 @@ package scala.async
import scala.reflect.macros.Context
import scala.collection.mutable
-private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: AsyncBase) {
- import c.universe._
+trait AsyncAnalysis {
+ self: AsyncMacro =>
- val utils = TransformUtils[c.type](c)
-
- import utils._
+ import global._
/**
* Analyze the contents of an `async` block in order to:
@@ -20,47 +18,26 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: Asy
*
* Must be called on the original tree, not on the ANF transformed tree.
*/
- def reportUnsupportedAwaits(tree: Tree): Boolean = {
- val analyzer = new UnsupportedAwaitAnalyzer
+ def reportUnsupportedAwaits(tree: Tree, report: Boolean): Boolean = {
+ val analyzer = new UnsupportedAwaitAnalyzer(report)
analyzer.traverse(tree)
analyzer.hasUnsupportedAwaits
}
- /**
- * Analyze the contents of an `async` block in order to:
- * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based
- * on whether or not they are accessed only from a single state.
- *
- * Must be called on the ANF transformed tree.
- */
- def defTreesUsedInSubsequentStates(tree: Tree): List[DefTree] = {
- val analyzer = new AsyncDefinitionUseAnalyzer
- analyzer.traverse(tree)
- val liftable: List[DefTree] = (analyzer.valDefsToLift ++ analyzer.nestedMethodsToLift).toList.distinct
- liftable
- }
-
- private class UnsupportedAwaitAnalyzer extends AsyncTraverser {
+ private class UnsupportedAwaitAnalyzer(report: Boolean) extends AsyncTraverser {
var hasUnsupportedAwaits = false
override def nestedClass(classDef: ClassDef) {
- val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class"
- if (!reportUnsupportedAwait(classDef, s"nested $kind")) {
- // do not allow local class definitions, because of SI-5467 (specific to case classes, though)
- if (classDef.symbol.asClass.isCaseClass)
- c.error(classDef.pos, s"Local case class ${classDef.name.decoded} illegal within `async` block")
- }
+ val kind = if (classDef.symbol.isTrait) "trait" else "class"
+ reportUnsupportedAwait(classDef, s"nested ${kind}")
}
override def nestedModule(module: ModuleDef) {
- if (!reportUnsupportedAwait(module, "nested object")) {
- // local object definitions lead to spurious type errors (because of resetAllAttrs?)
- c.error(module.pos, s"Local object ${module.name.decoded} illegal within `async` block")
- }
+ reportUnsupportedAwait(module, "nested object")
}
- override def nestedMethod(module: DefDef) {
- reportUnsupportedAwait(module, "nested method")
+ override def nestedMethod(defDef: DefDef) {
+ reportUnsupportedAwait(defDef, "nested method")
}
override def byNameArgument(arg: Tree) {
@@ -82,9 +59,10 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: Asy
reportUnsupportedAwait(tree, "try/catch")
super.traverse(tree)
case Return(_) =>
- c.abort(tree.pos, "return is illegal within a async block")
+ abort(tree.pos, "return is illegal within a async block")
case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) =>
- c.abort(tree.pos, "lazy vals are illegal within an async block")
+ // TODO lift this restriction
+ abort(tree.pos, "lazy vals are illegal within an async block")
case _ =>
super.traverse(tree)
}
@@ -106,87 +84,8 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: Asy
private def reportError(pos: Position, msg: String) {
hasUnsupportedAwaits = true
- if (!asyncBase.fallbackEnabled)
- c.error(pos, msg)
+ if (report)
+ abort(pos, msg)
}
}
-
- private class AsyncDefinitionUseAnalyzer extends AsyncTraverser {
- private var chunkId = 0
-
- private def nextChunk() = chunkId += 1
-
- private var valDefChunkId = Map[Symbol, (ValDef, Int)]()
-
- val valDefsToLift : mutable.Set[ValDef] = collection.mutable.Set()
- val nestedMethodsToLift: mutable.Set[DefDef] = collection.mutable.Set()
-
- override def nestedMethod(defDef: DefDef) {
- nestedMethodsToLift += defDef
- markReferencedVals(defDef)
- }
-
- override def function(function: Function) {
- markReferencedVals(function)
- }
-
- override def patMatFunction(tree: Match) {
- markReferencedVals(tree)
- }
-
- private def markReferencedVals(tree: Tree) {
- tree foreach {
- case rt: RefTree =>
- valDefChunkId.get(rt.symbol) match {
- case Some((vd, defChunkId)) =>
- valDefsToLift += vd // lift all vals referred to by nested functions.
- case _ =>
- }
- case _ =>
- }
- }
-
- override def traverse(tree: Tree) = {
- tree match {
- case If(cond, thenp, elsep) if tree exists isAwait =>
- traverseChunks(List(cond, thenp, elsep))
- case Match(selector, cases) if tree exists isAwait =>
- traverseChunks(selector :: cases)
- case LabelDef(name, params, rhs) if rhs exists isAwait =>
- traverseChunks(rhs :: Nil)
- case Apply(fun, args) if isAwait(fun) =>
- super.traverse(tree)
- nextChunk()
- case vd: ValDef =>
- super.traverse(tree)
- valDefChunkId += (vd.symbol -> (vd -> chunkId))
- val isPatternBinder = vd.name.toString.contains(name.bindSuffix)
- if (isAwait(vd.rhs) || isPatternBinder) valDefsToLift += vd
- case as: Assign =>
- if (isAwait(as.rhs)) {
- assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol)
-
- // TODO test the orElse case, try to remove the restriction.
- val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}"))
- valDefsToLift += vd
- }
- super.traverse(tree)
- case rt: RefTree =>
- valDefChunkId.get(rt.symbol) match {
- case Some((vd, defChunkId)) if defChunkId != chunkId =>
- valDefsToLift += vd
- case _ =>
- }
- super.traverse(tree)
- case _ => super.traverse(tree)
- }
- }
-
- private def traverseChunks(trees: List[Tree]) {
- trees.foreach {
- t => traverse(t); nextChunk()
- }
- }
- }
-
}
diff --git a/src/main/scala/scala/async/AsyncMacro.scala b/src/main/scala/scala/async/AsyncMacro.scala
new file mode 100644
index 0000000..8827351
--- /dev/null
+++ b/src/main/scala/scala/async/AsyncMacro.scala
@@ -0,0 +1,29 @@
+package scala.async
+
+import scala.tools.nsc.Global
+import scala.tools.nsc.transform.TypingTransformers
+
+object AsyncMacro {
+ def apply(c: reflect.macros.Context, futureSystem0: FutureSystem): AsyncMacro = {
+ import language.reflectiveCalls
+ val powerContext = c.asInstanceOf[c.type {val universe: Global; val callsiteTyper: universe.analyzer.Typer}]
+ new AsyncMacro {
+ val global: powerContext.universe.type = powerContext.universe
+ val callSiteTyper: global.analyzer.Typer = powerContext.callsiteTyper
+ val futureSystem: futureSystem0.type = futureSystem0
+ val futureSystemOps: futureSystem.Ops {val universe: global.type} = futureSystem0.mkOps(global)
+ val macroApplication: global.Tree = c.macroApplication.asInstanceOf[global.Tree]
+ }
+ }
+}
+
+private[async] trait AsyncMacro
+ extends TypingTransformers
+ with AnfTransform with TransformUtils with Lifter
+ with ExprBuilder with AsyncTransform with AsyncAnalysis {
+
+ val global: Global
+ val callSiteTyper: global.analyzer.Typer
+ val macroApplication: global.Tree
+
+}
diff --git a/src/main/scala/scala/async/AsyncTransform.scala b/src/main/scala/scala/async/AsyncTransform.scala
new file mode 100644
index 0000000..129f88e
--- /dev/null
+++ b/src/main/scala/scala/async/AsyncTransform.scala
@@ -0,0 +1,176 @@
+package scala.async
+
+trait AsyncTransform {
+ self: AsyncMacro =>
+
+ import global._
+
+ def asyncTransform[T](body: Tree, execContext: Tree, cpsFallbackEnabled: Boolean)
+ (implicit resultType: WeakTypeTag[T]): Tree = {
+
+ reportUnsupportedAwaits(body, report = !cpsFallbackEnabled)
+
+ // Transform to A-normal form:
+ // - no await calls in qualifiers or arguments,
+ // - if/match only used in statement position.
+ val anfTree: Block = anfTransform(body)
+
+ val resumeFunTreeDummyBody = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), Literal(Constant(())))
+
+ val applyDefDefDummyBody: DefDef = {
+ val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)))
+ DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), Literal(Constant(())))
+ }
+
+ val stateMachineType = applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType))
+
+ val stateMachine: ClassDef = {
+ val body: List[Tree] = {
+ val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0)))
+ val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree)
+ val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext)
+
+ val apply0DefDef: DefDef = {
+ // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`.
+ // See SI-1247 for the the optimization that avoids creatio
+ DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil))
+ }
+ List(emptyConstructor, stateVar, result, execContextValDef) ++ List(resumeFunTreeDummyBody, applyDefDefDummyBody, apply0DefDef)
+ }
+ val template = {
+ Template(List(stateMachineType), emptyValDef, body)
+ }
+ val t = ClassDef(NoMods, name.stateMachineT, Nil, template)
+ callSiteTyper.typedPos(macroApplication.pos)(Block(t :: Nil, Literal(Constant(()))))
+ t
+ }
+
+ val asyncBlock: AsyncBlock = {
+ val symLookup = new SymLookup(stateMachine.symbol, applyDefDefDummyBody.vparamss.head.head.symbol)
+ buildAsyncBlock(anfTree, symLookup)
+ }
+
+ logDiagnostics(anfTree, asyncBlock.asyncStates.map(_.toString))
+
+ def startStateMachine: Tree = {
+ val stateMachineSpliced: Tree = spliceMethodBodies(
+ liftables(asyncBlock.asyncStates),
+ stateMachine,
+ asyncBlock.onCompleteHandler[T],
+ asyncBlock.resumeFunTree[T].rhs
+ )
+
+ def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection)
+
+ Block(List[Tree](
+ stateMachineSpliced,
+ ValDef(NoMods, name.stateMachine, stateMachineType, Apply(Select(New(Ident(stateMachine.symbol)), nme.CONSTRUCTOR), Nil)),
+ futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil), selectStateMachine(name.execContext))
+ ),
+ futureSystemOps.promiseToFuture(Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree)
+ }
+
+ val isSimple = asyncBlock.asyncStates.size == 1
+ if (isSimple)
+ futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }`
+ else
+ startStateMachine
+ }
+
+ def logDiagnostics(anfTree: Tree, states: Seq[String]) {
+ val pos = macroApplication.pos
+ def location = try {
+ pos.source.path
+ } catch {
+ case _: UnsupportedOperationException =>
+ pos.toString
+ }
+
+ AsyncUtils.vprintln(s"In file '$location':")
+ AsyncUtils.vprintln(s"${macroApplication}")
+ AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree")
+ states foreach (s => AsyncUtils.vprintln(s))
+ }
+
+ def spliceMethodBodies(liftables: List[Tree], tree: Tree, applyBody: Tree,
+ resumeBody: Tree): Tree = {
+
+ val liftedSyms = liftables.map(_.symbol).toSet
+ val stateMachineClass = tree.symbol
+ liftedSyms.foreach {
+ sym =>
+ if (sym != null) {
+ sym.owner = stateMachineClass
+ if (sym.isModule)
+ sym.moduleClass.owner = stateMachineClass
+ }
+ }
+ // Replace the ValDefs in the splicee with Assigns to the corresponding lifted
+ // fields. Similarly, replace references to them with references to the field.
+ //
+ // This transform will be only be run on the RHS of `def foo`.
+ class UseFields extends MacroTypingTransformer {
+ override def transform(tree: Tree): Tree = tree match {
+ case _ if currentOwner == stateMachineClass =>
+ super.transform(tree)
+ case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) =>
+ atOwner(currentOwner) {
+ val fieldSym = tree.symbol
+ val set = Assign(gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), transform(rhs))
+ changeOwner(set, tree.symbol, currentOwner)
+ localTyper.typedPos(tree.pos)(set)
+ }
+ case _: DefTree if liftedSyms(tree.symbol) =>
+ EmptyTree
+ case Ident(name) if liftedSyms(tree.symbol) =>
+ val fieldSym = tree.symbol
+ gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym).setType(tree.tpe)
+ case _ =>
+ super.transform(tree)
+ }
+ }
+
+ val liftablesUseFields = liftables.map {
+ case vd: ValDef => vd
+ case x =>
+ val useField = new UseFields()
+ //.substituteSymbols(fromSyms, toSyms)
+ useField.atOwner(stateMachineClass)(useField.transform(x))
+ }
+
+ tree.children.foreach {
+ t =>
+ new ChangeOwnerAndModuleClassTraverser(callSiteTyper.context.owner, tree.symbol).traverse(t)
+ }
+ val treeSubst = tree
+
+ def fixup(dd: DefDef, body: Tree, ctx: analyzer.Context): Tree = {
+ val spliceeAnfFixedOwnerSyms = body
+ val useField = new UseFields()
+ val newRhs = useField.atOwner(dd.symbol)(useField.transform(spliceeAnfFixedOwnerSyms))
+ val typer = global.analyzer.newTyper(ctx.make(dd, dd.symbol))
+ treeCopy.DefDef(dd, dd.mods, dd.name, dd.tparams, dd.vparamss, dd.tpt, typer.typed(newRhs))
+ }
+
+ liftablesUseFields.foreach(t => if (t.symbol != null) stateMachineClass.info.decls.enter(t.symbol))
+
+ val result0 = transformAt(treeSubst) {
+ case t@Template(parents, self, stats) =>
+ (ctx: analyzer.Context) => {
+ treeCopy.Template(t, parents, self, liftablesUseFields ++ stats)
+ }
+ }
+ val result = transformAt(result0) {
+ case dd@DefDef(_, name.apply, _, List(List(_)), _, _) if dd.symbol.owner == stateMachineClass =>
+ (ctx: analyzer.Context) =>
+ val typedTree = fixup(dd, changeOwner(applyBody, callSiteTyper.context.owner, dd.symbol), ctx)
+ typedTree
+ case dd@DefDef(_, name.resume, _, _, _, _) if dd.symbol.owner == stateMachineClass =>
+ (ctx: analyzer.Context) =>
+ val changed = changeOwner(resumeBody, callSiteTyper.context.owner, dd.symbol)
+ val res = fixup(dd, changed, ctx)
+ res
+ }
+ result
+ }
+}
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index ca46a83..a3837d3 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -7,17 +7,17 @@ import scala.reflect.macros.Context
import scala.collection.mutable.ListBuffer
import collection.mutable
import language.existentials
+import scala.reflect.api.Universe
+import scala.reflect.api
-private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: C, futureSystem: FS, origTree: C#Tree) {
- builder =>
+trait ExprBuilder {
+ builder: AsyncMacro =>
- val utils = TransformUtils[c.type](c)
-
- import c.universe._
- import utils._
+ import global._
import defn._
- lazy val futureSystemOps = futureSystem.mkOps(c)
+ val futureSystem: FutureSystem
+ val futureSystemOps: futureSystem.Ops { val universe: global.type }
val stateAssigner = new StateAssigner
val labelDefStates = collection.mutable.Map[Symbol, Int]()
@@ -27,22 +27,27 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
def mkHandlerCaseForState: CaseDef
- def mkOnCompleteHandler[T: c.WeakTypeTag]: Option[CaseDef] = None
+ def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None
def stats: List[Tree]
- final def body: c.Tree = stats match {
+ final def allStats: List[Tree] = this match {
+ case a: AsyncStateWithAwait => stats :+ a.awaitable.resultValDef
+ case _ => stats
+ }
+
+ final def body: Tree = stats match {
case stat :: Nil => stat
case init :+ last => Block(init, last)
}
}
/** A sequence of statements the concludes with a unconditional transition to `nextState` */
- final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int)
+ final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup)
extends AsyncState {
def mkHandlerCaseForState: CaseDef =
- mkHandlerCase(state, stats :+ mkStateTree(nextState) :+ mkResumeApply)
+ mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup))
override val toString: String =
s"AsyncState #$state, next = $nextState"
@@ -51,7 +56,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
/** A sequence of statements with a conditional transition to the next state, which will represent
* a branch of an `if` or a `match`.
*/
- final class AsyncStateWithoutAwait(val stats: List[c.Tree], val state: Int) extends AsyncState {
+ final class AsyncStateWithoutAwait(val stats: List[Tree], val state: Int) extends AsyncState {
override def mkHandlerCaseForState: CaseDef =
mkHandlerCase(state, stats)
@@ -62,25 +67,25 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
/** A sequence of statements that concludes with an `await` call. The `onComplete`
* handler will unconditionally transition to `nestState`.``
*/
- final class AsyncStateWithAwait(val stats: List[c.Tree], val state: Int, nextState: Int,
- awaitable: Awaitable)
+ final class AsyncStateWithAwait(val stats: List[Tree], val state: Int, nextState: Int,
+ val awaitable: Awaitable, symLookup: SymLookup)
extends AsyncState {
override def mkHandlerCaseForState: CaseDef = {
- val callOnComplete = futureSystemOps.onComplete(c.Expr(awaitable.expr),
- c.Expr(This(tpnme.EMPTY)), c.Expr(Ident(name.execContext))).tree
+ val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr),
+ Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree
mkHandlerCase(state, stats :+ callOnComplete)
}
- override def mkOnCompleteHandler[T: c.WeakTypeTag]: Option[CaseDef] = {
+ override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = {
val tryGetTree =
Assign(
Ident(awaitable.resultName),
- TypeApply(Select(Select(Ident(name.tr), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType)))
+ TypeApply(Select(Select(Ident(symLookup.applyTrParam), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType)))
)
/* if (tr.isFailure)
- * result$async.complete(tr.asInstanceOf[Try[T]])
+ * result.complete(tr.asInstanceOf[Try[T]])
* else {
* <resultName> = tr.get.asInstanceOf[<resultType>]
* <nextState>
@@ -88,13 +93,13 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
* }
*/
val ifIsFailureTree =
- If(Select(Ident(name.tr), Try_isFailure),
+ If(Select(Ident(symLookup.applyTrParam), Try_isFailure),
futureSystemOps.completeProm[T](
- c.Expr[futureSystem.Prom[T]](Ident(name.result)),
- c.Expr[scala.util.Try[T]](
- TypeApply(Select(Ident(name.tr), newTermName("asInstanceOf")),
+ Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)),
+ Expr[scala.util.Try[T]](
+ TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")),
List(TypeTree(weakTypeOf[scala.util.Try[T]]))))).tree,
- Block(List(tryGetTree, mkStateTree(nextState)), mkResumeApply)
+ Block(List(tryGetTree, mkStateTree(nextState, symLookup)), mkResumeApply(symLookup))
)
Some(mkHandlerCase(state, List(ifIsFailureTree)))
@@ -107,19 +112,16 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
/*
* Builder for a single state of an async method.
*/
- final class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name]) {
+ final class AsyncStateBuilder(state: Int, private val symLookup: SymLookup) {
/* Statements preceding an await call. */
- private val stats = ListBuffer[c.Tree]()
+ private val stats = ListBuffer[Tree]()
/** The state of the target of a LabelDef application (while loop jump) */
private var nextJumpState: Option[Int] = None
- private def renameReset(tree: Tree) = resetInternalAttrs(substituteNames(tree, nameMap))
-
- def +=(stat: c.Tree): this.type = {
+ def +=(stat: Tree): this.type = {
assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat")
- def addStat() = stats += renameReset(stat)
+ def addStat() = stats += stat
stat match {
- case _: DefDef => // these have been lifted.
case Apply(fun, Nil) =>
labelDefStates get fun.symbol match {
case Some(nextState) => nextJumpState = Some(nextState)
@@ -132,22 +134,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
def resultWithAwait(awaitable: Awaitable,
nextState: Int): AsyncState = {
- val sanitizedAwaitable = awaitable.copy(expr = renameReset(awaitable.expr))
val effectiveNextState = nextJumpState.getOrElse(nextState)
- new AsyncStateWithAwait(stats.toList, state, effectiveNextState, sanitizedAwaitable)
+ new AsyncStateWithAwait(stats.toList, state, effectiveNextState, awaitable, symLookup)
}
def resultSimple(nextState: Int): AsyncState = {
val effectiveNextState = nextJumpState.getOrElse(nextState)
- new SimpleAsyncState(stats.toList, state, effectiveNextState)
+ new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup)
}
- def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = {
- // 1. build changed if-else tree
- // 2. insert that tree at the end of the current state
- val cond = renameReset(condTree)
- def mkBranch(state: Int) = Block(mkStateTree(state) :: Nil, mkResumeApply)
- this += If(cond, mkBranch(thenState), mkBranch(elseState))
+ def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = {
+ def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup))
+ this += If(condTree, mkBranch(thenState), mkBranch(elseState))
new AsyncStateWithoutAwait(stats.toList, state)
}
@@ -161,23 +159,20 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
* @param caseStates starting state of the right-hand side of the each case
* @return an `AsyncState` representing the match expression
*/
- def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], caseStates: List[Int]): AsyncState = {
+ def resultWithMatch(scrutTree: Tree, cases: List[CaseDef], caseStates: List[Int], symLookup: SymLookup): AsyncState = {
// 1. build list of changed cases
val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match {
case CaseDef(pat, guard, rhs) =>
- val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal).map {
- case ValDef(_, name, _, rhs) => Assign(Ident(name), rhs)
- case t => sys.error(s"Unexpected tree. Expected ValDef, found: $t")
- }
- CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num)), mkResumeApply))
+ val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal)
+ CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num), symLookup), mkResumeApply(symLookup)))
}
// 2. insert changed match tree at the end of the current state
- this += Match(renameReset(scrutTree), newCases)
+ this += Match(scrutTree, newCases)
new AsyncStateWithoutAwait(stats.toList, state)
}
- def resultWithLabel(startLabelState: Int): AsyncState = {
- this += Block(mkStateTree(startLabelState) :: Nil, mkResumeApply)
+ def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = {
+ this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup))
new AsyncStateWithoutAwait(stats.toList, state)
}
@@ -194,24 +189,23 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
* @param expr the last expression of the block
* @param startState the start state
* @param endState the state to continue with
- * @param toRename a `Map` for renaming the given key symbols to the mangled value names
*/
- final private class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int,
- private val toRename: Map[Symbol, c.Name]) {
+ final private class AsyncBlockBuilder(stats: List[Tree], expr: Tree, startState: Int, endState: Int,
+ private val symLookup: SymLookup) {
val asyncStates = ListBuffer[AsyncState]()
- var stateBuilder = new AsyncStateBuilder(startState, toRename)
+ var stateBuilder = new AsyncStateBuilder(startState, symLookup)
var currState = startState
/* TODO Fall back to CPS plug-in if tree contains an `await` call. */
- def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists {
+ def checkForUnsupportedAwait(tree: Tree) = if (tree exists {
case Apply(fun, _) if isAwait(fun) => true
case _ => false
- }) c.abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException
+ }) abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException
def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = {
val (nestedStats, nestedExpr) = statsAndExpr(nestedTree)
- new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, toRename)
+ new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, symLookup)
}
import stateAssigner.nextState
@@ -219,16 +213,12 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
// populate asyncStates
for (stat <- stats) stat match {
// the val name = await(..) pattern
- case ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) =>
+ case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) =>
val afterAwaitState = nextState()
- val awaitable = Awaitable(arg, toRename(stat.symbol).toTermName, tpt.tpe)
+ val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd)
asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await
currState = afterAwaitState
- stateBuilder = new AsyncStateBuilder(currState, toRename)
-
- case ValDef(mods, name, tpt, rhs) if toRename contains stat.symbol =>
- checkForUnsupportedAwait(rhs)
- stateBuilder += Assign(Ident(toRename(stat.symbol).toTermName), rhs)
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
case If(cond, thenp, elsep) if stat exists isAwait =>
checkForUnsupportedAwait(cond)
@@ -248,7 +238,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
}
currState = afterIfState
- stateBuilder = new AsyncStateBuilder(currState, toRename)
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
case Match(scrutinee, cases) if stat exists isAwait =>
checkForUnsupportedAwait(scrutinee)
@@ -257,7 +247,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
val afterMatchState = nextState()
asyncStates +=
- stateBuilder.resultWithMatch(scrutinee, cases, caseStates)
+ stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup)
for ((cas, num) <- cases.zipWithIndex) {
val (stats, expr) = statsAndExpr(cas.body)
@@ -267,18 +257,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
}
currState = afterMatchState
- stateBuilder = new AsyncStateBuilder(currState, toRename)
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
case ld@LabelDef(name, params, rhs) if rhs exists isAwait =>
val startLabelState = nextState()
val afterLabelState = nextState()
- asyncStates += stateBuilder.resultWithLabel(startLabelState)
+ asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup)
labelDefStates(ld.symbol) = startLabelState
val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState)
asyncStates ++= builder.asyncStates
currState = afterLabelState
- stateBuilder = new AsyncStateBuilder(currState, toRename)
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
case _ =>
checkForUnsupportedAwait(stat)
stateBuilder += stat
@@ -292,17 +282,23 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
trait AsyncBlock {
def asyncStates: List[AsyncState]
- def onCompleteHandler[T: c.WeakTypeTag]: Tree
+ def onCompleteHandler[T: WeakTypeTag]: Tree
+
+ def resumeFunTree[T]: DefDef
+ }
- def resumeFunTree[T]: Tree
+ case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) {
+ def stateMachineMember(name: TermName): Symbol =
+ stateMachineClass.info.member(name)
+ def memberRef(name: TermName) = gen.mkAttributedRef(stateMachineMember(name))
}
- def build(block: Block, toRename: Map[Symbol, c.Name]): AsyncBlock = {
+ def buildAsyncBlock(block: Block, symLookup: SymLookup): AsyncBlock = {
val Block(stats, expr) = block
val startState = stateAssigner.nextState()
val endState = Int.MaxValue
- val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, toRename)
+ val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, symLookup)
new AsyncBlock {
def asyncStates = blockBuilder.asyncStates.toList
@@ -310,9 +306,9 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
def mkCombinedHandlerCases[T]: List[CaseDef] = {
val caseForLastState: CaseDef = {
val lastState = asyncStates.last
- val lastStateBody = c.Expr[T](lastState.body)
+ val lastStateBody = Expr[T](lastState.body)
val rhs = futureSystemOps.completeProm(
- c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Success(lastStateBody.splice)))
+ Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Success(lastStateBody.splice)))
mkHandlerCase(lastState.state, rhs.tree)
}
asyncStates.toList match {
@@ -327,18 +323,6 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
val initStates = asyncStates.init
/**
- * // assumes tr: Try[Any] is in scope.
- * //
- * state match {
- * case 0 => {
- * x11 = tr.get.asInstanceOf[Double];
- * state = 1;
- * resume()
- * }
- */
- def onCompleteHandler[T: c.WeakTypeTag]: Tree = Match(Ident(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList)
-
- /**
* def resume(): Unit = {
* try {
* state match {
@@ -353,18 +337,31 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
* }
* }
*/
- def resumeFunTree[T]: Tree =
+ def resumeFunTree[T]: DefDef =
DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass),
Try(
- Match(Ident(name.state), mkCombinedHandlerCases[T]),
+ Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T]),
List(
CaseDef(
- Apply(Ident(defn.NonFatalClass), List(Bind(name.tr, Ident(nme.WILDCARD)))),
- EmptyTree,
+ Bind(name.t, Ident(nme.WILDCARD)),
+ Apply(Ident(defn.NonFatalClass), List(Ident(name.t))),
Block(List({
- val t = c.Expr[Throwable](Ident(name.tr))
- futureSystemOps.completeProm[T](c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Failure(t.splice))).tree
- }), c.literalUnit.tree))), EmptyTree))
+ val t = Expr[Throwable](Ident(name.t))
+ futureSystemOps.completeProm[T](
+ Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Failure(t.splice))).tree
+ }), literalUnit))), EmptyTree))
+
+ /**
+ * // assumes tr: Try[Any] is in scope.
+ * //
+ * state match {
+ * case 0 => {
+ * x11 = tr.get.asInstanceOf[Double];
+ * state = 1;
+ * resume()
+ * }
+ */
+ def onCompleteHandler[T: WeakTypeTag]: Tree = Match(symLookup.memberRef(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList)
}
}
@@ -373,22 +370,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
case _ => false
}
- private final case class Awaitable(expr: Tree, resultName: TermName, resultType: Type)
-
- private val internalSyms = origTree.collect {
- case dt: DefTree => dt.symbol
- }
+ case class Awaitable(expr: Tree, resultName: Symbol, resultType: Type, resultValDef: ValDef)
- private def resetInternalAttrs(tree: Tree) = utils.resetInternalAttrs(tree, internalSyms)
+ private def mkResumeApply(symLookup: SymLookup) = Apply(symLookup.memberRef(name.resume), Nil)
- private def mkResumeApply = Apply(Ident(name.resume), Nil)
+ private def mkStateTree(nextState: Int, symLookup: SymLookup): Tree =
+ Assign(symLookup.memberRef(name.state), Literal(Constant(nextState)))
- private def mkStateTree(nextState: Int): c.Tree =
- Assign(Ident(name.state), c.literal(nextState).tree)
+ private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef =
+ mkHandlerCase(num, Block(rhs, literalUnit))
- private def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef =
- mkHandlerCase(num, Block(rhs, c.literalUnit.tree))
+ private def mkHandlerCase(num: Int, rhs: Tree): CaseDef =
+ CaseDef(Literal(Constant(num)), EmptyTree, rhs)
- private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef =
- CaseDef(c.literal(num).tree, EmptyTree, rhs)
+ private def literalUnit = Literal(Constant(()))
}
diff --git a/src/main/scala/scala/async/FutureSystem.scala b/src/main/scala/scala/async/FutureSystem.scala
index a050bec..0c04296 100644
--- a/src/main/scala/scala/async/FutureSystem.scala
+++ b/src/main/scala/scala/async/FutureSystem.scala
@@ -6,6 +6,7 @@ package scala.async
import scala.language.higherKinds
import scala.reflect.macros.Context
+import scala.reflect.internal.SymbolTable
/**
* An abstraction over a future system.
@@ -26,12 +27,10 @@ trait FutureSystem {
type ExecContext
trait Ops {
- val context: reflect.macros.Context
+ val universe: reflect.internal.SymbolTable
- import context.universe._
-
- /** Lookup the execution context, typically with an implicit search */
- def execContext: Expr[ExecContext]
+ import universe._
+ def Expr[T: WeakTypeTag](tree: Tree): Expr[T] = universe.Expr[T](rootMirror, universe.FixedMirrorTreeCreator(rootMirror, tree))
def promType[A: WeakTypeTag]: Type
def execContextType: Type
@@ -52,15 +51,17 @@ trait FutureSystem {
/** Complete a promise with a value */
def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit]
- def spawn(tree: context.Tree): context.Tree =
- future(context.Expr[Unit](tree))(execContext).tree
+ def spawn(tree: Tree, execContext: Tree): Tree =
+ future(Expr[Unit](tree))(Expr[ExecContext](execContext)).tree
+ // TODO Why is this needed?
def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]]
}
- def mkOps(c: Context): Ops { val context: c.type }
+ def mkOps(c: SymbolTable): Ops { val universe: c.type }
}
+
object ScalaConcurrentFutureSystem extends FutureSystem {
import scala.concurrent._
@@ -69,18 +70,13 @@ object ScalaConcurrentFutureSystem extends FutureSystem {
type Fut[A] = Future[A]
type ExecContext = ExecutionContext
- def mkOps(c: Context): Ops {val context: c.type} = new Ops {
- val context: c.type = c
-
- import context.universe._
+ def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops {
+ val universe: c.type = c
- def execContext: Expr[ExecContext] = c.Expr(c.inferImplicitValue(c.weakTypeOf[ExecutionContext]) match {
- case EmptyTree => c.abort(c.macroApplication.pos, "Unable to resolve implicit ExecutionContext")
- case context => context
- })
+ import universe._
- def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Promise[A]]
- def execContextType: Type = c.weakTypeOf[ExecutionContext]
+ def promType[A: WeakTypeTag]: Type = weakTypeOf[Promise[A]]
+ def execContextType: Type = weakTypeOf[ExecutionContext]
def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify {
Promise[A]()
@@ -101,7 +97,7 @@ object ScalaConcurrentFutureSystem extends FutureSystem {
def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify {
prom.splice.complete(value.splice)
- context.literalUnit.splice
+ Expr[Unit](Literal(Constant(()))).splice
}
def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = reify {
@@ -121,15 +117,15 @@ object IdentityFutureSystem extends FutureSystem {
type Fut[A] = A
type ExecContext = Unit
- def mkOps(c: Context): Ops {val context: c.type} = new Ops {
- val context: c.type = c
+ def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops {
+ val universe: c.type = c
- import context.universe._
+ import universe._
- def execContext: Expr[ExecContext] = c.literalUnit
+ def execContext: Expr[ExecContext] = Expr[Unit](Literal(Constant(())))
- def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Prom[A]]
- def execContextType: Type = c.weakTypeOf[Unit]
+ def promType[A: WeakTypeTag]: Type = weakTypeOf[Prom[A]]
+ def execContextType: Type = weakTypeOf[Unit]
def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify {
new Prom(null.asInstanceOf[A])
@@ -144,12 +140,12 @@ object IdentityFutureSystem extends FutureSystem {
def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U],
execContext: Expr[ExecContext]): Expr[Unit] = reify {
fun.splice.apply(util.Success(future.splice))
- context.literalUnit.splice
+ Expr[Unit](Literal(Constant(()))).splice
}
def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify {
prom.splice.a = value.splice.get
- context.literalUnit.splice
+ Expr[Unit](Literal(Constant(()))).splice
}
def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = ???
diff --git a/src/main/scala/scala/async/Lifter.scala b/src/main/scala/scala/async/Lifter.scala
new file mode 100644
index 0000000..52ce47d
--- /dev/null
+++ b/src/main/scala/scala/async/Lifter.scala
@@ -0,0 +1,150 @@
+package scala.async
+
+trait Lifter {
+ self: AsyncMacro =>
+ import global._
+
+ /**
+ * Identify which DefTrees are used (including transitively) which are declared
+ * in some state but used (including transitively) in another state.
+ *
+ * These will need to be lifted to class members of the state machine.
+ */
+ def liftables(asyncStates: List[AsyncState]): List[Tree] = {
+ object companionship {
+ private val companions = collection.mutable.Map[Symbol, Symbol]()
+ private val companionsInverse = collection.mutable.Map[Symbol, Symbol]()
+ private def record(sym1: Symbol, sym2: Symbol) {
+ companions(sym1) = sym2
+ companions(sym2) = sym1
+ }
+
+ def record(defs: List[Tree]) {
+ // Keep note of local companions so we rename them consistently
+ // when lifting.
+ val comps = for {
+ cd@ClassDef(_, _, _, _) <- defs
+ md@ModuleDef(_, _, _) <- defs
+ if (cd.name.toTermName == md.name)
+ } record(cd.symbol, md.symbol)
+ }
+ def companionOf(sym: Symbol): Symbol = {
+ companions.get(sym).orElse(companionsInverse.get(sym)).getOrElse(NoSymbol)
+ }
+ }
+
+
+ val defs: Map[Tree, Int] = {
+ /** Collect the DefTrees directly enclosed within `t` that have the same owner */
+ def collectDirectlyEnclosedDefs(t: Tree): List[DefTree] = t match {
+ case dt: DefTree => dt :: Nil
+ case _: Function => Nil
+ case t =>
+ val childDefs = t.children.flatMap(collectDirectlyEnclosedDefs(_))
+ companionship.record(childDefs)
+ childDefs
+ }
+ asyncStates.flatMap {
+ asyncState =>
+ val defs = collectDirectlyEnclosedDefs(Block(asyncState.allStats: _*))
+ defs.map((_, asyncState.state))
+ }.toMap
+ }
+
+ // In which block are these symbols defined?
+ val symToDefiningState: Map[Symbol, Int] = defs.map {
+ case (k, v) => (k.symbol, v)
+ }
+
+ // The definitions trees
+ val symToTree: Map[Symbol, Tree] = defs.map {
+ case (k, v) => (k.symbol, k)
+ }
+
+ // The direct references of each definition tree
+ val defSymToReferenced: Map[Symbol, List[Symbol]] = defs.keys.map {
+ case tree => (tree.symbol, tree.collect {
+ case rt: RefTree if symToDefiningState.contains(rt.symbol) => rt.symbol
+ })
+ }.toMap
+
+ // The direct references of each block, excluding references of `DefTree`-s which
+ // are already accounted for.
+ val stateIdToDirectlyReferenced: Map[Int, List[Symbol]] = {
+ val refs: List[(Int, Symbol)] = asyncStates.flatMap(
+ asyncState => asyncState.stats.filterNot(_.isDef).flatMap(_.collect {
+ case rt: RefTree if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol)
+ })
+ )
+ toMultiMap(refs)
+ }
+
+ def liftableSyms: Set[Symbol] = {
+ val liftableMutableSet = collection.mutable.Set[Symbol]()
+ def markForLift(sym: Symbol) {
+ if (!liftableMutableSet(sym)) {
+ liftableMutableSet += sym
+
+ // Only mark transitive references of defs, modules and classes. The RHS of lifted vals/vars
+ // stays in its original location, so things that it refers to need not be lifted.
+ if (!(sym.isVal || sym.isVar))
+ defSymToReferenced(sym).foreach(sym2 => markForLift(sym2))
+ }
+ }
+ // Start things with DefTrees directly referenced from statements from other states...
+ val liftableStatementRefs: List[Symbol] = stateIdToDirectlyReferenced.toList.flatMap {
+ case (i, syms) => syms.filter(sym => symToDefiningState(sym) != i)
+ }
+ // .. and likewise for DefTrees directly referenced by other DefTrees from other states
+ val liftableRefsOfDefTrees = defSymToReferenced.toList.flatMap {
+ case (referee, referents) => referents.filter(sym => symToDefiningState(sym) != symToDefiningState(referee))
+ }
+ // Mark these for lifting, which will follow transitive references.
+ (liftableStatementRefs ++ liftableRefsOfDefTrees).foreach(markForLift)
+ liftableMutableSet.toSet
+ }
+
+ val lifted = liftableSyms.map(symToTree).toList.map {
+ case vd@ValDef(_, _, tpt, rhs) =>
+ import reflect.internal.Flags._
+ val sym = vd.symbol
+ sym.setFlag(MUTABLE | STABLE | PRIVATE | LOCAL)
+ sym.name = name.fresh(sym.name.toTermName)
+ sym.modifyInfo(_.deconst)
+ ValDef(vd.symbol, gen.mkZero(vd.symbol.info)).setPos(vd.pos)
+ case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
+ import reflect.internal.Flags._
+ val sym = dd.symbol
+ sym.name = this.name.fresh(sym.name.toTermName)
+ sym.setFlag(PRIVATE | LOCAL)
+ DefDef(dd.symbol, rhs).setPos(dd.pos)
+ case cd@ClassDef(_, _, _, impl) =>
+ import reflect.internal.Flags._
+ val sym = cd.symbol
+ sym.name = newTypeName(name.fresh(sym.name.toString).toString)
+ companionship.companionOf(cd.symbol) match {
+ case NoSymbol =>
+ case moduleSymbol =>
+ moduleSymbol.name = sym.name.toTermName
+ moduleSymbol.moduleClass.name = moduleSymbol.name.toTypeName
+ }
+ ClassDef(cd.symbol, impl).setPos(cd.pos)
+ case md@ModuleDef(_, _, impl) =>
+ import reflect.internal.Flags._
+ val sym = md.symbol
+ companionship.companionOf(md.symbol) match {
+ case NoSymbol =>
+ sym.name = name.fresh(sym.name.toTermName)
+ sym.moduleClass.name = sym.name.toTypeName
+ case classSymbol => // will be renamed by `case ClassDef` above.
+ }
+ ModuleDef(md.symbol, impl).setPos(md.pos)
+ case td@TypeDef(_, _, _, rhs) =>
+ import reflect.internal.Flags._
+ val sym = td.symbol
+ sym.name = newTypeName(name.fresh(sym.name.toString).toString)
+ TypeDef(td.symbol, rhs).setPos(td.pos)
+ }
+ lifted
+ }
+}
diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala
index ebd546f..33dd21d 100644
--- a/src/main/scala/scala/async/TransformUtils.scala
+++ b/src/main/scala/scala/async/TransformUtils.scala
@@ -5,115 +5,40 @@ package scala.async
import scala.reflect.macros.Context
import reflect.ClassTag
+import scala.reflect.macros.runtime.AbortMacroException
/**
* Utilities used in both `ExprBuilder` and `AnfTransform`.
*/
-private[async] final case class TransformUtils[C <: Context](c: C) {
+private[async] trait TransformUtils {
+ self: AsyncMacro =>
- import c.universe._
+ import global._
object name {
- def suffix(string: String) = string + "$async"
-
- def suffixedName(prefix: String) = newTermName(suffix(prefix))
-
- val state = suffixedName("state")
- val result = suffixedName("result")
- val resume = suffixedName("resume")
- val execContext = suffixedName("execContext")
- val stateMachine = newTermName(fresh("stateMachine"))
- val stateMachineT = stateMachine.toTypeName
+ val resume = newTermName("resume")
val apply = newTermName("apply")
- val applyOrElse = newTermName("applyOrElse")
- val tr = newTermName("tr")
val matchRes = "matchres"
val ifRes = "ifres"
val await = "await"
val bindSuffix = "$bind"
- def fresh(name: TermName): TermName = newTermName(fresh(name.toString))
+ val state = newTermName("state")
+ val result = newTermName("result")
+ val execContext = newTermName("execContext")
+ val stateMachine = newTermName(fresh("stateMachine"))
+ val stateMachineT = stateMachine.toTypeName
+ val tr = newTermName("tr")
+ val t = newTermName("throwable")
- def fresh(name: String): String = if (name.toString.contains("$")) name else c.fresh("" + name + "$")
- }
+ def fresh(name: TermName): TermName = newTermName(fresh(name.toString))
- def defaultValue(tpe: Type): Literal = {
- val defaultValue: Any =
- if (tpe <:< definitions.BooleanTpe) false
- else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0
- else if (tpe <:< definitions.AnyValTpe) 0
- else null
- Literal(Constant(defaultValue))
+ def fresh(name: String): String = if (name.toString.contains("$")) name else currentUnit.freshTermName("" + name + "$").toString
}
def isAwait(fun: Tree) =
fun.symbol == defn.Async_await
- /** Replace all `Ident` nodes referring to one of the keys n `renameMap` with a node
- * referring to the corresponding new name
- */
- def substituteNames(tree: Tree, renameMap: Map[Symbol, Name]): Tree = {
- val renamer = new Transformer {
- override def transform(tree: Tree) = tree match {
- case Ident(_) => (renameMap get tree.symbol).fold(tree)(Ident(_))
- case tt: TypeTree if tt.original != EmptyTree && tt.original != null =>
- // We also have to apply our renaming transform on originals of TypeTrees.
- // TODO 2.10.1 Can we find a cleaner way?
- val symTab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
- val tt1 = tt.asInstanceOf[symTab.TypeTree]
- tt1.setOriginal(transform(tt.original).asInstanceOf[symTab.Tree])
- super.transform(tree)
- case _ => super.transform(tree)
- }
- }
- renamer.transform(tree)
- }
-
- /** Descends into the regions of the tree that are subject to the
- * translation to a state machine by `async`. When a nested template,
- * function, or by-name argument is encountered, the descent stops,
- * and `nestedClass` etc are invoked.
- */
- trait AsyncTraverser extends Traverser {
- def nestedClass(classDef: ClassDef) {
- }
-
- def nestedModule(module: ModuleDef) {
- }
-
- def nestedMethod(module: DefDef) {
- }
-
- def byNameArgument(arg: Tree) {
- }
-
- def function(function: Function) {
- }
-
- def patMatFunction(tree: Match) {
- }
-
- override def traverse(tree: Tree) {
- tree match {
- case cd: ClassDef => nestedClass(cd)
- case md: ModuleDef => nestedModule(md)
- case dd: DefDef => nestedMethod(dd)
- case fun: Function => function(fun)
- case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions`
- case Applied(fun, targs, argss) if argss.nonEmpty =>
- val isInByName = isByName(fun)
- for ((args, i) <- argss.zipWithIndex) {
- for ((arg, j) <- args.zipWithIndex) {
- if (!isInByName(i, j)) traverse(arg)
- else byNameArgument(arg)
- }
- }
- traverse(fun)
- case _ => super.traverse(tree)
- }
- }
- }
-
private lazy val Boolean_ShortCircuits: Set[Symbol] = {
import definitions.BooleanClass
def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName)
@@ -122,57 +47,30 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
Set(Boolean_&&, Boolean_||)
}
- def isByName(fun: Tree): ((Int, Int) => Boolean) = {
+ private def isByName(fun: Tree): ((Int, Int) => Boolean) = {
if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true
else {
- val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
- val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss
+ val paramss = fun.tpe.paramss
val byNamess = paramss.map(_.map(_.isByNameParam))
(i, j) => util.Try(byNamess(i)(j)).getOrElse(false)
}
}
- def argName(fun: Tree): ((Int, Int) => String) = {
- val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
- val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss
+ private def argName(fun: Tree): ((Int, Int) => String) = {
+ val paramss = fun.tpe.paramss
val namess = paramss.map(_.map(_.name.toString))
(i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}")
}
- object Applied {
- val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
- object treeInfo extends {
- val global: symtab.type = symtab
- } with reflect.internal.TreeInfo
-
- def unapply(tree: Tree): Some[(Tree, List[Tree], List[List[Tree]])] = {
- val treeInfo.Applied(core, targs, argss) = tree.asInstanceOf[symtab.Tree]
- Some((core.asInstanceOf[Tree], targs.asInstanceOf[List[Tree]], argss.asInstanceOf[List[List[Tree]]]))
- }
- }
-
- def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match {
- case Block(stats, expr) => (stats, expr)
- case _ => (List(tree), Literal(Constant(())))
- }
-
- def mkVarDefTree(resultType: Type, resultName: TermName): c.Tree = {
- ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType))
- }
-
- def emptyConstructor: DefDef = {
- val emptySuperCall = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil)
- DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), c.literalUnit.tree))
- }
-
- def applied(className: String, types: List[Type]): AppliedTypeTree =
- AppliedTypeTree(Ident(c.mirror.staticClass(className)), types.map(TypeTree(_)))
+ def Expr[A: WeakTypeTag](t: Tree) = global.Expr[A](rootMirror, new FixedMirrorTreeCreator(rootMirror, t))
object defn {
def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = {
- c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree)))
+ Expr(Apply(Ident(definitions.List_apply), args.map(_.tree)))
}
- def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify(self.splice.contains(elem.splice))
+ def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify {
+ self.splice.contains(elem.splice)
+ }
def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify {
self.splice.apply(arg.splice)
@@ -186,146 +84,17 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
self.splice.get
}
- val Try_get = methodSym(reify((null: scala.util.Try[Any]).get))
- val Try_isFailure = methodSym(reify((null: scala.util.Try[Any]).isFailure))
-
- val TryClass = c.mirror.staticClass("scala.util.Try")
+ val TryClass = rootMirror.staticClass("scala.util.Try")
+ val Try_get = TryClass.typeSignature.member(newTermName("get")).ensuring(_ != NoSymbol)
+ val Try_isFailure = TryClass.typeSignature.member(newTermName("isFailure")).ensuring(_ != NoSymbol)
val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe))
- val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal")
-
- private def asyncMember(name: String) = {
- val asyncMod = c.mirror.staticClass("scala.async.AsyncBase")
- val tpe = asyncMod.asType.toType
- tpe.member(newTermName(name)).ensuring(_ != NoSymbol)
- }
-
- val Async_await = asyncMember("await")
- }
-
- /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */
- private def methodSym(apply: c.Expr[Any]): Symbol = {
- val tree2: Tree = c.typeCheck(apply.tree)
- tree2.collect {
- case s: SymTree if s.symbol.isMethod => s.symbol
- }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}"))
- }
-
- /**
- * Using [[scala.reflect.api.Trees.TreeCopier]] copies more than we would like:
- * we don't want to copy types and symbols to the new trees in some cases.
- *
- * Instead, we just copy positions and attachments.
- */
- def attachCopy[T <: Tree](orig: Tree)(tree: T): tree.type = {
- tree.setPos(orig.pos)
- for (att <- orig.attachments.all)
- tree.updateAttachment[Any](att)(ClassTag.apply[Any](att.getClass))
- tree
- }
-
- def resetInternalAttrs(tree: Tree, internalSyms: List[Symbol]) =
- new ResetInternalAttrs(internalSyms.toSet).transform(tree)
-
- /**
- * Adaptation of [[scala.reflect.internal.Trees.ResetAttrs]]
- *
- * A transformer which resets symbol and tpe fields of all nodes in a given tree,
- * with special treatment of:
- * `TypeTree` nodes: are replaced by their original if it exists, otherwise tpe field is reset
- * to empty if it started out empty or refers to local symbols (which are erased).
- * `TypeApply` nodes: are deleted if type arguments end up reverted to empty
- *
- * `This` and `Ident` nodes referring to an external symbol are ''not'' reset.
- */
- private final class ResetInternalAttrs(internalSyms: Set[Symbol]) extends Transformer {
-
- import language.existentials
-
- override def transform(tree: Tree): Tree = super.transform {
- def isExternal = tree.symbol != NoSymbol && !internalSyms(tree.symbol)
-
- tree match {
- case tpt: TypeTree => resetTypeTree(tpt)
- case TypeApply(fn, args)
- if args map transform exists (_.isEmpty) => transform(fn)
- case EmptyTree => tree
- case (_: Ident | _: This) if isExternal => tree // #35 Don't reset the symbol of Ident/This bound outside of the async block
- case _ => resetTree(tree)
- }
- }
-
- private def resetTypeTree(tpt: TypeTree): Tree = {
- if (tpt.original != null)
- transform(tpt.original)
- else if (tpt.tpe != null && tpt.asInstanceOf[symtab.TypeTree forSome {val symtab: reflect.internal.SymbolTable}].wasEmpty) {
- val dupl = tpt.duplicate
- dupl.tpe = null
- dupl
- }
- else tpt
- }
-
- private def resetTree(tree: Tree): Tree = {
- val hasSymbol: Boolean = {
- val reflectInternalTree = tree.asInstanceOf[symtab.Tree forSome {val symtab: reflect.internal.SymbolTable}]
- reflectInternalTree.hasSymbol
- }
- val dupl = tree.duplicate
- if (hasSymbol)
- dupl.symbol = NoSymbol
- dupl.tpe = null
- dupl
- }
- }
-
- /**
- * Replaces expressions of the form `{ new $anon extends PartialFunction[A, B] { ... ; def applyOrElse[..](...) = ... match <cases> }`
- * with `Match(EmptyTree, cases`.
- *
- * This reverses the transformation performed in `Typers`, and works around non-idempotency of typechecking such trees.
- */
- // TODO Reference JIRA issue.
- final def restorePatternMatchingFunctions(tree: Tree) =
- RestorePatternMatchingFunctions transform tree
-
- private object RestorePatternMatchingFunctions extends Transformer {
-
- import language.existentials
- val DefaultCaseName: TermName = "defaultCase$"
-
- override def transform(tree: Tree): Tree = {
- val SYNTHETIC = (1 << 21).toLong.asInstanceOf[FlagSet]
- def isSynthetic(cd: ClassDef) = cd.mods hasFlag SYNTHETIC
-
- /** Is this pattern node a synthetic catch-all case, added during PartialFuction synthesis before we know
- * whether the user provided cases are exhaustive. */
- def isSyntheticDefaultCase(cdef: CaseDef) = cdef match {
- case CaseDef(Bind(DefaultCaseName, _), EmptyTree, _) => true
- case _ => false
- }
- tree match {
- case Block(
- (cd@ClassDef(_, _, _, Template(_, _, body))) :: Nil,
- Apply(Select(New(a), nme.CONSTRUCTOR), Nil)) if isSynthetic(cd) =>
- val restored = (body collectFirst {
- case DefDef(_, /*name.apply | */ name.applyOrElse, _, _, _, Match(_, cases)) =>
- val nonSyntheticCases = cases.takeWhile(cdef => !isSyntheticDefaultCase(cdef))
- val transformedCases = super.transformStats(nonSyntheticCases, currentOwner).asInstanceOf[List[CaseDef]]
- Match(EmptyTree, transformedCases)
- }).getOrElse(c.abort(tree.pos, s"Internal Error: Unable to find original pattern matching cases in: $body"))
- restored
- case t => super.transform(t)
- }
- }
+ val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal")
+ val AsyncClass = rootMirror.staticClass("scala.async.AsyncBase")
+ val Async_await = AsyncClass.typeSignature.member(newTermName("await")).ensuring(_ != NoSymbol)
}
def isSafeToInline(tree: Tree) = {
- val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
- object treeInfo extends {
- val global: symtab.type = symtab
- } with reflect.internal.TreeInfo
- val castTree = tree.asInstanceOf[symtab.Tree]
- treeInfo.isExprSafeToInline(castTree)
+ treeInfo.isExprSafeToInline(tree)
}
/** Map a list of arguments to:
@@ -371,4 +140,104 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
}
}.unzip
}
+
+
+ def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match {
+ case Block(stats, expr) => (stats, expr)
+ case _ => (List(tree), Literal(Constant(())))
+ }
+
+ def emptyConstructor: DefDef = {
+ val emptySuperCall = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil)
+ DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), Literal(Constant(()))))
+ }
+
+ def applied(className: String, types: List[Type]): AppliedTypeTree =
+ AppliedTypeTree(Ident(rootMirror.staticClass(className)), types.map(TypeTree(_)))
+
+ /** Descends into the regions of the tree that are subject to the
+ * translation to a state machine by `async`. When a nested template,
+ * function, or by-name argument is encountered, the descent stops,
+ * and `nestedClass` etc are invoked.
+ */
+ trait AsyncTraverser extends Traverser {
+ def nestedClass(classDef: ClassDef) {
+ }
+
+ def nestedModule(module: ModuleDef) {
+ }
+
+ def nestedMethod(module: DefDef) {
+ }
+
+ def byNameArgument(arg: Tree) {
+ }
+
+ def function(function: Function) {
+ }
+
+ def patMatFunction(tree: Match) {
+ }
+
+ override def traverse(tree: Tree) {
+ tree match {
+ case cd: ClassDef => nestedClass(cd)
+ case md: ModuleDef => nestedModule(md)
+ case dd: DefDef => nestedMethod(dd)
+ case fun: Function => function(fun)
+ case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions`
+ case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty =>
+ val isInByName = isByName(fun)
+ for ((args, i) <- argss.zipWithIndex) {
+ for ((arg, j) <- args.zipWithIndex) {
+ if (!isInByName(i, j)) traverse(arg)
+ else byNameArgument(arg)
+ }
+ }
+ traverse(fun)
+ case _ => super.traverse(tree)
+ }
+ }
+ }
+
+ def abort(pos: Position, msg: String) = throw new AbortMacroException(pos, msg)
+
+ abstract class MacroTypingTransformer extends TypingTransformer(callSiteTyper.context.unit) {
+ currentOwner = callSiteTyper.context.owner
+
+ def currOwner: Symbol = currentOwner
+
+ localTyper = global.analyzer.newTyper(callSiteTyper.context.make(unit = callSiteTyper.context.unit))
+ }
+
+ def transformAt(tree: Tree)(f: PartialFunction[Tree, (analyzer.Context => Tree)]) = {
+ object trans extends MacroTypingTransformer {
+ override def transform(tree: Tree): Tree = {
+ if (f.isDefinedAt(tree)) {
+ f(tree)(localTyper.context)
+ } else super.transform(tree)
+ }
+ }
+ trans.transform(tree)
+ }
+
+ def changeOwner(tree: Tree, oldOwner: Symbol, newOwner: Symbol): tree.type = {
+ new ChangeOwnerAndModuleClassTraverser(oldOwner, newOwner).traverse(tree)
+ tree
+ }
+
+ class ChangeOwnerAndModuleClassTraverser(oldowner: Symbol, newowner: Symbol)
+ extends ChangeOwnerTraverser(oldowner, newowner) {
+
+ override def traverse(tree: Tree) {
+ tree match {
+ case _: DefTree => change(tree.symbol.moduleClass)
+ case _ =>
+ }
+ super.traverse(tree)
+ }
+ }
+
+ def toMultiMap[A, B](as: Iterable[(A, B)]): Map[A, List[B]] =
+ as.toList.groupBy(_._1).mapValues(_.map(_._2).toList).toMap
}
diff --git a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala
index a669cfa..7abc6e8 100644
--- a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala
+++ b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala
@@ -22,21 +22,28 @@ trait AsyncBaseWithCPSFallback extends AsyncBase {
/* Implements `async { ... }` using the CPS plugin.
*/
- protected def cpsBasedAsyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = {
+ protected def cpsBasedAsyncImpl[T: c.WeakTypeTag](c: Context)
+ (body: c.Expr[T])
+ (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = {
import c.universe._
- def lookupMember(name: String) = {
- val asyncTrait = c.mirror.staticClass("scala.async.continuations.AsyncBaseWithCPSFallback")
+ def lookupClassMember(clazz: String, name: String) = {
+ val asyncTrait = c.mirror.staticClass(clazz)
val tpe = asyncTrait.asType.toType
- tpe.member(newTermName(name)).ensuring(_ != NoSymbol)
+ tpe.member(newTermName(name)).ensuring(_ != NoSymbol, s"$clazz.$name")
+ }
+ def lookupObjectMember(clazz: String, name: String) = {
+ val moduleClass = c.mirror.staticModule(clazz).moduleClass
+ val tpe = moduleClass.asType.toType
+ tpe.member(newTermName(name)).ensuring(_ != NoSymbol, s"$clazz.$name")
}
AsyncUtils.vprintln("AsyncBaseWithCPSFallback.cpsBasedAsyncImpl")
- val utils = TransformUtils[c.type](c)
- val futureSystemOps = futureSystem.mkOps(c)
- val awaitSym = utils.defn.Async_await
- val awaitFallbackSym = lookupMember("awaitFallback")
+ val symTab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
+ val futureSystemOps = futureSystem.mkOps(symTab)
+ val awaitSym = lookupObjectMember("scala.async.Async", "await")
+ val awaitFallbackSym = lookupClassMember("scala.async.continuations.AsyncBaseWithCPSFallback", "awaitFallback")
// replace `await` invocations with `awaitFallback` invocations
val awaitReplacer = new Transformer {
@@ -60,10 +67,12 @@ trait AsyncBaseWithCPSFallback extends AsyncBase {
}.asInstanceOf[Future[T]]
*/
+ def spawn(expr: Tree) = futureSystemOps.spawn(expr.asInstanceOf[futureSystemOps.universe.Tree], execContext.tree.asInstanceOf[futureSystemOps.universe.Tree]).asInstanceOf[Tree]
+
val bodyWithFuture = {
val tree = bodyWithAwaitFallback match {
- case Block(stmts, expr) => Block(stmts, futureSystemOps.spawn(expr))
- case expr => futureSystemOps.spawn(expr)
+ case Block(stmts, expr) => Block(stmts, spawn(expr))
+ case expr => spawn(expr)
}
c.Expr[futureSystem.Fut[Any]](c.resetAllAttrs(tree.duplicate))
}
@@ -71,20 +80,22 @@ trait AsyncBaseWithCPSFallback extends AsyncBase {
val bodyWithReset: c.Expr[futureSystem.Fut[Any]] = reify {
reset { bodyWithFuture.splice }
}
- val bodyWithCast = futureSystemOps.castTo[T](bodyWithReset)
+ val bodyWithCast = futureSystemOps.castTo[T](bodyWithReset.asInstanceOf[futureSystemOps.universe.Expr[futureSystem.Fut[Any]]]).asInstanceOf[c.Expr[futureSystem.Fut[T]]]
AsyncUtils.vprintln(s"CPS-based async transform expands to:\n${bodyWithCast.tree}")
bodyWithCast
}
- override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = {
+ override def asyncImpl[T: c.WeakTypeTag](c: Context)
+ (body: c.Expr[T])
+ (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = {
AsyncUtils.vprintln("AsyncBaseWithCPSFallback.asyncImpl")
- val analyzer = AsyncAnalysis[c.type](c, this)
+ val asyncMacro = AsyncMacro(c, futureSystem)
- if (!analyzer.reportUnsupportedAwaits(body.tree))
- super.asyncImpl[T](c)(body) // no unsupported awaits
+ if (!asyncMacro.reportUnsupportedAwaits(body.tree.asInstanceOf[asyncMacro.global.Tree], report = fallbackEnabled))
+ super.asyncImpl[T](c)(body)(execContext) // no unsupported awaits
else
- cpsBasedAsyncImpl[T](c)(body) // fallback to CPS
+ cpsBasedAsyncImpl[T](c)(body)(execContext) // fallback to CPS
}
}
diff --git a/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala
index fe6e1a6..e0da5aa 100644
--- a/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala
+++ b/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala
@@ -13,8 +13,13 @@ import scala.concurrent.Future
trait AsyncWithCPSFallback extends AsyncBaseWithCPSFallback with ScalaConcurrentCPSFallback
object AsyncWithCPSFallback extends AsyncWithCPSFallback {
+ import scala.concurrent.{ExecutionContext, Future}
- def async[T](body: T) = macro asyncImpl[T]
+ def async[T](body: T)(implicit execContext: ExecutionContext): Future[T] = macro asyncImpl[T]
- override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body)
+ override def asyncImpl[T: c.WeakTypeTag](c: Context)
+ (body: c.Expr[T])
+ (execContext: c.Expr[ExecutionContext]): c.Expr[Future[T]] = {
+ super.asyncImpl[T](c)(body)(execContext)
+ }
}
diff --git a/src/main/scala/scala/async/continuations/CPSBasedAsync.scala b/src/main/scala/scala/async/continuations/CPSBasedAsync.scala
index 922d1ac..2003082 100644
--- a/src/main/scala/scala/async/continuations/CPSBasedAsync.scala
+++ b/src/main/scala/scala/async/continuations/CPSBasedAsync.scala
@@ -8,14 +8,17 @@ package continuations
import scala.language.experimental.macros
import scala.reflect.macros.Context
-import scala.concurrent.Future
+import scala.concurrent.{ExecutionContext, Future}
trait CPSBasedAsync extends CPSBasedAsyncBase with ScalaConcurrentCPSFallback
object CPSBasedAsync extends CPSBasedAsync {
- def async[T](body: T) = macro asyncImpl[T]
-
- override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body)
+ def async[T](body: T)(implicit execContext: ExecutionContext): Future[T] = macro asyncImpl[T]
+ override def asyncImpl[T: c.WeakTypeTag](c: Context)
+ (body: c.Expr[T])
+ (execContext: c.Expr[ExecutionContext]): c.Expr[Future[T]] = {
+ super.asyncImpl[T](c)(body)(execContext)
+ }
}
diff --git a/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala b/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala
index 4e8ec80..a350704 100644
--- a/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala
+++ b/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala
@@ -15,7 +15,9 @@ import scala.util.continuations._
*/
trait CPSBasedAsyncBase extends AsyncBaseWithCPSFallback {
- override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] =
- super.cpsBasedAsyncImpl[T](c)(body)
-
+ override def asyncImpl[T: c.WeakTypeTag](c: Context)
+ (body: c.Expr[T])
+ (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = {
+ super.cpsBasedAsyncImpl[T](c)(body)(execContext)
+ }
}
diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala
index 43393a7..a06437d 100644
--- a/src/test/scala/scala/async/TreeInterrogation.scala
+++ b/src/test/scala/scala/async/TreeInterrogation.scala
@@ -40,8 +40,8 @@ class TreeInterrogation {
val varDefs = tree1.collect {
case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name
}
- varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2"))
- varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2"))
+ varDefs.map(_.decoded.trim).toSet mustBe (Set("state", "await$1", "await$2"))
+ varDefs.map(_.decoded.trim).toSet mustBe (Set("state", "await$1", "await$2"))
val defDefs = tree1.collect {
case t: Template =>
@@ -68,7 +68,7 @@ object TreeInterrogation extends App {
withDebug {
val cm = reflect.runtime.currentMirror
- val tb = mkToolbox("-cp ${toolboxClasspath} -Xprint:flatten")
+ val tb = mkToolbox("-cp ${toolboxClasspath} -Xprint:typer -uniqid")
import scala.async.Async._
val tree = tb.parse(
""" import _root_.scala.async.AsyncId.{async, await}
diff --git a/src/test/scala/scala/async/neg/LocalClasses0Spec.scala b/src/test/scala/scala/async/neg/LocalClasses0Spec.scala
index 2569303..dcd9bb8 100644
--- a/src/test/scala/scala/async/neg/LocalClasses0Spec.scala
+++ b/src/test/scala/scala/async/neg/LocalClasses0Spec.scala
@@ -5,121 +5,32 @@
package scala.async
package neg
-/**
- * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
- */
-
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.Test
@RunWith(classOf[JUnit4])
class LocalClasses0Spec {
-
@Test
- def `reject a local class`() {
- expectError("Local case class Person illegal within `async` block") {
- """
- | import scala.concurrent.ExecutionContext.Implicits.global
- | import scala.async.Async._
- |
- | async {
- | case class Person(name: String)
- | }
- """.stripMargin
- }
+ def localClassCrashIssue16() {
+ import scala.async.AsyncId.{async, await}
+ async {
+ class B { def f = 1 }
+ await(new B()).f
+ } mustBe 1
}
@Test
- def `reject a local class 2`() {
- expectError("Local case class Person illegal within `async` block") {
- """
- | import scala.concurrent.{Future, ExecutionContext}
- | import ExecutionContext.Implicits.global
- | import scala.async.Async._
- |
- | async {
- | case class Person(name: String)
- | val fut = Future { 5 }
- | val x = await(fut)
- | x
- | }
- """.stripMargin
- }
+ def nestedCaseClassAndModuleAllowed() {
+ import AsyncId.{await, async}
+ async {
+ trait Base { def base = 0}
+ await(0)
+ case class Person(name: String) extends Base
+ val fut = async { "bob" }
+ val x = Person(await(fut))
+ x.base
+ x.name
+ } mustBe "bob"
}
-
- @Test
- def `reject a local class 3`() {
- expectError("Local case class Person illegal within `async` block") {
- """
- | import scala.concurrent.{Future, ExecutionContext}
- | import ExecutionContext.Implicits.global
- | import scala.async.Async._
- |
- | async {
- | val fut = Future { 5 }
- | val x = await(fut)
- | case class Person(name: String)
- | x
- | }
- """.stripMargin
- }
- }
-
- @Test
- def `reject a local class with symbols in its name`() {
- expectError("Local case class :: illegal within `async` block") {
- """
- | import scala.concurrent.{Future, ExecutionContext}
- | import ExecutionContext.Implicits.global
- | import scala.async.Async._
- |
- | async {
- | val fut = Future { 5 }
- | val x = await(fut)
- | case class ::(name: String)
- | x
- | }
- """.stripMargin
- }
- }
-
- @Test
- def `reject a nested local class`() {
- expectError("Local case class Person illegal within `async` block") {
- """
- | import scala.concurrent.{Future, ExecutionContext}
- | import ExecutionContext.Implicits.global
- | import scala.async.Async._
- |
- | async {
- | val fut = Future { 5 }
- | val x = 2 + 2
- | var y = 0
- | if (x > 0) {
- | case class Person(name: String)
- | y = await(fut)
- | } else {
- | y = x
- | }
- | y
- | }
- """.stripMargin
- }
- }
-
- @Test
- def `reject a local singleton object`() {
- expectError("Local object Person illegal within `async` block") {
- """
- | import scala.concurrent.ExecutionContext.Implicits.global
- | import scala.async.Async._
- |
- | async {
- | object Person { val name = "Joe" }
- | }
- """.stripMargin
- }
- }
-
}
diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
index 7be6299..abce3ce 100644
--- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
+++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
@@ -238,7 +238,7 @@ class AnfTransformSpec {
val res = async {
var i = 0
def get = {i += 1; i}
- foo(get)(get)
+ foo(get)(await(get))
}
res mustBe "a0 = 1, b0 = 2"
}
diff --git a/src/test/scala/scala/async/run/nesteddef/NestedDef.scala b/src/test/scala/scala/async/run/nesteddef/NestedDef.scala
index ee0a78e..cf74602 100644
--- a/src/test/scala/scala/async/run/nesteddef/NestedDef.scala
+++ b/src/test/scala/scala/async/run/nesteddef/NestedDef.scala
@@ -37,4 +37,60 @@ class NestedDef {
}
result mustBe ((0d, 44d, 2))
}
+
+ // We must lift `foo` and `bar` in the next two tests.
+ @Test
+ def nestedDefTransitive1() {
+ import AsyncId._
+ val result = async {
+ val a = 0
+ val x = await(a) - 1
+ def bar = a
+ def foo = bar
+ foo
+ }
+ result mustBe 0
+ }
+
+ @Test
+ def nestedDefTransitive2() {
+ import AsyncId._
+ val result = async {
+ val a = 0
+ val x = await(a) - 1
+ def bar = a
+ def foo = bar
+ 0
+ }
+ result mustBe 0
+ }
+
+
+ // checking that our use/definition analysis doesn't cycle.
+ @Test
+ def mutuallyRecursive1() {
+ import AsyncId._
+ val result = async {
+ val a = 0
+ val x = await(a) - 1
+ def foo: Int = if (true) 0 else bar
+ def bar: Int = if (true) 0 else foo
+ bar
+ }
+ result mustBe 0
+ }
+
+ // checking that our use/definition analysis doesn't cycle.
+ @Test
+ def mutuallyRecursive2() {
+ import AsyncId._
+ val result = async {
+ val a = 0
+ def foo: Int = if (true) 0 else bar
+ def bar: Int = if (true) 0 else foo
+ val x = await(a) - 1
+ bar
+ }
+ result mustBe 0
+ }
}
diff --git a/src/test/scala/scala/async/run/toughtype/ToughType.scala b/src/test/scala/scala/async/run/toughtype/ToughType.scala
index 83f5a2d..6fcd966 100644
--- a/src/test/scala/scala/async/run/toughtype/ToughType.scala
+++ b/src/test/scala/scala/async/run/toughtype/ToughType.scala
@@ -67,4 +67,42 @@ class ToughTypeSpec {
await(f(2))
} mustBe 3
}
+
+ @Test def existentialBindIssue19() {
+ import AsyncId.{await, async}
+ def m7(a: Any) = async {
+ a match {
+ case s: Seq[_] =>
+ val x = s.size
+ var ss = s
+ ss = s
+ await(x)
+ }
+ }
+ m7(Nil) mustBe 0
+ }
+
+ @Test def existentialBind2Issue19() {
+ import scala.async.Async._, scala.concurrent.ExecutionContext.Implicits.global
+ def conjure[T]: T = null.asInstanceOf[T]
+
+ def m3 = async {
+ val p: List[Option[_]] = conjure[List[Option[_]]]
+ await(future(1))
+ }
+
+ def m4 = async {
+ await(future[List[_]](Nil))
+ }
+ }
+
+ @Test def singletonTypeIssue17() {
+ import scala.async.AsyncId.{async, await}
+ class A { class B }
+ async {
+ val a = new A
+ def foo(b: a.B) = 0
+ await(foo(new a.B))
+ }
+ }
}