From a90d1f01d603d9f00445ead48a87a051cd0ede15 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Fri, 31 May 2013 23:49:52 +0200 Subject: SI-6574 Support @tailrec for extension methods. Currently, when the body of an extension method is transplanted to the companion object, recursive calls point back to the original instance method. That changes during erasure, but this is too late for tail call analysis/elimination. This commit eagerly updates the recursive calls to point to the extension method in the companion. It also removes the @tailrec annotation from the original method. --- .../tools/nsc/transform/ExtensionMethods.scala | 39 +++++++++++++++++++--- test/files/neg/t6574.check | 7 ++++ test/files/neg/t6574.scala | 10 ++++++ test/files/pos/t6574.scala | 19 +++++++++++ test/files/run/t6574b.check | 1 + test/files/run/t6574b.scala | 7 ++++ 6 files changed, 79 insertions(+), 4 deletions(-) create mode 100644 test/files/neg/t6574.check create mode 100644 test/files/neg/t6574.scala create mode 100644 test/files/pos/t6574.scala create mode 100644 test/files/run/t6574b.check create mode 100644 test/files/run/t6574b.scala diff --git a/src/compiler/scala/tools/nsc/transform/ExtensionMethods.scala b/src/compiler/scala/tools/nsc/transform/ExtensionMethods.scala index 672d9d232a..56ec49e962 100644 --- a/src/compiler/scala/tools/nsc/transform/ExtensionMethods.scala +++ b/src/compiler/scala/tools/nsc/transform/ExtensionMethods.scala @@ -208,6 +208,7 @@ abstract class ExtensionMethods extends Transform with TypingTransformers { companion.moduleClass.newMethod(extensionName, origMeth.pos, origMeth.flags & ~OVERRIDE & ~PROTECTED | FINAL) setAnnotations origMeth.annotations ) + origMeth.removeAnnotation(TailrecClass) // it's on the extension method, now. companion.info.decls.enter(extensionMeth) } @@ -221,15 +222,16 @@ abstract class ExtensionMethods extends Transform with TypingTransformers { val extensionParams = allParameters(extensionMono) val extensionThis = gen.mkAttributedStableRef(thiz setPos extensionMeth.pos) - val extensionBody = ( - rhs + val extensionBody: Tree = { + val tree = rhs .substituteSymbols(origTpeParams, extensionTpeParams) .substituteSymbols(origParams, extensionParams) .substituteThis(origThis, extensionThis) .changeOwner(origMeth -> extensionMeth) - ) + new SubstututeRecursion(origMeth, extensionMeth, unit).transform(tree) + } - // Record the extension method ( FIXME: because... ? ) + // Record the extension method. Later, in `Extender#transformStats`, these will be added to the companion object. extensionDefs(companion) += atPos(tree.pos)(DefDef(extensionMeth, extensionBody)) // These three lines are assembling Foo.bar$extension[T1, T2, ...]($this) @@ -264,4 +266,33 @@ abstract class ExtensionMethods extends Transform with TypingTransformers { stat } } + + final class SubstututeRecursion(origMeth: Symbol, extensionMeth: Symbol, + unit: CompilationUnit) extends TypingTransformer(unit) { + override def transform(tree: Tree): Tree = tree match { + // SI-6574 Rewrite recursive calls against the extension method so they can + // be tail call optimized later. The tailcalls phases comes before + // erasure, which performs this translation more generally at all call + // sites. + // + // // Source + // class C[C] { def meth[M](a: A) = { { : C[C'] }.meth[M'] } } + // + // // Translation + // class C[C] { def meth[M](a: A) = { { : C[C'] }.meth[M'](a1) } } + // object C { def meth$extension[M, C](this$: C[C], a: A) + // = { meth$extension[M', C']({ : C[C'] })(a1) } } + case treeInfo.Applied(sel @ Select(qual, _), targs, argss) if sel.symbol == origMeth => + import gen.CODE._ + localTyper.typedPos(tree.pos) { + val allArgss = List(qual) :: argss + val origThis = extensionMeth.owner.companionClass + val baseType = qual.tpe.baseType(origThis) + val allTargs = targs.map(_.tpe) ::: baseType.typeArgs + val fun = gen.mkAttributedTypeApply(THIS(extensionMeth.owner), extensionMeth, allTargs) + allArgss.foldLeft(fun)(Apply(_, _)) + } + case _ => super.transform(tree) + } + } } diff --git a/test/files/neg/t6574.check b/test/files/neg/t6574.check new file mode 100644 index 0000000000..c67b4ed804 --- /dev/null +++ b/test/files/neg/t6574.check @@ -0,0 +1,7 @@ +t6574.scala:4: error: could not optimize @tailrec annotated method notTailPos$extension: it contains a recursive call not in tail position + println("tail") + ^ +t6574.scala:8: error: could not optimize @tailrec annotated method differentTypeArgs$extension: it is called recursively with different type arguments + {(); new Bad[String, Unit](0)}.differentTypeArgs + ^ +two errors found diff --git a/test/files/neg/t6574.scala b/test/files/neg/t6574.scala new file mode 100644 index 0000000000..bba97ad62e --- /dev/null +++ b/test/files/neg/t6574.scala @@ -0,0 +1,10 @@ +class Bad[X, Y](val v: Int) extends AnyVal { + @annotation.tailrec final def notTailPos[Z](a: Int)(b: String) { + this.notTailPos[Z](a)(b) + println("tail") + } + + @annotation.tailrec final def differentTypeArgs { + {(); new Bad[String, Unit](0)}.differentTypeArgs + } +} diff --git a/test/files/pos/t6574.scala b/test/files/pos/t6574.scala new file mode 100644 index 0000000000..59c1701eb4 --- /dev/null +++ b/test/files/pos/t6574.scala @@ -0,0 +1,19 @@ +class Bad[X, Y](val v: Int) extends AnyVal { + def vv = v + @annotation.tailrec final def foo[Z](a: Int)(b: String) { + this.foo[Z](a)(b) + } + + @annotation.tailrec final def differentReceiver { + {(); new Bad[X, Y](0)}.differentReceiver + } + + @annotation.tailrec final def dependent[Z](a: Int)(b: String): b.type = { + this.dependent[Z](a)(b) + } +} + +class HK[M[_]](val v: Int) extends AnyVal { + def hk[N[_]]: Unit = if (false) hk[M] else () +} + diff --git a/test/files/run/t6574b.check b/test/files/run/t6574b.check new file mode 100644 index 0000000000..e10fa4f810 --- /dev/null +++ b/test/files/run/t6574b.check @@ -0,0 +1 @@ +List(5, 4, 3, 2, 1) diff --git a/test/files/run/t6574b.scala b/test/files/run/t6574b.scala new file mode 100644 index 0000000000..df329a31ca --- /dev/null +++ b/test/files/run/t6574b.scala @@ -0,0 +1,7 @@ +object Test extends App { + implicit class AnyOps(val i: Int) extends AnyVal { + private def parentsOf(x: Int): List[Int] = if (x == 0) Nil else x :: parentsOf(x - 1) + def parents: List[Int] = parentsOf(i) + } + println((5).parents) +} -- cgit v1.2.3