diff options
author | Adriaan Moors <adriaan.moors@epfl.ch> | 2012-03-15 19:07:15 +0100 |
---|---|---|
committer | Adriaan Moors <adriaan.moors@epfl.ch> | 2012-03-20 19:53:50 +0100 |
commit | cd3d342032613e52e5917f3900f2461536a54e26 (patch) | |
tree | 9fa6a47822788ae5286798572aa605a784223c89 | |
parent | 3e0f24d2e78aa3d8dace0b2b253dfe7870b330fe (diff) | |
download | scala-cd3d342032613e52e5917f3900f2461536a54e26.tar.gz scala-cd3d342032613e52e5917f3900f2461536a54e26.tar.bz2 scala-cd3d342032613e52e5917f3900f2461536a54e26.zip |
[vpm] tailcalls support for jumpy vpm
-rw-r--r-- | src/compiler/scala/tools/nsc/transform/TailCalls.scala | 126 | ||||
-rw-r--r-- | src/compiler/scala/tools/nsc/typechecker/PatMatVirtualiser.scala | 5 |
2 files changed, 119 insertions, 12 deletions
diff --git a/src/compiler/scala/tools/nsc/transform/TailCalls.scala b/src/compiler/scala/tools/nsc/transform/TailCalls.scala index fdb5c7e52e..6ebecb02c6 100644 --- a/src/compiler/scala/tools/nsc/transform/TailCalls.scala +++ b/src/compiler/scala/tools/nsc/transform/TailCalls.scala @@ -36,6 +36,8 @@ abstract class TailCalls extends Transform { } } + private def hasSynthCaseSymbol(t: Tree) = (t.symbol ne null) && (t.symbol hasFlag (Flags.CASE | Flags.SYNTHETIC)) + /** * A Tail Call Transformer * @@ -87,10 +89,22 @@ abstract class TailCalls extends Transform { class TailCallElimination(unit: CompilationUnit) extends Transformer { private val defaultReason = "it contains a recursive call not in tail position" + /** Has the label been accessed? Then its symbol is in this set. */ + private val accessed = new collection.mutable.HashSet[Symbol]() + // `accessed` was stored as boolean in the current context -- this is no longer tenable + // with jumps to labels in tailpositions now considered in tailposition, + // a downstream context may access the label, and the upstream one will be none the wiser + // this is necessary because tail-calls may occur in places where syntactically they seem impossible + // (since we now consider jumps to labels that are in tailposition, such as matchEnd(x) {x}) + + class Context() { /** The current method */ var method: Symbol = NoSymbol + // symbols of label defs in this method that are in tail position + var tailLabels: Set[Symbol] = Set() + /** The current tail-call label */ var label: Symbol = NoSymbol @@ -104,24 +118,20 @@ abstract class TailCalls extends Transform { var failReason = defaultReason var failPos = method.pos - /** Has the label been accessed? */ - var accessed = false - def this(that: Context) = { this() this.method = that.method this.tparams = that.tparams this.tailPos = that.tailPos - this.accessed = that.accessed this.failPos = that.failPos this.label = that.label + this.tailLabels = that.tailLabels } def this(dd: DefDef) { this() this.method = dd.symbol this.tparams = dd.tparams map (_.symbol) this.tailPos = true - this.accessed = false this.failPos = dd.pos /** Create a new method symbol for the current method and store it in @@ -141,14 +151,14 @@ abstract class TailCalls extends Transform { def isEligible = method.isEffectivelyFinal // @tailrec annotation indicates mandatory transformation def isMandatory = method.hasAnnotation(TailrecClass) && !forMSIL - def isTransformed = isEligible && accessed + def isTransformed = isEligible && accessed(label) def tailrecFailure() = unit.error(failPos, "could not optimize @tailrec annotated " + method + ": " + failReason) def newThis(pos: Position) = method.newValue(nme.THIS, pos, SYNTHETIC) setInfo currentClass.typeOfThis override def toString(): String = ( "" + method.name + " tparams: " + tparams + " tailPos: " + tailPos + - " accessed: " + accessed + "\nLabel: " + label + "\nLabel type: " + label.info + " Label: " + label + " Label type: " + label.info ) } @@ -206,7 +216,7 @@ abstract class TailCalls extends Transform { def rewriteTailCall(recv: Tree): Tree = { debuglog("Rewriting tail recursive call: " + fun.pos.lineContent.trim) - ctx.accessed = true + accessed += ctx.label typedPos(fun.pos)(Apply(Ident(ctx.label), recv :: transformArgs)) } @@ -242,10 +252,16 @@ abstract class TailCalls extends Transform { unit.error(tree.pos, "@tailrec annotated method contains no recursive calls") } } - debuglog("Considering " + dd.name + " for tailcalls") + + // labels are local to a method, so only traverse the rhs of a defdef + val collectTailPosLabels = new TailPosLabelsTraverser + collectTailPosLabels traverse rhs0 + newCtx.tailLabels = collectTailPosLabels.tailLabels.toSet + + debuglog("Considering " + dd.name + " for tailcalls, with labels in tailpos: "+ newCtx.tailLabels) val newRHS = transform(rhs0, newCtx) - deriveDefDef(tree)(rhs => + deriveDefDef(tree){rhs => if (newCtx.isTransformed) { /** We have rewritten the tree, but there may be nested recursive calls remaining. * If @tailrec is given we need to fail those now. @@ -270,8 +286,22 @@ abstract class TailCalls extends Transform { newRHS } + } + + // a translated match + case Block(stats, expr) if stats forall hasSynthCaseSymbol => + // the assumption is once we encounter a case, the remainder of the block will consist of cases + // the prologue may be empty, usually it is the valdef that stores the scrut + val (prologue, cases) = stats span (s => !s.isInstanceOf[LabelDef]) + treeCopy.Block(tree, + noTailTransforms(prologue) ++ transformTrees(cases), + transform(expr) ) + // a translated casedef + case LabelDef(_, _, body) if hasSynthCaseSymbol(tree) => + deriveLabelDef(tree)(transform) + case Block(stats, expr) => treeCopy.Block(tree, noTailTransforms(stats), @@ -308,8 +338,18 @@ abstract class TailCalls extends Transform { case Apply(fun, args) => if (fun.symbol == Boolean_or || fun.symbol == Boolean_and) treeCopy.Apply(tree, fun, transformTrees(args)) - else - rewriteApply(fun, fun, Nil, args) + else if (fun.symbol.isLabel && args.nonEmpty && args.tail.isEmpty && ctx.tailLabels(fun.symbol)) { + // this is to detect tailcalls in translated matches + // it's a one-argument call to a label that is in a tailposition and that looks like label(x) {x} + // thus, the argument to the call is in tailposition and we don't need to jump to the label, tail jump instead + val saved = ctx.tailPos + ctx.tailPos = true + debuglog("in tailpos label: "+ args.head) + val res = transform(args.head) + ctx.tailPos = saved + if (res ne args.head) res // we tail-called -- TODO: shield from false-positives where we rewrite but don't tail-call + else rewriteApply(fun, fun, Nil, args) + } else rewriteApply(fun, fun, Nil, args) case Alternative(_) | Star(_) | Bind(_, _) => sys.error("We should've never gotten inside a pattern") @@ -320,4 +360,66 @@ abstract class TailCalls extends Transform { } } } + + // collect the LabelDefs (generated by the pattern matcher) in a DefDef that are in tail position + // the labels all look like: matchEnd(x) {x} + // then, in a forward jump `matchEnd(expr)`, `expr` is considered in tail position (and the matchEnd jump is replaced by the jump generated by expr) + class TailPosLabelsTraverser extends Traverser { + val tailLabels = new collection.mutable.ListBuffer[Symbol]() + + private var maybeTail: Boolean = true // since we start in the rhs of a DefDef + + def traverse(tree: Tree, maybeTailNew: Boolean): Unit = { + val saved = maybeTail + maybeTail = maybeTailNew + try traverse(tree) + finally maybeTail = saved + } + + def traverseNoTail(tree: Tree) = traverse(tree, false) + def traverseTreesNoTail(trees: List[Tree]) = trees foreach traverseNoTail + + override def traverse(tree: Tree) = tree match { + case LabelDef(_, List(arg), body@Ident(_)) if arg.symbol == body.symbol => // we're looking for label(x){x} in tail position, since that means `a` is in tail position in a call `label(a)` + if (maybeTail) tailLabels += tree.symbol + + // a translated casedef + case LabelDef(_, _, body) if hasSynthCaseSymbol(tree) => + traverse(body) + + // a translated match + case Block(stats, expr) if stats forall hasSynthCaseSymbol => + // the assumption is once we encounter a case, the remainder of the block will consist of cases + // the prologue may be empty, usually it is the valdef that stores the scrut + val (prologue, cases) = stats span (s => !s.isInstanceOf[LabelDef]) + traverseTreesNoTail(prologue) // selector (may be absent) + traverseTrees(cases) + traverse(expr) + + case CaseDef(pat, guard, body) => + traverse(body) + + case Match(selector, cases) => + traverseNoTail(selector) + traverseTrees(cases) + + case dd @ DefDef(_, _, _, _, _, _) => // we are run per-method + + case Block(stats, expr) => + traverseTreesNoTail(stats) + traverse(expr) + + case If(cond, thenp, elsep) => + traverse(thenp) + traverse(elsep) + + case Try(block, catches, finalizer) => + traverseNoTail(block) + traverseTreesNoTail(catches) + traverseNoTail(finalizer) + + case EmptyTree | Super(_, _) | This(_) | Select(_, _) | Ident(_) | Literal(_) | Function(_, _) | TypeTree() => + case _ => super.traverse(tree) + } + } } diff --git a/src/compiler/scala/tools/nsc/typechecker/PatMatVirtualiser.scala b/src/compiler/scala/tools/nsc/typechecker/PatMatVirtualiser.scala index 552c236bde..0422da54e0 100644 --- a/src/compiler/scala/tools/nsc/typechecker/PatMatVirtualiser.scala +++ b/src/compiler/scala/tools/nsc/typechecker/PatMatVirtualiser.scala @@ -1651,6 +1651,11 @@ class Foo(x: Other) { x._1 } // no error in this order LabelDef(nextCase, Nil, catchAllGen(casegen)(scrutRef)) } toList + // the generated block is taken apart in TailCalls under the following assumptions + // the assumption is once we encounter a case, the remainder of the block will consist of cases + // the prologue may be empty, usually it is the valdef that stores the scrut + // val (prologue, cases) = stats span (s => !s.isInstanceOf[LabelDef]) + val prologue = if(scrutSym ne NoSymbol) List(VAL(scrutSym) === scrut) else Nil Block( prologue ++ (cases map caseDef) ++ catchAll, |