diff options
-rw-r--r-- | src/dotty/tools/backend/jvm/LabelDefs.scala | 103 |
1 files changed, 58 insertions, 45 deletions
diff --git a/src/dotty/tools/backend/jvm/LabelDefs.scala b/src/dotty/tools/backend/jvm/LabelDefs.scala index 0e50e9366..18aec6b13 100644 --- a/src/dotty/tools/backend/jvm/LabelDefs.scala +++ b/src/dotty/tools/backend/jvm/LabelDefs.scala @@ -32,6 +32,7 @@ import java.lang.AssertionError import dotty.tools.dotc.util.Positions.Position import Decorators._ import tpd._ +import Flags._ import StdNames.nme /** @@ -80,54 +81,68 @@ class LabelDefs extends MiniPhaseTransform { val queue = new ArrayBuffer[Tree]() - - - override def transformBlock(tree: tpd.Block)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { - collectLabelDefs.clear - val newStats = collectLabelDefs.transformStats(tree.stats) - val newExpr = collectLabelDefs.transform(tree.expr) - val labelCalls = collectLabelDefs.labelCalls - val entryPoints = collectLabelDefs.parentLabelCalls - val labelDefs = collectLabelDefs.labelDefs - - // make sure that for every label there's a single location it should return and single entry point - // if theres already a location that it returns to that's a failure - val disallowed = new mutable.HashMap[Symbol, Tree]() - queue.sizeHint(labelCalls.size + entryPoints.size) - def moveLabels(entryPoint: Tree): List[Tree] = { - if((entryPoint.symbol is Flags.Label) && labelDefs.contains(entryPoint.symbol)) { - val visitedNow = new mutable.HashMap[Symbol, Tree]() - val treesToAppend = new ArrayBuffer[Tree]() // order matters. parents should go first - queue.clear() - - var visited = 0 - queue += entryPoint - while (visited < queue.size) { - val owningLabelDefSym = queue(visited).symbol - val owningLabelDef = labelDefs(owningLabelDefSym) - for (call <- labelCalls(owningLabelDefSym)) - if (disallowed.contains(call.symbol)) { - val oldCall = disallowed(call.symbol) - ctx.error(s"Multiple return locations for Label $oldCall and $call", call.symbol.pos) - } else { - if ((!visitedNow.contains(call.symbol)) && labelDefs.contains(call.symbol)) { - val df = labelDefs(call.symbol) - visitedNow.put(call.symbol, labelDefs(call.symbol)) - queue += call + override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { + if (tree.symbol is Flags.Label) tree + else { + collectLabelDefs.clear + val newRhs = collectLabelDefs.transform(tree.rhs) + val labelCalls = collectLabelDefs.labelCalls + var entryPoints = collectLabelDefs.parentLabelCalls + var labelDefs = collectLabelDefs.labelDefs + + // make sure that for every label there's a single location it should return and single entry point + // if theres already a location that it returns to that's a failure + val disallowed = new mutable.HashMap[Symbol, Tree]() + queue.sizeHint(labelCalls.size + entryPoints.size) + def moveLabels(entryPoint: Tree): List[Tree] = { + if ((entryPoint.symbol is Flags.Label) && labelDefs.contains(entryPoint.symbol)) { + val visitedNow = new mutable.HashMap[Symbol, Tree]() + val treesToAppend = new ArrayBuffer[Tree]() // order matters. parents should go first + queue.clear() + + var visited = 0 + queue += entryPoint + while (visited < queue.size) { + val owningLabelDefSym = queue(visited).symbol + val owningLabelDef = labelDefs(owningLabelDefSym) + for (call <- labelCalls(owningLabelDefSym)) + if (disallowed.contains(call.symbol)) { + val oldCall = disallowed(call.symbol) + ctx.error(s"Multiple return locations for Label $oldCall and $call", call.symbol.pos) + } else { + if ((!visitedNow.contains(call.symbol)) && labelDefs.contains(call.symbol)) { + visitedNow.put(call.symbol, labelDefs(call.symbol)) + queue += call + } } + if (!treesToAppend.contains(owningLabelDef)) { + treesToAppend += owningLabelDef } - if(!treesToAppend.contains(owningLabelDef)) - treesToAppend += owningLabelDef - visited += 1 + visited += 1 + } + disallowed ++= visitedNow + + treesToAppend.toList + } else Nil + } + + val putLabelDefsNearCallees = new TreeMap() { + + override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = { + tree match { + case t: Apply if (entryPoints.contains(t)) => + entryPoints = entryPoints - t + Block(moveLabels(t), t) + case _ => if (entryPoints.nonEmpty && labelDefs.nonEmpty) super.transform(tree) else tree + } } - disallowed ++= visitedNow + } - treesToAppend.toList - } else Nil - } - cpy.Block(tree)(entryPoints.flatMap(moveLabels).toList ++ newStats, newExpr) + val res = cpy.DefDef(tree)(rhs = putLabelDefsNearCallees.transform(newRhs)) + res + } } val collectLabelDefs = new TreeMap() { @@ -137,13 +152,12 @@ class LabelDefs extends MiniPhaseTransform { var isInsideLabel = false var isInsideBlock = false - def shouldMoveLabel = !isInsideBlock + def shouldMoveLabel = true // labelSymbol -> Defining tree val labelDefs = new mutable.HashMap[Symbol, Tree]() // owner -> all calls by this owner val labelCalls = new mutable.HashMap[Symbol, mutable.Set[Tree]]() - val labelCallCounts = new mutable.HashMap[Symbol, Int]() def clear = { parentLabelCalls.clear() @@ -175,7 +189,6 @@ class LabelDefs extends MiniPhaseTransform { } else r case t: Apply if t.symbol is Flags.Label => parentLabelCalls = parentLabelCalls + t - labelCallCounts.get(t.symbol) super.transform(tree) case _ => super.transform(tree) |