package dotty.tools.dotc.transform import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, TreeTransform, TreeTransformer} import dotty.tools.dotc.ast.{Trees, tpd} import dotty.tools.dotc.core.Contexts.Context import scala.collection.mutable.ListBuffer import dotty.tools.dotc.core._ import dotty.tools.dotc.core.Symbols.NoSymbol import scala.annotation.tailrec import Types._, Contexts._, Constants._, Names._, NameOps._, Flags._ import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._ import Decorators._ import Symbols._ import scala.Some import dotty.tools.dotc.transform.TreeTransforms.{NXTransformations, TransformerInfo, TreeTransform, TreeTransformer} import dotty.tools.dotc.core.Contexts.Context import scala.collection.mutable import dotty.tools.dotc.core.Names.Name import NameOps._ import dotty.tools.dotc.CompilationUnit import dotty.tools.dotc.util.Positions.{Position, Coord} import dotty.tools.dotc.util.Positions.NoPosition import dotty.tools.dotc.core.DenotTransformers.DenotTransformer import dotty.tools.dotc.core.Denotations.SingleDenotation import dotty.tools.dotc.transform.TailRec._ /** * A Tail Rec Transformer * * @author Erik Stenman, Iulian Dragos, * ported to dotty by Dmitry Petrashko * @version 1.1 * * What it does: *

* Finds method calls in tail-position and replaces them with jumps. * A call is in a tail-position if it is the last instruction to be * executed in the body of a method. This is done by recursing over * the trees that may contain calls in tail-position (trees that can't * contain such calls are not transformed). However, they are not that * many. *

*

* Self-recursive calls in tail-position are replaced by jumps to a * label at the beginning of the method. As the JVM provides no way to * jump from a method to another one, non-recursive calls in * tail-position are not optimized. *

*

* A method call is self-recursive if it calls the current method and * the method is final (otherwise, it could * be a call to an overridden method in a subclass). * * Recursive calls on a different instance * are optimized. Since 'this' is not a local variable it s added as * a label parameter. *

*

* This phase has been moved before pattern matching to catch more * of the common cases of tail recursive functions. This means that * more cases should be taken into account (like nested function, and * pattern cases). *

*

* If a method contains self-recursive calls, a label is added to at * the beginning of its body and the calls are replaced by jumps to * that label. *

*

* * In scalac, If the method had type parameters, the call must contain same * parameters as type arguments. This is no longer case in dotc. * In scalac, this is named tailCall but it does only provide optimization for * self recursive functions, that's why it's renamed to tailrec *

*/ class TailRec extends TreeTransform with DenotTransformer { import tpd._ override def transform(ref: SingleDenotation)(implicit ctx: Context): SingleDenotation = ref override def name: String = "tailrec" final val labelPrefix = "tailLabel" private def mkLabel(method: Symbol, tp: Type)(implicit c: Context): TermSymbol = { val name = c.freshName(labelPrefix) c.newSymbol(method, name.toTermName, Flags.Synthetic, tp) } override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { tree match { case dd@DefDef(mods, name, tparams, vparamss0, tpt, rhs0) if (dd.symbol.isEffectivelyFinal) && !((dd.symbol is Flags.Accessor) || (rhs0 eq EmptyTree)) => val mandatory = dd.symbol.hasAnnotation(defn.TailrecAnnotationClass) cpy.DefDef(tree, mods, name, tparams, vparamss0, tpt, rhs = { val owner = ctx.owner.enclosingClass val thisTpe = owner.thisType val newType: Type = dd.tpe.widen match { case t: PolyType => PolyType(t.paramNames)(x => t.paramBounds, x => MethodType(List(nme.THIS), List(thisTpe), t.resultType)) case t => MethodType(List(nme.THIS), List(thisTpe), t) } val label = mkLabel(dd.symbol, newType) var rewrote = false // Note: this can be split in two separate transforms(in different groups), // than first one will collect info about which transformations and rewritings should be applied // and second one will actually apply, // now this speculatively transforms tree and throws away result in many cases val res = tpd.Closure(label, args => { val thiz = args.head.head val argMapping: Map[Symbol, Tree] = (vparamss0.flatten.map(_.symbol) zip args.tail.flatten).toMap val transformer = new TailRecElimination(dd.symbol, thiz, argMapping, owner, mandatory, label) val rhs = transformer.transform(rhs0)(ctx.withPhase(ctx.phase.next)) rewrote = transformer.rewrote rhs }, tparams) if (rewrote) res else { if (mandatory) ctx.error("TailRec optimisation not applicable, method not tail recursive", dd.pos) rhs0 } }) case d: DefDef if d.symbol.hasAnnotation(defn.TailrecAnnotationClass) => ctx.error("TailRec optimisation not applicable, method is neither private nor final so can be overridden", d.pos) d case d if d.symbol.hasAnnotation(defn.TailrecAnnotationClass) => ctx.error("TailRec optimisation not applicable, not a method", d.pos) d case _ => tree } } class TailRecElimination(method: Symbol, thiz: Tree, argMapping: Map[Symbol, Tree], enclosingClass: Symbol, isMandatory: Boolean, label: Symbol) extends tpd.TreeMap { import tpd._ var rewrote = false private val defaultReason = "it contains a recursive call not in tail position" private var ctx: TailContext = yesTailContext /** Rewrite this tree to contain no tail recursive calls */ def transform(tree: Tree, nctx: TailContext)(implicit c: Context): Tree = { if (ctx == nctx) transform(tree) else { val saved = ctx ctx = nctx try transform(tree) finally this.ctx = saved } } def yesTailTransform(tree: Tree)(implicit c: Context): Tree = transform(tree, yesTailContext) def noTailTransform(tree: Tree)(implicit c: Context): Tree = transform(tree, noTailContext) def noTailTransforms(trees: List[Tree])(implicit c: Context) = trees map (noTailTransform) override def transform(tree: Tree)(implicit c: Context): Tree = { /* A possibly polymorphic apply to be considered for tail call transformation. */ def rewriteApply(tree: Tree, sym: Symbol): Tree = { def receiverArgumentsAndSymbol(t: Tree, accArgs: List[List[Tree]] = Nil, accT: List[Tree] = Nil): (Tree, Tree, List[List[Tree]], List[Tree], Symbol) = t match { case TypeApply(fun, targs) if fun.symbol eq t.symbol => receiverArgumentsAndSymbol(fun, accArgs, targs) case Apply(fn, args) if fn.symbol == t.symbol => receiverArgumentsAndSymbol(fn, args :: accArgs, accT) case Select(qual, _) => (qual, t, accArgs, accT, t.symbol) case x: This => (x, x, accArgs, accT, x.symbol) case x: Ident if x.symbol eq method => (EmptyTree, x, accArgs, accT, x.symbol) case x => (x, x, accArgs, accT, x.symbol) } val (reciever, call, arguments, typeArguments, symbol) = receiverArgumentsAndSymbol(tree) val recv = noTailTransform(reciever) val targs = typeArguments.map(noTailTransform) val argumentss = arguments.map(noTailTransforms) val receiverIsSame = enclosingClass.typeRef.widen =:= recv.tpe.widen val receiverIsSuper = (method.name eq sym) && enclosingClass.typeRef.widen <:< recv.tpe.widen val receiverIsThis = recv.tpe.widen =:= thiz.tpe.widen val isRecursiveCall = (method eq sym) def continue = { val method = noTailTransform(call) val methodWithTargs = if (targs.nonEmpty) TypeApply(method, targs) else method if (methodWithTargs.tpe.widen.isParameterless) methodWithTargs else argumentss.foldLeft(methodWithTargs) { case (method, args) => Apply(method, args) } } def fail(reason: String) = { if (isMandatory) c.error(s"Cannot rewrite recursive call: $reason", tree.pos) else c.debuglog("Cannot rewrite recursive call at: " + tree.pos + " because: " + reason) continue } def rewriteTailCall(recv: Tree): Tree = { c.debuglog("Rewriting tail recursive call: " + tree.pos) rewrote = true val method = if (targs.nonEmpty) TypeApply(Ident(label.termRef), targs) else Ident(label.termRef) val recv = noTailTransform(reciever) if (recv.tpe.widen.isParameterless) method else argumentss.foldLeft(Apply(method, List(recv))) { case (method, args) => Apply(method, args) } } if (isRecursiveCall) { if (ctx.tailPos) { if (recv eq EmptyTree) rewriteTailCall(thiz) else if (receiverIsSame || receiverIsThis) rewriteTailCall(recv) else fail("it changes type of 'this' on a polymorphic recursive call") } else fail(defaultReason) } else { if (receiverIsSuper) fail("it contains a recursive call targeting a supertype") else continue } } def rewriteTry(tree: Try): Tree = { def transformHandlers(t: Tree): Tree = { t match { case Block(List((d: DefDef)), cl@Closure(Nil, _, EmptyTree)) => val newDef = cpy.DefDef(d, d.mods, d.name, d.tparams, d.vparamss, d.tpt, transform(d.rhs)) Block(List(newDef), cl) case _ => assert(false, s"failed to deconstruct try handler ${t.show}"); ??? } } if (tree.finalizer eq EmptyTree) { // SI-1672 Catches are in tail position when there is no finalizer tpd.cpy.Try(tree, noTailTransform(tree.expr), transformHandlers(tree.handler), EmptyTree ) } else { tpd.cpy.Try(tree, noTailTransform(tree.expr), noTailTransform(tree.handler), noTailTransform(tree.finalizer) ) } } val res: Tree = tree match { case Block(stats, expr) => tpd.cpy.Block(tree, noTailTransforms(stats), transform(expr) ) case t@CaseDef(pat, guard, body) => cpy.CaseDef(t, pat, guard, transform(body)) case If(cond, thenp, elsep) => tpd.cpy.If(tree, transform(cond), transform(thenp), transform(elsep) ) case Match(selector, cases) => tpd.cpy.Match(tree, noTailTransform(selector), transformSub(cases) ) case t: Try => rewriteTry(t) case Apply(fun, args) if fun.symbol == defn.Boolean_or || fun.symbol == defn.Boolean_and => tpd.cpy.Apply(tree, fun, transform(args)) case Apply(fun, args) => rewriteApply(tree, fun.symbol) case Alternative(_) | Bind(_, _) => assert(false, "We should've never gotten inside a pattern") tree case This(cls) if cls eq enclosingClass => thiz case Select(qual, name) => val sym = tree.symbol if (sym == method && ctx.tailPos) rewriteApply(tree, sym) else tpd.cpy.Select(tree, noTailTransform(qual), name) case ValDef(_, _, _, _) | EmptyTree | Super(_, _) | This(_) | Literal(_) | TypeTree(_) | DefDef(_, _, _, _, _, _) | TypeDef(_, _, _) => tree case Ident(qual) => val sym = tree.symbol if (sym == method && ctx.tailPos) rewriteApply(tree, sym) else argMapping.get(sym) match { case Some(rewrite) => rewrite case None => tree.tpe match { case TermRef(ThisType(`enclosingClass`), _) => if (sym.flags is Flags.Local) { // trying to access private[this] member. toggle flag in order to access. val d = sym.denot val newDenot = d.copySymDenotation(initFlags = sym.flags &~ Flags.Local) newDenot.installAfter(TailRec.this) } Select(thiz, sym) case _ => tree } } case _ => super.transform(tree) } res } } } object TailRec { final class TailContext(val tailPos: Boolean) extends AnyVal final val noTailContext = new TailContext(false) final val yesTailContext = new TailContext(true) }