aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--build.sbt2
-rw-r--r--src/main/scala/scala/async/AnfTransform.scala86
-rw-r--r--src/main/scala/scala/async/Async.scala99
-rw-r--r--src/main/scala/scala/async/AsyncAnalysis.scala70
-rw-r--r--src/main/scala/scala/async/AsyncUtils.scala4
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala41
-rw-r--r--src/main/scala/scala/async/FutureSystem.scala9
-rw-r--r--src/main/scala/scala/async/TransformUtils.scala125
-rw-r--r--src/test/scala/scala/async/TreeInterrogation.scala60
-rw-r--r--src/test/scala/scala/async/neg/LocalClasses0Spec.scala12
-rw-r--r--src/test/scala/scala/async/neg/NakedAwait.scala11
-rw-r--r--src/test/scala/scala/async/package.scala4
-rw-r--r--src/test/scala/scala/async/run/anf/AnfTransformSpec.scala4
-rw-r--r--src/test/scala/scala/async/run/hygiene/Hygiene.scala84
-rw-r--r--src/test/scala/scala/async/run/ifelse0/IfElse0.scala8
-rw-r--r--src/test/scala/scala/async/run/match0/Match0.scala43
-rw-r--r--src/test/scala/scala/async/run/nesteddef/NestedDef.scala40
18 files changed, 500 insertions, 203 deletions
diff --git a/.gitignore b/.gitignore
index ad6e494..0c4d130 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,3 +2,4 @@ classes
target
.idea
.idea_modules
+*.icode
diff --git a/build.sbt b/build.sbt
index 9b0a6bd..4a3f200 100644
--- a/build.sbt
+++ b/build.sbt
@@ -1,4 +1,4 @@
-scalaVersion := "2.10.0-RC1"
+scalaVersion := "2.10.0-RC3"
organization := "org.typesafe.async"
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala
index 0836634..a2d21f6 100644
--- a/src/main/scala/scala/async/AnfTransform.scala
+++ b/src/main/scala/scala/async/AnfTransform.scala
@@ -7,10 +7,12 @@ package scala.async
import scala.reflect.macros.Context
-private[async] final case class AnfTransform[C <: Context](val c: C) {
+private[async] final case class AnfTransform[C <: Context](c: C) {
import c.universe._
+
val utils = TransformUtils[c.type](c)
+
import utils._
def apply(tree: Tree): List[Tree] = {
@@ -29,9 +31,21 @@ private[async] final case class AnfTransform[C <: Context](val c: C) {
* This step is needed to allow us to safely merge blocks during the `inline` transform below.
*/
private final class UniqueNames(tree: Tree) extends Transformer {
- val repeatedNames: Set[Name] = tree.collect {
- case dt: DefTree => dt.symbol.name
- }.groupBy(x => x).filter(_._2.size > 1).keySet
+ val repeatedNames: Set[Symbol] = {
+ class DuplicateNameTraverser extends AsyncTraverser {
+ val result = collection.mutable.Buffer[Symbol]()
+
+ override def traverse(tree: Tree) {
+ tree match {
+ case dt: DefTree => result += dt.symbol
+ case _ => super.traverse(tree)
+ }
+ }
+ }
+ val dupNameTraverser = new DuplicateNameTraverser
+ dupNameTraverser.traverse(tree)
+ dupNameTraverser.result.groupBy(x => x.name).filter(_._2.size > 1).values.flatten.toSet[Symbol]
+ }
/** Stepping outside of the public Macro API to call [[scala.reflect.internal.Symbols.Symbol.name_=]] */
val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
@@ -40,7 +54,7 @@ private[async] final case class AnfTransform[C <: Context](val c: C) {
override def transform(tree: Tree): Tree = {
tree match {
- case defTree: DefTree if repeatedNames(defTree.symbol.name) =>
+ case defTree: DefTree if repeatedNames(defTree.symbol) =>
val trans = super.transform(defTree)
val origName = defTree.symbol.name
val sym = defTree.symbol.asInstanceOf[symtab.Symbol]
@@ -54,6 +68,8 @@ private[async] final case class AnfTransform[C <: Context](val c: C) {
trans match {
case ValDef(mods, name, tpt, rhs) =>
treeCopy.ValDef(trans, mods, newName, tpt, rhs)
+ case Bind(name, body) =>
+ treeCopy.Bind(trans, newName, body)
case DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
treeCopy.DefDef(trans, mods, newName, tparams, vparamss, tpt, rhs)
case TypeDef(mods, name, tparams, rhs) =>
@@ -79,12 +95,16 @@ private[async] final case class AnfTransform[C <: Context](val c: C) {
private object trace {
private var indent = -1
+
def indentString = " " * indent
+
def apply[T](prefix: String, args: Any)(t: => T): T = {
indent += 1
- def oneLine(s: Any) = s.toString.replaceAll("""\n""", "\\\\n").take(127)
+ def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127)
try {
- AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})")
+ AsyncUtils.trace(s"${
+ indentString
+ }$prefix(${oneLine(args)})")
val result = t
AsyncUtils.trace(s"${indentString}= ${oneLine(result)}")
result
@@ -127,11 +147,11 @@ private[async] final case class AnfTransform[C <: Context](val c: C) {
val varDef = defineVar(name.matchRes, expr.tpe, tree.pos)
val casesWithAssign = cases map {
case cd@CaseDef(pat, guard, Block(caseStats, caseExpr)) =>
- attachCopy.CaseDef(cd)(pat, guard, Block(caseStats, Assign(Ident(varDef.name), caseExpr)))
+ attachCopy(cd)(CaseDef(pat, guard, Block(caseStats, Assign(Ident(varDef.name), caseExpr))))
case cd@CaseDef(pat, guard, body) =>
- attachCopy.CaseDef(cd)(pat, guard, Assign(Ident(varDef.name), body))
+ attachCopy(cd)(CaseDef(pat, guard, Assign(Ident(varDef.name), body)))
}
- val matchWithAssign = attachCopy.Match(tree)(scrut, casesWithAssign)
+ val matchWithAssign = attachCopy(tree)(Match(scrut, casesWithAssign))
stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name)
}
case _ =>
@@ -139,10 +159,7 @@ private[async] final case class AnfTransform[C <: Context](val c: C) {
}
}
- def transformToList(trees: List[Tree]): List[Tree] = trees match {
- case fst :: rest => transformToList(fst) ++ transformToList(rest)
- case Nil => Nil
- }
+ def transformToList(trees: List[Tree]): List[Tree] = trees flatMap transformToList
def transformToBlock(tree: Tree): Block = transformToList(tree) match {
case stats :+ expr => Block(stats, expr)
@@ -168,7 +185,7 @@ private[async] final case class AnfTransform[C <: Context](val c: C) {
tree match {
case Select(qual, sel) if containsAwait =>
val stats :+ expr = inline.transformToList(qual)
- stats :+ attachCopy.Select(tree)(expr, sel).setSymbol(tree.symbol)
+ stats :+ attachCopy(tree)(Select(expr, sel).setSymbol(tree.symbol))
case Apply(fun, args) if containsAwait =>
// we an assume that no await call appears in a by-name argument position,
@@ -178,7 +195,7 @@ private[async] final case class AnfTransform[C <: Context](val c: C) {
val argLists = args map inline.transformToList
val allArgStats = argLists flatMap (_.init)
val simpleArgs = argLists map (_.last)
- funStats ++ allArgStats :+ attachCopy.Apply(tree)(simpleFun, simpleArgs).setSymbol(tree.symbol)
+ funStats ++ allArgStats :+ attachCopy(tree)(Apply(simpleFun, simpleArgs).setSymbol(tree.symbol))
case Block(stats, expr) if containsAwait =>
inline.transformToList(stats :+ expr)
@@ -186,39 +203,60 @@ private[async] final case class AnfTransform[C <: Context](val c: C) {
case ValDef(mods, name, tpt, rhs) if containsAwait =>
if (rhs exists isAwait) {
val stats :+ expr = inline.transformToList(rhs)
- stats :+ attachCopy.ValDef(tree)(mods, name, tpt, expr).setSymbol(tree.symbol)
+ stats :+ attachCopy(tree)(ValDef(mods, name, tpt, expr).setSymbol(tree.symbol))
} else List(tree)
case Assign(lhs, rhs) if containsAwait =>
val stats :+ expr = inline.transformToList(rhs)
- stats :+ attachCopy.Assign(tree)(lhs, expr)
+ stats :+ attachCopy(tree)(Assign(lhs, expr))
case If(cond, thenp, elsep) if containsAwait =>
- val stats :+ expr = inline.transformToList(cond)
+ val condStats :+ condExpr = inline.transformToList(cond)
val thenBlock = inline.transformToBlock(thenp)
val elseBlock = inline.transformToBlock(elsep)
- stats :+
- c.typeCheck(attachCopy.If(tree)(expr, thenBlock, elseBlock))
+ // Typechecking with `condExpr` as the condition fails if the condition
+ // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems
+ // we rely on this call to `typeCheck` descending into the branches.
+ // But, we can get away with typechecking a throwaway `If` tree with the
+ // original scrutinee and the new branches, and setting that type on
+ // the real `If` tree.
+ val ifType = c.typeCheck(If(cond, thenBlock, elseBlock)).tpe
+ condStats :+
+ attachCopy(tree)(If(condExpr, thenBlock, elseBlock)).setType(ifType)
case Match(scrut, cases) if containsAwait =>
val scrutStats :+ scrutExpr = inline.transformToList(scrut)
val caseDefs = cases map {
case CaseDef(pat, guard, body) =>
+ // extract local variables for all names bound in `pat`, and rewrite `body`
+ // to refer to these.
+ // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`.
val block = inline.transformToBlock(body)
- attachCopy.CaseDef(tree)(pat, guard, block)
+ val (valDefs, mappings) = (pat collect {
+ case b@Bind(name, _) =>
+ val newName = newTermName(utils.name.fresh(name.toTermName + utils.name.bindSuffix))
+ val vd = ValDef(NoMods, newName, TypeTree(), Ident(b.symbol))
+ (vd, (b.symbol, newName))
+ }).unzip
+ val Block(stats1, expr1) = utils.substituteNames(block, mappings.toMap).asInstanceOf[Block]
+ attachCopy(tree)(CaseDef(pat, guard, Block(valDefs ++ stats1, expr1)))
}
- scrutStats :+ c.typeCheck(attachCopy.Match(tree)(scrutExpr, caseDefs))
+ // Refer to comments the translation of `If` above.
+ val matchType = c.typeCheck(Match(scrut, caseDefs)).tpe
+ val typedMatch = attachCopy(tree)(Match(scrutExpr, caseDefs)).setType(tree.tpe)
+ scrutStats :+ typedMatch
case LabelDef(name, params, rhs) if containsAwait =>
List(LabelDef(name, params, Block(inline.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
case TypeApply(fun, targs) if containsAwait =>
val funStats :+ simpleFun = inline.transformToList(fun)
- funStats :+ attachCopy.TypeApply(tree)(simpleFun, targs).setSymbol(tree.symbol)
+ funStats :+ attachCopy(tree)(TypeApply(simpleFun, targs).setSymbol(tree.symbol))
case _ =>
List(tree)
}
}
}
+
}
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala
index ef506a5..4a770ed 100644
--- a/src/main/scala/scala/async/Async.scala
+++ b/src/main/scala/scala/async/Async.scala
@@ -11,6 +11,7 @@ import scala.reflect.macros.Context
* @author Philipp Haller
*/
object Async extends AsyncBase {
+
import scala.concurrent.Future
lazy val futureSystem = ScalaConcurrentFutureSystem
@@ -65,7 +66,6 @@ abstract class AsyncBase {
def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = {
import c.universe._
- val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem)
val anaylzer = AsyncAnalysis[c.type](c)
val utils = TransformUtils[c.type](c)
import utils.{name, defn}
@@ -87,54 +87,82 @@ abstract class AsyncBase {
// states of our generated state machine, e.g. a value assigned before
// an `await` and read afterwards.
val renameMap: Map[Symbol, TermName] = {
- anaylzer.valDefsUsedInSubsequentStates(anfTree).map {
+ anaylzer.defTreesUsedInSubsequentStates(anfTree).map {
vd =>
- (vd.symbol, name.fresh(vd.name))
+ (vd.symbol, name.fresh(vd.name.toTermName))
}.toMap
}
+ val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem, anfTree)
val asyncBlock: builder.AsyncBlock = builder.build(anfTree, renameMap)
import asyncBlock.asyncStates
logDiagnostics(c)(anfTree, asyncStates.map(_.toString))
+ // Important to retain the original declaration order here!
val localVarTrees = anfTree.collect {
- case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol =>
+ case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol =>
utils.mkVarDefTree(tpt.tpe, renameMap(vd.symbol))
+ case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) if renameMap contains dd.symbol =>
+ DefDef(mods, renameMap(dd.symbol), tparams, vparamss, tpt, c.resetAllAttrs(utils.substituteNames(rhs, renameMap)))
}
- val onCompleteHandler = asyncBlock.onCompleteHandler
+ val onCompleteHandler = {
+ Function(
+ List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)),
+ asyncBlock.onCompleteHandler)
+ }
val resumeFunTree = asyncBlock.resumeFunTree[T]
- val prom: Expr[futureSystem.Prom[T]] = reify {
- // Create the empty promise
- val result$async = futureSystemOps.createProm[T].splice
- // Initialize the state
- var state$async = 0
- // Resolve the execution context
- val execContext$async = futureSystemOps.execContext.splice
- var onCompleteHandler$async: util.Try[Any] => Unit = null
-
- // Spawn a future to:
- futureSystemOps.future[Unit] {
- c.Expr[Unit](Block(
- // define vars for all intermediate results that are accessed from multiple states
- localVarTrees :+
- // define the resume() method
- resumeFunTree :+
- // assign onComplete function. (The var breaks the circular dependency with resume)`
- Assign(Ident(name.onCompleteHandler), onCompleteHandler),
- // and get things started by calling resume()
- Apply(Ident(name.resume), Nil)))
- }(c.Expr[futureSystem.ExecContext](Ident(name.execContext))).splice
- // Return the promise from this reify block...
- result$async
+ val stateMachineType = utils.applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType))
+
+ lazy val stateMachine: ClassDef = {
+ val body: List[Tree] = {
+ val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0)))
+ val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree)
+ val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree)
+ val applyDefDef: DefDef = {
+ val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)))
+ val applyBody = asyncBlock.onCompleteHandler
+ DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), applyBody)
+ }
+ val apply0DefDef: DefDef = {
+ // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`.
+ // See SI-1247 for the the optimization that avoids creatio
+ val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)))
+ val applyBody = asyncBlock.onCompleteHandler
+ DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil))
+ }
+ List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef)
+ }
+ val template = {
+ Template(List(stateMachineType), emptyValDef, body)
+ }
+ ClassDef(NoMods, name.stateMachineT, Nil, template)
}
- // ... and return its Future from the macro.
- val result = futureSystemOps.promiseToFuture(prom)
- AsyncUtils.vprintln(s"async state machine transform expands to:\n ${result.tree}")
+ def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection)
+
+ def spawn(tree: Tree): Tree =
+ futureSystemOps.future(c.Expr[Unit](tree))(c.Expr[futureSystem.ExecContext](selectStateMachine(name.execContext))).tree
+
+ val code: c.Expr[futureSystem.Fut[T]] = {
+ val isSimple = asyncStates.size == 1
+ val tree =
+ if (isSimple)
+ Block(Nil, spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }`
+ else {
+ Block(List[Tree](
+ stateMachine,
+ ValDef(NoMods, name.stateMachine, stateMachineType, New(Ident(name.stateMachineT), Nil)),
+ spawn(Apply(selectStateMachine(name.apply), Nil))
+ ),
+ futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree)
+ }
+ c.Expr[futureSystem.Fut[T]](tree)
+ }
- result
+ AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}")
+ code
}
def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) {
@@ -151,3 +179,10 @@ abstract class AsyncBase {
states foreach (s => AsyncUtils.vprintln(s))
}
}
+
+/** Internal class used by the `async` macro; should not be manually extended by client code */
+abstract class StateMachine[Result, EC] extends (scala.util.Try[Any] => Unit) with (() => Unit) {
+ def result$async: Result
+
+ def execContext$async: EC
+}
diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala
index 4f5bf8d..8bb5bcd 100644
--- a/src/main/scala/scala/async/AsyncAnalysis.scala
+++ b/src/main/scala/scala/async/AsyncAnalysis.scala
@@ -5,12 +5,13 @@
package scala.async
import scala.reflect.macros.Context
-import collection.mutable
+import scala.collection.mutable
-private[async] final case class AsyncAnalysis[C <: Context](val c: C) {
+private[async] final case class AsyncAnalysis[C <: Context](c: C) {
import c.universe._
val utils = TransformUtils[c.type](c)
+
import utils._
/**
@@ -30,10 +31,11 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) {
*
* Must be called on the ANF transformed tree.
*/
- def valDefsUsedInSubsequentStates(tree: Tree): List[ValDef] = {
+ def defTreesUsedInSubsequentStates(tree: Tree): List[DefTree] = {
val analyzer = new AsyncDefinitionUseAnalyzer
analyzer.traverse(tree)
- analyzer.valDefsToLift.toList
+ val liftable: List[DefTree] = (analyzer.valDefsToLift ++ analyzer.nestedMethodsToLift).toList.distinct
+ liftable
}
private class UnsupportedAwaitAnalyzer extends AsyncTraverser {
@@ -41,7 +43,8 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) {
val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class"
if (!reportUnsupportedAwait(classDef, s"nested $kind")) {
// do not allow local class definitions, because of SI-5467 (specific to case classes, though)
- c.error(classDef.pos, s"Local class ${classDef.name.decoded} illegal within `async` block")
+ if (classDef.symbol.asClass.isCaseClass)
+ c.error(classDef.pos, s"Local case class ${classDef.name.decoded} illegal within `async` block")
}
}
@@ -70,12 +73,9 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) {
case Try(_, _, _) if containsAwait =>
reportUnsupportedAwait(tree, "try/catch")
super.traverse(tree)
- case If(cond, _, _) if containsAwait =>
- reportUnsupportedAwait(cond, "condition")
- super.traverse(tree)
- case Return(_) =>
+ case Return(_) =>
c.abort(tree.pos, "return is illegal within a async block")
- case _ =>
+ case _ =>
super.traverse(tree)
}
}
@@ -92,7 +92,7 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) {
c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
}
badAwaits.nonEmpty
- }
+ }
}
private class AsyncDefinitionUseAnalyzer extends AsyncTraverser {
@@ -102,40 +102,67 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) {
private var valDefChunkId = Map[Symbol, (ValDef, Int)]()
- val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]()
+ val valDefsToLift : mutable.Set[ValDef] = collection.mutable.Set()
+ val nestedMethodsToLift: mutable.Set[DefDef] = collection.mutable.Set()
+
+ override def nestedMethod(defDef: DefDef) {
+ nestedMethodsToLift += defDef
+ defDef.rhs foreach {
+ case rt: RefTree =>
+ valDefChunkId.get(rt.symbol) match {
+ case Some((vd, defChunkId)) =>
+ valDefsToLift += vd // lift all vals referred to by nested methods.
+ case _ =>
+ }
+ case _ =>
+ }
+ }
+
+ override def function(function: Function) {
+ function foreach {
+ case rt: RefTree =>
+ valDefChunkId.get(rt.symbol) match {
+ case Some((vd, defChunkId)) =>
+ valDefsToLift += vd // lift all vals referred to by nested functions.
+ case _ =>
+ }
+ case _ =>
+ }
+ }
override def traverse(tree: Tree) = {
tree match {
- case If(cond, thenp, elsep) if tree exists isAwait =>
+ case If(cond, thenp, elsep) if tree exists isAwait =>
traverseChunks(List(cond, thenp, elsep))
- case Match(selector, cases) if tree exists isAwait =>
+ case Match(selector, cases) if tree exists isAwait =>
traverseChunks(selector :: cases)
case LabelDef(name, params, rhs) if rhs exists isAwait =>
traverseChunks(rhs :: Nil)
- case Apply(fun, args) if isAwait(fun) =>
+ case Apply(fun, args) if isAwait(fun) =>
super.traverse(tree)
nextChunk()
- case vd: ValDef =>
+ case vd: ValDef =>
super.traverse(tree)
valDefChunkId += (vd.symbol ->(vd, chunkId))
- if (isAwait(vd.rhs)) valDefsToLift += vd
- case as: Assign =>
+ val isPatternBinder = vd.name.toString.contains(name.bindSuffix)
+ if (isAwait(vd.rhs) || isPatternBinder) valDefsToLift += vd
+ case as: Assign =>
if (isAwait(as.rhs)) {
- assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol)
+ assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol)
// TODO test the orElse case, try to remove the restriction.
val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}"))
valDefsToLift += vd
}
super.traverse(tree)
- case rt: RefTree =>
+ case rt: RefTree =>
valDefChunkId.get(rt.symbol) match {
case Some((vd, defChunkId)) if defChunkId != chunkId =>
valDefsToLift += vd
case _ =>
}
super.traverse(tree)
- case _ => super.traverse(tree)
+ case _ => super.traverse(tree)
}
}
@@ -145,4 +172,5 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) {
}
}
}
+
}
diff --git a/src/main/scala/scala/async/AsyncUtils.scala b/src/main/scala/scala/async/AsyncUtils.scala
index 87a63d7..999cb95 100644
--- a/src/main/scala/scala/async/AsyncUtils.scala
+++ b/src/main/scala/scala/async/AsyncUtils.scala
@@ -10,8 +10,8 @@ object AsyncUtils {
private def enabled(level: String) = sys.props.getOrElse(s"scala.async.$level", "false").equalsIgnoreCase("true")
- private val verbose = enabled("debug")
- private val trace = enabled("trace")
+ var verbose = enabled("debug")
+ var trace = enabled("trace")
private[async] def vprintln(s: => Any): Unit = if (verbose) println(s"[async] $s")
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index 5ae01f9..7b4ccb8 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -6,11 +6,12 @@ package scala.async
import scala.reflect.macros.Context
import scala.collection.mutable.ListBuffer
import collection.mutable
+import language.existentials
/*
* @author Philipp Haller
*/
-private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSystem: FS) {
+private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: C, futureSystem: FS, origTree: C#Tree) {
builder =>
val utils = TransformUtils[c.type](c)
@@ -70,7 +71,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](va
override def mkHandlerCaseForState: CaseDef = {
val callOnComplete = futureSystemOps.onComplete(c.Expr(awaitable.expr),
- c.Expr(Ident(name.onCompleteHandler)), c.Expr(Ident(name.execContext))).tree
+ c.Expr(This(tpnme.EMPTY)), c.Expr(Ident(name.execContext))).tree
mkHandlerCase(state, stats :+ callOnComplete)
}
@@ -96,12 +97,13 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](va
/** The state of the target of a LabelDef application (while loop jump) */
private var nextJumpState: Option[Int] = None
- private def renameReset(tree: Tree) = resetDuplicate(substituteNames(tree, nameMap))
+ private def renameReset(tree: Tree) = resetInternalAttrs(substituteNames(tree, nameMap))
def +=(stat: c.Tree): this.type = {
assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat")
def addStat() = stats += renameReset(stat)
stat match {
+ case _: DefDef => // these have been lifted.
case Apply(fun, Nil) =>
labelDefStates get fun.symbol match {
case Some(nextState) => nextJumpState = Some(nextState)
@@ -146,7 +148,12 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](va
def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], caseStates: List[Int]): AsyncState = {
// 1. build list of changed cases
val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match {
- case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(caseStates(num)), mkResumeApply))
+ case CaseDef(pat, guard, rhs) =>
+ val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal).map {
+ case ValDef(_, name, _, rhs) => Assign(Ident(name), rhs)
+ case t => sys.error(s"Unexpected tree. Expected ValDef, found: $t")
+ }
+ CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num)), mkResumeApply))
}
// 2. insert changed match tree at the end of the current state
this += Match(renameReset(scrutTree), newCases)
@@ -237,7 +244,9 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](va
stateBuilder.resultWithMatch(scrutinee, cases, caseStates)
for ((cas, num) <- cases.zipWithIndex) {
- val builder = nestedBlockBuilder(cas.body, caseStates(num), afterMatchState)
+ val (stats, expr) = statsAndExpr(cas.body)
+ val stats1 = stats.dropWhile(isSyntheticBindVal)
+ val builder = nestedBlockBuilder(Block(stats1, expr), caseStates(num), afterMatchState)
asyncStates ++= builder.asyncStates
}
@@ -302,19 +311,16 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](va
val initStates = asyncStates.init
/**
- * lazy val onCompleteHandler = (tr: Try[Any]) => state match {
+ * // assumes tr: Try[Any] is in scope.
+ * //
+ * state match {
* case 0 => {
* x11 = tr.get.asInstanceOf[Double];
* state = 1;
* resume()
* }
*/
- val onCompleteHandler: Tree = {
- val onCompleteHandlers = initStates.flatMap(_.mkOnCompleteHandler).toList
- Function(
- List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)),
- Match(Ident(name.state), onCompleteHandlers))
- }
+ val onCompleteHandler: Tree = Match(Ident(name.state), initStates.flatMap(_.mkOnCompleteHandler).toList)
/**
* def resume(): Unit = {
@@ -346,9 +352,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](va
}
}
+ private def isSyntheticBindVal(tree: Tree) = tree match {
+ case vd@ValDef(_, lname, _, Ident(rname)) => lname.toString.contains(name.bindSuffix)
+ case _ => false
+ }
+
private final case class Awaitable(expr: Tree, resultName: TermName, resultType: Type)
- private def resetDuplicate(tree: Tree) = c.resetAllAttrs(tree.duplicate)
+ private val internalSyms = origTree.collect {
+ case dt: DefTree => dt.symbol
+ }
+
+ private def resetInternalAttrs(tree: Tree) = utils.resetInternalAttrs(tree, internalSyms)
private def mkResumeApply = Apply(Ident(name.resume), Nil)
diff --git a/src/main/scala/scala/async/FutureSystem.scala b/src/main/scala/scala/async/FutureSystem.scala
index 20bbea3..e9373b3 100644
--- a/src/main/scala/scala/async/FutureSystem.scala
+++ b/src/main/scala/scala/async/FutureSystem.scala
@@ -33,6 +33,9 @@ trait FutureSystem {
/** Lookup the execution context, typically with an implicit search */
def execContext: Expr[ExecContext]
+ def promType[A: WeakTypeTag]: Type
+ def execContextType: Type
+
/** Create an empty promise */
def createProm[A: WeakTypeTag]: Expr[Prom[A]]
@@ -71,6 +74,9 @@ object ScalaConcurrentFutureSystem extends FutureSystem {
case context => context
})
+ def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Promise[A]]
+ def execContextType: Type = c.weakTypeOf[ExecutionContext]
+
def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify {
Promise[A]()
}
@@ -113,6 +119,9 @@ object IdentityFutureSystem extends FutureSystem {
def execContext: Expr[ExecContext] = c.literalUnit
+ def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Prom[A]]
+ def execContextType: Type = c.weakTypeOf[Unit]
+
def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify {
new Prom(null.asInstanceOf[A])
}
diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala
index 8838bb3..5b1fcbe 100644
--- a/src/main/scala/scala/async/TransformUtils.scala
+++ b/src/main/scala/scala/async/TransformUtils.scala
@@ -9,7 +9,7 @@ import reflect.ClassTag
/**
* Utilities used in both `ExprBuilder` and `AnfTransform`.
*/
-private[async] final case class TransformUtils[C <: Context](val c: C) {
+private[async] final case class TransformUtils[C <: Context](c: C) {
import c.universe._
@@ -18,18 +18,18 @@ private[async] final case class TransformUtils[C <: Context](val c: C) {
def suffixedName(prefix: String) = newTermName(suffix(prefix))
- val state = suffixedName("state")
- val result = suffixedName("result")
- val resume = suffixedName("resume")
- val execContext = suffixedName("execContext")
-
- // TODO do we need to freshen any of these?
- val tr = newTermName("tr")
- val onCompleteHandler = suffixedName("onCompleteHandler")
-
- val matchRes = "matchres"
- val ifRes = "ifres"
- val await = "await"
+ val state = suffixedName("state")
+ val result = suffixedName("result")
+ val resume = suffixedName("resume")
+ val execContext = suffixedName("execContext")
+ val stateMachine = newTermName(fresh("stateMachine"))
+ val stateMachineT = stateMachine.toTypeName
+ val apply = newTermName("apply")
+ val tr = newTermName("tr")
+ val matchRes = "matchres"
+ val ifRes = "ifres"
+ val await = "await"
+ val bindSuffix = "$bind"
def fresh(name: TermName): TermName = newTermName(fresh(name.toString))
@@ -127,6 +127,14 @@ private[async] final case class TransformUtils[C <: Context](val c: C) {
ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType))
}
+ def emptyConstructor: DefDef = {
+ val emptySuperCall = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil)
+ DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), c.literalUnit.tree))
+ }
+
+ def applied(className: String, types: List[Type]): AppliedTypeTree =
+ AppliedTypeTree(Ident(c.mirror.staticClass(className)), types.map(TypeTree(_)))
+
object defn {
def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = {
c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree)))
@@ -146,8 +154,7 @@ private[async] final case class TransformUtils[C <: Context](val c: C) {
self.splice.get
}
- val Try_get = methodSym(reify((null: scala.util.Try[Any]).get))
-
+ val Try_get = methodSym(reify((null: scala.util.Try[Any]).get))
val TryClass = c.mirror.staticClass("scala.util.Try")
val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe))
val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal")
@@ -159,7 +166,6 @@ private[async] final case class TransformUtils[C <: Context](val c: C) {
}
}
-
/** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */
private def methodSym(apply: c.Expr[Any]): Symbol = {
val tree2: Tree = c.typeCheck(apply.tree)
@@ -168,43 +174,72 @@ private[async] final case class TransformUtils[C <: Context](val c: C) {
}.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}"))
}
- /** Using [[scala.reflect.api.Trees.TreeCopier]] copies more than we would like:
- * we don't want to copy types and symbols to the new trees in some cases.
- *
- * Instead, we just copy positions and attachments.
- */
- object attachCopy {
- def copyAttach[T <: Tree](orig: Tree, tree: T): tree.type = {
- tree.setPos(orig.pos)
- for (att <- orig.attachments.all)
- tree.updateAttachment[Any](att)(ClassTag.apply[Any](att.getClass))
- tree
- }
-
- def Apply(tree: Tree)(fun: Tree, args: List[Tree]): Apply =
- copyAttach(tree, c.universe.Apply(fun, args))
+ /**
+ * Using [[scala.reflect.api.Trees.TreeCopier]] copies more than we would like:
+ * we don't want to copy types and symbols to the new trees in some cases.
+ *
+ * Instead, we just copy positions and attachments.
+ */
+ def attachCopy[T <: Tree](orig: Tree)(tree: T): tree.type = {
+ tree.setPos(orig.pos)
+ for (att <- orig.attachments.all)
+ tree.updateAttachment[Any](att)(ClassTag.apply[Any](att.getClass))
+ tree
+ }
- def Assign(tree: Tree)(lhs: Tree, rhs: Tree): Assign =
- copyAttach(tree, c.universe.Assign(lhs, rhs))
+ def resetInternalAttrs(tree: Tree, internalSyms: List[Symbol]) =
+ new ResetInternalAttrs(internalSyms.toSet).transform(tree)
- def CaseDef(tree: Tree)(pat: Tree, guard: Tree, block: Tree): CaseDef =
- copyAttach(tree, c.universe.CaseDef(pat, guard, block))
+ /**
+ * Adaptation of [[scala.reflect.internal.Trees.ResetAttrs]]
+ *
+ * A transformer which resets symbol and tpe fields of all nodes in a given tree,
+ * with special treatment of:
+ * `TypeTree` nodes: are replaced by their original if it exists, otherwise tpe field is reset
+ * to empty if it started out empty or refers to local symbols (which are erased).
+ * `TypeApply` nodes: are deleted if type arguments end up reverted to empty
+ *
+ * `This` and `Ident` nodes referring to an external symbol are ''not'' reset.
+ */
+ private final class ResetInternalAttrs(internalSyms: Set[Symbol]) extends Transformer {
- def If(tree: Tree)(cond: Tree, thenp: Tree, elsep: Tree): If =
- copyAttach(tree, c.universe.If(cond, thenp, elsep))
+ import language.existentials
- def Match(tree: Tree)(selector: Tree, cases: List[CaseDef]): Match =
- copyAttach(tree, c.universe.Match(selector, cases))
+ override def transform(tree: Tree): Tree = super.transform {
+ def isExternal = tree.symbol != NoSymbol && !internalSyms(tree.symbol)
- def Select(tree: Tree)(qual: Tree, name: Name): Select =
- copyAttach(tree, c.universe.Select(qual, name))
+ tree match {
+ case tpt: TypeTree => resetTypeTree(tpt)
+ case TypeApply(fn, args)
+ if args map transform exists (_.isEmpty) => transform(fn)
+ case EmptyTree => tree
+ case (_: Ident | _: This) if isExternal => tree // #35 Don't reset the symbol of Ident/This bound outside of the async block
+ case _ => resetTree(tree)
+ }
+ }
- def TypeApply(tree: Tree)(fun: Tree, args: List[Tree]): TypeApply = {
- copyAttach(tree, c.universe.TypeApply(fun, args))
+ private def resetTypeTree(tpt: TypeTree): Tree = {
+ if (tpt.original != null)
+ transform(tpt.original)
+ else if (tpt.tpe != null && tpt.asInstanceOf[symtab.TypeTree forSome {val symtab: reflect.internal.SymbolTable}].wasEmpty) {
+ val dupl = tpt.duplicate
+ dupl.tpe = null
+ dupl
+ }
+ else tpt
}
- def ValDef(tree: Tree)(mods: Modifiers, name: TermName, tpt: Tree, rhs: Tree): ValDef =
- copyAttach(tree, c.universe.ValDef(mods, name, tpt, rhs))
+ private def resetTree(tree: Tree): Tree = {
+ val hasSymbol: Boolean = {
+ val reflectInternalTree = tree.asInstanceOf[symtab.Tree forSome {val symtab: reflect.internal.SymbolTable}]
+ reflectInternalTree.hasSymbol
+ }
+ val dupl = tree.duplicate
+ if (hasSymbol)
+ dupl.symbol = NoSymbol
+ dupl.tpe = null
+ dupl
+ }
}
}
diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala
index dd239a3..a46eaf2 100644
--- a/src/test/scala/scala/async/TreeInterrogation.scala
+++ b/src/test/scala/scala/async/TreeInterrogation.scala
@@ -8,6 +8,7 @@ import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.Test
import AsyncId._
+import tools.reflect.ToolBox
@RunWith(classOf[JUnit4])
class TreeInterrogation {
@@ -20,47 +21,68 @@ class TreeInterrogation {
| async {
| val x = await(1)
| val y = x * 2
+ | def foo(a: Int) = { def nested = 0; a } // don't lift `nested`.
| val z = await(x * 3)
+ | foo(z)
| z
| }""".stripMargin)
val tree1 = tb.typeCheck(tree)
//println(cm.universe.show(tree1))
- import tb.mirror.universe._
+ import tb.u._
val functions = tree1.collect {
case f: Function => f
+ case t: Template => t
}
functions.size mustBe 1
val varDefs = tree1.collect {
case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name
}
- varDefs.map(_.decoded).toSet mustBe (Set("state$async", "onCompleteHandler$async", "await$1", "await$2"))
+ varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2"))
+ varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2"))
+
+ val defDefs = tree1.collect {
+ case t: Template =>
+ val stats: List[Tree] = t.body
+ stats.collect {
+ case dd : DefDef
+ if !dd.symbol.isImplementationArtifact
+ && !dd.symbol.asTerm.isAccessor && !dd.symbol.asTerm.isSetter => dd.name
+ }
+ }.flatten
+ defDefs.map(_.decoded.trim).toSet mustBe (Set("foo$1", "apply", "resume$async", "<init>"))
}
+}
- //@Test
- def sandbox() {
- sys.props("scala.async.debug") = true.toString
- sys.props("scala.async.trace") = false.toString
+object TreeInterrogation extends App {
+ def withDebug[T](t: => T) {
+ AsyncUtils.trace = true
+ AsyncUtils.verbose = true
+ try t
+ finally {
+ AsyncUtils.trace = false
+ AsyncUtils.verbose = false
+ }
+ }
+ withDebug {
val cm = reflect.runtime.currentMirror
- val tb = mkToolbox("-cp target/scala-2.10/classes")
+ val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:all")
val tree = tb.parse(
""" import _root_.scala.async.AsyncId._
- | async {
- | var sum = 0
- | var i = 0
- | while (i < 5) {
- | var j = 0
- | while (j < 5) {
- | sum += await(i) * await(j)
- | j += 1
- | }
- | i += 1
- | }
- | sum
+ | val state = 23
+ | val result: Any = "result"
+ | def resume(): Any = "resume"
+ | val res = async {
+ | val f1 = async { state + 2 }
+ | val x = await(f1)
+ | val y = await(async { result })
+ | val z = await(async { resume() })
+ | (x, y, z)
| }
+ | ()
| """.stripMargin)
println(tree)
val tree1 = tb.typeCheck(tree.duplicate)
diff --git a/src/test/scala/scala/async/neg/LocalClasses0Spec.scala b/src/test/scala/scala/async/neg/LocalClasses0Spec.scala
index 06a0e71..2569303 100644
--- a/src/test/scala/scala/async/neg/LocalClasses0Spec.scala
+++ b/src/test/scala/scala/async/neg/LocalClasses0Spec.scala
@@ -18,7 +18,7 @@ class LocalClasses0Spec {
@Test
def `reject a local class`() {
- expectError("Local class Person illegal within `async` block", "-cp target/scala-2.10/classes -deprecation -Xfatal-warnings") {
+ expectError("Local case class Person illegal within `async` block") {
"""
| import scala.concurrent.ExecutionContext.Implicits.global
| import scala.async.Async._
@@ -32,7 +32,7 @@ class LocalClasses0Spec {
@Test
def `reject a local class 2`() {
- expectError("Local class Person illegal within `async` block", "-cp target/scala-2.10/classes -deprecation -Xfatal-warnings") {
+ expectError("Local case class Person illegal within `async` block") {
"""
| import scala.concurrent.{Future, ExecutionContext}
| import ExecutionContext.Implicits.global
@@ -50,7 +50,7 @@ class LocalClasses0Spec {
@Test
def `reject a local class 3`() {
- expectError("Local class Person illegal within `async` block", "-cp target/scala-2.10/classes -deprecation -Xfatal-warnings") {
+ expectError("Local case class Person illegal within `async` block") {
"""
| import scala.concurrent.{Future, ExecutionContext}
| import ExecutionContext.Implicits.global
@@ -68,7 +68,7 @@ class LocalClasses0Spec {
@Test
def `reject a local class with symbols in its name`() {
- expectError("Local class :: illegal within `async` block", "-cp target/scala-2.10/classes -deprecation -Xfatal-warnings") {
+ expectError("Local case class :: illegal within `async` block") {
"""
| import scala.concurrent.{Future, ExecutionContext}
| import ExecutionContext.Implicits.global
@@ -86,7 +86,7 @@ class LocalClasses0Spec {
@Test
def `reject a nested local class`() {
- expectError("Local class Person illegal within `async` block", "-cp target/scala-2.10/classes -deprecation -Xfatal-warnings") {
+ expectError("Local case class Person illegal within `async` block") {
"""
| import scala.concurrent.{Future, ExecutionContext}
| import ExecutionContext.Implicits.global
@@ -110,7 +110,7 @@ class LocalClasses0Spec {
@Test
def `reject a local singleton object`() {
- expectError("Local object Person illegal within `async` block", "-cp target/scala-2.10/classes -deprecation -Xfatal-warnings") {
+ expectError("Local object Person illegal within `async` block") {
"""
| import scala.concurrent.ExecutionContext.Implicits.global
| import scala.async.Async._
diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala
index f4cfca2..ecc84f9 100644
--- a/src/test/scala/scala/async/neg/NakedAwait.scala
+++ b/src/test/scala/scala/async/neg/NakedAwait.scala
@@ -143,15 +143,4 @@ class NakedAwait {
|""".stripMargin
}
}
-
- // TODO Anf transform if to have a simple condition.
- @Test
- def ifCondition() {
- expectError("await must not be used under a condition.") {
- """
- | import _root_.scala.async.AsyncId._
- | async { if (await(true)) () }
- |""".stripMargin
- }
- }
}
diff --git a/src/test/scala/scala/async/package.scala b/src/test/scala/scala/async/package.scala
index bc4ebac..4a7a958 100644
--- a/src/test/scala/scala/async/package.scala
+++ b/src/test/scala/scala/async/package.scala
@@ -5,7 +5,7 @@
package scala
import reflect._
-import tools.reflect.ToolBoxError
+import tools.reflect.{ToolBox, ToolBoxError}
package object async {
@@ -36,7 +36,7 @@ package object async {
tb.eval(tb.parse(code))
}
- def mkToolbox(compileOptions: String = "") = {
+ def mkToolbox(compileOptions: String = ""): ToolBox[_ <: scala.reflect.api.Universe] = {
val m = scala.reflect.runtime.currentMirror
import scala.tools.reflect.ToolBox
m.mkToolBox(options = compileOptions)
diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
index 41eeaa5..6dd4db7 100644
--- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
+++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
@@ -24,8 +24,8 @@ class AnfTestClass {
}
def m(y: Int): Future[Int] = async {
- val f = base(y)
- await(f)
+ val blerg = base(y)
+ await(blerg)
}
def m2(y: Int): Future[Int] = async {
diff --git a/src/test/scala/scala/async/run/hygiene/Hygiene.scala b/src/test/scala/scala/async/run/hygiene/Hygiene.scala
index d0be2e0..9d1df21 100644
--- a/src/test/scala/scala/async/run/hygiene/Hygiene.scala
+++ b/src/test/scala/scala/async/run/hygiene/Hygiene.scala
@@ -6,45 +6,79 @@ package scala.async
package run
package hygiene
-import language.{reflectiveCalls, postfixOps}
-import concurrent._
-import scala.concurrent.{Future, ExecutionContext, future, Await}
-import scala.concurrent.duration._
-import scala.async.Async.{async, await}
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
+@RunWith(classOf[JUnit4])
+class HygieneSpec {
-class HygieneClass {
-
- import ExecutionContext.Implicits.global
+ import scala.async.AsyncId.{async, await}
- def m1(x: Int): Future[Int] = future {
- x + 2
- }
-
- def m2(y: Int) = {
+ @Test
+ def `is hygenic`() {
val state = 23
val result: Any = "result"
def resume(): Any = "resume"
- async {
- val f1 = m1(state)
- val x = await(f1)
- val y = await(future(result))
- val z = await(future(resume()))
+ val res = async {
+ val f1 = state + 2
+ val x = await(f1)
+ val y = await(result)
+ val z = await(resume())
(x, y, z)
}
+ res mustBe ((25, "result", "resume"))
}
-}
-@RunWith(classOf[JUnit4])
-class HygieneSpec {
+ @Test
+ def `external var as result of await`() {
+ var ext = 0
+ async {
+ ext = await(12)
+ }
+ ext mustBe (12)
+ }
- @Test def `is hygenic`() {
- val o = new HygieneClass
- val fut = o.m2(10)
- val res = Await.result(fut, 2 seconds)
+ @Test
+ def `external var as result of await 2`() {
+ var ext = 0
+ val inp = 10
+ async {
+ if (inp > 0)
+ ext = await(12)
+ else
+ ext = await(10)
+ }
+ ext mustBe (12)
+ }
+
+ @Test
+ def `external var as result of await 3`() {
+ var ext = 0
+ val inp = 10
+ async {
+ val x = if (inp > 0)
+ await(12)
+ else
+ await(10)
+ ext = x + await(2)
+ }
+ ext mustBe (14)
+ }
+
+ @Test
+ def `is hygenic nested`() {
+ val state = 23
+ val result: Any = "result"
+ def resume(): Any = "resume"
+ import AsyncId.{await, async}
+ val res = async {
+ val f1 = async { state + 2 }
+ val x = await(f1)
+ val y = await(async { result })
+ val z = await(async(await(async { resume() })))
+ (x, y, z)
+ }
res._1 mustBe (25)
res._2 mustBe ("result")
res._3 mustBe ("resume")
diff --git a/src/test/scala/scala/async/run/ifelse0/IfElse0.scala b/src/test/scala/scala/async/run/ifelse0/IfElse0.scala
index 0a72f1e..e2b1ca6 100644
--- a/src/test/scala/scala/async/run/ifelse0/IfElse0.scala
+++ b/src/test/scala/scala/async/run/ifelse0/IfElse0.scala
@@ -47,4 +47,12 @@ class IfElseSpec {
val res = Await.result(fut, 2 seconds)
res mustBe (14)
}
+
+ @Test def `await in condition`() {
+ import AsyncId.{async, await}
+ val result = async {
+ if ({await(true); await(true)}) await(1) else ???
+ }
+ result mustBe (1)
+ }
}
diff --git a/src/test/scala/scala/async/run/match0/Match0.scala b/src/test/scala/scala/async/run/match0/Match0.scala
index f550a69..8263e72 100644
--- a/src/test/scala/scala/async/run/match0/Match0.scala
+++ b/src/test/scala/scala/async/run/match0/Match0.scala
@@ -69,4 +69,47 @@ class MatchSpec {
val res = Await.result(fut, 2 seconds)
res mustBe (5)
}
+
+ @Test def `support await in a match expression with binds`() {
+ val result = AsyncId.async {
+ val x = 1
+ Option(x) match {
+ case op @ Some(x) =>
+ assert(op == Some(1))
+ x + AsyncId.await(x)
+ case None => AsyncId.await(0)
+ }
+ }
+ result mustBe (2)
+ }
+
+ @Test def `support await referring to pattern matching vals`() {
+ import AsyncId.{async, await}
+ val result = async {
+ val x = 1
+ val opt = Some("")
+ await(0)
+ val o @ Some(y) = opt
+
+ {
+ val o @ Some(y) = Some(".")
+ }
+
+ await(0)
+ await((o, y.isEmpty))
+ }
+ result mustBe ((Some(""), true))
+ }
+
+ @Test def `await in scrutinee`() {
+ import AsyncId.{async, await}
+ val result = async {
+ await(if ("".isEmpty) await(1) else ???) match {
+ case x if x < 0 => ???
+ case y: Int => y * await(3)
+ case _ => ???
+ }
+ }
+ result mustBe (3)
+ }
}
diff --git a/src/test/scala/scala/async/run/nesteddef/NestedDef.scala b/src/test/scala/scala/async/run/nesteddef/NestedDef.scala
new file mode 100644
index 0000000..2baef0d
--- /dev/null
+++ b/src/test/scala/scala/async/run/nesteddef/NestedDef.scala
@@ -0,0 +1,40 @@
+package scala.async
+package run
+package nesteddef
+
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import org.junit.Test
+
+@RunWith(classOf[JUnit4])
+class NestedDef {
+
+ @Test
+ def nestedDef() {
+ import AsyncId._
+ val result = async {
+ val a = 0
+ val x = await(a) - 1
+ val local = 43
+ def bar(d: Double) = -d + a + local
+ def foo(z: Any) = (a.toDouble, bar(x).toDouble, z)
+ foo(await(2))
+ }
+ result mustBe (0d, 44d, 2)
+ }
+
+
+ @Test
+ def nestedFunction() {
+ import AsyncId._
+ val result = async {
+ val a = 0
+ val x = await(a) - 1
+ val local = 43
+ val bar = (d: Double) => -d + a + local
+ val foo = (z: Any) => (a.toDouble, bar(x).toDouble, z)
+ foo(await(2))
+ }
+ result mustBe (0d, 44d, 2)
+ }
+}