diff options
author | Jason Zaugg <jzaugg@gmail.com> | 2013-07-07 10:48:11 +1000 |
---|---|---|
committer | Jason Zaugg <jzaugg@gmail.com> | 2013-07-07 10:48:11 +1000 |
commit | 2d8506a64392cd7192b6831c38798cc9a7c8bfed (patch) | |
tree | 84eafcf1a9a179eeaa97dd1e3595c18351b2b814 /src/main/scala/scala/async/internal | |
parent | c60c38ca6098402f7a9cc6d6746b664bb2b1306c (diff) | |
download | scala-async-2d8506a64392cd7192b6831c38798cc9a7c8bfed.tar.gz scala-async-2d8506a64392cd7192b6831c38798cc9a7c8bfed.tar.bz2 scala-async-2d8506a64392cd7192b6831c38798cc9a7c8bfed.zip |
Move implementation details to scala.async.internal._.
If we intend to keep CPS fallback around for any length of time
it should probably move there too.
Diffstat (limited to 'src/main/scala/scala/async/internal')
-rw-r--r-- | src/main/scala/scala/async/internal/AnfTransform.scala | 253 | ||||
-rw-r--r-- | src/main/scala/scala/async/internal/AsyncAnalysis.scala | 91 | ||||
-rw-r--r-- | src/main/scala/scala/async/internal/AsyncBase.scala | 58 | ||||
-rw-r--r-- | src/main/scala/scala/async/internal/AsyncId.scala | 64 | ||||
-rw-r--r-- | src/main/scala/scala/async/internal/AsyncMacro.scala | 29 | ||||
-rw-r--r-- | src/main/scala/scala/async/internal/AsyncTransform.scala | 176 | ||||
-rw-r--r-- | src/main/scala/scala/async/internal/AsyncUtils.scala | 16 | ||||
-rw-r--r-- | src/main/scala/scala/async/internal/ExprBuilder.scala | 388 | ||||
-rw-r--r-- | src/main/scala/scala/async/internal/FutureSystem.scala | 106 | ||||
-rw-r--r-- | src/main/scala/scala/async/internal/Lifter.scala | 150 | ||||
-rw-r--r-- | src/main/scala/scala/async/internal/StateAssigner.scala | 14 | ||||
-rw-r--r-- | src/main/scala/scala/async/internal/TransformUtils.scala | 251 |
12 files changed, 1596 insertions, 0 deletions
diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala new file mode 100644 index 0000000..80f8161 --- /dev/null +++ b/src/main/scala/scala/async/internal/AnfTransform.scala @@ -0,0 +1,253 @@ + +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +package scala.async.internal + +import scala.tools.nsc.Global +import scala.Predef._ + +private[async] trait AnfTransform { + self: AsyncMacro => + + import global._ + import reflect.internal.Flags._ + + def anfTransform(tree: Tree): Block = { + // Must prepend the () for issue #31. + val block = callSiteTyper.typedPos(tree.pos)(Block(List(Literal(Constant(()))), tree)).setType(tree.tpe) + + new SelectiveAnfTransform().transform(block) + } + + sealed abstract class AnfMode + + case object Anf extends AnfMode + + case object Linearizing extends AnfMode + + final class SelectiveAnfTransform extends MacroTypingTransformer { + var mode: AnfMode = Anf + + def blockToList(tree: Tree): List[Tree] = tree match { + case Block(stats, expr) => stats :+ expr + case t => t :: Nil + } + + def listToBlock(trees: List[Tree]): Block = trees match { + case trees @ (init :+ last) => + val pos = trees.map(_.pos).reduceLeft(_ union _) + Block(init, last).setType(last.tpe).setPos(pos) + } + + override def transform(tree: Tree): Block = { + def anfLinearize: Block = { + val trees: List[Tree] = mode match { + case Anf => anf._transformToList(tree) + case Linearizing => linearize._transformToList(tree) + } + listToBlock(trees) + } + tree match { + case _: ValDef | _: DefDef | _: Function | _: ClassDef | _: TypeDef => + atOwner(tree.symbol)(anfLinearize) + case _: ModuleDef => + atOwner(tree.symbol.moduleClass orElse tree.symbol)(anfLinearize) + case _ => + anfLinearize + } + } + + private object linearize { + def transformToList(tree: Tree): List[Tree] = { + mode = Linearizing; blockToList(transform(tree)) + } + + def transformToBlock(tree: Tree): Block = listToBlock(transformToList(tree)) + + def _transformToList(tree: Tree): List[Tree] = trace(tree) { + val stats :+ expr = anf.transformToList(tree) + expr match { + case Apply(fun, args) if isAwait(fun) => + val valDef = defineVal(name.await, expr, tree.pos) + stats :+ valDef :+ gen.mkAttributedStableRef(valDef.symbol) + + case If(cond, thenp, elsep) => + // if type of if-else is Unit don't introduce assignment, + // but add Unit value to bring it into form expected by async transform + if (expr.tpe =:= definitions.UnitTpe) { + stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(()))) + } else { + val varDef = defineVar(name.ifRes, expr.tpe, tree.pos) + def branchWithAssign(orig: Tree) = localTyper.typedPos(orig.pos) { + def cast(t: Tree) = mkAttributedCastPreservingAnnotations(t, varDef.symbol.tpe) + orig match { + case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr))) + case _ => Assign(Ident(varDef.symbol), cast(orig)) + } + }.setType(orig.tpe) + val ifWithAssign = treeCopy.If(tree, cond, branchWithAssign(thenp), branchWithAssign(elsep)) + stats :+ varDef :+ ifWithAssign :+ gen.mkAttributedStableRef(varDef.symbol) + } + + case Match(scrut, cases) => + // if type of match is Unit don't introduce assignment, + // but add Unit value to bring it into form expected by async transform + if (expr.tpe =:= definitions.UnitTpe) { + stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(()))) + } + else { + val varDef = defineVar(name.matchRes, expr.tpe, tree.pos) + def typedAssign(lhs: Tree) = + localTyper.typedPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, varDef.symbol.tpe))) + val casesWithAssign = cases map { + case cd@CaseDef(pat, guard, body) => + val newBody = body match { + case b@Block(caseStats, caseExpr) => treeCopy.Block(b, caseStats, typedAssign(caseExpr)) + case _ => typedAssign(body) + } + treeCopy.CaseDef(cd, pat, guard, newBody) + } + val matchWithAssign = treeCopy.Match(tree, scrut, casesWithAssign) + require(matchWithAssign.tpe != null, matchWithAssign) + stats :+ varDef :+ matchWithAssign :+ gen.mkAttributedStableRef(varDef.symbol) + } + case _ => + stats :+ expr + } + } + + private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = { + val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(tp) + ValDef(sym, gen.mkZero(tp)).setType(NoType).setPos(pos) + } + } + + private object trace { + private var indent = -1 + + def indentString = " " * indent + + def apply[T](args: Any)(t: => T): T = { + def prefix = mode.toString.toLowerCase + indent += 1 + def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127) + try { + AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") + val result = t + AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") + result + } finally { + indent -= 1 + } + } + } + + private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = { + val sym = currOwner.newTermSymbol(name.fresh(prefix), pos, SYNTHETIC).setInfo(lhs.tpe) + changeOwner(lhs, currentOwner, sym) + ValDef(sym, changeOwner(lhs, currentOwner, sym)).setType(NoType).setPos(pos) + } + + private object anf { + def transformToList(tree: Tree): List[Tree] = { + mode = Anf; blockToList(transform(tree)) + } + + def _transformToList(tree: Tree): List[Tree] = trace(tree) { + val containsAwait = tree exists isAwait + if (!containsAwait) { + List(tree) + } else tree match { + case Select(qual, sel) => + val stats :+ expr = linearize.transformToList(qual) + stats :+ treeCopy.Select(tree, expr, sel) + + case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty => + // we an assume that no await call appears in a by-name argument position, + // this has already been checked. + val funStats :+ simpleFun = linearize.transformToList(fun) + def isAwaitRef(name: Name) = name.toString.startsWith(AnfTransform.this.name.await + "$") + val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) = + mapArgumentss[List[Tree]](fun, argss) { + case Arg(expr, byName, _) if byName /*|| isPure(expr) TODO */ => (Nil, expr) + case Arg(expr@Ident(name), _, _) if isAwaitRef(name) => (Nil, expr) // TODO needed? // not typed, so it eludes the check in `isSafeToInline` + case Arg(expr, _, argName) => + linearize.transformToList(expr) match { + case stats :+ expr1 => + val valDef = defineVal(argName, expr1, expr1.pos) + require(valDef.tpe != null, valDef) + val stats1 = stats :+ valDef + //stats1.foreach(changeOwner(_, currentOwner, currentOwner.owner)) + (stats1, gen.stabilize(gen.mkAttributedIdent(valDef.symbol))) + } + } + val applied = treeInfo.dissectApplied(tree) + val core = if (targs.isEmpty) simpleFun else treeCopy.TypeApply(applied.callee, simpleFun, targs) + val newApply = argExprss.foldLeft(core)(Apply(_, _)).setSymbol(tree.symbol) + val typedNewApply = localTyper.typedPos(tree.pos)(newApply).setType(tree.tpe) + funStats ++ argStatss.flatten.flatten :+ typedNewApply + case Block(stats, expr) => + (stats :+ expr).flatMap(linearize.transformToList) + + case ValDef(mods, name, tpt, rhs) => + if (rhs exists isAwait) { + val stats :+ expr = atOwner(currOwner.owner)(linearize.transformToList(rhs)) + stats.foreach(changeOwner(_, currOwner, currOwner.owner)) + stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr) + } else List(tree) + + case Assign(lhs, rhs) => + val stats :+ expr = linearize.transformToList(rhs) + stats :+ treeCopy.Assign(tree, lhs, expr) + + case If(cond, thenp, elsep) => + val condStats :+ condExpr = linearize.transformToList(cond) + val thenBlock = linearize.transformToBlock(thenp) + val elseBlock = linearize.transformToBlock(elsep) + // Typechecking with `condExpr` as the condition fails if the condition + // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems + // we rely on this call to `typeCheck` descending into the branches. + // But, we can get away with typechecking a throwaway `If` tree with the + // original scrutinee and the new branches, and setting that type on + // the real `If` tree. + val iff = treeCopy.If(tree, condExpr, thenBlock, elseBlock) + condStats :+ iff + + case Match(scrut, cases) => + val scrutStats :+ scrutExpr = linearize.transformToList(scrut) + val caseDefs = cases map { + case CaseDef(pat, guard, body) => + // extract local variables for all names bound in `pat`, and rewrite `body` + // to refer to these. + // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`. + val block = linearize.transformToBlock(body) + val (valDefs, mappings) = (pat collect { + case b@Bind(name, _) => + val vd = defineVal(name.toTermName + AnfTransform.this.name.bindSuffix, gen.mkAttributedStableRef(b.symbol), b.pos) + (vd, (b.symbol, vd.symbol)) + }).unzip + val (from, to) = mappings.unzip + val b@Block(stats1, expr1) = block.substituteSymbols(from, to).asInstanceOf[Block] + val newBlock = treeCopy.Block(b, valDefs ++ stats1, expr1) + treeCopy.CaseDef(tree, pat, guard, newBlock) + } + // Refer to comments the translation of `If` above. + val typedMatch = treeCopy.Match(tree, scrutExpr, caseDefs) + scrutStats :+ typedMatch + + case LabelDef(name, params, rhs) => + List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) + + case TypeApply(fun, targs) => + val funStats :+ simpleFun = linearize.transformToList(fun) + funStats :+ treeCopy.TypeApply(tree, simpleFun, targs) + + case _ => + List(tree) + } + } + } + } +} diff --git a/src/main/scala/scala/async/internal/AsyncAnalysis.scala b/src/main/scala/scala/async/internal/AsyncAnalysis.scala new file mode 100644 index 0000000..62842c9 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncAnalysis.scala @@ -0,0 +1,91 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +package scala.async.internal + +import scala.reflect.macros.Context +import scala.collection.mutable + +trait AsyncAnalysis { + self: AsyncMacro => + + import global._ + + /** + * Analyze the contents of an `async` block in order to: + * - Report unsupported `await` calls under nested templates, functions, by-name arguments. + * + * Must be called on the original tree, not on the ANF transformed tree. + */ + def reportUnsupportedAwaits(tree: Tree, report: Boolean): Boolean = { + val analyzer = new UnsupportedAwaitAnalyzer(report) + analyzer.traverse(tree) + analyzer.hasUnsupportedAwaits + } + + private class UnsupportedAwaitAnalyzer(report: Boolean) extends AsyncTraverser { + var hasUnsupportedAwaits = false + + override def nestedClass(classDef: ClassDef) { + val kind = if (classDef.symbol.isTrait) "trait" else "class" + reportUnsupportedAwait(classDef, s"nested ${kind}") + } + + override def nestedModule(module: ModuleDef) { + reportUnsupportedAwait(module, "nested object") + } + + override def nestedMethod(defDef: DefDef) { + reportUnsupportedAwait(defDef, "nested method") + } + + override def byNameArgument(arg: Tree) { + reportUnsupportedAwait(arg, "by-name argument") + } + + override def function(function: Function) { + reportUnsupportedAwait(function, "nested function") + } + + override def patMatFunction(tree: Match) { + reportUnsupportedAwait(tree, "nested function") + } + + override def traverse(tree: Tree) { + def containsAwait = tree exists isAwait + tree match { + case Try(_, _, _) if containsAwait => + reportUnsupportedAwait(tree, "try/catch") + super.traverse(tree) + case Return(_) => + abort(tree.pos, "return is illegal within a async block") + case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) => + // TODO lift this restriction + abort(tree.pos, "lazy vals are illegal within an async block") + case _ => + super.traverse(tree) + } + } + + /** + * @return true, if the tree contained an unsupported await. + */ + private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = { + val badAwaits: List[RefTree] = tree collect { + case rt: RefTree if isAwait(rt) => rt + } + badAwaits foreach { + tree => + reportError(tree.pos, s"await must not be used under a $whyUnsupported.") + } + badAwaits.nonEmpty + } + + private def reportError(pos: Position, msg: String) { + hasUnsupportedAwaits = true + if (report) + abort(pos, msg) + } + } +} diff --git a/src/main/scala/scala/async/internal/AsyncBase.scala b/src/main/scala/scala/async/internal/AsyncBase.scala new file mode 100644 index 0000000..2f7e38d --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncBase.scala @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +package scala.async.internal + +import scala.reflect.internal.annotations.compileTimeOnly +import scala.reflect.macros.Context + +/** + * A base class for the `async` macro. Subclasses must provide: + * + * - Concrete types for a given future system + * - Tree manipulations to create and complete the equivalent of Future and Promise + * in that system. + * - The `async` macro declaration itself, and a forwarder for the macro implementation. + * (The latter is temporarily needed to workaround bug SI-6650 in the macro system) + * + * The default implementation, [[scala.async.Async]], binds the macro to `scala.concurrent._`. + */ +abstract class AsyncBase { + self => + + type FS <: FutureSystem + val futureSystem: FS + + /** + * A call to `await` must be nested in an enclosing `async` block. + * + * A call to `await` does not block the current thread, rather it is a delimiter + * used by the enclosing `async` macro. Code following the `await` + * call is executed asynchronously, when the argument of `await` has been completed. + * + * @param awaitable the future from which a value is awaited. + * @tparam T the type of that value. + * @return the value. + */ + @compileTimeOnly("`await` must be enclosed in an `async` block") + def await[T](awaitable: futureSystem.Fut[T]): T = ??? + + protected[async] def fallbackEnabled = false + + def asyncImpl[T: c.WeakTypeTag](c: Context) + (body: c.Expr[T]) + (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { + import c.universe._ + + val asyncMacro = AsyncMacro(c, futureSystem) + + val code = asyncMacro.asyncTransform[T]( + body.tree.asInstanceOf[asyncMacro.global.Tree], + execContext.tree.asInstanceOf[asyncMacro.global.Tree], + fallbackEnabled)(implicitly[c.WeakTypeTag[T]].asInstanceOf[asyncMacro.global.WeakTypeTag[T]]) + + AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}") + c.Expr[futureSystem.Fut[T]](code.asInstanceOf[Tree]) + } +} diff --git a/src/main/scala/scala/async/internal/AsyncId.scala b/src/main/scala/scala/async/internal/AsyncId.scala new file mode 100644 index 0000000..394f587 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncId.scala @@ -0,0 +1,64 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +package scala.async.internal + +import language.experimental.macros +import scala.reflect.macros.Context +import scala.reflect.internal.SymbolTable + +object AsyncId extends AsyncBase { + lazy val futureSystem = IdentityFutureSystem + type FS = IdentityFutureSystem.type + + def async[T](body: T) = macro asyncIdImpl[T] + + def asyncIdImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit) +} + +/** + * A trivial implementation of [[FutureSystem]] that performs computations + * on the current thread. Useful for testing. + */ +object IdentityFutureSystem extends FutureSystem { + + class Prom[A](var a: A) + + type Fut[A] = A + type ExecContext = Unit + + def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops { + val universe: c.type = c + + import universe._ + + def execContext: Expr[ExecContext] = Expr[Unit](Literal(Constant(()))) + + def promType[A: WeakTypeTag]: Type = weakTypeOf[Prom[A]] + def execContextType: Type = weakTypeOf[Unit] + + def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { + new Prom(null.asInstanceOf[A]) + } + + def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify { + prom.splice.a + } + + def future[A: WeakTypeTag](t: Expr[A])(execContext: Expr[ExecContext]) = t + + def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], + execContext: Expr[ExecContext]): Expr[Unit] = reify { + fun.splice.apply(util.Success(future.splice)) + Expr[Unit](Literal(Constant(()))).splice + } + + def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { + prom.splice.a = value.splice.get + Expr[Unit](Literal(Constant(()))).splice + } + + def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = ??? + } +} diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala new file mode 100644 index 0000000..6b7d031 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncMacro.scala @@ -0,0 +1,29 @@ +package scala.async.internal + +import scala.tools.nsc.Global +import scala.tools.nsc.transform.TypingTransformers + +object AsyncMacro { + def apply(c: reflect.macros.Context, futureSystem0: FutureSystem): AsyncMacro = { + import language.reflectiveCalls + val powerContext = c.asInstanceOf[c.type {val universe: Global; val callsiteTyper: universe.analyzer.Typer}] + new AsyncMacro { + val global: powerContext.universe.type = powerContext.universe + val callSiteTyper: global.analyzer.Typer = powerContext.callsiteTyper + val futureSystem: futureSystem0.type = futureSystem0 + val futureSystemOps: futureSystem.Ops {val universe: global.type} = futureSystem0.mkOps(global) + val macroApplication: global.Tree = c.macroApplication.asInstanceOf[global.Tree] + } + } +} + +private[async] trait AsyncMacro + extends TypingTransformers + with AnfTransform with TransformUtils with Lifter + with ExprBuilder with AsyncTransform with AsyncAnalysis { + + val global: Global + val callSiteTyper: global.analyzer.Typer + val macroApplication: global.Tree + +} diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala new file mode 100644 index 0000000..bdc8664 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -0,0 +1,176 @@ +package scala.async.internal + +trait AsyncTransform { + self: AsyncMacro => + + import global._ + + def asyncTransform[T](body: Tree, execContext: Tree, cpsFallbackEnabled: Boolean) + (implicit resultType: WeakTypeTag[T]): Tree = { + + reportUnsupportedAwaits(body, report = !cpsFallbackEnabled) + + // Transform to A-normal form: + // - no await calls in qualifiers or arguments, + // - if/match only used in statement position. + val anfTree: Block = anfTransform(body) + + val resumeFunTreeDummyBody = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), Literal(Constant(()))) + + val applyDefDefDummyBody: DefDef = { + val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) + DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), Literal(Constant(()))) + } + + val stateMachineType = applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType)) + + val stateMachine: ClassDef = { + val body: List[Tree] = { + val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) + val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree) + val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext) + + val apply0DefDef: DefDef = { + // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. + // See SI-1247 for the the optimization that avoids creatio + DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil)) + } + List(emptyConstructor, stateVar, result, execContextValDef) ++ List(resumeFunTreeDummyBody, applyDefDefDummyBody, apply0DefDef) + } + val template = { + Template(List(stateMachineType), emptyValDef, body) + } + val t = ClassDef(NoMods, name.stateMachineT, Nil, template) + callSiteTyper.typedPos(macroApplication.pos)(Block(t :: Nil, Literal(Constant(())))) + t + } + + val asyncBlock: AsyncBlock = { + val symLookup = new SymLookup(stateMachine.symbol, applyDefDefDummyBody.vparamss.head.head.symbol) + buildAsyncBlock(anfTree, symLookup) + } + + logDiagnostics(anfTree, asyncBlock.asyncStates.map(_.toString)) + + def startStateMachine: Tree = { + val stateMachineSpliced: Tree = spliceMethodBodies( + liftables(asyncBlock.asyncStates), + stateMachine, + asyncBlock.onCompleteHandler[T], + asyncBlock.resumeFunTree[T].rhs + ) + + def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) + + Block(List[Tree]( + stateMachineSpliced, + ValDef(NoMods, name.stateMachine, stateMachineType, Apply(Select(New(Ident(stateMachine.symbol)), nme.CONSTRUCTOR), Nil)), + futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil), selectStateMachine(name.execContext)) + ), + futureSystemOps.promiseToFuture(Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree) + } + + val isSimple = asyncBlock.asyncStates.size == 1 + if (isSimple) + futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }` + else + startStateMachine + } + + def logDiagnostics(anfTree: Tree, states: Seq[String]) { + val pos = macroApplication.pos + def location = try { + pos.source.path + } catch { + case _: UnsupportedOperationException => + pos.toString + } + + AsyncUtils.vprintln(s"In file '$location':") + AsyncUtils.vprintln(s"${macroApplication}") + AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree") + states foreach (s => AsyncUtils.vprintln(s)) + } + + def spliceMethodBodies(liftables: List[Tree], tree: Tree, applyBody: Tree, + resumeBody: Tree): Tree = { + + val liftedSyms = liftables.map(_.symbol).toSet + val stateMachineClass = tree.symbol + liftedSyms.foreach { + sym => + if (sym != null) { + sym.owner = stateMachineClass + if (sym.isModule) + sym.moduleClass.owner = stateMachineClass + } + } + // Replace the ValDefs in the splicee with Assigns to the corresponding lifted + // fields. Similarly, replace references to them with references to the field. + // + // This transform will be only be run on the RHS of `def foo`. + class UseFields extends MacroTypingTransformer { + override def transform(tree: Tree): Tree = tree match { + case _ if currentOwner == stateMachineClass => + super.transform(tree) + case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) => + atOwner(currentOwner) { + val fieldSym = tree.symbol + val set = Assign(gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), transform(rhs)) + changeOwner(set, tree.symbol, currentOwner) + localTyper.typedPos(tree.pos)(set) + } + case _: DefTree if liftedSyms(tree.symbol) => + EmptyTree + case Ident(name) if liftedSyms(tree.symbol) => + val fieldSym = tree.symbol + gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym).setType(tree.tpe) + case _ => + super.transform(tree) + } + } + + val liftablesUseFields = liftables.map { + case vd: ValDef => vd + case x => + val useField = new UseFields() + //.substituteSymbols(fromSyms, toSyms) + useField.atOwner(stateMachineClass)(useField.transform(x)) + } + + tree.children.foreach { + t => + new ChangeOwnerAndModuleClassTraverser(callSiteTyper.context.owner, tree.symbol).traverse(t) + } + val treeSubst = tree + + def fixup(dd: DefDef, body: Tree, ctx: analyzer.Context): Tree = { + val spliceeAnfFixedOwnerSyms = body + val useField = new UseFields() + val newRhs = useField.atOwner(dd.symbol)(useField.transform(spliceeAnfFixedOwnerSyms)) + val typer = global.analyzer.newTyper(ctx.make(dd, dd.symbol)) + treeCopy.DefDef(dd, dd.mods, dd.name, dd.tparams, dd.vparamss, dd.tpt, typer.typed(newRhs)) + } + + liftablesUseFields.foreach(t => if (t.symbol != null) stateMachineClass.info.decls.enter(t.symbol)) + + val result0 = transformAt(treeSubst) { + case t@Template(parents, self, stats) => + (ctx: analyzer.Context) => { + treeCopy.Template(t, parents, self, liftablesUseFields ++ stats) + } + } + val result = transformAt(result0) { + case dd@DefDef(_, name.apply, _, List(List(_)), _, _) if dd.symbol.owner == stateMachineClass => + (ctx: analyzer.Context) => + val typedTree = fixup(dd, changeOwner(applyBody, callSiteTyper.context.owner, dd.symbol), ctx) + typedTree + case dd@DefDef(_, name.resume, _, _, _, _) if dd.symbol.owner == stateMachineClass => + (ctx: analyzer.Context) => + val changed = changeOwner(resumeBody, callSiteTyper.context.owner, dd.symbol) + val res = fixup(dd, changed, ctx) + res + } + result + } +} diff --git a/src/main/scala/scala/async/internal/AsyncUtils.scala b/src/main/scala/scala/async/internal/AsyncUtils.scala new file mode 100644 index 0000000..8700bd6 --- /dev/null +++ b/src/main/scala/scala/async/internal/AsyncUtils.scala @@ -0,0 +1,16 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ +package scala.async.internal + +object AsyncUtils { + + private def enabled(level: String) = sys.props.getOrElse(s"scala.async.$level", "false").equalsIgnoreCase("true") + + private def verbose = enabled("debug") + private def trace = enabled("trace") + + private[async] def vprintln(s: => Any): Unit = if (verbose) println(s"[async] $s") + + private[async] def trace(s: => Any): Unit = if (trace) println(s"[async] $s") +} diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala new file mode 100644 index 0000000..1ce30e6 --- /dev/null +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -0,0 +1,388 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ +package scala.async.internal + +import scala.reflect.macros.Context +import scala.collection.mutable.ListBuffer +import collection.mutable +import language.existentials +import scala.reflect.api.Universe +import scala.reflect.api +import scala.Some + +trait ExprBuilder { + builder: AsyncMacro => + + import global._ + import defn._ + + val futureSystem: FutureSystem + val futureSystemOps: futureSystem.Ops { val universe: global.type } + + val stateAssigner = new StateAssigner + val labelDefStates = collection.mutable.Map[Symbol, Int]() + + trait AsyncState { + def state: Int + + def mkHandlerCaseForState: CaseDef + + def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None + + def stats: List[Tree] + + final def allStats: List[Tree] = this match { + case a: AsyncStateWithAwait => stats :+ a.awaitable.resultValDef + case _ => stats + } + + final def body: Tree = stats match { + case stat :: Nil => stat + case init :+ last => Block(init, last) + } + } + + /** A sequence of statements the concludes with a unconditional transition to `nextState` */ + final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup) + extends AsyncState { + + def mkHandlerCaseForState: CaseDef = + mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup)) + + override val toString: String = + s"AsyncState #$state, next = $nextState" + } + + /** A sequence of statements with a conditional transition to the next state, which will represent + * a branch of an `if` or a `match`. + */ + final class AsyncStateWithoutAwait(val stats: List[Tree], val state: Int) extends AsyncState { + override def mkHandlerCaseForState: CaseDef = + mkHandlerCase(state, stats) + + override val toString: String = + s"AsyncStateWithoutAwait #$state" + } + + /** A sequence of statements that concludes with an `await` call. The `onComplete` + * handler will unconditionally transition to `nestState`.`` + */ + final class AsyncStateWithAwait(val stats: List[Tree], val state: Int, nextState: Int, + val awaitable: Awaitable, symLookup: SymLookup) + extends AsyncState { + + override def mkHandlerCaseForState: CaseDef = { + val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr), + Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree + mkHandlerCase(state, stats :+ callOnComplete) + } + + override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = { + val tryGetTree = + Assign( + Ident(awaitable.resultName), + TypeApply(Select(Select(Ident(symLookup.applyTrParam), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) + ) + + /* if (tr.isFailure) + * result.complete(tr.asInstanceOf[Try[T]]) + * else { + * <resultName> = tr.get.asInstanceOf[<resultType>] + * <nextState> + * <mkResumeApply> + * } + */ + val ifIsFailureTree = + If(Select(Ident(symLookup.applyTrParam), Try_isFailure), + futureSystemOps.completeProm[T]( + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), + Expr[scala.util.Try[T]]( + TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")), + List(TypeTree(weakTypeOf[scala.util.Try[T]]))))).tree, + Block(List(tryGetTree, mkStateTree(nextState, symLookup)), mkResumeApply(symLookup)) + ) + + Some(mkHandlerCase(state, List(ifIsFailureTree))) + } + + override val toString: String = + s"AsyncStateWithAwait #$state, next = $nextState" + } + + /* + * Builder for a single state of an async method. + */ + final class AsyncStateBuilder(state: Int, private val symLookup: SymLookup) { + /* Statements preceding an await call. */ + private val stats = ListBuffer[Tree]() + /** The state of the target of a LabelDef application (while loop jump) */ + private var nextJumpState: Option[Int] = None + + def +=(stat: Tree): this.type = { + assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat") + def addStat() = stats += stat + stat match { + case Apply(fun, Nil) => + labelDefStates get fun.symbol match { + case Some(nextState) => nextJumpState = Some(nextState) + case None => addStat() + } + case _ => addStat() + } + this + } + + def resultWithAwait(awaitable: Awaitable, + nextState: Int): AsyncState = { + val effectiveNextState = nextJumpState.getOrElse(nextState) + new AsyncStateWithAwait(stats.toList, state, effectiveNextState, awaitable, symLookup) + } + + def resultSimple(nextState: Int): AsyncState = { + val effectiveNextState = nextJumpState.getOrElse(nextState) + new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup) + } + + def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { + def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup)) + this += If(condTree, mkBranch(thenState), mkBranch(elseState)) + new AsyncStateWithoutAwait(stats.toList, state) + } + + /** + * Build `AsyncState` ending with a match expression. + * + * The cases of the match simply resume at the state of their corresponding right-hand side. + * + * @param scrutTree tree of the scrutinee + * @param cases list of case definitions + * @param caseStates starting state of the right-hand side of the each case + * @return an `AsyncState` representing the match expression + */ + def resultWithMatch(scrutTree: Tree, cases: List[CaseDef], caseStates: List[Int], symLookup: SymLookup): AsyncState = { + // 1. build list of changed cases + val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { + case CaseDef(pat, guard, rhs) => + val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal) + CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num), symLookup), mkResumeApply(symLookup))) + } + // 2. insert changed match tree at the end of the current state + this += Match(scrutTree, newCases) + new AsyncStateWithoutAwait(stats.toList, state) + } + + def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = { + this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup)) + new AsyncStateWithoutAwait(stats.toList, state) + } + + override def toString: String = { + val statsBeforeAwait = stats.mkString("\n") + s"ASYNC STATE:\n$statsBeforeAwait" + } + } + + /** + * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`). + * + * @param stats a list of expressions + * @param expr the last expression of the block + * @param startState the start state + * @param endState the state to continue with + */ + final private class AsyncBlockBuilder(stats: List[Tree], expr: Tree, startState: Int, endState: Int, + private val symLookup: SymLookup) { + val asyncStates = ListBuffer[AsyncState]() + + var stateBuilder = new AsyncStateBuilder(startState, symLookup) + var currState = startState + + /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ + def checkForUnsupportedAwait(tree: Tree) = if (tree exists { + case Apply(fun, _) if isAwait(fun) => true + case _ => false + }) abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException + + def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = { + val (nestedStats, nestedExpr) = statsAndExpr(nestedTree) + new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, symLookup) + } + + import stateAssigner.nextState + + // populate asyncStates + for (stat <- stats) stat match { + // the val name = await(..) pattern + case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => + val afterAwaitState = nextState() + val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd) + asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await + currState = afterAwaitState + stateBuilder = new AsyncStateBuilder(currState, symLookup) + + case If(cond, thenp, elsep) if stat exists isAwait => + checkForUnsupportedAwait(cond) + + val thenStartState = nextState() + val elseStartState = nextState() + val afterIfState = nextState() + + asyncStates += + // the two Int arguments are the start state of the then branch and the else branch, respectively + stateBuilder.resultWithIf(cond, thenStartState, elseStartState) + + List((thenp, thenStartState), (elsep, elseStartState)) foreach { + case (branchTree, state) => + val builder = nestedBlockBuilder(branchTree, state, afterIfState) + asyncStates ++= builder.asyncStates + } + + currState = afterIfState + stateBuilder = new AsyncStateBuilder(currState, symLookup) + + case Match(scrutinee, cases) if stat exists isAwait => + checkForUnsupportedAwait(scrutinee) + + val caseStates = cases.map(_ => nextState()) + val afterMatchState = nextState() + + asyncStates += + stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup) + + for ((cas, num) <- cases.zipWithIndex) { + val (stats, expr) = statsAndExpr(cas.body) + val stats1 = stats.dropWhile(isSyntheticBindVal) + val builder = nestedBlockBuilder(Block(stats1, expr), caseStates(num), afterMatchState) + asyncStates ++= builder.asyncStates + } + + currState = afterMatchState + stateBuilder = new AsyncStateBuilder(currState, symLookup) + + case ld@LabelDef(name, params, rhs) if rhs exists isAwait => + val startLabelState = nextState() + val afterLabelState = nextState() + asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup) + labelDefStates(ld.symbol) = startLabelState + val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState) + asyncStates ++= builder.asyncStates + + currState = afterLabelState + stateBuilder = new AsyncStateBuilder(currState, symLookup) + case _ => + checkForUnsupportedAwait(stat) + stateBuilder += stat + } + // complete last state builder (representing the expressions after the last await) + stateBuilder += expr + val lastState = stateBuilder.resultSimple(endState) + asyncStates += lastState + } + + trait AsyncBlock { + def asyncStates: List[AsyncState] + + def onCompleteHandler[T: WeakTypeTag]: Tree + + def resumeFunTree[T]: DefDef + } + + case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) { + def stateMachineMember(name: TermName): Symbol = + stateMachineClass.info.member(name) + def memberRef(name: TermName) = gen.mkAttributedRef(stateMachineMember(name)) + } + + def buildAsyncBlock(block: Block, symLookup: SymLookup): AsyncBlock = { + val Block(stats, expr) = block + val startState = stateAssigner.nextState() + val endState = Int.MaxValue + + val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, symLookup) + + new AsyncBlock { + def asyncStates = blockBuilder.asyncStates.toList + + def mkCombinedHandlerCases[T]: List[CaseDef] = { + val caseForLastState: CaseDef = { + val lastState = asyncStates.last + val lastStateBody = Expr[T](lastState.body) + val rhs = futureSystemOps.completeProm( + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Success(lastStateBody.splice))) + mkHandlerCase(lastState.state, rhs.tree) + } + asyncStates.toList match { + case s :: Nil => + List(caseForLastState) + case _ => + val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState + initCases :+ caseForLastState + } + } + + val initStates = asyncStates.init + + /** + * def resume(): Unit = { + * try { + * state match { + * case 0 => { + * f11 = exprReturningFuture + * f11.onComplete(onCompleteHandler)(context) + * } + * ... + * } + * } catch { + * case NonFatal(t) => result.failure(t) + * } + * } + */ + def resumeFunTree[T]: DefDef = + DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), + Try( + Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T]), + List( + CaseDef( + Bind(name.t, Ident(nme.WILDCARD)), + Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), + Block(List({ + val t = Expr[Throwable](Ident(name.t)) + futureSystemOps.completeProm[T]( + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Failure(t.splice))).tree + }), literalUnit))), EmptyTree)) + + /** + * // assumes tr: Try[Any] is in scope. + * // + * state match { + * case 0 => { + * x11 = tr.get.asInstanceOf[Double]; + * state = 1; + * resume() + * } + */ + def onCompleteHandler[T: WeakTypeTag]: Tree = Match(symLookup.memberRef(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList) + } + } + + private def isSyntheticBindVal(tree: Tree) = tree match { + case vd@ValDef(_, lname, _, Ident(rname)) => lname.toString.contains(name.bindSuffix) + case _ => false + } + + case class Awaitable(expr: Tree, resultName: Symbol, resultType: Type, resultValDef: ValDef) + + private def mkResumeApply(symLookup: SymLookup) = Apply(symLookup.memberRef(name.resume), Nil) + + private def mkStateTree(nextState: Int, symLookup: SymLookup): Tree = + Assign(symLookup.memberRef(name.state), Literal(Constant(nextState))) + + private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef = + mkHandlerCase(num, Block(rhs, literalUnit)) + + private def mkHandlerCase(num: Int, rhs: Tree): CaseDef = + CaseDef(Literal(Constant(num)), EmptyTree, rhs) + + private def literalUnit = Literal(Constant(())) +} diff --git a/src/main/scala/scala/async/internal/FutureSystem.scala b/src/main/scala/scala/async/internal/FutureSystem.scala new file mode 100644 index 0000000..101b7bf --- /dev/null +++ b/src/main/scala/scala/async/internal/FutureSystem.scala @@ -0,0 +1,106 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ +package scala.async.internal + +import scala.language.higherKinds + +import scala.reflect.macros.Context +import scala.reflect.internal.SymbolTable + +/** + * An abstraction over a future system. + * + * Used by the macro implementations in [[scala.async.AsyncBase]] to + * customize the code generation. + * + * The API mirrors that of `scala.concurrent.Future`, see the instance + * [[ScalaConcurrentFutureSystem]] for an example of how + * to implement this. + */ +trait FutureSystem { + /** A container to receive the final value of the computation */ + type Prom[A] + /** A (potentially in-progress) computation */ + type Fut[A] + /** An execution context, required to create or register an on completion callback on a Future. */ + type ExecContext + + trait Ops { + val universe: reflect.internal.SymbolTable + + import universe._ + def Expr[T: WeakTypeTag](tree: Tree): Expr[T] = universe.Expr[T](rootMirror, universe.FixedMirrorTreeCreator(rootMirror, tree)) + + def promType[A: WeakTypeTag]: Type + def execContextType: Type + + /** Create an empty promise */ + def createProm[A: WeakTypeTag]: Expr[Prom[A]] + + /** Extract a future from the given promise. */ + def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]): Expr[Fut[A]] + + /** Construct a future to asynchronously compute the given expression */ + def future[A: WeakTypeTag](a: Expr[A])(execContext: Expr[ExecContext]): Expr[Fut[A]] + + /** Register an call back to run on completion of the given future */ + def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], + execContext: Expr[ExecContext]): Expr[Unit] + + /** Complete a promise with a value */ + def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] + + def spawn(tree: Tree, execContext: Tree): Tree = + future(Expr[Unit](tree))(Expr[ExecContext](execContext)).tree + + // TODO Why is this needed? + def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] + } + + def mkOps(c: SymbolTable): Ops { val universe: c.type } +} + +object ScalaConcurrentFutureSystem extends FutureSystem { + + import scala.concurrent._ + + type Prom[A] = Promise[A] + type Fut[A] = Future[A] + type ExecContext = ExecutionContext + + def mkOps(c: SymbolTable): Ops {val universe: c.type} = new Ops { + val universe: c.type = c + + import universe._ + + def promType[A: WeakTypeTag]: Type = weakTypeOf[Promise[A]] + def execContextType: Type = weakTypeOf[ExecutionContext] + + def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { + Promise[A]() + } + + def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify { + prom.splice.future + } + + def future[A: WeakTypeTag](a: Expr[A])(execContext: Expr[ExecContext]) = reify { + Future(a.splice)(execContext.splice) + } + + def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], + execContext: Expr[ExecContext]): Expr[Unit] = reify { + future.splice.onComplete(fun.splice)(execContext.splice) + } + + def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { + prom.splice.complete(value.splice) + Expr[Unit](Literal(Constant(()))).splice + } + + def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = reify { + future.splice.asInstanceOf[Fut[A]] + } + } +} diff --git a/src/main/scala/scala/async/internal/Lifter.scala b/src/main/scala/scala/async/internal/Lifter.scala new file mode 100644 index 0000000..f49dcbb --- /dev/null +++ b/src/main/scala/scala/async/internal/Lifter.scala @@ -0,0 +1,150 @@ +package scala.async.internal + +trait Lifter { + self: AsyncMacro => + import global._ + + /** + * Identify which DefTrees are used (including transitively) which are declared + * in some state but used (including transitively) in another state. + * + * These will need to be lifted to class members of the state machine. + */ + def liftables(asyncStates: List[AsyncState]): List[Tree] = { + object companionship { + private val companions = collection.mutable.Map[Symbol, Symbol]() + private val companionsInverse = collection.mutable.Map[Symbol, Symbol]() + private def record(sym1: Symbol, sym2: Symbol) { + companions(sym1) = sym2 + companions(sym2) = sym1 + } + + def record(defs: List[Tree]) { + // Keep note of local companions so we rename them consistently + // when lifting. + val comps = for { + cd@ClassDef(_, _, _, _) <- defs + md@ModuleDef(_, _, _) <- defs + if (cd.name.toTermName == md.name) + } record(cd.symbol, md.symbol) + } + def companionOf(sym: Symbol): Symbol = { + companions.get(sym).orElse(companionsInverse.get(sym)).getOrElse(NoSymbol) + } + } + + + val defs: Map[Tree, Int] = { + /** Collect the DefTrees directly enclosed within `t` that have the same owner */ + def collectDirectlyEnclosedDefs(t: Tree): List[DefTree] = t match { + case dt: DefTree => dt :: Nil + case _: Function => Nil + case t => + val childDefs = t.children.flatMap(collectDirectlyEnclosedDefs(_)) + companionship.record(childDefs) + childDefs + } + asyncStates.flatMap { + asyncState => + val defs = collectDirectlyEnclosedDefs(Block(asyncState.allStats: _*)) + defs.map((_, asyncState.state)) + }.toMap + } + + // In which block are these symbols defined? + val symToDefiningState: Map[Symbol, Int] = defs.map { + case (k, v) => (k.symbol, v) + } + + // The definitions trees + val symToTree: Map[Symbol, Tree] = defs.map { + case (k, v) => (k.symbol, k) + } + + // The direct references of each definition tree + val defSymToReferenced: Map[Symbol, List[Symbol]] = defs.keys.map { + case tree => (tree.symbol, tree.collect { + case rt: RefTree if symToDefiningState.contains(rt.symbol) => rt.symbol + }) + }.toMap + + // The direct references of each block, excluding references of `DefTree`-s which + // are already accounted for. + val stateIdToDirectlyReferenced: Map[Int, List[Symbol]] = { + val refs: List[(Int, Symbol)] = asyncStates.flatMap( + asyncState => asyncState.stats.filterNot(_.isDef).flatMap(_.collect { + case rt: RefTree if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol) + }) + ) + toMultiMap(refs) + } + + def liftableSyms: Set[Symbol] = { + val liftableMutableSet = collection.mutable.Set[Symbol]() + def markForLift(sym: Symbol) { + if (!liftableMutableSet(sym)) { + liftableMutableSet += sym + + // Only mark transitive references of defs, modules and classes. The RHS of lifted vals/vars + // stays in its original location, so things that it refers to need not be lifted. + if (!(sym.isVal || sym.isVar)) + defSymToReferenced(sym).foreach(sym2 => markForLift(sym2)) + } + } + // Start things with DefTrees directly referenced from statements from other states... + val liftableStatementRefs: List[Symbol] = stateIdToDirectlyReferenced.toList.flatMap { + case (i, syms) => syms.filter(sym => symToDefiningState(sym) != i) + } + // .. and likewise for DefTrees directly referenced by other DefTrees from other states + val liftableRefsOfDefTrees = defSymToReferenced.toList.flatMap { + case (referee, referents) => referents.filter(sym => symToDefiningState(sym) != symToDefiningState(referee)) + } + // Mark these for lifting, which will follow transitive references. + (liftableStatementRefs ++ liftableRefsOfDefTrees).foreach(markForLift) + liftableMutableSet.toSet + } + + val lifted = liftableSyms.map(symToTree).toList.map { + case vd@ValDef(_, _, tpt, rhs) => + import reflect.internal.Flags._ + val sym = vd.symbol + sym.setFlag(MUTABLE | STABLE | PRIVATE | LOCAL) + sym.name = name.fresh(sym.name.toTermName) + sym.modifyInfo(_.deconst) + ValDef(vd.symbol, gen.mkZero(vd.symbol.info)).setPos(vd.pos) + case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) => + import reflect.internal.Flags._ + val sym = dd.symbol + sym.name = this.name.fresh(sym.name.toTermName) + sym.setFlag(PRIVATE | LOCAL) + DefDef(dd.symbol, rhs).setPos(dd.pos) + case cd@ClassDef(_, _, _, impl) => + import reflect.internal.Flags._ + val sym = cd.symbol + sym.name = newTypeName(name.fresh(sym.name.toString).toString) + companionship.companionOf(cd.symbol) match { + case NoSymbol => + case moduleSymbol => + moduleSymbol.name = sym.name.toTermName + moduleSymbol.moduleClass.name = moduleSymbol.name.toTypeName + } + ClassDef(cd.symbol, impl).setPos(cd.pos) + case md@ModuleDef(_, _, impl) => + import reflect.internal.Flags._ + val sym = md.symbol + companionship.companionOf(md.symbol) match { + case NoSymbol => + sym.name = name.fresh(sym.name.toTermName) + sym.moduleClass.name = sym.name.toTypeName + case classSymbol => // will be renamed by `case ClassDef` above. + } + ModuleDef(md.symbol, impl).setPos(md.pos) + case td@TypeDef(_, _, _, rhs) => + import reflect.internal.Flags._ + val sym = td.symbol + sym.name = newTypeName(name.fresh(sym.name.toString).toString) + TypeDef(td.symbol, rhs).setPos(td.pos) + } + lifted + } +} diff --git a/src/main/scala/scala/async/internal/StateAssigner.scala b/src/main/scala/scala/async/internal/StateAssigner.scala new file mode 100644 index 0000000..cdde7a4 --- /dev/null +++ b/src/main/scala/scala/async/internal/StateAssigner.scala @@ -0,0 +1,14 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +package scala.async.internal + +private[async] final class StateAssigner { + private var current = -1 + + def nextState(): Int = { + current += 1 + current + } +} diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala new file mode 100644 index 0000000..2582c91 --- /dev/null +++ b/src/main/scala/scala/async/internal/TransformUtils.scala @@ -0,0 +1,251 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ +package scala.async.internal + +import scala.reflect.macros.Context +import reflect.ClassTag +import scala.reflect.macros.runtime.AbortMacroException + +/** + * Utilities used in both `ExprBuilder` and `AnfTransform`. + */ +private[async] trait TransformUtils { + self: AsyncMacro => + + import global._ + + object name { + val resume = newTermName("resume") + val apply = newTermName("apply") + val matchRes = "matchres" + val ifRes = "ifres" + val await = "await" + val bindSuffix = "$bind" + + val state = newTermName("state") + val result = newTermName("result") + val execContext = newTermName("execContext") + val stateMachine = newTermName(fresh("stateMachine")) + val stateMachineT = stateMachine.toTypeName + val tr = newTermName("tr") + val t = newTermName("throwable") + + def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) + + def fresh(name: String): String = currentUnit.freshTermName("" + name + "$").toString + } + + def isAwait(fun: Tree) = + fun.symbol == defn.Async_await + + private lazy val Boolean_ShortCircuits: Set[Symbol] = { + import definitions.BooleanClass + def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName) + val Boolean_&& = BooleanTermMember("&&") + val Boolean_|| = BooleanTermMember("||") + Set(Boolean_&&, Boolean_||) + } + + private def isByName(fun: Tree): ((Int, Int) => Boolean) = { + if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true + else { + val paramss = fun.tpe.paramss + val byNamess = paramss.map(_.map(_.isByNameParam)) + (i, j) => util.Try(byNamess(i)(j)).getOrElse(false) + } + } + private def argName(fun: Tree): ((Int, Int) => String) = { + val paramss = fun.tpe.paramss + val namess = paramss.map(_.map(_.name.toString)) + (i, j) => util.Try(namess(i)(j)).getOrElse(s"arg_${i}_${j}") + } + + def Expr[A: WeakTypeTag](t: Tree) = global.Expr[A](rootMirror, new FixedMirrorTreeCreator(rootMirror, t)) + + object defn { + def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { + Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) + } + + def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify { + self.splice.contains(elem.splice) + } + + def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify { + self.splice.apply(arg.splice) + } + + def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify { + self.splice == other.splice + } + + def mkTry_get[A](self: Expr[util.Try[A]]) = reify { + self.splice.get + } + + val TryClass = rootMirror.staticClass("scala.util.Try") + val Try_get = TryClass.typeSignature.member(newTermName("get")).ensuring(_ != NoSymbol) + val Try_isFailure = TryClass.typeSignature.member(newTermName("isFailure")).ensuring(_ != NoSymbol) + val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe)) + val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal") + val AsyncClass = rootMirror.staticClass("scala.async.internal.AsyncBase") + val Async_await = AsyncClass.typeSignature.member(newTermName("await")).ensuring(_ != NoSymbol) + } + + def isSafeToInline(tree: Tree) = { + treeInfo.isExprSafeToInline(tree) + } + + /** Map a list of arguments to: + * - A list of argument Trees + * - A list of auxillary results. + * + * The function unwraps and rewraps the `arg :_*` construct. + * + * @param args The original argument trees + * @param f A function from argument (with '_*' unwrapped) and argument index to argument. + * @tparam A The type of the auxillary result + */ + private def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = { + args match { + case args :+ Typed(tree, Ident(tpnme.WILDCARD_STAR)) => + val (a, argExprs :+ lastArgExpr) = (args :+ tree).zipWithIndex.map(f.tupled).unzip + val exprs = argExprs :+ Typed(lastArgExpr, Ident(tpnme.WILDCARD_STAR)).setPos(lastArgExpr.pos) + (a, exprs) + case args => + args.zipWithIndex.map(f.tupled).unzip + } + } + + case class Arg(expr: Tree, isByName: Boolean, argName: String) + + /** + * Transform a list of argument lists, producing the transformed lists, and lists of auxillary + * results. + * + * The function `f` need not concern itself with varargs arguments e.g (`xs : _*`). It will + * receive `xs`, and it's result will be re-wrapped as `f(xs) : _*`. + * + * @param fun The function being applied + * @param argss The argument lists + * @return (auxillary results, mapped argument trees) + */ + def mapArgumentss[A](fun: Tree, argss: List[List[Tree]])(f: Arg => (A, Tree)): (List[List[A]], List[List[Tree]]) = { + val isByNamess: (Int, Int) => Boolean = isByName(fun) + val argNamess: (Int, Int) => String = argName(fun) + argss.zipWithIndex.map { case (args, i) => + mapArguments[A](args) { + (tree, j) => f(Arg(tree, isByNamess(i, j), argNamess(i, j))) + } + }.unzip + } + + + def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match { + case Block(stats, expr) => (stats, expr) + case _ => (List(tree), Literal(Constant(()))) + } + + def emptyConstructor: DefDef = { + val emptySuperCall = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil) + DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), Literal(Constant(())))) + } + + def applied(className: String, types: List[Type]): AppliedTypeTree = + AppliedTypeTree(Ident(rootMirror.staticClass(className)), types.map(TypeTree(_))) + + /** Descends into the regions of the tree that are subject to the + * translation to a state machine by `async`. When a nested template, + * function, or by-name argument is encountered, the descent stops, + * and `nestedClass` etc are invoked. + */ + trait AsyncTraverser extends Traverser { + def nestedClass(classDef: ClassDef) { + } + + def nestedModule(module: ModuleDef) { + } + + def nestedMethod(module: DefDef) { + } + + def byNameArgument(arg: Tree) { + } + + def function(function: Function) { + } + + def patMatFunction(tree: Match) { + } + + override def traverse(tree: Tree) { + tree match { + case cd: ClassDef => nestedClass(cd) + case md: ModuleDef => nestedModule(md) + case dd: DefDef => nestedMethod(dd) + case fun: Function => function(fun) + case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions` + case treeInfo.Applied(fun, targs, argss) if argss.nonEmpty => + val isInByName = isByName(fun) + for ((args, i) <- argss.zipWithIndex) { + for ((arg, j) <- args.zipWithIndex) { + if (!isInByName(i, j)) traverse(arg) + else byNameArgument(arg) + } + } + traverse(fun) + case _ => super.traverse(tree) + } + } + } + + def abort(pos: Position, msg: String) = throw new AbortMacroException(pos, msg) + + abstract class MacroTypingTransformer extends TypingTransformer(callSiteTyper.context.unit) { + currentOwner = callSiteTyper.context.owner + + def currOwner: Symbol = currentOwner + + localTyper = global.analyzer.newTyper(callSiteTyper.context.make(unit = callSiteTyper.context.unit)) + } + + def transformAt(tree: Tree)(f: PartialFunction[Tree, (analyzer.Context => Tree)]) = { + object trans extends MacroTypingTransformer { + override def transform(tree: Tree): Tree = { + if (f.isDefinedAt(tree)) { + f(tree)(localTyper.context) + } else super.transform(tree) + } + } + trans.transform(tree) + } + + def changeOwner(tree: Tree, oldOwner: Symbol, newOwner: Symbol): tree.type = { + new ChangeOwnerAndModuleClassTraverser(oldOwner, newOwner).traverse(tree) + tree + } + + class ChangeOwnerAndModuleClassTraverser(oldowner: Symbol, newowner: Symbol) + extends ChangeOwnerTraverser(oldowner, newowner) { + + override def traverse(tree: Tree) { + tree match { + case _: DefTree => change(tree.symbol.moduleClass) + case _ => + } + super.traverse(tree) + } + } + + def toMultiMap[A, B](as: Iterable[(A, B)]): Map[A, List[B]] = + as.toList.groupBy(_._1).mapValues(_.map(_._2).toList).toMap + + // Attributed version of `TreeGen#mkCastPreservingAnnotations` + def mkAttributedCastPreservingAnnotations(tree: Tree, tp: Type): Tree = { + atPos(tree.pos) { + val casted = gen.mkAttributedCast(tree, tp.withoutAnnotations.dealias) + Typed(casted, TypeTree(tp)).setType(tp) + } + } +} |