aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/scala/async/Async.scala (renamed from src/async/library/scala/async/Async.scala)3
-rw-r--r--src/main/scala/scala/async/AsyncUtils.scala (renamed from src/async/library/scala/async/AsyncUtils.scala)0
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala (renamed from src/async/library/scala/async/ExprBuilder.scala)268
-rw-r--r--src/test/scala/scala/async/TestLatch.scala36
-rw-r--r--src/test/scala/scala/async/TestUtils.scala38
-rw-r--r--src/test/scala/scala/async/neg/SampleNegSpec.scala25
-rw-r--r--src/test/scala/scala/async/neg/package.scala11
-rw-r--r--src/test/scala/scala/async/package.scala5
-rw-r--r--src/test/scala/scala/async/run/await0/Await0Spec.scala78
-rw-r--r--src/test/scala/scala/async/run/block0/AsyncSpec.scala61
-rw-r--r--src/test/scala/scala/async/run/block1/block1.scala46
-rw-r--r--src/test/scala/scala/async/run/ifelse0/IfElse0.scala51
-rw-r--r--src/test/scala/scala/async/run/ifelse1/IfElse1.scala131
-rw-r--r--src/test/scala/scala/async/run/ifelse2/ifelse2.scala51
-rw-r--r--src/test/scala/scala/async/run/ifelse3/IfElse3.scala54
15 files changed, 722 insertions, 136 deletions
diff --git a/src/async/library/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala
index 2c81bc3..0bf4362 100644
--- a/src/async/library/scala/async/Async.scala
+++ b/src/main/scala/scala/async/Async.scala
@@ -74,13 +74,14 @@ object Async extends AsyncUtils {
}
}
*/
+ val nonFatalModule = c.mirror.staticModule("scala.util.control.NonFatal")
val resumeFunTree: c.Tree = DefDef(Modifiers(), newTermName("resume"), List(), List(List()), Ident(definitions.UnitClass),
Try(Apply(Select(
Apply(Select(handlerExpr.tree, newTermName("orElse")), List(handlerForLastState.tree)),
newTermName("apply")), List(Ident(newTermName("state")))),
List(
CaseDef(
- Apply(Select(Select(Select(Ident(newTermName("scala")), newTermName("util")), newTermName("control")), newTermName("NonFatal")), List(Bind(newTermName("t"), Ident(nme.WILDCARD)))),
+ Apply(Ident(nonFatalModule), List(Bind(newTermName("t"), Ident(nme.WILDCARD)))),
EmptyTree,
Block(List(
Apply(Select(Ident(newTermName("result")), newTermName("failure")), List(Ident(newTermName("t"))))),
diff --git a/src/async/library/scala/async/AsyncUtils.scala b/src/main/scala/scala/async/AsyncUtils.scala
index 19e9d92..19e9d92 100644
--- a/src/async/library/scala/async/AsyncUtils.scala
+++ b/src/main/scala/scala/async/AsyncUtils.scala
diff --git a/src/async/library/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index b7d6446..655c26f 100644
--- a/src/async/library/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -4,19 +4,19 @@
package scala.async
import scala.reflect.macros.Context
-import scala.collection.mutable.{ ListBuffer, Builder }
+import scala.collection.mutable.{ListBuffer, Builder}
/*
* @author Philipp Haller
*/
class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
builder =>
-
+
import c.universe._
import Flag._
-
+
private val awaitMethod = awaitSym(c)
-
+
/* Make a partial function literal handling case #num:
*
* {
@@ -24,42 +24,46 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
* }
*/
def mkHandler(num: Int, rhs: c.Expr[Unit]): c.Expr[PartialFunction[Int, Unit]] = {
-/*
- val numLiteral = c.Expr[Int](Literal(Constant(num)))
-
- reify(new PartialFunction[Int, Unit] {
- def isDefinedAt(`x$1`: Int) =
- `x$1` == numLiteral.splice
- def apply(`x$1`: Int) = `x$1` match {
- case any: Int if any == numLiteral.splice =>
- rhs.splice
- }
- })
-*/
+ /*
+ val numLiteral = c.Expr[Int](Literal(Constant(num)))
+
+ reify(new PartialFunction[Int, Unit] {
+ def isDefinedAt(`x$1`: Int) =
+ `x$1` == numLiteral.splice
+ def apply(`x$1`: Int) = `x$1` match {
+ case any: Int if any == numLiteral.splice =>
+ rhs.splice
+ }
+ })
+ */
val rhsTree = c.resetAllAttrs(rhs.tree.duplicate)
val handlerTree = mkHandlerTree(num, rhsTree)
c.Expr(handlerTree).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]]
}
-
- def mkIncrStateTree(): c.Tree =
+
+ def mkIncrStateTree(): c.Tree = {
Assign(
Ident(newTermName("state")),
Apply(Select(Ident(newTermName("state")), newTermName("$plus")), List(Literal(Constant(1)))))
-
+ }
+
def mkStateTree(nextState: Int): c.Tree =
Assign(
Ident(newTermName("state")),
Literal(Constant(nextState)))
-
- def mkVarDefTree(resultType: c.universe.Type, resultName: c.universe.TermName): c.Tree = {
- val rhs =
- if (resultType <:< definitions.IntTpe) Literal(Constant(0))
- else if (resultType <:< definitions.LongTpe) Literal(Constant(0L))
- else if (resultType <:< definitions.BooleanTpe) Literal(Constant(false))
- else Literal(Constant(null))
- ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs)
+
+ def defaultValue(tpe: Type): Literal = {
+ val defaultValue: Any =
+ if (tpe <:< definitions.BooleanTpe) false
+ else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0
+ else null
+ Literal(Constant(defaultValue))
}
-
+
+ def mkVarDefTree(resultType: Type, resultName: TermName): c.Tree = {
+ ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType))
+ }
+
def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef =
CaseDef(
// pattern
@@ -68,26 +72,27 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
Apply(Select(Ident(newTermName("any")), newTermName("$eq$eq")), List(Literal(Constant(num)))),
rhs
)
-
+
def mkHandlerTreeFor(cases: List[(CaseDef, Int)]): c.Tree = {
val partFunIdent = Ident(c.mirror.staticClass("scala.PartialFunction"))
val intIdent = Ident(definitions.IntClass)
val unitIdent = Ident(definitions.UnitClass)
-
+
val caseCheck =
- Apply(Select(Apply(Select(Ident(newTermName("List")), newTermName("apply")),
- cases.map(p => Literal(Constant(p._2)))), newTermName("contains")), List(Ident(newTermName("x$1"))))
-
+ Apply(Select(Apply(Ident(definitions.List_apply),
+ cases.map(p => Literal(Constant(p._2)))), newTermName("contains")), List(Ident(newTermName("x$1"))))
+
Block(List(
// anonymous subclass of PartialFunction[Int, Unit]
+ // TODO subclass AbstractPartialFunction
ClassDef(Modifiers(FINAL), newTypeName("$anon"), List(), Template(List(AppliedTypeTree(partFunIdent, List(intIdent, unitIdent))),
emptyValDef, List(
DefDef(Modifiers(), nme.CONSTRUCTOR, List(), List(List()), TypeTree(),
Block(List(Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), List())), Literal(Constant(())))),
-
+
DefDef(Modifiers(), newTermName("isDefinedAt"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(),
- caseCheck),
-
+ caseCheck),
+
DefDef(Modifiers(), newTermName("apply"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(),
Match(Ident(newTermName("x$1")), cases.map(_._1)) // combine all cases into a single match
)
@@ -96,60 +101,61 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
Apply(Select(New(Ident(newTypeName("$anon"))), nme.CONSTRUCTOR), List())
)
}
-
+
def mkHandlerTree(num: Int, rhs: c.Tree): c.Tree =
mkHandlerTreeFor(List(mkHandlerCase(num, rhs) -> num))
-
+
class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) {
val body: c.Tree =
if (stats.size == 1) stats.head
else Block(stats: _*)
-
- val varDefs: List[(c.universe.TermName, c.universe.Type)] = List()
-
+
+ val varDefs: List[(TermName, Type)] = List()
+
def mkHandlerCaseForState(): CaseDef =
mkHandlerCase(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*))
-
+
def mkHandlerTreeForState(): c.Tree =
mkHandlerTree(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*))
-
+
def mkHandlerTreeForState(nextState: Int): c.Tree =
mkHandlerTree(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*))
-
+
def varDefForResult: Option[c.Tree] =
None
-
+
def allVarDefs: List[c.Tree] =
varDefForResult.toList ++ varDefs.map(p => mkVarDefTree(p._2, p._1))
-
+
override val toString: String =
s"AsyncState #$state, next = $nextState"
}
-
+
class AsyncStateWithIf(stats: List[c.Tree], state: Int)
- extends AsyncState(stats, state, 0) { // nextState unused, since encoded in then and else branches
-
+ extends AsyncState(stats, state, 0) {
+ // nextState unused, since encoded in then and else branches
+
override def mkHandlerTreeForState(): c.Tree =
mkHandlerTree(state, Block(stats: _*))
-
+
//TODO mkHandlerTreeForState(nextState: Int)
-
+
override def mkHandlerCaseForState(): CaseDef =
mkHandlerCase(state, Block(stats: _*))
-
+
override val toString: String =
s"AsyncStateWithIf #$state, next = $nextState"
}
-
+
abstract class AsyncStateWithAwait(stats: List[c.Tree], state: Int, nextState: Int)
- extends AsyncState(stats, state, nextState) {
+ extends AsyncState(stats, state, nextState) {
val awaitable: c.Tree
- val resultName: c.universe.TermName
- val resultType: c.universe.Type
-
+ val resultName: TermName
+ val resultType: Type
+
override val toString: String =
s"AsyncStateWithAwait #$state, next = $nextState"
-
+
/* Make an `onComplete` invocation:
*
* awaitable.onComplete {
@@ -162,23 +168,23 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
val assignTree =
Assign(
Ident(resultName.toString),
- Select(Ident("tr"), c.universe.newTermName("get"))
+ Select(Ident("tr"), newTermName("get"))
)
val handlerTree =
Match(
EmptyTree,
List(
- CaseDef(Bind(c.universe.newTermName("tr"), Ident("_")), EmptyTree,
+ CaseDef(Bind(newTermName("tr"), Ident("_")), EmptyTree,
Block(assignTree, Apply(Ident("resume"), List())) // rhs of case
)
)
)
Apply(
- Select(awaitable, c.universe.newTermName("onComplete")),
+ Select(awaitable, newTermName("onComplete")),
List(handlerTree)
)
}
-
+
/* Make an `onComplete` invocation which increments the state upon resuming:
*
* awaitable.onComplete {
@@ -192,19 +198,19 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
val tryGetTree =
Assign(
Ident(resultName.toString),
- Select(Ident("tr"), c.universe.newTermName("get"))
+ Select(Ident("tr"), newTermName("get"))
)
val handlerTree =
Match(
EmptyTree,
List(
- CaseDef(Bind(c.universe.newTermName("tr"), Ident("_")), EmptyTree,
+ CaseDef(Bind(newTermName("tr"), Ident("_")), EmptyTree,
Block(tryGetTree, mkIncrStateTree(), Apply(Ident("resume"), List())) // rhs of case
)
)
)
Apply(
- Select(awaitable, c.universe.newTermName("onComplete")),
+ Select(awaitable, newTermName("onComplete")),
List(handlerTree)
)
}
@@ -222,19 +228,19 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
val tryGetTree =
Assign(
Ident(resultName.toString),
- Select(Ident("tr"), c.universe.newTermName("get"))
+ Select(Ident("tr"), newTermName("get"))
)
val handlerTree =
Match(
EmptyTree,
List(
- CaseDef(Bind(c.universe.newTermName("tr"), Ident("_")), EmptyTree,
+ CaseDef(Bind(newTermName("tr"), Ident("_")), EmptyTree,
Block(tryGetTree, mkStateTree(nextState), Apply(Ident("resume"), List())) // rhs of case
)
)
)
Apply(
- Select(awaitable, c.universe.newTermName("onComplete")),
+ Select(awaitable, newTermName("onComplete")),
List(handlerTree)
)
}
@@ -255,7 +261,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
assert(awaitable != null)
builder.mkHandler(num, c.Expr[Unit](Block((stats :+ mkOnCompleteTree): _*)))
}
-
+
/* Make a partial function literal handling case #num:
*
* {
@@ -273,56 +279,43 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
assert(awaitable != null)
mkHandlerTree(state, Block((stats :+ mkOnCompleteIncrStateTree): _*))
}
-
+
override def mkHandlerTreeForState(nextState: Int): c.Tree = {
assert(awaitable != null)
mkHandlerTree(state, Block((stats :+ mkOnCompleteStateTree(nextState)): _*))
}
-
+
override def mkHandlerCaseForState(): CaseDef = {
assert(awaitable != null)
mkHandlerCase(state, Block((stats :+ mkOnCompleteIncrStateTree): _*))
}
-
- override def varDefForResult: Option[c.Tree] = {
- val rhs =
- if (resultType <:< definitions.IntTpe) Literal(Constant(0))
- else if (resultType <:< definitions.LongTpe) Literal(Constant(0L))
- else if (resultType <:< definitions.BooleanTpe) Literal(Constant(false))
- else if (resultType <:< definitions.FloatTpe) Literal(Constant(0.0f))
- else if (resultType <:< definitions.DoubleTpe) Literal(Constant(0.0d))
- else if (resultType <:< definitions.CharTpe) Literal(Constant(0.toChar))
- else if (resultType <:< definitions.ShortTpe) Literal(Constant(0.toShort))
- else if (resultType <:< definitions.ByteTpe) Literal(Constant(0.toByte))
- else Literal(Constant(null))
- Some(
- ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs)
- )
- }
+
+ override def varDefForResult: Option[c.Tree] =
+ Some(mkVarDefTree(resultType, resultName))
}
-
+
/*
* Builder for a single state of an async method.
*/
class AsyncStateBuilder(state: Int, private var nameMap: Map[c.Symbol, c.Name]) extends Builder[c.Tree, AsyncState] {
self =>
-
+
/* Statements preceding an await call. */
private val stats = ListBuffer[c.Tree]()
-
+
/* Argument of an await call. */
var awaitable: c.Tree = null
-
+
/* Result name of an await call. */
- var resultName: c.universe.TermName = null
-
+ var resultName: TermName = null
+
/* Result type of an await call. */
- var resultType: c.universe.Type = null
-
+ var resultType: Type = null
+
var nextState: Int = state + 1
-
- private val varDefs = ListBuffer[(c.universe.TermName, c.universe.Type)]()
-
+
+ private val varDefs = ListBuffer[(TermName, Type)]()
+
private val renamer = new Transformer {
override def transform(tree: Tree) = tree match {
case Ident(_) if nameMap.keySet contains tree.symbol =>
@@ -331,20 +324,20 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
super.transform(tree)
}
}
-
- def += (stat: c.Tree): this.type = {
+
+ def +=(stat: c.Tree): this.type = {
stats += c.resetAllAttrs(renamer.transform(stat).duplicate)
this
}
-
+
//TODO do not ignore `mods`
- def addVarDef(mods: Any, name: c.universe.TermName, tpt: c.Tree, rhs: c.Tree, extNameMap: Map[c.Symbol, c.Name]): this.type = {
+ def addVarDef(mods: Any, name: TermName, tpt: c.Tree, rhs: c.Tree, extNameMap: Map[c.Symbol, c.Name]): this.type = {
varDefs += (name -> tpt.tpe)
nameMap ++= extNameMap // update name map
this += Assign(Ident(name), c.resetAllAttrs(renamer.transform(rhs).duplicate))
this
}
-
+
def result(): AsyncState =
if (awaitable == null)
new AsyncState(stats.toList, state, nextState) {
@@ -357,14 +350,14 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
val resultType = self.resultType
override val varDefs = self.varDefs.toList
}
-
+
def clear(): Unit = {
stats.clear()
awaitable = null
resultName = null
resultType = null
}
-
+
/* Result needs to be created as a var at the beginning of the transformed method body, so that
* it is visible in subsequent states of the state machine.
*
@@ -372,7 +365,8 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
* @param awaitResultName the name of the variable that the result of await is assigned to
* @param awaitResultType the type of the result of await
*/
- def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: Tree, extNameMap: Map[c.Symbol, c.Name], nextState: Int = state + 1): this.type = {
+ def complete(awaitArg: c.Tree, awaitResultName: TermName, awaitResultType: Tree,
+ extNameMap: Map[c.Symbol, c.Name], nextState: Int = state + 1): this.type = {
nameMap ++= extNameMap
awaitable = c.resetAllAttrs(renamer.transform(awaitArg).duplicate)
resultName = awaitResultName
@@ -380,51 +374,53 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
this.nextState = nextState
this
}
-
+
def complete(nextState: Int): this.type = {
this.nextState = nextState
this
}
-
+
def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = {
// 1. build changed if-else tree
// 2. insert that tree at the end of the current state
val cond = c.resetAllAttrs(condTree.duplicate)
this += If(cond,
- Block(mkStateTree(thenState), Apply(Ident("resume"), List())),
- Block(mkStateTree(elseState), Apply(Ident("resume"), List())))
+ Block(mkStateTree(thenState), Apply(Ident("resume"), List())),
+ Block(mkStateTree(elseState), Apply(Ident("resume"), List())))
new AsyncStateWithIf(stats.toList, state) {
override val varDefs = self.varDefs.toList
}
}
-
+
override def toString: String = {
val statsBeforeAwait = stats.mkString("\n")
s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName"
}
}
- class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, budget: Int, private var toRename: Map[c.Symbol, c.Name]) {
+ class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int,
+ budget: Int, private var toRename: Map[c.Symbol, c.Name]) {
val asyncStates = ListBuffer[builder.AsyncState]()
-
- private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename) // current state builder
+
+ private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename)
+ // current state builder
private var currState = startState
-
+
private var remainingBudget = budget
-
+
/* Fall back to CPS plug-in if tree contains an `await` call. */
def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists {
case Apply(fun, _) if fun.symbol == awaitMethod => true
case _ => false
}) throw new FallbackToCpsException
-
+
// populate asyncStates
for (stat <- stats) stat match {
// the val name = await(..) pattern
case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == awaitMethod =>
val newName = newTermName(Async.freshString(name.toString()))
toRename += (stat.symbol -> newName)
-
+
asyncStates += stateBuilder.complete(args(0), newName, tpt, toRename).result // complete with await
if (remainingBudget > 0)
remainingBudget -= 1
@@ -432,29 +428,29 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
assert(false, "too many invocations of `await` in current method")
currState += 1
stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
-
+
case ValDef(mods, name, tpt, rhs) =>
checkForUnsupportedAwait(rhs)
-
+
val newName = newTermName(Async.freshString(name.toString()))
toRename += (stat.symbol -> newName)
// when adding assignment need to take `toRename` into account
stateBuilder.addVarDef(mods, newName, tpt, rhs, toRename)
-
+
case If(cond, thenp, elsep) =>
checkForUnsupportedAwait(cond)
-
+
val ifBudget: Int = remainingBudget / 2
remainingBudget -= ifBudget //TODO test if budget > 0
// state that we continue with after if-else: currState + ifBudget
-
+
val thenBudget: Int = ifBudget / 2
val elseBudget = ifBudget - thenBudget
-
+
asyncStates +=
// the two Int arguments are the start state of the then branch and the else branch, respectively
stateBuilder.resultWithIf(cond, currState + 1, currState + thenBudget)
-
+
val thenBuilder = thenp match {
case Block(thenStats, thenExpr) =>
new AsyncBlockBuilder(thenStats, thenExpr, currState + 1, currState + ifBudget, thenBudget, toRename)
@@ -463,7 +459,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
}
asyncStates ++= thenBuilder.asyncStates
toRename ++= thenBuilder.toRename
-
+
val elseBuilder = elsep match {
case Block(elseStats, elseExpr) =>
new AsyncBlockBuilder(elseStats, elseExpr, currState + thenBudget, currState + ifBudget, elseBudget, toRename)
@@ -472,11 +468,11 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
}
asyncStates ++= elseBuilder.asyncStates
toRename ++= elseBuilder.toRename
-
+
// create new state builder for state `currState + ifBudget`
currState = currState + ifBudget
stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
-
+
case _ =>
checkForUnsupportedAwait(stat)
stateBuilder += stat
@@ -485,35 +481,37 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
stateBuilder += expr
val lastState = stateBuilder.complete(endState).result
asyncStates += lastState
-
+
def mkCombinedHandlerExpr(): c.Expr[PartialFunction[Int, Unit]] = {
assert(asyncStates.size > 1)
-
+
val cases = for (state <- asyncStates.toList) yield state.mkHandlerCaseForState()
c.Expr(mkHandlerTreeFor(cases zip asyncStates.init.map(_.state))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]]
}
-
+
/* Builds the handler expression for a sequence of async states.
*/
def mkHandlerExpr(): c.Expr[PartialFunction[Int, Unit]] = {
assert(asyncStates.size > 1)
-
+
var handlerExpr =
c.Expr(asyncStates(0).mkHandlerTreeForState()).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]]
-
+
if (asyncStates.size == 2)
handlerExpr
else {
- for (asyncState <- asyncStates.tail.init) { // do not traverse first or last state
+ for (asyncState <- asyncStates.tail.init) {
+ // do not traverse first or last state
val handlerTreeForNextState = asyncState.mkHandlerTreeForState()
val currentHandlerTreeNaked = c.resetAllAttrs(handlerExpr.tree.duplicate)
handlerExpr = c.Expr(
Apply(Select(currentHandlerTreeNaked, newTermName("orElse")),
- List(handlerTreeForNextState))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]]
+ List(handlerTreeForNextState))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]]
}
handlerExpr
}
}
}
+
}
diff --git a/src/test/scala/scala/async/TestLatch.scala b/src/test/scala/scala/async/TestLatch.scala
new file mode 100644
index 0000000..676ea63
--- /dev/null
+++ b/src/test/scala/scala/async/TestLatch.scala
@@ -0,0 +1,36 @@
+package scala.async
+
+import concurrent.{CanAwait, Awaitable}
+import concurrent.duration.Duration
+import java.util.concurrent.{TimeoutException, CountDownLatch, TimeUnit}
+
+object TestLatch {
+ val DefaultTimeout = Duration(5, TimeUnit.SECONDS)
+
+ def apply(count: Int = 1) = new TestLatch(count)
+}
+
+
+class TestLatch(count: Int = 1) extends Awaitable[Unit] {
+ private var latch = new CountDownLatch(count)
+
+ def countDown() = latch.countDown()
+
+ def isOpen: Boolean = latch.getCount == 0
+
+ def open() = while (!isOpen) countDown()
+
+ def reset() = latch = new CountDownLatch(count)
+
+ @throws(classOf[TimeoutException])
+ def ready(atMost: Duration)(implicit permit: CanAwait) = {
+ val opened = latch.await(atMost.toNanos, TimeUnit.NANOSECONDS)
+ if (!opened) throw new TimeoutException(s"Timeout of ${(atMost.toString)}.")
+ this
+ }
+
+ @throws(classOf[Exception])
+ def result(atMost: Duration)(implicit permit: CanAwait): Unit = {
+ ready(atMost)
+ }
+}
diff --git a/src/test/scala/scala/async/TestUtils.scala b/src/test/scala/scala/async/TestUtils.scala
new file mode 100644
index 0000000..f4def22
--- /dev/null
+++ b/src/test/scala/scala/async/TestUtils.scala
@@ -0,0 +1,38 @@
+package scala.async
+
+import language.reflectiveCalls
+import language.postfixOps
+import language.implicitConversions
+
+import scala.reflect.{ClassTag, classTag}
+
+import scala.collection.mutable
+import scala.concurrent.{Future, Awaitable, CanAwait}
+import java.util.concurrent.{TimeoutException, CountDownLatch, TimeUnit}
+import scala.concurrent.duration.Duration
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+
+
+trait TestUtils {
+ implicit class objectops(obj: Any) {
+ def mustBe(other: Any) = assert(obj == other, obj + " is not " + other)
+
+ def mustEqual(other: Any) = mustBe(other)
+ }
+
+ implicit class stringops(text: String) {
+ def mustContain(substring: String) = assert(text contains substring, text)
+ }
+
+ def intercept[T <: Throwable : ClassTag](body: => Any): T = {
+ try {
+ body
+ throw new Exception(s"Exception of type ${classTag[T]} was not thrown")
+ } catch {
+ case t: Throwable =>
+ if (classTag[T].runtimeClass != t.getClass) throw t
+ else t.asInstanceOf[T]
+ }
+ }
+}
diff --git a/src/test/scala/scala/async/neg/SampleNegSpec.scala b/src/test/scala/scala/async/neg/SampleNegSpec.scala
new file mode 100644
index 0000000..00daf44
--- /dev/null
+++ b/src/test/scala/scala/async/neg/SampleNegSpec.scala
@@ -0,0 +1,25 @@
+package scala.async
+package neg
+
+import java.io.File
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import org.junit.Test
+import tools.reflect.ToolBoxError
+
+@RunWith(classOf[JUnit4])
+class SampleNegSpec {
+ val f = new File("/Users/jason/code/scala-async/test/files/run/await0")
+
+ @Test
+ def `missing symbol`() {
+ intercept[ToolBoxError] {
+ eval {
+ """
+ | kaboom
+ """.stripMargin
+ }
+ }.getMessage mustContain "not found: value kaboom"
+
+ }
+}
diff --git a/src/test/scala/scala/async/neg/package.scala b/src/test/scala/scala/async/neg/package.scala
new file mode 100644
index 0000000..1326394
--- /dev/null
+++ b/src/test/scala/scala/async/neg/package.scala
@@ -0,0 +1,11 @@
+package scala.async
+
+package object neg {
+ def eval(code: String): Any = {
+ val m = scala.reflect.runtime.currentMirror
+ import scala.tools.reflect.ToolBox
+ val tb = m.mkToolBox()
+ val result = tb.eval(tb.parse(code))
+ result
+ }
+}
diff --git a/src/test/scala/scala/async/package.scala b/src/test/scala/scala/async/package.scala
new file mode 100644
index 0000000..32e8be4
--- /dev/null
+++ b/src/test/scala/scala/async/package.scala
@@ -0,0 +1,5 @@
+package scala
+
+package object async extends TestUtils {
+
+}
diff --git a/src/test/scala/scala/async/run/await0/Await0Spec.scala b/src/test/scala/scala/async/run/await0/Await0Spec.scala
new file mode 100644
index 0000000..e7740e0
--- /dev/null
+++ b/src/test/scala/scala/async/run/await0/Await0Spec.scala
@@ -0,0 +1,78 @@
+package scala.async
+package run
+package await0
+
+/**
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+import language.{reflectiveCalls, postfixOps}
+
+import scala.concurrent.{Future, ExecutionContext, future, Await}
+import scala.concurrent.duration._
+import scala.async.Async.{async, await}
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import org.junit.Test
+
+class Await0Class {
+
+ import ExecutionContext.Implicits.global
+
+ def m1(x: Double): Future[Double] = future {
+ Thread.sleep(200)
+ x + 2.0
+ }
+
+ def m2(x: Float): Future[Float] = future {
+ Thread.sleep(200)
+ x + 2.0f
+ }
+
+ def m3(x: Char): Future[Char] = future {
+ Thread.sleep(200)
+ (x.toInt + 2).toChar
+ }
+
+ def m4(x: Short): Future[Short] = future {
+ Thread.sleep(200)
+ (x + 2).toShort
+ }
+
+ def m5(x: Byte): Future[Byte] = future {
+ Thread.sleep(200)
+ (x + 2).toByte
+ }
+
+ def m0(y: Int): Future[Double] = async {
+ val f1 = m1(y.toDouble)
+ val x1: Double = await(f1)
+
+ val f2 = m2(y.toFloat)
+ val x2: Float = await(f2)
+
+ val f3 = m3(y.toChar)
+ val x3: Char = await(f3)
+
+ val f4 = m4(y.toShort)
+ val x4: Short = await(f4)
+
+ val f5 = m5(y.toByte)
+ val x5: Byte = await(f5)
+
+ x1 + x2 + 2.0
+ }
+}
+
+@RunWith(classOf[JUnit4])
+class Await0Spec {
+
+ @Test
+ def `An async method support a simple await`() {
+ val o = new Await0Class
+ val fut = o.m0(10)
+ val res = Await.result(fut, 10 seconds)
+ res mustBe (26.0)
+ }
+}
+
diff --git a/src/test/scala/scala/async/run/block0/AsyncSpec.scala b/src/test/scala/scala/async/run/block0/AsyncSpec.scala
new file mode 100644
index 0000000..f56e394
--- /dev/null
+++ b/src/test/scala/scala/async/run/block0/AsyncSpec.scala
@@ -0,0 +1,61 @@
+/**
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async
+package run
+package block0
+
+import language.{reflectiveCalls, postfixOps}
+import scala.concurrent.{Future, ExecutionContext, future, Await}
+import scala.concurrent.duration._
+import scala.async.Async.{async, await}
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+
+
+class Test1Class {
+
+ import ExecutionContext.Implicits.global
+
+ def m1(x: Int): Future[Int] = future {
+ Thread.sleep(1000)
+ x + 2
+ }
+
+ def m2(y: Int): Future[Int] = async {
+ val f = m1(y)
+ val x = await(f)
+ x + 2
+ }
+
+ def m3(y: Int): Future[Int] = async {
+ val f1 = m1(y)
+ val x1 = await(f1)
+ val f2 = m1(y + 2)
+ val x2 = await(f2)
+ x1 + x2
+ }
+}
+
+
+@RunWith(classOf[JUnit4])
+class AsyncSpec {
+
+ @Test
+ def `simple await`() {
+ val o = new Test1Class
+ val fut = o.m2(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (14)
+ }
+
+ @Test
+ def `several awaits in sequence`() {
+ val o = new Test1Class
+ val fut = o.m3(10)
+ val res = Await.result(fut, 4 seconds)
+ res mustBe (26)
+ }
+}
diff --git a/src/test/scala/scala/async/run/block1/block1.scala b/src/test/scala/scala/async/run/block1/block1.scala
new file mode 100644
index 0000000..8f21688
--- /dev/null
+++ b/src/test/scala/scala/async/run/block1/block1.scala
@@ -0,0 +1,46 @@
+/**
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async
+package run
+package block1
+
+import language.{reflectiveCalls, postfixOps}
+import scala.concurrent.{Future, ExecutionContext, future, Await}
+import scala.concurrent.duration._
+import scala.async.Async.{async, await}
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+
+
+class Test1Class {
+
+ import ExecutionContext.Implicits.global
+
+ def m1(x: Int): Future[Int] = future {
+ Thread.sleep(1000)
+ x + 2
+ }
+
+ def m4(y: Int): Future[Int] = async {
+ val f1 = m1(y)
+ val f2 = m1(y + 2)
+ val x1 = await(f1)
+ println("between two awaits")
+ val x2 = await(f2)
+ x1 + x2
+ }
+}
+
+@RunWith(classOf[JUnit4])
+class Block1Spec {
+
+ @Test def `support a simple await`() {
+ val o = new Test1Class
+ val fut = o.m4(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (26)
+ }
+}
diff --git a/src/test/scala/scala/async/run/ifelse0/IfElse0.scala b/src/test/scala/scala/async/run/ifelse0/IfElse0.scala
new file mode 100644
index 0000000..eca3acd
--- /dev/null
+++ b/src/test/scala/scala/async/run/ifelse0/IfElse0.scala
@@ -0,0 +1,51 @@
+/**
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async
+package run
+package ifelse0
+
+import language.{reflectiveCalls, postfixOps}
+import scala.concurrent.{Future, ExecutionContext, future, Await}
+import scala.concurrent.duration._
+import scala.async.Async.{async, await}
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import org.junit.Test
+
+
+class TestIfElseClass {
+
+ import ExecutionContext.Implicits.global
+
+ def m1(x: Int): Future[Int] = future {
+ Thread.sleep(1000)
+ x + 2
+ }
+
+ def m2(y: Int): Future[Int] = async {
+ val f = m1(y)
+ var z = 0
+ if (y > 0) {
+ val x1 = await(f)
+ z = x1 + 2
+ } else {
+ val x2 = await(f)
+ z = x2 - 2
+ }
+ z
+ }
+}
+
+
+@RunWith(classOf[JUnit4])
+class IfElseSpec {
+
+ @Test def `support await in a simple if-else expression`() {
+ val o = new TestIfElseClass
+ val fut = o.m2(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (14)
+ }
+}
diff --git a/src/test/scala/scala/async/run/ifelse1/IfElse1.scala b/src/test/scala/scala/async/run/ifelse1/IfElse1.scala
new file mode 100644
index 0000000..128f02a
--- /dev/null
+++ b/src/test/scala/scala/async/run/ifelse1/IfElse1.scala
@@ -0,0 +1,131 @@
+/**
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async
+package run
+package ifelse1
+
+import language.{reflectiveCalls, postfixOps}
+import scala.concurrent.{Future, ExecutionContext, future, Await}
+import scala.concurrent.duration._
+import scala.async.Async.{async, await}
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import org.junit.Test
+
+
+class TestIfElse1Class {
+
+ import ExecutionContext.Implicits.global
+
+ def base(x: Int): Future[Int] = future {
+ Thread.sleep(1000)
+ x + 2
+ }
+
+ def m1(y: Int): Future[Int] = async {
+ val f = base(y)
+ var z = 0
+ if (y > 0) {
+ if (y > 100)
+ 5
+ else {
+ val x1 = await(f)
+ z = x1 + 2
+ }
+ } else {
+ val x2 = await(f)
+ z = x2 - 2
+ }
+ z
+ }
+
+ def m2(y: Int): Future[Int] = async {
+ val f = base(y)
+ var z = 0
+ if (y > 0) {
+ if (y < 100) {
+ val x1 = await(f)
+ z = x1 + 2
+ }
+ else
+ 5
+ } else {
+ val x2 = await(f)
+ z = x2 - 2
+ }
+ z
+ }
+
+ def m3(y: Int): Future[Int] = async {
+ val f = base(y)
+ var z = 0
+ if (y < 0) {
+ val x2 = await(f)
+ z = x2 - 2
+ } else {
+ if (y > 100)
+ 5
+ else {
+ val x1 = await(f)
+ z = x1 + 2
+ }
+ }
+ z
+ }
+
+ def m4(y: Int): Future[Int] = async {
+ val f = base(y)
+ var z = 0
+ if (y < 0) {
+ val x2 = await(f)
+ z = x2 - 2
+ } else {
+ if (y < 100) {
+ val x1 = await(f)
+ z = x1 + 2
+ } else
+ 5
+ }
+ z
+ }
+}
+
+@RunWith(classOf[JUnit4])
+class IfElse1Spec {
+
+ @Test
+ def `await in a nested if-else expression`() {
+ val o = new TestIfElse1Class
+ val fut = o.m1(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (14)
+ }
+
+ @Test
+ def `await in a nested if-else expression 2`() {
+ val o = new TestIfElse1Class
+ val fut = o.m2(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (14)
+ }
+
+
+ @Test
+ def `await in a nested if-else expression 3`() {
+ val o = new TestIfElse1Class
+ val fut = o.m3(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (14)
+ }
+
+
+ @Test
+ def `await in a nested if-else expression 4`() {
+ val o = new TestIfElse1Class
+ val fut = o.m4(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (14)
+ }
+}
diff --git a/src/test/scala/scala/async/run/ifelse2/ifelse2.scala b/src/test/scala/scala/async/run/ifelse2/ifelse2.scala
new file mode 100644
index 0000000..f894923
--- /dev/null
+++ b/src/test/scala/scala/async/run/ifelse2/ifelse2.scala
@@ -0,0 +1,51 @@
+/**
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async
+package run
+package ifelse2
+
+import language.{reflectiveCalls, postfixOps}
+import scala.concurrent.{Future, ExecutionContext, future, Await}
+import scala.concurrent.duration._
+import scala.async.Async.{async, await}
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import org.junit.Test
+
+
+class TestIfElse2Class {
+
+ import ExecutionContext.Implicits.global
+
+ def base(x: Int): Future[Int] = future {
+ Thread.sleep(1000)
+ x + 2
+ }
+
+ def m(y: Int): Future[Int] = async {
+ val f = base(y)
+ var z = 0
+ if (y > 0) {
+ val x = await(f)
+ z = x + 2
+ } else {
+ val x = await(f)
+ z = x - 2
+ }
+ z
+ }
+}
+
+@RunWith(classOf[JUnit4])
+class IfElse2Spec {
+
+ @Test
+ def `variables of the same name in different blocks`() {
+ val o = new TestIfElse2Class
+ val fut = o.m(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (14)
+ }
+}
diff --git a/src/test/scala/scala/async/run/ifelse3/IfElse3.scala b/src/test/scala/scala/async/run/ifelse3/IfElse3.scala
new file mode 100644
index 0000000..0c0dbfe
--- /dev/null
+++ b/src/test/scala/scala/async/run/ifelse3/IfElse3.scala
@@ -0,0 +1,54 @@
+/**
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async
+package run
+package ifelse3
+
+import language.{reflectiveCalls, postfixOps}
+import scala.concurrent.{Future, ExecutionContext, future, Await}
+import scala.concurrent.duration._
+import scala.async.Async.{async, await}
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import org.junit.Test
+
+
+class TestIfElse3Class {
+
+ import ExecutionContext.Implicits.global
+
+ def base(x: Int): Future[Int] = future {
+ Thread.sleep(1000)
+ x + 2
+ }
+
+ def m(y: Int): Future[Int] = async {
+ val f = base(y)
+ var z = 0
+ if (y > 0) {
+ val x1 = await(f)
+ var w = x1 + 2
+ z = w + 2
+ } else {
+ val x2 = await(f)
+ var w = x2 + 2
+ z = w - 2
+ }
+ z
+ }
+}
+
+
+@RunWith(classOf[JUnit4])
+class IfElse3Spec {
+
+ @Test
+ def `variables of the same name in different blocks`() {
+ val o = new TestIfElse3Class
+ val fut = o.m(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (16)
+ }
+}