aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2015-10-23 16:25:02 +1000
committerJason Zaugg <jzaugg@gmail.com>2016-01-19 14:27:31 +1000
commit549a656fa22af5f7f0c5e89dd6e0a19ed4b604f5 (patch)
tree7a7c778a24143923a674acb9db30ecdc5e3f8f5e /src
parent634c454dbd546e2f3db6321b4047b3cebd2f899a (diff)
downloadscala-async-549a656fa22af5f7f0c5e89dd6e0a19ed4b604f5.tar.gz
scala-async-549a656fa22af5f7f0c5e89dd6e0a19ed4b604f5.tar.bz2
scala-async-549a656fa22af5f7f0c5e89dd6e0a19ed4b604f5.zip
Various fixes to late expansion
- Detect cross-state symbol references where the RefTree is nested in a LabelDef. Failure to do so led to ill-scoped local variable references which sometimes manifest as VerifyErrors. - Emit a default case in the Match intended to be a tableswitch. We have to do this ourselves if we expand after pattern matcher - Cleanup generated code to avoid redundant blocks - Avoid unnecessary `matchRes` temporary variable for unit-typed pattern matches - Fix the trace level logging in the ANF transform to restore indented output. - Emit `{ state = nextState; ... }` rather than `try { ... } finally { state = nextState }` in state handlers. This simplifies generated code and has the same meaning, as the code in the state machine isn't reentrant and can't observe the "early" transition of the state.
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/scala/async/internal/AnfTransform.scala145
-rw-r--r--src/main/scala/scala/async/internal/AsyncBase.scala6
-rw-r--r--src/main/scala/scala/async/internal/ExprBuilder.scala52
-rw-r--r--src/main/scala/scala/async/internal/Lifter.scala5
-rw-r--r--src/main/scala/scala/async/internal/TransformUtils.scala12
-rw-r--r--src/test/scala/scala/async/TreeInterrogation.scala2
-rw-r--r--src/test/scala/scala/async/run/late/LateExpansion.scala337
7 files changed, 481 insertions, 78 deletions
diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala
index 4545ca6..dc10a95 100644
--- a/src/main/scala/scala/async/internal/AnfTransform.scala
+++ b/src/main/scala/scala/async/internal/AnfTransform.scala
@@ -27,6 +27,27 @@ private[async] trait AnfTransform {
val tree1 = adjustTypeOfTranslatedPatternMatches(block, owner)
var mode: AnfMode = Anf
+
+ object trace {
+ private var indent = -1
+
+ private def indentString = " " * indent
+
+ def apply[T](args: Any)(t: => T): T = {
+ def prefix = mode.toString.toLowerCase
+ indent += 1
+ def oneLine(s: Any) = s.toString.replaceAll("""\n""", "\\\\n").take(127)
+ try {
+ AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})")
+ val result = t
+ AsyncUtils.trace(s"${indentString}= ${oneLine(result)}")
+ result
+ } finally {
+ indent -= 1
+ }
+ }
+ }
+
typingTransform(tree1, owner)((tree, api) => {
def blockToList(tree: Tree): List[Tree] = tree match {
case Block(stats, expr) => stats :+ expr
@@ -97,8 +118,11 @@ private[async] trait AnfTransform {
val ifWithAssign = treeCopy.If(tree, cond, branchWithAssign(thenp), branchWithAssign(elsep)).setType(definitions.UnitTpe)
stats :+ varDef :+ ifWithAssign :+ atPos(tree.pos)(gen.mkAttributedStableRef(varDef.symbol)).setType(tree.tpe)
}
- case LabelDef(name, params, rhs) =>
- statsExprUnit
+ case ld @ LabelDef(name, params, rhs) =>
+ if (ld.symbol.info.resultType.typeSymbol == definitions.UnitClass)
+ statsExprUnit
+ else
+ stats :+ expr
case Match(scrut, cases) =>
// if type of match is Unit don't introduce assignment,
@@ -134,26 +158,6 @@ private[async] trait AnfTransform {
}
}
- object trace {
- private var indent = -1
-
- private def indentString = " " * indent
-
- def apply[T](args: Any)(t: => T): T = {
- def prefix = mode.toString.toLowerCase
- indent += 1
- def oneLine(s: Any) = s.toString.replaceAll("""\n""", "\\\\n").take(127)
- try {
- AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})")
- val result = t
- AsyncUtils.trace(s"${indentString}= ${oneLine(result)}")
- result
- } finally {
- indent -= 1
- }
- }
- }
-
def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = {
val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, SYNTHETIC).setInfo(uncheckedBounds(lhs.tpe))
internal.valDef(sym, internal.changeOwner(lhs, api.currentOwner, sym)).setType(NoType).setPos(pos)
@@ -219,8 +223,29 @@ private[async] trait AnfTransform {
funStats ++ argStatss.flatten.flatten :+ typedNewApply
case Block(stats, expr) =>
- val trees = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit) ::: linearize.transformToList(expr)
- eliminateMatchEndLabelParameter(trees)
+ val stats1 = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit)
+ val exprs1 = linearize.transformToList(expr)
+ val trees = stats1 ::: exprs1
+ def isMatchEndLabel(t: Tree): Boolean = t match {
+ case ValDef(_, _, _, t) if isMatchEndLabel(t) => true
+ case ld: LabelDef if ld.name.toString.startsWith("matchEnd") => true
+ case _ => false
+ }
+ def groupsEndingWith[T](ts: List[T])(f: T => Boolean): List[List[T]] = if (ts.isEmpty) Nil else {
+ ts.indexWhere(f) match {
+ case -1 => List(ts)
+ case i =>
+ val (ts1, ts2) = ts.splitAt(i + 1)
+ ts1 :: groupsEndingWith(ts2)(f)
+ }
+ }
+ val matchGroups = groupsEndingWith(trees)(isMatchEndLabel)
+ val trees1 = matchGroups.flatMap(eliminateMatchEndLabelParameter)
+ val result = trees1 flatMap {
+ case Block(stats, expr) => stats :+ expr
+ case t => t :: Nil
+ }
+ result
case ValDef(mods, name, tpt, rhs) =>
if (containsAwait(rhs)) {
@@ -260,7 +285,10 @@ private[async] trait AnfTransform {
scrutStats :+ treeCopy.Match(tree, scrutExpr, caseDefs)
case LabelDef(name, params, rhs) =>
- List(LabelDef(name, params, newBlock(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
+ if (tree.symbol.info.typeSymbol == definitions.UnitClass)
+ List(treeCopy.LabelDef(tree, name, params, api.typecheck(newBlock(linearize.transformToList(rhs), Literal(Constant(()))))).setSymbol(tree.symbol))
+ else
+ List(treeCopy.LabelDef(tree, name, params, api.typecheck(listToBlock(linearize.transformToList(rhs)))).setSymbol(tree.symbol))
case TypeApply(fun, targs) =>
val funStats :+ simpleFun = linearize.transformToList(fun)
@@ -274,7 +302,7 @@ private[async] trait AnfTransform {
// Replace the label parameters on `matchEnd` with use of a `matchRes` temporary variable
//
- // CaseDefs are translated to labels without parmeters. A terminal label, `matchEnd`, accepts
+ // CaseDefs are translated to labels without parameters. A terminal label, `matchEnd`, accepts
// a parameter which is the result of the match (this is regular, so even Unit-typed matches have this).
//
// For our purposes, it is easier to:
@@ -286,34 +314,71 @@ private[async] trait AnfTransform {
val caseDefToMatchResult = collection.mutable.Map[Symbol, Symbol]()
val matchResults = collection.mutable.Buffer[Tree]()
- val statsExpr0 = statsExpr.reverseMap {
- case ld @ LabelDef(_, param :: Nil, body) =>
+ def modifyLabelDef(ld: LabelDef): (Tree, Tree) = {
+ val symTab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
+ val param = ld.params.head
+ val ld2 = if (ld.params.head.tpe.typeSymbol == definitions.UnitClass) {
+ // Unit typed match: eliminate the label def parameter, but don't create a matchres temp variable to
+ // store the result for cleaner generated code.
+ caseDefToMatchResult(ld.symbol) = NoSymbol
+ val rhs2 = substituteTrees(ld.rhs, param.symbol :: Nil, api.typecheck(literalUnit) :: Nil)
+ (treeCopy.LabelDef(ld, ld.name, Nil, api.typecheck(literalUnit)), rhs2)
+ } else {
+ // Otherwise, create the matchres var. We'll callers of the label def below.
+ // Remember: we're iterating through the statement sequence in reverse, so we'll get
+ // to the LabelDef and mutate `matchResults` before we'll get to its callers.
val matchResult = linearize.defineVar(name.matchRes, param.tpe, ld.pos)
matchResults += matchResult
caseDefToMatchResult(ld.symbol) = matchResult.symbol
- val ld2 = treeCopy.LabelDef(ld, ld.name, Nil, body.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil))
- setInfo(ld.symbol, methodType(Nil, ld.symbol.info.resultType))
- ld2
+ val rhs2 = ld.rhs.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil)
+ (treeCopy.LabelDef(ld, ld.name, Nil, api.typecheck(literalUnit)), rhs2)
+ }
+ setInfo(ld.symbol, methodType(Nil, definitions.UnitTpe))
+ ld2
+ }
+ val statsExpr0 = statsExpr.reverse.flatMap {
+ case ld @ LabelDef(_, param :: Nil, _) =>
+ val (ld1, after) = modifyLabelDef(ld)
+ List(after, ld1)
+ case a @ ValDef(mods, name, tpt, ld @ LabelDef(_, param :: Nil, _)) =>
+ val (ld1, after) = modifyLabelDef(ld)
+ List(treeCopy.ValDef(a, mods, name, tpt, after), ld1)
case t =>
- if (caseDefToMatchResult.isEmpty) t
- else typingTransform(t)((tree, api) =>
+ if (caseDefToMatchResult.isEmpty) t :: Nil
+ else typingTransform(t)((tree, api) => {
+ def typedPos(pos: Position)(t: Tree): Tree =
+ api.typecheck(atPos(pos)(t))
tree match {
case Apply(fun, arg :: Nil) if isLabel(fun.symbol) && caseDefToMatchResult.contains(fun.symbol) =>
- api.typecheck(atPos(tree.pos)(newBlock(Assign(Ident(caseDefToMatchResult(fun.symbol)), api.recur(arg)) :: Nil, treeCopy.Apply(tree, fun, Nil))))
- case Block(stats, expr) =>
+ val temp = caseDefToMatchResult(fun.symbol)
+ if (temp == NoSymbol)
+ typedPos(tree.pos)(newBlock(api.recur(arg) :: Nil, treeCopy.Apply(tree, fun, Nil)))
+ else
+ // setType needed for LateExpansion.shadowingRefinedType test case. There seems to be an inconsistency
+ // in the trees after pattern matcher.
+ // TODO miminize the problem in patmat and fix in scalac.
+ typedPos(tree.pos)(newBlock(Assign(Ident(temp), api.recur(internal.setType(arg, fun.tpe.paramLists.head.head.info))) :: Nil, treeCopy.Apply(tree, fun, Nil)))
+ case Block(stats, expr: Apply) if isLabel(expr.symbol) =>
api.default(tree) match {
- case Block(stats, Block(stats1, expr)) =>
- treeCopy.Block(tree, stats ::: stats1, expr)
+ case Block(stats0, Block(stats1, expr1)) =>
+ // flatten the block returned by `case Apply` above into the enclosing block for
+ // cleaner generated code.
+ treeCopy.Block(tree, stats0 ::: stats1, expr1)
case t => t
}
case _ =>
api.default(tree)
}
- )
+ }) :: Nil
}
matchResults.toList match {
- case Nil => statsExpr
- case r1 :: Nil => (r1 +: statsExpr0.reverse) :+ atPos(tree.pos)(gen.mkAttributedIdent(r1.symbol))
+ case _ if caseDefToMatchResult.isEmpty =>
+ statsExpr // return the original trees if nothing changed
+ case Nil =>
+ statsExpr0.reverse :+ literalUnit // must have been a unit-typed match, no matchRes variable to definne or refer to
+ case r1 :: Nil =>
+ // { var matchRes = _; ....; matchRes }
+ (r1 +: statsExpr0.reverse) :+ atPos(tree.pos)(gen.mkAttributedIdent(r1.symbol))
case _ => c.error(macroPos, "Internal error: unexpected tree encountered during ANF transform " + statsExpr); statsExpr
}
}
diff --git a/src/main/scala/scala/async/internal/AsyncBase.scala b/src/main/scala/scala/async/internal/AsyncBase.scala
index 4853d2b..ec9dc25 100644
--- a/src/main/scala/scala/async/internal/AsyncBase.scala
+++ b/src/main/scala/scala/async/internal/AsyncBase.scala
@@ -55,12 +55,14 @@ abstract class AsyncBase {
protected[async] def asyncMethod(u: Universe)(asyncMacroSymbol: u.Symbol): u.Symbol = {
import u._
- asyncMacroSymbol.owner.typeSignature.member(newTermName("async"))
+ if (asyncMacroSymbol == null) NoSymbol
+ else asyncMacroSymbol.owner.typeSignature.member(newTermName("async"))
}
protected[async] def awaitMethod(u: Universe)(asyncMacroSymbol: u.Symbol): u.Symbol = {
import u._
- asyncMacroSymbol.owner.typeSignature.member(newTermName("await"))
+ if (asyncMacroSymbol == null) NoSymbol
+ else asyncMacroSymbol.owner.typeSignature.member(newTermName("await"))
}
protected[async] def nullOut(u: Universe)(name: u.Expr[String], v: u.Expr[Any]): u.Expr[Unit] =
diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala
index 32f09d8..3ef9da5 100644
--- a/src/main/scala/scala/async/internal/ExprBuilder.scala
+++ b/src/main/scala/scala/async/internal/ExprBuilder.scala
@@ -3,7 +3,6 @@
*/
package scala.async.internal
-import scala.reflect.macros.Context
import scala.collection.mutable.ListBuffer
import collection.mutable
import language.existentials
@@ -34,18 +33,17 @@ trait ExprBuilder {
var stats: List[Tree]
- def statsAnd(trees: List[Tree]): List[Tree] = {
- val body = stats match {
+ def treesThenStats(trees: List[Tree]): List[Tree] = {
+ (stats match {
case init :+ last if tpeOf(last) =:= definitions.NothingTpe =>
- adaptToUnit(init :+ Typed(last, TypeTree(definitions.AnyTpe)))
+ adaptToUnit((trees ::: init) :+ Typed(last, TypeTree(definitions.AnyTpe)))
case _ =>
- adaptToUnit(stats)
- }
- Try(body, Nil, adaptToUnit(trees)) :: Nil
+ adaptToUnit(trees ::: stats)
+ }) :: Nil
}
final def allStats: List[Tree] = this match {
- case a: AsyncStateWithAwait => statsAnd(a.awaitable.resultValDef :: Nil)
+ case a: AsyncStateWithAwait => treesThenStats(a.awaitable.resultValDef :: Nil)
case _ => stats
}
@@ -63,7 +61,7 @@ trait ExprBuilder {
List(nextState)
def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = {
- mkHandlerCase(state, statsAnd(mkStateTree(nextState, symLookup) :: Nil))
+ mkHandlerCase(state, treesThenStats(mkStateTree(nextState, symLookup) :: Nil))
}
override val toString: String =
@@ -99,10 +97,10 @@ trait ExprBuilder {
if (futureSystemOps.continueCompletedFutureOnSameThread)
If(futureSystemOps.isCompleted(c.Expr[futureSystem.Fut[_]](awaitable.expr)).tree,
adaptToUnit(ifIsFailureTree[T](futureSystemOps.getCompleted[Any](c.Expr[futureSystem.Fut[Any]](awaitable.expr)).tree) :: Nil),
- Block(toList(callOnComplete), Return(literalUnit)))
+ Block(toList(callOnComplete), Return(literalUnit))) :: Nil
else
- Block(toList(callOnComplete), Return(literalUnit))
- mkHandlerCase(state, stats ++ List(mkStateTree(onCompleteState, symLookup), tryGetOrCallOnComplete))
+ toList(callOnComplete) ::: Return(literalUnit) :: Nil
+ mkHandlerCase(state, stats ++ List(mkStateTree(onCompleteState, symLookup)) ++ tryGetOrCallOnComplete)
}
private def tryGetTree(tryReference: => Tree) =
@@ -251,12 +249,17 @@ trait ExprBuilder {
case LabelDef(name, _, _) => name.toString.startsWith("case")
case _ => false
}
- val (before, _ :: after) = (stats :+ expr).span(_ ne t)
- before.reverse.takeWhile(isPatternCaseLabelDef) ::: after.takeWhile(isPatternCaseLabelDef)
+ val span = (stats :+ expr).filterNot(isLiteralUnit).span(_ ne t)
+ span match {
+ case (before, _ :: after) =>
+ before.reverse.takeWhile(isPatternCaseLabelDef) ::: after.takeWhile(isPatternCaseLabelDef)
+ case _ =>
+ stats :+ expr
+ }
}
// populate asyncStates
- for (stat <- (stats :+ expr)) stat match {
+ def add(stat: Tree): Unit = stat match {
// the val name = await(..) pattern
case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) =>
val onCompleteState = nextState()
@@ -315,10 +318,13 @@ trait ExprBuilder {
asyncStates ++= builder.asyncStates
currState = afterLabelState
stateBuilder = new AsyncStateBuilder(currState, symLookup)
+ case b @ Block(stats, expr) =>
+ (stats :+ expr) foreach (add)
case _ =>
checkForUnsupportedAwait(stat)
stateBuilder += stat
}
+ for (stat <- (stats :+ expr)) add(stat)
val lastState = stateBuilder.resultSimple(endState)
asyncStates += lastState
}
@@ -392,7 +398,10 @@ trait ExprBuilder {
* }
*/
private def resumeFunTree[T: WeakTypeTag]: Tree = {
- val body = Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]))
+ val stateMemberSymbol = symLookup.stateMachineMember(name.state)
+ val stateMemberRef = symLookup.memberRef(name.state)
+ val body = Match(stateMemberRef, mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ++ List(CaseDef(Ident(nme.WILDCARD), EmptyTree, Throw(Apply(Select(New(Ident(defn.IllegalStateExceptionClass)), termNames.CONSTRUCTOR), List())))))
+
Try(
body,
List(
@@ -462,13 +471,24 @@ trait ExprBuilder {
private def tpeOf(t: Tree): Type = t match {
case _ if t.tpe != null => t.tpe
case Try(body, Nil, _) => tpeOf(body)
+ case Block(_, expr) => tpeOf(expr)
+ case Literal(Constant(value)) if value == () => definitions.UnitTpe
+ case Return(_) => definitions.NothingTpe
case _ => NoType
}
private def adaptToUnit(rhs: List[Tree]): c.universe.Block = {
rhs match {
+ case (rhs: Block) :: Nil if tpeOf(rhs) <:< definitions.UnitTpe =>
+ rhs
case init :+ last if tpeOf(last) <:< definitions.UnitTpe =>
Block(init, last)
+ case init :+ (last @ Literal(Constant(()))) =>
+ Block(init, last)
+ case init :+ (last @ Block(_, Return(_) | Literal(Constant(())))) =>
+ Block(init, last)
+ case init :+ (Block(stats, expr)) =>
+ Block(init, Block(stats :+ expr, literalUnit))
case _ =>
Block(rhs, literalUnit)
}
diff --git a/src/main/scala/scala/async/internal/Lifter.scala b/src/main/scala/scala/async/internal/Lifter.scala
index 2998baf..9481f69 100644
--- a/src/main/scala/scala/async/internal/Lifter.scala
+++ b/src/main/scala/scala/async/internal/Lifter.scala
@@ -76,8 +76,9 @@ trait Lifter {
// are already accounted for.
val stateIdToDirectlyReferenced: Map[Int, List[Symbol]] = {
val refs: List[(Int, Symbol)] = asyncStates.flatMap(
- asyncState => asyncState.stats.filterNot(_.isDef).flatMap(_.collect {
- case rt: RefTree if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol)
+ asyncState => asyncState.stats.filterNot(t => t.isDef && !isLabel(t.symbol)).flatMap(_.collect {
+ case rt: RefTree
+ if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol)
})
)
toMultiMap(refs)
diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala
index e15ef1b..2999be2 100644
--- a/src/main/scala/scala/async/internal/TransformUtils.scala
+++ b/src/main/scala/scala/async/internal/TransformUtils.scala
@@ -167,8 +167,8 @@ private[async] trait TransformUtils {
val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal")
val ThrowableClass = rootMirror.staticClass("java.lang.Throwable")
- val Async_async = asyncBase.asyncMethod(c.universe)(c.macroApplication.symbol).ensuring(_ != NoSymbol)
- val Async_await = asyncBase.awaitMethod(c.universe)(c.macroApplication.symbol).ensuring(_ != NoSymbol)
+ lazy val Async_async = asyncBase.asyncMethod(c.universe)(c.macroApplication.symbol)
+ lazy val Async_await = asyncBase.awaitMethod(c.universe)(c.macroApplication.symbol)
val IllegalStateExceptionClass = rootMirror.staticClass("java.lang.IllegalStateException")
}
@@ -190,6 +190,10 @@ private[async] trait TransformUtils {
val LABEL = 1L << 17 // not in the public reflection API.
(internal.flags(sym).asInstanceOf[Long] & LABEL) != 0L
}
+ def isSynth(sym: Symbol): Boolean = {
+ val SYNTHETIC = 1 << 21 // not in the public reflection API.
+ (internal.flags(sym).asInstanceOf[Long] & SYNTHETIC) != 0L
+ }
def symId(sym: Symbol): Int = {
val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable]
sym.asInstanceOf[symtab.Symbol].id
@@ -388,7 +392,7 @@ private[async] trait TransformUtils {
catch { case _: ScalaReflectionException => NoSymbol }
}
final def uncheckedBounds(tp: Type): Type = {
- if (tp.typeArgs.isEmpty || UncheckedBoundsClass == NoSymbol) tp
+ if ((tp.typeArgs.isEmpty && (tp match { case _: TypeRef => true; case _ => false}))|| UncheckedBoundsClass == NoSymbol) tp
else withAnnotation(tp, Annotation(UncheckedBoundsClass.asType.toType, Nil, ListMap()))
}
// =====================================
@@ -402,6 +406,8 @@ private[async] trait TransformUtils {
* in search of a sub tree that was decorated with the cached answer.
*/
final def containsAwaitCached(t: Tree): Tree => Boolean = {
+ if (c.macroApplication.symbol == null) return (t => false)
+
def treeCannotContainAwait(t: Tree) = t match {
case _: Ident | _: TypeTree | _: Literal => true
case _ => isAsync(t)
diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala
index 09fa69e..1637102 100644
--- a/src/test/scala/scala/async/TreeInterrogation.scala
+++ b/src/test/scala/scala/async/TreeInterrogation.scala
@@ -54,7 +54,7 @@ class TreeInterrogation {
}
object TreeInterrogation extends App {
- def withDebug[T](t: => T) {
+ def withDebug[T](t: => T): T = {
def set(level: String, value: Boolean) = System.setProperty(s"scala.async.$level", value.toString)
val levels = Seq("trace", "debug")
def setAll(value: Boolean) = levels.foreach(set(_, value))
diff --git a/src/test/scala/scala/async/run/late/LateExpansion.scala b/src/test/scala/scala/async/run/late/LateExpansion.scala
index b866527..a40b1af 100644
--- a/src/test/scala/scala/async/run/late/LateExpansion.scala
+++ b/src/test/scala/scala/async/run/late/LateExpansion.scala
@@ -3,10 +3,12 @@ package scala.async.run.late
import java.io.File
import junit.framework.Assert.assertEquals
-import org.junit.Test
+import org.junit.{Assert, Test}
import scala.annotation.StaticAnnotation
-import scala.async.internal.{AsyncId, AsyncMacro}
+import scala.annotation.meta.{field, getter}
+import scala.async.TreeInterrogation
+import scala.async.internal.AsyncId
import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader
import scala.tools.nsc._
import scala.tools.nsc.plugins.{Plugin, PluginComponent}
@@ -16,6 +18,7 @@ import scala.tools.nsc.transform.TypingTransformers
// Tests for customized use of the async transform from a compiler plugin, which
// calls it from a new phase that runs after patmat.
class LateExpansion {
+
@Test def test0(): Unit = {
val result = wrapAndRun(
"""
@@ -75,6 +78,263 @@ class LateExpansion {
assertEquals("case 3: blerg3", result)
}
+ @Test def polymorphicMethod(): Unit = {
+ val result = run(
+ """
+ |import scala.async.run.late.{autoawait,lateasync}
+ |object Test {
+ | class C { override def toString = "C" }
+ | @autoawait def foo[A <: C](a: A): A = a
+ | @lateasync
+ | def test1[CC <: C](c: CC): (CC, CC) = {
+ | val x: (CC, CC) = 0 match { case _ if false => ???; case _ => (foo(c), foo(c)) }
+ | x
+ | }
+ | def test(): (C, C) = test1(new C)
+ |}
+ | """.stripMargin)
+ assertEquals("(C,C)", result.toString)
+ }
+
+ @Test def shadowing(): Unit = {
+ val result = run(
+ """
+ |import scala.async.run.late.{autoawait,lateasync}
+ |object Test {
+ | trait Foo
+ | trait Bar extends Foo
+ | @autoawait def boundary = ""
+ | @lateasync
+ | def test: Unit = {
+ | (new Bar {}: Any) match {
+ | case foo: Bar =>
+ | boundary
+ | 0 match {
+ | case _ => foo; ()
+ | }
+ | ()
+ | }
+ | ()
+ | }
+ |}
+ | """.stripMargin)
+ }
+
+ @Test def shadowing0(): Unit = {
+ val result = run(
+ """
+ |import scala.async.run.late.{autoawait,lateasync}
+ |object Test {
+ | trait Foo
+ | trait Bar
+ | def test: Any = test(new C)
+ | @autoawait def asyncBoundary: String = ""
+ | @lateasync
+ | def test(foo: Foo): Foo = foo match {
+ | case foo: Bar =>
+ | val foo2: Foo with Bar = new Foo with Bar {}
+ | asyncBoundary
+ | null match {
+ | case _ => foo2
+ | }
+ | case other => foo
+ | }
+ | class C extends Foo with Bar
+ |}
+ | """.stripMargin)
+ }
+ @Test def shadowing2(): Unit = {
+ val result = run(
+ """
+ |import scala.async.run.late.{autoawait,lateasync}
+ |object Test {
+ | trait Base; trait Foo[T <: Base] { @autoawait def func: Option[Foo[T]] = None }
+ | class Sub extends Base
+ | trait Bar extends Foo[Sub]
+ | def test: Any = test(new Bar {})
+ | @lateasync
+ | def test[T <: Base](foo: Foo[T]): Foo[T] = foo match {
+ | case foo: Bar =>
+ | val res = foo.func
+ | res match {
+ | case _ =>
+ | }
+ | foo
+ | case other => foo
+ | }
+ | test(new Bar {})
+ |}
+ | """.stripMargin)
+ }
+
+ @Test def patternAlternative(): Unit = {
+ val result = wrapAndRun(
+ """
+ | @autoawait def one = 1
+ |
+ | @lateasync def test = {
+ | Option(true) match {
+ | case null | None => false
+ | case Some(v) => one; v
+ | }
+ | }
+ | """.stripMargin)
+ }
+
+ @Test def patternAlternativeBothAnnotations(): Unit = {
+ val result = wrapAndRun(
+ """
+ |import scala.async.run.late.{autoawait,lateasync}
+ |object Test {
+ | @autoawait def func1() = "hello"
+ | @lateasync def func(a: Option[Boolean]) = a match {
+ | case null | None => func1 + " world"
+ | case _ => "okay"
+ | }
+ | def test: Any = func(None)
+ |}
+ | """.stripMargin)
+ }
+
+ @Test def shadowingRefinedTypes(): Unit = {
+ val result = run(
+ s"""
+ |import scala.async.run.late.{autoawait,lateasync}
+ |trait Base
+ |class Sub extends Base
+ |trait Foo[T <: Base] {
+ | @autoawait def func: Option[Foo[T]] = None
+ |}
+ |trait Bar extends Foo[Sub]
+ |object Test {
+ | @lateasync def func[T <: Base](foo: Foo[T]): Foo[T] = foo match { // the whole pattern match will be wrapped with async{ }
+ | case foo: Bar =>
+ | val res = foo.func // will be rewritten into: await(foo.func)
+ | res match {
+ | case Some(v) => v // this will report type mismtach
+ | case other => foo
+ | }
+ | case other => foo
+ | }
+ | def test: Any = { val b = new Bar{}; func(b) == b }
+ |}""".stripMargin)
+ assertEquals(true, result)
+ }
+
+ @Test def testMatchEndIssue(): Unit = {
+ val result = run(
+ """
+ |import scala.async.run.late.{autoawait,lateasync}
+ |sealed trait Subject
+ |final class Principal(val name: String) extends Subject
+ |object Principal {
+ | def unapply(p: Principal): Option[String] = Some(p.name)
+ |}
+ |object Test {
+ | @autoawait @lateasync
+ | def containsPrincipal(search: String, value: Subject): Boolean = value match {
+ | case Principal(name) if name == search => true
+ | case Principal(name) => containsPrincipal(search, value)
+ | case other => false
+ | }
+ |
+ | @lateasync
+ | def test = containsPrincipal("test", new Principal("test"))
+ |}
+ | """.stripMargin)
+ }
+
+ @Test def testGenericTypeBoundaryIssue(): Unit = {
+ val result = run(
+ """
+ import scala.async.run.late.{autoawait,lateasync}
+ trait InstrumentOfValue
+ trait Security[T <: InstrumentOfValue] extends InstrumentOfValue
+ class Bound extends Security[Bound]
+ class Futures extends Security[Futures]
+ object TestGenericTypeBoundIssue {
+ @autoawait @lateasync def processBound(bound: Bound): Unit = { println("process Bound") }
+ @autoawait @lateasync def processFutures(futures: Futures): Unit = { println("process Futures") }
+ @autoawait @lateasync def doStuff(sec: Security[_]): Unit = {
+ sec match {
+ case bound: Bound => processBound(bound)
+ case futures: Futures => processFutures(futures)
+ case _ => throw new Exception("Unknown Security type: " + sec)
+ }
+ }
+ }
+ """.stripMargin)
+ }
+
+ @Test def testReturnTupleIssue(): Unit = {
+ val result = run(
+ """
+ import scala.async.run.late.{autoawait,lateasync}
+ class TestReturnExprIssue(str: String) {
+ @autoawait @lateasync def getTestValue = Some(42)
+ @autoawait @lateasync def doStuff: Int = {
+ val opt: Option[Int] = getTestValue // here we have an async method invoke
+ opt match {
+ case Some(li) => li // use the result somehow
+ case None =>
+ }
+ 42 // type mismatch; found : AnyVal required: Int
+ }
+ }
+ """.stripMargin)
+ }
+
+
+ @Test def testAfterRefchecksIssue(): Unit = {
+ val result = run(
+ """
+ import scala.async.run.late.{autoawait,lateasync}
+ trait Factory[T] { def create: T }
+ sealed trait TimePoint
+ class TimeLine[TP <: TimePoint](val tpInitial: Factory[TP]) {
+ @autoawait @lateasync private[TimeLine] val tp: TP = tpInitial.create
+ @autoawait @lateasync def timePoint: TP = tp
+ }
+ object Test {
+ def test: Unit = ()
+ }
+ """)
+ }
+
+ @Test def testArrayIndexOutOfBoundIssue(): Unit = {
+ val result = run(
+ """
+ import scala.async.run.late.{autoawait,lateasync}
+
+ sealed trait Result
+ case object A extends Result
+ case object B extends Result
+ case object C extends Result
+
+ object Test {
+ protected def doStuff(res: Result) = {
+ class C {
+ @autoawait def needCheck = false
+
+ @lateasync def m = {
+ if (needCheck) "NO"
+ else {
+ res match {
+ case A => 1
+ case _ => 2
+ }
+ }
+ }
+ }
+ }
+
+
+ @lateasync
+ def test() = doStuff(B)
+ }
+ """)
+ }
+
def wrapAndRun(code: String): Any = {
run(
s"""
@@ -88,10 +348,49 @@ class LateExpansion {
| """.stripMargin)
}
+
+ @Test def testNegativeArraySizeException(): Unit = {
+ val result = run(
+ """
+ import scala.async.run.late.{autoawait,lateasync}
+
+ object Test {
+ def foo(foo: Any, bar: Any) = ()
+ @autoawait def getValue = 4.2
+ @lateasync def func(f: Any) = {
+ foo(f match { case _ if "".isEmpty => 2 }, getValue);
+ }
+
+ @lateasync
+ def test() = func(4)
+ }
+ """)
+ }
+ @Test def testNegativeArraySizeExceptionFine1(): Unit = {
+ val result = run(
+ """
+ import scala.async.run.late.{autoawait,lateasync}
+ case class FixedFoo(foo: Int)
+ class Foobar(val foo: Int, val bar: Double) {
+ @autoawait @lateasync def getValue = 4.2
+ @autoawait @lateasync def func(f: Any) = {
+ new Foobar(foo = f match {
+ case FixedFoo(x) => x
+ case _ => 2
+ },
+ bar = getValue)
+ }
+ }
+ object Test {
+ @lateasync def test() = new Foobar(0, 0).func(4)
+ }
+ """)
+ }
def run(code: String): Any = {
val reporter = new StoreReporter
val settings = new Settings(println(_))
- settings.outdir.value = sys.props("java.io.tmpdir")
+ // settings.processArgumentString("-Xprint:patmat,postpatmat,jvm -Ybackend:GenASM -nowarn")
+ settings.outdir.value = "/tmp"
settings.embeddedDefaults(getClass.getClassLoader)
val isInSBT = !settings.classpath.isSetByUser
if (isInSBT) settings.usejavacp.value = true
@@ -108,8 +407,10 @@ class LateExpansion {
val run = new Run
val source = newSourceFile(code)
- run.compileSources(source :: Nil)
- assert(!reporter.hasErrors, reporter.infos.mkString("\n"))
+// TreeInterrogation.withDebug {
+ run.compileSources(source :: Nil)
+// }
+ Assert.assertTrue(reporter.infos.mkString("\n"), !reporter.hasErrors)
val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader)
val cls = loader.loadClass("Test")
cls.getMethod("test").invoke(null)
@@ -133,20 +434,26 @@ abstract class LatePlugin extends Plugin {
super.transform(tree) match {
case ap@Apply(fun, args) if fun.symbol.hasAnnotation(autoAwaitSym) =>
localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(ap.tpe) :: Nil), ap :: Nil))
+ case sel@Select(fun, _) if sel.symbol.hasAnnotation(autoAwaitSym) && !(tree.tpe.isInstanceOf[MethodTypeApi] || tree.tpe.isInstanceOf[PolyTypeApi] ) =>
+ localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(sel.tpe) :: Nil), sel :: Nil))
case dd: DefDef if dd.symbol.hasAnnotation(lateAsyncSym) => atOwner(dd.symbol) {
- val expandee = localTyper.context.withMacrosDisabled(
- localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(dd.rhs.tpe) :: Nil), List(dd.rhs)))
- )
- val c = analyzer.macroContext(localTyper, gen.mkAttributedRef(asyncIdSym), expandee)
- val asyncMacro = AsyncMacro(c, AsyncId)(dd.rhs)
- val code = asyncMacro.asyncTransform[Any](localTyper.typed(Literal(Constant(()))))(c.weakTypeTag[Any])
- deriveDefDef(dd)(_ => localTyper.typed(code))
+ deriveDefDef(dd){ rhs: Tree =>
+ val invoke = Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(rhs.tpe) :: Nil), List(rhs))
+ localTyper.typed(atPos(dd.pos)(invoke))
+ }
}
+ case vd: ValDef if vd.symbol.hasAnnotation(lateAsyncSym) => atOwner(vd.symbol) {
+ deriveValDef(vd){ rhs: Tree =>
+ val invoke = Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(rhs.tpe) :: Nil), List(rhs))
+ localTyper.typed(atPos(vd.pos)(invoke))
+ }
+ }
+ case vd: ValDef =>
+ vd
case x => x
}
}
}
-
override def newPhase(prev: Phase): Phase = new StdPhase(prev) {
override def apply(unit: CompilationUnit): Unit = {
val translated = newTransformer(unit).transformUnit(unit)
@@ -155,7 +462,7 @@ abstract class LatePlugin extends Plugin {
}
}
- override val runsAfter: List[String] = "patmat" :: Nil
+ override val runsAfter: List[String] = "refchecks" :: Nil
override val phaseName: String = "postpatmat"
})
@@ -164,7 +471,9 @@ abstract class LatePlugin extends Plugin {
}
// Methods with this annotation are translated to having the RHS wrapped in `AsyncId.async { <original RHS> }`
+@field
final class lateasync extends StaticAnnotation
// Calls to methods with this annotation are translated to `AsyncId.await(<call>)`
+@getter
final class autoawait extends StaticAnnotation