aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/dotty/tools/backend/jvm/LabelDefs.scala109
-rw-r--r--tests/pos/Labels.scala17
2 files changed, 84 insertions, 42 deletions
diff --git a/src/dotty/tools/backend/jvm/LabelDefs.scala b/src/dotty/tools/backend/jvm/LabelDefs.scala
index 18aec6b13..769dcdc36 100644
--- a/src/dotty/tools/backend/jvm/LabelDefs.scala
+++ b/src/dotty/tools/backend/jvm/LabelDefs.scala
@@ -1,10 +1,10 @@
package dotty.tools.backend.jvm
-import dotty.tools.dotc.ast.tpd
+import dotty.tools.dotc.ast.Trees.Thicket
+import dotty.tools.dotc.ast.{Trees, tpd}
import dotty.tools.dotc.core.Contexts.Context
import dotty.tools.dotc.core.Types
import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, TreeTransform, MiniPhase, MiniPhaseTransform}
-import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc
import dotty.tools.dotc.backend.jvm.DottyPrimitives
import dotty.tools.dotc.core.Flags.FlagSet
@@ -80,6 +80,8 @@ class LabelDefs extends MiniPhaseTransform {
def phaseName: String = "labelDef"
val queue = new ArrayBuffer[Tree]()
+ val beingAppended = new mutable.HashSet[Symbol]()
+ var labelLevel = 0
override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
if (tree.symbol is Flags.Label) tree
@@ -89,55 +91,72 @@ class LabelDefs extends MiniPhaseTransform {
val labelCalls = collectLabelDefs.labelCalls
var entryPoints = collectLabelDefs.parentLabelCalls
var labelDefs = collectLabelDefs.labelDefs
+ var callCounts = collectLabelDefs.callCounts
// make sure that for every label there's a single location it should return and single entry point
// if theres already a location that it returns to that's a failure
val disallowed = new mutable.HashMap[Symbol, Tree]()
queue.sizeHint(labelCalls.size + entryPoints.size)
- def moveLabels(entryPoint: Tree): List[Tree] = {
- if ((entryPoint.symbol is Flags.Label) && labelDefs.contains(entryPoint.symbol)) {
+
+ def putLabelDefsNearCallees = new TreeMap() {
+
+ override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = {
+ tree match {
+ case t: Apply if (entryPoints.contains(t)) =>
+ entryPoints = entryPoints - t
+ labelLevel = labelLevel + 1
+ val r = Block(moveLabels(t), t)
+ labelLevel = labelLevel - 1
+ if(labelLevel == 0) beingAppended.clear()
+ r
+ case _ => if (entryPoints.nonEmpty && labelDefs.nonEmpty) super.transform(tree) else tree
+ }
+
+ }
+ }
+
+ def moveLabels(entryPoint: Apply): List[Tree] = {
+ val entrySym = entryPoint.symbol
+ if ((entrySym is Flags.Label) && labelDefs.contains(entrySym)) {
val visitedNow = new mutable.HashMap[Symbol, Tree]()
val treesToAppend = new ArrayBuffer[Tree]() // order matters. parents should go first
+ treesToAppend += labelDefs(entrySym)
queue.clear()
var visited = 0
queue += entryPoint
while (visited < queue.size) {
val owningLabelDefSym = queue(visited).symbol
- val owningLabelDef = labelDefs(owningLabelDefSym)
- for (call <- labelCalls(owningLabelDefSym))
- if (disallowed.contains(call.symbol)) {
- val oldCall = disallowed(call.symbol)
- ctx.error(s"Multiple return locations for Label $oldCall and $call", call.symbol.pos)
- } else {
- if ((!visitedNow.contains(call.symbol)) && labelDefs.contains(call.symbol)) {
- visitedNow.put(call.symbol, labelDefs(call.symbol))
- queue += call
+ for (call <- labelCalls(owningLabelDefSym)) {
+ val callSym = call.symbol
+ if (!beingAppended.contains(callSym)) {
+ if (disallowed.contains(callSym)) {
+ val oldCall = disallowed(callSym)
+ ctx.error(s"Multiple return locations for Label $oldCall and $call", callSym.pos)
+ } else {
+ if ((!visitedNow.contains(callSym)) && labelDefs.contains(callSym)) {
+ val defTree = labelDefs(callSym)
+ visitedNow.put(callSym, defTree)
+ val callCount = callCounts(callSym)
+ if (callCount > 1) {
+ if (!treesToAppend.contains(defTree)) {
+ treesToAppend += defTree
+ queue += call
+
+ }
+ } else if (entryPoint.symbol ne callSym) entryPoints += call
+ }
}
}
- if (!treesToAppend.contains(owningLabelDef)) {
- treesToAppend += owningLabelDef
}
+
visited += 1
}
- disallowed ++= visitedNow
-
- treesToAppend.toList
+ beingAppended ++= treesToAppend.map(_.symbol)
+ treesToAppend.toList.map(putLabelDefsNearCallees.transform)
} else Nil
}
- val putLabelDefsNearCallees = new TreeMap() {
-
- override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = {
- tree match {
- case t: Apply if (entryPoints.contains(t)) =>
- entryPoints = entryPoints - t
- Block(moveLabels(t), t)
- case _ => if (entryPoints.nonEmpty && labelDefs.nonEmpty) super.transform(tree) else tree
- }
- }
- }
-
val res = cpy.DefDef(tree)(rhs = putLabelDefsNearCallees.transform(newRhs))
@@ -149,8 +168,7 @@ class LabelDefs extends MiniPhaseTransform {
// label calls from this DefDef
var parentLabelCalls: mutable.Set[Tree] = new mutable.HashSet[Tree]()
- var isInsideLabel = false
- var isInsideBlock = false
+ var callCounts: mutable.Map[Symbol, Int] = new mutable.HashMap[Symbol, Int]().withDefaultValue(0)
def shouldMoveLabel = true
@@ -158,6 +176,7 @@ class LabelDefs extends MiniPhaseTransform {
val labelDefs = new mutable.HashMap[Symbol, Tree]()
// owner -> all calls by this owner
val labelCalls = new mutable.HashMap[Symbol, mutable.Set[Tree]]()
+ var owner: Symbol = null
def clear = {
parentLabelCalls.clear()
@@ -167,28 +186,34 @@ class LabelDefs extends MiniPhaseTransform {
override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = tree match {
case t: Template => t
- case t: Block =>
- val tmp = isInsideBlock
- isInsideBlock = true
+ case t: Block =>
val r = super.transform(t)
- isInsideBlock = tmp
- r
+ r match {
+ case t: Block if t.stats.isEmpty => t.expr
+ case _ => r
+ }
case t: DefDef =>
assert(t.symbol is Flags.Label)
+
val st = parentLabelCalls
parentLabelCalls = new mutable.HashSet[Tree]()
- val tmp = isInsideLabel
- isInsideLabel = true
+ val symt = owner
+ owner = t.symbol
+
val r = super.transform(tree)
- isInsideLabel = tmp
+
+ owner = symt
labelCalls(r.symbol) = parentLabelCalls
parentLabelCalls = st
+
if(shouldMoveLabel) {
- labelDefs(r.symbol) = r
- EmptyTree
+ labelDefs(r.symbol) = r
+ EmptyTree
} else r
case t: Apply if t.symbol is Flags.Label =>
+ val sym = t.symbol
parentLabelCalls = parentLabelCalls + t
+ if(owner != sym) callCounts(sym) = callCounts(sym) + 1
super.transform(tree)
case _ =>
super.transform(tree)
diff --git a/tests/pos/Labels.scala b/tests/pos/Labels.scala
index 4a84175af..d82287313 100644
--- a/tests/pos/Labels.scala
+++ b/tests/pos/Labels.scala
@@ -1,3 +1,7 @@
+import dotty.tools.dotc.ast.Trees.Thicket
+import dotty.tools.dotc.ast.tpd._
+
+
object Labels {
def main(args: Array[String]): Unit = {
var i = 10
@@ -18,4 +22,17 @@ object Labels {
case t@2 => println("two" + t)
case _ => println("default")
}
+
+ def flatten(trees: Tree): Int = {
+ trees match {
+ case Thicket(elems) =>
+ while (trees ne trees) {
+ }
+ case tree =>
+ 33
+ }
+ 55
+ }
+
+
}