From c4ceea0ca8538297622634121b99e2357ca72acb Mon Sep 17 00:00:00 2001 From: phaller Date: Tue, 13 Nov 2012 01:07:28 +0100 Subject: Add selective ANF transform - Does not descend into class and module defs - Adds several tests, including tests for if-else --- src/main/scala/scala/async/AnfTransform.scala | 108 +++++++++++++++++++++ src/main/scala/scala/async/Async.scala | 12 ++- src/main/scala/scala/async/ExprBuilder.scala | 22 +++-- .../scala/async/run/anf/AnfTransformSpec.scala | 94 ++++++++++++++++++ 4 files changed, 224 insertions(+), 12 deletions(-) create mode 100644 src/main/scala/scala/async/AnfTransform.scala create mode 100644 src/test/scala/scala/async/run/anf/AnfTransformSpec.scala diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala new file mode 100644 index 0000000..86dea75 --- /dev/null +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -0,0 +1,108 @@ +package scala.async + +import scala.reflect.macros.Context + +class AnfTransform[C <: Context](val c: C) { + import c.universe._ + import AsyncUtils._ + + object inline { + //TODO: DRY + private def defaultValue(tpe: Type): Literal = { + val defaultValue: Any = + if (tpe <:< definitions.BooleanTpe) false + else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0 + else if (tpe <:< definitions.AnyValTpe) 0 + else null + Literal(Constant(defaultValue)) + } + + def transformToList(tree: Tree): List[Tree] = { + val stats :+ expr = anf.transformToList(tree) + expr match { + + case Apply(fun, args) if { vprintln("check fun.toString: " + fun.toString); fun.toString.startsWith("scala.async.Async.await") } => + vprintln("found await!!") + val liftedName = c.fresh("await$") + stats :+ ValDef(NoMods, liftedName, TypeTree(), expr) :+ Ident(liftedName) + + case If(cond, thenp, elsep) => + val liftedName = c.fresh("ifres$") + val varDef = + ValDef(Modifiers(Flag.MUTABLE), liftedName, TypeTree(expr.tpe), defaultValue(expr.tpe)) + val thenWithAssign = thenp match { + case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(liftedName), thenExpr)) + case _ => Assign(Ident(liftedName), thenp) + } + val elseWithAssign = elsep match { + case Block(elseStats, elseExpr) => Block(elseStats, Assign(Ident(liftedName), elseExpr)) + case _ => Assign(Ident(liftedName), elsep) + } + val ifWithAssign = + If(cond, thenWithAssign, elseWithAssign) + stats :+ varDef :+ ifWithAssign :+ Ident(liftedName) + + case _ => + vprintln("found something else") + stats :+ expr + } + } + + def transformToList(trees: List[Tree]): List[Tree] = trees match { + case fst :: rest => transformToList(fst) ++ transformToList(rest) + case Nil => Nil + } + } + + object anf { + def transformToList(tree: Tree): List[Tree] = tree match { + case Select(qual, sel) => + val stats :+ expr = inline.transformToList(qual) + stats :+ Select(expr, sel) + + case Apply(fun, args) => + val funStats :+ simpleFun = inline.transformToList(fun) + val argLists = args map inline.transformToList + val allArgStats = argLists flatMap (_.init) + val simpleArgs = argLists map (_.last) + funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs) + + case Block(stats, expr) => + inline.transformToList(stats) ++ inline.transformToList(expr) + + case ValDef(mods, name, tpt, rhs) => + val stats :+ expr = inline.transformToList(rhs) + stats :+ ValDef(mods, name, tpt, expr) + + case Assign(name, rhs) => + val stats :+ expr = inline.transformToList(rhs) + stats :+ Assign(name, expr) + + case If(cond, thenp, elsep) => + val stats :+ expr = inline.transformToList(cond) + val thenStats :+ thenExpr = inline.transformToList(thenp) + val elseStats :+ elseExpr = inline.transformToList(elsep) + stats :+ + c.typeCheck(If(expr, Block(thenStats, thenExpr), Block(elseStats, elseExpr)), + lub(List(thenp.tpe, elsep.tpe))) + + //TODO + case Literal(_) | Ident(_) | This(_) | Match(_, _) | New(_) | Function(_, _) => List(tree) + + case TypeApply(fun, targs) => + val funStats :+ simpleFun = inline.transformToList(fun) + funStats :+ TypeApply(simpleFun, targs) + + //TODO + case DefDef(mods, name, tparams, vparamss, tpt, rhs) => List(tree) + + case ClassDef(mods, name, tparams, impl) => List(tree) + + case ModuleDef(mods, name, impl) => List(tree) + + case _ => + println("do not handle tree "+tree) + ??? + } + } +} diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index b71ce74..8fc7ccf 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -74,7 +74,15 @@ abstract class AsyncBase { import builder.defn._ import builder.name import builder.futureSystemOps - val (stats, expr) = body.tree match { + + val transform = new AnfTransform[c.type](c) + val typedBody = c.typeCheck(body.tree) + val stats1 :+ expr1 = transform.anf.transformToList(typedBody) + val btree = c.typeCheck(Block(stats1, expr1)) + + AsyncUtils.vprintln(s"ANF transform expands to:\n $btree") + + val (stats, expr) = btree match { case Block(stats, expr) => (stats, expr) case tree => (Nil, tree) } @@ -86,7 +94,7 @@ abstract class AsyncBase { val handlerCases: List[CaseDef] = asyncBlockBuilder.mkCombinedHandlerCases[T]() val initStates = asyncBlockBuilder.asyncStates.init - val localVarTrees = initStates.flatMap(_.allVarDefs).toList + val localVarTrees = asyncBlockBuilder.asyncStates.flatMap(_.allVarDefs).toList /* lazy val onCompleteHandler = (tr: Try[Any]) => state match { diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 2b35ff4..4ace31c 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -54,6 +54,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy val defaultValue: Any = if (tpe <:< definitions.BooleanTpe) false else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0 + else if (tpe <:< definitions.AnyValTpe) 0 else null Literal(Constant(defaultValue)) } @@ -142,7 +143,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy /* * Builder for a single state of an async method. */ - class AsyncStateBuilder(state: Int, private var nameMap: Map[c.Symbol, c.Name]) { + class AsyncStateBuilder(state: Int, private var nameMap: Map[String, c.Name]) { self => /* Statements preceding an await call. */ @@ -163,8 +164,8 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy private val renamer = new Transformer { override def transform(tree: Tree) = tree match { - case Ident(_) if nameMap.keySet contains tree.symbol => - Ident(nameMap(tree.symbol)) + case Ident(_) if nameMap.keySet contains tree.symbol.toString => + Ident(nameMap(tree.symbol.toString)) case _ => super.transform(tree) } @@ -178,7 +179,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy //TODO do not ignore `mods` def addVarDef(mods: Any, name: TermName, tpt: c.Tree, rhs: c.Tree, extNameMap: Map[c.Symbol, c.Name]): this.type = { varDefs += (name -> tpt.tpe) - nameMap ++= extNameMap // update name map + nameMap ++= extNameMap.map { case (k, v) => (k.toString, v) } // update name map this += Assign(Ident(name), rhs) this } @@ -205,8 +206,9 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy */ def complete(awaitArg: c.Tree, awaitResultName: TermName, awaitResultType: Tree, extNameMap: Map[c.Symbol, c.Name], nextState: Int = state + 1): this.type = { - nameMap ++= extNameMap - awaitable = resetDuplicate(renamer.transform(awaitArg)) + nameMap ++= extNameMap.map { case (k, v) => (k.toString, v) } + val renamed = renamer.transform(awaitArg) + awaitable = resetDuplicate(renamed) resultName = awaitResultName resultType = awaitResultType.tpe this.nextState = nextState @@ -273,7 +275,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy budget: Int, private var toRename: Map[c.Symbol, c.Name]) { val asyncStates = ListBuffer[builder.AsyncState]() - private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename) + private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename.map { case (k, v) => (k.toString, v) }) // current state builder private var currState = startState @@ -306,7 +308,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy else assert(false, "too many invocations of `await` in current method") currState += 1 - stateBuilder = new builder.AsyncStateBuilder(currState, toRename) + stateBuilder = new builder.AsyncStateBuilder(currState, toRename.map { case (k, v) => (k.toString, v) }) case ValDef(mods, name, tpt, rhs) => checkForUnsupportedAwait(rhs) @@ -339,7 +341,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy // create new state builder for state `currState + ifBudget` currState = currState + ifBudget - stateBuilder = new builder.AsyncStateBuilder(currState, toRename) + stateBuilder = new builder.AsyncStateBuilder(currState, toRename.map { case (k, v) => (k.toString, v) }) case Match(scrutinee, cases) => vprintln("transforming match expr: " + stat) @@ -366,7 +368,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy // create new state builder for state `currState + matchBudget` currState = currState + matchBudget - stateBuilder = new builder.AsyncStateBuilder(currState, toRename) + stateBuilder = new builder.AsyncStateBuilder(currState, toRename.map { case (k, v) => (k.toString, v) }) case ClassDef(_, name, _, _) => // do not allow local class definitions, because of SI-5467 (specific to case classes, though) diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala new file mode 100644 index 0000000..f38efa9 --- /dev/null +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -0,0 +1,94 @@ +/** + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async +package run +package anf + +import language.{reflectiveCalls, postfixOps} +import scala.concurrent.{Future, ExecutionContext, future, Await} +import scala.concurrent.duration._ +import scala.async.Async.{async, await} +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 + + +class AnfTestClass { + + import ExecutionContext.Implicits.global + + def base(x: Int): Future[Int] = future { + x + 2 + } + + def m(y: Int): Future[Int] = async { + val f = base(y) + await(f) + } + + def m2(y: Int): Future[Int] = async { + val f = base(y) + val f2 = base(y + 1) + await(f) + await(f2) + } + + def m3(y: Int): Future[Int] = async { + val f = base(y) + var z = 0 + if (y > 0) { + z = await(f) + 2 + } else { + z = await(f) - 2 + } + z + } + + def m4(y: Int): Future[Int] = async { + val f = base(y) + val z = if (y > 0) { + await(f) + 2 + } else { + await(f) - 2 + } + z + 1 + } +} + + +@RunWith(classOf[JUnit4]) +class AnfTransformSpec { + + @Test + def `simple ANF transform`() { + val o = new AnfTestClass + val fut = o.m(10) + val res = Await.result(fut, 2 seconds) + res mustBe (12) + } + + @Test + def `simple ANF transform 2`() { + val o = new AnfTestClass + val fut = o.m2(10) + val res = Await.result(fut, 2 seconds) + res mustBe (25) + } + + @Test + def `simple ANF transform 3`() { + val o = new AnfTestClass + val fut = o.m3(10) + val res = Await.result(fut, 2 seconds) + res mustBe (14) + } + + @Test + def `ANF transform of assigning the result of an if-else`() { + val o = new AnfTestClass + val fut = o.m4(10) + val res = Await.result(fut, 2 seconds) + res mustBe (15) + } +} -- cgit v1.2.3