aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorphaller <hallerp@gmail.com>2012-11-13 01:07:28 +0100
committerphaller <hallerp@gmail.com>2012-11-16 09:42:30 +0100
commitc4ceea0ca8538297622634121b99e2357ca72acb (patch)
tree31c2265496175dde52244f2506c94997fd9194dd
parentf451904320d02c7dbe6b298f6ff790ca5cf5f080 (diff)
downloadscala-async-c4ceea0ca8538297622634121b99e2357ca72acb.tar.gz
scala-async-c4ceea0ca8538297622634121b99e2357ca72acb.tar.bz2
scala-async-c4ceea0ca8538297622634121b99e2357ca72acb.zip
Add selective ANF transform
- Does not descend into class and module defs - Adds several tests, including tests for if-else
-rw-r--r--src/main/scala/scala/async/AnfTransform.scala108
-rw-r--r--src/main/scala/scala/async/Async.scala12
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala22
-rw-r--r--src/test/scala/scala/async/run/anf/AnfTransformSpec.scala94
4 files changed, 224 insertions, 12 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala
new file mode 100644
index 0000000..86dea75
--- /dev/null
+++ b/src/main/scala/scala/async/AnfTransform.scala
@@ -0,0 +1,108 @@
+package scala.async
+
+import scala.reflect.macros.Context
+
+class AnfTransform[C <: Context](val c: C) {
+ import c.universe._
+ import AsyncUtils._
+
+ object inline {
+ //TODO: DRY
+ private def defaultValue(tpe: Type): Literal = {
+ val defaultValue: Any =
+ if (tpe <:< definitions.BooleanTpe) false
+ else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0
+ else if (tpe <:< definitions.AnyValTpe) 0
+ else null
+ Literal(Constant(defaultValue))
+ }
+
+ def transformToList(tree: Tree): List[Tree] = {
+ val stats :+ expr = anf.transformToList(tree)
+ expr match {
+
+ case Apply(fun, args) if { vprintln("check fun.toString: " + fun.toString); fun.toString.startsWith("scala.async.Async.await") } =>
+ vprintln("found await!!")
+ val liftedName = c.fresh("await$")
+ stats :+ ValDef(NoMods, liftedName, TypeTree(), expr) :+ Ident(liftedName)
+
+ case If(cond, thenp, elsep) =>
+ val liftedName = c.fresh("ifres$")
+ val varDef =
+ ValDef(Modifiers(Flag.MUTABLE), liftedName, TypeTree(expr.tpe), defaultValue(expr.tpe))
+ val thenWithAssign = thenp match {
+ case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(liftedName), thenExpr))
+ case _ => Assign(Ident(liftedName), thenp)
+ }
+ val elseWithAssign = elsep match {
+ case Block(elseStats, elseExpr) => Block(elseStats, Assign(Ident(liftedName), elseExpr))
+ case _ => Assign(Ident(liftedName), elsep)
+ }
+ val ifWithAssign =
+ If(cond, thenWithAssign, elseWithAssign)
+ stats :+ varDef :+ ifWithAssign :+ Ident(liftedName)
+
+ case _ =>
+ vprintln("found something else")
+ stats :+ expr
+ }
+ }
+
+ def transformToList(trees: List[Tree]): List[Tree] = trees match {
+ case fst :: rest => transformToList(fst) ++ transformToList(rest)
+ case Nil => Nil
+ }
+ }
+
+ object anf {
+ def transformToList(tree: Tree): List[Tree] = tree match {
+ case Select(qual, sel) =>
+ val stats :+ expr = inline.transformToList(qual)
+ stats :+ Select(expr, sel)
+
+ case Apply(fun, args) =>
+ val funStats :+ simpleFun = inline.transformToList(fun)
+ val argLists = args map inline.transformToList
+ val allArgStats = argLists flatMap (_.init)
+ val simpleArgs = argLists map (_.last)
+ funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs)
+
+ case Block(stats, expr) =>
+ inline.transformToList(stats) ++ inline.transformToList(expr)
+
+ case ValDef(mods, name, tpt, rhs) =>
+ val stats :+ expr = inline.transformToList(rhs)
+ stats :+ ValDef(mods, name, tpt, expr)
+
+ case Assign(name, rhs) =>
+ val stats :+ expr = inline.transformToList(rhs)
+ stats :+ Assign(name, expr)
+
+ case If(cond, thenp, elsep) =>
+ val stats :+ expr = inline.transformToList(cond)
+ val thenStats :+ thenExpr = inline.transformToList(thenp)
+ val elseStats :+ elseExpr = inline.transformToList(elsep)
+ stats :+
+ c.typeCheck(If(expr, Block(thenStats, thenExpr), Block(elseStats, elseExpr)),
+ lub(List(thenp.tpe, elsep.tpe)))
+
+ //TODO
+ case Literal(_) | Ident(_) | This(_) | Match(_, _) | New(_) | Function(_, _) => List(tree)
+
+ case TypeApply(fun, targs) =>
+ val funStats :+ simpleFun = inline.transformToList(fun)
+ funStats :+ TypeApply(simpleFun, targs)
+
+ //TODO
+ case DefDef(mods, name, tparams, vparamss, tpt, rhs) => List(tree)
+
+ case ClassDef(mods, name, tparams, impl) => List(tree)
+
+ case ModuleDef(mods, name, impl) => List(tree)
+
+ case _ =>
+ println("do not handle tree "+tree)
+ ???
+ }
+ }
+}
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala
index b71ce74..8fc7ccf 100644
--- a/src/main/scala/scala/async/Async.scala
+++ b/src/main/scala/scala/async/Async.scala
@@ -74,7 +74,15 @@ abstract class AsyncBase {
import builder.defn._
import builder.name
import builder.futureSystemOps
- val (stats, expr) = body.tree match {
+
+ val transform = new AnfTransform[c.type](c)
+ val typedBody = c.typeCheck(body.tree)
+ val stats1 :+ expr1 = transform.anf.transformToList(typedBody)
+ val btree = c.typeCheck(Block(stats1, expr1))
+
+ AsyncUtils.vprintln(s"ANF transform expands to:\n $btree")
+
+ val (stats, expr) = btree match {
case Block(stats, expr) => (stats, expr)
case tree => (Nil, tree)
}
@@ -86,7 +94,7 @@ abstract class AsyncBase {
val handlerCases: List[CaseDef] = asyncBlockBuilder.mkCombinedHandlerCases[T]()
val initStates = asyncBlockBuilder.asyncStates.init
- val localVarTrees = initStates.flatMap(_.allVarDefs).toList
+ val localVarTrees = asyncBlockBuilder.asyncStates.flatMap(_.allVarDefs).toList
/*
lazy val onCompleteHandler = (tr: Try[Any]) => state match {
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index 2b35ff4..4ace31c 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -54,6 +54,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
val defaultValue: Any =
if (tpe <:< definitions.BooleanTpe) false
else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0
+ else if (tpe <:< definitions.AnyValTpe) 0
else null
Literal(Constant(defaultValue))
}
@@ -142,7 +143,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
/*
* Builder for a single state of an async method.
*/
- class AsyncStateBuilder(state: Int, private var nameMap: Map[c.Symbol, c.Name]) {
+ class AsyncStateBuilder(state: Int, private var nameMap: Map[String, c.Name]) {
self =>
/* Statements preceding an await call. */
@@ -163,8 +164,8 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
private val renamer = new Transformer {
override def transform(tree: Tree) = tree match {
- case Ident(_) if nameMap.keySet contains tree.symbol =>
- Ident(nameMap(tree.symbol))
+ case Ident(_) if nameMap.keySet contains tree.symbol.toString =>
+ Ident(nameMap(tree.symbol.toString))
case _ =>
super.transform(tree)
}
@@ -178,7 +179,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
//TODO do not ignore `mods`
def addVarDef(mods: Any, name: TermName, tpt: c.Tree, rhs: c.Tree, extNameMap: Map[c.Symbol, c.Name]): this.type = {
varDefs += (name -> tpt.tpe)
- nameMap ++= extNameMap // update name map
+ nameMap ++= extNameMap.map { case (k, v) => (k.toString, v) } // update name map
this += Assign(Ident(name), rhs)
this
}
@@ -205,8 +206,9 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
*/
def complete(awaitArg: c.Tree, awaitResultName: TermName, awaitResultType: Tree,
extNameMap: Map[c.Symbol, c.Name], nextState: Int = state + 1): this.type = {
- nameMap ++= extNameMap
- awaitable = resetDuplicate(renamer.transform(awaitArg))
+ nameMap ++= extNameMap.map { case (k, v) => (k.toString, v) }
+ val renamed = renamer.transform(awaitArg)
+ awaitable = resetDuplicate(renamed)
resultName = awaitResultName
resultType = awaitResultType.tpe
this.nextState = nextState
@@ -273,7 +275,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
budget: Int, private var toRename: Map[c.Symbol, c.Name]) {
val asyncStates = ListBuffer[builder.AsyncState]()
- private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename)
+ private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename.map { case (k, v) => (k.toString, v) })
// current state builder
private var currState = startState
@@ -306,7 +308,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
else
assert(false, "too many invocations of `await` in current method")
currState += 1
- stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
+ stateBuilder = new builder.AsyncStateBuilder(currState, toRename.map { case (k, v) => (k.toString, v) })
case ValDef(mods, name, tpt, rhs) =>
checkForUnsupportedAwait(rhs)
@@ -339,7 +341,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
// create new state builder for state `currState + ifBudget`
currState = currState + ifBudget
- stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
+ stateBuilder = new builder.AsyncStateBuilder(currState, toRename.map { case (k, v) => (k.toString, v) })
case Match(scrutinee, cases) =>
vprintln("transforming match expr: " + stat)
@@ -366,7 +368,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
// create new state builder for state `currState + matchBudget`
currState = currState + matchBudget
- stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
+ stateBuilder = new builder.AsyncStateBuilder(currState, toRename.map { case (k, v) => (k.toString, v) })
case ClassDef(_, name, _, _) =>
// do not allow local class definitions, because of SI-5467 (specific to case classes, though)
diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
new file mode 100644
index 0000000..f38efa9
--- /dev/null
+++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
@@ -0,0 +1,94 @@
+/**
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async
+package run
+package anf
+
+import language.{reflectiveCalls, postfixOps}
+import scala.concurrent.{Future, ExecutionContext, future, Await}
+import scala.concurrent.duration._
+import scala.async.Async.{async, await}
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+
+
+class AnfTestClass {
+
+ import ExecutionContext.Implicits.global
+
+ def base(x: Int): Future[Int] = future {
+ x + 2
+ }
+
+ def m(y: Int): Future[Int] = async {
+ val f = base(y)
+ await(f)
+ }
+
+ def m2(y: Int): Future[Int] = async {
+ val f = base(y)
+ val f2 = base(y + 1)
+ await(f) + await(f2)
+ }
+
+ def m3(y: Int): Future[Int] = async {
+ val f = base(y)
+ var z = 0
+ if (y > 0) {
+ z = await(f) + 2
+ } else {
+ z = await(f) - 2
+ }
+ z
+ }
+
+ def m4(y: Int): Future[Int] = async {
+ val f = base(y)
+ val z = if (y > 0) {
+ await(f) + 2
+ } else {
+ await(f) - 2
+ }
+ z + 1
+ }
+}
+
+
+@RunWith(classOf[JUnit4])
+class AnfTransformSpec {
+
+ @Test
+ def `simple ANF transform`() {
+ val o = new AnfTestClass
+ val fut = o.m(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (12)
+ }
+
+ @Test
+ def `simple ANF transform 2`() {
+ val o = new AnfTestClass
+ val fut = o.m2(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (25)
+ }
+
+ @Test
+ def `simple ANF transform 3`() {
+ val o = new AnfTestClass
+ val fut = o.m3(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (14)
+ }
+
+ @Test
+ def `ANF transform of assigning the result of an if-else`() {
+ val o = new AnfTestClass
+ val fut = o.m4(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (15)
+ }
+}