diff options
-rw-r--r-- | src/compiler/scala/tools/nsc/matching/ParallelMatching.scala | 160 | ||||
-rw-r--r-- | src/compiler/scala/tools/nsc/matching/TransMatcher.scala | 11 |
2 files changed, 73 insertions, 98 deletions
diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala index 83b893a252..0f72778bc3 100644 --- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala +++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala @@ -152,16 +152,12 @@ trait ParallelMatching extends ast.TreeDSL { import collection.mutable.{ HashMap, ListBuffer } class MatchMatrix(context: MatchMatrixContext, data: MatchMatrixInit) { import context._ - val MatchMatrixInit(roots, cases, failTree) = data - val labels = new HashMap[Int, Symbol]() - val shortCuts = new ListBuffer[Symbol]() - lazy val reached = new BitSet(targets.size) + val MatchMatrixInit(roots, cases, failTree) = data + val ExpandedMatrix(rows, targets) = expand(roots, cases) + val expansion: Rep = make(roots, rows) - private lazy val expandResult = expand(roots, cases) - lazy val targets: List[FinalState] = expandResult._2 - lazy val vss: List[List[Symbol]] = expandResult._3 - lazy val expansion: Rep = make(roots, expandResult._1) + val shortCuts = new ListBuffer[Symbol]() final def shortCut(theLabel: Symbol): Int = { shortCuts += theLabel @@ -187,8 +183,9 @@ trait ParallelMatching extends ast.TreeDSL { object lxtt extends Transformer { override def transform(tree: Tree): Tree = tree match { case blck @ Block(vdefs, ld @ LabelDef(name, params, body)) => - val bx = labelIndex(ld.symbol) - if (bx >= 0 && !isReachedTwice(bx)) squeezedBlock(vdefs, body) + def shouldInline(t: FinalState) = t.isReachedOnce && (t.label eq ld.symbol) + + if (targets exists shouldInline) squeezedBlock(vdefs, body) else blck case t => @@ -218,78 +215,50 @@ trait ParallelMatching extends ast.TreeDSL { returning[Tree](resetTraverser traverse _)(lxtt transform tree) } - final def isReached(bx: Int) = labels contains bx - final def markReachedTwice(bx: Int) { reached += bx } - - /** @pre bx < 0 || labelIndex(bx) != -1 */ - final def isReachedTwice(bx: Int) = (bx < 0) || reached(bx) - - /* @returns bx such that labels(bx) eq label, -1 if no such bx exists */ - final def labelIndex(label: Symbol) = labels find (_._2 eq label) map (_._1) getOrElse (-1) - /** first time bx is requested, a LabelDef is returned. next time, a jump. * the function takes care of binding */ final def requestBody(bx: Int, subst: Bindings): Tree = { - if (bx < 0) { // is shortcut - val jlabel = shortCuts(-bx-1) - return Apply(ID(jlabel), Nil) - } - if (!isReached(bx)) { // first time this bx is requested + val target = targets(bx) + lazy val FinalState(bindings, body, freeVars) = target + + // shortcut + def labelJump: Tree = Apply(ID(shortCuts(-bx-1)), Nil) + + // first time this bx is requested + def firstTime: Tree = { // might be bound elsewhere val (vsyms, vdefs) : (List[Symbol], List[Tree]) = List.unzip( - for (v <- vss(bx) ; substv <- subst(v)) yield + for (v <- freeVars ; substv <- subst(v)) yield (v, typedValDef(v, substv)) ) - val body = targets(bx).body // @bug: typer is not able to digest a body of type Nothing being assigned result type Unit - val tpe = if (body.tpe.isNothing) body.tpe else resultType - val newType = MethodType(vsyms, tpe) - val label = owner.newLabel(body.pos, "body%"+bx) setInfo newType - labels(bx) = label + val tpe = if (body.tpe.isNothing) body.tpe else resultType + target.setLabel(owner.newLabel(body.pos, "body%"+bx) setInfo MethodType(vsyms, tpe)) - return logAndReturn("requestBody(%d) first time: ".format(bx), squeezedBlock(vdefs, ( - if (isLabellable(body)) LabelDef(label, vsyms, body setType tpe) + squeezedBlock(vdefs, ( + if (isLabellable(body)) LabelDef(target.label, vsyms, body setType tpe) else body.duplicate setType tpe - ))) + )) } - // if some bx is not reached twice, its LabelDef is replaced with body itself - markReachedTwice(bx) - - val args = vss(bx) map subst flatten - val label = labels(bx) - val body = targets(bx).body - val fmls = label.tpe.paramTypes + def successiveTimes: Tree = { + val args = freeVars map subst flatten + val fmls = target.label.tpe.paramTypes + def vds = for (v <- freeVars ; substv <- subst(v)) yield typedValDef(v, substv) - def debugConsistencyFailure(): String = { - val xs = - ( for ((vs, i) <- vss.zipWithIndex) yield "vss(%d) = %s\nargs = %s".format(i, vs mkString ", ", args) ) ++ - ( for ((t, i) <- targets.zipWithIndex) yield "targets(%d) = %s".format(i, t) ) ++ - ( for ((i, l) <- labels) yield "labels(%d) = %s".format(i, l) ) ++ - ( for ((s, v) <- List("bx" -> bx, "label.tpe" -> label.tpe)) yield "%s = %s".format(s, v) ) - - xs mkString "\n" - } - // sanity checks: same length lists and args are conformant with formals - def isConsistent() = (fmls.length == args.length) && List.forall2(args, fmls)(_.tpe <:< _) - - if (!isConsistent()) { - val msg = ( - """Consistency problem compiling %s! - |Trying to call %s(%s) with arguments (%s)""" . - stripMargin.format(cunit.source, label, fmls, args) - ) - println(debugConsistencyFailure()) - // TRACE(debugConsistencyFailure()) - cunit.error(body.pos, msg) + if (isLabellable(body)) ID(target.label) APPLY (args) + else squeezedBlock(vds, body.duplicate setType resultType) } - def vds = for (v <- vss(bx) ; substv <- subst(v)) yield typedValDef(v, substv) + if (bx < 0) labelJump + else { + target.incrementReached - if (isLabellable(body)) ID(label) APPLY (args) - else squeezedBlock(vds, body.duplicate setType resultType) + if (target.isReachedOnce) firstTime + else successiveTimes + } } /** the injection here handles alternatives and unapply type tests */ @@ -367,21 +336,10 @@ trait ParallelMatching extends ast.TreeDSL { else Rep(tvars, rows).checkExhaustive } - override def toString() = { - val toPrint: List[(Any, Traversable[Any])] = ( - (vss.zipWithIndex map (_.swap)) ::: - List[(Any, Traversable[Any])]( - "labels" -> labels, - "targets" -> targets, - "reached" -> reached, - "shortCuts" -> shortCuts.toList - ) filterNot (_._2.isEmpty) - ) + override def toString() = "MatchMatrix(%s)".format(targets) - val strs = toPrint map { case (k, v) => " %s = %s\n".format(k, v) } - if (toPrint.isEmpty) "MatchMatrix()" - else "MatchMatrix(\n%s)".format(strs mkString) - } + /** Intended to be the DFA created from the match matrix. */ + class MatchAutomaton(matrix: MatchMatrix) { } /** * Encapsulates a symbol being matched on. @@ -1099,7 +1057,32 @@ trait ParallelMatching extends ast.TreeDSL { } } - case class FinalState(subst: Bindings, body: Tree) + object ExpandedMatrix { + def unapply(x: ExpandedMatrix) = Some(x.rows, x.targets) + def apply(rows: List[Row], targets: List[FinalState]) = new ExpandedMatrix(rows, targets) + } + class ExpandedMatrix(val rows: List[Row], val targets: List[FinalState]) + + abstract class State { + def bindings: Bindings + def body: Tree + def freeVars: List[Symbol] + def isFinal: Boolean + } + + case class FinalState(bindings: Bindings, body: Tree, freeVars: List[Symbol]) extends State { + private var referenceCount = 0 + private var _label: Symbol = null + def incrementReached: Unit = { referenceCount += 1 } + def setLabel(s: Symbol): Unit = { _label = s } + def label = _label + + def isFinal = true + def isNotReached = referenceCount == 0 + def isReachedOnce = referenceCount == 1 + def isReachedTwice = referenceCount > 1 + } + case class Combo(index: Int, sym: Symbol) { // is this combination covered by the given pattern? @@ -1143,10 +1126,12 @@ trait ParallelMatching extends ast.TreeDSL { case (i, syms) :: cs => for (s <- syms.toList; rest <- combine(cs)) yield Combo(i, s) :: rest } - /* internal representation is (tvars:List[Symbol], rows:List[Row]) - * - * tmp1 tmp_m - */ + /** Applying the rule will result in one of: + * + * VariableRule - if all patterns are default patterns + * MixtureRule - if one or more patterns are not default patterns + * ErrorRule - if there are no rows remaining + */ final def applyRule(): RuleApplication = { def dropIndex[T](xs: List[T], n: Int) = (xs take n) ::: (xs drop (n + 1)) @@ -1236,10 +1221,9 @@ trait ParallelMatching extends ast.TreeDSL { } val NoRep = Rep(Nil, Nil) - /** Expands the patterns recursively. */ - final def expand(roots: List[Symbol], cases: List[Tree]): (List[Row], List[FinalState], List[List[Symbol]]) = { - val res = unzip3( + final def expand(roots: List[Symbol], cases: List[Tree]): ExpandedMatrix = { + val (rows, finals) = List.unzip( for ((CaseDef(pat, guard, body), index) <- cases.zipWithIndex) yield { def mkRow(ps: List[Tree]) = Row(toPats(ps), NoBinding, Guard(guard), index) @@ -1248,11 +1232,11 @@ trait ParallelMatching extends ast.TreeDSL { case Apply(fn, args) => mkRow(args) case WILD() => mkRow(getDummies(roots.length)) } - (rowForPat, FinalState(NoBinding, body), definedVars(pat)) + (rowForPat, FinalState(NoBinding, body, definedVars(pat))) } ) - res match { case (rows, finals, vars) => (rows flatMap (x => x), finals, vars) } + new ExpandedMatrix(rows flatMap (x => x), finals) } /** returns the condition in "if (cond) k1 else k2" diff --git a/src/compiler/scala/tools/nsc/matching/TransMatcher.scala b/src/compiler/scala/tools/nsc/matching/TransMatcher.scala index 0618608e5a..78896846a4 100644 --- a/src/compiler/scala/tools/nsc/matching/TransMatcher.scala +++ b/src/compiler/scala/tools/nsc/matching/TransMatcher.scala @@ -203,16 +203,7 @@ trait TransMatcher extends ast.TreeDSL with CompactTreePrinter { val dfatree = typer typed Block(vars, mch) // packages into a code block // redundancy check - for ((cs, bx) <- cases.zipWithIndex) { - // if (!matrix.isReached(bx)) { - // println("cases = %s".format(cases)) - // println("matrix = %s, rep = %s".format(matrix, rep)) - // println("dfatree = " + toCompactString(dfatree)) - // } - if (!matrix.isReached(bx)) - cunit.error(cs.body.pos, "unreachable code") - } - + matrix.targets filter (_.isNotReached) foreach (cs => cunit.error(cs.body.pos, "unreachable code")) // cleanup performs squeezing and resets any remaining TRANS_FLAGs matrix cleanup dfatree } |