summaryrefslogtreecommitdiff
path: root/src/compiler/scala/tools/nsc/transform/TailCalls.scala
diff options
context:
space:
mode:
authorPaul Phillips <paulp@improving.org>2011-01-10 07:21:56 +0000
committerPaul Phillips <paulp@improving.org>2011-01-10 07:21:56 +0000
commite05dfaeabf430dac8909ff9e5a5911b0c94101ae (patch)
tree397ce22c41c4c5b95390625ef73fcb70cef856b6 /src/compiler/scala/tools/nsc/transform/TailCalls.scala
parent58b5c24df8ba06c02b92723922bc21381fedeb93 (diff)
downloadscala-e05dfaeabf430dac8909ff9e5a5911b0c94101ae.tar.gz
scala-e05dfaeabf430dac8909ff9e5a5911b0c94101ae.tar.bz2
scala-e05dfaeabf430dac8909ff9e5a5911b0c94101ae.zip
A pretty severe bug in the recognition of tail ...
A pretty severe bug in the recognition of tail call elimination. It turns out that Tailcalls will perform "partial elimination" in situations such as: @annotation.tailrec final def f(x: Int): Int = f(f(x)) The outer call to f1 becomes a jump, but the inner call remains as it was. I implemented @tailrec under the impression that if the optimization had taken place, it had gone all the way. So this is now fixed with a direct examination of the rewritten tree. While I was in there I threw in some improved error reporting: the error positioning is now on the call which is not in tail position rather than the method declaration. Closes #4135, no review.
Diffstat (limited to 'src/compiler/scala/tools/nsc/transform/TailCalls.scala')
-rw-r--r--src/compiler/scala/tools/nsc/transform/TailCalls.scala227
1 files changed, 100 insertions, 127 deletions
diff --git a/src/compiler/scala/tools/nsc/transform/TailCalls.scala b/src/compiler/scala/tools/nsc/transform/TailCalls.scala
index c7a3e6a778..214248e1f2 100644
--- a/src/compiler/scala/tools/nsc/transform/TailCalls.scala
+++ b/src/compiler/scala/tools/nsc/transform/TailCalls.scala
@@ -13,13 +13,10 @@ import scala.tools.nsc.symtab.Flags
* @author Iulian Dragos
* @version 1.0
*/
-abstract class TailCalls extends Transform
- /* with JavaLogging() */ {
- // inherits abstract value `global' and class `Phase' from Transform
-
- import global._ // the global environment
- import definitions._ // standard classes and methods
- import typer.{typed, atOwner} // methods to type trees
+abstract class TailCalls extends Transform {
+ import global._ // the global environment
+ import definitions._ // standard classes and methods
+ import typer.{ typed, typedPos } // methods to type trees
val phaseName: String = "tailcalls"
@@ -87,6 +84,7 @@ abstract class TailCalls extends Transform
* </p>
*/
class TailCallElimination(unit: CompilationUnit) extends Transformer {
+ private val defaultReason = "it contains a recursive call not in tail position"
class Context {
/** The current method */
@@ -102,7 +100,8 @@ abstract class TailCalls extends Transform
var tailPos = false
/** The reason this method could not be optimized. */
- var tailrecFailReason = "reason indeterminate"
+ var failReason = defaultReason
+ var failPos: Position = null
/** Is the label accessed? */
var accessed = false
@@ -114,8 +113,11 @@ abstract class TailCalls extends Transform
this.tparams = that.tparams
this.tailPos = that.tailPos
this.accessed = that.accessed
+ this.failPos = that.failPos
}
+ def enclosingType = currentMethod.enclClass.typeOfThis
+
/** Create a new method symbol for the current method and store it in
* the label field.
*/
@@ -130,84 +132,82 @@ abstract class TailCalls extends Transform
)
}
- private def mkContext(that: Context) = new Context(that)
- private def mkContext(that: Context, tp: Boolean): Context = {
- val t = mkContext(that)
- t.tailPos = tp
+ private var ctx: Context = new Context()
+ private def noTailContext() = {
+ val t = new Context(ctx)
+ t.tailPos = false
t
}
- private var ctx: Context = new Context()
- private def enclosingType = ctx.currentMethod.enclClass.typeOfThis
-
/** Rewrite this tree to contain no tail recursive calls */
def transform(tree: Tree, nctx: Context): Tree = {
- val oldCtx = ctx
+ val saved = ctx
ctx = nctx
- val t = transform(tree)
- this.ctx = oldCtx
- t
+ try transform(tree)
+ finally this.ctx = saved
+ }
+
+ def noTailTransform(tree: Tree): Tree = transform(tree, noTailContext())
+ def noTailTransforms(trees: List[Tree]) = {
+ val nctx = noTailContext()
+ trees map (t => transform(t, nctx))
}
override def transform(tree: Tree): Tree = {
/** A possibly polymorphic apply to be considered for tail call transformation.
*/
def rewriteApply(target: Tree, fun: Tree, targs: List[Tree], args: List[Tree]) = {
- def receiver = fun match {
- case Select(qual, _) => Some(qual)
- case _ => None
+ val receiver: Tree = fun match {
+ case Select(qual, _) => qual
+ case _ => EmptyTree
}
- def receiverIsSame = receiver exists (enclosingType.widen =:= _.tpe.widen)
- def receiverIsSuper = receiver exists (enclosingType.widen <:< _.tpe.widen)
+ def receiverIsSame = ctx.enclosingType.widen =:= receiver.tpe.widen
+ def receiverIsSuper = ctx.enclosingType.widen <:< receiver.tpe.widen
def isRecursiveCall = ctx.currentMethod eq fun.symbol
def isMandatory = ctx.currentMethod hasAnnotation TailrecClass
def isEligible = ctx.currentMethod.isEffectivelyFinal
- def transformArgs = transformTrees(args, mkContext(ctx, false))
+ def transformArgs = noTailTransforms(args)
def matchesTypeArgs = ctx.tparams sameElements (targs map (_.tpe.typeSymbol))
- def defaultTree = treeCopy.Apply(tree, target, transformArgs)
/** Records failure reason in Context for reporting.
*/
def cannotRewrite(reason: String) = {
- if (isMandatory)
- ctx.tailrecFailReason = reason
+ ctx.failReason = reason
+ ctx.failPos = fun.pos
- defaultTree
+ treeCopy.Apply(tree, target, transformArgs)
}
def notRecursiveReason() =
if (receiverIsSuper) "it contains a recursive call targetting a supertype"
else "it contains a recursive call not in tail position"
- def rewriteTailCall(receiver: Tree, otherArgs: List[Tree]): Tree = {
+ def rewriteTailCall(recv: Tree, otherArgs: List[Tree]): Tree = {
log("Rewriting tail recursive method call at: " + fun.pos)
ctx.accessed = true
- typed { atPos(fun.pos)(Apply(Ident(ctx.label), receiver :: otherArgs)) }
+ typedPos(fun.pos)(Apply(Ident(ctx.label), recv :: otherArgs))
}
if (!isRecursiveCall) cannotRewrite(notRecursiveReason())
else if (!isEligible) cannotRewrite("it is neither private nor final so can be overridden")
- else if (!ctx.tailPos) cannotRewrite("it contains a recursive call not in tail position")
+ else if (!ctx.tailPos) cannotRewrite(defaultReason)
else if (!matchesTypeArgs) cannotRewrite("it is called recursively with different type arguments")
- else receiver match {
- case Some(qual) =>
- if (forMSIL) cannotRewrite("it cannot be optimized on MSIL")
- else if (!receiverIsSame) cannotRewrite("it changes type of 'this' on a polymorphic recursive call")
- else rewriteTailCall(qual, transformArgs)
- case _ => rewriteTailCall(This(currentClass), transformArgs)
- }
+ else if (receiver == EmptyTree) rewriteTailCall(This(currentClass), transformArgs)
+ else if (forMSIL) cannotRewrite("it cannot be optimized on MSIL")
+ else if (!receiverIsSame) cannotRewrite("it changes type of 'this' on a polymorphic recursive call")
+ else rewriteTailCall(receiver, transformArgs)
}
tree match {
-
case dd @ DefDef(mods, name, tparams, vparams, tpt, rhs) =>
log("Entering DefDef: " + name)
- val newCtx = mkContext(ctx)
+ val newCtx = new Context(ctx)
+ newCtx.failPos = dd.pos
newCtx.currentMethod = tree.symbol
newCtx.makeLabel()
val currentClassParam = tree.symbol.newSyntheticValueParam(currentClass.typeOfThis)
- newCtx.label.setInfo(MethodType(currentClassParam :: tree.symbol.tpe.params, tree.symbol.tpe.finalResultType))
+ newCtx.label setInfo MethodType(currentClassParam :: tree.symbol.tpe.params, tree.symbol.tpe.finalResultType)
newCtx.tailPos = true
val isEligible = newCtx.currentMethod.isEffectivelyFinal
@@ -219,120 +219,93 @@ abstract class TailCalls extends Transform
tree.symbol.tpe match {
case PolyType(tpes, restpe) =>
newCtx.tparams = tparams map (_.symbol)
- newCtx.label.setInfo(
- newCtx.label.tpe.substSym(tpes, tparams map (_.symbol)))
+ newCtx.label setInfo newCtx.label.tpe.substSym(tpes, tparams map (_.symbol))
case _ =>
}
}
+ val newRHS = transform(rhs, newCtx)
+ def tailrecFailure(pos: Position, reason: String) {
+ unit.error(pos, "could not optimize @tailrec annotated " + newCtx.currentMethod + ": " + reason)
+ }
- val t1 = treeCopy.DefDef(tree, mods, name, tparams, vparams, tpt, {
- val transformed = transform(rhs, newCtx)
-
- transformed match {
- case newRHS if isEligible && newCtx.accessed =>
- log("Rewrote def " + newCtx.currentMethod)
- val newThis = newCtx.currentMethod
- . newValue (tree.pos, nme.THIS)
- . setInfo (currentClass.typeOfThis)
- . setFlag (Flags.SYNTHETIC)
-
- typed(atPos(tree.pos)(Block(
- List(ValDef(newThis, This(currentClass))),
- LabelDef(newCtx.label, newThis :: (vparams.flatten map (_.symbol)), newRHS)
- )))
- case rhs =>
- if (isMandatory)
- unit.error(dd.pos, "could not optimize @tailrec annotated method: " + newCtx.tailrecFailReason)
-
- rhs
+ treeCopy.DefDef(tree, mods, name, tparams, vparams, tpt, {
+ if (isEligible && newCtx.accessed) {
+ /** We have rewritten the tree, but there may be nested recursive calls remaining.
+ * If @tailrec is given we need to fail those now.
+ */
+ if (isMandatory) {
+ for (t @ Apply(fn, _) <- newRHS ; if fn.symbol == newCtx.currentMethod)
+ tailrecFailure(t.pos, defaultReason)
+ }
+
+ val newThis = newCtx.currentMethod
+ . newValue (tree.pos, nme.THIS)
+ . setInfo (currentClass.typeOfThis)
+ . setFlag (Flags.SYNTHETIC)
+
+ typedPos(tree.pos)(Block(
+ List(ValDef(newThis, This(currentClass))),
+ LabelDef(newCtx.label, newThis :: (vparams.flatten map (_.symbol)), newRHS)
+ ))
}
- })
+ else {
+ if (isMandatory)
+ tailrecFailure(newCtx.failPos, newCtx.failReason)
- log("Leaving DefDef: " + name)
- t1
-
- case EmptyTree => tree
-
- case PackageDef(_, _) =>
- super.transform(tree)
-
- case ClassDef(_, name, _, _) =>
- log("Entering class " + name)
- val res = super.transform(tree)
- log("Leaving class " + name)
- res
-
- case ValDef(mods, name, tpt, rhs) => super.transform(tree)
- case LabelDef(name, params, rhs) => super.transform(tree)
-
- case Template(parents, self, body) =>
- super.transform(tree)
+ newRHS
+ }
+ })
case Block(stats, expr) =>
treeCopy.Block(tree,
- transformTrees(stats, mkContext(ctx, false)),
- transform(expr))
+ noTailTransforms(stats),
+ transform(expr)
+ )
case CaseDef(pat, guard, body) =>
- treeCopy.CaseDef(tree, pat, guard, transform(body))
-
- case Alternative(_) | Star(_) | Bind(_, _) =>
- throw new RuntimeException("We should've never gotten inside a pattern")
-
- case Function(vparams, body) =>
- tree
- //throw new RuntimeException("Anonymous function should not exist at this point. at: " + unit.position(tree.pos));
-
- case Assign(lhs, rhs) =>
- super.transform(tree)
+ treeCopy.CaseDef(tree,
+ pat,
+ guard,
+ transform(body)
+ )
case If(cond, thenp, elsep) =>
- treeCopy.If(tree, cond, transform(thenp), transform(elsep))
-
- case Match(selector, cases) => //super.transform(tree);
- treeCopy.Match(tree, transform(selector, mkContext(ctx, false)), transformTrees(cases).asInstanceOf[List[CaseDef]])
+ treeCopy.If(tree,
+ cond,
+ transform(thenp),
+ transform(elsep)
+ )
+
+ case Match(selector, cases) =>
+ treeCopy.Match(tree,
+ noTailTransform(selector),
+ transformTrees(cases).asInstanceOf[List[CaseDef]]
+ )
- case Return(expr) => super.transform(tree)
case Try(block, catches, finalizer) =>
// no calls inside a try are in tail position, but keep recursing for nested functions
- treeCopy.Try(tree, transform(block, mkContext(ctx, false)),
- transformTrees(catches, mkContext(ctx, false)).asInstanceOf[List[CaseDef]],
- transform(finalizer, mkContext(ctx, false)))
-
- case Throw(expr) => super.transform(tree)
- case New(tpt) => super.transform(tree)
- case Typed(expr, tpt) => super.transform(tree)
+ treeCopy.Try(tree,
+ noTailTransform(block),
+ noTailTransforms(catches).asInstanceOf[List[CaseDef]],
+ noTailTransform(finalizer)
+ )
case Apply(tapply @ TypeApply(fun, targs), vargs) =>
rewriteApply(tapply, fun, targs, vargs)
- case TypeApply(fun, args) =>
- super.transform(tree)
-
case Apply(fun, args) =>
if (fun.symbol == Boolean_or || fun.symbol == Boolean_and)
treeCopy.Apply(tree, fun, transformTrees(args))
else
rewriteApply(fun, fun, Nil, args)
- case Super(qual, mix) =>
- tree
- case This(qual) =>
- tree
- case Select(qualifier, selector) =>
- tree
- case Ident(name) =>
- tree
- case Literal(value) =>
- tree
- case TypeTree() =>
+ case Alternative(_) | Star(_) | Bind(_, _) =>
+ system.error("We should've never gotten inside a pattern")
+ case EmptyTree | Super(_, _) | This(_) | Select(_, _) | Ident(_) | Literal(_) | Function(_, _) | TypeTree() =>
tree
case _ =>
- tree
+ super.transform(tree)
}
}
-
- def transformTrees(trees: List[Tree], nctx: Context): List[Tree] =
- trees map ((tree) => transform(tree, nctx))
}
}