aboutsummaryrefslogtreecommitdiff
path: root/compiler/src/dotty/tools/dotc/transform/TailRec.scala
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/src/dotty/tools/dotc/transform/TailRec.scala')
-rw-r--r--compiler/src/dotty/tools/dotc/transform/TailRec.scala384
1 files changed, 384 insertions, 0 deletions
diff --git a/compiler/src/dotty/tools/dotc/transform/TailRec.scala b/compiler/src/dotty/tools/dotc/transform/TailRec.scala
new file mode 100644
index 000000000..dc4454439
--- /dev/null
+++ b/compiler/src/dotty/tools/dotc/transform/TailRec.scala
@@ -0,0 +1,384 @@
+package dotty.tools.dotc.transform
+
+import dotty.tools.dotc.ast.Trees._
+import dotty.tools.dotc.ast.{TreeTypeMap, tpd}
+import dotty.tools.dotc.core.Contexts.Context
+import dotty.tools.dotc.core.Decorators._
+import dotty.tools.dotc.core.DenotTransformers.DenotTransformer
+import dotty.tools.dotc.core.Denotations.SingleDenotation
+import dotty.tools.dotc.core.Symbols._
+import dotty.tools.dotc.core.Types._
+import dotty.tools.dotc.core._
+import dotty.tools.dotc.transform.TailRec._
+import dotty.tools.dotc.transform.TreeTransforms.{MiniPhaseTransform, TransformerInfo}
+
+/**
+ * A Tail Rec Transformer
+ * @author Erik Stenman, Iulian Dragos,
+ * ported and heavily modified for dotty by Dmitry Petrashko
+ * @version 1.1
+ *
+ * What it does:
+ * <p>
+ * Finds method calls in tail-position and replaces them with jumps.
+ * A call is in a tail-position if it is the last instruction to be
+ * executed in the body of a method. This is done by recursing over
+ * the trees that may contain calls in tail-position (trees that can't
+ * contain such calls are not transformed). However, they are not that
+ * many.
+ * </p>
+ * <p>
+ * Self-recursive calls in tail-position are replaced by jumps to a
+ * label at the beginning of the method. As the JVM provides no way to
+ * jump from a method to another one, non-recursive calls in
+ * tail-position are not optimized.
+ * </p>
+ * <p>
+ * A method call is self-recursive if it calls the current method and
+ * the method is final (otherwise, it could
+ * be a call to an overridden method in a subclass).
+ *
+ * Recursive calls on a different instance
+ * are optimized. Since 'this' is not a local variable it s added as
+ * a label parameter.
+ * </p>
+ * <p>
+ * This phase has been moved before pattern matching to catch more
+ * of the common cases of tail recursive functions. This means that
+ * more cases should be taken into account (like nested function, and
+ * pattern cases).
+ * </p>
+ * <p>
+ * If a method contains self-recursive calls, a label is added to at
+ * the beginning of its body and the calls are replaced by jumps to
+ * that label.
+ * </p>
+ * <p>
+ *
+ * In scalac, If the method had type parameters, the call must contain same
+ * parameters as type arguments. This is no longer case in dotc.
+ * In scalac, this is named tailCall but it does only provide optimization for
+ * self recursive functions, that's why it's renamed to tailrec
+ * </p>
+ */
+class TailRec extends MiniPhaseTransform with DenotTransformer with FullParameterization { thisTransform =>
+
+ import dotty.tools.dotc.ast.tpd._
+
+ override def transform(ref: SingleDenotation)(implicit ctx: Context): SingleDenotation = ref
+
+ override def phaseName: String = "tailrec"
+ override def treeTransformPhase = thisTransform // TODO Make sure tailrec runs at next phase.
+
+ final val labelPrefix = "tailLabel"
+ final val labelFlags = Flags.Synthetic | Flags.Label
+
+ /** Symbols of methods that have @tailrec annotatios inside */
+ private val methodsWithInnerAnnots = new collection.mutable.HashSet[Symbol]()
+
+ override def transformUnit(tree: Tree)(implicit ctx: Context, info: TransformerInfo): Tree = {
+ methodsWithInnerAnnots.clear()
+ tree
+ }
+
+ override def transformTyped(tree: Typed)(implicit ctx: Context, info: TransformerInfo): Tree = {
+ if (tree.tpt.tpe.hasAnnotation(defn.TailrecAnnot))
+ methodsWithInnerAnnots += ctx.owner.enclosingMethod
+ tree
+ }
+
+ private def mkLabel(method: Symbol, abstractOverClass: Boolean)(implicit c: Context): TermSymbol = {
+ val name = c.freshName(labelPrefix)
+
+ if (method.owner.isClass)
+ c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass, abstractOverClass, liftThisType = false))
+ else c.newSymbol(method, name.toTermName, labelFlags, method.info)
+ }
+
+ override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
+ val sym = tree.symbol
+ tree match {
+ case dd@DefDef(name, tparams, vparamss0, tpt, _)
+ if (sym.isEffectivelyFinal) && !((sym is Flags.Accessor) || (dd.rhs eq EmptyTree) || (sym is Flags.Label)) =>
+ val mandatory = sym.hasAnnotation(defn.TailrecAnnot)
+ atGroupEnd { implicit ctx: Context =>
+
+ cpy.DefDef(dd)(rhs = {
+
+ val defIsTopLevel = sym.owner.isClass
+ val origMeth = sym
+ val label = mkLabel(sym, abstractOverClass = defIsTopLevel)
+ val owner = ctx.owner.enclosingClass.asClass
+ val thisTpe = owner.thisType.widen
+
+ var rewrote = false
+
+ // Note: this can be split in two separate transforms(in different groups),
+ // than first one will collect info about which transformations and rewritings should be applied
+ // and second one will actually apply,
+ // now this speculatively transforms tree and throws away result in many cases
+ val rhsSemiTransformed = {
+ val transformer = new TailRecElimination(origMeth, dd.tparams, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel)
+ val rhs = atGroupEnd(transformer.transform(dd.rhs)(_))
+ rewrote = transformer.rewrote
+ rhs
+ }
+
+ if (rewrote) {
+ val dummyDefDef = cpy.DefDef(tree)(rhs = rhsSemiTransformed)
+ if (tree.symbol.owner.isClass) {
+ val labelDef = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel)
+ val call = forwarder(label, dd, abstractOverClass = defIsTopLevel, liftThisType = true)
+ Block(List(labelDef), call)
+ } else { // inner method. Tail recursion does not change `this`
+ val labelDef = polyDefDef(label, trefs => vrefss => {
+ val origMeth = tree.symbol
+ val origTParams = tree.tparams.map(_.symbol)
+ val origVParams = tree.vparamss.flatten map (_.symbol)
+ new TreeTypeMap(
+ typeMap = identity(_)
+ .substDealias(origTParams, trefs)
+ .subst(origVParams, vrefss.flatten.map(_.tpe)),
+ oldOwners = origMeth :: Nil,
+ newOwners = label :: Nil
+ ).transform(rhsSemiTransformed)
+ })
+ val callIntoLabel = (
+ if (dd.tparams.isEmpty) ref(label)
+ else ref(label).appliedToTypes(dd.tparams.map(_.tpe))
+ ).appliedToArgss(vparamss0.map(_.map(x=> ref(x.symbol))))
+ Block(List(labelDef), callIntoLabel)
+ }} else {
+ if (mandatory) ctx.error(
+ "TailRec optimisation not applicable, method not tail recursive",
+ // FIXME: want to report this error on `dd.namePos`, but
+ // because of extension method getting a weird pos, it is
+ // better to report on symbol so there's no overlap
+ sym.pos
+ )
+ dd.rhs
+ }
+ })
+ }
+ case d: DefDef if d.symbol.hasAnnotation(defn.TailrecAnnot) || methodsWithInnerAnnots.contains(d.symbol) =>
+ ctx.error("TailRec optimisation not applicable, method is neither private nor final so can be overridden", sym.pos)
+ d
+ case d if d.symbol.hasAnnotation(defn.TailrecAnnot) || methodsWithInnerAnnots.contains(d.symbol) =>
+ ctx.error("TailRec optimisation not applicable, not a method", sym.pos)
+ d
+ case _ => tree
+ }
+
+ }
+
+ class TailRecElimination(method: Symbol, methTparams: List[Tree], enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol, abstractOverClass: Boolean) extends tpd.TreeMap {
+
+ import dotty.tools.dotc.ast.tpd._
+
+ var rewrote = false
+
+ private val defaultReason = "it contains a recursive call not in tail position"
+
+ private var ctx: TailContext = yesTailContext
+
+ /** Rewrite this tree to contain no tail recursive calls */
+ def transform(tree: Tree, nctx: TailContext)(implicit c: Context): Tree = {
+ if (ctx == nctx) transform(tree)
+ else {
+ val saved = ctx
+ ctx = nctx
+ try transform(tree)
+ finally this.ctx = saved
+ }
+ }
+
+ def yesTailTransform(tree: Tree)(implicit c: Context): Tree =
+ transform(tree, yesTailContext)
+
+ def noTailTransform(tree: Tree)(implicit c: Context): Tree =
+ transform(tree, noTailContext)
+
+ def noTailTransforms[Tr <: Tree](trees: List[Tr])(implicit c: Context): List[Tr] =
+ trees.map(noTailTransform).asInstanceOf[List[Tr]]
+
+ override def transform(tree: Tree)(implicit c: Context): Tree = {
+ /* A possibly polymorphic apply to be considered for tail call transformation. */
+ def rewriteApply(tree: Tree, sym: Symbol, required: Boolean = false): Tree = {
+ def receiverArgumentsAndSymbol(t: Tree, accArgs: List[List[Tree]] = Nil, accT: List[Tree] = Nil):
+ (Tree, Tree, List[List[Tree]], List[Tree], Symbol) = t match {
+ case TypeApply(fun, targs) if fun.symbol eq t.symbol => receiverArgumentsAndSymbol(fun, accArgs, targs)
+ case Apply(fn, args) if fn.symbol == t.symbol => receiverArgumentsAndSymbol(fn, args :: accArgs, accT)
+ case Select(qual, _) => (qual, t, accArgs, accT, t.symbol)
+ case x: This => (x, x, accArgs, accT, x.symbol)
+ case x: Ident if x.symbol eq method => (EmptyTree, x, accArgs, accT, x.symbol)
+ case x => (x, x, accArgs, accT, x.symbol)
+ }
+
+ val (prefix, call, arguments, typeArguments, symbol) = receiverArgumentsAndSymbol(tree)
+ val hasConformingTargs = (typeArguments zip methTparams).forall{x => x._1.tpe <:< x._2.tpe}
+ val recv = noTailTransform(prefix)
+
+ val targs = typeArguments.map(noTailTransform)
+ val argumentss = arguments.map(noTailTransforms)
+
+ val recvWiden = recv.tpe.widenDealias
+
+ val receiverIsSame = enclosingClass.typeRef.widenDealias =:= recvWiden
+ val receiverIsSuper = (method.name eq sym) && enclosingClass.typeRef.widen <:< recvWiden
+ val receiverIsThis = recv.tpe =:= thisType || recv.tpe.widen =:= thisType
+
+ val isRecursiveCall = (method eq sym)
+
+ def continue = {
+ val method = noTailTransform(call)
+ val methodWithTargs = if (targs.nonEmpty) TypeApply(method, targs) else method
+ if (methodWithTargs.tpe.widen.isParameterless) methodWithTargs
+ else argumentss.foldLeft(methodWithTargs) {
+ // case (method, args) => Apply(method, args) // Dotty deviation no auto-detupling yet. Interesting that one can do it in Scala2!
+ (method, args) => Apply(method, args)
+ }
+ }
+ def fail(reason: String) = {
+ if (isMandatory || required) c.error(s"Cannot rewrite recursive call: $reason", tree.pos)
+ else c.debuglog("Cannot rewrite recursive call at: " + tree.pos + " because: " + reason)
+ continue
+ }
+
+ def rewriteTailCall(recv: Tree): Tree = {
+ c.debuglog("Rewriting tail recursive call: " + tree.pos)
+ rewrote = true
+ val receiver = noTailTransform(recv)
+
+ val callTargs: List[tpd.Tree] =
+ if (abstractOverClass) {
+ val classTypeArgs = recv.tpe.baseTypeWithArgs(enclosingClass).argInfos
+ targs ::: classTypeArgs.map(x => ref(x.typeSymbol))
+ } else targs
+
+ val method = if (callTargs.nonEmpty) TypeApply(Ident(label.termRef), callTargs) else Ident(label.termRef)
+ val thisPassed =
+ if (this.method.owner.isClass)
+ method.appliedTo(receiver.ensureConforms(method.tpe.widen.firstParamTypes.head))
+ else method
+
+ val res =
+ if (thisPassed.tpe.widen.isParameterless) thisPassed
+ else argumentss.foldLeft(thisPassed) {
+ (met, ar) => Apply(met, ar) // Dotty deviation no auto-detupling yet.
+ }
+ res
+ }
+
+ if (isRecursiveCall) {
+ if (ctx.tailPos) {
+ if (!hasConformingTargs) fail("it changes type arguments on a polymorphic recursive call")
+ else if (recv eq EmptyTree) rewriteTailCall(This(enclosingClass.asClass))
+ else if (receiverIsSame || receiverIsThis) rewriteTailCall(recv)
+ else fail("it changes type of 'this' on a polymorphic recursive call")
+ }
+ else fail(defaultReason)
+ } else {
+ if (receiverIsSuper) fail("it contains a recursive call targeting a supertype")
+ else continue
+ }
+ }
+
+ def rewriteTry(tree: Try): Try = {
+ if (tree.finalizer eq EmptyTree) {
+ // SI-1672 Catches are in tail position when there is no finalizer
+ tpd.cpy.Try(tree)(
+ noTailTransform(tree.expr),
+ transformSub(tree.cases),
+ EmptyTree
+ )
+ }
+ else {
+ tpd.cpy.Try(tree)(
+ noTailTransform(tree.expr),
+ noTailTransforms(tree.cases),
+ noTailTransform(tree.finalizer)
+ )
+ }
+ }
+
+ val res: Tree = tree match {
+
+ case Ident(qual) =>
+ val sym = tree.symbol
+ if (sym == method && ctx.tailPos) rewriteApply(tree, sym)
+ else tree
+
+ case tree: Select =>
+ val sym = tree.symbol
+ if (sym == method && ctx.tailPos) rewriteApply(tree, sym)
+ else tpd.cpy.Select(tree)(noTailTransform(tree.qualifier), tree.name)
+
+ case Apply(fun, args) =>
+ val meth = fun.symbol
+ if (meth == defn.Boolean_|| || meth == defn.Boolean_&&)
+ tpd.cpy.Apply(tree)(fun, transform(args))
+ else
+ rewriteApply(tree, meth)
+
+ case tree@Block(stats, expr) =>
+ tpd.cpy.Block(tree)(
+ noTailTransforms(stats),
+ transform(expr)
+ )
+ case tree @ Typed(t: Apply, tpt) if tpt.tpe.hasAnnotation(defn.TailrecAnnot) =>
+ tpd.Typed(rewriteApply(t, t.fun.symbol, required = true), tpt)
+ case tree@If(cond, thenp, elsep) =>
+ tpd.cpy.If(tree)(
+ noTailTransform(cond),
+ transform(thenp),
+ transform(elsep)
+ )
+
+ case tree@CaseDef(_, _, body) =>
+ cpy.CaseDef(tree)(body = transform(body))
+
+ case tree@Match(selector, cases) =>
+ tpd.cpy.Match(tree)(
+ noTailTransform(selector),
+ transformSub(cases)
+ )
+
+ case tree: Try =>
+ rewriteTry(tree)
+
+ case Alternative(_) | Bind(_, _) =>
+ assert(false, "We should never have gotten inside a pattern")
+ tree
+
+ case t @ DefDef(_, _, _, _, _) =>
+ t // todo: could improve to handle DefDef's with a label flag calls to which are in tail position
+
+ case ValDef(_, _, _) | EmptyTree | Super(_, _) | This(_) |
+ Literal(_) | TypeTree() | TypeDef(_, _) =>
+ tree
+
+ case Return(expr, from) =>
+ tpd.cpy.Return(tree)(noTailTransform(expr), from)
+
+ case _ =>
+ super.transform(tree)
+ }
+
+ res
+ }
+ }
+
+ /** If references to original `target` from fully parameterized method `derived` should be
+ * rewired to some fully parameterized method, that method symbol,
+ * otherwise NoSymbol.
+ */
+ override protected def rewiredTarget(target: Symbol, derived: Symbol)(implicit ctx: Context): Symbol = NoSymbol
+}
+
+object TailRec {
+
+ final class TailContext(val tailPos: Boolean) extends AnyVal
+
+ final val noTailContext = new TailContext(false)
+ final val yesTailContext = new TailContext(true)
+}