diff options
Diffstat (limited to 'src/main/scala/scala')
-rw-r--r-- | src/main/scala/scala/async/AnfTransform.scala | 62 | ||||
-rw-r--r-- | src/main/scala/scala/async/Async.scala | 81 | ||||
-rw-r--r-- | src/main/scala/scala/async/AsyncAnalysis.scala | 39 | ||||
-rw-r--r-- | src/main/scala/scala/async/AsyncUtils.scala | 14 | ||||
-rw-r--r-- | src/main/scala/scala/async/ExprBuilder.scala | 426 | ||||
-rw-r--r-- | src/main/scala/scala/async/FutureSystem.scala | 2 | ||||
-rw-r--r-- | src/main/scala/scala/async/StateAssigner.scala | 4 | ||||
-rw-r--r-- | src/main/scala/scala/async/TransformUtils.scala | 62 |
8 files changed, 380 insertions, 310 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 6bf6e9a..0836634 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -1,13 +1,25 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + package scala.async import scala.reflect.macros.Context -class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { +private[async] final case class AnfTransform[C <: Context](val c: C) { import c.universe._ + val utils = TransformUtils[c.type](c) + import utils._ + + def apply(tree: Tree): List[Tree] = { + val unique = uniqueNames(tree) + // Must prepend the () for issue #31. + anf.transformToList(Block(List(c.literalUnit.tree), unique)) + } - def uniqueNames(tree: Tree): Tree = { + private def uniqueNames(tree: Tree): Tree = { new UniqueNames(tree).transform(tree) } @@ -16,7 +28,7 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { * * This step is needed to allow us to safely merge blocks during the `inline` transform below. */ - final class UniqueNames(tree: Tree) extends Transformer { + private final class UniqueNames(tree: Tree) extends Transformer { val repeatedNames: Set[Name] = tree.collect { case dt: DefTree => dt.symbol.name }.groupBy(x => x).filter(_._2.size > 1).keySet @@ -32,7 +44,7 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { val trans = super.transform(defTree) val origName = defTree.symbol.name val sym = defTree.symbol.asInstanceOf[symtab.Symbol] - val fresh = c.fresh("" + sym.name + "$") + val fresh = name.fresh(sym.name.toString) sym.name = defTree.symbol.name match { case _: TermName => symtab.newTermName(fresh) case _: TypeName => symtab.newTypeName(fresh) @@ -65,12 +77,29 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { } } - object inline { - def transformToList(tree: Tree): List[Tree] = { + private object trace { + private var indent = -1 + def indentString = " " * indent + def apply[T](prefix: String, args: Any)(t: => T): T = { + indent += 1 + def oneLine(s: Any) = s.toString.replaceAll("""\n""", "\\\\n").take(127) + try { + AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") + val result = t + AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") + result + } finally { + indent -= 1 + } + } + } + + private object inline { + def transformToList(tree: Tree): List[Tree] = trace("inline", tree) { val stats :+ expr = anf.transformToList(tree) expr match { case Apply(fun, args) if isAwait(fun) => - val valDef = defineVal("await", expr, tree.pos) + val valDef = defineVal(name.await, expr, tree.pos) stats :+ valDef :+ Ident(valDef.name) case If(cond, thenp, elsep) => @@ -79,7 +108,7 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { if (expr.tpe =:= definitions.UnitTpe) { stats :+ expr :+ Literal(Constant(())) } else { - val varDef = defineVar("ifres", expr.tpe, tree.pos) + val varDef = defineVar(name.ifRes, expr.tpe, tree.pos) def branchWithAssign(orig: Tree) = orig match { case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.name), thenExpr)) case _ => Assign(Ident(varDef.name), orig) @@ -95,7 +124,7 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { stats :+ expr :+ Literal(Constant(())) } else { - val varDef = defineVar("matchres", expr.tpe, tree.pos) + val varDef = defineVar(name.matchRes, expr.tpe, tree.pos) val casesWithAssign = cases map { case cd@CaseDef(pat, guard, Block(caseStats, caseExpr)) => attachCopy.CaseDef(cd)(pat, guard, Block(caseStats, Assign(Ident(varDef.name), caseExpr))) @@ -119,23 +148,22 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { case stats :+ expr => Block(stats, expr) } - def liftedName(prefix: String) = c.fresh(prefix + "$") - private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = { - val vd = ValDef(Modifiers(Flag.MUTABLE), liftedName(prefix), TypeTree(tp), defaultValue(tp)) + val vd = ValDef(Modifiers(Flag.MUTABLE), name.fresh(prefix), TypeTree(tp), defaultValue(tp)) vd.setPos(pos) vd } private def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = { - val vd = ValDef(NoMods, liftedName(prefix), TypeTree(), lhs) + val vd = ValDef(NoMods, name.fresh(prefix), TypeTree(), lhs) vd.setPos(pos) vd } } - object anf { - def transformToList(tree: Tree): List[Tree] = { + private object anf { + + private[AnfTransform] def transformToList(tree: Tree): List[Tree] = trace("anf", tree) { def containsAwait = tree exists isAwait tree match { case Select(qual, sel) if containsAwait => @@ -181,6 +209,9 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { } scrutStats :+ c.typeCheck(attachCopy.Match(tree)(scrutExpr, caseDefs)) + case LabelDef(name, params, rhs) if containsAwait => + List(LabelDef(name, params, Block(inline.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) + case TypeApply(fun, targs) if containsAwait => val funStats :+ simpleFun = inline.transformToList(fun) funStats :+ attachCopy.TypeApply(tree)(simpleFun, targs).setSymbol(tree.symbol) @@ -190,5 +221,4 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { } } } - } diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index 0b77c78..ef506a5 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -1,6 +1,7 @@ -/** +/* * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> */ + package scala.async import scala.language.experimental.macros @@ -63,13 +64,11 @@ abstract class AsyncBase { def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { import c.universe._ - import Flag._ - val builder = new ExprBuilder[c.type, futureSystem.type](c, self.futureSystem) - val anaylzer = new AsyncAnalysis[c.type](c) - - import builder.defn._ - import builder.name + val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem) + val anaylzer = AsyncAnalysis[c.type](c) + val utils = TransformUtils[c.type](c) + import utils.{name, defn} import builder.futureSystemOps anaylzer.reportUnsupportedAwaits(body.tree) @@ -78,12 +77,9 @@ abstract class AsyncBase { // - no await calls in qualifiers or arguments, // - if/match only used in statement position. val anfTree: Block = { - val transform = new AnfTransform[c.type](c) - val unique = transform.uniqueNames(body.tree) - val stats1 :+ expr1 = transform.anf.transformToList(unique) - + val anf = AnfTransform[c.type](c) + val stats1 :+ expr1 = anf(body.tree) val block = Block(stats1, expr1) - c.typeCheck(block).asInstanceOf[Block] } @@ -93,64 +89,21 @@ abstract class AsyncBase { val renameMap: Map[Symbol, TermName] = { anaylzer.valDefsUsedInSubsequentStates(anfTree).map { vd => - (vd.symbol, builder.name.fresh(vd.name)) + (vd.symbol, name.fresh(vd.name)) }.toMap } - val startState = builder.stateAssigner.nextState() - val endState = Int.MaxValue - - val asyncBlockBuilder = new builder.AsyncBlockBuilder(anfTree.stats, anfTree.expr, startState, endState, renameMap) - val handlerCases: List[CaseDef] = asyncBlockBuilder.mkCombinedHandlerCases[T]() - - import asyncBlockBuilder.asyncStates + val asyncBlock: builder.AsyncBlock = builder.build(anfTree, renameMap) + import asyncBlock.asyncStates logDiagnostics(c)(anfTree, asyncStates.map(_.toString)) - val initStates = asyncStates.init - val localVarTrees = asyncStates.flatMap(_.allVarDefs).toList - - /* - lazy val onCompleteHandler = (tr: Try[Any]) => state match { - case 0 => { - x11 = tr.get.asInstanceOf[Double]; - state = 1; - resume() - } - ... - */ - val onCompleteHandler = { - val onCompleteHandlers = initStates.flatMap(_.mkOnCompleteHandler()).toList - Function( - List(ValDef(Modifiers(PARAM), name.tr, TypeTree(TryAnyType), EmptyTree)), - Match(Ident(name.state), onCompleteHandlers)) - } - /* - def resume(): Unit = { - try { - state match { - case 0 => { - f11 = exprReturningFuture - f11.onComplete(onCompleteHandler)(context) - } - ... - } - } catch { - case NonFatal(t) => result.failure(t) - } - } - */ - val resumeFunTree: c.Tree = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), - Try( - Match(Ident(name.state), handlerCases), - List( - CaseDef( - Apply(Ident(NonFatalClass), List(Bind(name.tr, Ident(nme.WILDCARD)))), - EmptyTree, - Block(List({ - val t = c.Expr[Throwable](Ident(name.tr)) - futureSystemOps.completeProm[T](c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Failure(t.splice))).tree - }), c.literalUnit.tree))), EmptyTree)) + val localVarTrees = anfTree.collect { + case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol => + utils.mkVarDefTree(tpt.tpe, renameMap(vd.symbol)) + } + val onCompleteHandler = asyncBlock.onCompleteHandler + val resumeFunTree = asyncBlock.resumeFunTree[T] val prom: Expr[futureSystem.Prom[T]] = reify { // Create the empty promise diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index 180e006..4f5bf8d 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -1,11 +1,18 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + package scala.async import scala.reflect.macros.Context import collection.mutable -private[async] final class AsyncAnalysis[C <: Context](override val c: C) extends TransformUtils(c) { +private[async] final case class AsyncAnalysis[C <: Context](val c: C) { import c.universe._ + val utils = TransformUtils[c.type](c) + import utils._ + /** * Analyze the contents of an `async` block in order to: * - Report unsupported `await` calls under nested templates, functions, by-name arguments. @@ -29,7 +36,7 @@ private[async] final class AsyncAnalysis[C <: Context](override val c: C) extend analyzer.valDefsToLift.toList } - private class UnsupportedAwaitAnalyzer extends super.AsyncTraverser { + private class UnsupportedAwaitAnalyzer extends AsyncTraverser { override def nestedClass(classDef: ClassDef) { val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" if (!reportUnsupportedAwait(classDef, s"nested $kind")) { @@ -45,6 +52,10 @@ private[async] final class AsyncAnalysis[C <: Context](override val c: C) extend } } + override def nestedMethod(module: DefDef) { + reportUnsupportedAwait(module, "nested method") + } + override def byNameArgument(arg: Tree) { reportUnsupportedAwait(arg, "by-name argument") } @@ -53,6 +64,22 @@ private[async] final class AsyncAnalysis[C <: Context](override val c: C) extend reportUnsupportedAwait(function, "nested function") } + override def traverse(tree: Tree) { + def containsAwait = tree exists isAwait + tree match { + case Try(_, _, _) if containsAwait => + reportUnsupportedAwait(tree, "try/catch") + super.traverse(tree) + case If(cond, _, _) if containsAwait => + reportUnsupportedAwait(cond, "condition") + super.traverse(tree) + case Return(_) => + c.abort(tree.pos, "return is illegal within a async block") + case _ => + super.traverse(tree) + } + } + /** * @return true, if the tree contained an unsupported await. */ @@ -68,7 +95,7 @@ private[async] final class AsyncAnalysis[C <: Context](override val c: C) extend } } - private class AsyncDefinitionUseAnalyzer extends super.AsyncTraverser { + private class AsyncDefinitionUseAnalyzer extends AsyncTraverser { private var chunkId = 0 private def nextChunk() = chunkId += 1 @@ -83,6 +110,8 @@ private[async] final class AsyncAnalysis[C <: Context](override val c: C) extend traverseChunks(List(cond, thenp, elsep)) case Match(selector, cases) if tree exists isAwait => traverseChunks(selector :: cases) + case LabelDef(name, params, rhs) if rhs exists isAwait => + traverseChunks(rhs :: Nil) case Apply(fun, args) if isAwait(fun) => super.traverse(tree) nextChunk() @@ -92,10 +121,10 @@ private[async] final class AsyncAnalysis[C <: Context](override val c: C) extend if (isAwait(vd.rhs)) valDefsToLift += vd case as: Assign => if (isAwait(as.rhs)) { - assert(as.symbol != null, "internal error: null symbol for Assign tree:" + as) + assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol) // TODO test the orElse case, try to remove the restriction. - val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block. " + as.symbol)) + val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}")) valDefsToLift += vd } super.traverse(tree) diff --git a/src/main/scala/scala/async/AsyncUtils.scala b/src/main/scala/scala/async/AsyncUtils.scala index 77c155f..87a63d7 100644 --- a/src/main/scala/scala/async/AsyncUtils.scala +++ b/src/main/scala/scala/async/AsyncUtils.scala @@ -1,4 +1,4 @@ -/** +/* * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> */ package scala.async @@ -8,8 +8,12 @@ package scala.async */ object AsyncUtils { - private val verbose = false - - private[async] def vprintln(s: => Any): Unit = if (verbose) - println("[async] "+s) + private def enabled(level: String) = sys.props.getOrElse(s"scala.async.$level", "false").equalsIgnoreCase("true") + + private val verbose = enabled("debug") + private val trace = enabled("trace") + + private[async] def vprintln(s: => Any): Unit = if (verbose) println(s"[async] $s") + + private[async] def trace(s: => Any): Unit = if (trace) println(s"[async] $s") } diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index b7cfc8f..5ae01f9 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -1,4 +1,4 @@ -/** +/* * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> */ package scala.async @@ -10,212 +10,127 @@ import collection.mutable /* * @author Philipp Haller */ -final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val futureSystem: FS) - extends TransformUtils(c) { +private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSystem: FS) { builder => + val utils = TransformUtils[c.type](c) + import c.universe._ + import utils._ import defn._ - private[async] object name { - def suffix(string: String) = string + "$async" - - def suffixedName(prefix: String) = newTermName(suffix(prefix)) - - val state = suffixedName("state") - val result = suffixedName("result") - val resume = suffixedName("resume") - val execContext = suffixedName("execContext") - - // TODO do we need to freshen any of these? - val x1 = newTermName("x$1") - val tr = newTermName("tr") - val onCompleteHandler = suffixedName("onCompleteHandler") - - def fresh(name: TermName) = if (name.toString.contains("$")) name else newTermName(c.fresh("" + name + "$")) - } - - private[async] lazy val futureSystemOps = futureSystem.mkOps(c) - - private def resetDuplicate(tree: Tree) = c.resetAllAttrs(tree.duplicate) - - private def mkResumeApply = Apply(Ident(name.resume), Nil) + lazy val futureSystemOps = futureSystem.mkOps(c) - private def mkStateTree(nextState: Int): c.Tree = - mkStateTree(c.literal(nextState).tree) + val stateAssigner = new StateAssigner + val labelDefStates = collection.mutable.Map[Symbol, Int]() - private def mkStateTree(nextState: Tree): c.Tree = - Assign(Ident(name.state), nextState) + trait AsyncState { + def state: Int - private def mkVarDefTree(resultType: Type, resultName: TermName): c.Tree = { - ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType)) - } + def mkHandlerCaseForState: CaseDef - private def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef = - mkHandlerCase(num, Block(rhs: _*)) + def mkOnCompleteHandler: Option[CaseDef] = None - private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = - CaseDef(c.literal(num).tree, EmptyTree, rhs) + def stats: List[Tree] - class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) { - val body: c.Tree = stats match { + final def body: c.Tree = stats match { case stat :: Nil => stat case _ => Block(stats: _*) } + } - val varDefs: List[(TermName, Type)] = Nil + /** A sequence of statements the concludes with a unconditional transition to `nextState` */ + final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int) + extends AsyncState { - def mkHandlerCaseForState(): CaseDef = + def mkHandlerCaseForState: CaseDef = mkHandlerCase(state, stats :+ mkStateTree(nextState) :+ mkResumeApply) - def mkOnCompleteHandler(): Option[CaseDef] = { - this match { - case aw: AsyncStateWithAwait => - val tryGetTree = - Assign( - Ident(aw.resultName), - TypeApply(Select(Select(Ident(name.tr), Try_get), newTermName("asInstanceOf")), List(TypeTree(aw.resultType))) - ) - val updateState = mkStateTree(nextState) - Some(mkHandlerCase(state, List(tryGetTree, updateState, mkResumeApply))) - case _ => - None - } - } - - def varDefForResult: Option[c.Tree] = - None - - def allVarDefs: List[c.Tree] = - varDefForResult.toList ++ varDefs.map(p => mkVarDefTree(p._2, p._1)) - override val toString: String = s"AsyncState #$state, next = $nextState" } - class AsyncStateWithoutAwait(stats: List[c.Tree], state: Int) - extends AsyncState(stats, state, 0) { - // nextState unused, since encoded in then and else branches - - override def mkHandlerCaseForState(): CaseDef = + /** A sequence of statements with a conditional transition to the next state, which will represent + * a branch of an `if` or a `match`. + */ + final class AsyncStateWithoutAwait(val stats: List[c.Tree], val state: Int) extends AsyncState { + override def mkHandlerCaseForState: CaseDef = mkHandlerCase(state, stats) override val toString: String = - s"AsyncStateWithIf #$state, next = $nextState" + s"AsyncStateWithoutAwait #$state" } - abstract class AsyncStateWithAwait(stats: List[c.Tree], state: Int, nextState: Int) - extends AsyncState(stats, state, nextState) { - val awaitable : c.Tree - val resultName: TermName - val resultType: Type - - protected def tryType = appliedType(TryClass.toType, List(resultType)) - - override val toString: String = - s"AsyncStateWithAwait #$state, next = $nextState" - - private def mkOnCompleteTree: c.Tree = { - futureSystemOps.onComplete(c.Expr(awaitable), c.Expr(Ident(name.onCompleteHandler)), c.Expr(Ident(name.execContext))).tree + /** A sequence of statements that concludes with an `await` call. The `onComplete` + * handler will unconditionally transition to `nestState`.`` + */ + final class AsyncStateWithAwait(val stats: List[c.Tree], val state: Int, nextState: Int, + awaitable: Awaitable) + extends AsyncState { + + override def mkHandlerCaseForState: CaseDef = { + val callOnComplete = futureSystemOps.onComplete(c.Expr(awaitable.expr), + c.Expr(Ident(name.onCompleteHandler)), c.Expr(Ident(name.execContext))).tree + mkHandlerCase(state, stats :+ callOnComplete) } - override def mkHandlerCaseForState(): CaseDef = { - assert(awaitable != null) - mkHandlerCase(state, stats :+ mkOnCompleteTree) + override def mkOnCompleteHandler: Option[CaseDef] = { + val tryGetTree = + Assign( + Ident(awaitable.resultName), + TypeApply(Select(Select(Ident(name.tr), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) + ) + Some(mkHandlerCase(state, List(tryGetTree, mkStateTree(nextState), mkResumeApply))) } - override def varDefForResult: Option[c.Tree] = - Some(mkVarDefTree(resultType, resultName)) + override val toString: String = + s"AsyncStateWithAwait #$state, next = $nextState" } /* * Builder for a single state of an async method. */ - class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name]) { - self => - + final class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name]) { /* Statements preceding an await call. */ - private val stats = ListBuffer[c.Tree]() + private val stats = ListBuffer[c.Tree]() + /** The state of the target of a LabelDef application (while loop jump) */ + private var nextJumpState: Option[Int] = None - /* Argument of an await call. */ - var awaitable: c.Tree = null - - /* Result name of an await call. */ - var resultName: TermName = null - - /* Result type of an await call. */ - var resultType: Type = null - - var nextState: Int = -1 - - private val varDefs = ListBuffer[(TermName, Type)]() - - private val renamer = new Transformer { - override def transform(tree: Tree) = tree match { - case Ident(_) if nameMap.keySet contains tree.symbol => - Ident(nameMap(tree.symbol)) - case _ => - super.transform(tree) - } - } + private def renameReset(tree: Tree) = resetDuplicate(substituteNames(tree, nameMap)) def +=(stat: c.Tree): this.type = { - stats += resetDuplicate(renamer.transform(stat)) - this - } - - //TODO do not ignore `mods` - def addVarDef(mods: Any, name: TermName, tpt: c.Tree, rhs: c.Tree): this.type = { - varDefs += (name -> tpt.tpe) - this += Assign(Ident(name), rhs) + assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat") + def addStat() = stats += renameReset(stat) + stat match { + case Apply(fun, Nil) => + labelDefStates get fun.symbol match { + case Some(nextState) => nextJumpState = Some(nextState) + case None => addStat() + } + case _ => addStat() + } this } - def result(): AsyncState = - if (awaitable == null) - new AsyncState(stats.toList, state, nextState) { - override val varDefs = self.varDefs.toList - } - else - new AsyncStateWithAwait(stats.toList, state, nextState) { - val awaitable = self.awaitable - val resultName = self.resultName - val resultType = self.resultType - override val varDefs = self.varDefs.toList - } - - /* Result needs to be created as a var at the beginning of the transformed method body, so that - * it is visible in subsequent states of the state machine. - * - * @param awaitArg the argument of await - * @param awaitResultName the name of the variable that the result of await is assigned to - * @param awaitResultType the type of the result of await - */ - def complete(awaitArg: c.Tree, awaitResultName: TermName, awaitResultType: Tree, - nextState: Int): this.type = { - val renamed = renamer.transform(awaitArg) - awaitable = resetDuplicate(renamed) - resultName = awaitResultName - resultType = awaitResultType.tpe - this.nextState = nextState - this + def resultWithAwait(awaitable: Awaitable, + nextState: Int): AsyncState = { + val sanitizedAwaitable = awaitable.copy(expr = renameReset(awaitable.expr)) + val effectiveNextState = nextJumpState.getOrElse(nextState) + new AsyncStateWithAwait(stats.toList, state, effectiveNextState, sanitizedAwaitable) } - def complete(nextState: Int): this.type = { - this.nextState = nextState - this + def resultSimple(nextState: Int): AsyncState = { + val effectiveNextState = nextJumpState.getOrElse(nextState) + new SimpleAsyncState(stats.toList, state, effectiveNextState) } def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = { // 1. build changed if-else tree // 2. insert that tree at the end of the current state - val cond = resetDuplicate(condTree) - this += If(cond, - Block(mkStateTree(thenState), mkResumeApply), - Block(mkStateTree(elseState), mkResumeApply)) - new AsyncStateWithoutAwait(stats.toList, state) { - override val varDefs = self.varDefs.toList - } + val cond = renameReset(condTree) + def mkBranch(state: Int) = Block(mkStateTree(state), mkResumeApply) + this += If(cond, mkBranch(thenState), mkBranch(elseState)) + new AsyncStateWithoutAwait(stats.toList, state) } /** @@ -234,20 +149,21 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(caseStates(num)), mkResumeApply)) } // 2. insert changed match tree at the end of the current state - this += Match(resetDuplicate(scrutTree), newCases) - new AsyncStateWithoutAwait(stats.toList, state) { - override val varDefs = self.varDefs.toList - } + this += Match(renameReset(scrutTree), newCases) + new AsyncStateWithoutAwait(stats.toList, state) + } + + def resultWithLabel(startLabelState: Int): AsyncState = { + this += Block(mkStateTree(startLabelState), mkResumeApply) + new AsyncStateWithoutAwait(stats.toList, state) } override def toString: String = { val statsBeforeAwait = stats.mkString("\n") - s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName" + s"ASYNC STATE:\n$statsBeforeAwait" } } - val stateAssigner = new StateAssigner - /** * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`). * @@ -257,107 +173,191 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val * @param endState the state to continue with * @param toRename a `Map` for renaming the given key symbols to the mangled value names */ - class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, - private val toRename: Map[Symbol, c.Name]) { - val asyncStates = ListBuffer[builder.AsyncState]() + final private class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, + private val toRename: Map[Symbol, c.Name]) { + val asyncStates = ListBuffer[AsyncState]() - private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename) - // current state builder - private var currState = startState + var stateBuilder = new AsyncStateBuilder(startState, toRename) + var currState = startState /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists { case Apply(fun, _) if isAwait(fun) => true case _ => false - }) c.abort(tree.pos, "await unsupported in this position") //throw new FallbackToCpsException + }) c.abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException - def builderForBranch(tree: c.Tree, state: Int, nextState: Int): AsyncBlockBuilder = { - val (branchStats, branchExpr) = tree match { - case Block(s, e) => (s, e) - case _ => (List(tree), c.literalUnit.tree) - } - new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, toRename) + def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = { + val (nestedStats, nestedExpr) = statsAndExpr(nestedTree) + new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, toRename) } + import stateAssigner.nextState + // populate asyncStates for (stat <- stats) stat match { // the val name = await(..) pattern - case ValDef(mods, name, tpt, Apply(fun, args)) if isAwait(fun) => - val afterAwaitState = stateAssigner.nextState() - asyncStates += stateBuilder.complete(args.head, toRename(stat.symbol).toTermName, tpt, afterAwaitState).result // complete with await + case ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => + val afterAwaitState = nextState() + val awaitable = Awaitable(arg, toRename(stat.symbol).toTermName, tpt.tpe) + asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await currState = afterAwaitState - stateBuilder = new builder.AsyncStateBuilder(currState, toRename) + stateBuilder = new AsyncStateBuilder(currState, toRename) case ValDef(mods, name, tpt, rhs) if toRename contains stat.symbol => checkForUnsupportedAwait(rhs) - - // when adding assignment need to take `toRename` into account - stateBuilder.addVarDef(mods, toRename(stat.symbol).toTermName, tpt, rhs) + stateBuilder += Assign(Ident(toRename(stat.symbol).toTermName), rhs) case If(cond, thenp, elsep) if stat exists isAwait => checkForUnsupportedAwait(cond) - val thenStartState = stateAssigner.nextState() - val elseStartState = stateAssigner.nextState() - val afterIfState = stateAssigner.nextState() + val thenStartState = nextState() + val elseStartState = nextState() + val afterIfState = nextState() asyncStates += // the two Int arguments are the start state of the then branch and the else branch, respectively stateBuilder.resultWithIf(cond, thenStartState, elseStartState) List((thenp, thenStartState), (elsep, elseStartState)) foreach { - case (tree, state) => - val builder = builderForBranch(tree, state, afterIfState) + case (branchTree, state) => + val builder = nestedBlockBuilder(branchTree, state, afterIfState) asyncStates ++= builder.asyncStates } currState = afterIfState - stateBuilder = new builder.AsyncStateBuilder(currState, toRename) + stateBuilder = new AsyncStateBuilder(currState, toRename) case Match(scrutinee, cases) if stat exists isAwait => checkForUnsupportedAwait(scrutinee) - val caseStates = cases.map(_ => stateAssigner.nextState()) - val afterMatchState = stateAssigner.nextState() + val caseStates = cases.map(_ => nextState()) + val afterMatchState = nextState() asyncStates += stateBuilder.resultWithMatch(scrutinee, cases, caseStates) for ((cas, num) <- cases.zipWithIndex) { - val (casStats, casExpr) = cas match { - case CaseDef(_, _, Block(s, e)) => (s, e) - case CaseDef(_, _, rhs) => (List(rhs), c.literalUnit.tree) - } - val builder = new AsyncBlockBuilder(casStats, casExpr, caseStates(num), afterMatchState, toRename) + val builder = nestedBlockBuilder(cas.body, caseStates(num), afterMatchState) asyncStates ++= builder.asyncStates } currState = afterMatchState - stateBuilder = new builder.AsyncStateBuilder(currState, toRename) - - case _ => + stateBuilder = new AsyncStateBuilder(currState, toRename) + + case ld@LabelDef(name, params, rhs) if rhs exists isAwait => + val startLabelState = nextState() + val afterLabelState = nextState() + asyncStates += stateBuilder.resultWithLabel(startLabelState) + labelDefStates(ld.symbol) = startLabelState + val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState) + asyncStates ++= builder.asyncStates + + currState = afterLabelState + stateBuilder = new AsyncStateBuilder(currState, toRename) + case _ => checkForUnsupportedAwait(stat) stateBuilder += stat } // complete last state builder (representing the expressions after the last await) stateBuilder += expr - val lastState = stateBuilder.complete(endState).result() + val lastState = stateBuilder.resultSimple(endState) asyncStates += lastState + } + + trait AsyncBlock { + def asyncStates: List[AsyncState] - def mkCombinedHandlerCases[T](): List[CaseDef] = { - val caseForLastState: CaseDef = { - val lastState = asyncStates.last - val lastStateBody = c.Expr[T](lastState.body) - val rhs = futureSystemOps.completeProm(c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Success(lastStateBody.splice))) - mkHandlerCase(lastState.state, rhs.tree) + def onCompleteHandler: Tree + + def resumeFunTree[T]: Tree + } + + def build(block: Block, toRename: Map[Symbol, c.Name]): AsyncBlock = { + val Block(stats, expr) = block + val startState = stateAssigner.nextState() + val endState = Int.MaxValue + + val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, toRename) + + new AsyncBlock { + def asyncStates = blockBuilder.asyncStates.toList + + def mkCombinedHandlerCases[T]: List[CaseDef] = { + val caseForLastState: CaseDef = { + val lastState = asyncStates.last + val lastStateBody = c.Expr[T](lastState.body) + val rhs = futureSystemOps.completeProm( + c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Success(lastStateBody.splice))) + mkHandlerCase(lastState.state, rhs.tree) + } + asyncStates.toList match { + case s :: Nil => + List(caseForLastState) + case _ => + val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState + initCases :+ caseForLastState + } } - asyncStates.toList match { - case s :: Nil => - List(caseForLastState) - case _ => - val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState() - initCases :+ caseForLastState + + val initStates = asyncStates.init + + /** + * lazy val onCompleteHandler = (tr: Try[Any]) => state match { + * case 0 => { + * x11 = tr.get.asInstanceOf[Double]; + * state = 1; + * resume() + * } + */ + val onCompleteHandler: Tree = { + val onCompleteHandlers = initStates.flatMap(_.mkOnCompleteHandler).toList + Function( + List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)), + Match(Ident(name.state), onCompleteHandlers)) } + + /** + * def resume(): Unit = { + * try { + * state match { + * case 0 => { + * f11 = exprReturningFuture + * f11.onComplete(onCompleteHandler)(context) + * } + * ... + * } + * } catch { + * case NonFatal(t) => result.failure(t) + * } + * } + */ + def resumeFunTree[T]: Tree = + DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), + Try( + Match(Ident(name.state), mkCombinedHandlerCases[T]), + List( + CaseDef( + Apply(Ident(defn.NonFatalClass), List(Bind(name.tr, Ident(nme.WILDCARD)))), + EmptyTree, + Block(List({ + val t = c.Expr[Throwable](Ident(name.tr)) + futureSystemOps.completeProm[T](c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Failure(t.splice))).tree + }), c.literalUnit.tree))), EmptyTree)) } } + + private final case class Awaitable(expr: Tree, resultName: TermName, resultType: Type) + + private def resetDuplicate(tree: Tree) = c.resetAllAttrs(tree.duplicate) + + private def mkResumeApply = Apply(Ident(name.resume), Nil) + + private def mkStateTree(nextState: Int): c.Tree = + Assign(Ident(name.state), c.literal(nextState).tree) + + private def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef = + mkHandlerCase(num, Block(rhs: _*)) + + private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = + CaseDef(c.literal(num).tree, EmptyTree, rhs) } diff --git a/src/main/scala/scala/async/FutureSystem.scala b/src/main/scala/scala/async/FutureSystem.scala index ed967cb..20bbea3 100644 --- a/src/main/scala/scala/async/FutureSystem.scala +++ b/src/main/scala/scala/async/FutureSystem.scala @@ -1,4 +1,4 @@ -/** +/* * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> */ package scala.async diff --git a/src/main/scala/scala/async/StateAssigner.scala b/src/main/scala/scala/async/StateAssigner.scala index 4f6c5a0..bc60a6d 100644 --- a/src/main/scala/scala/async/StateAssigner.scala +++ b/src/main/scala/scala/async/StateAssigner.scala @@ -1,3 +1,7 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + package scala.async private[async] final class StateAssigner { diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index b79be87..8838bb3 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -1,4 +1,4 @@ -/** +/* * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> */ package scala.async @@ -9,11 +9,34 @@ import reflect.ClassTag /** * Utilities used in both `ExprBuilder` and `AnfTransform`. */ -class TransformUtils[C <: Context](val c: C) { +private[async] final case class TransformUtils[C <: Context](val c: C) { import c.universe._ - protected def defaultValue(tpe: Type): Literal = { + object name { + def suffix(string: String) = string + "$async" + + def suffixedName(prefix: String) = newTermName(suffix(prefix)) + + val state = suffixedName("state") + val result = suffixedName("result") + val resume = suffixedName("resume") + val execContext = suffixedName("execContext") + + // TODO do we need to freshen any of these? + val tr = newTermName("tr") + val onCompleteHandler = suffixedName("onCompleteHandler") + + val matchRes = "matchres" + val ifRes = "ifres" + val await = "await" + + def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) + + def fresh(name: String): String = if (name.toString.contains("$")) name else c.fresh("" + name + "$") + } + + def defaultValue(tpe: Type): Literal = { val defaultValue: Any = if (tpe <:< definitions.BooleanTpe) false else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0 @@ -22,9 +45,22 @@ class TransformUtils[C <: Context](val c: C) { Literal(Constant(defaultValue)) } - protected def isAwait(fun: Tree) = + def isAwait(fun: Tree) = fun.symbol == defn.Async_await + /** Replace all `Ident` nodes referring to one of the keys n `renameMap` with a node + * referring to the corresponding new name + */ + def substituteNames(tree: Tree, renameMap: Map[Symbol, Name]): Tree = { + val renamer = new Transformer { + override def transform(tree: Tree) = tree match { + case Ident(_) => (renameMap get tree.symbol).fold(tree)(Ident(_)) + case _ => super.transform(tree) + } + } + renamer.transform(tree) + } + /** Descends into the regions of the tree that are subject to the * translation to a state machine by `async`. When a nested template, * function, or by-name argument is encountered, the descend stops, @@ -37,6 +73,9 @@ class TransformUtils[C <: Context](val c: C) { def nestedModule(module: ModuleDef) { } + def nestedMethod(module: DefDef) { + } + def byNameArgument(arg: Tree) { } @@ -47,6 +86,7 @@ class TransformUtils[C <: Context](val c: C) { tree match { case cd: ClassDef => nestedClass(cd) case md: ModuleDef => nestedModule(md) + case dd: DefDef => nestedMethod(dd) case fun: Function => function(fun) case Apply(fun, args) => val isInByName = isByName(fun) @@ -68,7 +108,7 @@ class TransformUtils[C <: Context](val c: C) { Set(Boolean_&&, Boolean_||) } - protected def isByName(fun: Tree): (Int => Boolean) = { + def isByName(fun: Tree): (Int => Boolean) = { if (Boolean_ShortCircuits contains fun.symbol) i => true else fun.tpe match { case MethodType(params, _) => @@ -78,7 +118,16 @@ class TransformUtils[C <: Context](val c: C) { } } - private[async] object defn { + def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match { + case Block(stats, expr) => (stats, expr) + case _ => (List(tree), Literal(Constant(()))) + } + + def mkVarDefTree(resultType: Type, resultName: TermName): c.Tree = { + ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType)) + } + + object defn { def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) } @@ -157,4 +206,5 @@ class TransformUtils[C <: Context](val c: C) { def ValDef(tree: Tree)(mods: Modifiers, name: TermName, tpt: Tree, rhs: Tree): ValDef = copyAttach(tree, c.universe.ValDef(mods, name, tpt, rhs)) } + } |