aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/scala/async/internal/AnfTransform.scala81
-rw-r--r--src/main/scala/scala/async/internal/AsyncBase.scala4
-rw-r--r--src/main/scala/scala/async/internal/AsyncId.scala10
-rw-r--r--src/main/scala/scala/async/internal/AsyncMacro.scala7
-rw-r--r--src/main/scala/scala/async/internal/AsyncTransform.scala11
-rw-r--r--src/main/scala/scala/async/internal/ExprBuilder.scala80
-rw-r--r--src/main/scala/scala/async/internal/Lifter.scala1
-rw-r--r--src/main/scala/scala/async/internal/StateAssigner.scala3
-rw-r--r--src/main/scala/scala/async/internal/TransformUtils.scala145
-rw-r--r--src/test/scala/scala/async/TreeInterrogation.scala2
-rw-r--r--src/test/scala/scala/async/run/late/LateExpansion.scala170
11 files changed, 459 insertions, 55 deletions
diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala
index f81f5af..4545ca6 100644
--- a/src/main/scala/scala/async/internal/AnfTransform.scala
+++ b/src/main/scala/scala/async/internal/AnfTransform.scala
@@ -16,16 +16,18 @@ private[async] trait AnfTransform {
import c.internal._
import decorators._
- def anfTransform(tree: Tree): Block = {
+ def anfTransform(tree: Tree, owner: Symbol): Block = {
// Must prepend the () for issue #31.
- val block = c.typecheck(atPos(tree.pos)(Block(List(Literal(Constant(()))), tree))).setType(tree.tpe)
+ val block = c.typecheck(atPos(tree.pos)(newBlock(List(Literal(Constant(()))), tree))).setType(tree.tpe)
sealed abstract class AnfMode
case object Anf extends AnfMode
case object Linearizing extends AnfMode
+ val tree1 = adjustTypeOfTranslatedPatternMatches(block, owner)
+
var mode: AnfMode = Anf
- typingTransform(block)((tree, api) => {
+ typingTransform(tree1, owner)((tree, api) => {
def blockToList(tree: Tree): List[Tree] = tree match {
case Block(stats, expr) => stats :+ expr
case t => t :: Nil
@@ -34,7 +36,7 @@ private[async] trait AnfTransform {
def listToBlock(trees: List[Tree]): Block = trees match {
case trees @ (init :+ last) =>
val pos = trees.map(_.pos).reduceLeft(_ union _)
- Block(init, last).setType(last.tpe).setPos(pos)
+ newBlock(init, last).setType(last.tpe).setPos(pos)
}
object linearize {
@@ -66,6 +68,17 @@ private[async] trait AnfTransform {
stats :+ valDef :+ atPos(tree.pos)(ref1)
case If(cond, thenp, elsep) =>
+ // If we run the ANF transform post patmat, deal with trees like `(if (cond) jump1(){String} else jump2(){String}){String}`
+ // as though it was typed with `Unit`.
+ def isPatMatGeneratedJump(t: Tree): Boolean = t match {
+ case Block(_, expr) => isPatMatGeneratedJump(expr)
+ case If(_, thenp, elsep) => isPatMatGeneratedJump(thenp) && isPatMatGeneratedJump(elsep)
+ case _: Apply if isLabel(t.symbol) => true
+ case _ => false
+ }
+ if (isPatMatGeneratedJump(expr)) {
+ internal.setType(expr, definitions.UnitTpe)
+ }
// if type of if-else is Unit don't introduce assignment,
// but add Unit value to bring it into form expected by async transform
if (expr.tpe =:= definitions.UnitTpe) {
@@ -77,7 +90,7 @@ private[async] trait AnfTransform {
def branchWithAssign(orig: Tree) = api.typecheck(atPos(orig.pos) {
def cast(t: Tree) = mkAttributedCastPreservingAnnotations(t, tpe(varDef.symbol))
orig match {
- case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr)))
+ case Block(thenStats, thenExpr) => newBlock(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr)))
case _ => Assign(Ident(varDef.symbol), cast(orig))
}
})
@@ -115,7 +128,7 @@ private[async] trait AnfTransform {
}
}
- private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
+ def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(uncheckedBounds(tp))
valDef(sym, mkZero(uncheckedBounds(tp))).setType(NoType).setPos(pos)
}
@@ -152,8 +165,7 @@ private[async] trait AnfTransform {
}
def _transformToList(tree: Tree): List[Tree] = trace(tree) {
- val containsAwait = tree exists isAwait
- if (!containsAwait) {
+ if (!containsAwait(tree)) {
tree match {
case Block(stats, expr) =>
// avoids nested block in `while(await(false)) ...`.
@@ -207,10 +219,11 @@ private[async] trait AnfTransform {
funStats ++ argStatss.flatten.flatten :+ typedNewApply
case Block(stats, expr) =>
- (stats :+ expr).flatMap(linearize.transformToList)
+ val trees = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit) ::: linearize.transformToList(expr)
+ eliminateMatchEndLabelParameter(trees)
case ValDef(mods, name, tpt, rhs) =>
- if (rhs exists isAwait) {
+ if (containsAwait(rhs)) {
val stats :+ expr = api.atOwner(api.currentOwner.owner)(linearize.transformToList(rhs))
stats.foreach(_.changeOwner(api.currentOwner, api.currentOwner.owner))
stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr)
@@ -247,7 +260,7 @@ private[async] trait AnfTransform {
scrutStats :+ treeCopy.Match(tree, scrutExpr, caseDefs)
case LabelDef(name, params, rhs) =>
- List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
+ List(LabelDef(name, params, newBlock(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
case TypeApply(fun, targs) =>
val funStats :+ simpleFun = linearize.transformToList(fun)
@@ -259,6 +272,52 @@ private[async] trait AnfTransform {
}
}
+ // Replace the label parameters on `matchEnd` with use of a `matchRes` temporary variable
+ //
+ // CaseDefs are translated to labels without parmeters. A terminal label, `matchEnd`, accepts
+ // a parameter which is the result of the match (this is regular, so even Unit-typed matches have this).
+ //
+ // For our purposes, it is easier to:
+ // - extract a `matchRes` variable
+ // - rewrite the terminal label def to take no parameters, and instead read this temp variable
+ // - change jumps to the terminal label to an assignment and a no-arg label application
+ def eliminateMatchEndLabelParameter(statsExpr: List[Tree]): List[Tree] = {
+ import internal.{methodType, setInfo}
+ val caseDefToMatchResult = collection.mutable.Map[Symbol, Symbol]()
+
+ val matchResults = collection.mutable.Buffer[Tree]()
+ val statsExpr0 = statsExpr.reverseMap {
+ case ld @ LabelDef(_, param :: Nil, body) =>
+ val matchResult = linearize.defineVar(name.matchRes, param.tpe, ld.pos)
+ matchResults += matchResult
+ caseDefToMatchResult(ld.symbol) = matchResult.symbol
+ val ld2 = treeCopy.LabelDef(ld, ld.name, Nil, body.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil))
+ setInfo(ld.symbol, methodType(Nil, ld.symbol.info.resultType))
+ ld2
+ case t =>
+ if (caseDefToMatchResult.isEmpty) t
+ else typingTransform(t)((tree, api) =>
+ tree match {
+ case Apply(fun, arg :: Nil) if isLabel(fun.symbol) && caseDefToMatchResult.contains(fun.symbol) =>
+ api.typecheck(atPos(tree.pos)(newBlock(Assign(Ident(caseDefToMatchResult(fun.symbol)), api.recur(arg)) :: Nil, treeCopy.Apply(tree, fun, Nil))))
+ case Block(stats, expr) =>
+ api.default(tree) match {
+ case Block(stats, Block(stats1, expr)) =>
+ treeCopy.Block(tree, stats ::: stats1, expr)
+ case t => t
+ }
+ case _ =>
+ api.default(tree)
+ }
+ )
+ }
+ matchResults.toList match {
+ case Nil => statsExpr
+ case r1 :: Nil => (r1 +: statsExpr0.reverse) :+ atPos(tree.pos)(gen.mkAttributedIdent(r1.symbol))
+ case _ => c.error(macroPos, "Internal error: unexpected tree encountered during ANF transform " + statsExpr); statsExpr
+ }
+ }
+
def anfLinearize(tree: Tree): Block = {
val trees: List[Tree] = mode match {
case Anf => anf._transformToList(tree)
diff --git a/src/main/scala/scala/async/internal/AsyncBase.scala b/src/main/scala/scala/async/internal/AsyncBase.scala
index 7464c42..7a1e274 100644
--- a/src/main/scala/scala/async/internal/AsyncBase.scala
+++ b/src/main/scala/scala/async/internal/AsyncBase.scala
@@ -43,9 +43,9 @@ abstract class AsyncBase {
(body: c.Expr[T])
(execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = {
import c.universe._, c.internal._, decorators._
- val asyncMacro = AsyncMacro(c, self)
+ val asyncMacro = AsyncMacro(c, self)(body.tree)
- val code = asyncMacro.asyncTransform[T](body.tree, execContext.tree)(c.weakTypeTag[T])
+ val code = asyncMacro.asyncTransform[T](execContext.tree)(c.weakTypeTag[T])
AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}")
// Mark range positions for synthetic code as transparent to allow some wiggle room for overlapping ranges
diff --git a/src/main/scala/scala/async/internal/AsyncId.scala b/src/main/scala/scala/async/internal/AsyncId.scala
index 3afa55b..8654474 100644
--- a/src/main/scala/scala/async/internal/AsyncId.scala
+++ b/src/main/scala/scala/async/internal/AsyncId.scala
@@ -41,11 +41,11 @@ object AsyncTestLV extends AsyncBase {
* A trivial implementation of [[FutureSystem]] that performs computations
* on the current thread. Useful for testing.
*/
+class Box[A] {
+ var a: A = _
+}
object IdentityFutureSystem extends FutureSystem {
-
- class Prom[A] {
- var a: A = _
- }
+ type Prom[A] = Box[A]
type Fut[A] = A
type ExecContext = Unit
@@ -57,7 +57,7 @@ object IdentityFutureSystem extends FutureSystem {
def execContext: Expr[ExecContext] = c.Expr[Unit](Literal(Constant(())))
- def promType[A: WeakTypeTag]: Type = weakTypeOf[Prom[A]]
+ def promType[A: WeakTypeTag]: Type = weakTypeOf[Box[A]]
def tryType[A: WeakTypeTag]: Type = weakTypeOf[scala.util.Try[A]]
def execContextType: Type = weakTypeOf[Unit]
diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala
index e969f9b..e22407d 100644
--- a/src/main/scala/scala/async/internal/AsyncMacro.scala
+++ b/src/main/scala/scala/async/internal/AsyncMacro.scala
@@ -1,15 +1,17 @@
package scala.async.internal
object AsyncMacro {
- def apply(c0: reflect.macros.Context, base: AsyncBase): AsyncMacro { val c: c0.type } = {
+ def apply(c0: reflect.macros.Context, base: AsyncBase)(body0: c0.Tree): AsyncMacro { val c: c0.type } = {
import language.reflectiveCalls
new AsyncMacro { self =>
val c: c0.type = c0
+ val body: c.Tree = body0
// This member is required by `AsyncTransform`:
val asyncBase: AsyncBase = base
// These members are required by `ExprBuilder`:
val futureSystem: FutureSystem = base.futureSystem
val futureSystemOps: futureSystem.Ops {val c: self.c.type} = futureSystem.mkOps(c)
+ val containsAwait: c.Tree => Boolean = containsAwaitCached(body0)
}
}
}
@@ -19,7 +21,10 @@ private[async] trait AsyncMacro
with ExprBuilder with AsyncTransform with AsyncAnalysis with LiveVariables {
val c: scala.reflect.macros.Context
+ val body: c.Tree
+ val containsAwait: c.Tree => Boolean
lazy val macroPos = c.macroApplication.pos.makeTransparent
def atMacroPos(t: c.Tree) = c.universe.atPos(macroPos)(t)
+
}
diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala
index baa3fc2..af290e4 100644
--- a/src/main/scala/scala/async/internal/AsyncTransform.scala
+++ b/src/main/scala/scala/async/internal/AsyncTransform.scala
@@ -9,7 +9,7 @@ trait AsyncTransform {
val asyncBase: AsyncBase
- def asyncTransform[T](body: Tree, execContext: Tree)
+ def asyncTransform[T](execContext: Tree)
(resultType: WeakTypeTag[T]): Tree = {
// We annotate the type of the whole expression as `T @uncheckedBounds` so as not to introduce
@@ -22,7 +22,7 @@ trait AsyncTransform {
// Transform to A-normal form:
// - no await calls in qualifiers or arguments,
// - if/match only used in statement position.
- val anfTree0: Block = anfTransform(body)
+ val anfTree0: Block = anfTransform(body, c.internal.enclosingOwner)
val anfTree = futureSystemOps.postAnfTransform(anfTree0)
@@ -35,7 +35,7 @@ trait AsyncTransform {
val stateMachine: ClassDef = {
val body: List[Tree] = {
val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(StateAssigner.Initial)))
- val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T](uncheckedBoundsResultTag)), futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree)
+ val resultAndAccessors = mkMutableField(futureSystemOps.promType[T](uncheckedBoundsResultTag), name.result, futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree)
val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext)
val apply0DefDef: DefDef = {
@@ -43,7 +43,7 @@ trait AsyncTransform {
// See SI-1247 for the the optimization that avoids creation.
DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.apply), literalNull :: Nil))
}
- List(emptyConstructor, stateVar, result, execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef)
+ List(emptyConstructor, stateVar) ++ resultAndAccessors ++ List(execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef)
}
val tryToUnit = appliedType(definitions.FunctionClass(1), futureSystemOps.tryType[Any], typeOf[Unit])
@@ -98,10 +98,11 @@ trait AsyncTransform {
}
val isSimple = asyncBlock.asyncStates.size == 1
- if (isSimple)
+ val result = if (isSimple)
futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }`
else
startStateMachine
+ cleanupContainsAwaitAttachments(result)
}
def logDiagnostics(anfTree: Tree, states: Seq[String]) {
diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala
index 164e85b..16b9207 100644
--- a/src/main/scala/scala/async/internal/ExprBuilder.scala
+++ b/src/main/scala/scala/async/internal/ExprBuilder.scala
@@ -146,6 +146,8 @@ trait ExprBuilder {
private val stats = ListBuffer[Tree]()
/** The state of the target of a LabelDef application (while loop jump) */
private var nextJumpState: Option[Int] = None
+ private var nextJumpSymbol: Symbol = NoSymbol
+ def effectiveNextState(nextState: Int) = nextJumpState.orElse(if (nextJumpSymbol == NoSymbol) None else Some(stateIdForLabel(nextJumpSymbol))).getOrElse(nextState)
def +=(stat: Tree): this.type = {
stat match {
@@ -155,11 +157,16 @@ trait ExprBuilder {
}
def addStat() = stats += stat
stat match {
- case Apply(fun, Nil) =>
+ case Apply(fun, args) if isLabel(fun.symbol) =>
// labelDefStates belongs to the current ExprBuilder
labelDefStates get fun.symbol match {
- case opt @ Some(nextState) => nextJumpState = opt // re-use object
- case None => addStat()
+ case opt@Some(nextState) =>
+ // A backward jump
+ nextJumpState = opt // re-use object
+ nextJumpSymbol = fun.symbol
+ case None =>
+ // We haven't the corresponding LabelDef, this is a forward jump
+ nextJumpSymbol = fun.symbol
}
case _ => addStat()
}
@@ -169,13 +176,11 @@ trait ExprBuilder {
def resultWithAwait(awaitable: Awaitable,
onCompleteState: Int,
nextState: Int): AsyncState = {
- val effectiveNextState = nextJumpState.getOrElse(nextState)
- new AsyncStateWithAwait(stats.toList, state, onCompleteState, effectiveNextState, awaitable, symLookup)
+ new AsyncStateWithAwait(stats.toList, state, onCompleteState, effectiveNextState(nextState), awaitable, symLookup)
}
def resultSimple(nextState: Int): AsyncState = {
- val effectiveNextState = nextJumpState.getOrElse(nextState)
- new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup)
+ new SimpleAsyncState(stats.toList, state, effectiveNextState(nextState), symLookup)
}
def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = {
@@ -243,9 +248,17 @@ trait ExprBuilder {
}
import stateAssigner.nextState
+ def directlyAdjacentLabelDefs(t: Tree): List[Tree] = {
+ def isPatternCaseLabelDef(t: Tree) = t match {
+ case LabelDef(name, _, _) => name.toString.startsWith("case")
+ case _ => false
+ }
+ val (before, _ :: after) = (stats :+ expr).span(_ ne t)
+ before.reverse.takeWhile(isPatternCaseLabelDef) ::: after.takeWhile(isPatternCaseLabelDef)
+ }
// populate asyncStates
- for (stat <- stats) stat match {
+ for (stat <- (stats :+ expr)) stat match {
// the val name = await(..) pattern
case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) =>
val onCompleteState = nextState()
@@ -255,7 +268,7 @@ trait ExprBuilder {
currState = afterAwaitState
stateBuilder = new AsyncStateBuilder(currState, symLookup)
- case If(cond, thenp, elsep) if (stat exists isAwait) || containsForiegnLabelJump(stat) =>
+ case If(cond, thenp, elsep) if containsAwait(stat) || containsForiegnLabelJump(stat) =>
checkForUnsupportedAwait(cond)
val thenStartState = nextState()
@@ -275,7 +288,7 @@ trait ExprBuilder {
currState = afterIfState
stateBuilder = new AsyncStateBuilder(currState, symLookup)
- case Match(scrutinee, cases) if stat exists isAwait =>
+ case Match(scrutinee, cases) if containsAwait(stat) =>
checkForUnsupportedAwait(scrutinee)
val caseStates = cases.map(_ => nextState())
@@ -293,24 +306,21 @@ trait ExprBuilder {
currState = afterMatchState
stateBuilder = new AsyncStateBuilder(currState, symLookup)
+ case ld @ LabelDef(name, params, rhs)
+ if containsAwait(rhs) || directlyAdjacentLabelDefs(ld).exists(containsAwait) =>
- case ld @ LabelDef(name, params, rhs) if rhs exists isAwait =>
- val startLabelState = nextState()
+ val startLabelState = stateIdForLabel(ld.symbol)
val afterLabelState = nextState()
asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup)
labelDefStates(ld.symbol) = startLabelState
val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState)
asyncStates ++= builder.asyncStates
-
currState = afterLabelState
stateBuilder = new AsyncStateBuilder(currState, symLookup)
-
case _ =>
checkForUnsupportedAwait(stat)
stateBuilder += stat
}
- // complete last state builder (representing the expressions after the last await)
- stateBuilder += expr
val lastState = stateBuilder.resultSimple(endState)
asyncStates += lastState
}
@@ -383,18 +393,26 @@ trait ExprBuilder {
* }
* }
*/
- private def resumeFunTree[T: WeakTypeTag]: Tree =
- Try(
- Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ),
- List(
- CaseDef(
- Bind(name.t, Ident(nme.WILDCARD)),
- Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), {
- val t = c.Expr[Throwable](Ident(name.t))
- val complete = futureSystemOps.completeProm[T](
+ private def resumeFunTree[T: WeakTypeTag]: Tree = {
+ val body = Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]))
+ Try(
+ body,
+ List(
+ CaseDef(
+ Bind(name.t, Typed(Ident(nme.WILDCARD), Ident(defn.ThrowableClass))),
+ EmptyTree, {
+ val then = {
+ val t = c.Expr[Throwable](Ident(name.t))
+ val complete = futureSystemOps.completeProm[T](
c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryyFailure[T](t)).tree
- Block(toList(complete), Return(literalUnit))
- })), EmptyTree)
+ Block(toList(complete), Return(literalUnit))
+ }
+ If(Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), then, Throw(Ident(name.t)))
+ then
+ })), EmptyTree)
+
+ //body
+ }
def forever(t: Tree): Tree = {
val labelName = name.fresh("while$")
@@ -435,6 +453,14 @@ trait ExprBuilder {
private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef =
mkHandlerCase(num, adaptToUnit(rhs))
+ // We use the convention that the state machine's ID for a state corresponding to
+ // a labeldef will a negative number be based on the symbol ID. This allows us
+ // to translate a forward jump to the label as a state transition to a known state
+ // ID, even though the state machine transform hasn't yet processed the target label
+ // def. Negative numbers are used so as as not to clash with regular state IDs, which
+ // are allocated in ascending order from 0.
+ private def stateIdForLabel(sym: Symbol): Int = -symId(sym)
+
private def tpeOf(t: Tree): Type = t match {
case _ if t.tpe != null => t.tpe
case Try(body, Nil, _) => tpeOf(body)
diff --git a/src/main/scala/scala/async/internal/Lifter.scala b/src/main/scala/scala/async/internal/Lifter.scala
index 4242a8e..2998baf 100644
--- a/src/main/scala/scala/async/internal/Lifter.scala
+++ b/src/main/scala/scala/async/internal/Lifter.scala
@@ -40,6 +40,7 @@ trait Lifter {
val defs: Map[Tree, Int] = {
/** Collect the DefTrees directly enclosed within `t` that have the same owner */
def collectDirectlyEnclosedDefs(t: Tree): List[DefTree] = t match {
+ case ld: LabelDef => Nil
case dt: DefTree => dt :: Nil
case _: Function => Nil
case t =>
diff --git a/src/main/scala/scala/async/internal/StateAssigner.scala b/src/main/scala/scala/async/internal/StateAssigner.scala
index 55e7a51..2b74e8d 100644
--- a/src/main/scala/scala/async/internal/StateAssigner.scala
+++ b/src/main/scala/scala/async/internal/StateAssigner.scala
@@ -7,8 +7,7 @@ package scala.async.internal
private[async] final class StateAssigner {
private var current = StateAssigner.Initial
- def nextState(): Int =
- try current finally current += 1
+ def nextState(): Int = try current finally current += 1
}
object StateAssigner {
diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala
index 547f980..90419d3 100644
--- a/src/main/scala/scala/async/internal/TransformUtils.scala
+++ b/src/main/scala/scala/async/internal/TransformUtils.scala
@@ -41,6 +41,19 @@ private[async] trait TransformUtils {
def isAwait(fun: Tree) =
fun.symbol == defn.Async_await
+ def newBlock(stats: List[Tree], expr: Tree): Block = {
+ Block(stats, expr)
+ }
+
+ def isLiteralUnit(t: Tree) = t match {
+ case Literal(Constant(())) =>
+ true
+ case _ => false
+ }
+
+ def isPastTyper =
+ c.universe.asInstanceOf[scala.reflect.internal.SymbolTable].isPastTyper
+
// Copy pasted from TreeInfo in the compiler.
// Using a quasiquote pattern like `case q"$fun[..$targs](...$args)" => is not
// sufficient since https://github.com/scala/scala/pull/3656 as it doesn't match
@@ -150,6 +163,7 @@ private[async] trait TransformUtils {
}
val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal")
+ val ThrowableClass = rootMirror.staticClass("java.lang.Throwable")
val Async_await = asyncBase.awaitMethod(c.universe)(c.macroApplication.symbol).ensuring(_ != NoSymbol)
val IllegalStateExceptionClass = rootMirror.staticClass("java.lang.IllegalStateException")
}
@@ -161,16 +175,26 @@ private[async] trait TransformUtils {
val labelDefs = t.collect {
case ld: LabelDef => ld.symbol
}.toSet
- t.exists {
+ val result = t.exists {
case rt: RefTree => rt.symbol != null && isLabel(rt.symbol) && !(labelDefs contains rt.symbol)
case _ => false
}
+ result
}
- private def isLabel(sym: Symbol): Boolean = {
+ def isLabel(sym: Symbol): Boolean = {
val LABEL = 1L << 17 // not in the public reflection API.
(internal.flags(sym).asInstanceOf[Long] & LABEL) != 0L
}
+ def symId(sym: Symbol): Int = {
+ val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable]
+ sym.asInstanceOf[symtab.Symbol].id
+ }
+ def substituteTrees(t: Tree, from: List[Symbol], to: List[Tree]): Tree = {
+ val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable]
+ val subst = new symtab.TreeSubstituter(from.asInstanceOf[List[symtab.Symbol]], to.asInstanceOf[List[symtab.Tree]])
+ subst.transform(t.asInstanceOf[symtab.Tree]).asInstanceOf[Tree]
+ }
/** Map a list of arguments to:
@@ -362,4 +386,121 @@ private[async] trait TransformUtils {
else withAnnotation(tp, Annotation(UncheckedBoundsClass.asType.toType, Nil, ListMap()))
}
// =====================================
+
+ /**
+ * Efficiently decorate each subtree within `t` with the result of `t exists isAwait`,
+ * and return a function that can be used on derived trees to efficiently test the
+ * same condition.
+ *
+ * If the derived tree contains synthetic wrapper trees, these will be recursed into
+ * in search of a sub tree that was decorated with the cached answer.
+ */
+ final def containsAwaitCached(t: Tree): Tree => Boolean = {
+ def treeCannotContainAwait(t: Tree) = t match {
+ case _: Ident | _: TypeTree | _: Literal => true
+ case _ => false
+ }
+ def shouldAttach(t: Tree) = !treeCannotContainAwait(t)
+ val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
+ def attachContainsAwait(t: Tree): Unit = if (shouldAttach(t)) {
+ val t1 = t.asInstanceOf[symtab.Tree]
+ t1.updateAttachment(ContainsAwait)
+ t1.removeAttachment[NoAwait.type]
+ }
+ def attachNoAwait(t: Tree): Unit = if (shouldAttach(t)) {
+ val t1 = t.asInstanceOf[symtab.Tree]
+ t1.updateAttachment(NoAwait)
+ }
+ object markContainsAwaitTraverser extends Traverser {
+ var stack: List[Tree] = Nil
+
+ override def traverse(tree: Tree): Unit = {
+ stack ::= tree
+ try {
+ if (isAwait(tree))
+ stack.foreach(attachContainsAwait)
+ else
+ attachNoAwait(tree)
+ super.traverse(tree)
+ } finally stack = stack.tail
+ }
+ }
+ markContainsAwaitTraverser.traverse(t)
+
+ (t: Tree) => {
+ object traverser extends Traverser {
+ var containsAwait = false
+ override def traverse(tree: Tree): Unit = {
+ def castTree = tree.asInstanceOf[symtab.Tree]
+ if (!castTree.hasAttachment[NoAwait.type]) {
+ if (castTree.hasAttachment[ContainsAwait.type])
+ containsAwait = true
+ else if (!treeCannotContainAwait(t))
+ super.traverse(tree)
+ }
+ }
+ }
+ traverser.traverse(t)
+ traverser.containsAwait
+ }
+ }
+
+ final def cleanupContainsAwaitAttachments(t: Tree): t.type = {
+ val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
+ t.foreach {t =>
+ t.asInstanceOf[symtab.Tree].removeAttachment[ContainsAwait.type]
+ t.asInstanceOf[symtab.Tree].removeAttachment[NoAwait.type]
+ }
+ t
+ }
+
+ // First modification to translated patterns:
+ // - Set the type of label jumps to `Unit`
+ // - Propagate this change to trees known to directly enclose them:
+ // ``If` / `Block`) adjust types of enclosing
+ final def adjustTypeOfTranslatedPatternMatches(t: Tree, owner: Symbol): Tree = {
+ import definitions.UnitTpe
+ typingTransform(t, owner) {
+ (tree, api) =>
+ tree match {
+ case Block(stats, expr) =>
+ val stats1 = stats map api.recur
+ val expr1 = api.recur(expr)
+ if (expr1.tpe =:= UnitTpe)
+ internal.setType(treeCopy.Block(tree, stats1, expr1), UnitTpe)
+ else
+ treeCopy.Block(tree, stats1, expr1)
+ case If(cond, thenp, elsep) =>
+ val cond1 = api.recur(cond)
+ val thenp1 = api.recur(thenp)
+ val elsep1 = api.recur(elsep)
+ if (thenp1.tpe =:= definitions.UnitTpe && elsep.tpe =:= UnitTpe)
+ internal.setType(treeCopy.If(tree, cond1, thenp1, elsep1), UnitTpe)
+ else
+ treeCopy.If(tree, cond1, thenp1, elsep1)
+ case Apply(fun, args) if isLabel(fun.symbol) =>
+ internal.setType(treeCopy.Apply(tree, api.recur(fun), args map api.recur), UnitTpe)
+ case t => api.default(t)
+ }
+ }
+ }
+
+ final def mkMutableField(tpt: Type, name: TermName, init: Tree): List[Tree] = {
+ if (isPastTyper) {
+ // If we are running after the typer phase (ie being called from a compiler plugin)
+ // we have to create the trio of members manually.
+ val ACCESSOR = (1L << 27).asInstanceOf[FlagSet]
+ val STABLE = (1L << 22).asInstanceOf[FlagSet]
+ val field = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name + " ", TypeTree(tpt), init)
+ val getter = DefDef(Modifiers(ACCESSOR | STABLE), name, Nil, Nil, TypeTree(tpt), Select(This(tpnme.EMPTY), field.name))
+ val setter = DefDef(Modifiers(ACCESSOR), name + "_=", Nil, List(List(ValDef(NoMods, TermName("x"), TypeTree(tpt), EmptyTree))), TypeTree(definitions.UnitTpe), Assign(Select(This(tpnme.EMPTY), field.name), Ident(TermName("x"))))
+ field :: getter :: setter :: Nil
+ } else {
+ val result = ValDef(NoMods, name, TypeTree(tpt), init)
+ result :: Nil
+ }
+ }
}
+
+case object ContainsAwait
+case object NoAwait \ No newline at end of file
diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala
index d6c619f..09fa69e 100644
--- a/src/test/scala/scala/async/TreeInterrogation.scala
+++ b/src/test/scala/scala/async/TreeInterrogation.scala
@@ -82,6 +82,8 @@ object TreeInterrogation extends App {
println(tree)
val tree1 = tb.typeCheck(tree.duplicate)
println(cm.universe.show(tree1))
+
println(tb.eval(tree))
}
+
}
diff --git a/src/test/scala/scala/async/run/late/LateExpansion.scala b/src/test/scala/scala/async/run/late/LateExpansion.scala
new file mode 100644
index 0000000..b866527
--- /dev/null
+++ b/src/test/scala/scala/async/run/late/LateExpansion.scala
@@ -0,0 +1,170 @@
+package scala.async.run.late
+
+import java.io.File
+
+import junit.framework.Assert.assertEquals
+import org.junit.Test
+
+import scala.annotation.StaticAnnotation
+import scala.async.internal.{AsyncId, AsyncMacro}
+import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader
+import scala.tools.nsc._
+import scala.tools.nsc.plugins.{Plugin, PluginComponent}
+import scala.tools.nsc.reporters.StoreReporter
+import scala.tools.nsc.transform.TypingTransformers
+
+// Tests for customized use of the async transform from a compiler plugin, which
+// calls it from a new phase that runs after patmat.
+class LateExpansion {
+ @Test def test0(): Unit = {
+ val result = wrapAndRun(
+ """
+ | @autoawait def id(a: String) = a
+ | id("foo") + id("bar")
+ | """.stripMargin)
+ assertEquals("foobar", result)
+ }
+ @Test def testGuard(): Unit = {
+ val result = wrapAndRun(
+ """
+ | @autoawait def id[A](a: A) = a
+ | "" match { case _ if id(false) => ???; case _ => "okay" }
+ | """.stripMargin)
+ assertEquals("okay", result)
+ }
+
+ @Test def testExtractor(): Unit = {
+ val result = wrapAndRun(
+ """
+ | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) }
+ | "" match { case Extractor(a, b) if "".isEmpty => a == b }
+ | """.stripMargin)
+ assertEquals(true, result)
+ }
+
+ @Test def testNestedMatchExtractor(): Unit = {
+ val result = wrapAndRun(
+ """
+ | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) }
+ | "" match {
+ | case _ if "".isEmpty =>
+ | "" match { case Extractor(a, b) => a == b }
+ | }
+ | """.stripMargin)
+ assertEquals(true, result)
+ }
+
+ @Test def testCombo(): Unit = {
+ val result = wrapAndRun(
+ """
+ | object Extractor1 { @autoawait def unapply(a: String) = Some((a + 1, a + 2)) }
+ | object Extractor2 { @autoawait def unapply(a: String) = Some(a + 3) }
+ | @autoawait def id(a: String) = a
+ | println("Test.test")
+ | val r1 = Predef.identity("blerg") match {
+ | case x if " ".isEmpty => "case 2: " + x
+ | case Extractor1(Extractor2(x), y: String) if x == "xxx" => "case 1: " + x + ":" + y
+ | x match {
+ | case Extractor1(Extractor2(x), y: String) =>
+ | case _ =>
+ | }
+ | case Extractor2(x) => "case 3: " + x
+ | }
+ | r1
+ | """.stripMargin)
+ assertEquals("case 3: blerg3", result)
+ }
+
+ def wrapAndRun(code: String): Any = {
+ run(
+ s"""
+ |import scala.async.run.late.{autoawait,lateasync}
+ |object Test {
+ | @lateasync
+ | def test: Any = {
+ | $code
+ | }
+ |}
+ | """.stripMargin)
+ }
+
+ def run(code: String): Any = {
+ val reporter = new StoreReporter
+ val settings = new Settings(println(_))
+ settings.outdir.value = sys.props("java.io.tmpdir")
+ settings.embeddedDefaults(getClass.getClassLoader)
+ val isInSBT = !settings.classpath.isSetByUser
+ if (isInSBT) settings.usejavacp.value = true
+ val global = new Global(settings, reporter) {
+ self =>
+
+ object late extends {
+ val global: self.type = self
+ } with LatePlugin
+
+ override protected def loadPlugins(): List[Plugin] = late :: Nil
+ }
+ import global._
+
+ val run = new Run
+ val source = newSourceFile(code)
+ run.compileSources(source :: Nil)
+ assert(!reporter.hasErrors, reporter.infos.mkString("\n"))
+ val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader)
+ val cls = loader.loadClass("Test")
+ cls.getMethod("test").invoke(null)
+ }
+}
+
+abstract class LatePlugin extends Plugin {
+ import global._
+
+ override val components: List[PluginComponent] = List(new PluginComponent with TypingTransformers {
+ val global: LatePlugin.this.global.type = LatePlugin.this.global
+
+ lazy val asyncIdSym = symbolOf[AsyncId.type]
+ lazy val asyncSym = asyncIdSym.info.member(TermName("async"))
+ lazy val awaitSym = asyncIdSym.info.member(TermName("await"))
+ lazy val autoAwaitSym = symbolOf[autoawait]
+ lazy val lateAsyncSym = symbolOf[lateasync]
+
+ def newTransformer(unit: CompilationUnit) = new TypingTransformer(unit) {
+ override def transform(tree: Tree): Tree = {
+ super.transform(tree) match {
+ case ap@Apply(fun, args) if fun.symbol.hasAnnotation(autoAwaitSym) =>
+ localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(ap.tpe) :: Nil), ap :: Nil))
+ case dd: DefDef if dd.symbol.hasAnnotation(lateAsyncSym) => atOwner(dd.symbol) {
+ val expandee = localTyper.context.withMacrosDisabled(
+ localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(dd.rhs.tpe) :: Nil), List(dd.rhs)))
+ )
+ val c = analyzer.macroContext(localTyper, gen.mkAttributedRef(asyncIdSym), expandee)
+ val asyncMacro = AsyncMacro(c, AsyncId)(dd.rhs)
+ val code = asyncMacro.asyncTransform[Any](localTyper.typed(Literal(Constant(()))))(c.weakTypeTag[Any])
+ deriveDefDef(dd)(_ => localTyper.typed(code))
+ }
+ case x => x
+ }
+ }
+ }
+
+ override def newPhase(prev: Phase): Phase = new StdPhase(prev) {
+ override def apply(unit: CompilationUnit): Unit = {
+ val translated = newTransformer(unit).transformUnit(unit)
+ //println(show(unit.body))
+ translated
+ }
+ }
+
+ override val runsAfter: List[String] = "patmat" :: Nil
+ override val phaseName: String = "postpatmat"
+
+ })
+ override val description: String = "postpatmat"
+ override val name: String = "postpatmat"
+}
+
+// Methods with this annotation are translated to having the RHS wrapped in `AsyncId.async { <original RHS> }`
+final class lateasync extends StaticAnnotation
+
+// Calls to methods with this annotation are translated to `AsyncId.await(<call>)`
+final class autoawait extends StaticAnnotation