diff options
Diffstat (limited to 'src/main/scala/scala/async/internal/TransformUtils.scala')
-rw-r--r-- | src/main/scala/scala/async/internal/TransformUtils.scala | 126 |
1 files changed, 124 insertions, 2 deletions
diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala index 547f980..ed8b103 100644 --- a/src/main/scala/scala/async/internal/TransformUtils.scala +++ b/src/main/scala/scala/async/internal/TransformUtils.scala @@ -41,6 +41,19 @@ private[async] trait TransformUtils { def isAwait(fun: Tree) = fun.symbol == defn.Async_await + def newBlock(stats: List[Tree], expr: Tree): Block = { + Block(stats, expr) + } + + def isLiteralUnit(t: Tree) = t match { + case Literal(Constant(())) => + true + case _ => false + } + + def isPastTyper = + c.universe.asInstanceOf[scala.reflect.internal.SymbolTable].isPastTyper + // Copy pasted from TreeInfo in the compiler. // Using a quasiquote pattern like `case q"$fun[..$targs](...$args)" => is not // sufficient since https://github.com/scala/scala/pull/3656 as it doesn't match @@ -150,6 +163,7 @@ private[async] trait TransformUtils { } val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal") + val ThrowableClass = rootMirror.staticClass("java.lang.Throwable") val Async_await = asyncBase.awaitMethod(c.universe)(c.macroApplication.symbol).ensuring(_ != NoSymbol) val IllegalStateExceptionClass = rootMirror.staticClass("java.lang.IllegalStateException") } @@ -161,16 +175,26 @@ private[async] trait TransformUtils { val labelDefs = t.collect { case ld: LabelDef => ld.symbol }.toSet - t.exists { + val result = t.exists { case rt: RefTree => rt.symbol != null && isLabel(rt.symbol) && !(labelDefs contains rt.symbol) case _ => false } + result } - private def isLabel(sym: Symbol): Boolean = { + def isLabel(sym: Symbol): Boolean = { val LABEL = 1L << 17 // not in the public reflection API. (internal.flags(sym).asInstanceOf[Long] & LABEL) != 0L } + def symId(sym: Symbol): Int = { + val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable] + sym.asInstanceOf[symtab.Symbol].id + } + def substituteTrees(t: Tree, from: List[Symbol], to: List[Tree]): Tree = { + val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable] + val subst = new symtab.TreeSubstituter(from.asInstanceOf[List[symtab.Symbol]], to.asInstanceOf[List[symtab.Tree]]) + subst.transform(t.asInstanceOf[symtab.Tree]).asInstanceOf[Tree] + } /** Map a list of arguments to: @@ -362,4 +386,102 @@ private[async] trait TransformUtils { else withAnnotation(tp, Annotation(UncheckedBoundsClass.asType.toType, Nil, ListMap())) } // ===================================== + + /** + * Efficiently decorate each subtree within `t` with the result of `t exists isAwait`, + * and return a function that can be used on derived trees to efficiently test the + * same condition. + * + * If the derived tree contains synthetic wrapper trees, these will be recursed into + * in search of a sub tree that was decorated with the cached answer. + */ + final def containsAwaitCached(t: Tree): Tree => Boolean = { + val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] + def attachContainsAwait(t: Tree): Unit = { + val t1 = t.asInstanceOf[symtab.Tree] + t1.updateAttachment(ContainsAwait) + t1.removeAttachment[NoAwait.type] + } + def attachNoAwait(t: Tree): Unit = { + val t1 = t.asInstanceOf[symtab.Tree] + t1.updateAttachment(NoAwait) + } + object markContainsAwaitTraverser extends Traverser { + var stack: List[Tree] = Nil + + override def traverse(tree: Tree): Unit = { + stack ::= tree + try { + if (isAwait(tree)) + stack.foreach(attachContainsAwait) + else + attachNoAwait(tree) + super.traverse(tree) + } finally stack = stack.tail + } + } + markContainsAwaitTraverser.traverse(t) + + (t: Tree) => { + val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] + object traverser extends Traverser { + var containsAwait = false + override def traverse(tree: Tree): Unit = { + if (tree.asInstanceOf[symtab.Tree].hasAttachment[NoAwait.type]) + () + else if (tree.asInstanceOf[symtab.Tree].hasAttachment[ContainsAwait.type]) + containsAwait = true + else super.traverse(tree) + } + } + traverser.traverse(t) + traverser.containsAwait + } + } + + final def adjustTypeOfTranslatedPatternMatches(t: Tree, owner: Symbol): Tree = { + import definitions.UnitTpe + typingTransform(t, owner) { + (tree, api) => + tree match { + case Block(stats, expr) => + val stats1 = stats map api.recur + val expr1 = api.recur(expr) + if (expr1.tpe =:= UnitTpe) + internal.setType(treeCopy.Block(tree, stats1, expr1), UnitTpe) + else + treeCopy.Block(tree, stats1, expr1) + case If(cond, thenp, elsep) => + val cond1 = api.recur(cond) + val thenp1 = api.recur(thenp) + val elsep1 = api.recur(elsep) + if (thenp1.tpe =:= definitions.UnitTpe && elsep.tpe =:= UnitTpe) + internal.setType(treeCopy.If(tree, cond1, thenp1, elsep1), UnitTpe) + else + treeCopy.If(tree, cond1, thenp1, elsep1) + case Apply(fun, args) if isLabel(fun.symbol) => + internal.setType(treeCopy.Apply(tree, api.recur(fun), args map api.recur), UnitTpe) + case t => api.default(t) + } + } + } + + final def mkMutableField(tpt: Type, name: TermName, init: Tree): List[Tree] = { + if (isPastTyper) { + // If we are running after the typer phase (ie being called from a compiler plugin) + // we have to create the trio of members manually. + val ACCESSOR = (1L << 27).asInstanceOf[FlagSet] + val STABLE = (1L << 22).asInstanceOf[FlagSet] + val field = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name + " ", TypeTree(tpt), init) + val getter = DefDef(Modifiers(ACCESSOR | STABLE), name, Nil, Nil, TypeTree(tpt), Select(This(tpnme.EMPTY), field.name)) + val setter = DefDef(Modifiers(ACCESSOR), name + "_=", Nil, List(List(ValDef(NoMods, TermName("x"), TypeTree(tpt), EmptyTree))), TypeTree(definitions.UnitTpe), Assign(Select(This(tpnme.EMPTY), field.name), Ident(TermName("x")))) + field :: getter :: setter :: Nil + } else { + val result = ValDef(NoMods, name, TypeTree(tpt), init) + result :: Nil + } + } } + +case object ContainsAwait +case object NoAwait
\ No newline at end of file |