diff options
author | Philipp Haller <hallerp@gmail.com> | 2012-11-22 06:10:00 -0800 |
---|---|---|
committer | Philipp Haller <hallerp@gmail.com> | 2012-11-22 06:10:00 -0800 |
commit | 1c91fec998d09e31c2c52760452af1771a092182 (patch) | |
tree | 8733f9b854baa83194b1688fa30ed5fc90fd249c | |
parent | f451904320d02c7dbe6b298f6ff790ca5cf5f080 (diff) | |
parent | 8e4a8ecdff955c4faa1dec344a2b93543ffe7d45 (diff) | |
download | scala-async-1c91fec998d09e31c2c52760452af1771a092182.tar.gz scala-async-1c91fec998d09e31c2c52760452af1771a092182.tar.bz2 scala-async-1c91fec998d09e31c2c52760452af1771a092182.zip |
Merge pull request #25 from phaller/topic/minimal-var-lifting-2
Topic/minimal var lifting 2
-rw-r--r-- | build.sbt | 2 | ||||
-rw-r--r-- | src/main/scala/scala/async/AnfTransform.scala | 101 | ||||
-rw-r--r-- | src/main/scala/scala/async/Async.scala | 65 | ||||
-rw-r--r-- | src/main/scala/scala/async/ExprBuilder.scala | 219 | ||||
-rw-r--r-- | src/main/scala/scala/async/StateAssigner.scala | 10 | ||||
-rw-r--r-- | src/main/scala/scala/async/TransformUtils.scala | 22 | ||||
-rw-r--r-- | src/test/scala/scala/async/TestUtils.scala | 4 | ||||
-rw-r--r-- | src/test/scala/scala/async/TreeInterrogation.scala | 3 | ||||
-rw-r--r-- | src/test/scala/scala/async/neg/NakedAwait.scala | 72 | ||||
-rw-r--r-- | src/test/scala/scala/async/run/anf/AnfTransformSpec.scala | 146 |
10 files changed, 542 insertions, 102 deletions
@@ -19,6 +19,8 @@ libraryDependencies += "com.novocode" % "junit-interface" % "0.10-M2" % "test" testOptions += Tests.Argument(TestFrameworks.JUnit, "-q", "-v", "-s") +parallelExecution in Global := false + autoCompilerPlugins := true libraryDependencies <<= (scalaVersion, libraryDependencies) { diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala new file mode 100644 index 0000000..e1d7cd5 --- /dev/null +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -0,0 +1,101 @@ +package scala.async + +import scala.reflect.macros.Context + +class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { + import c.universe._ + import AsyncUtils._ + + object inline { + def transformToList(tree: Tree): List[Tree] = { + val stats :+ expr = anf.transformToList(tree) + expr match { + + case Apply(fun, args) if fun.toString.startsWith("scala.async.Async.await") => + val liftedName = c.fresh("await$") + stats :+ ValDef(NoMods, liftedName, TypeTree(), expr) :+ Ident(liftedName) + + case If(cond, thenp, elsep) => + // if type of if-else is Unit don't introduce assignment, + // but add Unit value to bring it into form expected by async transform + if (expr.tpe =:= definitions.UnitTpe) { + stats :+ expr :+ Literal(Constant(())) + } + else { + val liftedName = c.fresh("ifres$") + val varDef = + ValDef(Modifiers(Flag.MUTABLE), liftedName, TypeTree(expr.tpe), defaultValue(expr.tpe)) + val thenWithAssign = thenp match { + case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(liftedName), thenExpr)) + case _ => Assign(Ident(liftedName), thenp) + } + val elseWithAssign = elsep match { + case Block(elseStats, elseExpr) => Block(elseStats, Assign(Ident(liftedName), elseExpr)) + case _ => Assign(Ident(liftedName), elsep) + } + val ifWithAssign = + If(cond, thenWithAssign, elseWithAssign) + stats :+ varDef :+ ifWithAssign :+ Ident(liftedName) + } + case _ => + stats :+ expr + } + } + + def transformToList(trees: List[Tree]): List[Tree] = trees match { + case fst :: rest => transformToList(fst) ++ transformToList(rest) + case Nil => Nil + } + } + + object anf { + def transformToList(tree: Tree): List[Tree] = tree match { + case Select(qual, sel) => + val stats :+ expr = inline.transformToList(qual) + stats :+ Select(expr, sel) + + case Apply(fun, args) => + val funStats :+ simpleFun = inline.transformToList(fun) + val argLists = args map inline.transformToList + val allArgStats = argLists flatMap (_.init) + val simpleArgs = argLists map (_.last) + funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs) + + case Block(stats, expr) => + inline.transformToList(stats) ++ inline.transformToList(expr) + + case ValDef(mods, name, tpt, rhs) => + val stats :+ expr = inline.transformToList(rhs) + stats :+ ValDef(mods, name, tpt, expr).setSymbol(tree.symbol) + + case Assign(name, rhs) => + val stats :+ expr = inline.transformToList(rhs) + stats :+ Assign(name, expr) + + case If(cond, thenp, elsep) => + val stats :+ expr = inline.transformToList(cond) + val thenStats :+ thenExpr = inline.transformToList(thenp) + val elseStats :+ elseExpr = inline.transformToList(elsep) + stats :+ + c.typeCheck(If(expr, Block(thenStats, thenExpr), Block(elseStats, elseExpr))) + + //TODO + case Literal(_) | Ident(_) | This(_) | Match(_, _) | New(_) | Function(_, _) => List(tree) + + case TypeApply(fun, targs) => + val funStats :+ simpleFun = inline.transformToList(fun) + funStats :+ TypeApply(simpleFun, targs) + + //TODO + case DefDef(mods, name, tparams, vparamss, tpt, rhs) => List(tree) + + case ClassDef(mods, name, tparams, impl) => List(tree) + + case ModuleDef(mods, name, impl) => List(tree) + + case _ => + c.error(tree.pos, "Internal error while compiling `async` block") + ??? + } + } +} diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index b71ce74..bd766f2 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -4,18 +4,14 @@ package scala.async import scala.language.experimental.macros - import scala.reflect.macros.Context -import scala.collection.mutable.ListBuffer -import scala.concurrent.{Future, Promise, ExecutionContext, future} -import ExecutionContext.Implicits.global -import scala.util.control.NonFatal - /* * @author Philipp Haller */ object Async extends AsyncBase { + import scala.concurrent.Future + lazy val futureSystem = ScalaConcurrentFutureSystem type FS = ScalaConcurrentFutureSystem.type @@ -52,11 +48,11 @@ abstract class AsyncBase { /** * A call to `await` must be nested in an enclosing `async` block. - * + * * A call to `await` does not block the current thread, rather it is a delimiter * used by the enclosing `async` macro. Code following the `await` * call is executed asynchronously, when the argument of `await` has been completed. - * + * * @param awaitable the future from which a value is awaited. * @tparam T the type of that value. * @return the value. @@ -74,19 +70,38 @@ abstract class AsyncBase { import builder.defn._ import builder.name import builder.futureSystemOps - val (stats, expr) = body.tree match { - case Block(stats, expr) => (stats, expr) - case tree => (Nil, tree) + + // Transform to A-normal form: + // - no await calls in qualifiers or arguments, + // - if/match only used in statement position. + val anfTree: Block = { + val transform = new AnfTransform[c.type](c) + val stats1 :+ expr1 = transform.anf.transformToList(body.tree) + c.typeCheck(Block(stats1, expr1)).asInstanceOf[Block] } - val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000, Map()) + // Analyze the block to find locals that will be accessed from multiple + // states of our generated state machine, e.g. a value assigned before + // an `await` and read afterwards. + val renameMap: Map[Symbol, TermName] = { + val analyzer = new builder.AsyncAnalyzer + analyzer.traverse(anfTree) + analyzer.valDefsToLift.map { + vd => + (vd.symbol, builder.name.fresh(vd.name)) + }.toMap + } - asyncBlockBuilder.asyncStates foreach (s => AsyncUtils.vprintln(s)) + val startState = builder.stateAssigner.nextState() + val endState = Int.MaxValue + val asyncBlockBuilder = new builder.AsyncBlockBuilder(anfTree.stats, anfTree.expr, startState, endState, renameMap) val handlerCases: List[CaseDef] = asyncBlockBuilder.mkCombinedHandlerCases[T]() - val initStates = asyncBlockBuilder.asyncStates.init - val localVarTrees = initStates.flatMap(_.allVarDefs).toList + import asyncBlockBuilder.asyncStates + logDiagnostics(c)(anfTree, asyncStates.map(_.toString)) + val initStates = asyncStates.init + val localVarTrees = asyncStates.flatMap(_.allVarDefs).toList /* lazy val onCompleteHandler = (tr: Try[Any]) => state match { @@ -98,7 +113,7 @@ abstract class AsyncBase { ... */ val onCompleteHandler = { - val onCompleteHandlers = initStates.flatMap(_.mkOnCompleteHandler).toList + val onCompleteHandlers = initStates.flatMap(_.mkOnCompleteHandler()).toList Function( List(ValDef(Modifiers(PARAM), name.tr, TypeTree(TryAnyType), EmptyTree)), Match(Ident(name.state), onCompleteHandlers)) @@ -144,7 +159,7 @@ abstract class AsyncBase { // Spawn a future to: futureSystemOps.future[Unit] { c.Expr[Unit](Block( - // define vars for all intermediate results + // define vars for all intermediate results that are accessed from multiple states localVarTrees :+ // define the resume() method resumeFunTree :+ @@ -159,8 +174,22 @@ abstract class AsyncBase { // ... and return its Future from the macro. val result = futureSystemOps.promiseToFuture(prom) - AsyncUtils.vprintln(s"${c.macroApplication} \nexpands to:\n ${result.tree}") + AsyncUtils.vprintln(s"async state machine transform expands to:\n ${result.tree}") result } + + def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) { + def location = try { + c.macroApplication.pos.source.path + } catch { + case _: UnsupportedOperationException => + c.macroApplication.pos.toString + } + + AsyncUtils.vprintln(s"In file '$location':") + AsyncUtils.vprintln(s"${c.macroApplication}") + AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree") + states foreach (s => AsyncUtils.vprintln(s)) + } } diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 2b35ff4..7a9c98d 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -5,17 +5,16 @@ package scala.async import scala.reflect.macros.Context import scala.collection.mutable.ListBuffer -import scala.concurrent.Future -import AsyncUtils.vprintln +import collection.mutable /* * @author Philipp Haller */ -final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSystem: FS) { +final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val futureSystem: FS) + extends TransformUtils(c) { builder => import c.universe._ - import Flag._ import defn._ private[async] object name { @@ -38,8 +37,6 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy private[async] lazy val futureSystemOps = futureSystem.mkOps(c) - private val execContext = futureSystemOps.execContext - private def resetDuplicate(tree: Tree) = c.resetAllAttrs(tree.duplicate) private def mkResumeApply = Apply(Ident(name.resume), Nil) @@ -50,14 +47,6 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy private def mkStateTree(nextState: Tree): c.Tree = Assign(Ident(name.state), nextState) - private def defaultValue(tpe: Type): Literal = { - val defaultValue: Any = - if (tpe <:< definitions.BooleanTpe) false - else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0 - else null - Literal(Constant(defaultValue)) - } - private def mkVarDefTree(resultType: Type, resultName: TermName): c.Tree = { ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType)) } @@ -87,7 +76,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy Ident(aw.resultName), TypeApply(Select(Select(Ident(name.tr), Try_get), newTermName("asInstanceOf")), List(TypeTree(aw.resultType))) ) - val updateState = mkStateTree(nextState) // or increment? + val updateState = mkStateTree(nextState) Some(mkHandlerCase(state, List(tryGetTree, updateState, mkResumeApply))) case _ => None @@ -142,7 +131,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy /* * Builder for a single state of an async method. */ - class AsyncStateBuilder(state: Int, private var nameMap: Map[c.Symbol, c.Name]) { + class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name]) { self => /* Statements preceding an await call. */ @@ -157,7 +146,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy /* Result type of an await call. */ var resultType: Type = null - var nextState: Int = state + 1 + var nextState: Int = -1 private val varDefs = ListBuffer[(TermName, Type)]() @@ -176,9 +165,8 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy } //TODO do not ignore `mods` - def addVarDef(mods: Any, name: TermName, tpt: c.Tree, rhs: c.Tree, extNameMap: Map[c.Symbol, c.Name]): this.type = { + def addVarDef(mods: Any, name: TermName, tpt: c.Tree, rhs: c.Tree): this.type = { varDefs += (name -> tpt.tpe) - nameMap ++= extNameMap // update name map this += Assign(Ident(name), rhs) this } @@ -204,9 +192,9 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy * @param awaitResultType the type of the result of await */ def complete(awaitArg: c.Tree, awaitResultName: TermName, awaitResultType: Tree, - extNameMap: Map[c.Symbol, c.Name], nextState: Int = state + 1): this.type = { - nameMap ++= extNameMap - awaitable = resetDuplicate(renamer.transform(awaitArg)) + nextState: Int): this.type = { + val renamed = renamer.transform(awaitArg) + awaitable = resetDuplicate(renamed) resultName = awaitResultName resultType = awaitResultType.tpe this.nextState = nextState @@ -237,14 +225,13 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy * * @param scrutTree tree of the scrutinee * @param cases list of case definitions - * @param stateFirstCase state of the right-hand side of the first case - * @param perCaseBudget maximum number of states per case + * @param caseStates starting state of the right-hand side of the each case * @return an `AsyncState` representing the match expression */ - def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], stateFirstCase: Int, perCasebudget: Int): AsyncState = { + def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], caseStates: List[Int]): AsyncState = { // 1. build list of changed cases val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { - case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(num * perCasebudget + stateFirstCase), mkResumeApply)) + case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(caseStates(num)), mkResumeApply)) } // 2. insert changed match tree at the end of the current state this += Match(resetDuplicate(scrutTree), newCases) @@ -259,6 +246,8 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy } } + val stateAssigner = new StateAssigner + /** * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`). * @@ -266,106 +255,84 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy * @param expr the last expression of the block * @param startState the start state * @param endState the state to continue with - * @param budget the maximum number of states in this block * @param toRename a `Map` for renaming the given key symbols to the mangled value names */ class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, - budget: Int, private var toRename: Map[c.Symbol, c.Name]) { + private val toRename: Map[Symbol, c.Name]) { val asyncStates = ListBuffer[builder.AsyncState]() private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename) // current state builder private var currState = startState - private var remainingBudget = budget - /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists { - case Apply(fun, _) if fun.symbol == Async_await => true + case Apply(fun, _) if isAwait(fun) => true case _ => false }) c.abort(tree.pos, "await unsupported in this position") //throw new FallbackToCpsException - def builderForBranch(tree: c.Tree, state: Int, nextState: Int, budget: Int, nameMap: Map[c.Symbol, c.Name]): AsyncBlockBuilder = { + def builderForBranch(tree: c.Tree, state: Int, nextState: Int): AsyncBlockBuilder = { val (branchStats, branchExpr) = tree match { case Block(s, e) => (s, e) case _ => (List(tree), c.literalUnit.tree) } - new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, budget, nameMap) + new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, toRename) } // populate asyncStates for (stat <- stats) stat match { // the val name = await(..) pattern - case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == Async_await => - val newName = builder.name.fresh(name) - toRename += (stat.symbol -> newName) - - asyncStates += stateBuilder.complete(args.head, newName, tpt, toRename).result // complete with await - if (remainingBudget > 0) - remainingBudget -= 1 - else - assert(false, "too many invocations of `await` in current method") - currState += 1 + case ValDef(mods, name, tpt, Apply(fun, args)) if isAwait(fun) => + val afterAwaitState = stateAssigner.nextState() + asyncStates += stateBuilder.complete(args.head, toRename(stat.symbol).toTermName, tpt, afterAwaitState).result // complete with await + currState = afterAwaitState stateBuilder = new builder.AsyncStateBuilder(currState, toRename) - case ValDef(mods, name, tpt, rhs) => + case ValDef(mods, name, tpt, rhs) if toRename contains stat.symbol => checkForUnsupportedAwait(rhs) - val newName = builder.name.fresh(name) - toRename += (stat.symbol -> newName) // when adding assignment need to take `toRename` into account - stateBuilder.addVarDef(mods, newName, tpt, rhs, toRename) + stateBuilder.addVarDef(mods, toRename(stat.symbol).toTermName, tpt, rhs) - case If(cond, thenp, elsep) => + case If(cond, thenp, elsep) if stat exists isAwait => checkForUnsupportedAwait(cond) - val ifBudget: Int = remainingBudget / 2 - remainingBudget -= ifBudget //TODO test if budget > 0 - // state that we continue with after if-else: currState + ifBudget - - val thenBudget: Int = ifBudget / 2 - val elseBudget = ifBudget - thenBudget + val thenStartState = stateAssigner.nextState() + val elseStartState = stateAssigner.nextState() + val afterIfState = stateAssigner.nextState() asyncStates += // the two Int arguments are the start state of the then branch and the else branch, respectively - stateBuilder.resultWithIf(cond, currState + 1, currState + thenBudget) + stateBuilder.resultWithIf(cond, thenStartState, elseStartState) - List((thenp, currState + 1, thenBudget), (elsep, currState + thenBudget, elseBudget)) foreach { - case (tree, state, branchBudget) => - val builder = builderForBranch(tree, state, currState + ifBudget, branchBudget, toRename) + List((thenp, thenStartState), (elsep, elseStartState)) foreach { + case (tree, state) => + val builder = builderForBranch(tree, state, afterIfState) asyncStates ++= builder.asyncStates - toRename ++= builder.toRename } - // create new state builder for state `currState + ifBudget` - currState = currState + ifBudget + currState = afterIfState stateBuilder = new builder.AsyncStateBuilder(currState, toRename) - case Match(scrutinee, cases) => - vprintln("transforming match expr: " + stat) + case Match(scrutinee, cases) if stat exists isAwait => checkForUnsupportedAwait(scrutinee) - val matchBudget: Int = remainingBudget / 2 - remainingBudget -= matchBudget //TODO test if budget > 0 - // state that we continue with after match: currState + matchBudget + val caseStates = cases.map(_ => stateAssigner.nextState()) + val afterMatchState = stateAssigner.nextState() - val perCaseBudget: Int = matchBudget / cases.size asyncStates += - // the two Int arguments are the start state of the first case and the per-case state budget, respectively - stateBuilder.resultWithMatch(scrutinee, cases, currState + 1, perCaseBudget) + stateBuilder.resultWithMatch(scrutinee, cases, caseStates) for ((cas, num) <- cases.zipWithIndex) { val (casStats, casExpr) = cas match { case CaseDef(_, _, Block(s, e)) => (s, e) case CaseDef(_, _, rhs) => (List(rhs), c.literalUnit.tree) } - val builder = new AsyncBlockBuilder(casStats, casExpr, currState + (num * perCaseBudget) + 1, currState + matchBudget, perCaseBudget, toRename) + val builder = new AsyncBlockBuilder(casStats, casExpr, caseStates(num), afterMatchState, toRename) asyncStates ++= builder.asyncStates - toRename ++= builder.toRename } - // create new state builder for state `currState + matchBudget` - currState = currState + matchBudget + currState = afterMatchState stateBuilder = new builder.AsyncStateBuilder(currState, toRename) case ClassDef(_, name, _, _) => @@ -382,7 +349,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy } // complete last state builder (representing the expressions after the last await) stateBuilder += expr - val lastState = stateBuilder.complete(endState).result + val lastState = stateBuilder.complete(endState).result() asyncStates += lastState def mkCombinedHandlerCases[T](): List[CaseDef] = { @@ -402,9 +369,106 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy } } + private val Boolean_ShortCircuits: Set[Symbol] = { + import definitions.BooleanClass + def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName) + val Boolean_&& = BooleanTermMember("&&") + val Boolean_|| = BooleanTermMember("||") + Set(Boolean_&&, Boolean_||) + } + + def isByName(fun: Tree): (Int => Boolean) = { + if (Boolean_ShortCircuits contains fun.symbol) i => true + else fun.tpe match { + case MethodType(params, _) => + val isByNameParams = params.map(_.asTerm.isByNameParam) + (i: Int) => isByNameParams.applyOrElse(i, (_: Int) => false) + case _ => Map() + } + } + + private def isAwait(fun: Tree) = { + fun.symbol == defn.Async_await + } + + /** + * Analyze the contents of an `async` block in order to: + * - Report unsupported `await` calls under nested templates, functions, by-name arguments. + * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based + * on whether or not they are accessed only from a single state. + */ + private[async] class AsyncAnalyzer extends Traverser { + private var chunkId = 0 + private def nextChunk() = chunkId += 1 + private var valDefChunkId = Map[Symbol, (ValDef, Int)]() + + val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]() + + override def traverse(tree: Tree) = { + tree match { + case cd: ClassDef => + val kind = if (cd.symbol.asClass.isTrait) "trait" else "class" + reportUnsupportedAwait(tree, s"nested $kind") + case md: ModuleDef => + reportUnsupportedAwait(tree, "nested object") + case _: Function => + reportUnsupportedAwait(tree, "nested anonymous function") + case If(cond, thenp, elsep) if tree exists isAwait => + traverseChunks(List(cond, thenp, elsep)) + case Match(selector, cases) if tree exists isAwait => + traverseChunks(selector :: cases) + case Apply(fun, args) if isAwait(fun) => + traverseTrees(args) + traverse(fun) + nextChunk() + case Apply(fun, args) => + val isInByName = isByName(fun) + for ((arg, index) <- args.zipWithIndex) { + if (!isInByName(index)) traverse(arg) + else reportUnsupportedAwait(arg, "by-name argument") + } + traverse(fun) + case vd: ValDef => + super.traverse(tree) + valDefChunkId += (vd.symbol ->(vd, chunkId)) + if (isAwait(vd.rhs)) valDefsToLift += vd + case as: Assign => + if (isAwait(as.rhs)) { + // TODO test the orElse case, try to remove the restriction. + val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block.")) + valDefsToLift += vd + } + super.traverse(tree) + case rt: RefTree => + valDefChunkId.get(rt.symbol) match { + case Some((vd, defChunkId)) if defChunkId != chunkId => + valDefsToLift += vd + case _ => + } + super.traverse(tree) + case _ => super.traverse(tree) + } + } + + private def traverseChunks(trees: List[Tree]) { + trees.foreach {t => traverse(t); nextChunk()} + } + + private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { + val badAwaits = tree collect { + case rt: RefTree if isAwait(rt) => rt + } + badAwaits foreach { + tree => + c.error(tree.pos, s"await must not be used under a $whyUnsupported.") + } + } + } + + /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ private def methodSym(apply: c.Expr[Any]): Symbol = { - val tree2: Tree = c.typeCheck(apply.tree) // TODO why is this needed? + val tree2: Tree = c.typeCheck(apply.tree) tree2.collect { case s: SymTree if s.symbol.isMethod => s.symbol }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}")) @@ -421,10 +485,6 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy self.splice.apply(arg.splice) } - def mkInt_+(self: Expr[Int])(other: Expr[Int]) = reify { - self.splice + other.splice - } - def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify { self.splice == other.splice } @@ -445,5 +505,4 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy tpe.member(c.universe.newTermName("await")) } } - } diff --git a/src/main/scala/scala/async/StateAssigner.scala b/src/main/scala/scala/async/StateAssigner.scala new file mode 100644 index 0000000..4f6c5a0 --- /dev/null +++ b/src/main/scala/scala/async/StateAssigner.scala @@ -0,0 +1,10 @@ +package scala.async + +private[async] final class StateAssigner { + private var current = -1 + + def nextState(): Int = { + current += 1 + current + } +}
\ No newline at end of file diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala new file mode 100644 index 0000000..d36c277 --- /dev/null +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -0,0 +1,22 @@ +/** + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ +package scala.async + +import scala.reflect.macros.Context + +/** + * Utilities used in both `ExprBuilder` and `AnfTransform`. + */ +class TransformUtils[C <: Context](val c: C) { + import c.universe._ + + protected def defaultValue(tpe: Type): Literal = { + val defaultValue: Any = + if (tpe <:< definitions.BooleanTpe) false + else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0 + else if (tpe <:< definitions.AnyValTpe) 0 + else null + Literal(Constant(defaultValue)) + } +} diff --git a/src/test/scala/scala/async/TestUtils.scala b/src/test/scala/scala/async/TestUtils.scala index bac22a3..0ae78b8 100644 --- a/src/test/scala/scala/async/TestUtils.scala +++ b/src/test/scala/scala/async/TestUtils.scala @@ -50,9 +50,9 @@ trait TestUtils { m.mkToolBox(options = compileOptions) } - def expectError(errorSnippet: String, compileOptions: String = "")(code: String) { + def expectError(errorSnippet: String, compileOptions: String = "", baseCompileOptions: String = "-cp target/scala-2.10/classes")(code: String) { intercept[ToolBoxError] { - eval(code, compileOptions) + eval(code, compileOptions + " " + baseCompileOptions) }.getMessage mustContain errorSnippet } } diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index 1293bdf..1ed9be2 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -32,7 +32,6 @@ class TreeInterrogation { val varDefs = tree1.collect { case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name } - // TODO no need to lift `y` as it is only accessed from a single state. - varDefs.map(_.decoded).toSet mustBe(Set("state$async", "onCompleteHandler$async", "x$1", "z$1", "y$1")) + varDefs.map(_.decoded).toSet mustBe(Set("state$async", "onCompleteHandler$async", "x$1", "z$1")) } } diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala index db67f18..66bc947 100644 --- a/src/test/scala/scala/async/neg/NakedAwait.scala +++ b/src/test/scala/scala/async/neg/NakedAwait.scala @@ -16,4 +16,76 @@ class NakedAwait { """.stripMargin } } + + + @Test + def `await not allowed in by-name argument`() { + expectError("await must not be used under a by-name argument.") { + """ + | import _root_.scala.async.AsyncId._ + | def foo(a: Int)(b: => Int) = 0 + | async { foo(0)(await(0)) } + """.stripMargin + } + } + + @Test + def `await not allowed in boolean short circuit argument 1`() { + expectError("await must not be used under a by-name argument.") { + """ + | import _root_.scala.async.AsyncId._ + | async { true && await(false) } + """.stripMargin + } + } + + @Test + def `await not allowed in boolean short circuit argument 2`() { + expectError("await must not be used under a by-name argument.") { + """ + | import _root_.scala.async.AsyncId._ + | async { true || await(false) } + """.stripMargin + } + } + + @Test + def nestedObject() { + expectError("await must not be used under a nested object.") { + """ + | import _root_.scala.async.AsyncId._ + | async { object Nested { await(false) } } + """.stripMargin + } + } + + @Test + def nestedTrait() { + expectError("await must not be used under a nested trait.") { + """ + | import _root_.scala.async.AsyncId._ + | async { trait Nested { await(false) } } + """.stripMargin + } + } + + @Test + def nestedClass() { + expectError("await must not be used under a nested class.") { + """ + | import _root_.scala.async.AsyncId._ + | async { class Nested { await(false) } } + """.stripMargin + } + } + + @Test + def nestedFunction() { + expectError("await must not be used under a nested anonymous function.") { + """ + | import _root_.scala.async.AsyncId._ + | async { () => { await(false) } } + """.stripMargin + } + } } diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala new file mode 100644 index 0000000..0abb937 --- /dev/null +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -0,0 +1,146 @@ +/** + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +package scala.async +package run +package anf + +import language.{reflectiveCalls, postfixOps} +import scala.concurrent.{Future, ExecutionContext, future, Await} +import scala.concurrent.duration._ +import scala.async.Async.{async, await} +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 + + +class AnfTestClass { + + import ExecutionContext.Implicits.global + + def base(x: Int): Future[Int] = future { + x + 2 + } + + def m(y: Int): Future[Int] = async { + val f = base(y) + await(f) + } + + def m2(y: Int): Future[Int] = async { + val f = base(y) + val f2 = base(y + 1) + await(f) + await(f2) + } + + def m3(y: Int): Future[Int] = async { + val f = base(y) + var z = 0 + if (y > 0) { + z = await(f) + 2 + } else { + z = await(f) - 2 + } + z + } + + def m4(y: Int): Future[Int] = async { + val f = base(y) + val z = if (y > 0) { + await(f) + 2 + } else { + await(f) - 2 + } + z + 1 + } + + def futureUnitIfElse(y: Int): Future[Unit] = async { + val f = base(y) + if (y > 0) { + State.result = await(f) + 2 + } else { + State.result = await(f) - 2 + } + } +} + +object State { + @volatile var result: Int = 0 +} + +@RunWith(classOf[JUnit4]) +class AnfTransformSpec { + + @Test + def `simple ANF transform`() { + val o = new AnfTestClass + val fut = o.m(10) + val res = Await.result(fut, 2 seconds) + res mustBe (12) + } + + @Test + def `simple ANF transform 2`() { + val o = new AnfTestClass + val fut = o.m2(10) + val res = Await.result(fut, 2 seconds) + res mustBe (25) + } + + @Test + def `simple ANF transform 3`() { + val o = new AnfTestClass + val fut = o.m3(10) + val res = Await.result(fut, 2 seconds) + res mustBe (14) + } + + @Test + def `ANF transform of assigning the result of an if-else`() { + val o = new AnfTestClass + val fut = o.m4(10) + val res = Await.result(fut, 2 seconds) + res mustBe (15) + } + + @Test + def `Unit-typed if-else in tail position`() { + val o = new AnfTestClass + val fut = o.futureUnitIfElse(10) + Await.result(fut, 2 seconds) + State.result mustBe (14) + } + + @Test + def `inlining block produces duplicate definition`() { + import scala.async.AsyncId + + AsyncId.async { + val f = 12 + val x = AsyncId.await(f) + + { + val x = 42 + println(x) + } + + x + } + } + @Test + def `inlining block in tail position produces duplicate definition`() { + import scala.async.AsyncId + + AsyncId.async { + val f = 12 + val x = AsyncId.await(f) + + { + val x = 42 // TODO should we rename the symbols when we collapse them into the same scope? + x + } + } mustBe (42) + + } +} |