From 82232ec47effb4a6b67b3a0792e1c7600e2d31b7 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Tue, 2 Jul 2013 15:55:34 +0200 Subject: An overdue overhaul of macro internals. - Avoid reset + retypecheck, instead hang onto the original types/symbols - Eliminated duplication between AsyncDefinitionUseAnalyzer and ExprBuilder - Instead, decide what do lift *after* running ExprBuilder - Account for transitive references local classes/objects and lift them as needed. - Make the execution context an regular implicit parameter of the macro - Fixes interaction with existential skolems and singleton types Fixes #6, #13, #16, #17, #19, #21. --- src/main/scala/scala/async/AnfTransform.scala | 450 ++++++++++----------- src/main/scala/scala/async/Async.scala | 139 ++----- src/main/scala/scala/async/AsyncAnalysis.scala | 133 +----- src/main/scala/scala/async/AsyncMacro.scala | 29 ++ src/main/scala/scala/async/AsyncTransform.scala | 176 ++++++++ src/main/scala/scala/async/ExprBuilder.scala | 205 +++++----- src/main/scala/scala/async/FutureSystem.scala | 50 ++- src/main/scala/scala/async/Lifter.scala | 150 +++++++ src/main/scala/scala/async/TransformUtils.scala | 391 ++++++------------ .../continuations/AsyncBaseWithCPSFallback.scala | 43 +- .../async/continuations/AsyncWithCPSFallback.scala | 9 +- .../scala/async/continuations/CPSBasedAsync.scala | 11 +- .../async/continuations/CPSBasedAsyncBase.scala | 8 +- src/test/scala/scala/async/TreeInterrogation.scala | 6 +- .../scala/scala/async/neg/LocalClasses0Spec.scala | 123 +----- .../scala/async/run/anf/AnfTransformSpec.scala | 2 +- .../scala/async/run/nesteddef/NestedDef.scala | 56 +++ .../scala/async/run/toughtype/ToughType.scala | 38 ++ 18 files changed, 1021 insertions(+), 998 deletions(-) create mode 100644 src/main/scala/scala/async/AsyncMacro.scala create mode 100644 src/main/scala/scala/async/AsyncTransform.scala create mode 100644 src/main/scala/scala/async/Lifter.scala diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 5b9901d..275bc49 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -5,270 +5,246 @@ package scala.async -import scala.reflect.macros.Context +import scala.tools.nsc.Global -private[async] final case class AnfTransform[C <: Context](c: C) { +private[async] trait AnfTransform { + self: AsyncMacro => - import c.universe._ + import global._ + import reflect.internal.Flags._ - val utils = TransformUtils[c.type](c) - - import utils._ - - def apply(tree: Tree): List[Tree] = { - val unique = uniqueNames(tree) + def anfTransform(tree: Tree): Block = { // Must prepend the () for issue #31. - anf.transformToList(Block(List(c.literalUnit.tree), unique)) - } + val block = callSiteTyper.typedPos(tree.pos)(Block(List(Literal(Constant(()))), tree)).setType(tree.tpe) - private def uniqueNames(tree: Tree): Tree = { - new UniqueNames(tree).transform(tree) + new SelectiveAnfTransform().transform(block) } - /** Assigns unique names to all definitions in a tree, and adjusts references to use the new name. - * Only modifies names that appear more than once in the tree. - * - * This step is needed to allow us to safely merge blocks during the `inline` transform below. - */ - private final class UniqueNames(tree: Tree) extends Transformer { - val repeatedNames: Set[Symbol] = { - class DuplicateNameTraverser extends AsyncTraverser { - val result = collection.mutable.Buffer[Symbol]() - - override def traverse(tree: Tree) { - tree match { - case dt: DefTree => result += dt.symbol - case _ => super.traverse(tree) - } - } - } - val dupNameTraverser = new DuplicateNameTraverser - dupNameTraverser.traverse(tree) - dupNameTraverser.result.groupBy(x => x.name).filter(_._2.size > 1).values.flatten.toSet[Symbol] - } + sealed abstract class AnfMode + + case object Anf extends AnfMode - /** Stepping outside of the public Macro API to call [[scala.reflect.internal.Symbols.Symbol.name_=]] */ - val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] + case object Linearizing extends AnfMode - val renamed = collection.mutable.Set[Symbol]() + final class SelectiveAnfTransform extends MacroTypingTransformer { + var mode: AnfMode = Anf - override def transform(tree: Tree): Tree = { + def blockToList(tree: Tree): List[Tree] = tree match { + case Block(stats, expr) => stats :+ expr + case t => t :: Nil + } + + def listToBlock(trees: List[Tree]): Block = trees match { + case trees @ (init :+ last) => + val pos = trees.map(_.pos).reduceLeft(_ union _) + Block(init, last).setType(last.tpe).setPos(pos) + } + + override def transform(tree: Tree): Block = { + def anfLinearize: Block = { + val trees: List[Tree] = mode match { + case Anf => anf._transformToList(tree) + case Linearizing => linearize._transformToList(tree) + } + listToBlock(trees) + } tree match { - case defTree: DefTree if repeatedNames(defTree.symbol) => - val trans = super.transform(defTree) - val origName = defTree.symbol.name - val sym = defTree.symbol.asInstanceOf[symtab.Symbol] - val fresh = name.fresh(sym.name.toString) - sym.name = origName match { - case _: TermName => symtab.newTermName(fresh) - case _: TypeName => symtab.newTypeName(fresh) - } - renamed += trans.symbol - val newName = trans.symbol.name - trans match { - case ValDef(mods, name, tpt, rhs) => - treeCopy.ValDef(trans, mods, newName, tpt, rhs) - case Bind(name, body) => - treeCopy.Bind(trans, newName, body) - case DefDef(mods, name, tparams, vparamss, tpt, rhs) => - treeCopy.DefDef(trans, mods, newName, tparams, vparamss, tpt, rhs) - case TypeDef(mods, name, tparams, rhs) => - treeCopy.TypeDef(tree, mods, newName, tparams, transform(rhs)) - // If we were to allow local classes / objects, we would need to rename here. - case ClassDef(mods, name, tparams, impl) => - treeCopy.ClassDef(tree, mods, newName, tparams, transform(impl).asInstanceOf[Template]) - case ModuleDef(mods, name, impl) => - treeCopy.ModuleDef(tree, mods, newName, transform(impl).asInstanceOf[Template]) - case x => super.transform(x) - } - case Ident(name) => - if (renamed(tree.symbol)) treeCopy.Ident(tree, tree.symbol.name) - else tree - case Select(fun, name) => - if (renamed(tree.symbol)) { - treeCopy.Select(tree, transform(fun), tree.symbol.name) - } else super.transform(tree) - case tt: TypeTree => - val tt1 = tt.asInstanceOf[symtab.TypeTree] - val orig = tt1.original - if (orig != null) tt1.setOriginal(transform(orig.asInstanceOf[Tree]).asInstanceOf[symtab.Tree]) - super.transform(tt) - case _ => super.transform(tree) + case _: ValDef | _: DefDef | _: Function | _: ClassDef | _: TypeDef => + atOwner(tree.symbol)(anfLinearize) + case _: ModuleDef => + atOwner(tree.symbol.moduleClass orElse tree.symbol)(anfLinearize) + case _ => + anfLinearize } } - } - private object trace { - private var indent = -1 - - def indentString = " " * indent - - def apply[T](prefix: String, args: Any)(t: => T): T = { - indent += 1 - def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127) - try { - AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") - val result = t - AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") - result - } finally { - indent -= 1 + private object linearize { + def transformToList(tree: Tree): List[Tree] = { + mode = Linearizing; blockToList(transform(tree)) } - } - } - private object inline { - def transformToList(tree: Tree): List[Tree] = trace("inline", tree) { - val stats :+ expr = anf.transformToList(tree) - expr match { - case Apply(fun, args) if isAwait(fun) => - val valDef = defineVal(name.await, expr, tree.pos) - stats :+ valDef :+ Ident(valDef.name) - - case If(cond, thenp, elsep) => - // if type of if-else is Unit don't introduce assignment, - // but add Unit value to bring it into form expected by async transform - if (expr.tpe =:= definitions.UnitTpe) { - stats :+ expr :+ Literal(Constant(())) - } else { - val varDef = defineVar(name.ifRes, expr.tpe, tree.pos) - def branchWithAssign(orig: Tree) = orig match { - case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.name), thenExpr)) - case _ => Assign(Ident(varDef.name), orig) + def transformToBlock(tree: Tree): Block = listToBlock(transformToList(tree)) + + def _transformToList(tree: Tree): List[Tree] = trace(tree) { + val stats :+ expr = anf.transformToList(tree) + expr match { + case Apply(fun, args) if isAwait(fun) => + val valDef = defineVal(name.await, expr, tree.pos) + stats :+ valDef :+ gen.mkAttributedStableRef(valDef.symbol) + + case If(cond, thenp, elsep) => + // if type of if-else is Unit don't introduce assignment, + // but add Unit value to bring it into form expected by async transform + if (expr.tpe =:= definitions.UnitTpe) { + stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(()))) + } else { + val varDef = defineVar(name.ifRes, expr.tpe, tree.pos) + def branchWithAssign(orig: Tree) = localTyper.typedPos(orig.pos)( + orig match { + case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), thenExpr)) + case _ => Assign(Ident(varDef.symbol), orig) + } + ).setType(orig.tpe) + val ifWithAssign = treeCopy.If(tree, cond, branchWithAssign(thenp), branchWithAssign(elsep)) + stats :+ varDef :+ ifWithAssign :+ gen.mkAttributedStableRef(varDef.symbol) + } + + case Match(scrut, cases) => + // if type of match is Unit don't introduce assignment, + // but add Unit value to bring it into form expected by async transform + if (expr.tpe =:= definitions.UnitTpe) { + stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(()))) } - val ifWithAssign = If(cond, branchWithAssign(thenp), branchWithAssign(elsep)) - stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name) - } - - case Match(scrut, cases) => - // if type of match is Unit don't introduce assignment, - // but add Unit value to bring it into form expected by async transform - if (expr.tpe =:= definitions.UnitTpe) { - stats :+ expr :+ Literal(Constant(())) - } - else { - val varDef = defineVar(name.matchRes, expr.tpe, tree.pos) - val casesWithAssign = cases map { - case cd@CaseDef(pat, guard, Block(caseStats, caseExpr)) => - attachCopy(cd)(CaseDef(pat, guard, Block(caseStats, Assign(Ident(varDef.name), caseExpr)))) - case cd@CaseDef(pat, guard, body) => - attachCopy(cd)(CaseDef(pat, guard, Assign(Ident(varDef.name), body))) + else { + val varDef = defineVar(name.matchRes, expr.tpe, tree.pos) + def typedAssign(lhs: Tree) = + localTyper.typedPos(lhs.pos)(Assign(Ident(varDef.symbol), lhs)) + val casesWithAssign = cases map { + case cd@CaseDef(pat, guard, body) => + val newBody = body match { + case b@Block(caseStats, caseExpr) => treeCopy.Block(b, caseStats, typedAssign(caseExpr)) + case _ => typedAssign(body) + } + treeCopy.CaseDef(cd, pat, guard, newBody) + } + val matchWithAssign = treeCopy.Match(tree, scrut, casesWithAssign) + require(matchWithAssign.tpe != null, matchWithAssign) + stats :+ varDef :+ matchWithAssign :+ gen.mkAttributedStableRef(varDef.symbol) } - val matchWithAssign = attachCopy(tree)(Match(scrut, casesWithAssign)) - stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name) - } - case _ => - stats :+ expr + case _ => + stats :+ expr + } } - } - def transformToList(trees: List[Tree]): List[Tree] = trees flatMap transformToList + private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = { + val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(tp) + ValDef(sym, gen.mkZero(tp)).setType(NoType).setPos(pos) + } + } - def transformToBlock(tree: Tree): Block = transformToList(tree) match { - case stats :+ expr => Block(stats, expr) + private object trace { + private var indent = -1 + + def indentString = " " * indent + + def apply[T](args: Any)(t: => T): T = { + def prefix = mode.toString.toLowerCase + indent += 1 + def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127) + try { + AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") + val result = t + AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") + result + } finally { + indent -= 1 + } + } } - private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = { - val vd = ValDef(Modifiers(Flag.MUTABLE), name.fresh(prefix), TypeTree(tp), defaultValue(tp)) - vd.setPos(pos) - vd + private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = { + val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, SYNTHETIC).setInfo(lhs.tpe) + changeOwner(lhs, currentOwner, sym) + ValDef(sym, changeOwner(lhs, currentOwner, sym)).setType(NoType).setPos(pos) } - } - private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = { - val vd = ValDef(NoMods, name.fresh(prefix), TypeTree(), lhs) - vd.setPos(pos) - vd - } + private object anf { + def transformToList(tree: Tree): List[Tree] = { + mode = Anf; blockToList(transform(tree)) + } - private object anf { - - private[AnfTransform] def transformToList(tree: Tree): List[Tree] = trace("anf", tree) { - val containsAwait = tree exists isAwait - if (!containsAwait) { - List(tree) - } else tree match { - case Select(qual, sel) => - val stats :+ expr = inline.transformToList(qual) - stats :+ attachCopy(tree)(Select(expr, sel).setSymbol(tree.symbol)) - - case Applied(fun, targs, argss) if argss.nonEmpty => - // we an assume that no await call appears in a by-name argument position, - // this has already been checked. - val funStats :+ simpleFun = inline.transformToList(fun) - def isAwaitRef(name: Name) = name.toString.startsWith(utils.name.await + "$") - val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) = - mapArgumentss[List[Tree]](fun, argss) { - case Arg(expr, byName, _) if byName || isSafeToInline(expr) => (Nil, expr) - case Arg(expr@Ident(name), _, _) if isAwaitRef(name) => (Nil, expr) // not typed, so it eludes the check in `isSafeToInline` - case Arg(expr, _, argName) => - inline.transformToList(expr) match { - case stats :+ expr1 => - val valDef = defineVal(argName, expr1, expr.pos) - (stats :+ valDef, Ident(valDef.name)) - } - } - val core = if (targs.isEmpty) simpleFun else TypeApply(simpleFun, targs) - val newApply = argExprss.foldLeft(core)(Apply(_, _)).setSymbol(tree.symbol) - funStats ++ argStatss.flatten.flatten :+ attachCopy(tree)(newApply) - case Block(stats, expr) => - inline.transformToList(stats :+ expr) - - case ValDef(mods, name, tpt, rhs) => - if (rhs exists isAwait) { - val stats :+ expr = inline.transformToList(rhs) - stats :+ attachCopy(tree)(ValDef(mods, name, tpt, expr).setSymbol(tree.symbol)) - } else List(tree) - - case Assign(lhs, rhs) => - val stats :+ expr = inline.transformToList(rhs) - stats :+ attachCopy(tree)(Assign(lhs, expr)) - - case If(cond, thenp, elsep) => - val condStats :+ condExpr = inline.transformToList(cond) - val thenBlock = inline.transformToBlock(thenp) - val elseBlock = inline.transformToBlock(elsep) - // Typechecking with `condExpr` as the condition fails if the condition - // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems - // we rely on this call to `typeCheck` descending into the branches. - // But, we can get away with typechecking a throwaway `If` tree with the - // original scrutinee and the new branches, and setting that type on - // the real `If` tree. - val ifType = c.typeCheck(If(cond, thenBlock, elseBlock)).tpe - condStats :+ - attachCopy(tree)(If(condExpr, thenBlock, elseBlock)).setType(ifType) - - case Match(scrut, cases) => - val scrutStats :+ scrutExpr = inline.transformToList(scrut) - val caseDefs = cases map { - case CaseDef(pat, guard, body) => - // extract local variables for all names bound in `pat`, and rewrite `body` - // to refer to these. - // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`. - val block = inline.transformToBlock(body) - val (valDefs, mappings) = (pat collect { - case b@Bind(name, _) => - val newName = newTermName(utils.name.fresh(name.toTermName + utils.name.bindSuffix)) - val vd = ValDef(NoMods, newName, TypeTree(), Ident(b.symbol)) - (vd, (b.symbol, newName)) - }).unzip - val Block(stats1, expr1) = utils.substituteNames(block, mappings.toMap).asInstanceOf[Block] - attachCopy(tree)(CaseDef(pat, guard, Block(valDefs ++ stats1, expr1))) - } - // Refer to comments the translation of `If` above. - val matchType = c.typeCheck(Match(scrut, caseDefs)).tpe - val typedMatch = attachCopy(tree)(Match(scrutExpr, caseDefs)).setType(tree.tpe) - scrutStats :+ typedMatch - - case LabelDef(name, params, rhs) => - List(LabelDef(name, params, Block(inline.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) - - case TypeApply(fun, targs) => - val funStats :+ simpleFun = inline.transformToList(fun) - funStats :+ attachCopy(tree)(TypeApply(simpleFun, targs).setSymbol(tree.symbol)) - - case _ => + def _transformToList(tree: Tree): List[Tree] = trace(tree) { + val containsAwait = tree exists isAwait + if (!containsAwait) { List(tree) + } else tree match { + case Select(qual, sel) => + val stats :+ expr = linearize.transformToList(qual) + stats :+ treeCopy.Select(tree, expr, sel) + + case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty => + // we an assume that no await call appears in a by-name argument position, + // this has already been checked. + val funStats :+ simpleFun = linearize.transformToList(fun) + def isAwaitRef(name: Name) = name.toString.startsWith(AnfTransform.this.name.await + "$") + val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) = + mapArgumentss[List[Tree]](fun, argss) { + case Arg(expr, byName, _) if byName /*|| isPure(expr) TODO */ => (Nil, expr) + case Arg(expr@Ident(name), _, _) if isAwaitRef(name) => (Nil, expr) // TODO needed? // not typed, so it eludes the check in `isSafeToInline` + case Arg(expr, _, argName) => + linearize.transformToList(expr) match { + case stats :+ expr1 => + val valDef = defineVal(argName, expr1, expr1.pos) + require(valDef.tpe != null, valDef) + val stats1 = stats :+ valDef + //stats1.foreach(changeOwner(_, currentOwner, currentOwner.owner)) + (stats1, gen.stabilize(gen.mkAttributedIdent(valDef.symbol))) + } + } + val applied = treeInfo.dissectApplied(tree) + val core = if (targs.isEmpty) simpleFun else treeCopy.TypeApply(applied.callee, simpleFun, targs) + val newApply = argExprss.foldLeft(core)(Apply(_, _)).setSymbol(tree.symbol) + val typedNewApply = localTyper.typedPos(tree.pos)(newApply).setType(tree.tpe) + funStats ++ argStatss.flatten.flatten :+ typedNewApply + case Block(stats, expr) => + (stats :+ expr).flatMap(linearize.transformToList) + + case ValDef(mods, name, tpt, rhs) => + if (rhs exists isAwait) { + val stats :+ expr = atOwner(currOwner.owner)(linearize.transformToList(rhs)) + stats.foreach(changeOwner(_, currOwner, currOwner.owner)) + stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr) + } else List(tree) + + case Assign(lhs, rhs) => + val stats :+ expr = linearize.transformToList(rhs) + stats :+ treeCopy.Assign(tree, lhs, expr) + + case If(cond, thenp, elsep) => + val condStats :+ condExpr = linearize.transformToList(cond) + val thenBlock = linearize.transformToBlock(thenp) + val elseBlock = linearize.transformToBlock(elsep) + // Typechecking with `condExpr` as the condition fails if the condition + // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems + // we rely on this call to `typeCheck` descending into the branches. + // But, we can get away with typechecking a throwaway `If` tree with the + // original scrutinee and the new branches, and setting that type on + // the real `If` tree. + val iff = treeCopy.If(tree, condExpr, thenBlock, elseBlock) + condStats :+ iff + + case Match(scrut, cases) => + val scrutStats :+ scrutExpr = linearize.transformToList(scrut) + val caseDefs = cases map { + case CaseDef(pat, guard, body) => + // extract local variables for all names bound in `pat`, and rewrite `body` + // to refer to these. + // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`. + val block = linearize.transformToBlock(body) + val (valDefs, mappings) = (pat collect { + case b@Bind(name, _) => + val vd = defineVal(name.toTermName + AnfTransform.this.name.bindSuffix, gen.mkAttributedStableRef(b.symbol), b.pos) + (vd, (b.symbol, vd.symbol)) + }).unzip + val (from, to) = mappings.unzip + val b@Block(stats1, expr1) = block.substituteSymbols(from, to).asInstanceOf[Block] + val newBlock = treeCopy.Block(b, valDefs ++ stats1, expr1) + treeCopy.CaseDef(tree, pat, guard, newBlock) + } + // Refer to comments the translation of `If` above. + val typedMatch = treeCopy.Match(tree, scrutExpr, caseDefs) + scrutStats :+ typedMatch + + case LabelDef(name, params, rhs) => + List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) + + case TypeApply(fun, targs) => + val funStats :+ simpleFun = linearize.transformToList(fun) + funStats :+ treeCopy.TypeApply(tree, simpleFun, targs) + + case _ => + List(tree) + } } } } diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index 35d3687..5f577cf 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -7,6 +7,9 @@ package scala.async import scala.language.experimental.macros import scala.reflect.macros.Context import scala.reflect.internal.annotations.compileTimeOnly +import scala.tools.nsc.Global +import language.reflectiveCalls +import scala.concurrent.ExecutionContext object Async extends AsyncBase { @@ -15,18 +18,22 @@ object Async extends AsyncBase { lazy val futureSystem = ScalaConcurrentFutureSystem type FS = ScalaConcurrentFutureSystem.type - def async[T](body: T) = macro asyncImpl[T] + def async[T](body: T)(implicit execContext: ExecutionContext): Future[T] = macro asyncImpl[T] - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body) + override def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[Future[T]] = { + super.asyncImpl[T](c)(body)(execContext) + } } object AsyncId extends AsyncBase { lazy val futureSystem = IdentityFutureSystem type FS = IdentityFutureSystem.type - def async[T](body: T) = macro asyncImpl[T] + def async[T](body: T) = macro asyncIdImpl[T] - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = super.asyncImpl[T](c)(body) + def asyncIdImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit) } /** @@ -62,124 +69,26 @@ abstract class AsyncBase { protected[async] def fallbackEnabled = false - def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { + def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { import c.universe._ - val analyzer = AsyncAnalysis[c.type](c, this) - val utils = TransformUtils[c.type](c) - import utils.{name, defn} - - analyzer.reportUnsupportedAwaits(body.tree) - - // Transform to A-normal form: - // - no await calls in qualifiers or arguments, - // - if/match only used in statement position. - val anfTree: Block = { - val anf = AnfTransform[c.type](c) - val restored = utils.restorePatternMatchingFunctions(body.tree) - val stats1 :+ expr1 = anf(restored) - val block = Block(stats1, expr1) - c.typeCheck(block).asInstanceOf[Block] - } - - // Analyze the block to find locals that will be accessed from multiple - // states of our generated state machine, e.g. a value assigned before - // an `await` and read afterwards. - val renameMap: Map[Symbol, TermName] = { - analyzer.defTreesUsedInSubsequentStates(anfTree).map { - vd => - (vd.symbol, name.fresh(vd.name.toTermName)) - }.toMap - } - - val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem, anfTree) - import builder.futureSystemOps - val asyncBlock: builder.AsyncBlock = builder.build(anfTree, renameMap) - import asyncBlock.asyncStates - logDiagnostics(c)(anfTree, asyncStates.map(_.toString)) - - // Important to retain the original declaration order here! - val localVarTrees = anfTree.collect { - case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol => - utils.mkVarDefTree(tpt.tpe, renameMap(vd.symbol)) - case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) if renameMap contains dd.symbol => - DefDef(mods, renameMap(dd.symbol), tparams, vparamss, tpt, c.resetAllAttrs(utils.substituteNames(rhs, renameMap))) - } - - val onCompleteHandler = { - Function( - List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)), - asyncBlock.onCompleteHandler) - } - val resumeFunTree = asyncBlock.resumeFunTree[T] - - val stateMachineType = utils.applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType)) - - lazy val stateMachine: ClassDef = { - val body: List[Tree] = { - val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) - val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree) - val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree) - val applyDefDef: DefDef = { - val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) - val applyBody = asyncBlock.onCompleteHandler - DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), applyBody) - } - val apply0DefDef: DefDef = { - // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. - // See SI-1247 for the the optimization that avoids creatio - val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) - val applyBody = asyncBlock.onCompleteHandler - DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil)) - } - List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef) - } - val template = { - Template(List(stateMachineType), emptyValDef, body) - } - ClassDef(NoMods, name.stateMachineT, Nil, template) - } - - def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) - - val code: c.Expr[futureSystem.Fut[T]] = { - val isSimple = asyncStates.size == 1 - val tree = - if (isSimple) - Block(Nil, futureSystemOps.spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }` - else { - Block(List[Tree]( - stateMachine, - ValDef(NoMods, name.stateMachine, stateMachineType, Apply(Select(New(Ident(name.stateMachineT)), nme.CONSTRUCTOR), Nil)), - futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil)) - ), - futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree) - } - c.Expr[futureSystem.Fut[T]](tree) - } - - AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}") - code - } + val asyncMacro = AsyncMacro(c, futureSystem) + + val code = asyncMacro.asyncTransform[T]( + body.tree.asInstanceOf[asyncMacro.global.Tree], + execContext.tree.asInstanceOf[asyncMacro.global.Tree], + fallbackEnabled)(implicitly[c.WeakTypeTag[T]].asInstanceOf[asyncMacro.global.WeakTypeTag[T]]) - def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) { - def location = try { - c.macroApplication.pos.source.path - } catch { - case _: UnsupportedOperationException => - c.macroApplication.pos.toString - } - - AsyncUtils.vprintln(s"In file '$location':") - AsyncUtils.vprintln(s"${c.macroApplication}") - AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree") - states foreach (s => AsyncUtils.vprintln(s)) + AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}") + c.Expr[futureSystem.Fut[T]](code.asInstanceOf[Tree]) } } /** Internal class used by the `async` macro; should not be manually extended by client code */ abstract class StateMachine[Result, EC] extends (scala.util.Try[Any] => Unit) with (() => Unit) { - def result$async: Result + def result: Result - def execContext$async: EC + def execContext: EC } diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index 4f55f1b..424318e 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -7,12 +7,10 @@ package scala.async import scala.reflect.macros.Context import scala.collection.mutable -private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: AsyncBase) { - import c.universe._ +trait AsyncAnalysis { + self: AsyncMacro => - val utils = TransformUtils[c.type](c) - - import utils._ + import global._ /** * Analyze the contents of an `async` block in order to: @@ -20,47 +18,26 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: Asy * * Must be called on the original tree, not on the ANF transformed tree. */ - def reportUnsupportedAwaits(tree: Tree): Boolean = { - val analyzer = new UnsupportedAwaitAnalyzer + def reportUnsupportedAwaits(tree: Tree, report: Boolean): Boolean = { + val analyzer = new UnsupportedAwaitAnalyzer(report) analyzer.traverse(tree) analyzer.hasUnsupportedAwaits } - /** - * Analyze the contents of an `async` block in order to: - * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based - * on whether or not they are accessed only from a single state. - * - * Must be called on the ANF transformed tree. - */ - def defTreesUsedInSubsequentStates(tree: Tree): List[DefTree] = { - val analyzer = new AsyncDefinitionUseAnalyzer - analyzer.traverse(tree) - val liftable: List[DefTree] = (analyzer.valDefsToLift ++ analyzer.nestedMethodsToLift).toList.distinct - liftable - } - - private class UnsupportedAwaitAnalyzer extends AsyncTraverser { + private class UnsupportedAwaitAnalyzer(report: Boolean) extends AsyncTraverser { var hasUnsupportedAwaits = false override def nestedClass(classDef: ClassDef) { - val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" - if (!reportUnsupportedAwait(classDef, s"nested $kind")) { - // do not allow local class definitions, because of SI-5467 (specific to case classes, though) - if (classDef.symbol.asClass.isCaseClass) - c.error(classDef.pos, s"Local case class ${classDef.name.decoded} illegal within `async` block") - } + val kind = if (classDef.symbol.isTrait) "trait" else "class" + reportUnsupportedAwait(classDef, s"nested ${kind}") } override def nestedModule(module: ModuleDef) { - if (!reportUnsupportedAwait(module, "nested object")) { - // local object definitions lead to spurious type errors (because of resetAllAttrs?) - c.error(module.pos, s"Local object ${module.name.decoded} illegal within `async` block") - } + reportUnsupportedAwait(module, "nested object") } - override def nestedMethod(module: DefDef) { - reportUnsupportedAwait(module, "nested method") + override def nestedMethod(defDef: DefDef) { + reportUnsupportedAwait(defDef, "nested method") } override def byNameArgument(arg: Tree) { @@ -82,9 +59,10 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: Asy reportUnsupportedAwait(tree, "try/catch") super.traverse(tree) case Return(_) => - c.abort(tree.pos, "return is illegal within a async block") + abort(tree.pos, "return is illegal within a async block") case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) => - c.abort(tree.pos, "lazy vals are illegal within an async block") + // TODO lift this restriction + abort(tree.pos, "lazy vals are illegal within an async block") case _ => super.traverse(tree) } @@ -106,87 +84,8 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: Asy private def reportError(pos: Position, msg: String) { hasUnsupportedAwaits = true - if (!asyncBase.fallbackEnabled) - c.error(pos, msg) + if (report) + abort(pos, msg) } } - - private class AsyncDefinitionUseAnalyzer extends AsyncTraverser { - private var chunkId = 0 - - private def nextChunk() = chunkId += 1 - - private var valDefChunkId = Map[Symbol, (ValDef, Int)]() - - val valDefsToLift : mutable.Set[ValDef] = collection.mutable.Set() - val nestedMethodsToLift: mutable.Set[DefDef] = collection.mutable.Set() - - override def nestedMethod(defDef: DefDef) { - nestedMethodsToLift += defDef - markReferencedVals(defDef) - } - - override def function(function: Function) { - markReferencedVals(function) - } - - override def patMatFunction(tree: Match) { - markReferencedVals(tree) - } - - private def markReferencedVals(tree: Tree) { - tree foreach { - case rt: RefTree => - valDefChunkId.get(rt.symbol) match { - case Some((vd, defChunkId)) => - valDefsToLift += vd // lift all vals referred to by nested functions. - case _ => - } - case _ => - } - } - - override def traverse(tree: Tree) = { - tree match { - case If(cond, thenp, elsep) if tree exists isAwait => - traverseChunks(List(cond, thenp, elsep)) - case Match(selector, cases) if tree exists isAwait => - traverseChunks(selector :: cases) - case LabelDef(name, params, rhs) if rhs exists isAwait => - traverseChunks(rhs :: Nil) - case Apply(fun, args) if isAwait(fun) => - super.traverse(tree) - nextChunk() - case vd: ValDef => - super.traverse(tree) - valDefChunkId += (vd.symbol -> (vd -> chunkId)) - val isPatternBinder = vd.name.toString.contains(name.bindSuffix) - if (isAwait(vd.rhs) || isPatternBinder) valDefsToLift += vd - case as: Assign => - if (isAwait(as.rhs)) { - assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol) - - // TODO test the orElse case, try to remove the restriction. - val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}")) - valDefsToLift += vd - } - super.traverse(tree) - case rt: RefTree => - valDefChunkId.get(rt.symbol) match { - case Some((vd, defChunkId)) if defChunkId != chunkId => - valDefsToLift += vd - case _ => - } - super.traverse(tree) - case _ => super.traverse(tree) - } - } - - private def traverseChunks(trees: List[Tree]) { - trees.foreach { - t => traverse(t); nextChunk() - } - } - } - } diff --git a/src/main/scala/scala/async/AsyncMacro.scala b/src/main/scala/scala/async/AsyncMacro.scala new file mode 100644 index 0000000..8827351 --- /dev/null +++ b/src/main/scala/scala/async/AsyncMacro.scala @@ -0,0 +1,29 @@ +package scala.async + +import scala.tools.nsc.Global +import scala.tools.nsc.transform.TypingTransformers + +object AsyncMacro { + def apply(c: reflect.macros.Context, futureSystem0: FutureSystem): AsyncMacro = { + import language.reflectiveCalls + val powerContext = c.asInstanceOf[c.type {val universe: Global; val callsiteTyper: universe.analyzer.Typer}] + new AsyncMacro { + val global: powerContext.universe.type = powerContext.universe + val callSiteTyper: global.analyzer.Typer = powerContext.callsiteTyper + val futureSystem: futureSystem0.type = futureSystem0 + val futureSystemOps: futureSystem.Ops {val universe: global.type} = futureSystem0.mkOps(global) + val macroApplication: global.Tree = c.macroApplication.asInstanceOf[global.Tree] + } + } +} + +private[async] trait AsyncMacro + extends TypingTransformers + with AnfTransform with TransformUtils with Lifter + with ExprBuilder with AsyncTransform with AsyncAnalysis { + + val global: Global + val callSiteTyper: global.analyzer.Typer + val macroApplication: global.Tree + +} diff --git a/src/main/scala/scala/async/AsyncTransform.scala b/src/main/scala/scala/async/AsyncTransform.scala new file mode 100644 index 0000000..129f88e --- /dev/null +++ b/src/main/scala/scala/async/AsyncTransform.scala @@ -0,0 +1,176 @@ +package scala.async + +trait AsyncTransform { + self: AsyncMacro => + + import global._ + + def asyncTransform[T](body: Tree, execContext: Tree, cpsFallbackEnabled: Boolean) + (implicit resultType: WeakTypeTag[T]): Tree = { + + reportUnsupportedAwaits(body, report = !cpsFallbackEnabled) + + // Transform to A-normal form: + // - no await calls in qualifiers or arguments, + // - if/match only used in statement position. + val anfTree: Block = anfTransform(body) + + val resumeFunTreeDummyBody = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), Literal(Constant(()))) + + val applyDefDefDummyBody: DefDef = { + val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) + DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), Literal(Constant(()))) + } + + val stateMachineType = applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType)) + + val stateMachine: ClassDef = { + val body: List[Tree] = { + val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) + val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree) + val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext) + + val apply0DefDef: DefDef = { + // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. + // See SI-1247 for the the optimization that avoids creatio + DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil)) + } + List(emptyConstructor, stateVar, result, execContextValDef) ++ List(resumeFunTreeDummyBody, applyDefDefDummyBody, apply0DefDef) + } + val template = { + Template(List(stateMachineType), emptyValDef, body) + } + val t = ClassDef(NoMods, name.stateMachineT, Nil, template) + callSiteTyper.typedPos(macroApplication.pos)(Block(t :: Nil, Literal(Constant(())))) + t + } + + val asyncBlock: AsyncBlock = { + val symLookup = new SymLookup(stateMachine.symbol, applyDefDefDummyBody.vparamss.head.head.symbol) + buildAsyncBlock(anfTree, symLookup) + } + + logDiagnostics(anfTree, asyncBlock.asyncStates.map(_.toString)) + + def startStateMachine: Tree = { + val stateMachineSpliced: Tree = spliceMethodBodies( + liftables(asyncBlock.asyncStates), + stateMachine, + asyncBlock.onCompleteHandler[T], + asyncBlock.resumeFunTree[T].rhs + ) + + def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) + + Block(List[Tree]( + stateMachineSpliced, + ValDef(NoMods, name.stateMachine, stateMachineType, Apply(Select(New(Ident(stateMachine.symbol)), nme.CONSTRUCTOR), Nil)), + futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil), selectStateMachine(name.execContext)) + ), + futureSystemOps.promiseToFuture(Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree) + } + + val isSimple = asyncBlock.asyncStates.size == 1 + if (isSimple) + futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }` + else + startStateMachine + } + + def logDiagnostics(anfTree: Tree, states: Seq[String]) { + val pos = macroApplication.pos + def location = try { + pos.source.path + } catch { + case _: UnsupportedOperationException => + pos.toString + } + + AsyncUtils.vprintln(s"In file '$location':") + AsyncUtils.vprintln(s"${macroApplication}") + AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree") + states foreach (s => AsyncUtils.vprintln(s)) + } + + def spliceMethodBodies(liftables: List[Tree], tree: Tree, applyBody: Tree, + resumeBody: Tree): Tree = { + + val liftedSyms = liftables.map(_.symbol).toSet + val stateMachineClass = tree.symbol + liftedSyms.foreach { + sym => + if (sym != null) { + sym.owner = stateMachineClass + if (sym.isModule) + sym.moduleClass.owner = stateMachineClass + } + } + // Replace the ValDefs in the splicee with Assigns to the corresponding lifted + // fields. Similarly, replace references to them with references to the field. + // + // This transform will be only be run on the RHS of `def foo`. + class UseFields extends MacroTypingTransformer { + override def transform(tree: Tree): Tree = tree match { + case _ if currentOwner == stateMachineClass => + super.transform(tree) + case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) => + atOwner(currentOwner) { + val fieldSym = tree.symbol + val set = Assign(gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), transform(rhs)) + changeOwner(set, tree.symbol, currentOwner) + localTyper.typedPos(tree.pos)(set) + } + case _: DefTree if liftedSyms(tree.symbol) => + EmptyTree + case Ident(name) if liftedSyms(tree.symbol) => + val fieldSym = tree.symbol + gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym).setType(tree.tpe) + case _ => + super.transform(tree) + } + } + + val liftablesUseFields = liftables.map { + case vd: ValDef => vd + case x => + val useField = new UseFields() + //.substituteSymbols(fromSyms, toSyms) + useField.atOwner(stateMachineClass)(useField.transform(x)) + } + + tree.children.foreach { + t => + new ChangeOwnerAndModuleClassTraverser(callSiteTyper.context.owner, tree.symbol).traverse(t) + } + val treeSubst = tree + + def fixup(dd: DefDef, body: Tree, ctx: analyzer.Context): Tree = { + val spliceeAnfFixedOwnerSyms = body + val useField = new UseFields() + val newRhs = useField.atOwner(dd.symbol)(useField.transform(spliceeAnfFixedOwnerSyms)) + val typer = global.analyzer.newTyper(ctx.make(dd, dd.symbol)) + treeCopy.DefDef(dd, dd.mods, dd.name, dd.tparams, dd.vparamss, dd.tpt, typer.typed(newRhs)) + } + + liftablesUseFields.foreach(t => if (t.symbol != null) stateMachineClass.info.decls.enter(t.symbol)) + + val result0 = transformAt(treeSubst) { + case t@Template(parents, self, stats) => + (ctx: analyzer.Context) => { + treeCopy.Template(t, parents, self, liftablesUseFields ++ stats) + } + } + val result = transformAt(result0) { + case dd@DefDef(_, name.apply, _, List(List(_)), _, _) if dd.symbol.owner == stateMachineClass => + (ctx: analyzer.Context) => + val typedTree = fixup(dd, changeOwner(applyBody, callSiteTyper.context.owner, dd.symbol), ctx) + typedTree + case dd@DefDef(_, name.resume, _, _, _, _) if dd.symbol.owner == stateMachineClass => + (ctx: analyzer.Context) => + val changed = changeOwner(resumeBody, callSiteTyper.context.owner, dd.symbol) + val res = fixup(dd, changed, ctx) + res + } + result + } +} diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index ca46a83..a3837d3 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -7,17 +7,17 @@ import scala.reflect.macros.Context import scala.collection.mutable.ListBuffer import collection.mutable import language.existentials +import scala.reflect.api.Universe +import scala.reflect.api -private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: C, futureSystem: FS, origTree: C#Tree) { - builder => +trait ExprBuilder { + builder: AsyncMacro => - val utils = TransformUtils[c.type](c) - - import c.universe._ - import utils._ + import global._ import defn._ - lazy val futureSystemOps = futureSystem.mkOps(c) + val futureSystem: FutureSystem + val futureSystemOps: futureSystem.Ops { val universe: global.type } val stateAssigner = new StateAssigner val labelDefStates = collection.mutable.Map[Symbol, Int]() @@ -27,22 +27,27 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: def mkHandlerCaseForState: CaseDef - def mkOnCompleteHandler[T: c.WeakTypeTag]: Option[CaseDef] = None + def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None def stats: List[Tree] - final def body: c.Tree = stats match { + final def allStats: List[Tree] = this match { + case a: AsyncStateWithAwait => stats :+ a.awaitable.resultValDef + case _ => stats + } + + final def body: Tree = stats match { case stat :: Nil => stat case init :+ last => Block(init, last) } } /** A sequence of statements the concludes with a unconditional transition to `nextState` */ - final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int) + final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup) extends AsyncState { def mkHandlerCaseForState: CaseDef = - mkHandlerCase(state, stats :+ mkStateTree(nextState) :+ mkResumeApply) + mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup)) override val toString: String = s"AsyncState #$state, next = $nextState" @@ -51,7 +56,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: /** A sequence of statements with a conditional transition to the next state, which will represent * a branch of an `if` or a `match`. */ - final class AsyncStateWithoutAwait(val stats: List[c.Tree], val state: Int) extends AsyncState { + final class AsyncStateWithoutAwait(val stats: List[Tree], val state: Int) extends AsyncState { override def mkHandlerCaseForState: CaseDef = mkHandlerCase(state, stats) @@ -62,25 +67,25 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: /** A sequence of statements that concludes with an `await` call. The `onComplete` * handler will unconditionally transition to `nestState`.`` */ - final class AsyncStateWithAwait(val stats: List[c.Tree], val state: Int, nextState: Int, - awaitable: Awaitable) + final class AsyncStateWithAwait(val stats: List[Tree], val state: Int, nextState: Int, + val awaitable: Awaitable, symLookup: SymLookup) extends AsyncState { override def mkHandlerCaseForState: CaseDef = { - val callOnComplete = futureSystemOps.onComplete(c.Expr(awaitable.expr), - c.Expr(This(tpnme.EMPTY)), c.Expr(Ident(name.execContext))).tree + val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr), + Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree mkHandlerCase(state, stats :+ callOnComplete) } - override def mkOnCompleteHandler[T: c.WeakTypeTag]: Option[CaseDef] = { + override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = { val tryGetTree = Assign( Ident(awaitable.resultName), - TypeApply(Select(Select(Ident(name.tr), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) + TypeApply(Select(Select(Ident(symLookup.applyTrParam), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) ) /* if (tr.isFailure) - * result$async.complete(tr.asInstanceOf[Try[T]]) + * result.complete(tr.asInstanceOf[Try[T]]) * else { * = tr.get.asInstanceOf[] * @@ -88,13 +93,13 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: * } */ val ifIsFailureTree = - If(Select(Ident(name.tr), Try_isFailure), + If(Select(Ident(symLookup.applyTrParam), Try_isFailure), futureSystemOps.completeProm[T]( - c.Expr[futureSystem.Prom[T]](Ident(name.result)), - c.Expr[scala.util.Try[T]]( - TypeApply(Select(Ident(name.tr), newTermName("asInstanceOf")), + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), + Expr[scala.util.Try[T]]( + TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")), List(TypeTree(weakTypeOf[scala.util.Try[T]]))))).tree, - Block(List(tryGetTree, mkStateTree(nextState)), mkResumeApply) + Block(List(tryGetTree, mkStateTree(nextState, symLookup)), mkResumeApply(symLookup)) ) Some(mkHandlerCase(state, List(ifIsFailureTree))) @@ -107,19 +112,16 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: /* * Builder for a single state of an async method. */ - final class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name]) { + final class AsyncStateBuilder(state: Int, private val symLookup: SymLookup) { /* Statements preceding an await call. */ - private val stats = ListBuffer[c.Tree]() + private val stats = ListBuffer[Tree]() /** The state of the target of a LabelDef application (while loop jump) */ private var nextJumpState: Option[Int] = None - private def renameReset(tree: Tree) = resetInternalAttrs(substituteNames(tree, nameMap)) - - def +=(stat: c.Tree): this.type = { + def +=(stat: Tree): this.type = { assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat") - def addStat() = stats += renameReset(stat) + def addStat() = stats += stat stat match { - case _: DefDef => // these have been lifted. case Apply(fun, Nil) => labelDefStates get fun.symbol match { case Some(nextState) => nextJumpState = Some(nextState) @@ -132,22 +134,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: def resultWithAwait(awaitable: Awaitable, nextState: Int): AsyncState = { - val sanitizedAwaitable = awaitable.copy(expr = renameReset(awaitable.expr)) val effectiveNextState = nextJumpState.getOrElse(nextState) - new AsyncStateWithAwait(stats.toList, state, effectiveNextState, sanitizedAwaitable) + new AsyncStateWithAwait(stats.toList, state, effectiveNextState, awaitable, symLookup) } def resultSimple(nextState: Int): AsyncState = { val effectiveNextState = nextJumpState.getOrElse(nextState) - new SimpleAsyncState(stats.toList, state, effectiveNextState) + new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup) } - def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = { - // 1. build changed if-else tree - // 2. insert that tree at the end of the current state - val cond = renameReset(condTree) - def mkBranch(state: Int) = Block(mkStateTree(state) :: Nil, mkResumeApply) - this += If(cond, mkBranch(thenState), mkBranch(elseState)) + def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { + def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup)) + this += If(condTree, mkBranch(thenState), mkBranch(elseState)) new AsyncStateWithoutAwait(stats.toList, state) } @@ -161,23 +159,20 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: * @param caseStates starting state of the right-hand side of the each case * @return an `AsyncState` representing the match expression */ - def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], caseStates: List[Int]): AsyncState = { + def resultWithMatch(scrutTree: Tree, cases: List[CaseDef], caseStates: List[Int], symLookup: SymLookup): AsyncState = { // 1. build list of changed cases val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { case CaseDef(pat, guard, rhs) => - val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal).map { - case ValDef(_, name, _, rhs) => Assign(Ident(name), rhs) - case t => sys.error(s"Unexpected tree. Expected ValDef, found: $t") - } - CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num)), mkResumeApply)) + val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal) + CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num), symLookup), mkResumeApply(symLookup))) } // 2. insert changed match tree at the end of the current state - this += Match(renameReset(scrutTree), newCases) + this += Match(scrutTree, newCases) new AsyncStateWithoutAwait(stats.toList, state) } - def resultWithLabel(startLabelState: Int): AsyncState = { - this += Block(mkStateTree(startLabelState) :: Nil, mkResumeApply) + def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = { + this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup)) new AsyncStateWithoutAwait(stats.toList, state) } @@ -194,24 +189,23 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: * @param expr the last expression of the block * @param startState the start state * @param endState the state to continue with - * @param toRename a `Map` for renaming the given key symbols to the mangled value names */ - final private class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, - private val toRename: Map[Symbol, c.Name]) { + final private class AsyncBlockBuilder(stats: List[Tree], expr: Tree, startState: Int, endState: Int, + private val symLookup: SymLookup) { val asyncStates = ListBuffer[AsyncState]() - var stateBuilder = new AsyncStateBuilder(startState, toRename) + var stateBuilder = new AsyncStateBuilder(startState, symLookup) var currState = startState /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ - def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists { + def checkForUnsupportedAwait(tree: Tree) = if (tree exists { case Apply(fun, _) if isAwait(fun) => true case _ => false - }) c.abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException + }) abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = { val (nestedStats, nestedExpr) = statsAndExpr(nestedTree) - new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, toRename) + new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, symLookup) } import stateAssigner.nextState @@ -219,16 +213,12 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: // populate asyncStates for (stat <- stats) stat match { // the val name = await(..) pattern - case ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => + case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => val afterAwaitState = nextState() - val awaitable = Awaitable(arg, toRename(stat.symbol).toTermName, tpt.tpe) + val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd) asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await currState = afterAwaitState - stateBuilder = new AsyncStateBuilder(currState, toRename) - - case ValDef(mods, name, tpt, rhs) if toRename contains stat.symbol => - checkForUnsupportedAwait(rhs) - stateBuilder += Assign(Ident(toRename(stat.symbol).toTermName), rhs) + stateBuilder = new AsyncStateBuilder(currState, symLookup) case If(cond, thenp, elsep) if stat exists isAwait => checkForUnsupportedAwait(cond) @@ -248,7 +238,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: } currState = afterIfState - stateBuilder = new AsyncStateBuilder(currState, toRename) + stateBuilder = new AsyncStateBuilder(currState, symLookup) case Match(scrutinee, cases) if stat exists isAwait => checkForUnsupportedAwait(scrutinee) @@ -257,7 +247,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: val afterMatchState = nextState() asyncStates += - stateBuilder.resultWithMatch(scrutinee, cases, caseStates) + stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup) for ((cas, num) <- cases.zipWithIndex) { val (stats, expr) = statsAndExpr(cas.body) @@ -267,18 +257,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: } currState = afterMatchState - stateBuilder = new AsyncStateBuilder(currState, toRename) + stateBuilder = new AsyncStateBuilder(currState, symLookup) case ld@LabelDef(name, params, rhs) if rhs exists isAwait => val startLabelState = nextState() val afterLabelState = nextState() - asyncStates += stateBuilder.resultWithLabel(startLabelState) + asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup) labelDefStates(ld.symbol) = startLabelState val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState) asyncStates ++= builder.asyncStates currState = afterLabelState - stateBuilder = new AsyncStateBuilder(currState, toRename) + stateBuilder = new AsyncStateBuilder(currState, symLookup) case _ => checkForUnsupportedAwait(stat) stateBuilder += stat @@ -292,17 +282,23 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: trait AsyncBlock { def asyncStates: List[AsyncState] - def onCompleteHandler[T: c.WeakTypeTag]: Tree + def onCompleteHandler[T: WeakTypeTag]: Tree + + def resumeFunTree[T]: DefDef + } - def resumeFunTree[T]: Tree + case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) { + def stateMachineMember(name: TermName): Symbol = + stateMachineClass.info.member(name) + def memberRef(name: TermName) = gen.mkAttributedRef(stateMachineMember(name)) } - def build(block: Block, toRename: Map[Symbol, c.Name]): AsyncBlock = { + def buildAsyncBlock(block: Block, symLookup: SymLookup): AsyncBlock = { val Block(stats, expr) = block val startState = stateAssigner.nextState() val endState = Int.MaxValue - val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, toRename) + val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, symLookup) new AsyncBlock { def asyncStates = blockBuilder.asyncStates.toList @@ -310,9 +306,9 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: def mkCombinedHandlerCases[T]: List[CaseDef] = { val caseForLastState: CaseDef = { val lastState = asyncStates.last - val lastStateBody = c.Expr[T](lastState.body) + val lastStateBody = Expr[T](lastState.body) val rhs = futureSystemOps.completeProm( - c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Success(lastStateBody.splice))) + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Success(lastStateBody.splice))) mkHandlerCase(lastState.state, rhs.tree) } asyncStates.toList match { @@ -326,18 +322,6 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: val initStates = asyncStates.init - /** - * // assumes tr: Try[Any] is in scope. - * // - * state match { - * case 0 => { - * x11 = tr.get.asInstanceOf[Double]; - * state = 1; - * resume() - * } - */ - def onCompleteHandler[T: c.WeakTypeTag]: Tree = Match(Ident(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList) - /** * def resume(): Unit = { * try { @@ -353,18 +337,31 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: * } * } */ - def resumeFunTree[T]: Tree = + def resumeFunTree[T]: DefDef = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), Try( - Match(Ident(name.state), mkCombinedHandlerCases[T]), + Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T]), List( CaseDef( - Apply(Ident(defn.NonFatalClass), List(Bind(name.tr, Ident(nme.WILDCARD)))), - EmptyTree, + Bind(name.t, Ident(nme.WILDCARD)), + Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), Block(List({ - val t = c.Expr[Throwable](Ident(name.tr)) - futureSystemOps.completeProm[T](c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Failure(t.splice))).tree - }), c.literalUnit.tree))), EmptyTree)) + val t = Expr[Throwable](Ident(name.t)) + futureSystemOps.completeProm[T]( + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Failure(t.splice))).tree + }), literalUnit))), EmptyTree)) + + /** + * // assumes tr: Try[Any] is in scope. + * // + * state match { + * case 0 => { + * x11 = tr.get.asInstanceOf[Double]; + * state = 1; + * resume() + * } + */ + def onCompleteHandler[T: WeakTypeTag]: Tree = Match(symLookup.memberRef(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList) } } @@ -373,22 +370,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: case _ => false } - private final case class Awaitable(expr: Tree, resultName: TermName, resultType: Type) - - private val internalSyms = origTree.collect { - case dt: DefTree => dt.symbol - } + case class Awaitable(expr: Tree, resultName: Symbol, resultType: Type, resultValDef: ValDef) - private def resetInternalAttrs(tree: Tree) = utils.resetInternalAttrs(tree, internalSyms) + private def mkResumeApply(symLookup: SymLookup) = Apply(symLookup.memberRef(name.resume), Nil) - private def mkResumeApply = Apply(Ident(name.resume), Nil) + private def mkStateTree(nextState: Int, symLookup: SymLookup): Tree = + Assign(symLookup.memberRef(name.state), Literal(Constant(nextState))) - private def mkStateTree(nextState: Int): c.Tree = - Assign(Ident(name.state), c.literal(nextState).tree) + private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef = + mkHandlerCase(num, Block(rhs, literalUnit)) - private def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef = - mkHandlerCase(num, Block(rhs, c.literalUnit.tree)) + private def mkHandlerCase(num: Int, rhs: Tree): CaseDef = + CaseDef(Literal(Constant(num)), EmptyTree, rhs) - private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = - CaseDef(c.literal(num).tree, EmptyTree, rhs) + private def literalUnit = Literal(Constant(())) } diff --git a/src/main/scala/scala/async/FutureSystem.scala b/src/main/scala/scala/async/FutureSystem.scala index a050bec..0c04296 100644 --- a/src/main/scala/scala/async/FutureSystem.scala +++ b/src/main/scala/scala/async/FutureSystem.scala @@ -6,6 +6,7 @@ package scala.async import scala.language.higherKinds import scala.reflect.macros.Context +import scala.reflect.internal.SymbolTable /** * An abstraction over a future system. @@ -26,12 +27,10 @@ trait FutureSystem { type ExecContext trait Ops { - val context: reflect.macros.Context + val universe: reflect.internal.SymbolTable - import context.universe._ - - /** Lookup the execution context, typically with an implicit search */ - def execContext: Expr[ExecContext] + import universe._ + def Expr[T: WeakTypeTag](tree: Tree): Expr[T] = universe.Expr[T](rootMirror, universe.FixedMirrorTreeCreator(rootMirror, tree)) def promType[A: WeakTypeTag]: Type def execContextType: Type @@ -52,15 +51,17 @@ trait FutureSystem { /** Complete a promise with a value */ def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] - def spawn(tree: context.Tree): context.Tree = - future(context.Expr[Unit](tree))(execContext).tree + def spawn(tree: Tree, execContext: Tree): Tree = + future(Expr[Unit](tree))(Expr[ExecContext](execContext)).tree + // TODO Why is this needed? def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] } - def mkOps(c: Context): Ops { val context: c.type } + def mkOps(c: SymbolTable): Ops { val universe: c.type } } + object ScalaConcurrentFutureSystem extends FutureSystem { import scala.concurrent._ @@ -69,18 +70,13 @@ object ScalaConcurrentFutureSystem extends FutureSystem { type Fut[A] = Future[A] type ExecContext = ExecutionContext - def mkOps(c: Context): Ops {val context: c.type} = new Ops { - val context: c.type = c - - import context.universe._ + def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops { + val universe: c.type = c - def execContext: Expr[ExecContext] = c.Expr(c.inferImplicitValue(c.weakTypeOf[ExecutionContext]) match { - case EmptyTree => c.abort(c.macroApplication.pos, "Unable to resolve implicit ExecutionContext") - case context => context - }) + import universe._ - def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Promise[A]] - def execContextType: Type = c.weakTypeOf[ExecutionContext] + def promType[A: WeakTypeTag]: Type = weakTypeOf[Promise[A]] + def execContextType: Type = weakTypeOf[ExecutionContext] def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { Promise[A]() @@ -101,7 +97,7 @@ object ScalaConcurrentFutureSystem extends FutureSystem { def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { prom.splice.complete(value.splice) - context.literalUnit.splice + Expr[Unit](Literal(Constant(()))).splice } def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = reify { @@ -121,15 +117,15 @@ object IdentityFutureSystem extends FutureSystem { type Fut[A] = A type ExecContext = Unit - def mkOps(c: Context): Ops {val context: c.type} = new Ops { - val context: c.type = c + def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops { + val universe: c.type = c - import context.universe._ + import universe._ - def execContext: Expr[ExecContext] = c.literalUnit + def execContext: Expr[ExecContext] = Expr[Unit](Literal(Constant(()))) - def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Prom[A]] - def execContextType: Type = c.weakTypeOf[Unit] + def promType[A: WeakTypeTag]: Type = weakTypeOf[Prom[A]] + def execContextType: Type = weakTypeOf[Unit] def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { new Prom(null.asInstanceOf[A]) @@ -144,12 +140,12 @@ object IdentityFutureSystem extends FutureSystem { def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], execContext: Expr[ExecContext]): Expr[Unit] = reify { fun.splice.apply(util.Success(future.splice)) - context.literalUnit.splice + Expr[Unit](Literal(Constant(()))).splice } def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { prom.splice.a = value.splice.get - context.literalUnit.splice + Expr[Unit](Literal(Constant(()))).splice } def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = ??? diff --git a/src/main/scala/scala/async/Lifter.scala b/src/main/scala/scala/async/Lifter.scala new file mode 100644 index 0000000..52ce47d --- /dev/null +++ b/src/main/scala/scala/async/Lifter.scala @@ -0,0 +1,150 @@ +package scala.async + +trait Lifter { + self: AsyncMacro => + import global._ + + /** + * Identify which DefTrees are used (including transitively) which are declared + * in some state but used (including transitively) in another state. + * + * These will need to be lifted to class members of the state machine. + */ + def liftables(asyncStates: List[AsyncState]): List[Tree] = { + object companionship { + private val companions = collection.mutable.Map[Symbol, Symbol]() + private val companionsInverse = collection.mutable.Map[Symbol, Symbol]() + private def record(sym1: Symbol, sym2: Symbol) { + companions(sym1) = sym2 + companions(sym2) = sym1 + } + + def record(defs: List[Tree]) { + // Keep note of local companions so we rename them consistently + // when lifting. + val comps = for { + cd@ClassDef(_, _, _, _) <- defs + md@ModuleDef(_, _, _) <- defs + if (cd.name.toTermName == md.name) + } record(cd.symbol, md.symbol) + } + def companionOf(sym: Symbol): Symbol = { + companions.get(sym).orElse(companionsInverse.get(sym)).getOrElse(NoSymbol) + } + } + + + val defs: Map[Tree, Int] = { + /** Collect the DefTrees directly enclosed within `t` that have the same owner */ + def collectDirectlyEnclosedDefs(t: Tree): List[DefTree] = t match { + case dt: DefTree => dt :: Nil + case _: Function => Nil + case t => + val childDefs = t.children.flatMap(collectDirectlyEnclosedDefs(_)) + companionship.record(childDefs) + childDefs + } + asyncStates.flatMap { + asyncState => + val defs = collectDirectlyEnclosedDefs(Block(asyncState.allStats: _*)) + defs.map((_, asyncState.state)) + }.toMap + } + + // In which block are these symbols defined? + val symToDefiningState: Map[Symbol, Int] = defs.map { + case (k, v) => (k.symbol, v) + } + + // The definitions trees + val symToTree: Map[Symbol, Tree] = defs.map { + case (k, v) => (k.symbol, k) + } + + // The direct references of each definition tree + val defSymToReferenced: Map[Symbol, List[Symbol]] = defs.keys.map { + case tree => (tree.symbol, tree.collect { + case rt: RefTree if symToDefiningState.contains(rt.symbol) => rt.symbol + }) + }.toMap + + // The direct references of each block, excluding references of `DefTree`-s which + // are already accounted for. + val stateIdToDirectlyReferenced: Map[Int, List[Symbol]] = { + val refs: List[(Int, Symbol)] = asyncStates.flatMap( + asyncState => asyncState.stats.filterNot(_.isDef).flatMap(_.collect { + case rt: RefTree if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol) + }) + ) + toMultiMap(refs) + } + + def liftableSyms: Set[Symbol] = { + val liftableMutableSet = collection.mutable.Set[Symbol]() + def markForLift(sym: Symbol) { + if (!liftableMutableSet(sym)) { + liftableMutableSet += sym + + // Only mark transitive references of defs, modules and classes. The RHS of lifted vals/vars + // stays in its original location, so things that it refers to need not be lifted. + if (!(sym.isVal || sym.isVar)) + defSymToReferenced(sym).foreach(sym2 => markForLift(sym2)) + } + } + // Start things with DefTrees directly referenced from statements from other states... + val liftableStatementRefs: List[Symbol] = stateIdToDirectlyReferenced.toList.flatMap { + case (i, syms) => syms.filter(sym => symToDefiningState(sym) != i) + } + // .. and likewise for DefTrees directly referenced by other DefTrees from other states + val liftableRefsOfDefTrees = defSymToReferenced.toList.flatMap { + case (referee, referents) => referents.filter(sym => symToDefiningState(sym) != symToDefiningState(referee)) + } + // Mark these for lifting, which will follow transitive references. + (liftableStatementRefs ++ liftableRefsOfDefTrees).foreach(markForLift) + liftableMutableSet.toSet + } + + val lifted = liftableSyms.map(symToTree).toList.map { + case vd@ValDef(_, _, tpt, rhs) => + import reflect.internal.Flags._ + val sym = vd.symbol + sym.setFlag(MUTABLE | STABLE | PRIVATE | LOCAL) + sym.name = name.fresh(sym.name.toTermName) + sym.modifyInfo(_.deconst) + ValDef(vd.symbol, gen.mkZero(vd.symbol.info)).setPos(vd.pos) + case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) => + import reflect.internal.Flags._ + val sym = dd.symbol + sym.name = this.name.fresh(sym.name.toTermName) + sym.setFlag(PRIVATE | LOCAL) + DefDef(dd.symbol, rhs).setPos(dd.pos) + case cd@ClassDef(_, _, _, impl) => + import reflect.internal.Flags._ + val sym = cd.symbol + sym.name = newTypeName(name.fresh(sym.name.toString).toString) + companionship.companionOf(cd.symbol) match { + case NoSymbol => + case moduleSymbol => + moduleSymbol.name = sym.name.toTermName + moduleSymbol.moduleClass.name = moduleSymbol.name.toTypeName + } + ClassDef(cd.symbol, impl).setPos(cd.pos) + case md@ModuleDef(_, _, impl) => + import reflect.internal.Flags._ + val sym = md.symbol + companionship.companionOf(md.symbol) match { + case NoSymbol => + sym.name = name.fresh(sym.name.toTermName) + sym.moduleClass.name = sym.name.toTypeName + case classSymbol => // will be renamed by `case ClassDef` above. + } + ModuleDef(md.symbol, impl).setPos(md.pos) + case td@TypeDef(_, _, _, rhs) => + import reflect.internal.Flags._ + val sym = td.symbol + sym.name = newTypeName(name.fresh(sym.name.toString).toString) + TypeDef(td.symbol, rhs).setPos(td.pos) + } + lifted + } +} diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index ebd546f..33dd21d 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -5,115 +5,40 @@ package scala.async import scala.reflect.macros.Context import reflect.ClassTag +import scala.reflect.macros.runtime.AbortMacroException /** * Utilities used in both `ExprBuilder` and `AnfTransform`. */ -private[async] final case class TransformUtils[C <: Context](c: C) { +private[async] trait TransformUtils { + self: AsyncMacro => - import c.universe._ + import global._ object name { - def suffix(string: String) = string + "$async" - - def suffixedName(prefix: String) = newTermName(suffix(prefix)) - - val state = suffixedName("state") - val result = suffixedName("result") - val resume = suffixedName("resume") - val execContext = suffixedName("execContext") - val stateMachine = newTermName(fresh("stateMachine")) - val stateMachineT = stateMachine.toTypeName + val resume = newTermName("resume") val apply = newTermName("apply") - val applyOrElse = newTermName("applyOrElse") - val tr = newTermName("tr") val matchRes = "matchres" val ifRes = "ifres" val await = "await" val bindSuffix = "$bind" - def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) + val state = newTermName("state") + val result = newTermName("result") + val execContext = newTermName("execContext") + val stateMachine = newTermName(fresh("stateMachine")) + val stateMachineT = stateMachine.toTypeName + val tr = newTermName("tr") + val t = newTermName("throwable") - def fresh(name: String): String = if (name.toString.contains("$")) name else c.fresh("" + name + "$") - } + def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) - def defaultValue(tpe: Type): Literal = { - val defaultValue: Any = - if (tpe <:< definitions.BooleanTpe) false - else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0 - else if (tpe <:< definitions.AnyValTpe) 0 - else null - Literal(Constant(defaultValue)) + def fresh(name: String): String = if (name.toString.contains("$")) name else currentUnit.freshTermName("" + name + "$").toString } def isAwait(fun: Tree) = fun.symbol == defn.Async_await - /** Replace all `Ident` nodes referring to one of the keys n `renameMap` with a node - * referring to the corresponding new name - */ - def substituteNames(tree: Tree, renameMap: Map[Symbol, Name]): Tree = { - val renamer = new Transformer { - override def transform(tree: Tree) = tree match { - case Ident(_) => (renameMap get tree.symbol).fold(tree)(Ident(_)) - case tt: TypeTree if tt.original != EmptyTree && tt.original != null => - // We also have to apply our renaming transform on originals of TypeTrees. - // TODO 2.10.1 Can we find a cleaner way? - val symTab = c.universe.asInstanceOf[reflect.internal.SymbolTable] - val tt1 = tt.asInstanceOf[symTab.TypeTree] - tt1.setOriginal(transform(tt.original).asInstanceOf[symTab.Tree]) - super.transform(tree) - case _ => super.transform(tree) - } - } - renamer.transform(tree) - } - - /** Descends into the regions of the tree that are subject to the - * translation to a state machine by `async`. When a nested template, - * function, or by-name argument is encountered, the descent stops, - * and `nestedClass` etc are invoked. - */ - trait AsyncTraverser extends Traverser { - def nestedClass(classDef: ClassDef) { - } - - def nestedModule(module: ModuleDef) { - } - - def nestedMethod(module: DefDef) { - } - - def byNameArgument(arg: Tree) { - } - - def function(function: Function) { - } - - def patMatFunction(tree: Match) { - } - - override def traverse(tree: Tree) { - tree match { - case cd: ClassDef => nestedClass(cd) - case md: ModuleDef => nestedModule(md) - case dd: DefDef => nestedMethod(dd) - case fun: Function => function(fun) - case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions` - case Applied(fun, targs, argss) if argss.nonEmpty => - val isInByName = isByName(fun) - for ((args, i) <- argss.zipWithIndex) { - for ((arg, j) <- args.zipWithIndex) { - if (!isInByName(i, j)) traverse(arg) - else byNameArgument(arg) - } - } - traverse(fun) - case _ => super.traverse(tree) - } - } - } - private lazy val Boolean_ShortCircuits: Set[Symbol] = { import definitions.BooleanClass def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName) @@ -122,57 +47,30 @@ private[async] final case class TransformUtils[C <: Context](c: C) { Set(Boolean_&&, Boolean_||) } - def isByName(fun: Tree): ((Int, Int) => Boolean) = { + private def isByName(fun: Tree): ((Int, Int) => Boolean) = { if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true else { - val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] - val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss + val paramss = fun.tpe.paramss val byNamess = paramss.map(_.map(_.isByNameParam)) (i, j) => util.Try(byNamess(i)(j)).getOrElse(false) } } - def argName(fun: Tree): ((Int, Int) => String) = { - val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] - val paramss = fun.tpe.asInstanceOf[symtab.Type].paramss + private def argName(fun: Tree): ((Int, Int) => String) = { + val paramss = fun.tpe.paramss val namess = paramss.map(_.map(_.name.toString)) (i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}") } - object Applied { - val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] - object treeInfo extends { - val global: symtab.type = symtab - } with reflect.internal.TreeInfo - - def unapply(tree: Tree): Some[(Tree, List[Tree], List[List[Tree]])] = { - val treeInfo.Applied(core, targs, argss) = tree.asInstanceOf[symtab.Tree] - Some((core.asInstanceOf[Tree], targs.asInstanceOf[List[Tree]], argss.asInstanceOf[List[List[Tree]]])) - } - } - - def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match { - case Block(stats, expr) => (stats, expr) - case _ => (List(tree), Literal(Constant(()))) - } - - def mkVarDefTree(resultType: Type, resultName: TermName): c.Tree = { - ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType)) - } - - def emptyConstructor: DefDef = { - val emptySuperCall = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil) - DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), c.literalUnit.tree)) - } - - def applied(className: String, types: List[Type]): AppliedTypeTree = - AppliedTypeTree(Ident(c.mirror.staticClass(className)), types.map(TypeTree(_))) + def Expr[A: WeakTypeTag](t: Tree) = global.Expr[A](rootMirror, new FixedMirrorTreeCreator(rootMirror, t)) object defn { def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { - c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) + Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) } - def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify(self.splice.contains(elem.splice)) + def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify { + self.splice.contains(elem.splice) + } def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify { self.splice.apply(arg.splice) @@ -186,146 +84,17 @@ private[async] final case class TransformUtils[C <: Context](c: C) { self.splice.get } - val Try_get = methodSym(reify((null: scala.util.Try[Any]).get)) - val Try_isFailure = methodSym(reify((null: scala.util.Try[Any]).isFailure)) - - val TryClass = c.mirror.staticClass("scala.util.Try") + val TryClass = rootMirror.staticClass("scala.util.Try") + val Try_get = TryClass.typeSignature.member(newTermName("get")).ensuring(_ != NoSymbol) + val Try_isFailure = TryClass.typeSignature.member(newTermName("isFailure")).ensuring(_ != NoSymbol) val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe)) - val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal") - - private def asyncMember(name: String) = { - val asyncMod = c.mirror.staticClass("scala.async.AsyncBase") - val tpe = asyncMod.asType.toType - tpe.member(newTermName(name)).ensuring(_ != NoSymbol) - } - - val Async_await = asyncMember("await") - } - - /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ - private def methodSym(apply: c.Expr[Any]): Symbol = { - val tree2: Tree = c.typeCheck(apply.tree) - tree2.collect { - case s: SymTree if s.symbol.isMethod => s.symbol - }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}")) - } - - /** - * Using [[scala.reflect.api.Trees.TreeCopier]] copies more than we would like: - * we don't want to copy types and symbols to the new trees in some cases. - * - * Instead, we just copy positions and attachments. - */ - def attachCopy[T <: Tree](orig: Tree)(tree: T): tree.type = { - tree.setPos(orig.pos) - for (att <- orig.attachments.all) - tree.updateAttachment[Any](att)(ClassTag.apply[Any](att.getClass)) - tree - } - - def resetInternalAttrs(tree: Tree, internalSyms: List[Symbol]) = - new ResetInternalAttrs(internalSyms.toSet).transform(tree) - - /** - * Adaptation of [[scala.reflect.internal.Trees.ResetAttrs]] - * - * A transformer which resets symbol and tpe fields of all nodes in a given tree, - * with special treatment of: - * `TypeTree` nodes: are replaced by their original if it exists, otherwise tpe field is reset - * to empty if it started out empty or refers to local symbols (which are erased). - * `TypeApply` nodes: are deleted if type arguments end up reverted to empty - * - * `This` and `Ident` nodes referring to an external symbol are ''not'' reset. - */ - private final class ResetInternalAttrs(internalSyms: Set[Symbol]) extends Transformer { - - import language.existentials - - override def transform(tree: Tree): Tree = super.transform { - def isExternal = tree.symbol != NoSymbol && !internalSyms(tree.symbol) - - tree match { - case tpt: TypeTree => resetTypeTree(tpt) - case TypeApply(fn, args) - if args map transform exists (_.isEmpty) => transform(fn) - case EmptyTree => tree - case (_: Ident | _: This) if isExternal => tree // #35 Don't reset the symbol of Ident/This bound outside of the async block - case _ => resetTree(tree) - } - } - - private def resetTypeTree(tpt: TypeTree): Tree = { - if (tpt.original != null) - transform(tpt.original) - else if (tpt.tpe != null && tpt.asInstanceOf[symtab.TypeTree forSome {val symtab: reflect.internal.SymbolTable}].wasEmpty) { - val dupl = tpt.duplicate - dupl.tpe = null - dupl - } - else tpt - } - - private def resetTree(tree: Tree): Tree = { - val hasSymbol: Boolean = { - val reflectInternalTree = tree.asInstanceOf[symtab.Tree forSome {val symtab: reflect.internal.SymbolTable}] - reflectInternalTree.hasSymbol - } - val dupl = tree.duplicate - if (hasSymbol) - dupl.symbol = NoSymbol - dupl.tpe = null - dupl - } - } - - /** - * Replaces expressions of the form `{ new $anon extends PartialFunction[A, B] { ... ; def applyOrElse[..](...) = ... match }` - * with `Match(EmptyTree, cases`. - * - * This reverses the transformation performed in `Typers`, and works around non-idempotency of typechecking such trees. - */ - // TODO Reference JIRA issue. - final def restorePatternMatchingFunctions(tree: Tree) = - RestorePatternMatchingFunctions transform tree - - private object RestorePatternMatchingFunctions extends Transformer { - - import language.existentials - val DefaultCaseName: TermName = "defaultCase$" - - override def transform(tree: Tree): Tree = { - val SYNTHETIC = (1 << 21).toLong.asInstanceOf[FlagSet] - def isSynthetic(cd: ClassDef) = cd.mods hasFlag SYNTHETIC - - /** Is this pattern node a synthetic catch-all case, added during PartialFuction synthesis before we know - * whether the user provided cases are exhaustive. */ - def isSyntheticDefaultCase(cdef: CaseDef) = cdef match { - case CaseDef(Bind(DefaultCaseName, _), EmptyTree, _) => true - case _ => false - } - tree match { - case Block( - (cd@ClassDef(_, _, _, Template(_, _, body))) :: Nil, - Apply(Select(New(a), nme.CONSTRUCTOR), Nil)) if isSynthetic(cd) => - val restored = (body collectFirst { - case DefDef(_, /*name.apply | */ name.applyOrElse, _, _, _, Match(_, cases)) => - val nonSyntheticCases = cases.takeWhile(cdef => !isSyntheticDefaultCase(cdef)) - val transformedCases = super.transformStats(nonSyntheticCases, currentOwner).asInstanceOf[List[CaseDef]] - Match(EmptyTree, transformedCases) - }).getOrElse(c.abort(tree.pos, s"Internal Error: Unable to find original pattern matching cases in: $body")) - restored - case t => super.transform(t) - } - } + val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal") + val AsyncClass = rootMirror.staticClass("scala.async.AsyncBase") + val Async_await = AsyncClass.typeSignature.member(newTermName("await")).ensuring(_ != NoSymbol) } def isSafeToInline(tree: Tree) = { - val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] - object treeInfo extends { - val global: symtab.type = symtab - } with reflect.internal.TreeInfo - val castTree = tree.asInstanceOf[symtab.Tree] - treeInfo.isExprSafeToInline(castTree) + treeInfo.isExprSafeToInline(tree) } /** Map a list of arguments to: @@ -371,4 +140,104 @@ private[async] final case class TransformUtils[C <: Context](c: C) { } }.unzip } + + + def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match { + case Block(stats, expr) => (stats, expr) + case _ => (List(tree), Literal(Constant(()))) + } + + def emptyConstructor: DefDef = { + val emptySuperCall = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil) + DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), Literal(Constant(())))) + } + + def applied(className: String, types: List[Type]): AppliedTypeTree = + AppliedTypeTree(Ident(rootMirror.staticClass(className)), types.map(TypeTree(_))) + + /** Descends into the regions of the tree that are subject to the + * translation to a state machine by `async`. When a nested template, + * function, or by-name argument is encountered, the descent stops, + * and `nestedClass` etc are invoked. + */ + trait AsyncTraverser extends Traverser { + def nestedClass(classDef: ClassDef) { + } + + def nestedModule(module: ModuleDef) { + } + + def nestedMethod(module: DefDef) { + } + + def byNameArgument(arg: Tree) { + } + + def function(function: Function) { + } + + def patMatFunction(tree: Match) { + } + + override def traverse(tree: Tree) { + tree match { + case cd: ClassDef => nestedClass(cd) + case md: ModuleDef => nestedModule(md) + case dd: DefDef => nestedMethod(dd) + case fun: Function => function(fun) + case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions` + case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty => + val isInByName = isByName(fun) + for ((args, i) <- argss.zipWithIndex) { + for ((arg, j) <- args.zipWithIndex) { + if (!isInByName(i, j)) traverse(arg) + else byNameArgument(arg) + } + } + traverse(fun) + case _ => super.traverse(tree) + } + } + } + + def abort(pos: Position, msg: String) = throw new AbortMacroException(pos, msg) + + abstract class MacroTypingTransformer extends TypingTransformer(callSiteTyper.context.unit) { + currentOwner = callSiteTyper.context.owner + + def currOwner: Symbol = currentOwner + + localTyper = global.analyzer.newTyper(callSiteTyper.context.make(unit = callSiteTyper.context.unit)) + } + + def transformAt(tree: Tree)(f: PartialFunction[Tree, (analyzer.Context => Tree)]) = { + object trans extends MacroTypingTransformer { + override def transform(tree: Tree): Tree = { + if (f.isDefinedAt(tree)) { + f(tree)(localTyper.context) + } else super.transform(tree) + } + } + trans.transform(tree) + } + + def changeOwner(tree: Tree, oldOwner: Symbol, newOwner: Symbol): tree.type = { + new ChangeOwnerAndModuleClassTraverser(oldOwner, newOwner).traverse(tree) + tree + } + + class ChangeOwnerAndModuleClassTraverser(oldowner: Symbol, newowner: Symbol) + extends ChangeOwnerTraverser(oldowner, newowner) { + + override def traverse(tree: Tree) { + tree match { + case _: DefTree => change(tree.symbol.moduleClass) + case _ => + } + super.traverse(tree) + } + } + + def toMultiMap[A, B](as: Iterable[(A, B)]): Map[A, List[B]] = + as.toList.groupBy(_._1).mapValues(_.map(_._2).toList).toMap } diff --git a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala index a669cfa..7abc6e8 100644 --- a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala +++ b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala @@ -22,21 +22,28 @@ trait AsyncBaseWithCPSFallback extends AsyncBase { /* Implements `async { ... }` using the CPS plugin. */ - protected def cpsBasedAsyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { + protected def cpsBasedAsyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { import c.universe._ - def lookupMember(name: String) = { - val asyncTrait = c.mirror.staticClass("scala.async.continuations.AsyncBaseWithCPSFallback") + def lookupClassMember(clazz: String, name: String) = { + val asyncTrait = c.mirror.staticClass(clazz) val tpe = asyncTrait.asType.toType - tpe.member(newTermName(name)).ensuring(_ != NoSymbol) + tpe.member(newTermName(name)).ensuring(_ != NoSymbol, s"$clazz.$name") + } + def lookupObjectMember(clazz: String, name: String) = { + val moduleClass = c.mirror.staticModule(clazz).moduleClass + val tpe = moduleClass.asType.toType + tpe.member(newTermName(name)).ensuring(_ != NoSymbol, s"$clazz.$name") } AsyncUtils.vprintln("AsyncBaseWithCPSFallback.cpsBasedAsyncImpl") - val utils = TransformUtils[c.type](c) - val futureSystemOps = futureSystem.mkOps(c) - val awaitSym = utils.defn.Async_await - val awaitFallbackSym = lookupMember("awaitFallback") + val symTab = c.universe.asInstanceOf[reflect.internal.SymbolTable] + val futureSystemOps = futureSystem.mkOps(symTab) + val awaitSym = lookupObjectMember("scala.async.Async", "await") + val awaitFallbackSym = lookupClassMember("scala.async.continuations.AsyncBaseWithCPSFallback", "awaitFallback") // replace `await` invocations with `awaitFallback` invocations val awaitReplacer = new Transformer { @@ -60,10 +67,12 @@ trait AsyncBaseWithCPSFallback extends AsyncBase { }.asInstanceOf[Future[T]] */ + def spawn(expr: Tree) = futureSystemOps.spawn(expr.asInstanceOf[futureSystemOps.universe.Tree], execContext.tree.asInstanceOf[futureSystemOps.universe.Tree]).asInstanceOf[Tree] + val bodyWithFuture = { val tree = bodyWithAwaitFallback match { - case Block(stmts, expr) => Block(stmts, futureSystemOps.spawn(expr)) - case expr => futureSystemOps.spawn(expr) + case Block(stmts, expr) => Block(stmts, spawn(expr)) + case expr => spawn(expr) } c.Expr[futureSystem.Fut[Any]](c.resetAllAttrs(tree.duplicate)) } @@ -71,20 +80,22 @@ trait AsyncBaseWithCPSFallback extends AsyncBase { val bodyWithReset: c.Expr[futureSystem.Fut[Any]] = reify { reset { bodyWithFuture.splice } } - val bodyWithCast = futureSystemOps.castTo[T](bodyWithReset) + val bodyWithCast = futureSystemOps.castTo[T](bodyWithReset.asInstanceOf[futureSystemOps.universe.Expr[futureSystem.Fut[Any]]]).asInstanceOf[c.Expr[futureSystem.Fut[T]]] AsyncUtils.vprintln(s"CPS-based async transform expands to:\n${bodyWithCast.tree}") bodyWithCast } - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { + override def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { AsyncUtils.vprintln("AsyncBaseWithCPSFallback.asyncImpl") - val analyzer = AsyncAnalysis[c.type](c, this) + val asyncMacro = AsyncMacro(c, futureSystem) - if (!analyzer.reportUnsupportedAwaits(body.tree)) - super.asyncImpl[T](c)(body) // no unsupported awaits + if (!asyncMacro.reportUnsupportedAwaits(body.tree.asInstanceOf[asyncMacro.global.Tree], report = fallbackEnabled)) + super.asyncImpl[T](c)(body)(execContext) // no unsupported awaits else - cpsBasedAsyncImpl[T](c)(body) // fallback to CPS + cpsBasedAsyncImpl[T](c)(body)(execContext) // fallback to CPS } } diff --git a/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala index fe6e1a6..e0da5aa 100644 --- a/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala +++ b/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala @@ -13,8 +13,13 @@ import scala.concurrent.Future trait AsyncWithCPSFallback extends AsyncBaseWithCPSFallback with ScalaConcurrentCPSFallback object AsyncWithCPSFallback extends AsyncWithCPSFallback { + import scala.concurrent.{ExecutionContext, Future} - def async[T](body: T) = macro asyncImpl[T] + def async[T](body: T)(implicit execContext: ExecutionContext): Future[T] = macro asyncImpl[T] - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body) + override def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[ExecutionContext]): c.Expr[Future[T]] = { + super.asyncImpl[T](c)(body)(execContext) + } } diff --git a/src/main/scala/scala/async/continuations/CPSBasedAsync.scala b/src/main/scala/scala/async/continuations/CPSBasedAsync.scala index 922d1ac..2003082 100644 --- a/src/main/scala/scala/async/continuations/CPSBasedAsync.scala +++ b/src/main/scala/scala/async/continuations/CPSBasedAsync.scala @@ -8,14 +8,17 @@ package continuations import scala.language.experimental.macros import scala.reflect.macros.Context -import scala.concurrent.Future +import scala.concurrent.{ExecutionContext, Future} trait CPSBasedAsync extends CPSBasedAsyncBase with ScalaConcurrentCPSFallback object CPSBasedAsync extends CPSBasedAsync { - def async[T](body: T) = macro asyncImpl[T] - - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body) + def async[T](body: T)(implicit execContext: ExecutionContext): Future[T] = macro asyncImpl[T] + override def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[ExecutionContext]): c.Expr[Future[T]] = { + super.asyncImpl[T](c)(body)(execContext) + } } diff --git a/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala b/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala index 4e8ec80..a350704 100644 --- a/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala +++ b/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala @@ -15,7 +15,9 @@ import scala.util.continuations._ */ trait CPSBasedAsyncBase extends AsyncBaseWithCPSFallback { - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = - super.cpsBasedAsyncImpl[T](c)(body) - + override def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { + super.cpsBasedAsyncImpl[T](c)(body)(execContext) + } } diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index 43393a7..a06437d 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -40,8 +40,8 @@ class TreeInterrogation { val varDefs = tree1.collect { case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name } - varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2")) - varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2")) + varDefs.map(_.decoded.trim).toSet mustBe (Set("state", "await$1", "await$2")) + varDefs.map(_.decoded.trim).toSet mustBe (Set("state", "await$1", "await$2")) val defDefs = tree1.collect { case t: Template => @@ -68,7 +68,7 @@ object TreeInterrogation extends App { withDebug { val cm = reflect.runtime.currentMirror - val tb = mkToolbox("-cp ${toolboxClasspath} -Xprint:flatten") + val tb = mkToolbox("-cp ${toolboxClasspath} -Xprint:typer -uniqid") import scala.async.Async._ val tree = tb.parse( """ import _root_.scala.async.AsyncId.{async, await} diff --git a/src/test/scala/scala/async/neg/LocalClasses0Spec.scala b/src/test/scala/scala/async/neg/LocalClasses0Spec.scala index 2569303..dcd9bb8 100644 --- a/src/test/scala/scala/async/neg/LocalClasses0Spec.scala +++ b/src/test/scala/scala/async/neg/LocalClasses0Spec.scala @@ -5,121 +5,32 @@ package scala.async package neg -/** - * Copyright (C) 2012 Typesafe Inc. - */ - import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.junit.Test @RunWith(classOf[JUnit4]) class LocalClasses0Spec { - @Test - def `reject a local class`() { - expectError("Local case class Person illegal within `async` block") { - """ - | import scala.concurrent.ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | case class Person(name: String) - | } - """.stripMargin - } + def localClassCrashIssue16() { + import scala.async.AsyncId.{async, await} + async { + class B { def f = 1 } + await(new B()).f + } mustBe 1 } @Test - def `reject a local class 2`() { - expectError("Local case class Person illegal within `async` block") { - """ - | import scala.concurrent.{Future, ExecutionContext} - | import ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | case class Person(name: String) - | val fut = Future { 5 } - | val x = await(fut) - | x - | } - """.stripMargin - } + def nestedCaseClassAndModuleAllowed() { + import AsyncId.{await, async} + async { + trait Base { def base = 0} + await(0) + case class Person(name: String) extends Base + val fut = async { "bob" } + val x = Person(await(fut)) + x.base + x.name + } mustBe "bob" } - - @Test - def `reject a local class 3`() { - expectError("Local case class Person illegal within `async` block") { - """ - | import scala.concurrent.{Future, ExecutionContext} - | import ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | val fut = Future { 5 } - | val x = await(fut) - | case class Person(name: String) - | x - | } - """.stripMargin - } - } - - @Test - def `reject a local class with symbols in its name`() { - expectError("Local case class :: illegal within `async` block") { - """ - | import scala.concurrent.{Future, ExecutionContext} - | import ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | val fut = Future { 5 } - | val x = await(fut) - | case class ::(name: String) - | x - | } - """.stripMargin - } - } - - @Test - def `reject a nested local class`() { - expectError("Local case class Person illegal within `async` block") { - """ - | import scala.concurrent.{Future, ExecutionContext} - | import ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | val fut = Future { 5 } - | val x = 2 + 2 - | var y = 0 - | if (x > 0) { - | case class Person(name: String) - | y = await(fut) - | } else { - | y = x - | } - | y - | } - """.stripMargin - } - } - - @Test - def `reject a local singleton object`() { - expectError("Local object Person illegal within `async` block") { - """ - | import scala.concurrent.ExecutionContext.Implicits.global - | import scala.async.Async._ - | - | async { - | object Person { val name = "Joe" } - | } - """.stripMargin - } - } - } diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala index 7be6299..abce3ce 100644 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -238,7 +238,7 @@ class AnfTransformSpec { val res = async { var i = 0 def get = {i += 1; i} - foo(get)(get) + foo(get)(await(get)) } res mustBe "a0 = 1, b0 = 2" } diff --git a/src/test/scala/scala/async/run/nesteddef/NestedDef.scala b/src/test/scala/scala/async/run/nesteddef/NestedDef.scala index ee0a78e..cf74602 100644 --- a/src/test/scala/scala/async/run/nesteddef/NestedDef.scala +++ b/src/test/scala/scala/async/run/nesteddef/NestedDef.scala @@ -37,4 +37,60 @@ class NestedDef { } result mustBe ((0d, 44d, 2)) } + + // We must lift `foo` and `bar` in the next two tests. + @Test + def nestedDefTransitive1() { + import AsyncId._ + val result = async { + val a = 0 + val x = await(a) - 1 + def bar = a + def foo = bar + foo + } + result mustBe 0 + } + + @Test + def nestedDefTransitive2() { + import AsyncId._ + val result = async { + val a = 0 + val x = await(a) - 1 + def bar = a + def foo = bar + 0 + } + result mustBe 0 + } + + + // checking that our use/definition analysis doesn't cycle. + @Test + def mutuallyRecursive1() { + import AsyncId._ + val result = async { + val a = 0 + val x = await(a) - 1 + def foo: Int = if (true) 0 else bar + def bar: Int = if (true) 0 else foo + bar + } + result mustBe 0 + } + + // checking that our use/definition analysis doesn't cycle. + @Test + def mutuallyRecursive2() { + import AsyncId._ + val result = async { + val a = 0 + def foo: Int = if (true) 0 else bar + def bar: Int = if (true) 0 else foo + val x = await(a) - 1 + bar + } + result mustBe 0 + } } diff --git a/src/test/scala/scala/async/run/toughtype/ToughType.scala b/src/test/scala/scala/async/run/toughtype/ToughType.scala index 83f5a2d..6fcd966 100644 --- a/src/test/scala/scala/async/run/toughtype/ToughType.scala +++ b/src/test/scala/scala/async/run/toughtype/ToughType.scala @@ -67,4 +67,42 @@ class ToughTypeSpec { await(f(2)) } mustBe 3 } + + @Test def existentialBindIssue19() { + import AsyncId.{await, async} + def m7(a: Any) = async { + a match { + case s: Seq[_] => + val x = s.size + var ss = s + ss = s + await(x) + } + } + m7(Nil) mustBe 0 + } + + @Test def existentialBind2Issue19() { + import scala.async.Async._, scala.concurrent.ExecutionContext.Implicits.global + def conjure[T]: T = null.asInstanceOf[T] + + def m3 = async { + val p: List[Option[_]] = conjure[List[Option[_]]] + await(future(1)) + } + + def m4 = async { + await(future[List[_]](Nil)) + } + } + + @Test def singletonTypeIssue17() { + import scala.async.AsyncId.{async, await} + class A { class B } + async { + val a = new A + def foo(b: a.B) = 0 + await(foo(new a.B)) + } + } } -- cgit v1.2.3