aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/internal/TransformUtils.scala
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2015-07-23 23:15:37 +1000
committerJason Zaugg <jzaugg@gmail.com>2015-09-22 16:53:33 +1000
commite3ff0382ae4e015fc69da8335450718951714982 (patch)
tree3f89dace31be3cd125531c0ba24270aa45100d7e /src/main/scala/scala/async/internal/TransformUtils.scala
parent93f207fee780652d08f93e1ea40e018db59fee99 (diff)
downloadscala-async-e3ff0382ae4e015fc69da8335450718951714982.tar.gz
scala-async-e3ff0382ae4e015fc69da8335450718951714982.tar.bz2
scala-async-e3ff0382ae4e015fc69da8335450718951714982.zip
Enable a compiler plugin to use the async transform after patmat
Currently, the async transformation is performed during the typer phase, like all other macros. We have to levy a few artificial restrictions on whern an async boundary may be: for instance we don't support await within a pattern guard. A more natural home for the transform would be after patterns have been translated. The test case in this commit shows how to use the async transform from a custom compiler phase after patmat. The remainder of the commit updates the implementation to handle the new tree shapes. For states that correspond to a label definition, we use `-symbol.id` as the state ID. This made it easier to emit the forward jumps to when processing the label application before we had seen the label definition. I've also made the transformation more efficient in the way it checks whether a given tree encloses an `await` call: we traverse the input tree at the start of the macro, and decorate it with tree attachments containig the answer to this question. Even after the ANF and state machine transforms introduce new layers of synthetic trees, the `containsAwait` code need only traverse shallowly through those trees to find a child that has the cached answer from the original traversal. I had to special case the ANF transform for expressions that always lead to a label jump: we avoids trying to push an assignment to a result variable into `if (cond) jump1() else jump2()`, in trees of the form: ``` % cat sandbox/jump.scala class Test { def test = { (null: Any) match { case _: String => "" case _ => "" } } } % qscalac -Xprint:patmat -Xprint-types sandbox/jump.scala def test: String = { case <synthetic> val x1: Any = (null{Null(null)}: Any){Any}; case5(){ if (x1.isInstanceOf{[T0]=> Boolean}[String]{Boolean}) matchEnd4{(x: String)String}(""{String("")}){String} else case6{()String}(){String}{String} }{String}; case6(){ matchEnd4{(x: String)String}(""{String("")}){String} }{String}; matchEnd4(x: String){ x{String} }{String} }{String} ```
Diffstat (limited to 'src/main/scala/scala/async/internal/TransformUtils.scala')
-rw-r--r--src/main/scala/scala/async/internal/TransformUtils.scala126
1 files changed, 124 insertions, 2 deletions
diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala
index 547f980..ed8b103 100644
--- a/src/main/scala/scala/async/internal/TransformUtils.scala
+++ b/src/main/scala/scala/async/internal/TransformUtils.scala
@@ -41,6 +41,19 @@ private[async] trait TransformUtils {
def isAwait(fun: Tree) =
fun.symbol == defn.Async_await
+ def newBlock(stats: List[Tree], expr: Tree): Block = {
+ Block(stats, expr)
+ }
+
+ def isLiteralUnit(t: Tree) = t match {
+ case Literal(Constant(())) =>
+ true
+ case _ => false
+ }
+
+ def isPastTyper =
+ c.universe.asInstanceOf[scala.reflect.internal.SymbolTable].isPastTyper
+
// Copy pasted from TreeInfo in the compiler.
// Using a quasiquote pattern like `case q"$fun[..$targs](...$args)" => is not
// sufficient since https://github.com/scala/scala/pull/3656 as it doesn't match
@@ -150,6 +163,7 @@ private[async] trait TransformUtils {
}
val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal")
+ val ThrowableClass = rootMirror.staticClass("java.lang.Throwable")
val Async_await = asyncBase.awaitMethod(c.universe)(c.macroApplication.symbol).ensuring(_ != NoSymbol)
val IllegalStateExceptionClass = rootMirror.staticClass("java.lang.IllegalStateException")
}
@@ -161,16 +175,26 @@ private[async] trait TransformUtils {
val labelDefs = t.collect {
case ld: LabelDef => ld.symbol
}.toSet
- t.exists {
+ val result = t.exists {
case rt: RefTree => rt.symbol != null && isLabel(rt.symbol) && !(labelDefs contains rt.symbol)
case _ => false
}
+ result
}
- private def isLabel(sym: Symbol): Boolean = {
+ def isLabel(sym: Symbol): Boolean = {
val LABEL = 1L << 17 // not in the public reflection API.
(internal.flags(sym).asInstanceOf[Long] & LABEL) != 0L
}
+ def symId(sym: Symbol): Int = {
+ val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable]
+ sym.asInstanceOf[symtab.Symbol].id
+ }
+ def substituteTrees(t: Tree, from: List[Symbol], to: List[Tree]): Tree = {
+ val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable]
+ val subst = new symtab.TreeSubstituter(from.asInstanceOf[List[symtab.Symbol]], to.asInstanceOf[List[symtab.Tree]])
+ subst.transform(t.asInstanceOf[symtab.Tree]).asInstanceOf[Tree]
+ }
/** Map a list of arguments to:
@@ -362,4 +386,102 @@ private[async] trait TransformUtils {
else withAnnotation(tp, Annotation(UncheckedBoundsClass.asType.toType, Nil, ListMap()))
}
// =====================================
+
+ /**
+ * Efficiently decorate each subtree within `t` with the result of `t exists isAwait`,
+ * and return a function that can be used on derived trees to efficiently test the
+ * same condition.
+ *
+ * If the derived tree contains synthetic wrapper trees, these will be recursed into
+ * in search of a sub tree that was decorated with the cached answer.
+ */
+ final def containsAwaitCached(t: Tree): Tree => Boolean = {
+ val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
+ def attachContainsAwait(t: Tree): Unit = {
+ val t1 = t.asInstanceOf[symtab.Tree]
+ t1.updateAttachment(ContainsAwait)
+ t1.removeAttachment[NoAwait.type]
+ }
+ def attachNoAwait(t: Tree): Unit = {
+ val t1 = t.asInstanceOf[symtab.Tree]
+ t1.updateAttachment(NoAwait)
+ }
+ object markContainsAwaitTraverser extends Traverser {
+ var stack: List[Tree] = Nil
+
+ override def traverse(tree: Tree): Unit = {
+ stack ::= tree
+ try {
+ if (isAwait(tree))
+ stack.foreach(attachContainsAwait)
+ else
+ attachNoAwait(tree)
+ super.traverse(tree)
+ } finally stack = stack.tail
+ }
+ }
+ markContainsAwaitTraverser.traverse(t)
+
+ (t: Tree) => {
+ val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
+ object traverser extends Traverser {
+ var containsAwait = false
+ override def traverse(tree: Tree): Unit = {
+ if (tree.asInstanceOf[symtab.Tree].hasAttachment[NoAwait.type])
+ ()
+ else if (tree.asInstanceOf[symtab.Tree].hasAttachment[ContainsAwait.type])
+ containsAwait = true
+ else super.traverse(tree)
+ }
+ }
+ traverser.traverse(t)
+ traverser.containsAwait
+ }
+ }
+
+ final def adjustTypeOfTranslatedPatternMatches(t: Tree, owner: Symbol): Tree = {
+ import definitions.UnitTpe
+ typingTransform(t, owner) {
+ (tree, api) =>
+ tree match {
+ case Block(stats, expr) =>
+ val stats1 = stats map api.recur
+ val expr1 = api.recur(expr)
+ if (expr1.tpe =:= UnitTpe)
+ internal.setType(treeCopy.Block(tree, stats1, expr1), UnitTpe)
+ else
+ treeCopy.Block(tree, stats1, expr1)
+ case If(cond, thenp, elsep) =>
+ val cond1 = api.recur(cond)
+ val thenp1 = api.recur(thenp)
+ val elsep1 = api.recur(elsep)
+ if (thenp1.tpe =:= definitions.UnitTpe && elsep.tpe =:= UnitTpe)
+ internal.setType(treeCopy.If(tree, cond1, thenp1, elsep1), UnitTpe)
+ else
+ treeCopy.If(tree, cond1, thenp1, elsep1)
+ case Apply(fun, args) if isLabel(fun.symbol) =>
+ internal.setType(treeCopy.Apply(tree, api.recur(fun), args map api.recur), UnitTpe)
+ case t => api.default(t)
+ }
+ }
+ }
+
+ final def mkMutableField(tpt: Type, name: TermName, init: Tree): List[Tree] = {
+ if (isPastTyper) {
+ // If we are running after the typer phase (ie being called from a compiler plugin)
+ // we have to create the trio of members manually.
+ val ACCESSOR = (1L << 27).asInstanceOf[FlagSet]
+ val STABLE = (1L << 22).asInstanceOf[FlagSet]
+ val field = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name + " ", TypeTree(tpt), init)
+ val getter = DefDef(Modifiers(ACCESSOR | STABLE), name, Nil, Nil, TypeTree(tpt), Select(This(tpnme.EMPTY), field.name))
+ val setter = DefDef(Modifiers(ACCESSOR), name + "_=", Nil, List(List(ValDef(NoMods, TermName("x"), TypeTree(tpt), EmptyTree))), TypeTree(definitions.UnitTpe), Assign(Select(This(tpnme.EMPTY), field.name), Ident(TermName("x"))))
+ field :: getter :: setter :: Nil
+ } else {
+ val result = ValDef(NoMods, name, TypeTree(tpt), init)
+ result :: Nil
+ }
+ }
}
+
+case object ContainsAwait
+case object NoAwait \ No newline at end of file