From 087d1e4e138eccf4b2d420298affb4289632bf73 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Thu, 22 Nov 2012 17:50:50 +0100 Subject: Support match as an expression. - corrects detection of await calls in the ANF transform. - Split AsyncAnalyzer into two parts. Unsupported await detection must happen prior to the async transform to prevent the ANF lifting out by-name arguments to vals and hence changing the semantics. --- src/main/scala/scala/async/AnfTransform.scala | 54 ++++-- src/main/scala/scala/async/Async.scala | 6 +- src/main/scala/scala/async/ExprBuilder.scala | 184 +++++++++------------ src/main/scala/scala/async/TransformUtils.scala | 98 +++++++++++ src/test/scala/scala/async/TreeInterrogation.scala | 4 +- src/test/scala/scala/async/neg/NakedAwait.scala | 3 +- .../scala/async/run/anf/AnfTransformSpec.scala | 74 ++++++--- src/test/scala/scala/async/run/block1/block1.scala | 1 - 8 files changed, 272 insertions(+), 152 deletions(-) diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index e1d7cd5..0921b55 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -3,6 +3,7 @@ package scala.async import scala.reflect.macros.Context class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { + import c.universe._ import AsyncUtils._ @@ -11,7 +12,7 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { val stats :+ expr = anf.transformToList(tree) expr match { - case Apply(fun, args) if fun.toString.startsWith("scala.async.Async.await") => + case Apply(fun, args) if isAwait(fun) => val liftedName = c.fresh("await$") stats :+ ValDef(NoMods, liftedName, TypeTree(), expr) :+ Ident(liftedName) @@ -27,17 +28,35 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { ValDef(Modifiers(Flag.MUTABLE), liftedName, TypeTree(expr.tpe), defaultValue(expr.tpe)) val thenWithAssign = thenp match { case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(liftedName), thenExpr)) - case _ => Assign(Ident(liftedName), thenp) + case _ => Assign(Ident(liftedName), thenp) } val elseWithAssign = elsep match { case Block(elseStats, elseExpr) => Block(elseStats, Assign(Ident(liftedName), elseExpr)) - case _ => Assign(Ident(liftedName), elsep) + case _ => Assign(Ident(liftedName), elsep) } val ifWithAssign = If(cond, thenWithAssign, elseWithAssign) stats :+ varDef :+ ifWithAssign :+ Ident(liftedName) } - case _ => + + case Match(scrut, cases) => + // if type of match is Unit don't introduce assignment, + // but add Unit value to bring it into form expected by async transform + if (expr.tpe =:= definitions.UnitTpe) { + stats :+ expr :+ Literal(Constant(())) + } + else { + val liftedName = c.fresh("matchres$") + val varDef = + ValDef(Modifiers(Flag.MUTABLE), liftedName, TypeTree(expr.tpe), defaultValue(expr.tpe)) + val casesWithAssign = cases map { + case CaseDef(pat, guard, Block(caseStats, caseExpr)) => CaseDef(pat, guard, Block(caseStats, Assign(Ident(liftedName), caseExpr))) + case CaseDef(pat, guard, tree) => CaseDef(pat, guard, Assign(Ident(liftedName), tree)) + } + val matchWithAssign = Match(scrut, casesWithAssign) + stats :+ varDef :+ matchWithAssign :+ Ident(liftedName) + } + case _ => stats :+ expr } } @@ -46,6 +65,10 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { case fst :: rest => transformToList(fst) ++ transformToList(rest) case Nil => Nil } + + def transformToBlock(tree: Tree): Block = transformToList(tree) match { + case stats :+ expr => Block(stats, expr) + } } object anf { @@ -59,7 +82,7 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { val argLists = args map inline.transformToList val allArgStats = argLists flatMap (_.init) val simpleArgs = argLists map (_.last) - funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs) + funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs).setSymbol(tree.symbol) case Block(stats, expr) => inline.transformToList(stats) ++ inline.transformToList(expr) @@ -74,13 +97,22 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { case If(cond, thenp, elsep) => val stats :+ expr = inline.transformToList(cond) - val thenStats :+ thenExpr = inline.transformToList(thenp) - val elseStats :+ elseExpr = inline.transformToList(elsep) + val thenBlock = inline.transformToBlock(thenp) + val elseBlock = inline.transformToBlock(elsep) stats :+ - c.typeCheck(If(expr, Block(thenStats, thenExpr), Block(elseStats, elseExpr))) + c.typeCheck(If(expr, thenBlock, elseBlock)) + + case Match(scrut, cases) => + val scrutStats :+ scrutExpr = inline.transformToList(scrut) + val caseDefs = cases map { + case CaseDef(pat, guard, body) => + val block = inline.transformToBlock(body) + CaseDef(pat, guard, block) + } + scrutStats :+ c.typeCheck(Match(scrutExpr, caseDefs)) //TODO - case Literal(_) | Ident(_) | This(_) | Match(_, _) | New(_) | Function(_, _) => List(tree) + case Literal(_) | Ident(_) | This(_) | New(_) | Function(_, _) => List(tree) case TypeApply(fun, targs) => val funStats :+ simpleFun = inline.transformToList(fun) @@ -94,8 +126,8 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { case ModuleDef(mods, name, impl) => List(tree) case _ => - c.error(tree.pos, "Internal error while compiling `async` block") - ??? + c.abort(tree.pos, s"Internal error while compiling `async` block: $tree") } } + } diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index bd766f2..645f3f7 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -71,6 +71,8 @@ abstract class AsyncBase { import builder.name import builder.futureSystemOps + builder.reportUnsupportedAwaits(body.tree) + // Transform to A-normal form: // - no await calls in qualifiers or arguments, // - if/match only used in statement position. @@ -84,9 +86,7 @@ abstract class AsyncBase { // states of our generated state machine, e.g. a value assigned before // an `await` and read afterwards. val renameMap: Map[Symbol, TermName] = { - val analyzer = new builder.AsyncAnalyzer - analyzer.traverse(anfTree) - analyzer.valDefsToLift.map { + builder.valDefsUsedInSubsequentStates(anfTree).map { vd => (vd.symbol, builder.name.fresh(vd.name)) }.toMap diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 7a9c98d..735db76 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -22,14 +22,14 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val def suffixedName(prefix: String) = newTermName(suffix(prefix)) - val state = suffixedName("state") - val result = suffixedName("result") - val resume = suffixedName("resume") + val state = suffixedName("state") + val result = suffixedName("result") + val resume = suffixedName("resume") val execContext = suffixedName("execContext") // TODO do we need to freshen any of these? - val x1 = newTermName("x$1") - val tr = newTermName("tr") + val x1 = newTermName("x$1") + val tr = newTermName("tr") val onCompleteHandler = suffixedName("onCompleteHandler") def fresh(name: TermName) = newTermName(c.fresh("" + name + "$")) @@ -60,7 +60,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) { val body: c.Tree = stats match { case stat :: Nil => stat - case _ => Block(stats: _*) + case _ => Block(stats: _*) } val varDefs: List[(TermName, Type)] = Nil @@ -78,7 +78,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val ) val updateState = mkStateTree(nextState) Some(mkHandlerCase(state, List(tryGetTree, updateState, mkResumeApply))) - case _ => + case _ => None } } @@ -106,7 +106,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val abstract class AsyncStateWithAwait(stats: List[c.Tree], state: Int, nextState: Int) extends AsyncState(stats, state, nextState) { - val awaitable: c.Tree + val awaitable : c.Tree val resultName: TermName val resultType: Type @@ -154,7 +154,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val override def transform(tree: Tree) = tree match { case Ident(_) if nameMap.keySet contains tree.symbol => Ident(nameMap(tree.symbol)) - case _ => + case _ => super.transform(tree) } } @@ -178,7 +178,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val } else new AsyncStateWithAwait(stats.toList, state, nextState) { - val awaitable = self.awaitable + val awaitable = self.awaitable val resultName = self.resultName val resultType = self.resultType override val varDefs = self.varDefs.toList @@ -263,18 +263,18 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename) // current state builder - private var currState = startState + private var currState = startState /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists { case Apply(fun, _) if isAwait(fun) => true - case _ => false + case _ => false }) c.abort(tree.pos, "await unsupported in this position") //throw new FallbackToCpsException def builderForBranch(tree: c.Tree, state: Int, nextState: Int): AsyncBlockBuilder = { val (branchStats, branchExpr) = tree match { case Block(s, e) => (s, e) - case _ => (List(tree), c.literalUnit.tree) + case _ => (List(tree), c.literalUnit.tree) } new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, toRename) } @@ -326,7 +326,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val for ((cas, num) <- cases.zipWithIndex) { val (casStats, casExpr) = cas match { case CaseDef(_, _, Block(s, e)) => (s, e) - case CaseDef(_, _, rhs) => (List(rhs), c.literalUnit.tree) + case CaseDef(_, _, rhs) => (List(rhs), c.literalUnit.tree) } val builder = new AsyncBlockBuilder(casStats, casExpr, caseStates(num), afterMatchState, toRename) asyncStates ++= builder.asyncStates @@ -362,147 +362,111 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val asyncStates.toList match { case s :: Nil => List(caseForLastState) - case _ => + case _ => val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState() initCases :+ caseForLastState } } } - private val Boolean_ShortCircuits: Set[Symbol] = { - import definitions.BooleanClass - def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName) - val Boolean_&& = BooleanTermMember("&&") - val Boolean_|| = BooleanTermMember("||") - Set(Boolean_&&, Boolean_||) + /** + * Analyze the contents of an `async` block in order to: + * - Report unsupported `await` calls under nested templates, functions, by-name arguments. + * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based + * on whether or not they are accessed only from a single state. + */ + def reportUnsupportedAwaits(tree: Tree) { + new UnsupportedAwaitAnalyzer().traverse(tree) } - def isByName(fun: Tree): (Int => Boolean) = { - if (Boolean_ShortCircuits contains fun.symbol) i => true - else fun.tpe match { - case MethodType(params, _) => - val isByNameParams = params.map(_.asTerm.isByNameParam) - (i: Int) => isByNameParams.applyOrElse(i, (_: Int) => false) - case _ => Map() + private class UnsupportedAwaitAnalyzer extends super.AsyncTraverser { + override def nestedClass(classDef: ClassDef) { + val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" + reportUnsupportedAwait(classDef, s"nested $kind") } - } - private def isAwait(fun: Tree) = { - fun.symbol == defn.Async_await + override def nestedModule(module: ModuleDef) { + reportUnsupportedAwait(module, "nested object") + } + + override def byNameArgument(arg: Tree) { + reportUnsupportedAwait(arg, "by-name argument") + } + + override def function(function: Function) { + reportUnsupportedAwait(function, "nested function") + } + + private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { + val badAwaits = tree collect { + case rt: RefTree if isAwait(rt) => rt + } + badAwaits foreach { + tree => + c.error(tree.pos, s"await must not be used under a $whyUnsupported.") + } + } } /** * Analyze the contents of an `async` block in order to: - * - Report unsupported `await` calls under nested templates, functions, by-name arguments. - * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based - * on whether or not they are accessed only from a single state. + * - Report unsupported `await` calls under nested templates, functions, by-name arguments. + * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based + * on whether or not they are accessed only from a single state. */ - private[async] class AsyncAnalyzer extends Traverser { + def valDefsUsedInSubsequentStates(tree: Tree): List[ValDef] = { + val analyzer = new AsyncDefinitionUseAnalyzer + analyzer.traverse(tree) + analyzer.valDefsToLift.toList + } + + private class AsyncDefinitionUseAnalyzer extends super.AsyncTraverser { private var chunkId = 0 + private def nextChunk() = chunkId += 1 + private var valDefChunkId = Map[Symbol, (ValDef, Int)]() val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]() override def traverse(tree: Tree) = { tree match { - case cd: ClassDef => - val kind = if (cd.symbol.asClass.isTrait) "trait" else "class" - reportUnsupportedAwait(tree, s"nested $kind") - case md: ModuleDef => - reportUnsupportedAwait(tree, "nested object") - case _: Function => - reportUnsupportedAwait(tree, "nested anonymous function") case If(cond, thenp, elsep) if tree exists isAwait => traverseChunks(List(cond, thenp, elsep)) case Match(selector, cases) if tree exists isAwait => traverseChunks(selector :: cases) - case Apply(fun, args) if isAwait(fun) => - traverseTrees(args) - traverse(fun) + case Apply(fun, args) if isAwait(fun) => + super.traverse(tree) nextChunk() - case Apply(fun, args) => - val isInByName = isByName(fun) - for ((arg, index) <- args.zipWithIndex) { - if (!isInByName(index)) traverse(arg) - else reportUnsupportedAwait(arg, "by-name argument") - } - traverse(fun) - case vd: ValDef => + case vd: ValDef => super.traverse(tree) valDefChunkId += (vd.symbol ->(vd, chunkId)) if (isAwait(vd.rhs)) valDefsToLift += vd - case as: Assign => + case as: Assign => if (isAwait(as.rhs)) { // TODO test the orElse case, try to remove the restriction. - val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block.")) - valDefsToLift += vd + if (as.symbol != null) { + // synthetic added by the ANF transfor + val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block. " + as.symbol)) + valDefsToLift += vd + } } super.traverse(tree) - case rt: RefTree => + case rt: RefTree => valDefChunkId.get(rt.symbol) match { case Some((vd, defChunkId)) if defChunkId != chunkId => valDefsToLift += vd - case _ => + case _ => } super.traverse(tree) - case _ => super.traverse(tree) + case _ => super.traverse(tree) } } private def traverseChunks(trees: List[Tree]) { - trees.foreach {t => traverse(t); nextChunk()} - } - - private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { - val badAwaits = tree collect { - case rt: RefTree if isAwait(rt) => rt + trees.foreach { + t => traverse(t); nextChunk() } - badAwaits foreach { - tree => - c.error(tree.pos, s"await must not be used under a $whyUnsupported.") - } - } - } - - - /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ - private def methodSym(apply: c.Expr[Any]): Symbol = { - val tree2: Tree = c.typeCheck(apply.tree) - tree2.collect { - case s: SymTree if s.symbol.isMethod => s.symbol - }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}")) - } - - private[async] object defn { - def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { - c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) - } - - def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify(self.splice.contains(elem.splice)) - - def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify { - self.splice.apply(arg.splice) - } - - def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify { - self.splice == other.splice - } - - def mkTry_get[A](self: Expr[util.Try[A]]) = reify { - self.splice.get - } - - val Try_get = methodSym(reify((null: scala.util.Try[Any]).get)) - - val TryClass = c.mirror.staticClass("scala.util.Try") - val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe)) - val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal") - - val Async_await = { - val asyncMod = c.mirror.staticModule("scala.async.Async") - val tpe = asyncMod.moduleClass.asType.toType - tpe.member(c.universe.newTermName("await")) } } } diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index d36c277..b8b21a3 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -9,6 +9,7 @@ import scala.reflect.macros.Context * Utilities used in both `ExprBuilder` and `AnfTransform`. */ class TransformUtils[C <: Context](val c: C) { + import c.universe._ protected def defaultValue(tpe: Type): Literal = { @@ -19,4 +20,101 @@ class TransformUtils[C <: Context](val c: C) { else null Literal(Constant(defaultValue)) } + + protected def isAwait(fun: Tree) = + fun.symbol == defn.Async_await + + /** Descends into the regions of the tree that are subject to the + * translation to a state machine by `async`. When a nested template, + * function, or by-name argument is encountered, the descend stops, + * and `nestedClass` etc are invoked. + */ + trait AsyncTraverser extends Traverser { + def nestedClass(classDef: ClassDef) { + } + + def nestedModule(module: ModuleDef) { + } + + def byNameArgument(arg: Tree) { + } + + def function(function: Function) { + } + + override def traverse(tree: Tree) { + tree match { + case cd: ClassDef => nestedClass(cd) + case md: ModuleDef => nestedModule(md) + case fun: Function => function(fun) + case Apply(fun, args) => + val isInByName = isByName(fun) + for ((arg, index) <- args.zipWithIndex) { + if (!isInByName(index)) traverse(arg) + else byNameArgument(arg) + } + traverse(fun) + case _ => super.traverse(tree) + } + } + } + + private lazy val Boolean_ShortCircuits: Set[Symbol] = { + import definitions.BooleanClass + def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName) + val Boolean_&& = BooleanTermMember("&&") + val Boolean_|| = BooleanTermMember("||") + Set(Boolean_&&, Boolean_||) + } + + protected def isByName(fun: Tree): (Int => Boolean) = { + if (Boolean_ShortCircuits contains fun.symbol) i => true + else fun.tpe match { + case MethodType(params, _) => + val isByNameParams = params.map(_.asTerm.isByNameParam) + (i: Int) => isByNameParams.applyOrElse(i, (_: Int) => false) + case _ => Map() + } + } + + private[async] object defn { + def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { + c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) + } + + def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify(self.splice.contains(elem.splice)) + + def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify { + self.splice.apply(arg.splice) + } + + def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify { + self.splice == other.splice + } + + def mkTry_get[A](self: Expr[util.Try[A]]) = reify { + self.splice.get + } + + val Try_get = methodSym(reify((null: scala.util.Try[Any]).get)) + + val TryClass = c.mirror.staticClass("scala.util.Try") + val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe)) + val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal") + + val Async_await = { + val asyncMod = c.mirror.staticModule("scala.async.Async") + val tpe = asyncMod.moduleClass.asType.toType + tpe.member(c.universe.newTermName("await")) + } + } + + + /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ + private def methodSym(apply: c.Expr[Any]): Symbol = { + val tree2: Tree = c.typeCheck(apply.tree) + tree2.collect { + case s: SymTree if s.symbol.isMethod => s.symbol + }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}")) + } } diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index 1ed9be2..02f4b43 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -21,7 +21,7 @@ class TreeInterrogation { | }""".stripMargin) val tree1 = tb.typeCheck(tree) - // println(cm.universe.showRaw(tree1)) + //println(cm.universe.show(tree1)) import tb.mirror.universe._ val functions = tree1.collect { @@ -32,6 +32,6 @@ class TreeInterrogation { val varDefs = tree1.collect { case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name } - varDefs.map(_.decoded).toSet mustBe(Set("state$async", "onCompleteHandler$async", "x$1", "z$1")) + varDefs.map(_.decoded).toSet mustBe(Set("state$async", "onCompleteHandler$async", "await$1$1", "await$2$1")) } } diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala index 66bc947..8b85977 100644 --- a/src/test/scala/scala/async/neg/NakedAwait.scala +++ b/src/test/scala/scala/async/neg/NakedAwait.scala @@ -17,7 +17,6 @@ class NakedAwait { } } - @Test def `await not allowed in by-name argument`() { expectError("await must not be used under a by-name argument.") { @@ -81,7 +80,7 @@ class NakedAwait { @Test def nestedFunction() { - expectError("await must not be used under a nested anonymous function.") { + expectError("await must not be used under a nested function.") { """ | import _root_.scala.async.AsyncId._ | async { () => { await(false) } } diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala index 0abb937..1d6e09a 100644 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -112,35 +112,63 @@ class AnfTransformSpec { State.result mustBe (14) } - @Test - def `inlining block produces duplicate definition`() { - import scala.async.AsyncId - - AsyncId.async { - val f = 12 - val x = AsyncId.await(f) + // TODO JZ +// @Test +// def `inlining block produces duplicate definition`() { +// import scala.async.AsyncId +// +// AsyncId.async { +// val f = 12 +// val x = AsyncId.await(f) +// +// { +// val x = 42 +// println(x) +// } +// +// x +// } +// } +// @Test +// def `inlining block in tail position produces duplicate definition`() { +// import scala.async.AsyncId +// +// AsyncId.async { +// val f = 12 +// val x = AsyncId.await(f) +// +// { +// val x = 42 // TODO should we rename the symbols when we collapse them into the same scope? +// x +// } +// } mustBe (42) +// } - { - val x = 42 - println(x) + @Test + def `match as expression 1`() { + import ExecutionContext.Implicits.global + val result = AsyncId.async { + val x = "" match { + case _ => AsyncId.await(1) + 1 } - x } + result mustBe (2) } - @Test - def `inlining block in tail position produces duplicate definition`() { - import scala.async.AsyncId - AsyncId.async { - val f = 12 - val x = AsyncId.await(f) - - { - val x = 42 // TODO should we rename the symbols when we collapse them into the same scope? - x + @Test + def `match as expression 2`() { + import ExecutionContext.Implicits.global + val result = AsyncId.async { + val x = "" match { + case "" if false => AsyncId.await(1) + 1 + case _ => 2 + AsyncId.await(1) } - } mustBe (42) - + val y = x + "" match { + case _ => AsyncId.await(y) + 100 + } + } + result mustBe (103) } } diff --git a/src/test/scala/scala/async/run/block1/block1.scala b/src/test/scala/scala/async/run/block1/block1.scala index a449805..0853498 100644 --- a/src/test/scala/scala/async/run/block1/block1.scala +++ b/src/test/scala/scala/async/run/block1/block1.scala @@ -27,7 +27,6 @@ class Test1Class { val f1 = m1(y) val f2 = m1(y + 2) val x1 = await(f1) - println("between two awaits") val x2 = await(f2) x1 + x2 } -- cgit v1.2.3 From 7d3b5f80507accd43ab7f7e70ef801d66540291d Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Thu, 22 Nov 2012 17:52:24 +0100 Subject: Comment. --- src/main/scala/scala/async/AnfTransform.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 0921b55..d06fb54 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -78,6 +78,8 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { stats :+ Select(expr, sel) case Apply(fun, args) => + // we an assume that no await call appears in a by-name argument position, + // this has already been checked. val funStats :+ simpleFun = inline.transformToList(fun) val argLists = args map inline.transformToList val allArgStats = argLists flatMap (_.init) -- cgit v1.2.3 From 93520f30d77af10c0b936da3f658ec644c7ecd4b Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Thu, 22 Nov 2012 17:54:50 +0100 Subject: Lookup await symbol in AsyncBase. --- src/main/scala/scala/async/TransformUtils.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index b8b21a3..103c8d2 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -103,9 +103,9 @@ class TransformUtils[C <: Context](val c: C) { val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal") val Async_await = { - val asyncMod = c.mirror.staticModule("scala.async.Async") - val tpe = asyncMod.moduleClass.asType.toType - tpe.member(c.universe.newTermName("await")) + val asyncMod = c.mirror.staticClass("scala.async.AsyncBase") + val tpe = asyncMod.asType.toType + tpe.member(c.universe.newTermName("await")).ensuring(_ != NoSymbol) } } -- cgit v1.2.3 From d6ce00d65ade8c31b61091d65fe21ad480c6b20c Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Thu, 22 Nov 2012 18:01:49 +0100 Subject: Refactor the analyzers to a seprarate file. --- src/main/scala/scala/async/AnfTransform.scala | 1 + src/main/scala/scala/async/Async.scala | 5 +- src/main/scala/scala/async/AsyncAnalysis.scala | 110 +++++++++++++++++++++++++ src/main/scala/scala/async/ExprBuilder.scala | 101 ----------------------- 4 files changed, 114 insertions(+), 103 deletions(-) create mode 100644 src/main/scala/scala/async/AsyncAnalysis.scala diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index d06fb54..0146210 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -1,3 +1,4 @@ + package scala.async import scala.reflect.macros.Context diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index 645f3f7..546445a 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -66,12 +66,13 @@ abstract class AsyncBase { import Flag._ val builder = new ExprBuilder[c.type, futureSystem.type](c, self.futureSystem) + val anaylzer = new AsyncAnalysis[c.type](c) import builder.defn._ import builder.name import builder.futureSystemOps - builder.reportUnsupportedAwaits(body.tree) + anaylzer.reportUnsupportedAwaits(body.tree) // Transform to A-normal form: // - no await calls in qualifiers or arguments, @@ -86,7 +87,7 @@ abstract class AsyncBase { // states of our generated state machine, e.g. a value assigned before // an `await` and read afterwards. val renameMap: Map[Symbol, TermName] = { - builder.valDefsUsedInSubsequentStates(anfTree).map { + anaylzer.valDefsUsedInSubsequentStates(anfTree).map { vd => (vd.symbol, builder.name.fresh(vd.name)) }.toMap diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala new file mode 100644 index 0000000..1b00620 --- /dev/null +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -0,0 +1,110 @@ +package scala.async + +import scala.reflect.macros.Context +import collection.mutable + +private[async] final class AsyncAnalysis[C <: Context](override val c: C) extends TransformUtils(c) { + import c.universe._ + + /** + * Analyze the contents of an `async` block in order to: + * - Report unsupported `await` calls under nested templates, functions, by-name arguments. + * + * Must be called on the original tree, not on the ANF transformed tree. + */ + def reportUnsupportedAwaits(tree: Tree) { + new UnsupportedAwaitAnalyzer().traverse(tree) + } + + /** + * Analyze the contents of an `async` block in order to: + * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based + * on whether or not they are accessed only from a single state. + * + * Must be called on the ANF transformed tree. + */ + def valDefsUsedInSubsequentStates(tree: Tree): List[ValDef] = { + val analyzer = new AsyncDefinitionUseAnalyzer + analyzer.traverse(tree) + analyzer.valDefsToLift.toList + } + + private class UnsupportedAwaitAnalyzer extends super.AsyncTraverser { + override def nestedClass(classDef: ClassDef) { + val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" + reportUnsupportedAwait(classDef, s"nested $kind") + } + + override def nestedModule(module: ModuleDef) { + reportUnsupportedAwait(module, "nested object") + } + + override def byNameArgument(arg: Tree) { + reportUnsupportedAwait(arg, "by-name argument") + } + + override def function(function: Function) { + reportUnsupportedAwait(function, "nested function") + } + + private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { + val badAwaits = tree collect { + case rt: RefTree if isAwait(rt) => rt + } + badAwaits foreach { + tree => + c.error(tree.pos, s"await must not be used under a $whyUnsupported.") + } + } + } + + private class AsyncDefinitionUseAnalyzer extends super.AsyncTraverser { + private var chunkId = 0 + + private def nextChunk() = chunkId += 1 + + private var valDefChunkId = Map[Symbol, (ValDef, Int)]() + + val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]() + + override def traverse(tree: Tree) = { + tree match { + case If(cond, thenp, elsep) if tree exists isAwait => + traverseChunks(List(cond, thenp, elsep)) + case Match(selector, cases) if tree exists isAwait => + traverseChunks(selector :: cases) + case Apply(fun, args) if isAwait(fun) => + super.traverse(tree) + nextChunk() + case vd: ValDef => + super.traverse(tree) + valDefChunkId += (vd.symbol ->(vd, chunkId)) + if (isAwait(vd.rhs)) valDefsToLift += vd + case as: Assign => + if (isAwait(as.rhs)) { + // TODO test the orElse case, try to remove the restriction. + if (as.symbol != null) { + // synthetic added by the ANF transfor + val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block. " + as.symbol)) + valDefsToLift += vd + } + } + super.traverse(tree) + case rt: RefTree => + valDefChunkId.get(rt.symbol) match { + case Some((vd, defChunkId)) if defChunkId != chunkId => + valDefsToLift += vd + case _ => + } + super.traverse(tree) + case _ => super.traverse(tree) + } + } + + private def traverseChunks(trees: List[Tree]) { + trees.foreach { + t => traverse(t); nextChunk() + } + } + } +} diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 735db76..573af16 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -368,105 +368,4 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val } } } - - /** - * Analyze the contents of an `async` block in order to: - * - Report unsupported `await` calls under nested templates, functions, by-name arguments. - * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based - * on whether or not they are accessed only from a single state. - */ - def reportUnsupportedAwaits(tree: Tree) { - new UnsupportedAwaitAnalyzer().traverse(tree) - } - - private class UnsupportedAwaitAnalyzer extends super.AsyncTraverser { - override def nestedClass(classDef: ClassDef) { - val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" - reportUnsupportedAwait(classDef, s"nested $kind") - } - - override def nestedModule(module: ModuleDef) { - reportUnsupportedAwait(module, "nested object") - } - - override def byNameArgument(arg: Tree) { - reportUnsupportedAwait(arg, "by-name argument") - } - - override def function(function: Function) { - reportUnsupportedAwait(function, "nested function") - } - - private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { - val badAwaits = tree collect { - case rt: RefTree if isAwait(rt) => rt - } - badAwaits foreach { - tree => - c.error(tree.pos, s"await must not be used under a $whyUnsupported.") - } - } - } - - /** - * Analyze the contents of an `async` block in order to: - * - Report unsupported `await` calls under nested templates, functions, by-name arguments. - * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based - * on whether or not they are accessed only from a single state. - */ - def valDefsUsedInSubsequentStates(tree: Tree): List[ValDef] = { - val analyzer = new AsyncDefinitionUseAnalyzer - analyzer.traverse(tree) - analyzer.valDefsToLift.toList - } - - private class AsyncDefinitionUseAnalyzer extends super.AsyncTraverser { - private var chunkId = 0 - - private def nextChunk() = chunkId += 1 - - private var valDefChunkId = Map[Symbol, (ValDef, Int)]() - - val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]() - - override def traverse(tree: Tree) = { - tree match { - case If(cond, thenp, elsep) if tree exists isAwait => - traverseChunks(List(cond, thenp, elsep)) - case Match(selector, cases) if tree exists isAwait => - traverseChunks(selector :: cases) - case Apply(fun, args) if isAwait(fun) => - super.traverse(tree) - nextChunk() - case vd: ValDef => - super.traverse(tree) - valDefChunkId += (vd.symbol ->(vd, chunkId)) - if (isAwait(vd.rhs)) valDefsToLift += vd - case as: Assign => - if (isAwait(as.rhs)) { - // TODO test the orElse case, try to remove the restriction. - if (as.symbol != null) { - // synthetic added by the ANF transfor - val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block. " + as.symbol)) - valDefsToLift += vd - } - } - super.traverse(tree) - case rt: RefTree => - valDefChunkId.get(rt.symbol) match { - case Some((vd, defChunkId)) if defChunkId != chunkId => - valDefsToLift += vd - case _ => - } - super.traverse(tree) - case _ => super.traverse(tree) - } - } - - private def traverseChunks(trees: List[Tree]) { - trees.foreach { - t => traverse(t); nextChunk() - } - } - } } -- cgit v1.2.3 From b746b7fe9f8ccbc2d37182d4be0dc24597223331 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Thu, 22 Nov 2012 19:10:36 +0100 Subject: Make the ANF transform more selective. --- src/main/scala/scala/async/AnfTransform.scala | 109 ++++++++++++-------------- 1 file changed, 52 insertions(+), 57 deletions(-) diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 0146210..346da57 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -73,63 +73,58 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { } object anf { - def transformToList(tree: Tree): List[Tree] = tree match { - case Select(qual, sel) => - val stats :+ expr = inline.transformToList(qual) - stats :+ Select(expr, sel) - - case Apply(fun, args) => - // we an assume that no await call appears in a by-name argument position, - // this has already been checked. - val funStats :+ simpleFun = inline.transformToList(fun) - val argLists = args map inline.transformToList - val allArgStats = argLists flatMap (_.init) - val simpleArgs = argLists map (_.last) - funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs).setSymbol(tree.symbol) - - case Block(stats, expr) => - inline.transformToList(stats) ++ inline.transformToList(expr) - - case ValDef(mods, name, tpt, rhs) => - val stats :+ expr = inline.transformToList(rhs) - stats :+ ValDef(mods, name, tpt, expr).setSymbol(tree.symbol) - - case Assign(name, rhs) => - val stats :+ expr = inline.transformToList(rhs) - stats :+ Assign(name, expr) - - case If(cond, thenp, elsep) => - val stats :+ expr = inline.transformToList(cond) - val thenBlock = inline.transformToBlock(thenp) - val elseBlock = inline.transformToBlock(elsep) - stats :+ - c.typeCheck(If(expr, thenBlock, elseBlock)) - - case Match(scrut, cases) => - val scrutStats :+ scrutExpr = inline.transformToList(scrut) - val caseDefs = cases map { - case CaseDef(pat, guard, body) => - val block = inline.transformToBlock(body) - CaseDef(pat, guard, block) - } - scrutStats :+ c.typeCheck(Match(scrutExpr, caseDefs)) - - //TODO - case Literal(_) | Ident(_) | This(_) | New(_) | Function(_, _) => List(tree) - - case TypeApply(fun, targs) => - val funStats :+ simpleFun = inline.transformToList(fun) - funStats :+ TypeApply(simpleFun, targs) - - //TODO - case DefDef(mods, name, tparams, vparamss, tpt, rhs) => List(tree) - - case ClassDef(mods, name, tparams, impl) => List(tree) - - case ModuleDef(mods, name, impl) => List(tree) - - case _ => - c.abort(tree.pos, s"Internal error while compiling `async` block: $tree") + def transformToList(tree: Tree): List[Tree] = { + def containsAwait = tree exists isAwait + tree match { + case Select(qual, sel) if containsAwait => + val stats :+ expr = inline.transformToList(qual) + stats :+ Select(expr, sel) + + case Apply(fun, args) if containsAwait => + // we an assume that no await call appears in a by-name argument position, + // this has already been checked. + + val funStats :+ simpleFun = inline.transformToList(fun) + val argLists = args map inline.transformToList + val allArgStats = argLists flatMap (_.init) + val simpleArgs = argLists map (_.last) + funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs).setSymbol(tree.symbol) + + case Block(stats, expr) => + inline.transformToList(stats) ++ inline.transformToList(expr) + + case ValDef(mods, name, tpt, rhs) if containsAwait => + if (rhs exists isAwait) { + val stats :+ expr = inline.transformToList(rhs) + stats :+ ValDef(mods, name, tpt, expr).setSymbol(tree.symbol) + } else List(tree) + case Assign(name, rhs) if containsAwait => + val stats :+ expr = inline.transformToList(rhs) + stats :+ Assign(name, expr) + + case If(cond, thenp, elsep) if containsAwait => + val stats :+ expr = inline.transformToList(cond) + val thenBlock = inline.transformToBlock(thenp) + val elseBlock = inline.transformToBlock(elsep) + stats :+ + c.typeCheck(If(expr, thenBlock, elseBlock)) + + case Match(scrut, cases) if containsAwait => + val scrutStats :+ scrutExpr = inline.transformToList(scrut) + val caseDefs = cases map { + case CaseDef(pat, guard, body) => + val block = inline.transformToBlock(body) + CaseDef(pat, guard, block) + } + scrutStats :+ c.typeCheck(Match(scrutExpr, caseDefs)) + + case TypeApply(fun, targs) if containsAwait => + val funStats :+ simpleFun = inline.transformToList(fun) + funStats :+ TypeApply(simpleFun, targs).setSymbol(tree.symbol) + + case _ => + List(tree) + } } } -- cgit v1.2.3 From 073db882c3c562a845b92a84f9592a6531117190 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Thu, 22 Nov 2012 19:13:14 +0100 Subject: Add TODO comment. --- src/main/scala/scala/async/AnfTransform.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 346da57..4938cd1 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -90,7 +90,7 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { val simpleArgs = argLists map (_.last) funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs).setSymbol(tree.symbol) - case Block(stats, expr) => + case Block(stats, expr) => // TODO figure out why adding a guard `if containsAwait` breaks LocalClasses0Spec. inline.transformToList(stats) ++ inline.transformToList(expr) case ValDef(mods, name, tpt, rhs) if containsAwait => -- cgit v1.2.3 From 8ff80d52047360f3236fcbc8e7849d388c4aa744 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Thu, 22 Nov 2012 21:09:18 +0100 Subject: Minor refactoring in ANF transform. --- src/main/scala/scala/async/AnfTransform.scala | 58 +++++++++++++-------------- 1 file changed, 27 insertions(+), 31 deletions(-) diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 4938cd1..24f37e7 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -6,38 +6,28 @@ import scala.reflect.macros.Context class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { import c.universe._ - import AsyncUtils._ object inline { def transformToList(tree: Tree): List[Tree] = { val stats :+ expr = anf.transformToList(tree) expr match { - case Apply(fun, args) if isAwait(fun) => - val liftedName = c.fresh("await$") - stats :+ ValDef(NoMods, liftedName, TypeTree(), expr) :+ Ident(liftedName) + val valDef = defineVal("await", expr) + stats :+ valDef :+ Ident(valDef.name) case If(cond, thenp, elsep) => // if type of if-else is Unit don't introduce assignment, // but add Unit value to bring it into form expected by async transform if (expr.tpe =:= definitions.UnitTpe) { stats :+ expr :+ Literal(Constant(())) - } - else { - val liftedName = c.fresh("ifres$") - val varDef = - ValDef(Modifiers(Flag.MUTABLE), liftedName, TypeTree(expr.tpe), defaultValue(expr.tpe)) - val thenWithAssign = thenp match { - case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(liftedName), thenExpr)) - case _ => Assign(Ident(liftedName), thenp) - } - val elseWithAssign = elsep match { - case Block(elseStats, elseExpr) => Block(elseStats, Assign(Ident(liftedName), elseExpr)) - case _ => Assign(Ident(liftedName), elsep) + } else { + val varDef = defineVar("ifres", expr.tpe) + def branchWithAssign(orig: Tree) = orig match { + case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.name), thenExpr)) + case _ => Assign(Ident(varDef.name), orig) } - val ifWithAssign = - If(cond, thenWithAssign, elseWithAssign) - stats :+ varDef :+ ifWithAssign :+ Ident(liftedName) + val ifWithAssign = If(cond, branchWithAssign(thenp), branchWithAssign(elsep)) + stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name) } case Match(scrut, cases) => @@ -47,29 +37,35 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { stats :+ expr :+ Literal(Constant(())) } else { - val liftedName = c.fresh("matchres$") - val varDef = - ValDef(Modifiers(Flag.MUTABLE), liftedName, TypeTree(expr.tpe), defaultValue(expr.tpe)) + val varDef = defineVar("matchres", expr.tpe) val casesWithAssign = cases map { - case CaseDef(pat, guard, Block(caseStats, caseExpr)) => CaseDef(pat, guard, Block(caseStats, Assign(Ident(liftedName), caseExpr))) - case CaseDef(pat, guard, tree) => CaseDef(pat, guard, Assign(Ident(liftedName), tree)) + case CaseDef(pat, guard, Block(caseStats, caseExpr)) => CaseDef(pat, guard, Block(caseStats, Assign(Ident(varDef.name), caseExpr))) + case CaseDef(pat, guard, body) => CaseDef(pat, guard, Assign(Ident(varDef.name), body)) } val matchWithAssign = Match(scrut, casesWithAssign) - stats :+ varDef :+ matchWithAssign :+ Ident(liftedName) + stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name) } - case _ => + case _ => stats :+ expr } } def transformToList(trees: List[Tree]): List[Tree] = trees match { case fst :: rest => transformToList(fst) ++ transformToList(rest) - case Nil => Nil + case Nil => Nil } def transformToBlock(tree: Tree): Block = transformToList(tree) match { case stats :+ expr => Block(stats, expr) } + + def liftedName(prefix: String) = c.fresh(prefix + "$") + + private def defineVar(prefix: String, tp: Type): ValDef = + ValDef(Modifiers(Flag.MUTABLE), liftedName(prefix), TypeTree(tp), defaultValue(tp)) + + private def defineVal(prefix: String, lhs: Tree): ValDef = + ValDef(NoMods, liftedName(prefix), TypeTree(), lhs) } object anf { @@ -78,7 +74,7 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { tree match { case Select(qual, sel) if containsAwait => val stats :+ expr = inline.transformToList(qual) - stats :+ Select(expr, sel) + stats :+ Select(expr, sel).setSymbol(tree.symbol) case Apply(fun, args) if containsAwait => // we an assume that no await call appears in a by-name argument position, @@ -91,16 +87,16 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs).setSymbol(tree.symbol) case Block(stats, expr) => // TODO figure out why adding a guard `if containsAwait` breaks LocalClasses0Spec. - inline.transformToList(stats) ++ inline.transformToList(expr) + inline.transformToList(stats :+ expr) case ValDef(mods, name, tpt, rhs) if containsAwait => if (rhs exists isAwait) { val stats :+ expr = inline.transformToList(rhs) stats :+ ValDef(mods, name, tpt, expr).setSymbol(tree.symbol) } else List(tree) - case Assign(name, rhs) if containsAwait => + case Assign(lhs, rhs) if containsAwait => val stats :+ expr = inline.transformToList(rhs) - stats :+ Assign(name, expr) + stats :+ Assign(lhs, expr) case If(cond, thenp, elsep) if containsAwait => val stats :+ expr = inline.transformToList(cond) -- cgit v1.2.3