diff options
Diffstat (limited to 'compiler/src/dotty/tools/dotc/transform/TailRec.scala')
-rw-r--r-- | compiler/src/dotty/tools/dotc/transform/TailRec.scala | 384 |
1 files changed, 384 insertions, 0 deletions
diff --git a/compiler/src/dotty/tools/dotc/transform/TailRec.scala b/compiler/src/dotty/tools/dotc/transform/TailRec.scala new file mode 100644 index 000000000..dc4454439 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/TailRec.scala @@ -0,0 +1,384 @@ +package dotty.tools.dotc.transform + +import dotty.tools.dotc.ast.Trees._ +import dotty.tools.dotc.ast.{TreeTypeMap, tpd} +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Decorators._ +import dotty.tools.dotc.core.DenotTransformers.DenotTransformer +import dotty.tools.dotc.core.Denotations.SingleDenotation +import dotty.tools.dotc.core.Symbols._ +import dotty.tools.dotc.core.Types._ +import dotty.tools.dotc.core._ +import dotty.tools.dotc.transform.TailRec._ +import dotty.tools.dotc.transform.TreeTransforms.{MiniPhaseTransform, TransformerInfo} + +/** + * A Tail Rec Transformer + * @author Erik Stenman, Iulian Dragos, + * ported and heavily modified for dotty by Dmitry Petrashko + * @version 1.1 + * + * What it does: + * <p> + * Finds method calls in tail-position and replaces them with jumps. + * A call is in a tail-position if it is the last instruction to be + * executed in the body of a method. This is done by recursing over + * the trees that may contain calls in tail-position (trees that can't + * contain such calls are not transformed). However, they are not that + * many. + * </p> + * <p> + * Self-recursive calls in tail-position are replaced by jumps to a + * label at the beginning of the method. As the JVM provides no way to + * jump from a method to another one, non-recursive calls in + * tail-position are not optimized. + * </p> + * <p> + * A method call is self-recursive if it calls the current method and + * the method is final (otherwise, it could + * be a call to an overridden method in a subclass). + * + * Recursive calls on a different instance + * are optimized. Since 'this' is not a local variable it s added as + * a label parameter. + * </p> + * <p> + * This phase has been moved before pattern matching to catch more + * of the common cases of tail recursive functions. This means that + * more cases should be taken into account (like nested function, and + * pattern cases). + * </p> + * <p> + * If a method contains self-recursive calls, a label is added to at + * the beginning of its body and the calls are replaced by jumps to + * that label. + * </p> + * <p> + * + * In scalac, If the method had type parameters, the call must contain same + * parameters as type arguments. This is no longer case in dotc. + * In scalac, this is named tailCall but it does only provide optimization for + * self recursive functions, that's why it's renamed to tailrec + * </p> + */ +class TailRec extends MiniPhaseTransform with DenotTransformer with FullParameterization { thisTransform => + + import dotty.tools.dotc.ast.tpd._ + + override def transform(ref: SingleDenotation)(implicit ctx: Context): SingleDenotation = ref + + override def phaseName: String = "tailrec" + override def treeTransformPhase = thisTransform // TODO Make sure tailrec runs at next phase. + + final val labelPrefix = "tailLabel" + final val labelFlags = Flags.Synthetic | Flags.Label + + /** Symbols of methods that have @tailrec annotatios inside */ + private val methodsWithInnerAnnots = new collection.mutable.HashSet[Symbol]() + + override def transformUnit(tree: Tree)(implicit ctx: Context, info: TransformerInfo): Tree = { + methodsWithInnerAnnots.clear() + tree + } + + override def transformTyped(tree: Typed)(implicit ctx: Context, info: TransformerInfo): Tree = { + if (tree.tpt.tpe.hasAnnotation(defn.TailrecAnnot)) + methodsWithInnerAnnots += ctx.owner.enclosingMethod + tree + } + + private def mkLabel(method: Symbol, abstractOverClass: Boolean)(implicit c: Context): TermSymbol = { + val name = c.freshName(labelPrefix) + + if (method.owner.isClass) + c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass, abstractOverClass, liftThisType = false)) + else c.newSymbol(method, name.toTermName, labelFlags, method.info) + } + + override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { + val sym = tree.symbol + tree match { + case dd@DefDef(name, tparams, vparamss0, tpt, _) + if (sym.isEffectivelyFinal) && !((sym is Flags.Accessor) || (dd.rhs eq EmptyTree) || (sym is Flags.Label)) => + val mandatory = sym.hasAnnotation(defn.TailrecAnnot) + atGroupEnd { implicit ctx: Context => + + cpy.DefDef(dd)(rhs = { + + val defIsTopLevel = sym.owner.isClass + val origMeth = sym + val label = mkLabel(sym, abstractOverClass = defIsTopLevel) + val owner = ctx.owner.enclosingClass.asClass + val thisTpe = owner.thisType.widen + + var rewrote = false + + // Note: this can be split in two separate transforms(in different groups), + // than first one will collect info about which transformations and rewritings should be applied + // and second one will actually apply, + // now this speculatively transforms tree and throws away result in many cases + val rhsSemiTransformed = { + val transformer = new TailRecElimination(origMeth, dd.tparams, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel) + val rhs = atGroupEnd(transformer.transform(dd.rhs)(_)) + rewrote = transformer.rewrote + rhs + } + + if (rewrote) { + val dummyDefDef = cpy.DefDef(tree)(rhs = rhsSemiTransformed) + if (tree.symbol.owner.isClass) { + val labelDef = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel) + val call = forwarder(label, dd, abstractOverClass = defIsTopLevel, liftThisType = true) + Block(List(labelDef), call) + } else { // inner method. Tail recursion does not change `this` + val labelDef = polyDefDef(label, trefs => vrefss => { + val origMeth = tree.symbol + val origTParams = tree.tparams.map(_.symbol) + val origVParams = tree.vparamss.flatten map (_.symbol) + new TreeTypeMap( + typeMap = identity(_) + .substDealias(origTParams, trefs) + .subst(origVParams, vrefss.flatten.map(_.tpe)), + oldOwners = origMeth :: Nil, + newOwners = label :: Nil + ).transform(rhsSemiTransformed) + }) + val callIntoLabel = ( + if (dd.tparams.isEmpty) ref(label) + else ref(label).appliedToTypes(dd.tparams.map(_.tpe)) + ).appliedToArgss(vparamss0.map(_.map(x=> ref(x.symbol)))) + Block(List(labelDef), callIntoLabel) + }} else { + if (mandatory) ctx.error( + "TailRec optimisation not applicable, method not tail recursive", + // FIXME: want to report this error on `dd.namePos`, but + // because of extension method getting a weird pos, it is + // better to report on symbol so there's no overlap + sym.pos + ) + dd.rhs + } + }) + } + case d: DefDef if d.symbol.hasAnnotation(defn.TailrecAnnot) || methodsWithInnerAnnots.contains(d.symbol) => + ctx.error("TailRec optimisation not applicable, method is neither private nor final so can be overridden", sym.pos) + d + case d if d.symbol.hasAnnotation(defn.TailrecAnnot) || methodsWithInnerAnnots.contains(d.symbol) => + ctx.error("TailRec optimisation not applicable, not a method", sym.pos) + d + case _ => tree + } + + } + + class TailRecElimination(method: Symbol, methTparams: List[Tree], enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol, abstractOverClass: Boolean) extends tpd.TreeMap { + + import dotty.tools.dotc.ast.tpd._ + + var rewrote = false + + private val defaultReason = "it contains a recursive call not in tail position" + + private var ctx: TailContext = yesTailContext + + /** Rewrite this tree to contain no tail recursive calls */ + def transform(tree: Tree, nctx: TailContext)(implicit c: Context): Tree = { + if (ctx == nctx) transform(tree) + else { + val saved = ctx + ctx = nctx + try transform(tree) + finally this.ctx = saved + } + } + + def yesTailTransform(tree: Tree)(implicit c: Context): Tree = + transform(tree, yesTailContext) + + def noTailTransform(tree: Tree)(implicit c: Context): Tree = + transform(tree, noTailContext) + + def noTailTransforms[Tr <: Tree](trees: List[Tr])(implicit c: Context): List[Tr] = + trees.map(noTailTransform).asInstanceOf[List[Tr]] + + override def transform(tree: Tree)(implicit c: Context): Tree = { + /* A possibly polymorphic apply to be considered for tail call transformation. */ + def rewriteApply(tree: Tree, sym: Symbol, required: Boolean = false): Tree = { + def receiverArgumentsAndSymbol(t: Tree, accArgs: List[List[Tree]] = Nil, accT: List[Tree] = Nil): + (Tree, Tree, List[List[Tree]], List[Tree], Symbol) = t match { + case TypeApply(fun, targs) if fun.symbol eq t.symbol => receiverArgumentsAndSymbol(fun, accArgs, targs) + case Apply(fn, args) if fn.symbol == t.symbol => receiverArgumentsAndSymbol(fn, args :: accArgs, accT) + case Select(qual, _) => (qual, t, accArgs, accT, t.symbol) + case x: This => (x, x, accArgs, accT, x.symbol) + case x: Ident if x.symbol eq method => (EmptyTree, x, accArgs, accT, x.symbol) + case x => (x, x, accArgs, accT, x.symbol) + } + + val (prefix, call, arguments, typeArguments, symbol) = receiverArgumentsAndSymbol(tree) + val hasConformingTargs = (typeArguments zip methTparams).forall{x => x._1.tpe <:< x._2.tpe} + val recv = noTailTransform(prefix) + + val targs = typeArguments.map(noTailTransform) + val argumentss = arguments.map(noTailTransforms) + + val recvWiden = recv.tpe.widenDealias + + val receiverIsSame = enclosingClass.typeRef.widenDealias =:= recvWiden + val receiverIsSuper = (method.name eq sym) && enclosingClass.typeRef.widen <:< recvWiden + val receiverIsThis = recv.tpe =:= thisType || recv.tpe.widen =:= thisType + + val isRecursiveCall = (method eq sym) + + def continue = { + val method = noTailTransform(call) + val methodWithTargs = if (targs.nonEmpty) TypeApply(method, targs) else method + if (methodWithTargs.tpe.widen.isParameterless) methodWithTargs + else argumentss.foldLeft(methodWithTargs) { + // case (method, args) => Apply(method, args) // Dotty deviation no auto-detupling yet. Interesting that one can do it in Scala2! + (method, args) => Apply(method, args) + } + } + def fail(reason: String) = { + if (isMandatory || required) c.error(s"Cannot rewrite recursive call: $reason", tree.pos) + else c.debuglog("Cannot rewrite recursive call at: " + tree.pos + " because: " + reason) + continue + } + + def rewriteTailCall(recv: Tree): Tree = { + c.debuglog("Rewriting tail recursive call: " + tree.pos) + rewrote = true + val receiver = noTailTransform(recv) + + val callTargs: List[tpd.Tree] = + if (abstractOverClass) { + val classTypeArgs = recv.tpe.baseTypeWithArgs(enclosingClass).argInfos + targs ::: classTypeArgs.map(x => ref(x.typeSymbol)) + } else targs + + val method = if (callTargs.nonEmpty) TypeApply(Ident(label.termRef), callTargs) else Ident(label.termRef) + val thisPassed = + if (this.method.owner.isClass) + method.appliedTo(receiver.ensureConforms(method.tpe.widen.firstParamTypes.head)) + else method + + val res = + if (thisPassed.tpe.widen.isParameterless) thisPassed + else argumentss.foldLeft(thisPassed) { + (met, ar) => Apply(met, ar) // Dotty deviation no auto-detupling yet. + } + res + } + + if (isRecursiveCall) { + if (ctx.tailPos) { + if (!hasConformingTargs) fail("it changes type arguments on a polymorphic recursive call") + else if (recv eq EmptyTree) rewriteTailCall(This(enclosingClass.asClass)) + else if (receiverIsSame || receiverIsThis) rewriteTailCall(recv) + else fail("it changes type of 'this' on a polymorphic recursive call") + } + else fail(defaultReason) + } else { + if (receiverIsSuper) fail("it contains a recursive call targeting a supertype") + else continue + } + } + + def rewriteTry(tree: Try): Try = { + if (tree.finalizer eq EmptyTree) { + // SI-1672 Catches are in tail position when there is no finalizer + tpd.cpy.Try(tree)( + noTailTransform(tree.expr), + transformSub(tree.cases), + EmptyTree + ) + } + else { + tpd.cpy.Try(tree)( + noTailTransform(tree.expr), + noTailTransforms(tree.cases), + noTailTransform(tree.finalizer) + ) + } + } + + val res: Tree = tree match { + + case Ident(qual) => + val sym = tree.symbol + if (sym == method && ctx.tailPos) rewriteApply(tree, sym) + else tree + + case tree: Select => + val sym = tree.symbol + if (sym == method && ctx.tailPos) rewriteApply(tree, sym) + else tpd.cpy.Select(tree)(noTailTransform(tree.qualifier), tree.name) + + case Apply(fun, args) => + val meth = fun.symbol + if (meth == defn.Boolean_|| || meth == defn.Boolean_&&) + tpd.cpy.Apply(tree)(fun, transform(args)) + else + rewriteApply(tree, meth) + + case tree@Block(stats, expr) => + tpd.cpy.Block(tree)( + noTailTransforms(stats), + transform(expr) + ) + case tree @ Typed(t: Apply, tpt) if tpt.tpe.hasAnnotation(defn.TailrecAnnot) => + tpd.Typed(rewriteApply(t, t.fun.symbol, required = true), tpt) + case tree@If(cond, thenp, elsep) => + tpd.cpy.If(tree)( + noTailTransform(cond), + transform(thenp), + transform(elsep) + ) + + case tree@CaseDef(_, _, body) => + cpy.CaseDef(tree)(body = transform(body)) + + case tree@Match(selector, cases) => + tpd.cpy.Match(tree)( + noTailTransform(selector), + transformSub(cases) + ) + + case tree: Try => + rewriteTry(tree) + + case Alternative(_) | Bind(_, _) => + assert(false, "We should never have gotten inside a pattern") + tree + + case t @ DefDef(_, _, _, _, _) => + t // todo: could improve to handle DefDef's with a label flag calls to which are in tail position + + case ValDef(_, _, _) | EmptyTree | Super(_, _) | This(_) | + Literal(_) | TypeTree() | TypeDef(_, _) => + tree + + case Return(expr, from) => + tpd.cpy.Return(tree)(noTailTransform(expr), from) + + case _ => + super.transform(tree) + } + + res + } + } + + /** If references to original `target` from fully parameterized method `derived` should be + * rewired to some fully parameterized method, that method symbol, + * otherwise NoSymbol. + */ + override protected def rewiredTarget(target: Symbol, derived: Symbol)(implicit ctx: Context): Symbol = NoSymbol +} + +object TailRec { + + final class TailContext(val tailPos: Boolean) extends AnyVal + + final val noTailContext = new TailContext(false) + final val yesTailContext = new TailContext(true) +} |