summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2014-11-04 14:55:45 +1000
committerJason Zaugg <jzaugg@gmail.com>2014-11-04 14:55:45 +1000
commitb556b2fdcc7198bffe0ee90c5adc8c9eb3c29e36 (patch)
tree97aac3270399e4e12c71994099f84b9a048bdf30
parentd61d007d4852032fcfd339ab2c904a4de6836c4d (diff)
parentc6c58071a785af3a55e7e51339461e86c58ae876 (diff)
downloadscala-b556b2fdcc7198bffe0ee90c5adc8c9eb3c29e36.tar.gz
scala-b556b2fdcc7198bffe0ee90c5adc8c9eb3c29e36.tar.bz2
scala-b556b2fdcc7198bffe0ee90c5adc8c9eb3c29e36.zip
Merge pull request #4036 from retronym/topic/opt-tail-calls
SI-8893 Restore linear perf in TailCalls with nested matches
-rw-r--r--src/compiler/scala/tools/nsc/backend/icode/BasicBlocks.scala18
-rw-r--r--src/compiler/scala/tools/nsc/transform/TailCalls.scala41
-rw-r--r--src/reflect/scala/reflect/internal/Definitions.scala4
-rw-r--r--test/files/pos/t8893.scala129
-rw-r--r--test/files/run/t8893.scala40
-rw-r--r--test/files/run/t8893b.scala15
6 files changed, 227 insertions, 20 deletions
diff --git a/src/compiler/scala/tools/nsc/backend/icode/BasicBlocks.scala b/src/compiler/scala/tools/nsc/backend/icode/BasicBlocks.scala
index f9551697d2..ad1975ef23 100644
--- a/src/compiler/scala/tools/nsc/backend/icode/BasicBlocks.scala
+++ b/src/compiler/scala/tools/nsc/backend/icode/BasicBlocks.scala
@@ -300,14 +300,16 @@ trait BasicBlocks {
if (!closed)
instructionList = instructionList map (x => map.getOrElse(x, x))
else
- instrs.zipWithIndex collect {
- case (oldInstr, i) if map contains oldInstr =>
- // SI-6288 clone important here because `replaceInstruction` assigns
- // a position to `newInstr`. Without this, a single instruction can
- // be added twice, and the position last position assigned clobbers
- // all previous positions in other usages.
- val newInstr = map(oldInstr).clone()
- code.touched |= replaceInstruction(i, newInstr)
+ instrs.iterator.zipWithIndex foreach {
+ case (oldInstr, i) =>
+ if (map contains oldInstr) {
+ // SI-6288 clone important here because `replaceInstruction` assigns
+ // a position to `newInstr`. Without this, a single instruction can
+ // be added twice, and the position last position assigned clobbers
+ // all previous positions in other usages.
+ val newInstr = map(oldInstr).clone()
+ code.touched |= replaceInstruction(i, newInstr)
+ }
}
////////////////////// Emit //////////////////////
diff --git a/src/compiler/scala/tools/nsc/transform/TailCalls.scala b/src/compiler/scala/tools/nsc/transform/TailCalls.scala
index ef534f70fd..16ea3ea90f 100644
--- a/src/compiler/scala/tools/nsc/transform/TailCalls.scala
+++ b/src/compiler/scala/tools/nsc/transform/TailCalls.scala
@@ -129,6 +129,13 @@ abstract class TailCalls extends Transform {
}
override def toString = s"${method.name} tparams=$tparams tailPos=$tailPos label=$label label info=${label.info}"
+ final def noTailContext() = clonedTailContext(false)
+ final def yesTailContext() = clonedTailContext(true)
+ protected def clonedTailContext(tailPos: Boolean): TailContext = this match {
+ case _ if this.tailPos == tailPos => this
+ case clone: ClonedTailContext => clone.that.clonedTailContext(tailPos)
+ case _ => new ClonedTailContext(this, tailPos)
+ }
}
object EmptyTailContext extends TailContext {
@@ -174,7 +181,7 @@ abstract class TailCalls extends Transform {
}
def containsRecursiveCall(t: Tree) = t exists isRecursiveCall
}
- class ClonedTailContext(that: TailContext, override val tailPos: Boolean) extends TailContext {
+ class ClonedTailContext(val that: TailContext, override val tailPos: Boolean) extends TailContext {
def method = that.method
def tparams = that.tparams
def methodPos = that.methodPos
@@ -183,9 +190,6 @@ abstract class TailCalls extends Transform {
}
private var ctx: TailContext = EmptyTailContext
- private def noTailContext() = new ClonedTailContext(ctx, tailPos = false)
- private def yesTailContext() = new ClonedTailContext(ctx, tailPos = true)
-
override def transformUnit(unit: CompilationUnit): Unit = {
try {
@@ -206,16 +210,16 @@ abstract class TailCalls extends Transform {
finally this.ctx = saved
}
- def yesTailTransform(tree: Tree): Tree = transform(tree, yesTailContext())
- def noTailTransform(tree: Tree): Tree = transform(tree, noTailContext())
+ def yesTailTransform(tree: Tree): Tree = transform(tree, ctx.yesTailContext())
+ def noTailTransform(tree: Tree): Tree = transform(tree, ctx.noTailContext())
def noTailTransforms(trees: List[Tree]) = {
- val nctx = noTailContext()
- trees map (t => transform(t, nctx))
+ val nctx = ctx.noTailContext()
+ trees mapConserve (t => transform(t, nctx))
}
override def transform(tree: Tree): Tree = {
/* A possibly polymorphic apply to be considered for tail call transformation. */
- def rewriteApply(target: Tree, fun: Tree, targs: List[Tree], args: List[Tree]) = {
+ def rewriteApply(target: Tree, fun: Tree, targs: List[Tree], args: List[Tree], mustTransformArgs: Boolean = true) = {
val receiver: Tree = fun match {
case Select(qual, _) => qual
case _ => EmptyTree
@@ -223,7 +227,7 @@ abstract class TailCalls extends Transform {
def receiverIsSame = ctx.enclosingType.widen =:= receiver.tpe.widen
def receiverIsSuper = ctx.enclosingType.widen <:< receiver.tpe.widen
def isRecursiveCall = (ctx.method eq fun.symbol) && ctx.tailPos
- def transformArgs = noTailTransforms(args)
+ def transformArgs = if (mustTransformArgs) noTailTransforms(args) else args
def matchesTypeArgs = ctx.tparams sameElements (targs map (_.tpe.typeSymbol))
/* Records failure reason in Context for reporting.
@@ -265,6 +269,10 @@ abstract class TailCalls extends Transform {
!(sym.hasAccessorFlag || sym.isConstructor)
}
+ // intentionally shadowing imports from definitions for performance
+ val runDefinitions = currentRun.runDefinitions
+ import runDefinitions.{Boolean_or, Boolean_and}
+
tree match {
case ValDef(_, _, _, _) =>
if (tree.symbol.isLazy && tree.symbol.hasAnnotation(TailrecClass))
@@ -312,8 +320,13 @@ abstract class TailCalls extends Transform {
// the assumption is once we encounter a case, the remainder of the block will consist of cases
// the prologue may be empty, usually it is the valdef that stores the scrut
val (prologue, cases) = stats span (s => !s.isInstanceOf[LabelDef])
+ val transformedPrologue = noTailTransforms(prologue)
+ val transformedCases = transformTrees(cases)
+ val transformedStats =
+ if ((prologue eq transformedPrologue) && (cases eq transformedCases)) stats // allow reuse of `tree` if the subtransform was an identity
+ else transformedPrologue ++ transformedCases
treeCopy.Block(tree,
- noTailTransforms(prologue) ++ transformTrees(cases),
+ transformedStats,
transform(expr)
)
@@ -380,7 +393,7 @@ abstract class TailCalls extends Transform {
if (res ne arg)
treeCopy.Apply(tree, fun, res :: Nil)
else
- rewriteApply(fun, fun, Nil, args)
+ rewriteApply(fun, fun, Nil, args, mustTransformArgs = false)
case Apply(fun, args) =>
rewriteApply(fun, fun, Nil, args)
@@ -421,6 +434,10 @@ abstract class TailCalls extends Transform {
def traverseNoTail(tree: Tree) = traverse(tree, maybeTailNew = false)
def traverseTreesNoTail(trees: List[Tree]) = trees foreach traverseNoTail
+ // intentionally shadowing imports from definitions for performance
+ private val runDefinitions = currentRun.runDefinitions
+ import runDefinitions.{Boolean_or, Boolean_and}
+
override def traverse(tree: Tree) = tree match {
// we're looking for label(x){x} in tail position, since that means `a` is in tail position in a call `label(a)`
case LabelDef(_, List(arg), body@Ident(_)) if arg.symbol == body.symbol =>
diff --git a/src/reflect/scala/reflect/internal/Definitions.scala b/src/reflect/scala/reflect/internal/Definitions.scala
index 666a3a5e64..e2ee6a9076 100644
--- a/src/reflect/scala/reflect/internal/Definitions.scala
+++ b/src/reflect/scala/reflect/internal/Definitions.scala
@@ -1439,6 +1439,10 @@ trait Definitions extends api.StandardDefinitions {
lazy val isUnbox = unboxMethod.values.toSet[Symbol]
lazy val isBox = boxMethod.values.toSet[Symbol]
+ lazy val Boolean_and = definitions.Boolean_and
+ lazy val Boolean_or = definitions.Boolean_or
+ lazy val Boolean_not = definitions.Boolean_not
+
lazy val Option_apply = getMemberMethod(OptionModule, nme.apply)
lazy val List_apply = DefinitionsClass.this.List_apply
diff --git a/test/files/pos/t8893.scala b/test/files/pos/t8893.scala
new file mode 100644
index 0000000000..b87c8bdd3c
--- /dev/null
+++ b/test/files/pos/t8893.scala
@@ -0,0 +1,129 @@
+// Took > 10 minutes to run the tail call phase.
+object Test {
+ def a(): Option[String] = Some("a")
+
+ def main(args: Array[String]) {
+ a() match {
+ case Some(b1) =>
+ a() match {
+ case Some(b2) =>
+ a() match {
+ case Some(b3) =>
+ a() match {
+ case Some(b4) =>
+ a() match {
+ case Some(b5) =>
+ a() match {
+ case Some(b6) =>
+ a() match {
+ case Some(b7) =>
+ a() match {
+ case Some(b8) =>
+ a() match {
+ case Some(b9) =>
+ a() match {
+ case Some(b10) =>
+ a() match {
+ case Some(b11) =>
+ a() match {
+ case Some(b12) =>
+ a() match {
+ case Some(b13) =>
+ a() match {
+ case Some(b14) =>
+ a() match {
+ case Some(b15) =>
+ a() match {
+ case Some(b16) =>
+ a() match {
+ case Some(b17) =>
+ a() match {
+ case Some(b18) =>
+ a() match {
+ case Some(b19) =>
+ a() match {
+ case Some(b20) =>
+ a() match {
+ case Some(b21) =>
+ a() match {
+ case Some(b22) =>
+ a() match {
+ case Some(b23) =>
+ a() match {
+ case Some(b24) =>
+ a() match {
+ case Some(b25) =>
+ a() match {
+ case Some(b26) =>
+ a() match {
+ case Some(b27) =>
+ a() match {
+ case Some(b28) =>
+ a() match {
+ case Some(b29) =>
+ a() match {
+ case Some(b30) =>
+ println("yay")
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ }
+}
+
diff --git a/test/files/run/t8893.scala b/test/files/run/t8893.scala
new file mode 100644
index 0000000000..6fef8ae912
--- /dev/null
+++ b/test/files/run/t8893.scala
@@ -0,0 +1,40 @@
+import annotation.tailrec
+
+object Test {
+ def a(): Option[String] = Some("a")
+
+ def test1: Any = {
+ a() match {
+ case Some(b1) =>
+ a() match {
+ case Some(b2) =>
+ @tailrec
+ def tick(i: Int): Unit = if (i < 0) () else tick(i - 1)
+ tick(10000000) // testing that this doesn't SOE
+ case None => None
+ }
+ case None => None
+ }
+ }
+
+ def test2: Any = {
+ a() match {
+ case Some(b1) =>
+ a() match {
+ case Some(b2) =>
+ @tailrec
+ def tick(i: Int): Unit = if (i < 0) () else tick(i - 1)
+ tick(10000000) // testing that this doesn't SOE
+ case None => test1
+ }
+ case None =>
+ test1 // not a tail call
+ test1
+ }
+ }
+
+ def main(args: Array[String]) {
+ test1
+ test2
+ }
+}
diff --git a/test/files/run/t8893b.scala b/test/files/run/t8893b.scala
new file mode 100644
index 0000000000..19120871aa
--- /dev/null
+++ b/test/files/run/t8893b.scala
@@ -0,0 +1,15 @@
+// Testing that recursive calls in tail positions are replaced with
+// jumps, even though the method contains recursive calls outside
+// of the tail position.
+object Test {
+ def tick(i : Int): Unit =
+ if (i == 0) ()
+ else if (i == 42) {
+ tick(0) /*not in tail posiiton*/
+ tick(i - 1)
+ } else tick(i - 1)
+
+ def main(args: Array[String]): Unit = {
+ tick(1000000)
+ }
+}