diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/dotty/tools/backend/jvm/LabelDefs.scala | 109 |
1 files changed, 67 insertions, 42 deletions
diff --git a/src/dotty/tools/backend/jvm/LabelDefs.scala b/src/dotty/tools/backend/jvm/LabelDefs.scala index 18aec6b13..769dcdc36 100644 --- a/src/dotty/tools/backend/jvm/LabelDefs.scala +++ b/src/dotty/tools/backend/jvm/LabelDefs.scala @@ -1,10 +1,10 @@ package dotty.tools.backend.jvm -import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.Trees.Thicket +import dotty.tools.dotc.ast.{Trees, tpd} import dotty.tools.dotc.core.Contexts.Context import dotty.tools.dotc.core.Types import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, TreeTransform, MiniPhase, MiniPhaseTransform} -import dotty.tools.dotc.ast.tpd import dotty.tools.dotc import dotty.tools.dotc.backend.jvm.DottyPrimitives import dotty.tools.dotc.core.Flags.FlagSet @@ -80,6 +80,8 @@ class LabelDefs extends MiniPhaseTransform { def phaseName: String = "labelDef" val queue = new ArrayBuffer[Tree]() + val beingAppended = new mutable.HashSet[Symbol]() + var labelLevel = 0 override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { if (tree.symbol is Flags.Label) tree @@ -89,55 +91,72 @@ class LabelDefs extends MiniPhaseTransform { val labelCalls = collectLabelDefs.labelCalls var entryPoints = collectLabelDefs.parentLabelCalls var labelDefs = collectLabelDefs.labelDefs + var callCounts = collectLabelDefs.callCounts // make sure that for every label there's a single location it should return and single entry point // if theres already a location that it returns to that's a failure val disallowed = new mutable.HashMap[Symbol, Tree]() queue.sizeHint(labelCalls.size + entryPoints.size) - def moveLabels(entryPoint: Tree): List[Tree] = { - if ((entryPoint.symbol is Flags.Label) && labelDefs.contains(entryPoint.symbol)) { + + def putLabelDefsNearCallees = new TreeMap() { + + override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = { + tree match { + case t: Apply if (entryPoints.contains(t)) => + entryPoints = entryPoints - t + labelLevel = labelLevel + 1 + val r = Block(moveLabels(t), t) + labelLevel = labelLevel - 1 + if(labelLevel == 0) beingAppended.clear() + r + case _ => if (entryPoints.nonEmpty && labelDefs.nonEmpty) super.transform(tree) else tree + } + + } + } + + def moveLabels(entryPoint: Apply): List[Tree] = { + val entrySym = entryPoint.symbol + if ((entrySym is Flags.Label) && labelDefs.contains(entrySym)) { val visitedNow = new mutable.HashMap[Symbol, Tree]() val treesToAppend = new ArrayBuffer[Tree]() // order matters. parents should go first + treesToAppend += labelDefs(entrySym) queue.clear() var visited = 0 queue += entryPoint while (visited < queue.size) { val owningLabelDefSym = queue(visited).symbol - val owningLabelDef = labelDefs(owningLabelDefSym) - for (call <- labelCalls(owningLabelDefSym)) - if (disallowed.contains(call.symbol)) { - val oldCall = disallowed(call.symbol) - ctx.error(s"Multiple return locations for Label $oldCall and $call", call.symbol.pos) - } else { - if ((!visitedNow.contains(call.symbol)) && labelDefs.contains(call.symbol)) { - visitedNow.put(call.symbol, labelDefs(call.symbol)) - queue += call + for (call <- labelCalls(owningLabelDefSym)) { + val callSym = call.symbol + if (!beingAppended.contains(callSym)) { + if (disallowed.contains(callSym)) { + val oldCall = disallowed(callSym) + ctx.error(s"Multiple return locations for Label $oldCall and $call", callSym.pos) + } else { + if ((!visitedNow.contains(callSym)) && labelDefs.contains(callSym)) { + val defTree = labelDefs(callSym) + visitedNow.put(callSym, defTree) + val callCount = callCounts(callSym) + if (callCount > 1) { + if (!treesToAppend.contains(defTree)) { + treesToAppend += defTree + queue += call + + } + } else if (entryPoint.symbol ne callSym) entryPoints += call + } } } - if (!treesToAppend.contains(owningLabelDef)) { - treesToAppend += owningLabelDef } + visited += 1 } - disallowed ++= visitedNow - - treesToAppend.toList + beingAppended ++= treesToAppend.map(_.symbol) + treesToAppend.toList.map(putLabelDefsNearCallees.transform) } else Nil } - val putLabelDefsNearCallees = new TreeMap() { - - override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = { - tree match { - case t: Apply if (entryPoints.contains(t)) => - entryPoints = entryPoints - t - Block(moveLabels(t), t) - case _ => if (entryPoints.nonEmpty && labelDefs.nonEmpty) super.transform(tree) else tree - } - } - } - val res = cpy.DefDef(tree)(rhs = putLabelDefsNearCallees.transform(newRhs)) @@ -149,8 +168,7 @@ class LabelDefs extends MiniPhaseTransform { // label calls from this DefDef var parentLabelCalls: mutable.Set[Tree] = new mutable.HashSet[Tree]() - var isInsideLabel = false - var isInsideBlock = false + var callCounts: mutable.Map[Symbol, Int] = new mutable.HashMap[Symbol, Int]().withDefaultValue(0) def shouldMoveLabel = true @@ -158,6 +176,7 @@ class LabelDefs extends MiniPhaseTransform { val labelDefs = new mutable.HashMap[Symbol, Tree]() // owner -> all calls by this owner val labelCalls = new mutable.HashMap[Symbol, mutable.Set[Tree]]() + var owner: Symbol = null def clear = { parentLabelCalls.clear() @@ -167,28 +186,34 @@ class LabelDefs extends MiniPhaseTransform { override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = tree match { case t: Template => t - case t: Block => - val tmp = isInsideBlock - isInsideBlock = true + case t: Block => val r = super.transform(t) - isInsideBlock = tmp - r + r match { + case t: Block if t.stats.isEmpty => t.expr + case _ => r + } case t: DefDef => assert(t.symbol is Flags.Label) + val st = parentLabelCalls parentLabelCalls = new mutable.HashSet[Tree]() - val tmp = isInsideLabel - isInsideLabel = true + val symt = owner + owner = t.symbol + val r = super.transform(tree) - isInsideLabel = tmp + + owner = symt labelCalls(r.symbol) = parentLabelCalls parentLabelCalls = st + if(shouldMoveLabel) { - labelDefs(r.symbol) = r - EmptyTree + labelDefs(r.symbol) = r + EmptyTree } else r case t: Apply if t.symbol is Flags.Label => + val sym = t.symbol parentLabelCalls = parentLabelCalls + t + if(owner != sym) callCounts(sym) = callCounts(sym) + 1 super.transform(tree) case _ => super.transform(tree) |