aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPhilipp Haller <hallerp@gmail.com>2012-11-22 06:10:00 -0800
committerPhilipp Haller <hallerp@gmail.com>2012-11-22 06:10:00 -0800
commit1c91fec998d09e31c2c52760452af1771a092182 (patch)
tree8733f9b854baa83194b1688fa30ed5fc90fd249c
parentf451904320d02c7dbe6b298f6ff790ca5cf5f080 (diff)
parent8e4a8ecdff955c4faa1dec344a2b93543ffe7d45 (diff)
downloadscala-async-1c91fec998d09e31c2c52760452af1771a092182.tar.gz
scala-async-1c91fec998d09e31c2c52760452af1771a092182.tar.bz2
scala-async-1c91fec998d09e31c2c52760452af1771a092182.zip
Merge pull request #25 from phaller/topic/minimal-var-lifting-2
Topic/minimal var lifting 2
-rw-r--r--build.sbt2
-rw-r--r--src/main/scala/scala/async/AnfTransform.scala101
-rw-r--r--src/main/scala/scala/async/Async.scala65
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala219
-rw-r--r--src/main/scala/scala/async/StateAssigner.scala10
-rw-r--r--src/main/scala/scala/async/TransformUtils.scala22
-rw-r--r--src/test/scala/scala/async/TestUtils.scala4
-rw-r--r--src/test/scala/scala/async/TreeInterrogation.scala3
-rw-r--r--src/test/scala/scala/async/neg/NakedAwait.scala72
-rw-r--r--src/test/scala/scala/async/run/anf/AnfTransformSpec.scala146
10 files changed, 542 insertions, 102 deletions
diff --git a/build.sbt b/build.sbt
index 7c48e67..ba0544b 100644
--- a/build.sbt
+++ b/build.sbt
@@ -19,6 +19,8 @@ libraryDependencies += "com.novocode" % "junit-interface" % "0.10-M2" % "test"
testOptions += Tests.Argument(TestFrameworks.JUnit, "-q", "-v", "-s")
+parallelExecution in Global := false
+
autoCompilerPlugins := true
libraryDependencies <<= (scalaVersion, libraryDependencies) {
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala
new file mode 100644
index 0000000..e1d7cd5
--- /dev/null
+++ b/src/main/scala/scala/async/AnfTransform.scala
@@ -0,0 +1,101 @@
+package scala.async
+
+import scala.reflect.macros.Context
+
+class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) {
+ import c.universe._
+ import AsyncUtils._
+
+ object inline {
+ def transformToList(tree: Tree): List[Tree] = {
+ val stats :+ expr = anf.transformToList(tree)
+ expr match {
+
+ case Apply(fun, args) if fun.toString.startsWith("scala.async.Async.await") =>
+ val liftedName = c.fresh("await$")
+ stats :+ ValDef(NoMods, liftedName, TypeTree(), expr) :+ Ident(liftedName)
+
+ case If(cond, thenp, elsep) =>
+ // if type of if-else is Unit don't introduce assignment,
+ // but add Unit value to bring it into form expected by async transform
+ if (expr.tpe =:= definitions.UnitTpe) {
+ stats :+ expr :+ Literal(Constant(()))
+ }
+ else {
+ val liftedName = c.fresh("ifres$")
+ val varDef =
+ ValDef(Modifiers(Flag.MUTABLE), liftedName, TypeTree(expr.tpe), defaultValue(expr.tpe))
+ val thenWithAssign = thenp match {
+ case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(liftedName), thenExpr))
+ case _ => Assign(Ident(liftedName), thenp)
+ }
+ val elseWithAssign = elsep match {
+ case Block(elseStats, elseExpr) => Block(elseStats, Assign(Ident(liftedName), elseExpr))
+ case _ => Assign(Ident(liftedName), elsep)
+ }
+ val ifWithAssign =
+ If(cond, thenWithAssign, elseWithAssign)
+ stats :+ varDef :+ ifWithAssign :+ Ident(liftedName)
+ }
+ case _ =>
+ stats :+ expr
+ }
+ }
+
+ def transformToList(trees: List[Tree]): List[Tree] = trees match {
+ case fst :: rest => transformToList(fst) ++ transformToList(rest)
+ case Nil => Nil
+ }
+ }
+
+ object anf {
+ def transformToList(tree: Tree): List[Tree] = tree match {
+ case Select(qual, sel) =>
+ val stats :+ expr = inline.transformToList(qual)
+ stats :+ Select(expr, sel)
+
+ case Apply(fun, args) =>
+ val funStats :+ simpleFun = inline.transformToList(fun)
+ val argLists = args map inline.transformToList
+ val allArgStats = argLists flatMap (_.init)
+ val simpleArgs = argLists map (_.last)
+ funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs)
+
+ case Block(stats, expr) =>
+ inline.transformToList(stats) ++ inline.transformToList(expr)
+
+ case ValDef(mods, name, tpt, rhs) =>
+ val stats :+ expr = inline.transformToList(rhs)
+ stats :+ ValDef(mods, name, tpt, expr).setSymbol(tree.symbol)
+
+ case Assign(name, rhs) =>
+ val stats :+ expr = inline.transformToList(rhs)
+ stats :+ Assign(name, expr)
+
+ case If(cond, thenp, elsep) =>
+ val stats :+ expr = inline.transformToList(cond)
+ val thenStats :+ thenExpr = inline.transformToList(thenp)
+ val elseStats :+ elseExpr = inline.transformToList(elsep)
+ stats :+
+ c.typeCheck(If(expr, Block(thenStats, thenExpr), Block(elseStats, elseExpr)))
+
+ //TODO
+ case Literal(_) | Ident(_) | This(_) | Match(_, _) | New(_) | Function(_, _) => List(tree)
+
+ case TypeApply(fun, targs) =>
+ val funStats :+ simpleFun = inline.transformToList(fun)
+ funStats :+ TypeApply(simpleFun, targs)
+
+ //TODO
+ case DefDef(mods, name, tparams, vparamss, tpt, rhs) => List(tree)
+
+ case ClassDef(mods, name, tparams, impl) => List(tree)
+
+ case ModuleDef(mods, name, impl) => List(tree)
+
+ case _ =>
+ c.error(tree.pos, "Internal error while compiling `async` block")
+ ???
+ }
+ }
+}
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala
index b71ce74..bd766f2 100644
--- a/src/main/scala/scala/async/Async.scala
+++ b/src/main/scala/scala/async/Async.scala
@@ -4,18 +4,14 @@
package scala.async
import scala.language.experimental.macros
-
import scala.reflect.macros.Context
-import scala.collection.mutable.ListBuffer
-import scala.concurrent.{Future, Promise, ExecutionContext, future}
-import ExecutionContext.Implicits.global
-import scala.util.control.NonFatal
-
/*
* @author Philipp Haller
*/
object Async extends AsyncBase {
+ import scala.concurrent.Future
+
lazy val futureSystem = ScalaConcurrentFutureSystem
type FS = ScalaConcurrentFutureSystem.type
@@ -52,11 +48,11 @@ abstract class AsyncBase {
/**
* A call to `await` must be nested in an enclosing `async` block.
- *
+ *
* A call to `await` does not block the current thread, rather it is a delimiter
* used by the enclosing `async` macro. Code following the `await`
* call is executed asynchronously, when the argument of `await` has been completed.
- *
+ *
* @param awaitable the future from which a value is awaited.
* @tparam T the type of that value.
* @return the value.
@@ -74,19 +70,38 @@ abstract class AsyncBase {
import builder.defn._
import builder.name
import builder.futureSystemOps
- val (stats, expr) = body.tree match {
- case Block(stats, expr) => (stats, expr)
- case tree => (Nil, tree)
+
+ // Transform to A-normal form:
+ // - no await calls in qualifiers or arguments,
+ // - if/match only used in statement position.
+ val anfTree: Block = {
+ val transform = new AnfTransform[c.type](c)
+ val stats1 :+ expr1 = transform.anf.transformToList(body.tree)
+ c.typeCheck(Block(stats1, expr1)).asInstanceOf[Block]
}
- val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000, Map())
+ // Analyze the block to find locals that will be accessed from multiple
+ // states of our generated state machine, e.g. a value assigned before
+ // an `await` and read afterwards.
+ val renameMap: Map[Symbol, TermName] = {
+ val analyzer = new builder.AsyncAnalyzer
+ analyzer.traverse(anfTree)
+ analyzer.valDefsToLift.map {
+ vd =>
+ (vd.symbol, builder.name.fresh(vd.name))
+ }.toMap
+ }
- asyncBlockBuilder.asyncStates foreach (s => AsyncUtils.vprintln(s))
+ val startState = builder.stateAssigner.nextState()
+ val endState = Int.MaxValue
+ val asyncBlockBuilder = new builder.AsyncBlockBuilder(anfTree.stats, anfTree.expr, startState, endState, renameMap)
val handlerCases: List[CaseDef] = asyncBlockBuilder.mkCombinedHandlerCases[T]()
- val initStates = asyncBlockBuilder.asyncStates.init
- val localVarTrees = initStates.flatMap(_.allVarDefs).toList
+ import asyncBlockBuilder.asyncStates
+ logDiagnostics(c)(anfTree, asyncStates.map(_.toString))
+ val initStates = asyncStates.init
+ val localVarTrees = asyncStates.flatMap(_.allVarDefs).toList
/*
lazy val onCompleteHandler = (tr: Try[Any]) => state match {
@@ -98,7 +113,7 @@ abstract class AsyncBase {
...
*/
val onCompleteHandler = {
- val onCompleteHandlers = initStates.flatMap(_.mkOnCompleteHandler).toList
+ val onCompleteHandlers = initStates.flatMap(_.mkOnCompleteHandler()).toList
Function(
List(ValDef(Modifiers(PARAM), name.tr, TypeTree(TryAnyType), EmptyTree)),
Match(Ident(name.state), onCompleteHandlers))
@@ -144,7 +159,7 @@ abstract class AsyncBase {
// Spawn a future to:
futureSystemOps.future[Unit] {
c.Expr[Unit](Block(
- // define vars for all intermediate results
+ // define vars for all intermediate results that are accessed from multiple states
localVarTrees :+
// define the resume() method
resumeFunTree :+
@@ -159,8 +174,22 @@ abstract class AsyncBase {
// ... and return its Future from the macro.
val result = futureSystemOps.promiseToFuture(prom)
- AsyncUtils.vprintln(s"${c.macroApplication} \nexpands to:\n ${result.tree}")
+ AsyncUtils.vprintln(s"async state machine transform expands to:\n ${result.tree}")
result
}
+
+ def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) {
+ def location = try {
+ c.macroApplication.pos.source.path
+ } catch {
+ case _: UnsupportedOperationException =>
+ c.macroApplication.pos.toString
+ }
+
+ AsyncUtils.vprintln(s"In file '$location':")
+ AsyncUtils.vprintln(s"${c.macroApplication}")
+ AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree")
+ states foreach (s => AsyncUtils.vprintln(s))
+ }
}
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index 2b35ff4..7a9c98d 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -5,17 +5,16 @@ package scala.async
import scala.reflect.macros.Context
import scala.collection.mutable.ListBuffer
-import scala.concurrent.Future
-import AsyncUtils.vprintln
+import collection.mutable
/*
* @author Philipp Haller
*/
-final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSystem: FS) {
+final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val futureSystem: FS)
+ extends TransformUtils(c) {
builder =>
import c.universe._
- import Flag._
import defn._
private[async] object name {
@@ -38,8 +37,6 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
private[async] lazy val futureSystemOps = futureSystem.mkOps(c)
- private val execContext = futureSystemOps.execContext
-
private def resetDuplicate(tree: Tree) = c.resetAllAttrs(tree.duplicate)
private def mkResumeApply = Apply(Ident(name.resume), Nil)
@@ -50,14 +47,6 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
private def mkStateTree(nextState: Tree): c.Tree =
Assign(Ident(name.state), nextState)
- private def defaultValue(tpe: Type): Literal = {
- val defaultValue: Any =
- if (tpe <:< definitions.BooleanTpe) false
- else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0
- else null
- Literal(Constant(defaultValue))
- }
-
private def mkVarDefTree(resultType: Type, resultName: TermName): c.Tree = {
ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType))
}
@@ -87,7 +76,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
Ident(aw.resultName),
TypeApply(Select(Select(Ident(name.tr), Try_get), newTermName("asInstanceOf")), List(TypeTree(aw.resultType)))
)
- val updateState = mkStateTree(nextState) // or increment?
+ val updateState = mkStateTree(nextState)
Some(mkHandlerCase(state, List(tryGetTree, updateState, mkResumeApply)))
case _ =>
None
@@ -142,7 +131,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
/*
* Builder for a single state of an async method.
*/
- class AsyncStateBuilder(state: Int, private var nameMap: Map[c.Symbol, c.Name]) {
+ class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name]) {
self =>
/* Statements preceding an await call. */
@@ -157,7 +146,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
/* Result type of an await call. */
var resultType: Type = null
- var nextState: Int = state + 1
+ var nextState: Int = -1
private val varDefs = ListBuffer[(TermName, Type)]()
@@ -176,9 +165,8 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
}
//TODO do not ignore `mods`
- def addVarDef(mods: Any, name: TermName, tpt: c.Tree, rhs: c.Tree, extNameMap: Map[c.Symbol, c.Name]): this.type = {
+ def addVarDef(mods: Any, name: TermName, tpt: c.Tree, rhs: c.Tree): this.type = {
varDefs += (name -> tpt.tpe)
- nameMap ++= extNameMap // update name map
this += Assign(Ident(name), rhs)
this
}
@@ -204,9 +192,9 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
* @param awaitResultType the type of the result of await
*/
def complete(awaitArg: c.Tree, awaitResultName: TermName, awaitResultType: Tree,
- extNameMap: Map[c.Symbol, c.Name], nextState: Int = state + 1): this.type = {
- nameMap ++= extNameMap
- awaitable = resetDuplicate(renamer.transform(awaitArg))
+ nextState: Int): this.type = {
+ val renamed = renamer.transform(awaitArg)
+ awaitable = resetDuplicate(renamed)
resultName = awaitResultName
resultType = awaitResultType.tpe
this.nextState = nextState
@@ -237,14 +225,13 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
*
* @param scrutTree tree of the scrutinee
* @param cases list of case definitions
- * @param stateFirstCase state of the right-hand side of the first case
- * @param perCaseBudget maximum number of states per case
+ * @param caseStates starting state of the right-hand side of the each case
* @return an `AsyncState` representing the match expression
*/
- def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], stateFirstCase: Int, perCasebudget: Int): AsyncState = {
+ def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], caseStates: List[Int]): AsyncState = {
// 1. build list of changed cases
val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match {
- case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(num * perCasebudget + stateFirstCase), mkResumeApply))
+ case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(caseStates(num)), mkResumeApply))
}
// 2. insert changed match tree at the end of the current state
this += Match(resetDuplicate(scrutTree), newCases)
@@ -259,6 +246,8 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
}
}
+ val stateAssigner = new StateAssigner
+
/**
* An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`).
*
@@ -266,106 +255,84 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
* @param expr the last expression of the block
* @param startState the start state
* @param endState the state to continue with
- * @param budget the maximum number of states in this block
* @param toRename a `Map` for renaming the given key symbols to the mangled value names
*/
class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int,
- budget: Int, private var toRename: Map[c.Symbol, c.Name]) {
+ private val toRename: Map[Symbol, c.Name]) {
val asyncStates = ListBuffer[builder.AsyncState]()
private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename)
// current state builder
private var currState = startState
- private var remainingBudget = budget
-
/* TODO Fall back to CPS plug-in if tree contains an `await` call. */
def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists {
- case Apply(fun, _) if fun.symbol == Async_await => true
+ case Apply(fun, _) if isAwait(fun) => true
case _ => false
}) c.abort(tree.pos, "await unsupported in this position") //throw new FallbackToCpsException
- def builderForBranch(tree: c.Tree, state: Int, nextState: Int, budget: Int, nameMap: Map[c.Symbol, c.Name]): AsyncBlockBuilder = {
+ def builderForBranch(tree: c.Tree, state: Int, nextState: Int): AsyncBlockBuilder = {
val (branchStats, branchExpr) = tree match {
case Block(s, e) => (s, e)
case _ => (List(tree), c.literalUnit.tree)
}
- new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, budget, nameMap)
+ new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, toRename)
}
// populate asyncStates
for (stat <- stats) stat match {
// the val name = await(..) pattern
- case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == Async_await =>
- val newName = builder.name.fresh(name)
- toRename += (stat.symbol -> newName)
-
- asyncStates += stateBuilder.complete(args.head, newName, tpt, toRename).result // complete with await
- if (remainingBudget > 0)
- remainingBudget -= 1
- else
- assert(false, "too many invocations of `await` in current method")
- currState += 1
+ case ValDef(mods, name, tpt, Apply(fun, args)) if isAwait(fun) =>
+ val afterAwaitState = stateAssigner.nextState()
+ asyncStates += stateBuilder.complete(args.head, toRename(stat.symbol).toTermName, tpt, afterAwaitState).result // complete with await
+ currState = afterAwaitState
stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
- case ValDef(mods, name, tpt, rhs) =>
+ case ValDef(mods, name, tpt, rhs) if toRename contains stat.symbol =>
checkForUnsupportedAwait(rhs)
- val newName = builder.name.fresh(name)
- toRename += (stat.symbol -> newName)
// when adding assignment need to take `toRename` into account
- stateBuilder.addVarDef(mods, newName, tpt, rhs, toRename)
+ stateBuilder.addVarDef(mods, toRename(stat.symbol).toTermName, tpt, rhs)
- case If(cond, thenp, elsep) =>
+ case If(cond, thenp, elsep) if stat exists isAwait =>
checkForUnsupportedAwait(cond)
- val ifBudget: Int = remainingBudget / 2
- remainingBudget -= ifBudget //TODO test if budget > 0
- // state that we continue with after if-else: currState + ifBudget
-
- val thenBudget: Int = ifBudget / 2
- val elseBudget = ifBudget - thenBudget
+ val thenStartState = stateAssigner.nextState()
+ val elseStartState = stateAssigner.nextState()
+ val afterIfState = stateAssigner.nextState()
asyncStates +=
// the two Int arguments are the start state of the then branch and the else branch, respectively
- stateBuilder.resultWithIf(cond, currState + 1, currState + thenBudget)
+ stateBuilder.resultWithIf(cond, thenStartState, elseStartState)
- List((thenp, currState + 1, thenBudget), (elsep, currState + thenBudget, elseBudget)) foreach {
- case (tree, state, branchBudget) =>
- val builder = builderForBranch(tree, state, currState + ifBudget, branchBudget, toRename)
+ List((thenp, thenStartState), (elsep, elseStartState)) foreach {
+ case (tree, state) =>
+ val builder = builderForBranch(tree, state, afterIfState)
asyncStates ++= builder.asyncStates
- toRename ++= builder.toRename
}
- // create new state builder for state `currState + ifBudget`
- currState = currState + ifBudget
+ currState = afterIfState
stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
- case Match(scrutinee, cases) =>
- vprintln("transforming match expr: " + stat)
+ case Match(scrutinee, cases) if stat exists isAwait =>
checkForUnsupportedAwait(scrutinee)
- val matchBudget: Int = remainingBudget / 2
- remainingBudget -= matchBudget //TODO test if budget > 0
- // state that we continue with after match: currState + matchBudget
+ val caseStates = cases.map(_ => stateAssigner.nextState())
+ val afterMatchState = stateAssigner.nextState()
- val perCaseBudget: Int = matchBudget / cases.size
asyncStates +=
- // the two Int arguments are the start state of the first case and the per-case state budget, respectively
- stateBuilder.resultWithMatch(scrutinee, cases, currState + 1, perCaseBudget)
+ stateBuilder.resultWithMatch(scrutinee, cases, caseStates)
for ((cas, num) <- cases.zipWithIndex) {
val (casStats, casExpr) = cas match {
case CaseDef(_, _, Block(s, e)) => (s, e)
case CaseDef(_, _, rhs) => (List(rhs), c.literalUnit.tree)
}
- val builder = new AsyncBlockBuilder(casStats, casExpr, currState + (num * perCaseBudget) + 1, currState + matchBudget, perCaseBudget, toRename)
+ val builder = new AsyncBlockBuilder(casStats, casExpr, caseStates(num), afterMatchState, toRename)
asyncStates ++= builder.asyncStates
- toRename ++= builder.toRename
}
- // create new state builder for state `currState + matchBudget`
- currState = currState + matchBudget
+ currState = afterMatchState
stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
case ClassDef(_, name, _, _) =>
@@ -382,7 +349,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
}
// complete last state builder (representing the expressions after the last await)
stateBuilder += expr
- val lastState = stateBuilder.complete(endState).result
+ val lastState = stateBuilder.complete(endState).result()
asyncStates += lastState
def mkCombinedHandlerCases[T](): List[CaseDef] = {
@@ -402,9 +369,106 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
}
}
+ private val Boolean_ShortCircuits: Set[Symbol] = {
+ import definitions.BooleanClass
+ def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName)
+ val Boolean_&& = BooleanTermMember("&&")
+ val Boolean_|| = BooleanTermMember("||")
+ Set(Boolean_&&, Boolean_||)
+ }
+
+ def isByName(fun: Tree): (Int => Boolean) = {
+ if (Boolean_ShortCircuits contains fun.symbol) i => true
+ else fun.tpe match {
+ case MethodType(params, _) =>
+ val isByNameParams = params.map(_.asTerm.isByNameParam)
+ (i: Int) => isByNameParams.applyOrElse(i, (_: Int) => false)
+ case _ => Map()
+ }
+ }
+
+ private def isAwait(fun: Tree) = {
+ fun.symbol == defn.Async_await
+ }
+
+ /**
+ * Analyze the contents of an `async` block in order to:
+ * - Report unsupported `await` calls under nested templates, functions, by-name arguments.
+ * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based
+ * on whether or not they are accessed only from a single state.
+ */
+ private[async] class AsyncAnalyzer extends Traverser {
+ private var chunkId = 0
+ private def nextChunk() = chunkId += 1
+ private var valDefChunkId = Map[Symbol, (ValDef, Int)]()
+
+ val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]()
+
+ override def traverse(tree: Tree) = {
+ tree match {
+ case cd: ClassDef =>
+ val kind = if (cd.symbol.asClass.isTrait) "trait" else "class"
+ reportUnsupportedAwait(tree, s"nested $kind")
+ case md: ModuleDef =>
+ reportUnsupportedAwait(tree, "nested object")
+ case _: Function =>
+ reportUnsupportedAwait(tree, "nested anonymous function")
+ case If(cond, thenp, elsep) if tree exists isAwait =>
+ traverseChunks(List(cond, thenp, elsep))
+ case Match(selector, cases) if tree exists isAwait =>
+ traverseChunks(selector :: cases)
+ case Apply(fun, args) if isAwait(fun) =>
+ traverseTrees(args)
+ traverse(fun)
+ nextChunk()
+ case Apply(fun, args) =>
+ val isInByName = isByName(fun)
+ for ((arg, index) <- args.zipWithIndex) {
+ if (!isInByName(index)) traverse(arg)
+ else reportUnsupportedAwait(arg, "by-name argument")
+ }
+ traverse(fun)
+ case vd: ValDef =>
+ super.traverse(tree)
+ valDefChunkId += (vd.symbol ->(vd, chunkId))
+ if (isAwait(vd.rhs)) valDefsToLift += vd
+ case as: Assign =>
+ if (isAwait(as.rhs)) {
+ // TODO test the orElse case, try to remove the restriction.
+ val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block."))
+ valDefsToLift += vd
+ }
+ super.traverse(tree)
+ case rt: RefTree =>
+ valDefChunkId.get(rt.symbol) match {
+ case Some((vd, defChunkId)) if defChunkId != chunkId =>
+ valDefsToLift += vd
+ case _ =>
+ }
+ super.traverse(tree)
+ case _ => super.traverse(tree)
+ }
+ }
+
+ private def traverseChunks(trees: List[Tree]) {
+ trees.foreach {t => traverse(t); nextChunk()}
+ }
+
+ private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) {
+ val badAwaits = tree collect {
+ case rt: RefTree if isAwait(rt) => rt
+ }
+ badAwaits foreach {
+ tree =>
+ c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
+ }
+ }
+ }
+
+
/** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */
private def methodSym(apply: c.Expr[Any]): Symbol = {
- val tree2: Tree = c.typeCheck(apply.tree) // TODO why is this needed?
+ val tree2: Tree = c.typeCheck(apply.tree)
tree2.collect {
case s: SymTree if s.symbol.isMethod => s.symbol
}.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}"))
@@ -421,10 +485,6 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
self.splice.apply(arg.splice)
}
- def mkInt_+(self: Expr[Int])(other: Expr[Int]) = reify {
- self.splice + other.splice
- }
-
def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify {
self.splice == other.splice
}
@@ -445,5 +505,4 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy
tpe.member(c.universe.newTermName("await"))
}
}
-
}
diff --git a/src/main/scala/scala/async/StateAssigner.scala b/src/main/scala/scala/async/StateAssigner.scala
new file mode 100644
index 0000000..4f6c5a0
--- /dev/null
+++ b/src/main/scala/scala/async/StateAssigner.scala
@@ -0,0 +1,10 @@
+package scala.async
+
+private[async] final class StateAssigner {
+ private var current = -1
+
+ def nextState(): Int = {
+ current += 1
+ current
+ }
+} \ No newline at end of file
diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala
new file mode 100644
index 0000000..d36c277
--- /dev/null
+++ b/src/main/scala/scala/async/TransformUtils.scala
@@ -0,0 +1,22 @@
+/**
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+package scala.async
+
+import scala.reflect.macros.Context
+
+/**
+ * Utilities used in both `ExprBuilder` and `AnfTransform`.
+ */
+class TransformUtils[C <: Context](val c: C) {
+ import c.universe._
+
+ protected def defaultValue(tpe: Type): Literal = {
+ val defaultValue: Any =
+ if (tpe <:< definitions.BooleanTpe) false
+ else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0
+ else if (tpe <:< definitions.AnyValTpe) 0
+ else null
+ Literal(Constant(defaultValue))
+ }
+}
diff --git a/src/test/scala/scala/async/TestUtils.scala b/src/test/scala/scala/async/TestUtils.scala
index bac22a3..0ae78b8 100644
--- a/src/test/scala/scala/async/TestUtils.scala
+++ b/src/test/scala/scala/async/TestUtils.scala
@@ -50,9 +50,9 @@ trait TestUtils {
m.mkToolBox(options = compileOptions)
}
- def expectError(errorSnippet: String, compileOptions: String = "")(code: String) {
+ def expectError(errorSnippet: String, compileOptions: String = "", baseCompileOptions: String = "-cp target/scala-2.10/classes")(code: String) {
intercept[ToolBoxError] {
- eval(code, compileOptions)
+ eval(code, compileOptions + " " + baseCompileOptions)
}.getMessage mustContain errorSnippet
}
}
diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala
index 1293bdf..1ed9be2 100644
--- a/src/test/scala/scala/async/TreeInterrogation.scala
+++ b/src/test/scala/scala/async/TreeInterrogation.scala
@@ -32,7 +32,6 @@ class TreeInterrogation {
val varDefs = tree1.collect {
case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name
}
- // TODO no need to lift `y` as it is only accessed from a single state.
- varDefs.map(_.decoded).toSet mustBe(Set("state$async", "onCompleteHandler$async", "x$1", "z$1", "y$1"))
+ varDefs.map(_.decoded).toSet mustBe(Set("state$async", "onCompleteHandler$async", "x$1", "z$1"))
}
}
diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala
index db67f18..66bc947 100644
--- a/src/test/scala/scala/async/neg/NakedAwait.scala
+++ b/src/test/scala/scala/async/neg/NakedAwait.scala
@@ -16,4 +16,76 @@ class NakedAwait {
""".stripMargin
}
}
+
+
+ @Test
+ def `await not allowed in by-name argument`() {
+ expectError("await must not be used under a by-name argument.") {
+ """
+ | import _root_.scala.async.AsyncId._
+ | def foo(a: Int)(b: => Int) = 0
+ | async { foo(0)(await(0)) }
+ """.stripMargin
+ }
+ }
+
+ @Test
+ def `await not allowed in boolean short circuit argument 1`() {
+ expectError("await must not be used under a by-name argument.") {
+ """
+ | import _root_.scala.async.AsyncId._
+ | async { true && await(false) }
+ """.stripMargin
+ }
+ }
+
+ @Test
+ def `await not allowed in boolean short circuit argument 2`() {
+ expectError("await must not be used under a by-name argument.") {
+ """
+ | import _root_.scala.async.AsyncId._
+ | async { true || await(false) }
+ """.stripMargin
+ }
+ }
+
+ @Test
+ def nestedObject() {
+ expectError("await must not be used under a nested object.") {
+ """
+ | import _root_.scala.async.AsyncId._
+ | async { object Nested { await(false) } }
+ """.stripMargin
+ }
+ }
+
+ @Test
+ def nestedTrait() {
+ expectError("await must not be used under a nested trait.") {
+ """
+ | import _root_.scala.async.AsyncId._
+ | async { trait Nested { await(false) } }
+ """.stripMargin
+ }
+ }
+
+ @Test
+ def nestedClass() {
+ expectError("await must not be used under a nested class.") {
+ """
+ | import _root_.scala.async.AsyncId._
+ | async { class Nested { await(false) } }
+ """.stripMargin
+ }
+ }
+
+ @Test
+ def nestedFunction() {
+ expectError("await must not be used under a nested anonymous function.") {
+ """
+ | import _root_.scala.async.AsyncId._
+ | async { () => { await(false) } }
+ """.stripMargin
+ }
+ }
}
diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
new file mode 100644
index 0000000..0abb937
--- /dev/null
+++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
@@ -0,0 +1,146 @@
+/**
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async
+package run
+package anf
+
+import language.{reflectiveCalls, postfixOps}
+import scala.concurrent.{Future, ExecutionContext, future, Await}
+import scala.concurrent.duration._
+import scala.async.Async.{async, await}
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+
+
+class AnfTestClass {
+
+ import ExecutionContext.Implicits.global
+
+ def base(x: Int): Future[Int] = future {
+ x + 2
+ }
+
+ def m(y: Int): Future[Int] = async {
+ val f = base(y)
+ await(f)
+ }
+
+ def m2(y: Int): Future[Int] = async {
+ val f = base(y)
+ val f2 = base(y + 1)
+ await(f) + await(f2)
+ }
+
+ def m3(y: Int): Future[Int] = async {
+ val f = base(y)
+ var z = 0
+ if (y > 0) {
+ z = await(f) + 2
+ } else {
+ z = await(f) - 2
+ }
+ z
+ }
+
+ def m4(y: Int): Future[Int] = async {
+ val f = base(y)
+ val z = if (y > 0) {
+ await(f) + 2
+ } else {
+ await(f) - 2
+ }
+ z + 1
+ }
+
+ def futureUnitIfElse(y: Int): Future[Unit] = async {
+ val f = base(y)
+ if (y > 0) {
+ State.result = await(f) + 2
+ } else {
+ State.result = await(f) - 2
+ }
+ }
+}
+
+object State {
+ @volatile var result: Int = 0
+}
+
+@RunWith(classOf[JUnit4])
+class AnfTransformSpec {
+
+ @Test
+ def `simple ANF transform`() {
+ val o = new AnfTestClass
+ val fut = o.m(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (12)
+ }
+
+ @Test
+ def `simple ANF transform 2`() {
+ val o = new AnfTestClass
+ val fut = o.m2(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (25)
+ }
+
+ @Test
+ def `simple ANF transform 3`() {
+ val o = new AnfTestClass
+ val fut = o.m3(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (14)
+ }
+
+ @Test
+ def `ANF transform of assigning the result of an if-else`() {
+ val o = new AnfTestClass
+ val fut = o.m4(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (15)
+ }
+
+ @Test
+ def `Unit-typed if-else in tail position`() {
+ val o = new AnfTestClass
+ val fut = o.futureUnitIfElse(10)
+ Await.result(fut, 2 seconds)
+ State.result mustBe (14)
+ }
+
+ @Test
+ def `inlining block produces duplicate definition`() {
+ import scala.async.AsyncId
+
+ AsyncId.async {
+ val f = 12
+ val x = AsyncId.await(f)
+
+ {
+ val x = 42
+ println(x)
+ }
+
+ x
+ }
+ }
+ @Test
+ def `inlining block in tail position produces duplicate definition`() {
+ import scala.async.AsyncId
+
+ AsyncId.async {
+ val f = 12
+ val x = AsyncId.await(f)
+
+ {
+ val x = 42 // TODO should we rename the symbols when we collapse them into the same scope?
+ x
+ }
+ } mustBe (42)
+
+ }
+}