aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/internal/TransformUtils.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/scala/async/internal/TransformUtils.scala')
-rw-r--r--src/main/scala/scala/async/internal/TransformUtils.scala126
1 files changed, 124 insertions, 2 deletions
diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala
index 547f980..ed8b103 100644
--- a/src/main/scala/scala/async/internal/TransformUtils.scala
+++ b/src/main/scala/scala/async/internal/TransformUtils.scala
@@ -41,6 +41,19 @@ private[async] trait TransformUtils {
def isAwait(fun: Tree) =
fun.symbol == defn.Async_await
+ def newBlock(stats: List[Tree], expr: Tree): Block = {
+ Block(stats, expr)
+ }
+
+ def isLiteralUnit(t: Tree) = t match {
+ case Literal(Constant(())) =>
+ true
+ case _ => false
+ }
+
+ def isPastTyper =
+ c.universe.asInstanceOf[scala.reflect.internal.SymbolTable].isPastTyper
+
// Copy pasted from TreeInfo in the compiler.
// Using a quasiquote pattern like `case q"$fun[..$targs](...$args)" => is not
// sufficient since https://github.com/scala/scala/pull/3656 as it doesn't match
@@ -150,6 +163,7 @@ private[async] trait TransformUtils {
}
val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal")
+ val ThrowableClass = rootMirror.staticClass("java.lang.Throwable")
val Async_await = asyncBase.awaitMethod(c.universe)(c.macroApplication.symbol).ensuring(_ != NoSymbol)
val IllegalStateExceptionClass = rootMirror.staticClass("java.lang.IllegalStateException")
}
@@ -161,16 +175,26 @@ private[async] trait TransformUtils {
val labelDefs = t.collect {
case ld: LabelDef => ld.symbol
}.toSet
- t.exists {
+ val result = t.exists {
case rt: RefTree => rt.symbol != null && isLabel(rt.symbol) && !(labelDefs contains rt.symbol)
case _ => false
}
+ result
}
- private def isLabel(sym: Symbol): Boolean = {
+ def isLabel(sym: Symbol): Boolean = {
val LABEL = 1L << 17 // not in the public reflection API.
(internal.flags(sym).asInstanceOf[Long] & LABEL) != 0L
}
+ def symId(sym: Symbol): Int = {
+ val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable]
+ sym.asInstanceOf[symtab.Symbol].id
+ }
+ def substituteTrees(t: Tree, from: List[Symbol], to: List[Tree]): Tree = {
+ val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable]
+ val subst = new symtab.TreeSubstituter(from.asInstanceOf[List[symtab.Symbol]], to.asInstanceOf[List[symtab.Tree]])
+ subst.transform(t.asInstanceOf[symtab.Tree]).asInstanceOf[Tree]
+ }
/** Map a list of arguments to:
@@ -362,4 +386,102 @@ private[async] trait TransformUtils {
else withAnnotation(tp, Annotation(UncheckedBoundsClass.asType.toType, Nil, ListMap()))
}
// =====================================
+
+ /**
+ * Efficiently decorate each subtree within `t` with the result of `t exists isAwait`,
+ * and return a function that can be used on derived trees to efficiently test the
+ * same condition.
+ *
+ * If the derived tree contains synthetic wrapper trees, these will be recursed into
+ * in search of a sub tree that was decorated with the cached answer.
+ */
+ final def containsAwaitCached(t: Tree): Tree => Boolean = {
+ val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
+ def attachContainsAwait(t: Tree): Unit = {
+ val t1 = t.asInstanceOf[symtab.Tree]
+ t1.updateAttachment(ContainsAwait)
+ t1.removeAttachment[NoAwait.type]
+ }
+ def attachNoAwait(t: Tree): Unit = {
+ val t1 = t.asInstanceOf[symtab.Tree]
+ t1.updateAttachment(NoAwait)
+ }
+ object markContainsAwaitTraverser extends Traverser {
+ var stack: List[Tree] = Nil
+
+ override def traverse(tree: Tree): Unit = {
+ stack ::= tree
+ try {
+ if (isAwait(tree))
+ stack.foreach(attachContainsAwait)
+ else
+ attachNoAwait(tree)
+ super.traverse(tree)
+ } finally stack = stack.tail
+ }
+ }
+ markContainsAwaitTraverser.traverse(t)
+
+ (t: Tree) => {
+ val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
+ object traverser extends Traverser {
+ var containsAwait = false
+ override def traverse(tree: Tree): Unit = {
+ if (tree.asInstanceOf[symtab.Tree].hasAttachment[NoAwait.type])
+ ()
+ else if (tree.asInstanceOf[symtab.Tree].hasAttachment[ContainsAwait.type])
+ containsAwait = true
+ else super.traverse(tree)
+ }
+ }
+ traverser.traverse(t)
+ traverser.containsAwait
+ }
+ }
+
+ final def adjustTypeOfTranslatedPatternMatches(t: Tree, owner: Symbol): Tree = {
+ import definitions.UnitTpe
+ typingTransform(t, owner) {
+ (tree, api) =>
+ tree match {
+ case Block(stats, expr) =>
+ val stats1 = stats map api.recur
+ val expr1 = api.recur(expr)
+ if (expr1.tpe =:= UnitTpe)
+ internal.setType(treeCopy.Block(tree, stats1, expr1), UnitTpe)
+ else
+ treeCopy.Block(tree, stats1, expr1)
+ case If(cond, thenp, elsep) =>
+ val cond1 = api.recur(cond)
+ val thenp1 = api.recur(thenp)
+ val elsep1 = api.recur(elsep)
+ if (thenp1.tpe =:= definitions.UnitTpe && elsep.tpe =:= UnitTpe)
+ internal.setType(treeCopy.If(tree, cond1, thenp1, elsep1), UnitTpe)
+ else
+ treeCopy.If(tree, cond1, thenp1, elsep1)
+ case Apply(fun, args) if isLabel(fun.symbol) =>
+ internal.setType(treeCopy.Apply(tree, api.recur(fun), args map api.recur), UnitTpe)
+ case t => api.default(t)
+ }
+ }
+ }
+
+ final def mkMutableField(tpt: Type, name: TermName, init: Tree): List[Tree] = {
+ if (isPastTyper) {
+ // If we are running after the typer phase (ie being called from a compiler plugin)
+ // we have to create the trio of members manually.
+ val ACCESSOR = (1L << 27).asInstanceOf[FlagSet]
+ val STABLE = (1L << 22).asInstanceOf[FlagSet]
+ val field = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name + " ", TypeTree(tpt), init)
+ val getter = DefDef(Modifiers(ACCESSOR | STABLE), name, Nil, Nil, TypeTree(tpt), Select(This(tpnme.EMPTY), field.name))
+ val setter = DefDef(Modifiers(ACCESSOR), name + "_=", Nil, List(List(ValDef(NoMods, TermName("x"), TypeTree(tpt), EmptyTree))), TypeTree(definitions.UnitTpe), Assign(Select(This(tpnme.EMPTY), field.name), Ident(TermName("x"))))
+ field :: getter :: setter :: Nil
+ } else {
+ val result = ValDef(NoMods, name, TypeTree(tpt), init)
+ result :: Nil
+ }
+ }
}
+
+case object ContainsAwait
+case object NoAwait \ No newline at end of file