diff options
Diffstat (limited to 'src')
13 files changed, 592 insertions, 117 deletions
diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala index 4545ca6..3d14fb7 100644 --- a/src/main/scala/scala/async/internal/AnfTransform.scala +++ b/src/main/scala/scala/async/internal/AnfTransform.scala @@ -27,6 +27,27 @@ private[async] trait AnfTransform { val tree1 = adjustTypeOfTranslatedPatternMatches(block, owner) var mode: AnfMode = Anf + + object trace { + private var indent = -1 + + private def indentString = " " * indent + + def apply[T](args: Any)(t: => T): T = { + def prefix = mode.toString.toLowerCase + indent += 1 + def oneLine(s: Any) = s.toString.replaceAll("""\n""", "\\\\n").take(127) + try { + AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") + val result = t + AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") + result + } finally { + indent -= 1 + } + } + } + typingTransform(tree1, owner)((tree, api) => { def blockToList(tree: Tree): List[Tree] = tree match { case Block(stats, expr) => stats :+ expr @@ -47,7 +68,7 @@ private[async] trait AnfTransform { def transformToBlock(tree: Tree): Block = listToBlock(transformToList(tree)) def _transformToList(tree: Tree): List[Tree] = trace(tree) { - val stats :+ expr = anf.transformToList(tree) + val stats :+ expr = _anf.transformToList(tree) def statsExprUnit = stats :+ expr :+ api.typecheck(atPos(expr.pos)(Literal(Constant(())))) def statsExprThrow = @@ -97,8 +118,11 @@ private[async] trait AnfTransform { val ifWithAssign = treeCopy.If(tree, cond, branchWithAssign(thenp), branchWithAssign(elsep)).setType(definitions.UnitTpe) stats :+ varDef :+ ifWithAssign :+ atPos(tree.pos)(gen.mkAttributedStableRef(varDef.symbol)).setType(tree.tpe) } - case LabelDef(name, params, rhs) => - statsExprUnit + case ld @ LabelDef(name, params, rhs) => + if (ld.symbol.info.resultType.typeSymbol == definitions.UnitClass) + statsExprUnit + else + stats :+ expr case Match(scrut, cases) => // if type of match is Unit don't introduce assignment, @@ -134,32 +158,12 @@ private[async] trait AnfTransform { } } - object trace { - private var indent = -1 - - private def indentString = " " * indent - - def apply[T](args: Any)(t: => T): T = { - def prefix = mode.toString.toLowerCase - indent += 1 - def oneLine(s: Any) = s.toString.replaceAll("""\n""", "\\\\n").take(127) - try { - AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") - val result = t - AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") - result - } finally { - indent -= 1 - } - } - } - def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = { val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, SYNTHETIC).setInfo(uncheckedBounds(lhs.tpe)) internal.valDef(sym, internal.changeOwner(lhs, api.currentOwner, sym)).setType(NoType).setPos(pos) } - object anf { + object _anf { def transformToList(tree: Tree): List[Tree] = { mode = Anf; blockToList(api.recur(tree)) } @@ -219,8 +223,29 @@ private[async] trait AnfTransform { funStats ++ argStatss.flatten.flatten :+ typedNewApply case Block(stats, expr) => - val trees = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit) ::: linearize.transformToList(expr) - eliminateMatchEndLabelParameter(trees) + val stats1 = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit) + val exprs1 = linearize.transformToList(expr) + val trees = stats1 ::: exprs1 + def isMatchEndLabel(t: Tree): Boolean = t match { + case ValDef(_, _, _, t) if isMatchEndLabel(t) => true + case ld: LabelDef if ld.name.toString.startsWith("matchEnd") => true + case _ => false + } + def groupsEndingWith[T](ts: List[T])(f: T => Boolean): List[List[T]] = if (ts.isEmpty) Nil else { + ts.indexWhere(f) match { + case -1 => List(ts) + case i => + val (ts1, ts2) = ts.splitAt(i + 1) + ts1 :: groupsEndingWith(ts2)(f) + } + } + val matchGroups = groupsEndingWith(trees)(isMatchEndLabel) + val trees1 = matchGroups.flatMap(eliminateMatchEndLabelParameter) + val result = trees1 flatMap { + case Block(stats, expr) => stats :+ expr + case t => t :: Nil + } + result case ValDef(mods, name, tpt, rhs) => if (containsAwait(rhs)) { @@ -260,7 +285,10 @@ private[async] trait AnfTransform { scrutStats :+ treeCopy.Match(tree, scrutExpr, caseDefs) case LabelDef(name, params, rhs) => - List(LabelDef(name, params, newBlock(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) + if (tree.symbol.info.typeSymbol == definitions.UnitClass) + List(treeCopy.LabelDef(tree, name, params, api.typecheck(newBlock(linearize.transformToList(rhs), Literal(Constant(()))))).setSymbol(tree.symbol)) + else + List(treeCopy.LabelDef(tree, name, params, api.typecheck(listToBlock(linearize.transformToList(rhs)))).setSymbol(tree.symbol)) case TypeApply(fun, targs) => val funStats :+ simpleFun = linearize.transformToList(fun) @@ -274,7 +302,7 @@ private[async] trait AnfTransform { // Replace the label parameters on `matchEnd` with use of a `matchRes` temporary variable // - // CaseDefs are translated to labels without parmeters. A terminal label, `matchEnd`, accepts + // CaseDefs are translated to labels without parameters. A terminal label, `matchEnd`, accepts // a parameter which is the result of the match (this is regular, so even Unit-typed matches have this). // // For our purposes, it is easier to: @@ -286,41 +314,78 @@ private[async] trait AnfTransform { val caseDefToMatchResult = collection.mutable.Map[Symbol, Symbol]() val matchResults = collection.mutable.Buffer[Tree]() - val statsExpr0 = statsExpr.reverseMap { - case ld @ LabelDef(_, param :: Nil, body) => + def modifyLabelDef(ld: LabelDef): (Tree, Tree) = { + val symTab = c.universe.asInstanceOf[reflect.internal.SymbolTable] + val param = ld.params.head + val ld2 = if (ld.params.head.tpe.typeSymbol == definitions.UnitClass) { + // Unit typed match: eliminate the label def parameter, but don't create a matchres temp variable to + // store the result for cleaner generated code. + caseDefToMatchResult(ld.symbol) = NoSymbol + val rhs2 = substituteTrees(ld.rhs, param.symbol :: Nil, api.typecheck(literalUnit) :: Nil) + (treeCopy.LabelDef(ld, ld.name, Nil, api.typecheck(literalUnit)), rhs2) + } else { + // Otherwise, create the matchres var. We'll callers of the label def below. + // Remember: we're iterating through the statement sequence in reverse, so we'll get + // to the LabelDef and mutate `matchResults` before we'll get to its callers. val matchResult = linearize.defineVar(name.matchRes, param.tpe, ld.pos) matchResults += matchResult caseDefToMatchResult(ld.symbol) = matchResult.symbol - val ld2 = treeCopy.LabelDef(ld, ld.name, Nil, body.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil)) - setInfo(ld.symbol, methodType(Nil, ld.symbol.info.resultType)) - ld2 + val rhs2 = ld.rhs.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil) + (treeCopy.LabelDef(ld, ld.name, Nil, api.typecheck(literalUnit)), rhs2) + } + setInfo(ld.symbol, methodType(Nil, definitions.UnitTpe)) + ld2 + } + val statsExpr0 = statsExpr.reverse.flatMap { + case ld @ LabelDef(_, param :: Nil, _) => + val (ld1, after) = modifyLabelDef(ld) + List(after, ld1) + case a @ ValDef(mods, name, tpt, ld @ LabelDef(_, param :: Nil, _)) => + val (ld1, after) = modifyLabelDef(ld) + List(treeCopy.ValDef(a, mods, name, tpt, after), ld1) case t => - if (caseDefToMatchResult.isEmpty) t - else typingTransform(t)((tree, api) => + if (caseDefToMatchResult.isEmpty) t :: Nil + else typingTransform(t)((tree, api) => { + def typedPos(pos: Position)(t: Tree): Tree = + api.typecheck(atPos(pos)(t)) tree match { case Apply(fun, arg :: Nil) if isLabel(fun.symbol) && caseDefToMatchResult.contains(fun.symbol) => - api.typecheck(atPos(tree.pos)(newBlock(Assign(Ident(caseDefToMatchResult(fun.symbol)), api.recur(arg)) :: Nil, treeCopy.Apply(tree, fun, Nil)))) - case Block(stats, expr) => + val temp = caseDefToMatchResult(fun.symbol) + if (temp == NoSymbol) + typedPos(tree.pos)(newBlock(api.recur(arg) :: Nil, treeCopy.Apply(tree, fun, Nil))) + else + // setType needed for LateExpansion.shadowingRefinedType test case. There seems to be an inconsistency + // in the trees after pattern matcher. + // TODO miminize the problem in patmat and fix in scalac. + typedPos(tree.pos)(newBlock(Assign(Ident(temp), api.recur(internal.setType(arg, fun.tpe.paramLists.head.head.info))) :: Nil, treeCopy.Apply(tree, fun, Nil))) + case Block(stats, expr: Apply) if isLabel(expr.symbol) => api.default(tree) match { - case Block(stats, Block(stats1, expr)) => - treeCopy.Block(tree, stats ::: stats1, expr) + case Block(stats0, Block(stats1, expr1)) => + // flatten the block returned by `case Apply` above into the enclosing block for + // cleaner generated code. + treeCopy.Block(tree, stats0 ::: stats1, expr1) case t => t } case _ => api.default(tree) } - ) + }) :: Nil } matchResults.toList match { - case Nil => statsExpr - case r1 :: Nil => (r1 +: statsExpr0.reverse) :+ atPos(tree.pos)(gen.mkAttributedIdent(r1.symbol)) + case _ if caseDefToMatchResult.isEmpty => + statsExpr // return the original trees if nothing changed + case Nil => + statsExpr0.reverse :+ literalUnit // must have been a unit-typed match, no matchRes variable to definne or refer to + case r1 :: Nil => + // { var matchRes = _; ....; matchRes } + (r1 +: statsExpr0.reverse) :+ atPos(tree.pos)(gen.mkAttributedIdent(r1.symbol)) case _ => c.error(macroPos, "Internal error: unexpected tree encountered during ANF transform " + statsExpr); statsExpr } } def anfLinearize(tree: Tree): Block = { val trees: List[Tree] = mode match { - case Anf => anf._transformToList(tree) + case Anf => _anf._transformToList(tree) case Linearizing => linearize._transformToList(tree) } listToBlock(trees) diff --git a/src/main/scala/scala/async/internal/AsyncAnalysis.scala b/src/main/scala/scala/async/internal/AsyncAnalysis.scala index 6540bdb..bd64f17 100644 --- a/src/main/scala/scala/async/internal/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/internal/AsyncAnalysis.scala @@ -4,6 +4,7 @@ package scala.async.internal +import scala.collection.mutable.ListBuffer import scala.reflect.macros.Context import scala.collection.mutable @@ -53,14 +54,15 @@ trait AsyncAnalysis { } override def traverse(tree: Tree) { - def containsAwait = tree exists isAwait tree match { - case Try(_, _, _) if containsAwait => + case Try(_, _, _) if containsAwait(tree) => reportUnsupportedAwait(tree, "try/catch") super.traverse(tree) case Return(_) => c.abort(tree.pos, "return is illegal within a async block") - case DefDef(mods, _, _, _, _, _) if mods.hasFlag(Flag.LAZY) && containsAwait => + case DefDef(mods, _, _, _, _, _) if mods.hasFlag(Flag.LAZY) && containsAwait(tree) => + reportUnsupportedAwait(tree, "lazy val initializer") + case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) && containsAwait(tree) => reportUnsupportedAwait(tree, "lazy val initializer") case CaseDef(_, guard, _) if guard exists isAwait => // TODO lift this restriction @@ -74,9 +76,19 @@ trait AsyncAnalysis { * @return true, if the tree contained an unsupported await. */ private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = { - val badAwaits: List[RefTree] = tree collect { - case rt: RefTree if isAwait(rt) => rt + val badAwaits = ListBuffer[Tree]() + object traverser extends Traverser { + override def traverse(tree: Tree): Unit = { + if (!isAsync(tree)) + super.traverse(tree) + tree match { + case rt: RefTree if isAwait(rt) => + badAwaits += rt + case _ => + } + } } + traverser(tree) badAwaits foreach { tree => reportError(tree.pos, s"await must not be used under a $whyUnsupported.") diff --git a/src/main/scala/scala/async/internal/AsyncBase.scala b/src/main/scala/scala/async/internal/AsyncBase.scala index 7a1e274..ec9dc25 100644 --- a/src/main/scala/scala/async/internal/AsyncBase.scala +++ b/src/main/scala/scala/async/internal/AsyncBase.scala @@ -53,9 +53,16 @@ abstract class AsyncBase { c.Expr[futureSystem.Fut[T]](code) } + protected[async] def asyncMethod(u: Universe)(asyncMacroSymbol: u.Symbol): u.Symbol = { + import u._ + if (asyncMacroSymbol == null) NoSymbol + else asyncMacroSymbol.owner.typeSignature.member(newTermName("async")) + } + protected[async] def awaitMethod(u: Universe)(asyncMacroSymbol: u.Symbol): u.Symbol = { import u._ - asyncMacroSymbol.owner.typeSignature.member(newTermName("await")) + if (asyncMacroSymbol == null) NoSymbol + else asyncMacroSymbol.owner.typeSignature.member(newTermName("await")) } protected[async] def nullOut(u: Universe)(name: u.Expr[String], v: u.Expr[Any]): u.Expr[Unit] = diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala index e22407d..57c73fc 100644 --- a/src/main/scala/scala/async/internal/AsyncMacro.scala +++ b/src/main/scala/scala/async/internal/AsyncMacro.scala @@ -11,7 +11,7 @@ object AsyncMacro { // These members are required by `ExprBuilder`: val futureSystem: FutureSystem = base.futureSystem val futureSystemOps: futureSystem.Ops {val c: self.c.type} = futureSystem.mkOps(c) - val containsAwait: c.Tree => Boolean = containsAwaitCached(body0) + var containsAwait: c.Tree => Boolean = containsAwaitCached(body0) } } } @@ -22,7 +22,7 @@ private[async] trait AsyncMacro val c: scala.reflect.macros.Context val body: c.Tree - val containsAwait: c.Tree => Boolean + var containsAwait: c.Tree => Boolean lazy val macroPos = c.macroApplication.pos.makeTransparent def atMacroPos(t: c.Tree) = c.universe.atPos(macroPos)(t) diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala index af290e4..feb1ea7 100644 --- a/src/main/scala/scala/async/internal/AsyncTransform.scala +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -26,6 +26,9 @@ trait AsyncTransform { val anfTree = futureSystemOps.postAnfTransform(anfTree0) + cleanupContainsAwaitAttachments(anfTree) + containsAwait = containsAwaitCached(anfTree) + val applyDefDefDummyBody: DefDef = { val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(futureSystemOps.tryType[Any]), EmptyTree))) DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), literalUnit) @@ -46,8 +49,16 @@ trait AsyncTransform { List(emptyConstructor, stateVar) ++ resultAndAccessors ++ List(execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef) } - val tryToUnit = appliedType(definitions.FunctionClass(1), futureSystemOps.tryType[Any], typeOf[Unit]) - val template = Template(List(tryToUnit, typeOf[() => Unit]).map(TypeTree(_)), emptyValDef, body) + val customParents = futureSystemOps.stateMachineClassParents + val tycon = if (customParents.exists(!_.typeSymbol.asClass.isTrait)) { + // prefer extending a class to reduce the class file size of the state machine. + symbolOf[scala.runtime.AbstractFunction1[Any, Any]] + } else { + // ... unless a custom future system already extends some class + symbolOf[scala.Function1[Any, Any]] + } + val tryToUnit = appliedType(tycon, futureSystemOps.tryType[Any], typeOf[Unit]) + val template = Template((futureSystemOps.stateMachineClassParents ::: List(tryToUnit, typeOf[() => Unit])).map(TypeTree(_)), emptyValDef, body) val t = ClassDef(NoMods, name.stateMachineT, Nil, template) typecheckClassDef(t) @@ -148,10 +159,13 @@ trait AsyncTransform { case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) => api.atOwner(api.currentOwner) { val fieldSym = tree.symbol - val lhs = atPos(tree.pos) { - gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym) + if (fieldSym.asTerm.isLazy) Literal(Constant(())) + else { + val lhs = atPos(tree.pos) { + gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym) + } + treeCopy.Assign(tree, lhs, api.recur(rhs)).setType(definitions.UnitTpe).changeOwner(fieldSym, api.currentOwner) } - treeCopy.Assign(tree, lhs, api.recur(rhs)).setType(definitions.UnitTpe).changeOwner(fieldSym, api.currentOwner) } case _: DefTree if liftedSyms(tree.symbol) => EmptyTree @@ -165,7 +179,7 @@ trait AsyncTransform { } val liftablesUseFields = liftables.map { - case vd: ValDef => vd + case vd: ValDef if !vd.symbol.asTerm.isLazy => vd case x => typingTransform(x, stateMachineClass)(useFields) } diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala index 16b9207..fe62cd6 100644 --- a/src/main/scala/scala/async/internal/ExprBuilder.scala +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -3,7 +3,6 @@ */ package scala.async.internal -import scala.reflect.macros.Context import scala.collection.mutable.ListBuffer import collection.mutable import language.existentials @@ -34,18 +33,17 @@ trait ExprBuilder { var stats: List[Tree] - def statsAnd(trees: List[Tree]): List[Tree] = { - val body = stats match { + def treesThenStats(trees: List[Tree]): List[Tree] = { + (stats match { case init :+ last if tpeOf(last) =:= definitions.NothingTpe => - adaptToUnit(init :+ Typed(last, TypeTree(definitions.AnyTpe))) + adaptToUnit((trees ::: init) :+ Typed(last, TypeTree(definitions.AnyTpe))) case _ => - adaptToUnit(stats) - } - Try(body, Nil, adaptToUnit(trees)) :: Nil + adaptToUnit(trees ::: stats) + }) :: Nil } final def allStats: List[Tree] = this match { - case a: AsyncStateWithAwait => statsAnd(a.awaitable.resultValDef :: Nil) + case a: AsyncStateWithAwait => treesThenStats(a.awaitable.resultValDef :: Nil) case _ => stats } @@ -63,7 +61,7 @@ trait ExprBuilder { List(nextState) def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = { - mkHandlerCase(state, statsAnd(mkStateTree(nextState, symLookup) :: Nil)) + mkHandlerCase(state, treesThenStats(mkStateTree(nextState, symLookup) :: Nil)) } override val toString: String = @@ -95,14 +93,17 @@ trait ExprBuilder { val fun = This(tpnme.EMPTY) val callOnComplete = futureSystemOps.onComplete[Any, Unit](c.Expr[futureSystem.Fut[Any]](awaitable.expr), c.Expr[futureSystem.Tryy[Any] => Unit](fun), c.Expr[futureSystem.ExecContext](Ident(name.execContext))).tree - val tryGetOrCallOnComplete = - if (futureSystemOps.continueCompletedFutureOnSameThread) - If(futureSystemOps.isCompleted(c.Expr[futureSystem.Fut[_]](awaitable.expr)).tree, - adaptToUnit(ifIsFailureTree[T](futureSystemOps.getCompleted[Any](c.Expr[futureSystem.Fut[Any]](awaitable.expr)).tree) :: Nil), + val tryGetOrCallOnComplete: List[Tree] = + if (futureSystemOps.continueCompletedFutureOnSameThread) { + val tempName = name.fresh(name.completed) + val initTemp = ValDef(NoMods, tempName, TypeTree(futureSystemOps.tryType[Any]), futureSystemOps.getCompleted[Any](c.Expr[futureSystem.Fut[Any]](awaitable.expr)).tree) + val ifTree = If(Apply(Select(Literal(Constant(null)), TermName("ne")), Ident(tempName) :: Nil), + adaptToUnit(ifIsFailureTree[T](Ident(tempName)) :: Nil), Block(toList(callOnComplete), Return(literalUnit))) - else - Block(toList(callOnComplete), Return(literalUnit)) - mkHandlerCase(state, stats ++ List(mkStateTree(onCompleteState, symLookup), tryGetOrCallOnComplete)) + initTemp :: ifTree :: Nil + } else + toList(callOnComplete) ::: Return(literalUnit) :: Nil + mkHandlerCase(state, stats ++ List(mkStateTree(onCompleteState, symLookup)) ++ tryGetOrCallOnComplete) } private def tryGetTree(tryReference: => Tree) = @@ -237,10 +238,8 @@ trait ExprBuilder { var stateBuilder = new AsyncStateBuilder(startState, symLookup) var currState = startState - def checkForUnsupportedAwait(tree: Tree) = if (tree exists { - case Apply(fun, _) if isAwait(fun) => true - case _ => false - }) c.abort(tree.pos, "await must not be used in this position") + def checkForUnsupportedAwait(tree: Tree) = if (containsAwait(tree)) + c.abort(tree.pos, "await must not be used in this position") def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = { val (nestedStats, nestedExpr) = statsAndExpr(nestedTree) @@ -253,12 +252,17 @@ trait ExprBuilder { case LabelDef(name, _, _) => name.toString.startsWith("case") case _ => false } - val (before, _ :: after) = (stats :+ expr).span(_ ne t) - before.reverse.takeWhile(isPatternCaseLabelDef) ::: after.takeWhile(isPatternCaseLabelDef) + val span = (stats :+ expr).filterNot(isLiteralUnit).span(_ ne t) + span match { + case (before, _ :: after) => + before.reverse.takeWhile(isPatternCaseLabelDef) ::: after.takeWhile(isPatternCaseLabelDef) + case _ => + stats :+ expr + } } // populate asyncStates - for (stat <- (stats :+ expr)) stat match { + def add(stat: Tree): Unit = stat match { // the val name = await(..) pattern case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => val onCompleteState = nextState() @@ -317,10 +321,13 @@ trait ExprBuilder { asyncStates ++= builder.asyncStates currState = afterLabelState stateBuilder = new AsyncStateBuilder(currState, symLookup) + case b @ Block(stats, expr) => + (stats :+ expr) foreach (add) case _ => checkForUnsupportedAwait(stat) stateBuilder += stat } + for (stat <- (stats :+ expr)) add(stat) val lastState = stateBuilder.resultSimple(endState) asyncStates += lastState } @@ -359,8 +366,8 @@ trait ExprBuilder { val caseForLastState: CaseDef = { val lastState = asyncStates.last val lastStateBody = c.Expr[T](lastState.body) - val rhs = futureSystemOps.completeProm( - c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryySuccess[T](lastStateBody)) + val rhs = futureSystemOps.completeWithSuccess( + c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), lastStateBody) mkHandlerCase(lastState.state, Block(rhs.tree, Return(literalUnit))) } asyncStates.toList match { @@ -394,7 +401,10 @@ trait ExprBuilder { * } */ private def resumeFunTree[T: WeakTypeTag]: Tree = { - val body = Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T])) + val stateMemberSymbol = symLookup.stateMachineMember(name.state) + val stateMemberRef = symLookup.memberRef(name.state) + val body = Match(stateMemberRef, mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ++ List(CaseDef(Ident(nme.WILDCARD), EmptyTree, Throw(Apply(Select(New(Ident(defn.IllegalStateExceptionClass)), termNames.CONSTRUCTOR), List()))))) + Try( body, List( @@ -464,13 +474,24 @@ trait ExprBuilder { private def tpeOf(t: Tree): Type = t match { case _ if t.tpe != null => t.tpe case Try(body, Nil, _) => tpeOf(body) + case Block(_, expr) => tpeOf(expr) + case Literal(Constant(value)) if value == () => definitions.UnitTpe + case Return(_) => definitions.NothingTpe case _ => NoType } private def adaptToUnit(rhs: List[Tree]): c.universe.Block = { rhs match { + case (rhs: Block) :: Nil if tpeOf(rhs) <:< definitions.UnitTpe => + rhs case init :+ last if tpeOf(last) <:< definitions.UnitTpe => Block(init, last) + case init :+ (last @ Literal(Constant(()))) => + Block(init, last) + case init :+ (last @ Block(_, Return(_) | Literal(Constant(())))) => + Block(init, last) + case init :+ (Block(stats, expr)) => + Block(init, Block(stats :+ expr, literalUnit)) case _ => Block(rhs, literalUnit) } diff --git a/src/main/scala/scala/async/internal/FutureSystem.scala b/src/main/scala/scala/async/internal/FutureSystem.scala index 6fccfdd..f330cbf 100644 --- a/src/main/scala/scala/async/internal/FutureSystem.scala +++ b/src/main/scala/scala/async/internal/FutureSystem.scala @@ -33,6 +33,7 @@ trait FutureSystem { def promType[A: WeakTypeTag]: Type def tryType[A: WeakTypeTag]: Type def execContextType: Type + def stateMachineClassParents: List[Type] = Nil /** Create an empty promise */ def createProm[A: WeakTypeTag]: Expr[Prom[A]] @@ -48,13 +49,16 @@ trait FutureSystem { execContext: Expr[ExecContext]): Expr[Unit] def continueCompletedFutureOnSameThread = false - def isCompleted(future: Expr[Fut[_]]): Expr[Boolean] = - throw new UnsupportedOperationException("isCompleted not supported by this FutureSystem") + + /** Return `null` if this future is not yet completed, or `Tryy[A]` with the completed result + * otherwise + */ def getCompleted[A: WeakTypeTag](future: Expr[Fut[A]]): Expr[Tryy[A]] = throw new UnsupportedOperationException("getCompleted not supported by this FutureSystem") /** Complete a promise with a value */ def completeProm[A](prom: Expr[Prom[A]], value: Expr[Tryy[A]]): Expr[Unit] + def completeWithSuccess[A: WeakTypeTag](prom: Expr[Prom[A]], value: Expr[A]): Expr[Unit] = completeProm(prom, tryySuccess(value)) def spawn(tree: Tree, execContext: Tree): Tree = future(c.Expr[Unit](tree))(c.Expr[ExecContext](execContext)).tree @@ -108,11 +112,8 @@ object ScalaConcurrentFutureSystem extends FutureSystem { override def continueCompletedFutureOnSameThread: Boolean = true - override def isCompleted(future: Expr[Fut[_]]): Expr[Boolean] = reify { - future.splice.isCompleted - } override def getCompleted[A: WeakTypeTag](future: Expr[Fut[A]]): Expr[Tryy[A]] = reify { - future.splice.value.get + if (future.splice.isCompleted) future.splice.value.get else null } def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { diff --git a/src/main/scala/scala/async/internal/Lifter.scala b/src/main/scala/scala/async/internal/Lifter.scala index 2998baf..3afe6d6 100644 --- a/src/main/scala/scala/async/internal/Lifter.scala +++ b/src/main/scala/scala/async/internal/Lifter.scala @@ -76,8 +76,9 @@ trait Lifter { // are already accounted for. val stateIdToDirectlyReferenced: Map[Int, List[Symbol]] = { val refs: List[(Int, Symbol)] = asyncStates.flatMap( - asyncState => asyncState.stats.filterNot(_.isDef).flatMap(_.collect { - case rt: RefTree if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol) + asyncState => asyncState.stats.filterNot(t => t.isDef && !isLabel(t.symbol)).flatMap(_.collect { + case rt: RefTree + if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol) }) ) toMultiMap(refs) @@ -91,7 +92,7 @@ trait Lifter { // Only mark transitive references of defs, modules and classes. The RHS of lifted vals/vars // stays in its original location, so things that it refers to need not be lifted. - if (!(sym.isTerm && (sym.asTerm.isVal || sym.asTerm.isVar))) + if (!(sym.isTerm && !sym.asTerm.isLazy && (sym.asTerm.isVal || sym.asTerm.isVar))) defSymToReferenced(sym).foreach(sym2 => markForLift(sym2)) } } @@ -116,7 +117,8 @@ trait Lifter { sym.setFlag(MUTABLE | STABLE | PRIVATE | LOCAL) sym.setName(name.fresh(sym.name.toTermName)) sym.setInfo(deconst(sym.info)) - treeCopy.ValDef(vd, Modifiers(sym.flags), sym.name, TypeTree(tpe(sym)).setPos(t.pos), EmptyTree) + val rhs1 = if (sym.asTerm.isLazy) rhs else EmptyTree + treeCopy.ValDef(vd, Modifiers(sym.flags), sym.name, TypeTree(tpe(sym)).setPos(t.pos), rhs1) case dd@DefDef(_, _, tparams, vparamss, tpt, rhs) => sym.setName(this.name.fresh(sym.name.toTermName)) sym.setFlag(PRIVATE | LOCAL) diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala index 90419d3..c86540b 100644 --- a/src/main/scala/scala/async/internal/TransformUtils.scala +++ b/src/main/scala/scala/async/internal/TransformUtils.scala @@ -24,6 +24,7 @@ private[async] trait TransformUtils { val ifRes = "ifres" val await = "await" val bindSuffix = "$bind" + val completed = newTermName("completed") val state = newTermName("state") val result = newTermName("result") @@ -38,6 +39,9 @@ private[async] trait TransformUtils { def fresh(name: String): String = c.freshName(name) } + def isAsync(fun: Tree) = + fun.symbol == defn.Async_async + def isAwait(fun: Tree) = fun.symbol == defn.Async_await @@ -164,7 +168,8 @@ private[async] trait TransformUtils { val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal") val ThrowableClass = rootMirror.staticClass("java.lang.Throwable") - val Async_await = asyncBase.awaitMethod(c.universe)(c.macroApplication.symbol).ensuring(_ != NoSymbol) + lazy val Async_async = asyncBase.asyncMethod(c.universe)(c.macroApplication.symbol) + lazy val Async_await = asyncBase.awaitMethod(c.universe)(c.macroApplication.symbol) val IllegalStateExceptionClass = rootMirror.staticClass("java.lang.IllegalStateException") } @@ -186,6 +191,10 @@ private[async] trait TransformUtils { val LABEL = 1L << 17 // not in the public reflection API. (internal.flags(sym).asInstanceOf[Long] & LABEL) != 0L } + def isSynth(sym: Symbol): Boolean = { + val SYNTHETIC = 1 << 21 // not in the public reflection API. + (internal.flags(sym).asInstanceOf[Long] & SYNTHETIC) != 0L + } def symId(sym: Symbol): Int = { val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable] sym.asInstanceOf[symtab.Symbol].id @@ -281,6 +290,8 @@ private[async] trait TransformUtils { override def traverse(tree: Tree) { tree match { + case _ if isAsync(tree) => + // Under -Ymacro-expand:discard, used in the IDE, nested async blocks will be visible to the outer blocks case cd: ClassDef => nestedClass(cd) case md: ModuleDef => nestedModule(md) case dd: DefDef => nestedMethod(dd) @@ -382,7 +393,7 @@ private[async] trait TransformUtils { catch { case _: ScalaReflectionException => NoSymbol } } final def uncheckedBounds(tp: Type): Type = { - if (tp.typeArgs.isEmpty || UncheckedBoundsClass == NoSymbol) tp + if ((tp.typeArgs.isEmpty && (tp match { case _: TypeRef => true; case _ => false}))|| UncheckedBoundsClass == NoSymbol) tp else withAnnotation(tp, Annotation(UncheckedBoundsClass.asType.toType, Nil, ListMap())) } // ===================================== @@ -396,9 +407,11 @@ private[async] trait TransformUtils { * in search of a sub tree that was decorated with the cached answer. */ final def containsAwaitCached(t: Tree): Tree => Boolean = { + if (c.macroApplication.symbol == null) return (t => false) + def treeCannotContainAwait(t: Tree) = t match { case _: Ident | _: TypeTree | _: Literal => true - case _ => false + case _ => isAsync(t) } def shouldAttach(t: Tree) = !treeCannotContainAwait(t) val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] @@ -417,11 +430,15 @@ private[async] trait TransformUtils { override def traverse(tree: Tree): Unit = { stack ::= tree try { - if (isAwait(tree)) - stack.foreach(attachContainsAwait) - else - attachNoAwait(tree) - super.traverse(tree) + if (isAsync(tree)) { + ; + } else { + if (isAwait(tree)) + stack.foreach(attachContainsAwait) + else + attachNoAwait(tree) + super.traverse(tree) + } } finally stack = stack.tail } } @@ -463,6 +480,14 @@ private[async] trait TransformUtils { typingTransform(t, owner) { (tree, api) => tree match { + case LabelDef(name, params, rhs) => + val rhs1 = api.recur(rhs) + if (rhs1.tpe =:= UnitTpe) { + internal.setInfo(tree.symbol, internal.methodType(tree.symbol.info.paramLists.head, UnitTpe)) + treeCopy.LabelDef(tree, name, params, rhs1) + } else { + treeCopy.LabelDef(tree, name, params, rhs1) + } case Block(stats, expr) => val stats1 = stats map api.recur val expr1 = api.recur(expr) diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index 09fa69e..1637102 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -54,7 +54,7 @@ class TreeInterrogation { } object TreeInterrogation extends App { - def withDebug[T](t: => T) { + def withDebug[T](t: => T): T = { def set(level: String, value: Boolean) = System.setProperty(s"scala.async.$level", value.toString) val levels = Seq("trace", "debug") def setAll(value: Boolean) = levels.foreach(set(_, value)) diff --git a/src/test/scala/scala/async/run/WarningsSpec.scala b/src/test/scala/scala/async/run/WarningsSpec.scala index 00c6466..c80bf9e 100644 --- a/src/test/scala/scala/async/run/WarningsSpec.scala +++ b/src/test/scala/scala/async/run/WarningsSpec.scala @@ -74,4 +74,24 @@ class WarningsSpec { run.compileSources(sourceFile :: Nil) assert(!global.reporter.hasErrors, global.reporter.asInstanceOf[StoreReporter].infos) } + + @Test + def ignoreNestedAwaitsInIDE_t1002561() { + // https://www.assembla.com/spaces/scala-ide/tickets/1002561 + val global = mkGlobal("-cp ${toolboxClasspath} -Yrangepos -Ystop-after:typer ") + val source = """ + | class Test { + | def test = { + | import scala.async.Async._, scala.concurrent._, ExecutionContext.Implicits.global + | async { + | 1 + await({def foo = (async(await(async(2)))); foo}) + | } + | } + |} + """.stripMargin + val run = new global.Run + val sourceFile = global.newSourceFile(source) + run.compileSources(sourceFile :: Nil) + assert(!global.reporter.hasErrors, global.reporter.asInstanceOf[StoreReporter].infos) + } } diff --git a/src/test/scala/scala/async/run/late/LateExpansion.scala b/src/test/scala/scala/async/run/late/LateExpansion.scala index b866527..a40b1af 100644 --- a/src/test/scala/scala/async/run/late/LateExpansion.scala +++ b/src/test/scala/scala/async/run/late/LateExpansion.scala @@ -3,10 +3,12 @@ package scala.async.run.late import java.io.File import junit.framework.Assert.assertEquals -import org.junit.Test +import org.junit.{Assert, Test} import scala.annotation.StaticAnnotation -import scala.async.internal.{AsyncId, AsyncMacro} +import scala.annotation.meta.{field, getter} +import scala.async.TreeInterrogation +import scala.async.internal.AsyncId import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader import scala.tools.nsc._ import scala.tools.nsc.plugins.{Plugin, PluginComponent} @@ -16,6 +18,7 @@ import scala.tools.nsc.transform.TypingTransformers // Tests for customized use of the async transform from a compiler plugin, which // calls it from a new phase that runs after patmat. class LateExpansion { + @Test def test0(): Unit = { val result = wrapAndRun( """ @@ -75,6 +78,263 @@ class LateExpansion { assertEquals("case 3: blerg3", result) } + @Test def polymorphicMethod(): Unit = { + val result = run( + """ + |import scala.async.run.late.{autoawait,lateasync} + |object Test { + | class C { override def toString = "C" } + | @autoawait def foo[A <: C](a: A): A = a + | @lateasync + | def test1[CC <: C](c: CC): (CC, CC) = { + | val x: (CC, CC) = 0 match { case _ if false => ???; case _ => (foo(c), foo(c)) } + | x + | } + | def test(): (C, C) = test1(new C) + |} + | """.stripMargin) + assertEquals("(C,C)", result.toString) + } + + @Test def shadowing(): Unit = { + val result = run( + """ + |import scala.async.run.late.{autoawait,lateasync} + |object Test { + | trait Foo + | trait Bar extends Foo + | @autoawait def boundary = "" + | @lateasync + | def test: Unit = { + | (new Bar {}: Any) match { + | case foo: Bar => + | boundary + | 0 match { + | case _ => foo; () + | } + | () + | } + | () + | } + |} + | """.stripMargin) + } + + @Test def shadowing0(): Unit = { + val result = run( + """ + |import scala.async.run.late.{autoawait,lateasync} + |object Test { + | trait Foo + | trait Bar + | def test: Any = test(new C) + | @autoawait def asyncBoundary: String = "" + | @lateasync + | def test(foo: Foo): Foo = foo match { + | case foo: Bar => + | val foo2: Foo with Bar = new Foo with Bar {} + | asyncBoundary + | null match { + | case _ => foo2 + | } + | case other => foo + | } + | class C extends Foo with Bar + |} + | """.stripMargin) + } + @Test def shadowing2(): Unit = { + val result = run( + """ + |import scala.async.run.late.{autoawait,lateasync} + |object Test { + | trait Base; trait Foo[T <: Base] { @autoawait def func: Option[Foo[T]] = None } + | class Sub extends Base + | trait Bar extends Foo[Sub] + | def test: Any = test(new Bar {}) + | @lateasync + | def test[T <: Base](foo: Foo[T]): Foo[T] = foo match { + | case foo: Bar => + | val res = foo.func + | res match { + | case _ => + | } + | foo + | case other => foo + | } + | test(new Bar {}) + |} + | """.stripMargin) + } + + @Test def patternAlternative(): Unit = { + val result = wrapAndRun( + """ + | @autoawait def one = 1 + | + | @lateasync def test = { + | Option(true) match { + | case null | None => false + | case Some(v) => one; v + | } + | } + | """.stripMargin) + } + + @Test def patternAlternativeBothAnnotations(): Unit = { + val result = wrapAndRun( + """ + |import scala.async.run.late.{autoawait,lateasync} + |object Test { + | @autoawait def func1() = "hello" + | @lateasync def func(a: Option[Boolean]) = a match { + | case null | None => func1 + " world" + | case _ => "okay" + | } + | def test: Any = func(None) + |} + | """.stripMargin) + } + + @Test def shadowingRefinedTypes(): Unit = { + val result = run( + s""" + |import scala.async.run.late.{autoawait,lateasync} + |trait Base + |class Sub extends Base + |trait Foo[T <: Base] { + | @autoawait def func: Option[Foo[T]] = None + |} + |trait Bar extends Foo[Sub] + |object Test { + | @lateasync def func[T <: Base](foo: Foo[T]): Foo[T] = foo match { // the whole pattern match will be wrapped with async{ } + | case foo: Bar => + | val res = foo.func // will be rewritten into: await(foo.func) + | res match { + | case Some(v) => v // this will report type mismtach + | case other => foo + | } + | case other => foo + | } + | def test: Any = { val b = new Bar{}; func(b) == b } + |}""".stripMargin) + assertEquals(true, result) + } + + @Test def testMatchEndIssue(): Unit = { + val result = run( + """ + |import scala.async.run.late.{autoawait,lateasync} + |sealed trait Subject + |final class Principal(val name: String) extends Subject + |object Principal { + | def unapply(p: Principal): Option[String] = Some(p.name) + |} + |object Test { + | @autoawait @lateasync + | def containsPrincipal(search: String, value: Subject): Boolean = value match { + | case Principal(name) if name == search => true + | case Principal(name) => containsPrincipal(search, value) + | case other => false + | } + | + | @lateasync + | def test = containsPrincipal("test", new Principal("test")) + |} + | """.stripMargin) + } + + @Test def testGenericTypeBoundaryIssue(): Unit = { + val result = run( + """ + import scala.async.run.late.{autoawait,lateasync} + trait InstrumentOfValue + trait Security[T <: InstrumentOfValue] extends InstrumentOfValue + class Bound extends Security[Bound] + class Futures extends Security[Futures] + object TestGenericTypeBoundIssue { + @autoawait @lateasync def processBound(bound: Bound): Unit = { println("process Bound") } + @autoawait @lateasync def processFutures(futures: Futures): Unit = { println("process Futures") } + @autoawait @lateasync def doStuff(sec: Security[_]): Unit = { + sec match { + case bound: Bound => processBound(bound) + case futures: Futures => processFutures(futures) + case _ => throw new Exception("Unknown Security type: " + sec) + } + } + } + """.stripMargin) + } + + @Test def testReturnTupleIssue(): Unit = { + val result = run( + """ + import scala.async.run.late.{autoawait,lateasync} + class TestReturnExprIssue(str: String) { + @autoawait @lateasync def getTestValue = Some(42) + @autoawait @lateasync def doStuff: Int = { + val opt: Option[Int] = getTestValue // here we have an async method invoke + opt match { + case Some(li) => li // use the result somehow + case None => + } + 42 // type mismatch; found : AnyVal required: Int + } + } + """.stripMargin) + } + + + @Test def testAfterRefchecksIssue(): Unit = { + val result = run( + """ + import scala.async.run.late.{autoawait,lateasync} + trait Factory[T] { def create: T } + sealed trait TimePoint + class TimeLine[TP <: TimePoint](val tpInitial: Factory[TP]) { + @autoawait @lateasync private[TimeLine] val tp: TP = tpInitial.create + @autoawait @lateasync def timePoint: TP = tp + } + object Test { + def test: Unit = () + } + """) + } + + @Test def testArrayIndexOutOfBoundIssue(): Unit = { + val result = run( + """ + import scala.async.run.late.{autoawait,lateasync} + + sealed trait Result + case object A extends Result + case object B extends Result + case object C extends Result + + object Test { + protected def doStuff(res: Result) = { + class C { + @autoawait def needCheck = false + + @lateasync def m = { + if (needCheck) "NO" + else { + res match { + case A => 1 + case _ => 2 + } + } + } + } + } + + + @lateasync + def test() = doStuff(B) + } + """) + } + def wrapAndRun(code: String): Any = { run( s""" @@ -88,10 +348,49 @@ class LateExpansion { | """.stripMargin) } + + @Test def testNegativeArraySizeException(): Unit = { + val result = run( + """ + import scala.async.run.late.{autoawait,lateasync} + + object Test { + def foo(foo: Any, bar: Any) = () + @autoawait def getValue = 4.2 + @lateasync def func(f: Any) = { + foo(f match { case _ if "".isEmpty => 2 }, getValue); + } + + @lateasync + def test() = func(4) + } + """) + } + @Test def testNegativeArraySizeExceptionFine1(): Unit = { + val result = run( + """ + import scala.async.run.late.{autoawait,lateasync} + case class FixedFoo(foo: Int) + class Foobar(val foo: Int, val bar: Double) { + @autoawait @lateasync def getValue = 4.2 + @autoawait @lateasync def func(f: Any) = { + new Foobar(foo = f match { + case FixedFoo(x) => x + case _ => 2 + }, + bar = getValue) + } + } + object Test { + @lateasync def test() = new Foobar(0, 0).func(4) + } + """) + } def run(code: String): Any = { val reporter = new StoreReporter val settings = new Settings(println(_)) - settings.outdir.value = sys.props("java.io.tmpdir") + // settings.processArgumentString("-Xprint:patmat,postpatmat,jvm -Ybackend:GenASM -nowarn") + settings.outdir.value = "/tmp" settings.embeddedDefaults(getClass.getClassLoader) val isInSBT = !settings.classpath.isSetByUser if (isInSBT) settings.usejavacp.value = true @@ -108,8 +407,10 @@ class LateExpansion { val run = new Run val source = newSourceFile(code) - run.compileSources(source :: Nil) - assert(!reporter.hasErrors, reporter.infos.mkString("\n")) +// TreeInterrogation.withDebug { + run.compileSources(source :: Nil) +// } + Assert.assertTrue(reporter.infos.mkString("\n"), !reporter.hasErrors) val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader) val cls = loader.loadClass("Test") cls.getMethod("test").invoke(null) @@ -133,20 +434,26 @@ abstract class LatePlugin extends Plugin { super.transform(tree) match { case ap@Apply(fun, args) if fun.symbol.hasAnnotation(autoAwaitSym) => localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(ap.tpe) :: Nil), ap :: Nil)) + case sel@Select(fun, _) if sel.symbol.hasAnnotation(autoAwaitSym) && !(tree.tpe.isInstanceOf[MethodTypeApi] || tree.tpe.isInstanceOf[PolyTypeApi] ) => + localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(sel.tpe) :: Nil), sel :: Nil)) case dd: DefDef if dd.symbol.hasAnnotation(lateAsyncSym) => atOwner(dd.symbol) { - val expandee = localTyper.context.withMacrosDisabled( - localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(dd.rhs.tpe) :: Nil), List(dd.rhs))) - ) - val c = analyzer.macroContext(localTyper, gen.mkAttributedRef(asyncIdSym), expandee) - val asyncMacro = AsyncMacro(c, AsyncId)(dd.rhs) - val code = asyncMacro.asyncTransform[Any](localTyper.typed(Literal(Constant(()))))(c.weakTypeTag[Any]) - deriveDefDef(dd)(_ => localTyper.typed(code)) + deriveDefDef(dd){ rhs: Tree => + val invoke = Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(rhs.tpe) :: Nil), List(rhs)) + localTyper.typed(atPos(dd.pos)(invoke)) + } } + case vd: ValDef if vd.symbol.hasAnnotation(lateAsyncSym) => atOwner(vd.symbol) { + deriveValDef(vd){ rhs: Tree => + val invoke = Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(rhs.tpe) :: Nil), List(rhs)) + localTyper.typed(atPos(vd.pos)(invoke)) + } + } + case vd: ValDef => + vd case x => x } } } - override def newPhase(prev: Phase): Phase = new StdPhase(prev) { override def apply(unit: CompilationUnit): Unit = { val translated = newTransformer(unit).transformUnit(unit) @@ -155,7 +462,7 @@ abstract class LatePlugin extends Plugin { } } - override val runsAfter: List[String] = "patmat" :: Nil + override val runsAfter: List[String] = "refchecks" :: Nil override val phaseName: String = "postpatmat" }) @@ -164,7 +471,9 @@ abstract class LatePlugin extends Plugin { } // Methods with this annotation are translated to having the RHS wrapped in `AsyncId.async { <original RHS> }` +@field final class lateasync extends StaticAnnotation // Calls to methods with this annotation are translated to `AsyncId.await(<call>)` +@getter final class autoawait extends StaticAnnotation diff --git a/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala b/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala index 01cf911..af236aa 100644 --- a/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala +++ b/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala @@ -266,7 +266,6 @@ class LiveVariablesSpec { // https://github.com/scala/async/issues/104 @Test def dontNullOutVarsOfTypeNothing_t104(): Unit = { - implicit val ec: scala.concurrent.ExecutionContext = null import scala.async.Async._ import scala.concurrent.duration.Duration import scala.concurrent.{Await, Future} |