From 2d8506a64392cd7192b6831c38798cc9a7c8bfed Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Sun, 7 Jul 2013 10:48:11 +1000 Subject: Move implementation details to scala.async.internal._. If we intend to keep CPS fallback around for any length of time it should probably move there too. --- src/main/scala/scala/async/AnfTransform.scala | 252 ------------- src/main/scala/scala/async/Async.scala | 94 ----- src/main/scala/scala/async/AsyncAnalysis.scala | 91 ----- src/main/scala/scala/async/AsyncBase.scala | 23 ++ src/main/scala/scala/async/AsyncMacro.scala | 29 -- src/main/scala/scala/async/AsyncTransform.scala | 176 ---------- src/main/scala/scala/async/AsyncUtils.scala | 16 - src/main/scala/scala/async/ExprBuilder.scala | 387 -------------------- src/main/scala/scala/async/FutureSystem.scala | 153 -------- src/main/scala/scala/async/Lifter.scala | 150 -------- src/main/scala/scala/async/StateAssigner.scala | 14 - src/main/scala/scala/async/StateMachine.scala | 12 + src/main/scala/scala/async/TransformUtils.scala | 251 ------------- .../continuations/AsyncBaseWithCPSFallback.scala | 3 +- .../continuations/ScalaConcurrentCPSFallback.scala | 1 + .../scala/scala/async/internal/AnfTransform.scala | 253 ++++++++++++++ .../scala/scala/async/internal/AsyncAnalysis.scala | 91 +++++ .../scala/scala/async/internal/AsyncBase.scala | 58 +++ src/main/scala/scala/async/internal/AsyncId.scala | 64 ++++ .../scala/scala/async/internal/AsyncMacro.scala | 29 ++ .../scala/async/internal/AsyncTransform.scala | 176 ++++++++++ .../scala/scala/async/internal/AsyncUtils.scala | 16 + .../scala/scala/async/internal/ExprBuilder.scala | 388 +++++++++++++++++++++ .../scala/scala/async/internal/FutureSystem.scala | 106 ++++++ src/main/scala/scala/async/internal/Lifter.scala | 150 ++++++++ .../scala/scala/async/internal/StateAssigner.scala | 14 + .../scala/async/internal/TransformUtils.scala | 251 +++++++++++++ 27 files changed, 1634 insertions(+), 1614 deletions(-) delete mode 100644 src/main/scala/scala/async/AnfTransform.scala delete mode 100644 src/main/scala/scala/async/Async.scala delete mode 100644 src/main/scala/scala/async/AsyncAnalysis.scala create mode 100644 src/main/scala/scala/async/AsyncBase.scala delete mode 100644 src/main/scala/scala/async/AsyncMacro.scala delete mode 100644 src/main/scala/scala/async/AsyncTransform.scala delete mode 100644 src/main/scala/scala/async/AsyncUtils.scala delete mode 100644 src/main/scala/scala/async/ExprBuilder.scala delete mode 100644 src/main/scala/scala/async/FutureSystem.scala delete mode 100644 src/main/scala/scala/async/Lifter.scala delete mode 100644 src/main/scala/scala/async/StateAssigner.scala create mode 100644 src/main/scala/scala/async/StateMachine.scala delete mode 100644 src/main/scala/scala/async/TransformUtils.scala create mode 100644 src/main/scala/scala/async/internal/AnfTransform.scala create mode 100644 src/main/scala/scala/async/internal/AsyncAnalysis.scala create mode 100644 src/main/scala/scala/async/internal/AsyncBase.scala create mode 100644 src/main/scala/scala/async/internal/AsyncId.scala create mode 100644 src/main/scala/scala/async/internal/AsyncMacro.scala create mode 100644 src/main/scala/scala/async/internal/AsyncTransform.scala create mode 100644 src/main/scala/scala/async/internal/AsyncUtils.scala create mode 100644 src/main/scala/scala/async/internal/ExprBuilder.scala create mode 100644 src/main/scala/scala/async/internal/FutureSystem.scala create mode 100644 src/main/scala/scala/async/internal/Lifter.scala create mode 100644 src/main/scala/scala/async/internal/StateAssigner.scala create mode 100644 src/main/scala/scala/async/internal/TransformUtils.scala (limited to 'src/main/scala') diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala deleted file mode 100644 index 14263da..0000000 --- a/src/main/scala/scala/async/AnfTransform.scala +++ /dev/null @@ -1,252 +0,0 @@ - -/* - * Copyright (C) 2012 Typesafe Inc. - */ - -package scala.async - -import scala.tools.nsc.Global - -private[async] trait AnfTransform { - self: AsyncMacro => - - import global._ - import reflect.internal.Flags._ - - def anfTransform(tree: Tree): Block = { - // Must prepend the () for issue #31. - val block = callSiteTyper.typedPos(tree.pos)(Block(List(Literal(Constant(()))), tree)).setType(tree.tpe) - - new SelectiveAnfTransform().transform(block) - } - - sealed abstract class AnfMode - - case object Anf extends AnfMode - - case object Linearizing extends AnfMode - - final class SelectiveAnfTransform extends MacroTypingTransformer { - var mode: AnfMode = Anf - - def blockToList(tree: Tree): List[Tree] = tree match { - case Block(stats, expr) => stats :+ expr - case t => t :: Nil - } - - def listToBlock(trees: List[Tree]): Block = trees match { - case trees @ (init :+ last) => - val pos = trees.map(_.pos).reduceLeft(_ union _) - Block(init, last).setType(last.tpe).setPos(pos) - } - - override def transform(tree: Tree): Block = { - def anfLinearize: Block = { - val trees: List[Tree] = mode match { - case Anf => anf._transformToList(tree) - case Linearizing => linearize._transformToList(tree) - } - listToBlock(trees) - } - tree match { - case _: ValDef | _: DefDef | _: Function | _: ClassDef | _: TypeDef => - atOwner(tree.symbol)(anfLinearize) - case _: ModuleDef => - atOwner(tree.symbol.moduleClass orElse tree.symbol)(anfLinearize) - case _ => - anfLinearize - } - } - - private object linearize { - def transformToList(tree: Tree): List[Tree] = { - mode = Linearizing; blockToList(transform(tree)) - } - - def transformToBlock(tree: Tree): Block = listToBlock(transformToList(tree)) - - def _transformToList(tree: Tree): List[Tree] = trace(tree) { - val stats :+ expr = anf.transformToList(tree) - expr match { - case Apply(fun, args) if isAwait(fun) => - val valDef = defineVal(name.await, expr, tree.pos) - stats :+ valDef :+ gen.mkAttributedStableRef(valDef.symbol) - - case If(cond, thenp, elsep) => - // if type of if-else is Unit don't introduce assignment, - // but add Unit value to bring it into form expected by async transform - if (expr.tpe =:= definitions.UnitTpe) { - stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(()))) - } else { - val varDef = defineVar(name.ifRes, expr.tpe, tree.pos) - def branchWithAssign(orig: Tree) = localTyper.typedPos(orig.pos) { - def cast(t: Tree) = mkAttributedCastPreservingAnnotations(t, varDef.symbol.tpe) - orig match { - case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr))) - case _ => Assign(Ident(varDef.symbol), cast(orig)) - } - }.setType(orig.tpe) - val ifWithAssign = treeCopy.If(tree, cond, branchWithAssign(thenp), branchWithAssign(elsep)) - stats :+ varDef :+ ifWithAssign :+ gen.mkAttributedStableRef(varDef.symbol) - } - - case Match(scrut, cases) => - // if type of match is Unit don't introduce assignment, - // but add Unit value to bring it into form expected by async transform - if (expr.tpe =:= definitions.UnitTpe) { - stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(()))) - } - else { - val varDef = defineVar(name.matchRes, expr.tpe, tree.pos) - def typedAssign(lhs: Tree) = - localTyper.typedPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, varDef.symbol.tpe))) - val casesWithAssign = cases map { - case cd@CaseDef(pat, guard, body) => - val newBody = body match { - case b@Block(caseStats, caseExpr) => treeCopy.Block(b, caseStats, typedAssign(caseExpr)) - case _ => typedAssign(body) - } - treeCopy.CaseDef(cd, pat, guard, newBody) - } - val matchWithAssign = treeCopy.Match(tree, scrut, casesWithAssign) - require(matchWithAssign.tpe != null, matchWithAssign) - stats :+ varDef :+ matchWithAssign :+ gen.mkAttributedStableRef(varDef.symbol) - } - case _ => - stats :+ expr - } - } - - private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = { - val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(tp) - ValDef(sym, gen.mkZero(tp)).setType(NoType).setPos(pos) - } - } - - private object trace { - private var indent = -1 - - def indentString = " " * indent - - def apply[T](args: Any)(t: => T): T = { - def prefix = mode.toString.toLowerCase - indent += 1 - def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127) - try { - AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") - val result = t - AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") - result - } finally { - indent -= 1 - } - } - } - - private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = { - val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, SYNTHETIC).setInfo(lhs.tpe) - changeOwner(lhs, currentOwner, sym) - ValDef(sym, changeOwner(lhs, currentOwner, sym)).setType(NoType).setPos(pos) - } - - private object anf { - def transformToList(tree: Tree): List[Tree] = { - mode = Anf; blockToList(transform(tree)) - } - - def _transformToList(tree: Tree): List[Tree] = trace(tree) { - val containsAwait = tree exists isAwait - if (!containsAwait) { - List(tree) - } else tree match { - case Select(qual, sel) => - val stats :+ expr = linearize.transformToList(qual) - stats :+ treeCopy.Select(tree, expr, sel) - - case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty => - // we an assume that no await call appears in a by-name argument position, - // this has already been checked. - val funStats :+ simpleFun = linearize.transformToList(fun) - def isAwaitRef(name: Name) = name.toString.startsWith(AnfTransform.this.name.await + "$") - val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) = - mapArgumentss[List[Tree]](fun, argss) { - case Arg(expr, byName, _) if byName /*|| isPure(expr) TODO */ => (Nil, expr) - case Arg(expr@Ident(name), _, _) if isAwaitRef(name) => (Nil, expr) // TODO needed? // not typed, so it eludes the check in `isSafeToInline` - case Arg(expr, _, argName) => - linearize.transformToList(expr) match { - case stats :+ expr1 => - val valDef = defineVal(argName, expr1, expr1.pos) - require(valDef.tpe != null, valDef) - val stats1 = stats :+ valDef - //stats1.foreach(changeOwner(_, currentOwner, currentOwner.owner)) - (stats1, gen.stabilize(gen.mkAttributedIdent(valDef.symbol))) - } - } - val applied = treeInfo.dissectApplied(tree) - val core = if (targs.isEmpty) simpleFun else treeCopy.TypeApply(applied.callee, simpleFun, targs) - val newApply = argExprss.foldLeft(core)(Apply(_, _)).setSymbol(tree.symbol) - val typedNewApply = localTyper.typedPos(tree.pos)(newApply).setType(tree.tpe) - funStats ++ argStatss.flatten.flatten :+ typedNewApply - case Block(stats, expr) => - (stats :+ expr).flatMap(linearize.transformToList) - - case ValDef(mods, name, tpt, rhs) => - if (rhs exists isAwait) { - val stats :+ expr = atOwner(currOwner.owner)(linearize.transformToList(rhs)) - stats.foreach(changeOwner(_, currOwner, currOwner.owner)) - stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr) - } else List(tree) - - case Assign(lhs, rhs) => - val stats :+ expr = linearize.transformToList(rhs) - stats :+ treeCopy.Assign(tree, lhs, expr) - - case If(cond, thenp, elsep) => - val condStats :+ condExpr = linearize.transformToList(cond) - val thenBlock = linearize.transformToBlock(thenp) - val elseBlock = linearize.transformToBlock(elsep) - // Typechecking with `condExpr` as the condition fails if the condition - // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems - // we rely on this call to `typeCheck` descending into the branches. - // But, we can get away with typechecking a throwaway `If` tree with the - // original scrutinee and the new branches, and setting that type on - // the real `If` tree. - val iff = treeCopy.If(tree, condExpr, thenBlock, elseBlock) - condStats :+ iff - - case Match(scrut, cases) => - val scrutStats :+ scrutExpr = linearize.transformToList(scrut) - val caseDefs = cases map { - case CaseDef(pat, guard, body) => - // extract local variables for all names bound in `pat`, and rewrite `body` - // to refer to these. - // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`. - val block = linearize.transformToBlock(body) - val (valDefs, mappings) = (pat collect { - case b@Bind(name, _) => - val vd = defineVal(name.toTermName + AnfTransform.this.name.bindSuffix, gen.mkAttributedStableRef(b.symbol), b.pos) - (vd, (b.symbol, vd.symbol)) - }).unzip - val (from, to) = mappings.unzip - val b@Block(stats1, expr1) = block.substituteSymbols(from, to).asInstanceOf[Block] - val newBlock = treeCopy.Block(b, valDefs ++ stats1, expr1) - treeCopy.CaseDef(tree, pat, guard, newBlock) - } - // Refer to comments the translation of `If` above. - val typedMatch = treeCopy.Match(tree, scrutExpr, caseDefs) - scrutStats :+ typedMatch - - case LabelDef(name, params, rhs) => - List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) - - case TypeApply(fun, targs) => - val funStats :+ simpleFun = linearize.transformToList(fun) - funStats :+ treeCopy.TypeApply(tree, simpleFun, targs) - - case _ => - List(tree) - } - } - } - } -} diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala deleted file mode 100644 index 5f577cf..0000000 --- a/src/main/scala/scala/async/Async.scala +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. - */ - -package scala.async - -import scala.language.experimental.macros -import scala.reflect.macros.Context -import scala.reflect.internal.annotations.compileTimeOnly -import scala.tools.nsc.Global -import language.reflectiveCalls -import scala.concurrent.ExecutionContext - -object Async extends AsyncBase { - - import scala.concurrent.Future - - lazy val futureSystem = ScalaConcurrentFutureSystem - type FS = ScalaConcurrentFutureSystem.type - - def async[T](body: T)(implicit execContext: ExecutionContext): Future[T] = macro asyncImpl[T] - - override def asyncImpl[T: c.WeakTypeTag](c: Context) - (body: c.Expr[T]) - (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[Future[T]] = { - super.asyncImpl[T](c)(body)(execContext) - } -} - -object AsyncId extends AsyncBase { - lazy val futureSystem = IdentityFutureSystem - type FS = IdentityFutureSystem.type - - def async[T](body: T) = macro asyncIdImpl[T] - - def asyncIdImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit) -} - -/** - * A base class for the `async` macro. Subclasses must provide: - * - * - Concrete types for a given future system - * - Tree manipulations to create and complete the equivalent of Future and Promise - * in that system. - * - The `async` macro declaration itself, and a forwarder for the macro implementation. - * (The latter is temporarily needed to workaround bug SI-6650 in the macro system) - * - * The default implementation, [[scala.async.Async]], binds the macro to `scala.concurrent._`. - */ -abstract class AsyncBase { - self => - - type FS <: FutureSystem - val futureSystem: FS - - /** - * A call to `await` must be nested in an enclosing `async` block. - * - * A call to `await` does not block the current thread, rather it is a delimiter - * used by the enclosing `async` macro. Code following the `await` - * call is executed asynchronously, when the argument of `await` has been completed. - * - * @param awaitable the future from which a value is awaited. - * @tparam T the type of that value. - * @return the value. - */ - @compileTimeOnly("`await` must be enclosed in an `async` block") - def await[T](awaitable: futureSystem.Fut[T]): T = ??? - - protected[async] def fallbackEnabled = false - - def asyncImpl[T: c.WeakTypeTag](c: Context) - (body: c.Expr[T]) - (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { - import c.universe._ - - val asyncMacro = AsyncMacro(c, futureSystem) - - val code = asyncMacro.asyncTransform[T]( - body.tree.asInstanceOf[asyncMacro.global.Tree], - execContext.tree.asInstanceOf[asyncMacro.global.Tree], - fallbackEnabled)(implicitly[c.WeakTypeTag[T]].asInstanceOf[asyncMacro.global.WeakTypeTag[T]]) - - AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}") - c.Expr[futureSystem.Fut[T]](code.asInstanceOf[Tree]) - } -} - -/** Internal class used by the `async` macro; should not be manually extended by client code */ -abstract class StateMachine[Result, EC] extends (scala.util.Try[Any] => Unit) with (() => Unit) { - def result: Result - - def execContext: EC -} diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala deleted file mode 100644 index 424318e..0000000 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. - */ - -package scala.async - -import scala.reflect.macros.Context -import scala.collection.mutable - -trait AsyncAnalysis { - self: AsyncMacro => - - import global._ - - /** - * Analyze the contents of an `async` block in order to: - * - Report unsupported `await` calls under nested templates, functions, by-name arguments. - * - * Must be called on the original tree, not on the ANF transformed tree. - */ - def reportUnsupportedAwaits(tree: Tree, report: Boolean): Boolean = { - val analyzer = new UnsupportedAwaitAnalyzer(report) - analyzer.traverse(tree) - analyzer.hasUnsupportedAwaits - } - - private class UnsupportedAwaitAnalyzer(report: Boolean) extends AsyncTraverser { - var hasUnsupportedAwaits = false - - override def nestedClass(classDef: ClassDef) { - val kind = if (classDef.symbol.isTrait) "trait" else "class" - reportUnsupportedAwait(classDef, s"nested ${kind}") - } - - override def nestedModule(module: ModuleDef) { - reportUnsupportedAwait(module, "nested object") - } - - override def nestedMethod(defDef: DefDef) { - reportUnsupportedAwait(defDef, "nested method") - } - - override def byNameArgument(arg: Tree) { - reportUnsupportedAwait(arg, "by-name argument") - } - - override def function(function: Function) { - reportUnsupportedAwait(function, "nested function") - } - - override def patMatFunction(tree: Match) { - reportUnsupportedAwait(tree, "nested function") - } - - override def traverse(tree: Tree) { - def containsAwait = tree exists isAwait - tree match { - case Try(_, _, _) if containsAwait => - reportUnsupportedAwait(tree, "try/catch") - super.traverse(tree) - case Return(_) => - abort(tree.pos, "return is illegal within a async block") - case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) => - // TODO lift this restriction - abort(tree.pos, "lazy vals are illegal within an async block") - case _ => - super.traverse(tree) - } - } - - /** - * @return true, if the tree contained an unsupported await. - */ - private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = { - val badAwaits: List[RefTree] = tree collect { - case rt: RefTree if isAwait(rt) => rt - } - badAwaits foreach { - tree => - reportError(tree.pos, s"await must not be used under a $whyUnsupported.") - } - badAwaits.nonEmpty - } - - private def reportError(pos: Position, msg: String) { - hasUnsupportedAwaits = true - if (report) - abort(pos, msg) - } - } -} diff --git a/src/main/scala/scala/async/AsyncBase.scala b/src/main/scala/scala/async/AsyncBase.scala new file mode 100644 index 0000000..ff04a57 --- /dev/null +++ b/src/main/scala/scala/async/AsyncBase.scala @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async + +import scala.language.experimental.macros +import scala.reflect.macros.Context +import scala.concurrent.{Future, ExecutionContext} +import scala.async.internal.{AsyncBase, ScalaConcurrentFutureSystem} + +object Async extends AsyncBase { + type FS = ScalaConcurrentFutureSystem.type + val futureSystem: FS = ScalaConcurrentFutureSystem + + def async[T](body: T)(implicit execContext: ExecutionContext): Future[T] = macro asyncImpl[T] + + override def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[Future[T]] = { + super.asyncImpl[T](c)(body)(execContext) + } +} diff --git a/src/main/scala/scala/async/AsyncMacro.scala b/src/main/scala/scala/async/AsyncMacro.scala deleted file mode 100644 index 8827351..0000000 --- a/src/main/scala/scala/async/AsyncMacro.scala +++ /dev/null @@ -1,29 +0,0 @@ -package scala.async - -import scala.tools.nsc.Global -import scala.tools.nsc.transform.TypingTransformers - -object AsyncMacro { - def apply(c: reflect.macros.Context, futureSystem0: FutureSystem): AsyncMacro = { - import language.reflectiveCalls - val powerContext = c.asInstanceOf[c.type {val universe: Global; val callsiteTyper: universe.analyzer.Typer}] - new AsyncMacro { - val global: powerContext.universe.type = powerContext.universe - val callSiteTyper: global.analyzer.Typer = powerContext.callsiteTyper - val futureSystem: futureSystem0.type = futureSystem0 - val futureSystemOps: futureSystem.Ops {val universe: global.type} = futureSystem0.mkOps(global) - val macroApplication: global.Tree = c.macroApplication.asInstanceOf[global.Tree] - } - } -} - -private[async] trait AsyncMacro - extends TypingTransformers - with AnfTransform with TransformUtils with Lifter - with ExprBuilder with AsyncTransform with AsyncAnalysis { - - val global: Global - val callSiteTyper: global.analyzer.Typer - val macroApplication: global.Tree - -} diff --git a/src/main/scala/scala/async/AsyncTransform.scala b/src/main/scala/scala/async/AsyncTransform.scala deleted file mode 100644 index 129f88e..0000000 --- a/src/main/scala/scala/async/AsyncTransform.scala +++ /dev/null @@ -1,176 +0,0 @@ -package scala.async - -trait AsyncTransform { - self: AsyncMacro => - - import global._ - - def asyncTransform[T](body: Tree, execContext: Tree, cpsFallbackEnabled: Boolean) - (implicit resultType: WeakTypeTag[T]): Tree = { - - reportUnsupportedAwaits(body, report = !cpsFallbackEnabled) - - // Transform to A-normal form: - // - no await calls in qualifiers or arguments, - // - if/match only used in statement position. - val anfTree: Block = anfTransform(body) - - val resumeFunTreeDummyBody = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), Literal(Constant(()))) - - val applyDefDefDummyBody: DefDef = { - val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) - DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), Literal(Constant(()))) - } - - val stateMachineType = applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType)) - - val stateMachine: ClassDef = { - val body: List[Tree] = { - val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) - val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree) - val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext) - - val apply0DefDef: DefDef = { - // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. - // See SI-1247 for the the optimization that avoids creatio - DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil)) - } - List(emptyConstructor, stateVar, result, execContextValDef) ++ List(resumeFunTreeDummyBody, applyDefDefDummyBody, apply0DefDef) - } - val template = { - Template(List(stateMachineType), emptyValDef, body) - } - val t = ClassDef(NoMods, name.stateMachineT, Nil, template) - callSiteTyper.typedPos(macroApplication.pos)(Block(t :: Nil, Literal(Constant(())))) - t - } - - val asyncBlock: AsyncBlock = { - val symLookup = new SymLookup(stateMachine.symbol, applyDefDefDummyBody.vparamss.head.head.symbol) - buildAsyncBlock(anfTree, symLookup) - } - - logDiagnostics(anfTree, asyncBlock.asyncStates.map(_.toString)) - - def startStateMachine: Tree = { - val stateMachineSpliced: Tree = spliceMethodBodies( - liftables(asyncBlock.asyncStates), - stateMachine, - asyncBlock.onCompleteHandler[T], - asyncBlock.resumeFunTree[T].rhs - ) - - def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) - - Block(List[Tree]( - stateMachineSpliced, - ValDef(NoMods, name.stateMachine, stateMachineType, Apply(Select(New(Ident(stateMachine.symbol)), nme.CONSTRUCTOR), Nil)), - futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil), selectStateMachine(name.execContext)) - ), - futureSystemOps.promiseToFuture(Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree) - } - - val isSimple = asyncBlock.asyncStates.size == 1 - if (isSimple) - futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }` - else - startStateMachine - } - - def logDiagnostics(anfTree: Tree, states: Seq[String]) { - val pos = macroApplication.pos - def location = try { - pos.source.path - } catch { - case _: UnsupportedOperationException => - pos.toString - } - - AsyncUtils.vprintln(s"In file '$location':") - AsyncUtils.vprintln(s"${macroApplication}") - AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree") - states foreach (s => AsyncUtils.vprintln(s)) - } - - def spliceMethodBodies(liftables: List[Tree], tree: Tree, applyBody: Tree, - resumeBody: Tree): Tree = { - - val liftedSyms = liftables.map(_.symbol).toSet - val stateMachineClass = tree.symbol - liftedSyms.foreach { - sym => - if (sym != null) { - sym.owner = stateMachineClass - if (sym.isModule) - sym.moduleClass.owner = stateMachineClass - } - } - // Replace the ValDefs in the splicee with Assigns to the corresponding lifted - // fields. Similarly, replace references to them with references to the field. - // - // This transform will be only be run on the RHS of `def foo`. - class UseFields extends MacroTypingTransformer { - override def transform(tree: Tree): Tree = tree match { - case _ if currentOwner == stateMachineClass => - super.transform(tree) - case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) => - atOwner(currentOwner) { - val fieldSym = tree.symbol - val set = Assign(gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), transform(rhs)) - changeOwner(set, tree.symbol, currentOwner) - localTyper.typedPos(tree.pos)(set) - } - case _: DefTree if liftedSyms(tree.symbol) => - EmptyTree - case Ident(name) if liftedSyms(tree.symbol) => - val fieldSym = tree.symbol - gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym).setType(tree.tpe) - case _ => - super.transform(tree) - } - } - - val liftablesUseFields = liftables.map { - case vd: ValDef => vd - case x => - val useField = new UseFields() - //.substituteSymbols(fromSyms, toSyms) - useField.atOwner(stateMachineClass)(useField.transform(x)) - } - - tree.children.foreach { - t => - new ChangeOwnerAndModuleClassTraverser(callSiteTyper.context.owner, tree.symbol).traverse(t) - } - val treeSubst = tree - - def fixup(dd: DefDef, body: Tree, ctx: analyzer.Context): Tree = { - val spliceeAnfFixedOwnerSyms = body - val useField = new UseFields() - val newRhs = useField.atOwner(dd.symbol)(useField.transform(spliceeAnfFixedOwnerSyms)) - val typer = global.analyzer.newTyper(ctx.make(dd, dd.symbol)) - treeCopy.DefDef(dd, dd.mods, dd.name, dd.tparams, dd.vparamss, dd.tpt, typer.typed(newRhs)) - } - - liftablesUseFields.foreach(t => if (t.symbol != null) stateMachineClass.info.decls.enter(t.symbol)) - - val result0 = transformAt(treeSubst) { - case t@Template(parents, self, stats) => - (ctx: analyzer.Context) => { - treeCopy.Template(t, parents, self, liftablesUseFields ++ stats) - } - } - val result = transformAt(result0) { - case dd@DefDef(_, name.apply, _, List(List(_)), _, _) if dd.symbol.owner == stateMachineClass => - (ctx: analyzer.Context) => - val typedTree = fixup(dd, changeOwner(applyBody, callSiteTyper.context.owner, dd.symbol), ctx) - typedTree - case dd@DefDef(_, name.resume, _, _, _, _) if dd.symbol.owner == stateMachineClass => - (ctx: analyzer.Context) => - val changed = changeOwner(resumeBody, callSiteTyper.context.owner, dd.symbol) - val res = fixup(dd, changed, ctx) - res - } - result - } -} diff --git a/src/main/scala/scala/async/AsyncUtils.scala b/src/main/scala/scala/async/AsyncUtils.scala deleted file mode 100644 index 1ade5f0..0000000 --- a/src/main/scala/scala/async/AsyncUtils.scala +++ /dev/null @@ -1,16 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. - */ -package scala.async - -object AsyncUtils { - - private def enabled(level: String) = sys.props.getOrElse(s"scala.async.$level", "false").equalsIgnoreCase("true") - - private def verbose = enabled("debug") - private def trace = enabled("trace") - - private[async] def vprintln(s: => Any): Unit = if (verbose) println(s"[async] $s") - - private[async] def trace(s: => Any): Unit = if (trace) println(s"[async] $s") -} diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala deleted file mode 100644 index a3837d3..0000000 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ /dev/null @@ -1,387 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. - */ -package scala.async - -import scala.reflect.macros.Context -import scala.collection.mutable.ListBuffer -import collection.mutable -import language.existentials -import scala.reflect.api.Universe -import scala.reflect.api - -trait ExprBuilder { - builder: AsyncMacro => - - import global._ - import defn._ - - val futureSystem: FutureSystem - val futureSystemOps: futureSystem.Ops { val universe: global.type } - - val stateAssigner = new StateAssigner - val labelDefStates = collection.mutable.Map[Symbol, Int]() - - trait AsyncState { - def state: Int - - def mkHandlerCaseForState: CaseDef - - def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None - - def stats: List[Tree] - - final def allStats: List[Tree] = this match { - case a: AsyncStateWithAwait => stats :+ a.awaitable.resultValDef - case _ => stats - } - - final def body: Tree = stats match { - case stat :: Nil => stat - case init :+ last => Block(init, last) - } - } - - /** A sequence of statements the concludes with a unconditional transition to `nextState` */ - final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup) - extends AsyncState { - - def mkHandlerCaseForState: CaseDef = - mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup)) - - override val toString: String = - s"AsyncState #$state, next = $nextState" - } - - /** A sequence of statements with a conditional transition to the next state, which will represent - * a branch of an `if` or a `match`. - */ - final class AsyncStateWithoutAwait(val stats: List[Tree], val state: Int) extends AsyncState { - override def mkHandlerCaseForState: CaseDef = - mkHandlerCase(state, stats) - - override val toString: String = - s"AsyncStateWithoutAwait #$state" - } - - /** A sequence of statements that concludes with an `await` call. The `onComplete` - * handler will unconditionally transition to `nestState`.`` - */ - final class AsyncStateWithAwait(val stats: List[Tree], val state: Int, nextState: Int, - val awaitable: Awaitable, symLookup: SymLookup) - extends AsyncState { - - override def mkHandlerCaseForState: CaseDef = { - val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr), - Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree - mkHandlerCase(state, stats :+ callOnComplete) - } - - override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = { - val tryGetTree = - Assign( - Ident(awaitable.resultName), - TypeApply(Select(Select(Ident(symLookup.applyTrParam), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) - ) - - /* if (tr.isFailure) - * result.complete(tr.asInstanceOf[Try[T]]) - * else { - * = tr.get.asInstanceOf[] - * - * - * } - */ - val ifIsFailureTree = - If(Select(Ident(symLookup.applyTrParam), Try_isFailure), - futureSystemOps.completeProm[T]( - Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), - Expr[scala.util.Try[T]]( - TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")), - List(TypeTree(weakTypeOf[scala.util.Try[T]]))))).tree, - Block(List(tryGetTree, mkStateTree(nextState, symLookup)), mkResumeApply(symLookup)) - ) - - Some(mkHandlerCase(state, List(ifIsFailureTree))) - } - - override val toString: String = - s"AsyncStateWithAwait #$state, next = $nextState" - } - - /* - * Builder for a single state of an async method. - */ - final class AsyncStateBuilder(state: Int, private val symLookup: SymLookup) { - /* Statements preceding an await call. */ - private val stats = ListBuffer[Tree]() - /** The state of the target of a LabelDef application (while loop jump) */ - private var nextJumpState: Option[Int] = None - - def +=(stat: Tree): this.type = { - assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat") - def addStat() = stats += stat - stat match { - case Apply(fun, Nil) => - labelDefStates get fun.symbol match { - case Some(nextState) => nextJumpState = Some(nextState) - case None => addStat() - } - case _ => addStat() - } - this - } - - def resultWithAwait(awaitable: Awaitable, - nextState: Int): AsyncState = { - val effectiveNextState = nextJumpState.getOrElse(nextState) - new AsyncStateWithAwait(stats.toList, state, effectiveNextState, awaitable, symLookup) - } - - def resultSimple(nextState: Int): AsyncState = { - val effectiveNextState = nextJumpState.getOrElse(nextState) - new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup) - } - - def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { - def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup)) - this += If(condTree, mkBranch(thenState), mkBranch(elseState)) - new AsyncStateWithoutAwait(stats.toList, state) - } - - /** - * Build `AsyncState` ending with a match expression. - * - * The cases of the match simply resume at the state of their corresponding right-hand side. - * - * @param scrutTree tree of the scrutinee - * @param cases list of case definitions - * @param caseStates starting state of the right-hand side of the each case - * @return an `AsyncState` representing the match expression - */ - def resultWithMatch(scrutTree: Tree, cases: List[CaseDef], caseStates: List[Int], symLookup: SymLookup): AsyncState = { - // 1. build list of changed cases - val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { - case CaseDef(pat, guard, rhs) => - val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal) - CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num), symLookup), mkResumeApply(symLookup))) - } - // 2. insert changed match tree at the end of the current state - this += Match(scrutTree, newCases) - new AsyncStateWithoutAwait(stats.toList, state) - } - - def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = { - this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup)) - new AsyncStateWithoutAwait(stats.toList, state) - } - - override def toString: String = { - val statsBeforeAwait = stats.mkString("\n") - s"ASYNC STATE:\n$statsBeforeAwait" - } - } - - /** - * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`). - * - * @param stats a list of expressions - * @param expr the last expression of the block - * @param startState the start state - * @param endState the state to continue with - */ - final private class AsyncBlockBuilder(stats: List[Tree], expr: Tree, startState: Int, endState: Int, - private val symLookup: SymLookup) { - val asyncStates = ListBuffer[AsyncState]() - - var stateBuilder = new AsyncStateBuilder(startState, symLookup) - var currState = startState - - /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ - def checkForUnsupportedAwait(tree: Tree) = if (tree exists { - case Apply(fun, _) if isAwait(fun) => true - case _ => false - }) abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException - - def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = { - val (nestedStats, nestedExpr) = statsAndExpr(nestedTree) - new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, symLookup) - } - - import stateAssigner.nextState - - // populate asyncStates - for (stat <- stats) stat match { - // the val name = await(..) pattern - case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => - val afterAwaitState = nextState() - val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd) - asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await - currState = afterAwaitState - stateBuilder = new AsyncStateBuilder(currState, symLookup) - - case If(cond, thenp, elsep) if stat exists isAwait => - checkForUnsupportedAwait(cond) - - val thenStartState = nextState() - val elseStartState = nextState() - val afterIfState = nextState() - - asyncStates += - // the two Int arguments are the start state of the then branch and the else branch, respectively - stateBuilder.resultWithIf(cond, thenStartState, elseStartState) - - List((thenp, thenStartState), (elsep, elseStartState)) foreach { - case (branchTree, state) => - val builder = nestedBlockBuilder(branchTree, state, afterIfState) - asyncStates ++= builder.asyncStates - } - - currState = afterIfState - stateBuilder = new AsyncStateBuilder(currState, symLookup) - - case Match(scrutinee, cases) if stat exists isAwait => - checkForUnsupportedAwait(scrutinee) - - val caseStates = cases.map(_ => nextState()) - val afterMatchState = nextState() - - asyncStates += - stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup) - - for ((cas, num) <- cases.zipWithIndex) { - val (stats, expr) = statsAndExpr(cas.body) - val stats1 = stats.dropWhile(isSyntheticBindVal) - val builder = nestedBlockBuilder(Block(stats1, expr), caseStates(num), afterMatchState) - asyncStates ++= builder.asyncStates - } - - currState = afterMatchState - stateBuilder = new AsyncStateBuilder(currState, symLookup) - - case ld@LabelDef(name, params, rhs) if rhs exists isAwait => - val startLabelState = nextState() - val afterLabelState = nextState() - asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup) - labelDefStates(ld.symbol) = startLabelState - val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState) - asyncStates ++= builder.asyncStates - - currState = afterLabelState - stateBuilder = new AsyncStateBuilder(currState, symLookup) - case _ => - checkForUnsupportedAwait(stat) - stateBuilder += stat - } - // complete last state builder (representing the expressions after the last await) - stateBuilder += expr - val lastState = stateBuilder.resultSimple(endState) - asyncStates += lastState - } - - trait AsyncBlock { - def asyncStates: List[AsyncState] - - def onCompleteHandler[T: WeakTypeTag]: Tree - - def resumeFunTree[T]: DefDef - } - - case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) { - def stateMachineMember(name: TermName): Symbol = - stateMachineClass.info.member(name) - def memberRef(name: TermName) = gen.mkAttributedRef(stateMachineMember(name)) - } - - def buildAsyncBlock(block: Block, symLookup: SymLookup): AsyncBlock = { - val Block(stats, expr) = block - val startState = stateAssigner.nextState() - val endState = Int.MaxValue - - val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, symLookup) - - new AsyncBlock { - def asyncStates = blockBuilder.asyncStates.toList - - def mkCombinedHandlerCases[T]: List[CaseDef] = { - val caseForLastState: CaseDef = { - val lastState = asyncStates.last - val lastStateBody = Expr[T](lastState.body) - val rhs = futureSystemOps.completeProm( - Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Success(lastStateBody.splice))) - mkHandlerCase(lastState.state, rhs.tree) - } - asyncStates.toList match { - case s :: Nil => - List(caseForLastState) - case _ => - val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState - initCases :+ caseForLastState - } - } - - val initStates = asyncStates.init - - /** - * def resume(): Unit = { - * try { - * state match { - * case 0 => { - * f11 = exprReturningFuture - * f11.onComplete(onCompleteHandler)(context) - * } - * ... - * } - * } catch { - * case NonFatal(t) => result.failure(t) - * } - * } - */ - def resumeFunTree[T]: DefDef = - DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), - Try( - Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T]), - List( - CaseDef( - Bind(name.t, Ident(nme.WILDCARD)), - Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), - Block(List({ - val t = Expr[Throwable](Ident(name.t)) - futureSystemOps.completeProm[T]( - Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Failure(t.splice))).tree - }), literalUnit))), EmptyTree)) - - /** - * // assumes tr: Try[Any] is in scope. - * // - * state match { - * case 0 => { - * x11 = tr.get.asInstanceOf[Double]; - * state = 1; - * resume() - * } - */ - def onCompleteHandler[T: WeakTypeTag]: Tree = Match(symLookup.memberRef(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList) - } - } - - private def isSyntheticBindVal(tree: Tree) = tree match { - case vd@ValDef(_, lname, _, Ident(rname)) => lname.toString.contains(name.bindSuffix) - case _ => false - } - - case class Awaitable(expr: Tree, resultName: Symbol, resultType: Type, resultValDef: ValDef) - - private def mkResumeApply(symLookup: SymLookup) = Apply(symLookup.memberRef(name.resume), Nil) - - private def mkStateTree(nextState: Int, symLookup: SymLookup): Tree = - Assign(symLookup.memberRef(name.state), Literal(Constant(nextState))) - - private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef = - mkHandlerCase(num, Block(rhs, literalUnit)) - - private def mkHandlerCase(num: Int, rhs: Tree): CaseDef = - CaseDef(Literal(Constant(num)), EmptyTree, rhs) - - private def literalUnit = Literal(Constant(())) -} diff --git a/src/main/scala/scala/async/FutureSystem.scala b/src/main/scala/scala/async/FutureSystem.scala deleted file mode 100644 index 0c04296..0000000 --- a/src/main/scala/scala/async/FutureSystem.scala +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. - */ -package scala.async - -import scala.language.higherKinds - -import scala.reflect.macros.Context -import scala.reflect.internal.SymbolTable - -/** - * An abstraction over a future system. - * - * Used by the macro implementations in [[scala.async.AsyncBase]] to - * customize the code generation. - * - * The API mirrors that of `scala.concurrent.Future`, see the instance - * [[scala.async.ScalaConcurrentFutureSystem]] for an example of how - * to implement this. - */ -trait FutureSystem { - /** A container to receive the final value of the computation */ - type Prom[A] - /** A (potentially in-progress) computation */ - type Fut[A] - /** An execution context, required to create or register an on completion callback on a Future. */ - type ExecContext - - trait Ops { - val universe: reflect.internal.SymbolTable - - import universe._ - def Expr[T: WeakTypeTag](tree: Tree): Expr[T] = universe.Expr[T](rootMirror, universe.FixedMirrorTreeCreator(rootMirror, tree)) - - def promType[A: WeakTypeTag]: Type - def execContextType: Type - - /** Create an empty promise */ - def createProm[A: WeakTypeTag]: Expr[Prom[A]] - - /** Extract a future from the given promise. */ - def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]): Expr[Fut[A]] - - /** Construct a future to asynchronously compute the given expression */ - def future[A: WeakTypeTag](a: Expr[A])(execContext: Expr[ExecContext]): Expr[Fut[A]] - - /** Register an call back to run on completion of the given future */ - def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], - execContext: Expr[ExecContext]): Expr[Unit] - - /** Complete a promise with a value */ - def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] - - def spawn(tree: Tree, execContext: Tree): Tree = - future(Expr[Unit](tree))(Expr[ExecContext](execContext)).tree - - // TODO Why is this needed? - def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] - } - - def mkOps(c: SymbolTable): Ops { val universe: c.type } -} - - -object ScalaConcurrentFutureSystem extends FutureSystem { - - import scala.concurrent._ - - type Prom[A] = Promise[A] - type Fut[A] = Future[A] - type ExecContext = ExecutionContext - - def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops { - val universe: c.type = c - - import universe._ - - def promType[A: WeakTypeTag]: Type = weakTypeOf[Promise[A]] - def execContextType: Type = weakTypeOf[ExecutionContext] - - def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { - Promise[A]() - } - - def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify { - prom.splice.future - } - - def future[A: WeakTypeTag](a: Expr[A])(execContext: Expr[ExecContext]) = reify { - Future(a.splice)(execContext.splice) - } - - def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], - execContext: Expr[ExecContext]): Expr[Unit] = reify { - future.splice.onComplete(fun.splice)(execContext.splice) - } - - def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { - prom.splice.complete(value.splice) - Expr[Unit](Literal(Constant(()))).splice - } - - def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = reify { - future.splice.asInstanceOf[Fut[A]] - } - } -} - -/** - * A trivial implementation of [[scala.async.FutureSystem]] that performs computations - * on the current thread. Useful for testing. - */ -object IdentityFutureSystem extends FutureSystem { - - class Prom[A](var a: A) - - type Fut[A] = A - type ExecContext = Unit - - def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops { - val universe: c.type = c - - import universe._ - - def execContext: Expr[ExecContext] = Expr[Unit](Literal(Constant(()))) - - def promType[A: WeakTypeTag]: Type = weakTypeOf[Prom[A]] - def execContextType: Type = weakTypeOf[Unit] - - def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { - new Prom(null.asInstanceOf[A]) - } - - def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify { - prom.splice.a - } - - def future[A: WeakTypeTag](t: Expr[A])(execContext: Expr[ExecContext]) = t - - def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], - execContext: Expr[ExecContext]): Expr[Unit] = reify { - fun.splice.apply(util.Success(future.splice)) - Expr[Unit](Literal(Constant(()))).splice - } - - def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { - prom.splice.a = value.splice.get - Expr[Unit](Literal(Constant(()))).splice - } - - def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = ??? - } -} diff --git a/src/main/scala/scala/async/Lifter.scala b/src/main/scala/scala/async/Lifter.scala deleted file mode 100644 index 52ce47d..0000000 --- a/src/main/scala/scala/async/Lifter.scala +++ /dev/null @@ -1,150 +0,0 @@ -package scala.async - -trait Lifter { - self: AsyncMacro => - import global._ - - /** - * Identify which DefTrees are used (including transitively) which are declared - * in some state but used (including transitively) in another state. - * - * These will need to be lifted to class members of the state machine. - */ - def liftables(asyncStates: List[AsyncState]): List[Tree] = { - object companionship { - private val companions = collection.mutable.Map[Symbol, Symbol]() - private val companionsInverse = collection.mutable.Map[Symbol, Symbol]() - private def record(sym1: Symbol, sym2: Symbol) { - companions(sym1) = sym2 - companions(sym2) = sym1 - } - - def record(defs: List[Tree]) { - // Keep note of local companions so we rename them consistently - // when lifting. - val comps = for { - cd@ClassDef(_, _, _, _) <- defs - md@ModuleDef(_, _, _) <- defs - if (cd.name.toTermName == md.name) - } record(cd.symbol, md.symbol) - } - def companionOf(sym: Symbol): Symbol = { - companions.get(sym).orElse(companionsInverse.get(sym)).getOrElse(NoSymbol) - } - } - - - val defs: Map[Tree, Int] = { - /** Collect the DefTrees directly enclosed within `t` that have the same owner */ - def collectDirectlyEnclosedDefs(t: Tree): List[DefTree] = t match { - case dt: DefTree => dt :: Nil - case _: Function => Nil - case t => - val childDefs = t.children.flatMap(collectDirectlyEnclosedDefs(_)) - companionship.record(childDefs) - childDefs - } - asyncStates.flatMap { - asyncState => - val defs = collectDirectlyEnclosedDefs(Block(asyncState.allStats: _*)) - defs.map((_, asyncState.state)) - }.toMap - } - - // In which block are these symbols defined? - val symToDefiningState: Map[Symbol, Int] = defs.map { - case (k, v) => (k.symbol, v) - } - - // The definitions trees - val symToTree: Map[Symbol, Tree] = defs.map { - case (k, v) => (k.symbol, k) - } - - // The direct references of each definition tree - val defSymToReferenced: Map[Symbol, List[Symbol]] = defs.keys.map { - case tree => (tree.symbol, tree.collect { - case rt: RefTree if symToDefiningState.contains(rt.symbol) => rt.symbol - }) - }.toMap - - // The direct references of each block, excluding references of `DefTree`-s which - // are already accounted for. - val stateIdToDirectlyReferenced: Map[Int, List[Symbol]] = { - val refs: List[(Int, Symbol)] = asyncStates.flatMap( - asyncState => asyncState.stats.filterNot(_.isDef).flatMap(_.collect { - case rt: RefTree if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol) - }) - ) - toMultiMap(refs) - } - - def liftableSyms: Set[Symbol] = { - val liftableMutableSet = collection.mutable.Set[Symbol]() - def markForLift(sym: Symbol) { - if (!liftableMutableSet(sym)) { - liftableMutableSet += sym - - // Only mark transitive references of defs, modules and classes. The RHS of lifted vals/vars - // stays in its original location, so things that it refers to need not be lifted. - if (!(sym.isVal || sym.isVar)) - defSymToReferenced(sym).foreach(sym2 => markForLift(sym2)) - } - } - // Start things with DefTrees directly referenced from statements from other states... - val liftableStatementRefs: List[Symbol] = stateIdToDirectlyReferenced.toList.flatMap { - case (i, syms) => syms.filter(sym => symToDefiningState(sym) != i) - } - // .. and likewise for DefTrees directly referenced by other DefTrees from other states - val liftableRefsOfDefTrees = defSymToReferenced.toList.flatMap { - case (referee, referents) => referents.filter(sym => symToDefiningState(sym) != symToDefiningState(referee)) - } - // Mark these for lifting, which will follow transitive references. - (liftableStatementRefs ++ liftableRefsOfDefTrees).foreach(markForLift) - liftableMutableSet.toSet - } - - val lifted = liftableSyms.map(symToTree).toList.map { - case vd@ValDef(_, _, tpt, rhs) => - import reflect.internal.Flags._ - val sym = vd.symbol - sym.setFlag(MUTABLE | STABLE | PRIVATE | LOCAL) - sym.name = name.fresh(sym.name.toTermName) - sym.modifyInfo(_.deconst) - ValDef(vd.symbol, gen.mkZero(vd.symbol.info)).setPos(vd.pos) - case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) => - import reflect.internal.Flags._ - val sym = dd.symbol - sym.name = this.name.fresh(sym.name.toTermName) - sym.setFlag(PRIVATE | LOCAL) - DefDef(dd.symbol, rhs).setPos(dd.pos) - case cd@ClassDef(_, _, _, impl) => - import reflect.internal.Flags._ - val sym = cd.symbol - sym.name = newTypeName(name.fresh(sym.name.toString).toString) - companionship.companionOf(cd.symbol) match { - case NoSymbol => - case moduleSymbol => - moduleSymbol.name = sym.name.toTermName - moduleSymbol.moduleClass.name = moduleSymbol.name.toTypeName - } - ClassDef(cd.symbol, impl).setPos(cd.pos) - case md@ModuleDef(_, _, impl) => - import reflect.internal.Flags._ - val sym = md.symbol - companionship.companionOf(md.symbol) match { - case NoSymbol => - sym.name = name.fresh(sym.name.toTermName) - sym.moduleClass.name = sym.name.toTypeName - case classSymbol => // will be renamed by `case ClassDef` above. - } - ModuleDef(md.symbol, impl).setPos(md.pos) - case td@TypeDef(_, _, _, rhs) => - import reflect.internal.Flags._ - val sym = td.symbol - sym.name = newTypeName(name.fresh(sym.name.toString).toString) - TypeDef(td.symbol, rhs).setPos(td.pos) - } - lifted - } -} diff --git a/src/main/scala/scala/async/StateAssigner.scala b/src/main/scala/scala/async/StateAssigner.scala deleted file mode 100644 index bc60a6d..0000000 --- a/src/main/scala/scala/async/StateAssigner.scala +++ /dev/null @@ -1,14 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. - */ - -package scala.async - -private[async] final class StateAssigner { - private var current = -1 - - def nextState(): Int = { - current += 1 - current - } -} \ No newline at end of file diff --git a/src/main/scala/scala/async/StateMachine.scala b/src/main/scala/scala/async/StateMachine.scala new file mode 100644 index 0000000..823df71 --- /dev/null +++ b/src/main/scala/scala/async/StateMachine.scala @@ -0,0 +1,12 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async + +/** Internal class used by the `async` macro; should not be manually extended by client code */ +abstract class StateMachine[Result, EC] extends (scala.util.Try[Any] => Unit) with (() => Unit) { + def result: Result + + def execContext: EC +} diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala deleted file mode 100644 index d69d03a..0000000 --- a/src/main/scala/scala/async/TransformUtils.scala +++ /dev/null @@ -1,251 +0,0 @@ -/* - * Copyright (C) 2012 Typesafe Inc. - */ -package scala.async - -import scala.reflect.macros.Context -import reflect.ClassTag -import scala.reflect.macros.runtime.AbortMacroException - -/** - * Utilities used in both `ExprBuilder` and `AnfTransform`. - */ -private[async] trait TransformUtils { - self: AsyncMacro => - - import global._ - - object name { - val resume = newTermName("resume") - val apply = newTermName("apply") - val matchRes = "matchres" - val ifRes = "ifres" - val await = "await" - val bindSuffix = "$bind" - - val state = newTermName("state") - val result = newTermName("result") - val execContext = newTermName("execContext") - val stateMachine = newTermName(fresh("stateMachine")) - val stateMachineT = stateMachine.toTypeName - val tr = newTermName("tr") - val t = newTermName("throwable") - - def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) - - def fresh(name: String): String = currentUnit.freshTermName("" + name + "$").toString - } - - def isAwait(fun: Tree) = - fun.symbol == defn.Async_await - - private lazy val Boolean_ShortCircuits: Set[Symbol] = { - import definitions.BooleanClass - def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName) - val Boolean_&& = BooleanTermMember("&&") - val Boolean_|| = BooleanTermMember("||") - Set(Boolean_&&, Boolean_||) - } - - private def isByName(fun: Tree): ((Int, Int) => Boolean) = { - if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true - else { - val paramss = fun.tpe.paramss - val byNamess = paramss.map(_.map(_.isByNameParam)) - (i, j) => util.Try(byNamess(i)(j)).getOrElse(false) - } - } - private def argName(fun: Tree): ((Int, Int) => String) = { - val paramss = fun.tpe.paramss - val namess = paramss.map(_.map(_.name.toString)) - (i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}") - } - - def Expr[A: WeakTypeTag](t: Tree) = global.Expr[A](rootMirror, new FixedMirrorTreeCreator(rootMirror, t)) - - object defn { - def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { - Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) - } - - def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify { - self.splice.contains(elem.splice) - } - - def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify { - self.splice.apply(arg.splice) - } - - def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify { - self.splice == other.splice - } - - def mkTry_get[A](self: Expr[util.Try[A]]) = reify { - self.splice.get - } - - val TryClass = rootMirror.staticClass("scala.util.Try") - val Try_get = TryClass.typeSignature.member(newTermName("get")).ensuring(_ != NoSymbol) - val Try_isFailure = TryClass.typeSignature.member(newTermName("isFailure")).ensuring(_ != NoSymbol) - val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe)) - val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal") - val AsyncClass = rootMirror.staticClass("scala.async.AsyncBase") - val Async_await = AsyncClass.typeSignature.member(newTermName("await")).ensuring(_ != NoSymbol) - } - - def isSafeToInline(tree: Tree) = { - treeInfo.isExprSafeToInline(tree) - } - - /** Map a list of arguments to: - * - A list of argument Trees - * - A list of auxillary results. - * - * The function unwraps and rewraps the `arg :_*` construct. - * - * @param args The original argument trees - * @param f A function from argument (with '_*' unwrapped) and argument index to argument. - * @tparam A The type of the auxillary result - */ - private def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = { - args match { - case args :+ Typed(tree, Ident(tpnme.WILDCARD_STAR)) => - val (a, argExprs :+ lastArgExpr) = (args :+ tree).zipWithIndex.map(f.tupled).unzip - val exprs = argExprs :+ Typed(lastArgExpr, Ident(tpnme.WILDCARD_STAR)).setPos(lastArgExpr.pos) - (a, exprs) - case args => - args.zipWithIndex.map(f.tupled).unzip - } - } - - case class Arg(expr: Tree, isByName: Boolean, argName: String) - - /** - * Transform a list of argument lists, producing the transformed lists, and lists of auxillary - * results. - * - * The function `f` need not concern itself with varargs arguments e.g (`xs : _*`). It will - * receive `xs`, and it's result will be re-wrapped as `f(xs) : _*`. - * - * @param fun The function being applied - * @param argss The argument lists - * @return (auxillary results, mapped argument trees) - */ - def mapArgumentss[A](fun: Tree, argss: List[List[Tree]])(f: Arg => (A, Tree)): (List[List[A]], List[List[Tree]]) = { - val isByNamess: (Int, Int) => Boolean = isByName(fun) - val argNamess: (Int, Int) => String = argName(fun) - argss.zipWithIndex.map { case (args, i) => - mapArguments[A](args) { - (tree, j) => f(Arg(tree, isByNamess(i, j), argNamess(i, j))) - } - }.unzip - } - - - def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match { - case Block(stats, expr) => (stats, expr) - case _ => (List(tree), Literal(Constant(()))) - } - - def emptyConstructor: DefDef = { - val emptySuperCall = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil) - DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), Literal(Constant(())))) - } - - def applied(className: String, types: List[Type]): AppliedTypeTree = - AppliedTypeTree(Ident(rootMirror.staticClass(className)), types.map(TypeTree(_))) - - /** Descends into the regions of the tree that are subject to the - * translation to a state machine by `async`. When a nested template, - * function, or by-name argument is encountered, the descent stops, - * and `nestedClass` etc are invoked. - */ - trait AsyncTraverser extends Traverser { - def nestedClass(classDef: ClassDef) { - } - - def nestedModule(module: ModuleDef) { - } - - def nestedMethod(module: DefDef) { - } - - def byNameArgument(arg: Tree) { - } - - def function(function: Function) { - } - - def patMatFunction(tree: Match) { - } - - override def traverse(tree: Tree) { - tree match { - case cd: ClassDef => nestedClass(cd) - case md: ModuleDef => nestedModule(md) - case dd: DefDef => nestedMethod(dd) - case fun: Function => function(fun) - case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions` - case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty => - val isInByName = isByName(fun) - for ((args, i) <- argss.zipWithIndex) { - for ((arg, j) <- args.zipWithIndex) { - if (!isInByName(i, j)) traverse(arg) - else byNameArgument(arg) - } - } - traverse(fun) - case _ => super.traverse(tree) - } - } - } - - def abort(pos: Position, msg: String) = throw new AbortMacroException(pos, msg) - - abstract class MacroTypingTransformer extends TypingTransformer(callSiteTyper.context.unit) { - currentOwner = callSiteTyper.context.owner - - def currOwner: Symbol = currentOwner - - localTyper = global.analyzer.newTyper(callSiteTyper.context.make(unit = callSiteTyper.context.unit)) - } - - def transformAt(tree: Tree)(f: PartialFunction[Tree, (analyzer.Context => Tree)]) = { - object trans extends MacroTypingTransformer { - override def transform(tree: Tree): Tree = { - if (f.isDefinedAt(tree)) { - f(tree)(localTyper.context) - } else super.transform(tree) - } - } - trans.transform(tree) - } - - def changeOwner(tree: Tree, oldOwner: Symbol, newOwner: Symbol): tree.type = { - new ChangeOwnerAndModuleClassTraverser(oldOwner, newOwner).traverse(tree) - tree - } - - class ChangeOwnerAndModuleClassTraverser(oldowner: Symbol, newowner: Symbol) - extends ChangeOwnerTraverser(oldowner, newowner) { - - override def traverse(tree: Tree) { - tree match { - case _: DefTree => change(tree.symbol.moduleClass) - case _ => - } - super.traverse(tree) - } - } - - def toMultiMap[A, B](as: Iterable[(A, B)]): Map[A, List[B]] = - as.toList.groupBy(_._1).mapValues(_.map(_._2).toList).toMap - - // Attributed version of `TreeGen#mkCastPreservingAnnotations` - def mkAttributedCastPreservingAnnotations(tree: Tree, tp: Type): Tree = { - atPos(tree.pos) { - val casted = gen.mkAttributedCast(tree, tp.withoutAnnotations.dealias) - Typed(casted, TypeTree(tp)).setType(tp) - } - } -} diff --git a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala index 7abc6e8..2902558 100644 --- a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala +++ b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala @@ -9,8 +9,9 @@ import scala.language.experimental.macros import scala.reflect.macros.Context import scala.util.continuations._ +import scala.async.internal.{AsyncMacro, AsyncUtils} -trait AsyncBaseWithCPSFallback extends AsyncBase { +trait AsyncBaseWithCPSFallback extends internal.AsyncBase { /* Fall-back for `await` using CPS plugin. * diff --git a/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala b/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala index 018ad05..f864ad6 100644 --- a/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala +++ b/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala @@ -7,6 +7,7 @@ package continuations import scala.util.continuations._ import scala.concurrent.{Future, Promise, ExecutionContext} +import scala.async.internal.ScalaConcurrentFutureSystem trait ScalaConcurrentCPSFallback { self: AsyncBaseWithCPSFallback => diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala new file mode 100644 index 0000000..80f8161 --- /dev/null +++ b/src/main/scala/scala/async/internal/AnfTransform.scala @@ -0,0 +1,253 @@ + +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async.internal + +import scala.tools.nsc.Global +import scala.Predef._ + +private[async] trait AnfTransform { + self: AsyncMacro => + + import global._ + import reflect.internal.Flags._ + + def anfTransform(tree: Tree): Block = { + // Must prepend the () for issue #31. + val block = callSiteTyper.typedPos(tree.pos)(Block(List(Literal(Constant(()))), tree)).setType(tree.tpe) + + new SelectiveAnfTransform().transform(block) + } + + sealed abstract class AnfMode + + case object Anf extends AnfMode + + case object Linearizing extends AnfMode + + final class SelectiveAnfTransform extends MacroTypingTransformer { + var mode: AnfMode = Anf + + def blockToList(tree: Tree): List[Tree] = tree match { + case Block(stats, expr) => stats :+ expr + case t => t :: Nil + } + + def listToBlock(trees: List[Tree]): Block = trees match { + case trees @ (init :+ last) => + val pos = trees.map(_.pos).reduceLeft(_ union _) + Block(init, last).setType(last.tpe).setPos(pos) + } + + override def transform(tree: Tree): Block = { + def anfLinearize: Block = { + val trees: List[Tree] = mode match { + case Anf => anf._transformToList(tree) + case Linearizing => linearize._transformToList(tree) + } + listToBlock(trees) + } + tree match { + case _: ValDef | _: DefDef | _: Function | _: ClassDef | _: TypeDef => + atOwner(tree.symbol)(anfLinearize) + case _: ModuleDef => + atOwner(tree.symbol.moduleClass orElse tree.symbol)(anfLinearize) + case _ => + anfLinearize + } + } + + private object linearize { + def transformToList(tree: Tree): List[Tree] = { + mode = Linearizing; blockToList(transform(tree)) + } + + def transformToBlock(tree: Tree): Block = listToBlock(transformToList(tree)) + + def _transformToList(tree: Tree): List[Tree] = trace(tree) { + val stats :+ expr = anf.transformToList(tree) + expr match { + case Apply(fun, args) if isAwait(fun) => + val valDef = defineVal(name.await, expr, tree.pos) + stats :+ valDef :+ gen.mkAttributedStableRef(valDef.symbol) + + case If(cond, thenp, elsep) => + // if type of if-else is Unit don't introduce assignment, + // but add Unit value to bring it into form expected by async transform + if (expr.tpe =:= definitions.UnitTpe) { + stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(()))) + } else { + val varDef = defineVar(name.ifRes, expr.tpe, tree.pos) + def branchWithAssign(orig: Tree) = localTyper.typedPos(orig.pos) { + def cast(t: Tree) = mkAttributedCastPreservingAnnotations(t, varDef.symbol.tpe) + orig match { + case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr))) + case _ => Assign(Ident(varDef.symbol), cast(orig)) + } + }.setType(orig.tpe) + val ifWithAssign = treeCopy.If(tree, cond, branchWithAssign(thenp), branchWithAssign(elsep)) + stats :+ varDef :+ ifWithAssign :+ gen.mkAttributedStableRef(varDef.symbol) + } + + case Match(scrut, cases) => + // if type of match is Unit don't introduce assignment, + // but add Unit value to bring it into form expected by async transform + if (expr.tpe =:= definitions.UnitTpe) { + stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(()))) + } + else { + val varDef = defineVar(name.matchRes, expr.tpe, tree.pos) + def typedAssign(lhs: Tree) = + localTyper.typedPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, varDef.symbol.tpe))) + val casesWithAssign = cases map { + case cd@CaseDef(pat, guard, body) => + val newBody = body match { + case b@Block(caseStats, caseExpr) => treeCopy.Block(b, caseStats, typedAssign(caseExpr)) + case _ => typedAssign(body) + } + treeCopy.CaseDef(cd, pat, guard, newBody) + } + val matchWithAssign = treeCopy.Match(tree, scrut, casesWithAssign) + require(matchWithAssign.tpe != null, matchWithAssign) + stats :+ varDef :+ matchWithAssign :+ gen.mkAttributedStableRef(varDef.symbol) + } + case _ => + stats :+ expr + } + } + + private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = { + val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(tp) + ValDef(sym, gen.mkZero(tp)).setType(NoType).setPos(pos) + } + } + + private object trace { + private var indent = -1 + + def indentString = " " * indent + + def apply[T](args: Any)(t: => T): T = { + def prefix = mode.toString.toLowerCase + indent += 1 + def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127) + try { + AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") + val result = t + AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") + result + } finally { + indent -= 1 + } + } + } + + private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = { + val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, SYNTHETIC).setInfo(lhs.tpe) + changeOwner(lhs, currentOwner, sym) + ValDef(sym, changeOwner(lhs, currentOwner, sym)).setType(NoType).setPos(pos) + } + + private object anf { + def transformToList(tree: Tree): List[Tree] = { + mode = Anf; blockToList(transform(tree)) + } + + def _transformToList(tree: Tree): List[Tree] = trace(tree) { + val containsAwait = tree exists isAwait + if (!containsAwait) { + List(tree) + } else tree match { + case Select(qual, sel) => + val stats :+ expr = linearize.transformToList(qual) + stats :+ treeCopy.Select(tree, expr, sel) + + case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty => + // we an assume that no await call appears in a by-name argument position, + // this has already been checked. + val funStats :+ simpleFun = linearize.transformToList(fun) + def isAwaitRef(name: Name) = name.toString.startsWith(AnfTransform.this.name.await + "$") + val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) = + mapArgumentss[List[Tree]](fun, argss) { + case Arg(expr, byName, _) if byName /*|| isPure(expr) TODO */ => (Nil, expr) + case Arg(expr@Ident(name), _, _) if isAwaitRef(name) => (Nil, expr) // TODO needed? // not typed, so it eludes the check in `isSafeToInline` + case Arg(expr, _, argName) => + linearize.transformToList(expr) match { + case stats :+ expr1 => + val valDef = defineVal(argName, expr1, expr1.pos) + require(valDef.tpe != null, valDef) + val stats1 = stats :+ valDef + //stats1.foreach(changeOwner(_, currentOwner, currentOwner.owner)) + (stats1, gen.stabilize(gen.mkAttributedIdent(valDef.symbol))) + } + } + val applied = treeInfo.dissectApplied(tree) + val core = if (targs.isEmpty) simpleFun else treeCopy.TypeApply(applied.callee, simpleFun, targs) + val newApply = argExprss.foldLeft(core)(Apply(_, _)).setSymbol(tree.symbol) + val typedNewApply = localTyper.typedPos(tree.pos)(newApply).setType(tree.tpe) + funStats ++ argStatss.flatten.flatten :+ typedNewApply + case Block(stats, expr) => + (stats :+ expr).flatMap(linearize.transformToList) + + case ValDef(mods, name, tpt, rhs) => + if (rhs exists isAwait) { + val stats :+ expr = atOwner(currOwner.owner)(linearize.transformToList(rhs)) + stats.foreach(changeOwner(_, currOwner, currOwner.owner)) + stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr) + } else List(tree) + + case Assign(lhs, rhs) => + val stats :+ expr = linearize.transformToList(rhs) + stats :+ treeCopy.Assign(tree, lhs, expr) + + case If(cond, thenp, elsep) => + val condStats :+ condExpr = linearize.transformToList(cond) + val thenBlock = linearize.transformToBlock(thenp) + val elseBlock = linearize.transformToBlock(elsep) + // Typechecking with `condExpr` as the condition fails if the condition + // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems + // we rely on this call to `typeCheck` descending into the branches. + // But, we can get away with typechecking a throwaway `If` tree with the + // original scrutinee and the new branches, and setting that type on + // the real `If` tree. + val iff = treeCopy.If(tree, condExpr, thenBlock, elseBlock) + condStats :+ iff + + case Match(scrut, cases) => + val scrutStats :+ scrutExpr = linearize.transformToList(scrut) + val caseDefs = cases map { + case CaseDef(pat, guard, body) => + // extract local variables for all names bound in `pat`, and rewrite `body` + // to refer to these. + // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`. + val block = linearize.transformToBlock(body) + val (valDefs, mappings) = (pat collect { + case b@Bind(name, _) => + val vd = defineVal(name.toTermName + AnfTransform.this.name.bindSuffix, gen.mkAttributedStableRef(b.symbol), b.pos) + (vd, (b.symbol, vd.symbol)) + }).unzip + val (from, to) = mappings.unzip + val b@Block(stats1, expr1) = block.substituteSymbols(from, to).asInstanceOf[Block] + val newBlock = treeCopy.Block(b, valDefs ++ stats1, expr1) + treeCopy.CaseDef(tree, pat, guard, newBlock) + } + // Refer to comments the translation of `If` above. + val typedMatch = treeCopy.Match(tree, scrutExpr, caseDefs) + scrutStats :+ typedMatch + + case LabelDef(name, params, rhs) => + List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) + + case TypeApply(fun, targs) => + val funStats :+ simpleFun = linearize.transformToList(fun) + funStats :+ treeCopy.TypeApply(tree, simpleFun, targs) + + case _ => + List(tree) + } + } + } + } +} diff --git a/src/main/scala/scala/async/internal/AsyncAnalysis.scala b/src/main/scala/scala/async/internal/AsyncAnalysis.scala new file mode 100644 index 0000000..62842c9 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncAnalysis.scala @@ -0,0 +1,91 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async.internal + +import scala.reflect.macros.Context +import scala.collection.mutable + +trait AsyncAnalysis { + self: AsyncMacro => + + import global._ + + /** + * Analyze the contents of an `async` block in order to: + * - Report unsupported `await` calls under nested templates, functions, by-name arguments. + * + * Must be called on the original tree, not on the ANF transformed tree. + */ + def reportUnsupportedAwaits(tree: Tree, report: Boolean): Boolean = { + val analyzer = new UnsupportedAwaitAnalyzer(report) + analyzer.traverse(tree) + analyzer.hasUnsupportedAwaits + } + + private class UnsupportedAwaitAnalyzer(report: Boolean) extends AsyncTraverser { + var hasUnsupportedAwaits = false + + override def nestedClass(classDef: ClassDef) { + val kind = if (classDef.symbol.isTrait) "trait" else "class" + reportUnsupportedAwait(classDef, s"nested ${kind}") + } + + override def nestedModule(module: ModuleDef) { + reportUnsupportedAwait(module, "nested object") + } + + override def nestedMethod(defDef: DefDef) { + reportUnsupportedAwait(defDef, "nested method") + } + + override def byNameArgument(arg: Tree) { + reportUnsupportedAwait(arg, "by-name argument") + } + + override def function(function: Function) { + reportUnsupportedAwait(function, "nested function") + } + + override def patMatFunction(tree: Match) { + reportUnsupportedAwait(tree, "nested function") + } + + override def traverse(tree: Tree) { + def containsAwait = tree exists isAwait + tree match { + case Try(_, _, _) if containsAwait => + reportUnsupportedAwait(tree, "try/catch") + super.traverse(tree) + case Return(_) => + abort(tree.pos, "return is illegal within a async block") + case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) => + // TODO lift this restriction + abort(tree.pos, "lazy vals are illegal within an async block") + case _ => + super.traverse(tree) + } + } + + /** + * @return true, if the tree contained an unsupported await. + */ + private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = { + val badAwaits: List[RefTree] = tree collect { + case rt: RefTree if isAwait(rt) => rt + } + badAwaits foreach { + tree => + reportError(tree.pos, s"await must not be used under a $whyUnsupported.") + } + badAwaits.nonEmpty + } + + private def reportError(pos: Position, msg: String) { + hasUnsupportedAwaits = true + if (report) + abort(pos, msg) + } + } +} diff --git a/src/main/scala/scala/async/internal/AsyncBase.scala b/src/main/scala/scala/async/internal/AsyncBase.scala new file mode 100644 index 0000000..2f7e38d --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncBase.scala @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async.internal + +import scala.reflect.internal.annotations.compileTimeOnly +import scala.reflect.macros.Context + +/** + * A base class for the `async` macro. Subclasses must provide: + * + * - Concrete types for a given future system + * - Tree manipulations to create and complete the equivalent of Future and Promise + * in that system. + * - The `async` macro declaration itself, and a forwarder for the macro implementation. + * (The latter is temporarily needed to workaround bug SI-6650 in the macro system) + * + * The default implementation, [[scala.async.Async]], binds the macro to `scala.concurrent._`. + */ +abstract class AsyncBase { + self => + + type FS <: FutureSystem + val futureSystem: FS + + /** + * A call to `await` must be nested in an enclosing `async` block. + * + * A call to `await` does not block the current thread, rather it is a delimiter + * used by the enclosing `async` macro. Code following the `await` + * call is executed asynchronously, when the argument of `await` has been completed. + * + * @param awaitable the future from which a value is awaited. + * @tparam T the type of that value. + * @return the value. + */ + @compileTimeOnly("`await` must be enclosed in an `async` block") + def await[T](awaitable: futureSystem.Fut[T]): T = ??? + + protected[async] def fallbackEnabled = false + + def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { + import c.universe._ + + val asyncMacro = AsyncMacro(c, futureSystem) + + val code = asyncMacro.asyncTransform[T]( + body.tree.asInstanceOf[asyncMacro.global.Tree], + execContext.tree.asInstanceOf[asyncMacro.global.Tree], + fallbackEnabled)(implicitly[c.WeakTypeTag[T]].asInstanceOf[asyncMacro.global.WeakTypeTag[T]]) + + AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}") + c.Expr[futureSystem.Fut[T]](code.asInstanceOf[Tree]) + } +} diff --git a/src/main/scala/scala/async/internal/AsyncId.scala b/src/main/scala/scala/async/internal/AsyncId.scala new file mode 100644 index 0000000..394f587 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncId.scala @@ -0,0 +1,64 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async.internal + +import language.experimental.macros +import scala.reflect.macros.Context +import scala.reflect.internal.SymbolTable + +object AsyncId extends AsyncBase { + lazy val futureSystem = IdentityFutureSystem + type FS = IdentityFutureSystem.type + + def async[T](body: T) = macro asyncIdImpl[T] + + def asyncIdImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit) +} + +/** + * A trivial implementation of [[FutureSystem]] that performs computations + * on the current thread. Useful for testing. + */ +object IdentityFutureSystem extends FutureSystem { + + class Prom[A](var a: A) + + type Fut[A] = A + type ExecContext = Unit + + def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops { + val universe: c.type = c + + import universe._ + + def execContext: Expr[ExecContext] = Expr[Unit](Literal(Constant(()))) + + def promType[A: WeakTypeTag]: Type = weakTypeOf[Prom[A]] + def execContextType: Type = weakTypeOf[Unit] + + def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { + new Prom(null.asInstanceOf[A]) + } + + def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify { + prom.splice.a + } + + def future[A: WeakTypeTag](t: Expr[A])(execContext: Expr[ExecContext]) = t + + def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], + execContext: Expr[ExecContext]): Expr[Unit] = reify { + fun.splice.apply(util.Success(future.splice)) + Expr[Unit](Literal(Constant(()))).splice + } + + def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { + prom.splice.a = value.splice.get + Expr[Unit](Literal(Constant(()))).splice + } + + def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = ??? + } +} diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala new file mode 100644 index 0000000..6b7d031 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncMacro.scala @@ -0,0 +1,29 @@ +package scala.async.internal + +import scala.tools.nsc.Global +import scala.tools.nsc.transform.TypingTransformers + +object AsyncMacro { + def apply(c: reflect.macros.Context, futureSystem0: FutureSystem): AsyncMacro = { + import language.reflectiveCalls + val powerContext = c.asInstanceOf[c.type {val universe: Global; val callsiteTyper: universe.analyzer.Typer}] + new AsyncMacro { + val global: powerContext.universe.type = powerContext.universe + val callSiteTyper: global.analyzer.Typer = powerContext.callsiteTyper + val futureSystem: futureSystem0.type = futureSystem0 + val futureSystemOps: futureSystem.Ops {val universe: global.type} = futureSystem0.mkOps(global) + val macroApplication: global.Tree = c.macroApplication.asInstanceOf[global.Tree] + } + } +} + +private[async] trait AsyncMacro + extends TypingTransformers + with AnfTransform with TransformUtils with Lifter + with ExprBuilder with AsyncTransform with AsyncAnalysis { + + val global: Global + val callSiteTyper: global.analyzer.Typer + val macroApplication: global.Tree + +} diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala new file mode 100644 index 0000000..bdc8664 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -0,0 +1,176 @@ +package scala.async.internal + +trait AsyncTransform { + self: AsyncMacro => + + import global._ + + def asyncTransform[T](body: Tree, execContext: Tree, cpsFallbackEnabled: Boolean) + (implicit resultType: WeakTypeTag[T]): Tree = { + + reportUnsupportedAwaits(body, report = !cpsFallbackEnabled) + + // Transform to A-normal form: + // - no await calls in qualifiers or arguments, + // - if/match only used in statement position. + val anfTree: Block = anfTransform(body) + + val resumeFunTreeDummyBody = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), Literal(Constant(()))) + + val applyDefDefDummyBody: DefDef = { + val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) + DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), Literal(Constant(()))) + } + + val stateMachineType = applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType)) + + val stateMachine: ClassDef = { + val body: List[Tree] = { + val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) + val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree) + val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext) + + val apply0DefDef: DefDef = { + // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. + // See SI-1247 for the the optimization that avoids creatio + DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil)) + } + List(emptyConstructor, stateVar, result, execContextValDef) ++ List(resumeFunTreeDummyBody, applyDefDefDummyBody, apply0DefDef) + } + val template = { + Template(List(stateMachineType), emptyValDef, body) + } + val t = ClassDef(NoMods, name.stateMachineT, Nil, template) + callSiteTyper.typedPos(macroApplication.pos)(Block(t :: Nil, Literal(Constant(())))) + t + } + + val asyncBlock: AsyncBlock = { + val symLookup = new SymLookup(stateMachine.symbol, applyDefDefDummyBody.vparamss.head.head.symbol) + buildAsyncBlock(anfTree, symLookup) + } + + logDiagnostics(anfTree, asyncBlock.asyncStates.map(_.toString)) + + def startStateMachine: Tree = { + val stateMachineSpliced: Tree = spliceMethodBodies( + liftables(asyncBlock.asyncStates), + stateMachine, + asyncBlock.onCompleteHandler[T], + asyncBlock.resumeFunTree[T].rhs + ) + + def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) + + Block(List[Tree]( + stateMachineSpliced, + ValDef(NoMods, name.stateMachine, stateMachineType, Apply(Select(New(Ident(stateMachine.symbol)), nme.CONSTRUCTOR), Nil)), + futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil), selectStateMachine(name.execContext)) + ), + futureSystemOps.promiseToFuture(Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree) + } + + val isSimple = asyncBlock.asyncStates.size == 1 + if (isSimple) + futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }` + else + startStateMachine + } + + def logDiagnostics(anfTree: Tree, states: Seq[String]) { + val pos = macroApplication.pos + def location = try { + pos.source.path + } catch { + case _: UnsupportedOperationException => + pos.toString + } + + AsyncUtils.vprintln(s"In file '$location':") + AsyncUtils.vprintln(s"${macroApplication}") + AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree") + states foreach (s => AsyncUtils.vprintln(s)) + } + + def spliceMethodBodies(liftables: List[Tree], tree: Tree, applyBody: Tree, + resumeBody: Tree): Tree = { + + val liftedSyms = liftables.map(_.symbol).toSet + val stateMachineClass = tree.symbol + liftedSyms.foreach { + sym => + if (sym != null) { + sym.owner = stateMachineClass + if (sym.isModule) + sym.moduleClass.owner = stateMachineClass + } + } + // Replace the ValDefs in the splicee with Assigns to the corresponding lifted + // fields. Similarly, replace references to them with references to the field. + // + // This transform will be only be run on the RHS of `def foo`. + class UseFields extends MacroTypingTransformer { + override def transform(tree: Tree): Tree = tree match { + case _ if currentOwner == stateMachineClass => + super.transform(tree) + case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) => + atOwner(currentOwner) { + val fieldSym = tree.symbol + val set = Assign(gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), transform(rhs)) + changeOwner(set, tree.symbol, currentOwner) + localTyper.typedPos(tree.pos)(set) + } + case _: DefTree if liftedSyms(tree.symbol) => + EmptyTree + case Ident(name) if liftedSyms(tree.symbol) => + val fieldSym = tree.symbol + gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym).setType(tree.tpe) + case _ => + super.transform(tree) + } + } + + val liftablesUseFields = liftables.map { + case vd: ValDef => vd + case x => + val useField = new UseFields() + //.substituteSymbols(fromSyms, toSyms) + useField.atOwner(stateMachineClass)(useField.transform(x)) + } + + tree.children.foreach { + t => + new ChangeOwnerAndModuleClassTraverser(callSiteTyper.context.owner, tree.symbol).traverse(t) + } + val treeSubst = tree + + def fixup(dd: DefDef, body: Tree, ctx: analyzer.Context): Tree = { + val spliceeAnfFixedOwnerSyms = body + val useField = new UseFields() + val newRhs = useField.atOwner(dd.symbol)(useField.transform(spliceeAnfFixedOwnerSyms)) + val typer = global.analyzer.newTyper(ctx.make(dd, dd.symbol)) + treeCopy.DefDef(dd, dd.mods, dd.name, dd.tparams, dd.vparamss, dd.tpt, typer.typed(newRhs)) + } + + liftablesUseFields.foreach(t => if (t.symbol != null) stateMachineClass.info.decls.enter(t.symbol)) + + val result0 = transformAt(treeSubst) { + case t@Template(parents, self, stats) => + (ctx: analyzer.Context) => { + treeCopy.Template(t, parents, self, liftablesUseFields ++ stats) + } + } + val result = transformAt(result0) { + case dd@DefDef(_, name.apply, _, List(List(_)), _, _) if dd.symbol.owner == stateMachineClass => + (ctx: analyzer.Context) => + val typedTree = fixup(dd, changeOwner(applyBody, callSiteTyper.context.owner, dd.symbol), ctx) + typedTree + case dd@DefDef(_, name.resume, _, _, _, _) if dd.symbol.owner == stateMachineClass => + (ctx: analyzer.Context) => + val changed = changeOwner(resumeBody, callSiteTyper.context.owner, dd.symbol) + val res = fixup(dd, changed, ctx) + res + } + result + } +} diff --git a/src/main/scala/scala/async/internal/AsyncUtils.scala b/src/main/scala/scala/async/internal/AsyncUtils.scala new file mode 100644 index 0000000..8700bd6 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncUtils.scala @@ -0,0 +1,16 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ +package scala.async.internal + +object AsyncUtils { + + private def enabled(level: String) = sys.props.getOrElse(s"scala.async.$level", "false").equalsIgnoreCase("true") + + private def verbose = enabled("debug") + private def trace = enabled("trace") + + private[async] def vprintln(s: => Any): Unit = if (verbose) println(s"[async] $s") + + private[async] def trace(s: => Any): Unit = if (trace) println(s"[async] $s") +} diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala new file mode 100644 index 0000000..1ce30e6 --- /dev/null +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -0,0 +1,388 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ +package scala.async.internal + +import scala.reflect.macros.Context +import scala.collection.mutable.ListBuffer +import collection.mutable +import language.existentials +import scala.reflect.api.Universe +import scala.reflect.api +import scala.Some + +trait ExprBuilder { + builder: AsyncMacro => + + import global._ + import defn._ + + val futureSystem: FutureSystem + val futureSystemOps: futureSystem.Ops { val universe: global.type } + + val stateAssigner = new StateAssigner + val labelDefStates = collection.mutable.Map[Symbol, Int]() + + trait AsyncState { + def state: Int + + def mkHandlerCaseForState: CaseDef + + def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None + + def stats: List[Tree] + + final def allStats: List[Tree] = this match { + case a: AsyncStateWithAwait => stats :+ a.awaitable.resultValDef + case _ => stats + } + + final def body: Tree = stats match { + case stat :: Nil => stat + case init :+ last => Block(init, last) + } + } + + /** A sequence of statements the concludes with a unconditional transition to `nextState` */ + final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup) + extends AsyncState { + + def mkHandlerCaseForState: CaseDef = + mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup)) + + override val toString: String = + s"AsyncState #$state, next = $nextState" + } + + /** A sequence of statements with a conditional transition to the next state, which will represent + * a branch of an `if` or a `match`. + */ + final class AsyncStateWithoutAwait(val stats: List[Tree], val state: Int) extends AsyncState { + override def mkHandlerCaseForState: CaseDef = + mkHandlerCase(state, stats) + + override val toString: String = + s"AsyncStateWithoutAwait #$state" + } + + /** A sequence of statements that concludes with an `await` call. The `onComplete` + * handler will unconditionally transition to `nestState`.`` + */ + final class AsyncStateWithAwait(val stats: List[Tree], val state: Int, nextState: Int, + val awaitable: Awaitable, symLookup: SymLookup) + extends AsyncState { + + override def mkHandlerCaseForState: CaseDef = { + val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr), + Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree + mkHandlerCase(state, stats :+ callOnComplete) + } + + override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = { + val tryGetTree = + Assign( + Ident(awaitable.resultName), + TypeApply(Select(Select(Ident(symLookup.applyTrParam), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) + ) + + /* if (tr.isFailure) + * result.complete(tr.asInstanceOf[Try[T]]) + * else { + * = tr.get.asInstanceOf[] + * + * + * } + */ + val ifIsFailureTree = + If(Select(Ident(symLookup.applyTrParam), Try_isFailure), + futureSystemOps.completeProm[T]( + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), + Expr[scala.util.Try[T]]( + TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")), + List(TypeTree(weakTypeOf[scala.util.Try[T]]))))).tree, + Block(List(tryGetTree, mkStateTree(nextState, symLookup)), mkResumeApply(symLookup)) + ) + + Some(mkHandlerCase(state, List(ifIsFailureTree))) + } + + override val toString: String = + s"AsyncStateWithAwait #$state, next = $nextState" + } + + /* + * Builder for a single state of an async method. + */ + final class AsyncStateBuilder(state: Int, private val symLookup: SymLookup) { + /* Statements preceding an await call. */ + private val stats = ListBuffer[Tree]() + /** The state of the target of a LabelDef application (while loop jump) */ + private var nextJumpState: Option[Int] = None + + def +=(stat: Tree): this.type = { + assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat") + def addStat() = stats += stat + stat match { + case Apply(fun, Nil) => + labelDefStates get fun.symbol match { + case Some(nextState) => nextJumpState = Some(nextState) + case None => addStat() + } + case _ => addStat() + } + this + } + + def resultWithAwait(awaitable: Awaitable, + nextState: Int): AsyncState = { + val effectiveNextState = nextJumpState.getOrElse(nextState) + new AsyncStateWithAwait(stats.toList, state, effectiveNextState, awaitable, symLookup) + } + + def resultSimple(nextState: Int): AsyncState = { + val effectiveNextState = nextJumpState.getOrElse(nextState) + new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup) + } + + def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { + def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup)) + this += If(condTree, mkBranch(thenState), mkBranch(elseState)) + new AsyncStateWithoutAwait(stats.toList, state) + } + + /** + * Build `AsyncState` ending with a match expression. + * + * The cases of the match simply resume at the state of their corresponding right-hand side. + * + * @param scrutTree tree of the scrutinee + * @param cases list of case definitions + * @param caseStates starting state of the right-hand side of the each case + * @return an `AsyncState` representing the match expression + */ + def resultWithMatch(scrutTree: Tree, cases: List[CaseDef], caseStates: List[Int], symLookup: SymLookup): AsyncState = { + // 1. build list of changed cases + val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { + case CaseDef(pat, guard, rhs) => + val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal) + CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num), symLookup), mkResumeApply(symLookup))) + } + // 2. insert changed match tree at the end of the current state + this += Match(scrutTree, newCases) + new AsyncStateWithoutAwait(stats.toList, state) + } + + def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = { + this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup)) + new AsyncStateWithoutAwait(stats.toList, state) + } + + override def toString: String = { + val statsBeforeAwait = stats.mkString("\n") + s"ASYNC STATE:\n$statsBeforeAwait" + } + } + + /** + * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`). + * + * @param stats a list of expressions + * @param expr the last expression of the block + * @param startState the start state + * @param endState the state to continue with + */ + final private class AsyncBlockBuilder(stats: List[Tree], expr: Tree, startState: Int, endState: Int, + private val symLookup: SymLookup) { + val asyncStates = ListBuffer[AsyncState]() + + var stateBuilder = new AsyncStateBuilder(startState, symLookup) + var currState = startState + + /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ + def checkForUnsupportedAwait(tree: Tree) = if (tree exists { + case Apply(fun, _) if isAwait(fun) => true + case _ => false + }) abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException + + def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = { + val (nestedStats, nestedExpr) = statsAndExpr(nestedTree) + new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, symLookup) + } + + import stateAssigner.nextState + + // populate asyncStates + for (stat <- stats) stat match { + // the val name = await(..) pattern + case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => + val afterAwaitState = nextState() + val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd) + asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await + currState = afterAwaitState + stateBuilder = new AsyncStateBuilder(currState, symLookup) + + case If(cond, thenp, elsep) if stat exists isAwait => + checkForUnsupportedAwait(cond) + + val thenStartState = nextState() + val elseStartState = nextState() + val afterIfState = nextState() + + asyncStates += + // the two Int arguments are the start state of the then branch and the else branch, respectively + stateBuilder.resultWithIf(cond, thenStartState, elseStartState) + + List((thenp, thenStartState), (elsep, elseStartState)) foreach { + case (branchTree, state) => + val builder = nestedBlockBuilder(branchTree, state, afterIfState) + asyncStates ++= builder.asyncStates + } + + currState = afterIfState + stateBuilder = new AsyncStateBuilder(currState, symLookup) + + case Match(scrutinee, cases) if stat exists isAwait => + checkForUnsupportedAwait(scrutinee) + + val caseStates = cases.map(_ => nextState()) + val afterMatchState = nextState() + + asyncStates += + stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup) + + for ((cas, num) <- cases.zipWithIndex) { + val (stats, expr) = statsAndExpr(cas.body) + val stats1 = stats.dropWhile(isSyntheticBindVal) + val builder = nestedBlockBuilder(Block(stats1, expr), caseStates(num), afterMatchState) + asyncStates ++= builder.asyncStates + } + + currState = afterMatchState + stateBuilder = new AsyncStateBuilder(currState, symLookup) + + case ld@LabelDef(name, params, rhs) if rhs exists isAwait => + val startLabelState = nextState() + val afterLabelState = nextState() + asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup) + labelDefStates(ld.symbol) = startLabelState + val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState) + asyncStates ++= builder.asyncStates + + currState = afterLabelState + stateBuilder = new AsyncStateBuilder(currState, symLookup) + case _ => + checkForUnsupportedAwait(stat) + stateBuilder += stat + } + // complete last state builder (representing the expressions after the last await) + stateBuilder += expr + val lastState = stateBuilder.resultSimple(endState) + asyncStates += lastState + } + + trait AsyncBlock { + def asyncStates: List[AsyncState] + + def onCompleteHandler[T: WeakTypeTag]: Tree + + def resumeFunTree[T]: DefDef + } + + case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) { + def stateMachineMember(name: TermName): Symbol = + stateMachineClass.info.member(name) + def memberRef(name: TermName) = gen.mkAttributedRef(stateMachineMember(name)) + } + + def buildAsyncBlock(block: Block, symLookup: SymLookup): AsyncBlock = { + val Block(stats, expr) = block + val startState = stateAssigner.nextState() + val endState = Int.MaxValue + + val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, symLookup) + + new AsyncBlock { + def asyncStates = blockBuilder.asyncStates.toList + + def mkCombinedHandlerCases[T]: List[CaseDef] = { + val caseForLastState: CaseDef = { + val lastState = asyncStates.last + val lastStateBody = Expr[T](lastState.body) + val rhs = futureSystemOps.completeProm( + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Success(lastStateBody.splice))) + mkHandlerCase(lastState.state, rhs.tree) + } + asyncStates.toList match { + case s :: Nil => + List(caseForLastState) + case _ => + val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState + initCases :+ caseForLastState + } + } + + val initStates = asyncStates.init + + /** + * def resume(): Unit = { + * try { + * state match { + * case 0 => { + * f11 = exprReturningFuture + * f11.onComplete(onCompleteHandler)(context) + * } + * ... + * } + * } catch { + * case NonFatal(t) => result.failure(t) + * } + * } + */ + def resumeFunTree[T]: DefDef = + DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), + Try( + Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T]), + List( + CaseDef( + Bind(name.t, Ident(nme.WILDCARD)), + Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), + Block(List({ + val t = Expr[Throwable](Ident(name.t)) + futureSystemOps.completeProm[T]( + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Failure(t.splice))).tree + }), literalUnit))), EmptyTree)) + + /** + * // assumes tr: Try[Any] is in scope. + * // + * state match { + * case 0 => { + * x11 = tr.get.asInstanceOf[Double]; + * state = 1; + * resume() + * } + */ + def onCompleteHandler[T: WeakTypeTag]: Tree = Match(symLookup.memberRef(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList) + } + } + + private def isSyntheticBindVal(tree: Tree) = tree match { + case vd@ValDef(_, lname, _, Ident(rname)) => lname.toString.contains(name.bindSuffix) + case _ => false + } + + case class Awaitable(expr: Tree, resultName: Symbol, resultType: Type, resultValDef: ValDef) + + private def mkResumeApply(symLookup: SymLookup) = Apply(symLookup.memberRef(name.resume), Nil) + + private def mkStateTree(nextState: Int, symLookup: SymLookup): Tree = + Assign(symLookup.memberRef(name.state), Literal(Constant(nextState))) + + private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef = + mkHandlerCase(num, Block(rhs, literalUnit)) + + private def mkHandlerCase(num: Int, rhs: Tree): CaseDef = + CaseDef(Literal(Constant(num)), EmptyTree, rhs) + + private def literalUnit = Literal(Constant(())) +} diff --git a/src/main/scala/scala/async/internal/FutureSystem.scala b/src/main/scala/scala/async/internal/FutureSystem.scala new file mode 100644 index 0000000..101b7bf --- /dev/null +++ b/src/main/scala/scala/async/internal/FutureSystem.scala @@ -0,0 +1,106 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ +package scala.async.internal + +import scala.language.higherKinds + +import scala.reflect.macros.Context +import scala.reflect.internal.SymbolTable + +/** + * An abstraction over a future system. + * + * Used by the macro implementations in [[scala.async.AsyncBase]] to + * customize the code generation. + * + * The API mirrors that of `scala.concurrent.Future`, see the instance + * [[ScalaConcurrentFutureSystem]] for an example of how + * to implement this. + */ +trait FutureSystem { + /** A container to receive the final value of the computation */ + type Prom[A] + /** A (potentially in-progress) computation */ + type Fut[A] + /** An execution context, required to create or register an on completion callback on a Future. */ + type ExecContext + + trait Ops { + val universe: reflect.internal.SymbolTable + + import universe._ + def Expr[T: WeakTypeTag](tree: Tree): Expr[T] = universe.Expr[T](rootMirror, universe.FixedMirrorTreeCreator(rootMirror, tree)) + + def promType[A: WeakTypeTag]: Type + def execContextType: Type + + /** Create an empty promise */ + def createProm[A: WeakTypeTag]: Expr[Prom[A]] + + /** Extract a future from the given promise. */ + def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]): Expr[Fut[A]] + + /** Construct a future to asynchronously compute the given expression */ + def future[A: WeakTypeTag](a: Expr[A])(execContext: Expr[ExecContext]): Expr[Fut[A]] + + /** Register an call back to run on completion of the given future */ + def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], + execContext: Expr[ExecContext]): Expr[Unit] + + /** Complete a promise with a value */ + def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] + + def spawn(tree: Tree, execContext: Tree): Tree = + future(Expr[Unit](tree))(Expr[ExecContext](execContext)).tree + + // TODO Why is this needed? + def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] + } + + def mkOps(c: SymbolTable): Ops { val universe: c.type } +} + +object ScalaConcurrentFutureSystem extends FutureSystem { + + import scala.concurrent._ + + type Prom[A] = Promise[A] + type Fut[A] = Future[A] + type ExecContext = ExecutionContext + + def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops { + val universe: c.type = c + + import universe._ + + def promType[A: WeakTypeTag]: Type = weakTypeOf[Promise[A]] + def execContextType: Type = weakTypeOf[ExecutionContext] + + def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { + Promise[A]() + } + + def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify { + prom.splice.future + } + + def future[A: WeakTypeTag](a: Expr[A])(execContext: Expr[ExecContext]) = reify { + Future(a.splice)(execContext.splice) + } + + def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], + execContext: Expr[ExecContext]): Expr[Unit] = reify { + future.splice.onComplete(fun.splice)(execContext.splice) + } + + def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { + prom.splice.complete(value.splice) + Expr[Unit](Literal(Constant(()))).splice + } + + def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = reify { + future.splice.asInstanceOf[Fut[A]] + } + } +} diff --git a/src/main/scala/scala/async/internal/Lifter.scala b/src/main/scala/scala/async/internal/Lifter.scala new file mode 100644 index 0000000..f49dcbb --- /dev/null +++ b/src/main/scala/scala/async/internal/Lifter.scala @@ -0,0 +1,150 @@ +package scala.async.internal + +trait Lifter { + self: AsyncMacro => + import global._ + + /** + * Identify which DefTrees are used (including transitively) which are declared + * in some state but used (including transitively) in another state. + * + * These will need to be lifted to class members of the state machine. + */ + def liftables(asyncStates: List[AsyncState]): List[Tree] = { + object companionship { + private val companions = collection.mutable.Map[Symbol, Symbol]() + private val companionsInverse = collection.mutable.Map[Symbol, Symbol]() + private def record(sym1: Symbol, sym2: Symbol) { + companions(sym1) = sym2 + companions(sym2) = sym1 + } + + def record(defs: List[Tree]) { + // Keep note of local companions so we rename them consistently + // when lifting. + val comps = for { + cd@ClassDef(_, _, _, _) <- defs + md@ModuleDef(_, _, _) <- defs + if (cd.name.toTermName == md.name) + } record(cd.symbol, md.symbol) + } + def companionOf(sym: Symbol): Symbol = { + companions.get(sym).orElse(companionsInverse.get(sym)).getOrElse(NoSymbol) + } + } + + + val defs: Map[Tree, Int] = { + /** Collect the DefTrees directly enclosed within `t` that have the same owner */ + def collectDirectlyEnclosedDefs(t: Tree): List[DefTree] = t match { + case dt: DefTree => dt :: Nil + case _: Function => Nil + case t => + val childDefs = t.children.flatMap(collectDirectlyEnclosedDefs(_)) + companionship.record(childDefs) + childDefs + } + asyncStates.flatMap { + asyncState => + val defs = collectDirectlyEnclosedDefs(Block(asyncState.allStats: _*)) + defs.map((_, asyncState.state)) + }.toMap + } + + // In which block are these symbols defined? + val symToDefiningState: Map[Symbol, Int] = defs.map { + case (k, v) => (k.symbol, v) + } + + // The definitions trees + val symToTree: Map[Symbol, Tree] = defs.map { + case (k, v) => (k.symbol, k) + } + + // The direct references of each definition tree + val defSymToReferenced: Map[Symbol, List[Symbol]] = defs.keys.map { + case tree => (tree.symbol, tree.collect { + case rt: RefTree if symToDefiningState.contains(rt.symbol) => rt.symbol + }) + }.toMap + + // The direct references of each block, excluding references of `DefTree`-s which + // are already accounted for. + val stateIdToDirectlyReferenced: Map[Int, List[Symbol]] = { + val refs: List[(Int, Symbol)] = asyncStates.flatMap( + asyncState => asyncState.stats.filterNot(_.isDef).flatMap(_.collect { + case rt: RefTree if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol) + }) + ) + toMultiMap(refs) + } + + def liftableSyms: Set[Symbol] = { + val liftableMutableSet = collection.mutable.Set[Symbol]() + def markForLift(sym: Symbol) { + if (!liftableMutableSet(sym)) { + liftableMutableSet += sym + + // Only mark transitive references of defs, modules and classes. The RHS of lifted vals/vars + // stays in its original location, so things that it refers to need not be lifted. + if (!(sym.isVal || sym.isVar)) + defSymToReferenced(sym).foreach(sym2 => markForLift(sym2)) + } + } + // Start things with DefTrees directly referenced from statements from other states... + val liftableStatementRefs: List[Symbol] = stateIdToDirectlyReferenced.toList.flatMap { + case (i, syms) => syms.filter(sym => symToDefiningState(sym) != i) + } + // .. and likewise for DefTrees directly referenced by other DefTrees from other states + val liftableRefsOfDefTrees = defSymToReferenced.toList.flatMap { + case (referee, referents) => referents.filter(sym => symToDefiningState(sym) != symToDefiningState(referee)) + } + // Mark these for lifting, which will follow transitive references. + (liftableStatementRefs ++ liftableRefsOfDefTrees).foreach(markForLift) + liftableMutableSet.toSet + } + + val lifted = liftableSyms.map(symToTree).toList.map { + case vd@ValDef(_, _, tpt, rhs) => + import reflect.internal.Flags._ + val sym = vd.symbol + sym.setFlag(MUTABLE | STABLE | PRIVATE | LOCAL) + sym.name = name.fresh(sym.name.toTermName) + sym.modifyInfo(_.deconst) + ValDef(vd.symbol, gen.mkZero(vd.symbol.info)).setPos(vd.pos) + case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) => + import reflect.internal.Flags._ + val sym = dd.symbol + sym.name = this.name.fresh(sym.name.toTermName) + sym.setFlag(PRIVATE | LOCAL) + DefDef(dd.symbol, rhs).setPos(dd.pos) + case cd@ClassDef(_, _, _, impl) => + import reflect.internal.Flags._ + val sym = cd.symbol + sym.name = newTypeName(name.fresh(sym.name.toString).toString) + companionship.companionOf(cd.symbol) match { + case NoSymbol => + case moduleSymbol => + moduleSymbol.name = sym.name.toTermName + moduleSymbol.moduleClass.name = moduleSymbol.name.toTypeName + } + ClassDef(cd.symbol, impl).setPos(cd.pos) + case md@ModuleDef(_, _, impl) => + import reflect.internal.Flags._ + val sym = md.symbol + companionship.companionOf(md.symbol) match { + case NoSymbol => + sym.name = name.fresh(sym.name.toTermName) + sym.moduleClass.name = sym.name.toTypeName + case classSymbol => // will be renamed by `case ClassDef` above. + } + ModuleDef(md.symbol, impl).setPos(md.pos) + case td@TypeDef(_, _, _, rhs) => + import reflect.internal.Flags._ + val sym = td.symbol + sym.name = newTypeName(name.fresh(sym.name.toString).toString) + TypeDef(td.symbol, rhs).setPos(td.pos) + } + lifted + } +} diff --git a/src/main/scala/scala/async/internal/StateAssigner.scala b/src/main/scala/scala/async/internal/StateAssigner.scala new file mode 100644 index 0000000..cdde7a4 --- /dev/null +++ b/src/main/scala/scala/async/internal/StateAssigner.scala @@ -0,0 +1,14 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async.internal + +private[async] final class StateAssigner { + private var current = -1 + + def nextState(): Int = { + current += 1 + current + } +} diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala new file mode 100644 index 0000000..2582c91 --- /dev/null +++ b/src/main/scala/scala/async/internal/TransformUtils.scala @@ -0,0 +1,251 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ +package scala.async.internal + +import scala.reflect.macros.Context +import reflect.ClassTag +import scala.reflect.macros.runtime.AbortMacroException + +/** + * Utilities used in both `ExprBuilder` and `AnfTransform`. + */ +private[async] trait TransformUtils { + self: AsyncMacro => + + import global._ + + object name { + val resume = newTermName("resume") + val apply = newTermName("apply") + val matchRes = "matchres" + val ifRes = "ifres" + val await = "await" + val bindSuffix = "$bind" + + val state = newTermName("state") + val result = newTermName("result") + val execContext = newTermName("execContext") + val stateMachine = newTermName(fresh("stateMachine")) + val stateMachineT = stateMachine.toTypeName + val tr = newTermName("tr") + val t = newTermName("throwable") + + def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) + + def fresh(name: String): String = currentUnit.freshTermName("" + name + "$").toString + } + + def isAwait(fun: Tree) = + fun.symbol == defn.Async_await + + private lazy val Boolean_ShortCircuits: Set[Symbol] = { + import definitions.BooleanClass + def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName) + val Boolean_&& = BooleanTermMember("&&") + val Boolean_|| = BooleanTermMember("||") + Set(Boolean_&&, Boolean_||) + } + + private def isByName(fun: Tree): ((Int, Int) => Boolean) = { + if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true + else { + val paramss = fun.tpe.paramss + val byNamess = paramss.map(_.map(_.isByNameParam)) + (i, j) => util.Try(byNamess(i)(j)).getOrElse(false) + } + } + private def argName(fun: Tree): ((Int, Int) => String) = { + val paramss = fun.tpe.paramss + val namess = paramss.map(_.map(_.name.toString)) + (i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}") + } + + def Expr[A: WeakTypeTag](t: Tree) = global.Expr[A](rootMirror, new FixedMirrorTreeCreator(rootMirror, t)) + + object defn { + def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { + Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) + } + + def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify { + self.splice.contains(elem.splice) + } + + def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify { + self.splice.apply(arg.splice) + } + + def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify { + self.splice == other.splice + } + + def mkTry_get[A](self: Expr[util.Try[A]]) = reify { + self.splice.get + } + + val TryClass = rootMirror.staticClass("scala.util.Try") + val Try_get = TryClass.typeSignature.member(newTermName("get")).ensuring(_ != NoSymbol) + val Try_isFailure = TryClass.typeSignature.member(newTermName("isFailure")).ensuring(_ != NoSymbol) + val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe)) + val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal") + val AsyncClass = rootMirror.staticClass("scala.async.internal.AsyncBase") + val Async_await = AsyncClass.typeSignature.member(newTermName("await")).ensuring(_ != NoSymbol) + } + + def isSafeToInline(tree: Tree) = { + treeInfo.isExprSafeToInline(tree) + } + + /** Map a list of arguments to: + * - A list of argument Trees + * - A list of auxillary results. + * + * The function unwraps and rewraps the `arg :_*` construct. + * + * @param args The original argument trees + * @param f A function from argument (with '_*' unwrapped) and argument index to argument. + * @tparam A The type of the auxillary result + */ + private def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = { + args match { + case args :+ Typed(tree, Ident(tpnme.WILDCARD_STAR)) => + val (a, argExprs :+ lastArgExpr) = (args :+ tree).zipWithIndex.map(f.tupled).unzip + val exprs = argExprs :+ Typed(lastArgExpr, Ident(tpnme.WILDCARD_STAR)).setPos(lastArgExpr.pos) + (a, exprs) + case args => + args.zipWithIndex.map(f.tupled).unzip + } + } + + case class Arg(expr: Tree, isByName: Boolean, argName: String) + + /** + * Transform a list of argument lists, producing the transformed lists, and lists of auxillary + * results. + * + * The function `f` need not concern itself with varargs arguments e.g (`xs : _*`). It will + * receive `xs`, and it's result will be re-wrapped as `f(xs) : _*`. + * + * @param fun The function being applied + * @param argss The argument lists + * @return (auxillary results, mapped argument trees) + */ + def mapArgumentss[A](fun: Tree, argss: List[List[Tree]])(f: Arg => (A, Tree)): (List[List[A]], List[List[Tree]]) = { + val isByNamess: (Int, Int) => Boolean = isByName(fun) + val argNamess: (Int, Int) => String = argName(fun) + argss.zipWithIndex.map { case (args, i) => + mapArguments[A](args) { + (tree, j) => f(Arg(tree, isByNamess(i, j), argNamess(i, j))) + } + }.unzip + } + + + def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match { + case Block(stats, expr) => (stats, expr) + case _ => (List(tree), Literal(Constant(()))) + } + + def emptyConstructor: DefDef = { + val emptySuperCall = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil) + DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), Literal(Constant(())))) + } + + def applied(className: String, types: List[Type]): AppliedTypeTree = + AppliedTypeTree(Ident(rootMirror.staticClass(className)), types.map(TypeTree(_))) + + /** Descends into the regions of the tree that are subject to the + * translation to a state machine by `async`. When a nested template, + * function, or by-name argument is encountered, the descent stops, + * and `nestedClass` etc are invoked. + */ + trait AsyncTraverser extends Traverser { + def nestedClass(classDef: ClassDef) { + } + + def nestedModule(module: ModuleDef) { + } + + def nestedMethod(module: DefDef) { + } + + def byNameArgument(arg: Tree) { + } + + def function(function: Function) { + } + + def patMatFunction(tree: Match) { + } + + override def traverse(tree: Tree) { + tree match { + case cd: ClassDef => nestedClass(cd) + case md: ModuleDef => nestedModule(md) + case dd: DefDef => nestedMethod(dd) + case fun: Function => function(fun) + case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions` + case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty => + val isInByName = isByName(fun) + for ((args, i) <- argss.zipWithIndex) { + for ((arg, j) <- args.zipWithIndex) { + if (!isInByName(i, j)) traverse(arg) + else byNameArgument(arg) + } + } + traverse(fun) + case _ => super.traverse(tree) + } + } + } + + def abort(pos: Position, msg: String) = throw new AbortMacroException(pos, msg) + + abstract class MacroTypingTransformer extends TypingTransformer(callSiteTyper.context.unit) { + currentOwner = callSiteTyper.context.owner + + def currOwner: Symbol = currentOwner + + localTyper = global.analyzer.newTyper(callSiteTyper.context.make(unit = callSiteTyper.context.unit)) + } + + def transformAt(tree: Tree)(f: PartialFunction[Tree, (analyzer.Context => Tree)]) = { + object trans extends MacroTypingTransformer { + override def transform(tree: Tree): Tree = { + if (f.isDefinedAt(tree)) { + f(tree)(localTyper.context) + } else super.transform(tree) + } + } + trans.transform(tree) + } + + def changeOwner(tree: Tree, oldOwner: Symbol, newOwner: Symbol): tree.type = { + new ChangeOwnerAndModuleClassTraverser(oldOwner, newOwner).traverse(tree) + tree + } + + class ChangeOwnerAndModuleClassTraverser(oldowner: Symbol, newowner: Symbol) + extends ChangeOwnerTraverser(oldowner, newowner) { + + override def traverse(tree: Tree) { + tree match { + case _: DefTree => change(tree.symbol.moduleClass) + case _ => + } + super.traverse(tree) + } + } + + def toMultiMap[A, B](as: Iterable[(A, B)]): Map[A, List[B]] = + as.toList.groupBy(_._1).mapValues(_.map(_._2).toList).toMap + + // Attributed version of `TreeGen#mkCastPreservingAnnotations` + def mkAttributedCastPreservingAnnotations(tree: Tree, tp: Type): Tree = { + atPos(tree.pos) { + val casted = gen.mkAttributedCast(tree, tp.withoutAnnotations.dealias) + Typed(casted, TypeTree(tp)).setType(tp) + } + } +} -- cgit v1.2.3