aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorphaller <hallerp@gmail.com>2012-10-29 14:41:42 +0100
committerphaller <hallerp@gmail.com>2012-10-29 14:41:42 +0100
commit23403a5ba6e7d045231d57572813859f6d344377 (patch)
tree18a1f90ce0a75214154c18633c67025efbdd71c9
parent4ce496401b8e2c999bb69593ad31bc39ea2bfd0c (diff)
downloadscala-async-23403a5ba6e7d045231d57572813859f6d344377.tar.gz
scala-async-23403a5ba6e7d045231d57572813859f6d344377.tar.bz2
scala-async-23403a5ba6e7d045231d57572813859f6d344377.zip
WIP: support await in if-else expressions
-rw-r--r--src/async/library/scala/async/Async.scala276
-rw-r--r--src/async/library/scala/async/AsyncUtils.scala2
-rw-r--r--src/async/test/async-spec/if-else0.scala82
-rwxr-xr-xtest-if-else.sh5
4 files changed, 326 insertions, 39 deletions
diff --git a/src/async/library/scala/async/Async.scala b/src/async/library/scala/async/Async.scala
index 92d6fcd..53ed062 100644
--- a/src/async/library/scala/async/Async.scala
+++ b/src/async/library/scala/async/Async.scala
@@ -20,6 +20,8 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
import c.universe._
import Flag._
+ private val awaitMethod = awaitSym(c)
+
/* Make a partial function literal handling case #num:
*
* {
@@ -49,6 +51,20 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
Ident(newTermName("state")),
Apply(Select(Ident(newTermName("state")), newTermName("$plus")), List(Literal(Constant(1)))))
+ def mkStateTree(nextState: Int): c.Tree =
+ Assign(
+ Ident(newTermName("state")),
+ Literal(Constant(nextState)))
+
+ def mkVarDefTree(resultType: c.universe.Type, resultName: c.universe.TermName): c.Tree = {
+ val rhs =
+ if (resultType <:< definitions.IntTpe) Literal(Constant(0))
+ else if (resultType <:< definitions.LongTpe) Literal(Constant(0L))
+ else if (resultType <:< definitions.BooleanTpe) Literal(Constant(false))
+ else Literal(Constant(null))
+ ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs)
+ }
+
def mkHandlerTree(num: Int, rhs: c.Tree): c.Tree = {
val partFunIdent = Ident(c.mirror.staticClass("scala.PartialFunction"))
val intIdent = Ident(definitions.IntClass)
@@ -83,23 +99,50 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
)
}
- class AsyncState(stats: List[c.Tree]) {
+ class AsyncState(stats: List[c.Tree], protected val state: Int, protected val nextState: Int) {
val body: c.Tree =
if (stats.size == 1) stats.head
else Block(stats: _*)
- def mkHandlerTreeForState(num: Int): c.Tree =
- mkHandlerTree(num, Block((stats :+ mkIncrStateTree()): _*))
+ val varDefs: List[(c.universe.TermName, c.universe.Type)] = List()
+
+ def mkHandlerTreeForState(): c.Tree =
+ mkHandlerTree(state, Block((stats :+ mkStateTree(nextState)): _*))
+
+ def mkHandlerTreeForState(nextState: Int): c.Tree =
+ mkHandlerTree(state, Block((stats :+ mkStateTree(nextState)): _*))
def varDefForResult: Option[c.Tree] =
None
+
+ def allVarDefs: List[c.Tree] =
+ varDefForResult.toList ++ varDefs.map(p => mkVarDefTree(p._2, p._1))
+
+ override val toString: String =
+ s"AsyncState #$state, next = $nextState"
}
- abstract class AsyncStateWithAwait(stats: List[c.Tree]) extends AsyncState(stats) {
+ class AsyncStateWithIf(stats: List[c.Tree], state: Int)
+ extends AsyncState(stats, state, 0) { // nextState unused, since encoded in then and else branches
+
+ override def mkHandlerTreeForState(): c.Tree =
+ mkHandlerTree(state, Block(stats: _*))
+
+ //TODO mkHandlerTreeForState(nextState: Int)
+
+ override val toString: String =
+ s"AsyncStateWithIf #$state, next = $nextState"
+ }
+
+ abstract class AsyncStateWithAwait(stats: List[c.Tree], state: Int, nextState: Int)
+ extends AsyncState(stats, state, nextState) {
val awaitable: c.Tree
val resultName: c.universe.TermName
val resultType: c.universe.Type
+ override val toString: String =
+ s"AsyncStateWithAwait #$state, next = $nextState"
+
/* Make an `onComplete` invocation:
*
* awaitable.onComplete {
@@ -159,6 +202,36 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
)
}
+ /* Make an `onComplete` invocation which sets the state to `nextState` upon resuming:
+ *
+ * awaitable.onComplete {
+ * case tr =>
+ * resultName = tr.get
+ * state = `nextState`
+ * resume()
+ * }
+ */
+ def mkOnCompleteStateTree(nextState: Int): c.Tree = {
+ val tryGetTree =
+ Assign(
+ Ident(resultName.toString),
+ Select(Ident("tr"), c.universe.newTermName("get"))
+ )
+ val handlerTree =
+ Match(
+ EmptyTree,
+ List(
+ CaseDef(Bind(c.universe.newTermName("tr"), Ident("_")), EmptyTree,
+ Block(tryGetTree, mkStateTree(nextState), Apply(Ident("resume"), List())) // rhs of case
+ )
+ )
+ )
+ Apply(
+ Select(awaitable, c.universe.newTermName("onComplete")),
+ List(handlerTree)
+ )
+ }
+
/* Make a partial function literal handling case #num:
*
* {
@@ -189,9 +262,14 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
* }
* }
*/
- override def mkHandlerTreeForState(num: Int): c.Tree = {
+ override def mkHandlerTreeForState(): c.Tree = {
+ assert(awaitable != null)
+ mkHandlerTree(state, Block((stats :+ mkOnCompleteIncrStateTree): _*))
+ }
+
+ override def mkHandlerTreeForState(nextState: Int): c.Tree = {
assert(awaitable != null)
- builder.mkHandlerTree(num, Block((stats :+ mkOnCompleteIncrStateTree): _*))
+ mkHandlerTree(state, Block((stats :+ mkOnCompleteStateTree(nextState)): _*))
}
//TODO: complete for other primitive types, how to handle value classes?
@@ -210,7 +288,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
/*
* Builder for a single state of an async method.
*/
- class AsyncStateBuilder extends Builder[c.Tree, AsyncState] {
+ class AsyncStateBuilder(state: Int) extends Builder[c.Tree, AsyncState] {
self =>
/* Statements preceding an await call. */
@@ -225,19 +303,32 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
/* Result type of an await call. */
var resultType: c.universe.Type = null
+ var nextState: Int = state + 1
+
+ private val varDefs = ListBuffer[(c.universe.TermName, c.universe.Type)]()
+
def += (stat: c.Tree): this.type = {
stats += c.resetAllAttrs(stat.duplicate)
this
}
+ //TODO do not ignore `mods`
+ def addVarDef(mods: Any, name: c.universe.TermName, tpt: c.Tree): this.type = {
+ varDefs += (name -> tpt.tpe)
+ this
+ }
+
def result(): AsyncState =
if (awaitable == null)
- new AsyncState(stats.toList)
+ new AsyncState(stats.toList, state, nextState) {
+ override val varDefs = self.varDefs.toList
+ }
else
- new AsyncStateWithAwait(stats.toList) {
+ new AsyncStateWithAwait(stats.toList, state, nextState) {
val awaitable = self.awaitable
val resultName = self.resultName
val resultType = self.resultType
+ override val varDefs = self.varDefs.toList
}
def clear(): Unit = {
@@ -254,56 +345,164 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
* @param awaitResultName the name of the variable that the result of await is assigned to
* @param awaitResultType the type of the result of await
*/
- def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: Tree): this.type = {
+ def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: Tree, nextState: Int = state + 1): this.type = {
awaitable = c.resetAllAttrs(awaitArg.duplicate)
resultName = awaitResultName
resultType = awaitResultType.tpe
+ this.nextState = nextState
this
}
+ def complete(nextState: Int): this.type = {
+ this.nextState = nextState
+ this
+ }
+
+ def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = {
+ // 1. build changed if-else tree
+ // 2. insert that tree at the end of the current state
+ val cond = c.resetAllAttrs(condTree.duplicate)
+ this += If(cond, mkStateTree(thenState), mkStateTree(elseState))
+ new AsyncStateWithIf(stats.toList, state) {
+ override val varDefs = self.varDefs.toList
+ }
+ }
+
override def toString: String = {
val statsBeforeAwait = stats.mkString("\n")
s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName"
}
}
- class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int) {
+ /* current issue:
+ def m2(y: Int): Future[Int] = async {
+ val f = m1(y)
+ if (y > 0) {
+ val x = await(f)
+ x + 2
+ } else {
+ val x = await(f)
+ x - 2
+ }
+ }
+
+ */
+ class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, budget: Int) {
val asyncStates = ListBuffer[builder.AsyncState]()
- private var stateBuilder = new builder.AsyncStateBuilder // current state builder
- private val awaitMethod = awaitSym(c)
+
+ private var stateBuilder = new builder.AsyncStateBuilder(startState) // current state builder
+ private var currState = startState
+
+ private var remainingBudget = budget
// populate asyncStates
for (stat <- stats) stat match {
// the val name = await(..) pattern
case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == awaitMethod =>
asyncStates += stateBuilder.complete(args(0), name, tpt).result // complete with await
- stateBuilder = new builder.AsyncStateBuilder
-
+ currState += 1
+ stateBuilder = new builder.AsyncStateBuilder(currState)
+
+ case ValDef(mods, name, tpt, rhs) =>
+ stateBuilder.addVarDef(mods, name, tpt)
+ stateBuilder += // instead of adding `stat` we add a simple assignment
+ Assign(Ident(name), c.resetAllAttrs(rhs.duplicate))
+
+ case If(cond, thenp, elsep) =>
+ val ifBudget: Int = remainingBudget / 2
+ remainingBudget -= ifBudget
+ println(s"ASYNC IF: ifBudget = $ifBudget")
+ // state that we continue with after if-else: currState + ifBudget
+
+ val thenBudget: Int = ifBudget / 2
+ val elseBudget = ifBudget - thenBudget
+
+ asyncStates +=
+ stateBuilder.resultWithIf(cond, currState + 1, currState + thenBudget)
+
+ val thenBuilder = thenp match {
+ case Block(thenStats, thenExpr) =>
+ new AsyncBlockBuilder(thenStats, thenExpr, currState + 1, currState + ifBudget, thenBudget)
+ case _ =>
+ new AsyncBlockBuilder(List(thenp), Literal(Constant(())), currState + 1, currState + ifBudget, thenBudget)
+ }
+ println("ASYNC IF: states of thenp:")
+ for (s <- thenBuilder.asyncStates)
+ println(s.toString)
+
+ // insert states of thenBuilder into asyncStates
+ asyncStates ++= thenBuilder.asyncStates
+
+ val elseBuilder = elsep match {
+ case Block(elseStats, elseExpr) =>
+ new AsyncBlockBuilder(elseStats, elseExpr, currState + thenBudget, currState + ifBudget, elseBudget)
+ case _ =>
+ new AsyncBlockBuilder(List(elsep), Literal(Constant(())), currState + thenBudget, currState + ifBudget, elseBudget)
+ }
+ // insert states of elseBuilder into asyncStates
+ asyncStates ++= elseBuilder.asyncStates
+
+ // create new state builder for state `currState + ifBudget`
+ currState = currState + ifBudget
+ stateBuilder = new builder.AsyncStateBuilder(currState)
+
case _ =>
stateBuilder += stat
}
// complete last state builder (representing the expressions after the last await)
- asyncStates += (stateBuilder += expr).result
+ stateBuilder += expr
+ val lastState = stateBuilder.complete(endState).result
+ asyncStates += lastState
/* Builds the handler expression for a sequence of async states.
- * Also returns the index of the last state.
*/
- def mkHandlerExpr(): (c.Expr[PartialFunction[Int, Unit]], Int) = {
- //var handlerExpr = asyncStates(0).mkHandlerForState(1) // state 0 but { case 1 => ... }
- var handlerTree = asyncStates(0).mkHandlerTreeForState(0)
- var handlerExpr = c.Expr(handlerTree).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]]
-
- var i = 1
- for (asyncState <- asyncStates.tail.init) {
- //val handlerForNextState = asyncStates(i).mkHandlerForState(i+1)
- val handlerTreeForNextState = asyncState.mkHandlerTreeForState(i)
+ def mkHandlerExpr(): c.Expr[PartialFunction[Int, Unit]] = {
+ assert(asyncStates.size > 1)
+
+ println(s"!!ASYNC mkHandlerExpr: asyncStates.size = ${asyncStates.size}")
+ println(s"!!ASYNC state 0: ${asyncStates(0)}")
+
+ var handlerTree =
+ if (asyncStates.size > 2) asyncStates(0).mkHandlerTreeForState()
+ else asyncStates(0).mkHandlerTreeForState(endState)
+ var handlerExpr =
+ c.Expr(handlerTree).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]]
+
+ if (asyncStates.size == 2)
+ handlerExpr
+ else if (asyncStates.size == 3) {
+ // asyncStates(1) must continue with endState
+ val handlerTreeForLastState = asyncStates(1).mkHandlerTreeForState(endState)
+ val currentHandlerTreeNaked = c.resetAllAttrs(handlerExpr.tree.duplicate)
+ c.Expr(
+ Apply(Select(currentHandlerTreeNaked, newTermName("orElse")),
+ List(handlerTreeForLastState))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]]
+ } else { // asyncStates.size > 3
+ var i = startState + 1
+
+ println("!!ASYNC start for loop")
+
+ // do not traverse first state: asyncStates.tail
+ // do not traverse last state: asyncStates.tail.init
+ // handle second to last state specially: asyncStates.tail.init.init
+ for (asyncState <- asyncStates.tail.init.init) {
+ println(s"!!ASYNC current asyncState: $asyncState")
+ val handlerTreeForNextState = asyncState.mkHandlerTreeForState()
+ val currentHandlerTreeNaked = c.resetAllAttrs(handlerExpr.tree.duplicate)
+ handlerExpr = c.Expr(
+ Apply(Select(currentHandlerTreeNaked, newTermName("orElse")),
+ List(handlerTreeForNextState))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]]
+ i += 1
+ }
+
+ val lastState = asyncStates.tail.init.last
+ println(s"!!ASYNC current asyncState (forced to $endState): $lastState")
+ val handlerTreeForLastState = lastState.mkHandlerTreeForState(endState)
val currentHandlerTreeNaked = c.resetAllAttrs(handlerExpr.tree.duplicate)
- handlerExpr = c.Expr(
- Apply(Select(currentHandlerTreeNaked, newTermName("orElse")), List(handlerTreeForNextState))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]]
- i += 1
+ c.Expr(
+ Apply(Select(currentHandlerTreeNaked, newTermName("orElse")),
+ List(handlerTreeForLastState))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]]
}
- // asyncStates(i) does not end with `await` (asyncStates(i).awaitable == null)
- (handlerExpr, i)
}
}
@@ -327,28 +526,29 @@ object Async extends AsyncUtils {
body.tree match {
case Block(stats, expr) =>
- val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000)
+ val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000)
- vprintln("states of current method:")
+ vprintln(s"states of current method (${ asyncBlockBuilder.asyncStates }):")
asyncBlockBuilder.asyncStates foreach vprintln
- val (handlerExpr, indexOfLastState) = asyncBlockBuilder.mkHandlerExpr()
+ val handlerExpr = asyncBlockBuilder.mkHandlerExpr()
- vprintln(s"GENERATED handler expr ($indexOfLastState):")
+ vprintln(s"GENERATED handler expr:")
vprintln(handlerExpr)
val localVarDefs = ListBuffer[c.Tree]()
for (state <- asyncBlockBuilder.asyncStates.init) // exclude last state (doesn't have await result)
- localVarDefs ++= state.varDefForResult.toList
+ localVarDefs ++= //state.varDefForResult.toList
+ state.allVarDefs
// pad up to 5 var defs
if (localVarDefs.size < 5)
for (_ <- localVarDefs.size until 5) localVarDefs += EmptyTree
val handlerForLastState: c.Expr[PartialFunction[Int, Unit]] = {
val tree = Apply(Select(Ident("result"), c.universe.newTermName("success")),
- List(asyncBlockBuilder.asyncStates(indexOfLastState).body))
+ List(asyncBlockBuilder.asyncStates.last.body))
//builder.mkHandler(indexOfLastState + 1, c.Expr[Unit](tree))
- builder.mkHandler(indexOfLastState, c.Expr[Unit](tree))
+ builder.mkHandler(1000, c.Expr[Unit](tree))
}
vprintln("GENERATED handler for last state:")
diff --git a/src/async/library/scala/async/AsyncUtils.scala b/src/async/library/scala/async/AsyncUtils.scala
index 820541b..98330a5 100644
--- a/src/async/library/scala/async/AsyncUtils.scala
+++ b/src/async/library/scala/async/AsyncUtils.scala
@@ -10,7 +10,7 @@ import scala.reflect.macros.Context
*/
trait AsyncUtils {
- val verbose = false
+ val verbose = true
protected def vprintln(s: Any): Unit = if (verbose)
println("[async] "+s)
diff --git a/src/async/test/async-spec/if-else0.scala b/src/async/test/async-spec/if-else0.scala
new file mode 100644
index 0000000..2f60901
--- /dev/null
+++ b/src/async/test/async-spec/if-else0.scala
@@ -0,0 +1,82 @@
+/**
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+package scala.async
+
+import language.{ reflectiveCalls, postfixOps }
+import scala.concurrent.{ Future, ExecutionContext, future, Await }
+import scala.concurrent.duration._
+import scala.async.Async.{ async, await }
+
+
+object Test extends App {
+
+ IfElseSpec.check()
+
+}
+
+
+class TestIfElseClass {
+ import ExecutionContext.Implicits.global
+
+ def m1(x: Int): Future[Int] = future {
+ Thread.sleep(1000)
+ x + 2
+ }
+
+ def m2(y: Int): Future[Int] = async {
+ val f = m1(y)
+ var z = 0
+ if (y > 0) {
+ val x1 = await(f)
+ z = x1 + 2
+ } else {
+ val x2 = await(f)
+ z = x2 - 2
+ }
+ z
+ }
+/*
+ def m3(y: Int): Future[Int] = async {
+ val f1 = m1(y)
+ val x1 = await(f1)
+ val f2 = m1(y + 2)
+ val x2 = await(f2)
+ x1 + x2
+ }
+*/
+ // currently fails with: error: not found: value f2
+/*
+ def m4(y: Int): Future[Int] = async {
+ val f1 = m1(y)
+ val f2 = m1(y + 2)
+ val x1 = await(f1)
+ println("between two awaits")
+ val x2 = await(f2)
+ x1 + x2
+ }
+*/
+}
+
+
+object IfElseSpec extends MinimalScalaTest {
+
+ "An async method" should {
+ "support a simple await" in {
+ val o = new Test1Class
+ val fut = o.m2(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe(14)
+ }
+ }
+/*
+ "An async method" should {
+ "support several awaits in sequence" in {
+ val o = new Test1Class
+ val fut = o.m3(10)
+ val res = Await.result(fut, 4 seconds)
+ res mustBe(26)
+ }
+ }
+*/
+}
diff --git a/test-if-else.sh b/test-if-else.sh
new file mode 100755
index 0000000..64b7276
--- /dev/null
+++ b/test-if-else.sh
@@ -0,0 +1,5 @@
+#!/bin/bash
+mkdir -p test-classes
+scalac -cp classes -d test-classes src/async/test/async-spec/MinimalScalaTest.scala
+scalac -Xprint:typer -cp classes:test-classes -d test-classes src/async/test/async-spec/if-else0.scala
+scala -cp test-classes scala.async.Test