diff options
4 files changed, 123 insertions, 192 deletions
diff --git a/src/compiler/scala/tools/nsc/matching/Matrix.scala b/src/compiler/scala/tools/nsc/matching/Matrix.scala index 2532801d83..5648f97f84 100644 --- a/src/compiler/scala/tools/nsc/matching/Matrix.scala +++ b/src/compiler/scala/tools/nsc/matching/Matrix.scala @@ -88,21 +88,17 @@ trait Matrix extends MatrixAdditions { context: MatrixContext): Tree = { import context._ - // log("handlePattern: selector.tpe = " + selector.tpe) - // sets up top level match val matrixInit: MatrixInit = { val v = copyVar(selector, isChecked, selector.tpe, "temp") MatrixInit(List(v), cases, atPos(selector.pos)(MATCHERROR(v.ident))) } - - val matrix = new MatchMatrix(context) { lazy val data = matrixInit } - val rep = matrix.expansion // expands casedefs and assigns name - val mch = typer typed rep.toTree // executes algorithm, converts tree to DFA - val dfatree = typer typed Block(matrixInit.valDefs, mch) // packages into a code block + val matrix = new MatchMatrix(context) { lazy val data = matrixInit } + val mch = typer typed matrix.expansion.toTree + val dfatree = typer typed Block(matrix.data.valDefs, mch) // redundancy check - matrix.targets filter (_.isNotReached) foreach (cs => cunit.error(cs.body.pos, "unreachable code")) + matrix.targets filter (_.unreached) foreach (cs => cunit.error(cs.body.pos, "unreachable code")) // optimize performs squeezing and resets any remaining NO_EXHAUSTIVE tracing("handlePattern")(matrix optimize dfatree) } @@ -168,9 +164,9 @@ trait Matrix extends MatrixAdditions { val emptyPatternVarGroup = PatternVarGroup() class PatternVarGroup(val pvs: List[PatternVar]) { - def syms = pvs map (_.sym) + def syms = pvs map (_.sym) def valDefs = pvs map (_.valDef) - def idents = pvs map (_.ident) + def idents = pvs map (_.ident) def extractIndex(index: Int): (PatternVar, PatternVarGroup) = { val (t, ts) = self.extractIndex(pvs, index) diff --git a/src/compiler/scala/tools/nsc/matching/MatrixAdditions.scala b/src/compiler/scala/tools/nsc/matching/MatrixAdditions.scala index 3b481dd03e..e59d8c7858 100644 --- a/src/compiler/scala/tools/nsc/matching/MatrixAdditions.scala +++ b/src/compiler/scala/tools/nsc/matching/MatrixAdditions.scala @@ -12,8 +12,7 @@ import PartialFunction._ /** Traits which are mixed into MatchMatrix, but separated out as * (somewhat) independent components to keep them on the sidelines. */ -trait MatrixAdditions extends ast.TreeDSL -{ +trait MatrixAdditions extends ast.TreeDSL { self: ExplicitOuter with ParallelMatching => import global.{ typer => _, _ } @@ -30,50 +29,52 @@ trait MatrixAdditions extends ast.TreeDSL private val settings_squeeze = !settings.Ynosqueeze.value - def squeezedBlockPVs(pvs: List[PatternVar], exp: Tree): Tree = - squeezedBlock(pvs map (_.valDef), exp) + class RefTraverser(vd: ValDef) extends Traverser { + private val targetSymbol = vd.symbol + private var safeRefs = 0 + private var isSafe = true - /** Compresses multiple Blocks. */ - def mkBlock(stats: List[Tree], expr: Tree): Tree = expr match { - case Block(stats1, expr1) if stats.isEmpty => mkBlock(stats1, expr1) - case _ => Block(stats, expr) - } + def canDrop = isSafe && safeRefs == 0 + def canInline = isSafe && safeRefs == 1 - def squeezedBlock(vds: List[Tree], exp: Tree): Tree = - if (settings_squeeze) mkBlock(Nil, squeezedBlock1(vds, exp)) - else mkBlock(vds, exp) + override def traverse(tree: Tree): Unit = tree match { + case t: Ident if t.symbol eq targetSymbol => + // target symbol's owner should match currentOwner + if (targetSymbol.owner == currentOwner) safeRefs += 1 + else isSafe = false - private def squeezedBlock1(vds: List[Tree], exp: Tree): Tree = { - class RefTraverser(sym: Symbol) extends Traverser { - var nref, nsafeRef = 0 - override def traverse(tree: Tree) = tree match { - case t: Ident if t.symbol eq sym => - nref += 1 - if (sym.owner == currentOwner) // oldOwner should match currentOwner - nsafeRef += 1 - - case LabelDef(_, args, rhs) => - (args dropWhile(_.symbol ne sym)) match { - case Nil => - case _ => nref += 2 // cannot substitute this one - } - traverse(rhs) - case t if nref > 1 => // abort, no story to tell - case t => - super.traverse(t) - } - } + case LabelDef(_, params, rhs) => + if (params exists (_.symbol eq targetSymbol)) // cannot substitute this one + isSafe = false - class Subst(sym: Symbol, rhs: Tree) extends Transformer { - var stop = false - override def transform(tree: Tree) = tree match { - case t: Ident if t.symbol == sym => - stop = true - rhs - case _ => if (stop) tree else super.transform(tree) - } + traverse(rhs) + case _ if safeRefs > 1 => () + case _ => + super.traverse(tree) + } + } + class Subst(vd: ValDef) extends Transformer { + private var stop = false + override def transform(tree: Tree): Tree = tree match { + case t: Ident if t.symbol == vd.symbol => + stop = true + vd.rhs + case _ => + if (stop) tree + else super.transform(tree) } + } + + /** Compresses multiple Blocks. */ + private def combineBlocks(stats: List[Tree], expr: Tree): Tree = expr match { + case Block(stats1, expr1) if stats.isEmpty => combineBlocks(stats1, expr1) + case _ => Block(stats, expr) + } + def squeezedBlock(vds: List[Tree], exp: Tree): Tree = + if (settings_squeeze) combineBlocks(Nil, squeezedBlock1(vds, exp)) + else combineBlocks(vds, exp) + private def squeezedBlock1(vds: List[Tree], exp: Tree): Tree = { lazy val squeezedTail = squeezedBlock(vds.tail, exp) def default = squeezedTail match { case Block(vds2, exp2) => Block(vds.head :: vds2, exp2) @@ -83,17 +84,13 @@ trait MatrixAdditions extends ast.TreeDSL if (vds.isEmpty) exp else vds.head match { case vd: ValDef => - val sym = vd.symbol - val rt = new RefTraverser(sym) - rt.atOwner (owner) (rt traverse squeezedTail) - - rt.nref match { - case 0 => squeezedTail - case 1 if rt.nsafeRef == 1 => new Subst(sym, vd.rhs) transform squeezedTail - case _ => default - } - case _ => - default + val rt = new RefTraverser(vd) + rt.atOwner(owner)(rt traverse squeezedTail) + + if (rt.canDrop) squeezedTail + else if (rt.canInline) new Subst(vd) transform squeezedTail + else default + case _ => default } } } @@ -109,9 +106,7 @@ trait MatrixAdditions extends ast.TreeDSL object lxtt extends Transformer { override def transform(tree: Tree): Tree = tree match { case blck @ Block(vdefs, ld @ LabelDef(name, params, body)) => - def shouldInline(t: FinalState) = t.isReachedOnce && (t.labelSym eq ld.symbol) - - if (targets exists shouldInline) squeezedBlock(vdefs, body) + if (targets exists (_ shouldInline ld.symbol)) squeezedBlock(vdefs, body) else blck case t => @@ -166,7 +161,7 @@ trait MatrixAdditions extends ast.TreeDSL private def requiresExhaustive(sym: Symbol) = { (sym.isMutable) && // indicates that have not yet checked exhaustivity - !(sym hasFlag NO_EXHAUSTIVE) && // indicates @unchecked + !(sym hasFlag NO_EXHAUSTIVE) && // indicates @unchecked (sym.tpe.typeSymbol.isSealed) && !isValueClass(sym.tpe.typeSymbol) // make sure it's not a primitive, else (5: Byte) match { case 5 => ... } sees no Byte } diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala index 89e849d038..1b0265ce5d 100644 --- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala +++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala @@ -25,7 +25,7 @@ trait ParallelMatching extends ast.TreeDSL self: ExplicitOuter => import global.{ typer => _, _ } - import definitions.{ AnyRefClass, IntClass, BooleanClass, getProductArgs, productProj } + import definitions.{ AnyRefClass, NothingClass, IntClass, BooleanClass, getProductArgs, productProj } import CODE._ import Types._ import Debug._ @@ -55,16 +55,8 @@ trait ParallelMatching extends ast.TreeDSL */ final def requestBody(bx: Int, subst: Bindings): Tree = { // shortcut - if (bx < 0) - return Apply(ID(shortCuts(-bx-1)), Nil) - - val pvgroup = PatternVarGroup.fromBindings(subst.get(), targets(bx).freeVars) - val target = targets(bx) - - // first time this bx is requested - might be bound elsewhere - if (target.isNotReached) target.createLabelBody(bx, pvgroup) - // call label "method" if possible - else target.getLabelBody(pvgroup) + if (bx < 0) Apply(ID(shortCuts(-bx-1)), Nil) + else targets(bx) labelBody subst } /** the injection here handles alternatives and unapply type tests */ @@ -107,9 +99,9 @@ trait ParallelMatching extends ast.TreeDSL def allValDefs = extraValDefs ::: accessorPatternVars.valDefs // tests - def isDefined = sym ne NoSymbol - def isSimple = tpe.isByte || tpe.isShort || tpe.isChar || tpe.isInt - def isCaseClass = tpe.typeSymbol.isCase + def isDefined = sym ne NoSymbol + def isSubrangeType = Set(ByteClass, ShortClass, CharClass, IntClass) contains tpe.typeSymbol + def isCaseClass = tpe.typeSymbol.isCase // sequences def seqType = tpe.widen baseType SeqClass @@ -144,7 +136,7 @@ trait ParallelMatching extends ast.TreeDSL // ... should compile to a switch. It doesn't because the scrut isn't Int/Char, but // that could be handle in an if/else since every pattern requires an Int. // More immediately, Byte and Short scruts should also work. - if (!scrut.isSimple) None + if (!scrut.isSubrangeType) None else { val (_lits, others) = ps span isSwitchableConst val lits = _lits collect { case x: LiteralPattern => x } @@ -162,7 +154,7 @@ trait ParallelMatching extends ast.TreeDSL override val ps: List[LiteralPattern], val defaultPattern: Option[Pattern] ) extends PatternMatch(scrut, ps) { - require(scrut.isSimple && (ps forall (_.isSwitchable))) + require(scrut.isSubrangeType && (ps forall (_.isSwitchable))) } case class PatternMatch(scrut: Scrutinee, ps: List[Pattern]) { @@ -217,17 +209,6 @@ trait ParallelMatching extends ast.TreeDSL def MixtureRule(scrut: Scrutinee, column: List[Pattern], rest: Rep): RuleApplication = PatternMatch(scrut, column) mkRule rest - /** - * Class encapsulating a guard expression in a pattern match: - * case ... if(tree) => ... - */ - case class Guard(tree: Tree) { - def isEmpty = tree.isEmpty - def duplicate = Guard(tree.duplicate) - override def toString() = if (isEmpty) "" else " // if %s" format tree - } - val NoGuard = Guard(EmptyTree) - /***** Rule Applications *****/ sealed abstract class RuleApplication { @@ -263,14 +244,13 @@ trait ParallelMatching extends ast.TreeDSL /** {case ... if guard => bx} else {guardedRest} */ /** VariableRule: The top-most rows has only variable (non-constructor) patterns. */ - case class VariableRule(subst: Bindings, guard: Guard, guardedRest: Rep, bx: Int) extends RuleApplication { + case class VariableRule(subst: Bindings, guard: Tree, guardedRest: Rep, bx: Int) extends RuleApplication { def pmatch: PatternMatch = impossible def rest: Rep = guardedRest - lazy val cond = if (guard.isEmpty) TRUE else guard.duplicate.tree + lazy val cond = if (guard.isEmpty) TRUE else guard lazy val success = requestBody(bx, subst) lazy val failure = guardedRest.toTree - lazy val pvgroup = PatternVarGroup.fromBindings(subst.get()) final def tree(): Tree = @@ -278,11 +258,7 @@ trait ParallelMatching extends ast.TreeDSL else squeezedBlock(pvgroup.valDefs, codegen) } - /** Mixture rule for all literal ints (and chars) i.e. hopefully a switch - * will be emitted on the JVM. - */ - class MixLiteralInts(val pmatch: PatternSwitch, val rest: Rep) extends RuleApplication - { + class MixLiteralInts(val pmatch: PatternSwitch, val rest: Rep) extends RuleApplication { val literals = pmatch.ps val defaultPattern = pmatch.defaultPattern @@ -312,42 +288,41 @@ trait ParallelMatching extends ast.TreeDSL case None => (Nil, Nil) case Some(Pattern(_, vs)) => (vs, List(rebindAll(rest rows literals.size, vs, scrut.sym))) } + // literalMap is a map from each literal to a list of row indices. // varMap is a list from each literal to a list of the defined vars. - lazy val (literalMap, varMap) = { - val tags = literals map (_.intValue) - val varMap = tags zip (literals map (_.deepBoundVariables)) - val litMap = - tags.zipWithIndex.reverse.foldLeft(IntMap.empty[List[Int]]) { - // we reverse before the fold so the list can be built with :: - case (map, (tag, index)) => map.updated(tag, index :: map.getOrElse(tag, Nil)) - } - - (litMap, varMap) + lazy val (litPairs, varMap) = ( + literals.zipWithIndex map { + case (lit, index) => + val tag = lit.intValue + (tag -> index, tag -> lit.deepBoundVariables) + } unzip + ) + def literalMap = litPairs groupBy (_._1) map { + case (k, vs) => (k, vs map (_._2)) } lazy val cases = for ((tag, indices) <- literalMap.toList) yield { val newRows = indices map (i => addDefaultVars(i)(rest rows i)) - val r = remake(newRows ::: defaultRows, includeScrut = false) - val r2 = make(r.tvars, r.rows map (x => x rebind bindVars(tag, x.subst))) + val r = remake(newRows ++ defaultRows, includeScrut = false) + val r2 = make(r.tvars, r.rows map (x => x rebind bindVars(tag, x.subst))) CASE(Literal(tag)) ==> r2.toTree } - lazy val defaultTree = remake(defaultRows, includeScrut = false).toTree - def casesWithDefault = cases ::: List(CASE(WILD(IntClass.tpe)) ==> defaultTree) + lazy val defaultTree = remake(defaultRows, includeScrut = false).toTree + def defaultCase = CASE(WILD(IntClass.tpe)) ==> defaultTree // cond/success/failure only used if there is exactly one case. - lazy val (cond, success) = cases match { - case List(CaseDef(lit, _, body)) => (scrut.id MEMBER_== lit, body) - } + lazy val cond = scrut.id MEMBER_== cases.head.pat + lazy val success = cases.head.body lazy val failure = defaultTree // only one case becomes if/else, otherwise match def tree() = if (cases.size == 1) codegen - else casted MATCH (casesWithDefault: _*) + else casted MATCH (cases :+ defaultCase: _*) } /** mixture rule for unapply pattern @@ -426,7 +401,7 @@ trait ParallelMatching extends ast.TreeDSL val (squeezePVs, pvs, rows) = doSuccess val srep = remake(rows, pvs).toTree - squeezedBlockPVs(squeezePVs, srep) + squeezedBlock(squeezePVs map (_.valDef), srep) } final def tree() = @@ -521,7 +496,7 @@ trait ParallelMatching extends ast.TreeDSL // wrapping in a null check on the scrutinee nullSafe(compareFn, FALSE)(scrut.id) } - lazy val success = squeezedBlockPVs(pvs, remake(successRows, pvs, hasStar).toTree) + lazy val success = squeezedBlock(pvs map (_.valDef), remake(successRows, pvs, hasStar).toTree) lazy val failure = remake(failRows).toTree final def tree(): Tree = codegen @@ -535,7 +510,7 @@ trait ParallelMatching extends ast.TreeDSL private lazy val rhs = decodedEqualsType(head.tpe) match { case SingleType(pre, sym) => REF(pre, sym) - case PseudoType(o) => o.duplicate + case PseudoType(o) => o } lazy val label = @@ -546,7 +521,7 @@ trait ParallelMatching extends ast.TreeDSL lazy val success = remake(List( rest.rows.head.insert2(List(NoPattern), head.boundVariables, scrut.sym), - Row(emptyPatterns(1 + rest.tvars.size), NoBinding, NoGuard, shortCut(label)) + Row(emptyPatterns(1 + rest.tvars.size), NoBinding, EmptyTree, shortCut(label)) )).toTree lazy val failure = LabelDef(label, Nil, labelBody) @@ -637,7 +612,7 @@ trait ParallelMatching extends ast.TreeDSL /*** States, Rows, Etc. ***/ - case class Row(pats: List[Pattern], subst: Bindings, guard: Guard, bx: Int) { + case class Row(pats: List[Pattern], subst: Bindings, guard: Tree, bx: Int) { private def nobindings = subst.get().isEmpty private def bindstr = if (nobindings) "" else pp(subst) // if (pats exists (p => !p.isDefault)) @@ -693,8 +668,8 @@ trait ParallelMatching extends ast.TreeDSL def vprint(vs: List[Any]) = if (vs.isEmpty) "" else ": %s".format(pp(vs)) def rprint(r: Row) = pp(r) def tprint(t: FinalState) = - if (t.freeVars.isEmpty) " ==> %s".format(pp(t.body)) - else " ==>\n %s".format(pp(t.freeVars -> t.body)) + if (t.params.isEmpty) " ==> %s".format(pp(t.body)) + else " ==>\n %s".format(pp(t.params -> t.body)) val xs = rows zip targets map { case (r,t) => rprint(r) + tprint(t) } val ppstr = pp(xs, newlines = true) @@ -703,70 +678,38 @@ trait ParallelMatching extends ast.TreeDSL } } - abstract class State { - def body: Tree - def freeVars: List[Symbol] - def isFinal: Boolean - } - - case class FinalState(bx: Int, body: Tree, freeVars: List[Symbol]) extends State { + case class FinalState(bx: Int, body: Tree, params: List[Symbol]) { private var referenceCount = 0 - private var _label: LabelDef = null - private var _labelSym: Symbol = null - - def labelSym = _labelSym - def label = _label - - // @bug: typer is not able to digest a body of type Nothing being assigned result type Unit - def bodyTpe = if (body.tpe.isNothing) body.tpe else matchResultType - def duplicate = body.duplicate setType bodyTpe - - def isFinal = true - def isLabellable = !cond(body) { case _: Literal => true } - def isNotReached = referenceCount == 0 - def isReachedOnce = referenceCount == 1 - def isReachedTwice = referenceCount > 1 - - // arguments to pass to this body%xx - def labelParamTypes = label.tpe.paramTypes - - def createLabelBody(index: Int, pvgroup: PatternVarGroup) = { - val args = pvgroup.syms - val vdefs = pvgroup.valDefs - - val name = "body%" + index - require(_labelSym == null) - referenceCount += 1 - - if (isLabellable) { - val mtype = MethodType(freeVars, bodyTpe) - _labelSym = owner.newLabel(body.pos, name) setInfo mtype - _label = typer typedLabelDef LabelDef(_labelSym, freeVars, body setType bodyTpe) - // TRACE("Creating index %d: mtype = %s".format(bx, mtype)) - // TRACE("[New label] def %s%s: %s = %s".format(name, pp(freeVars), bodyTpe, body)) - } - - ifLabellable(vdefs, squeezedBlock(vdefs, label)) + // typer is not able to digest a body of type Nothing being assigned result type Unit + private def caseResultType = if (body.tpe.isNothing) body.tpe else matchResultType + private lazy val label: LabelDef = body match { + case Literal(_) => null + case _ => + val symbol = owner.newLabel(body.pos, "body%" + bx) setInfo MethodType(params, caseResultType) + // typer typedLabelDef + LabelDef(symbol, params, body setType caseResultType) } - def getLabelBody(pvgroup: PatternVarGroup): Tree = { - val idents = pvgroup map (_.rhs) - val vdefs = pvgroup.valDefs + def unreached = referenceCount == 0 + def shouldInline(sym: Symbol) = referenceCount == 1 && label != null && label.symbol == sym + + def labelBody(subst: Bindings): Tree = { referenceCount += 1 + val pvgroup = PatternVarGroup.fromBindings(subst.get(), params) - ifLabellable(vdefs, ID(labelSym) APPLY (idents)) + if (referenceCount > 1 && label != null) + ID(label.symbol) APPLY (pvgroup map (_.rhs)) + else squeezedBlock(pvgroup.valDefs, + if (label != null) label + else body.duplicate setType caseResultType + ) } - - private def ifLabellable(vdefs: List[Tree], t: => Tree) = - if (isLabellable) t - else squeezedBlock(vdefs, duplicate) - - override def toString() = pp("Final%d%s".format(bx, pp(freeVars)) -> body) + override def toString() = pp("Final%d%s".format(bx, pp(params)) -> body) } case class Rep(val tvars: PatternVarGroup, val rows: List[Row]) { lazy val Row(pats, subst, guard, index) = rows.head - lazy val guardedRest = if (guard.isEmpty) NoRep else make(tvars, rows.tail) + lazy val guardedRest = if (guard.isEmpty) Rep(Nil, Nil) else make(tvars, rows.tail) lazy val (defaults, others) = pats span (_.isDefault) /** Sealed classes. */ @@ -818,12 +761,11 @@ trait ParallelMatching extends ast.TreeDSL else "Rep(%dx%d)%s%s".format(tvars.size, rows.size, ppn(tvars), ppn(rows)) } - val NoRep = Rep(Nil, Nil) /** Expands the patterns recursively. */ final def expand(roots: List[PatternVar], cases: List[CaseDef]) = tracing("Expanded")(ExpandedMatrix( for ((CaseDef(pat, guard, body), index) <- cases.zipWithIndex) yield { - def mkRow(ps: List[Tree]) = Row(toPats(ps), NoBinding, Guard(guard), index) + def mkRow(ps: List[Tree]) = Row(toPats(ps), NoBinding, guard, index) val pattern = Pattern(pat) val row = mkRow(pat match { diff --git a/src/compiler/scala/tools/nsc/matching/Patterns.scala b/src/compiler/scala/tools/nsc/matching/Patterns.scala index 40ff73e648..742ab32736 100644 --- a/src/compiler/scala/tools/nsc/matching/Patterns.scala +++ b/src/compiler/scala/tools/nsc/matching/Patterns.scala @@ -98,7 +98,15 @@ trait Patterns extends ast.TreeDSL { require (args.isEmpty) val Apply(select: Select, _) = tree - override def sufficientType = mkSingletonFromQualifier + override lazy val sufficientType = qualifier.tpe match { + case t: ThisType => singleType(t, sym) // this.X + case _ => + qualifier match { + case _: Apply => PseudoType(tree) + case _ => singleType(Pattern(qualifier).necessaryType, sym) + } + } + override def simplify(pv: PatternVar) = this.rebindToObjectCheck() override def description = backticked match { case Some(s) => "this." + s @@ -382,16 +390,6 @@ trait Patterns extends ast.TreeDSL { case Apply(f, Nil) => getPathSegments(f) case _ => Nil } - protected def mkSingletonFromQualifier = { - def pType = qualifier match { - case _: Apply => PseudoType(tree) - case _ => singleType(Pattern(qualifier).necessaryType, sym) - } - qualifier.tpe match { - case t: ThisType => singleType(t, sym) // this.X - case _ => pType - } - } } sealed trait NamePattern extends Pattern { |