diff options
Diffstat (limited to 'src/compiler')
4 files changed, 88 insertions, 83 deletions
diff --git a/src/compiler/scala/tools/nsc/matching/Matrix.scala b/src/compiler/scala/tools/nsc/matching/Matrix.scala index 6d9f6a099b..e0d15a4c49 100644 --- a/src/compiler/scala/tools/nsc/matching/Matrix.scala +++ b/src/compiler/scala/tools/nsc/matching/Matrix.scala @@ -77,12 +77,6 @@ trait Matrix extends MatrixAdditions { states. Otherwise,the error state is used after its reference count has been incremented. **/ - case class MatrixInit( - roots: List[Symbol], - cases: List[CaseDef], - default: Tree - ) - case class MatrixContext( handleOuter: Tree => Tree, // for outer pointer typer: Typer, // a local typer @@ -95,14 +89,30 @@ trait Matrix extends MatrixAdditions { // TRANS_FLAG communicates there should be no exhaustiveness checking private def flags(checked: Boolean) = if (checked) Nil else List(TRANS_FLAG) - /** Every new variable allocated gets one of these. */ - class PatternVar(val lhs: Symbol, val rhs: Tree) { + case class MatrixInit( + roots: List[PatternVar], + cases: List[CaseDef], + default: Tree + ) { + def tvars = roots map (_.lhs) + def valDefs = roots map (_.valDef) + } + + /** Every temporary variable allocated is put in a PatternVar. + */ + class PatternVar(val lhs: Symbol, val rhs: Tree, val checked: Boolean) { + def sym = lhs lazy val ident = ID(lhs) lazy val valDef = typedValDef(lhs, rhs) override def toString() = "%s: %s = %s".format(lhs, lhs.info, rhs) } + /** Sets the rhs to EmptyTree, which makes the valDef ignored in Scrutinee. + */ + def specialVar(lhs: Symbol, checked: Boolean) = + new PatternVar(lhs, EmptyTree, checked) + /** Given a tree, creates a new synthetic variable of the same type * and assigns the tree to it. */ @@ -116,15 +126,17 @@ trait Matrix extends MatrixAdditions { val name = newName(root.pos, label) val sym = newVar(root.pos, tpe, flags(checked), name) - tracing("copy", new PatternVar(sym, root)) + tracing("copy", new PatternVar(sym, root, checked)) } - /** The rhs is expressed as a function of the lhs. */ + /** Creates a new synthetic variable of the specified type and + * assigns the result of f(symbol) to it. + */ def createVar(tpe: Type, f: Symbol => Tree, checked: Boolean) = { val lhs = newVar(owner.pos, tpe, flags(checked)) val rhs = f(lhs) - tracing("create", new PatternVar(lhs, rhs)) + tracing("create", new PatternVar(lhs, rhs, checked)) } private def newVar( diff --git a/src/compiler/scala/tools/nsc/matching/MatrixAdditions.scala b/src/compiler/scala/tools/nsc/matching/MatrixAdditions.scala index c1364af806..c27c713a47 100644 --- a/src/compiler/scala/tools/nsc/matching/MatrixAdditions.scala +++ b/src/compiler/scala/tools/nsc/matching/MatrixAdditions.scala @@ -25,6 +25,9 @@ trait MatrixAdditions extends ast.TreeDSL private[matching] trait Squeezer { self: MatrixContext => + def squeezedBlockPVs(pvs: List[PatternVar], exp: Tree): Tree = + squeezedBlock(pvs map (_.valDef), exp) + def squeezedBlock(vds: List[Tree], exp: Tree): Tree = if (settings_squeeze) Block(Nil, squeezedBlock1(vds, exp)) else Block(vds, exp) @@ -204,7 +207,7 @@ trait MatrixAdditions extends ast.TreeDSL } private lazy val inexhaustives: List[List[Combo]] = { val collected = - for ((sym, i) <- tvars.zipWithIndex ; if requiresExhaustive(sym)) yield + for ((pv, i) <- tvars.zipWithIndex ; val sym = pv.lhs ; if requiresExhaustive(sym)) yield i -> sealedSymsFor(sym.tpe.typeSymbol) val folded = @@ -230,7 +233,7 @@ trait MatrixAdditions extends ast.TreeDSL def check = { def errMsg = (inexhaustives map mkMissingStr).mkString if (inexhaustives.nonEmpty) - cunit.warning(tvars.head.pos, "match is not exhaustive!\n" + errMsg) + cunit.warning(tvars.head.lhs.pos, "match is not exhaustive!\n" + errMsg) rep } diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala index 801bdd3d19..8143b2e731 100644 --- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala +++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala @@ -37,12 +37,14 @@ trait ParallelMatching extends ast.TreeDSL def toPats(xs: List[Tree]): List[Pattern] = xs map Pattern.apply /** The umbrella matrix class. **/ - class MatchMatrix(val context: MatrixContext, data: MatrixInit) extends MatchMatrixOptimizer with MatrixExhaustiveness { + abstract class MatchMatrix(val context: MatrixContext) extends MatchMatrixOptimizer with MatrixExhaustiveness { import context._ - val MatrixInit(roots, cases, failTree) = data - val ExpandedMatrix(rows, targets) = expand(roots, cases) - val expansion: Rep = make(roots, rows) + def data: MatrixContext#MatrixInit + + lazy val MatrixInit(roots, cases, failTree) = data + lazy val ExpandedMatrix(rows, targets) = expand(roots, cases) + lazy val expansion: Rep = make(roots, rows) val shortCuts = new ListBuffer[Symbol]() @@ -85,8 +87,8 @@ trait ParallelMatching extends ast.TreeDSL // make(_tvars, _rows) /** the injection here handles alternatives and unapply type tests */ - final def make(tvars: List[Symbol], row1: List[Row]): Rep = { - def classifyPat(opat: Pattern, j: Int): Pattern = opat simplify tvars(j) + final def make(tvars: List[PatternVar], row1: List[Row]): Rep = { + def classifyPat(opat: Pattern, j: Int): Pattern = opat simplify tvars(j).sym val rows = row1 flatMap (_ expandAlternatives classifyPat) if (rows.length != row1.length) make(tvars, rows) // recursive call if any change @@ -96,32 +98,30 @@ trait ParallelMatching extends ast.TreeDSL override def toString() = "MatchMatrix(%s) { %s }".format(matchResultType, indentAll(targets)) /** - * Encapsulates a symbol being matched on. - * - * sym match { ... } - * - * results in Scrutinee(sym). + * Encapsulates a symbol being matched on. It is created from a + * PatternVar, which encapsulates the symbol's creation and assignment. * - * Note that we only ever match on Symbols, not Trees: a temporary variable - * is created for any expressions being matched on. + * We never match on trees directly - a temporary variable is created + * (in a PatternVar) for any expression being matched on. */ - class Scrutinee(val sym: Symbol, extraValdefs: List[ValDef] = Nil) { + class Scrutinee(val pv: PatternVar) { import definitions._ // presenting a face of our symbol - def tpe = sym.tpe - def pos = sym.pos - def id = ID(sym) // attributed ident + def sym = pv.sym + def tpe = sym.tpe + def pos = sym.pos + def id = ID(sym) // attributed ident def accessors = if (isCaseClass) sym.caseFieldAccessors else Nil def accessorTypes = accessors map (x => (tpe memberType x).resultType) - private lazy val accessorPatternVars = + lazy val accessorPatternVars = for ((accessor, tpe) <- accessors zip accessorTypes) yield createVar(tpe, _ => fn(id, accessor)) - def accessorVars = accessorPatternVars map (_.lhs) - def accessorValDefs = extraValdefs ::: (accessorPatternVars map (_.valDef)) + private def extraValDefs = if (pv.rhs.isEmpty) Nil else List(pv.valDef) + def allValDefs = extraValDefs ::: (accessorPatternVars map (_.valDef)) // tests def isDefined = sym ne NoSymbol @@ -141,10 +141,7 @@ trait ParallelMatching extends ast.TreeDSL def castedTo(headType: Type) = if (tpe =:= headType) this - else { - val pv = createVar(headType, lhs => id AS_ANY lhs.tpe) - new Scrutinee(pv.lhs, List(pv.valDef)) - } + else new Scrutinee(createVar(headType, lhs => id AS_ANY lhs.tpe)) override def toString() = "(%s: %s)".format(id, tpe) } @@ -246,7 +243,7 @@ trait ParallelMatching extends ast.TreeDSL def mkFail(xs: List[Row]): Tree = xs match { case Nil => failTree - case _ => make(scrut.sym :: rest.tvars, xs).toTree + case _ => make(scrut.pv :: rest.tvars, xs).toTree } /** translate outcome of the rule application into code (possible involving recursive application of rewriting) */ @@ -391,38 +388,35 @@ trait ParallelMatching extends ast.TreeDSL lazy val failure = mkFail(zipped.tail filterNot (x => isSameUnapply(x._1)) map { case (pat, r) => r insert pat }) - private def doSuccess: (List[Tree], List[Symbol], List[Row]) = { - lazy val alloc = scrut.createVar( + private def doSuccess: (List[PatternVar], List[PatternVar], List[Row]) = { + // pattern variable for the unapply result of Some(x).get + lazy val pv = scrut.createVar( app.tpe typeArgs 0, _ => fn(ID(unapplyResult.lhs), nme.get) ) - def vdef = alloc.valDef - def lhs = alloc.lhs + def tuple = pv.lhs // at this point it's Some[T1,T2...] - lazy val tpes = getProductArgs(lhs.tpe).get + lazy val tpes = getProductArgs(tuple.tpe).get - // one allocation per tuple element - lazy val allocations = + // one pattern variable per tuple element + lazy val tuplePVs = for ((tpe, i) <- tpes.zipWithIndex) yield - scrut.createVar(tpe, _ => fn(ID(lhs), productProj(lhs, i + 1))) - - def vdefs = allocations map (_.valDef) - def vsyms = allocations map (_.lhs) + scrut.createVar(tpe, _ => fn(ID(tuple), productProj(tuple, i + 1))) // 0 is Boolean, 1 is Option[T], 2+ is Option[(T1,T2,...)] args.length match { case 0 => (Nil, Nil, mkNewRows((xs) => Nil, 0)) - case 1 => (List(vdef), List(lhs), mkNewRows(xs => List(xs.head), 1)) - case _ => (vdef :: vdefs, vsyms, mkNewRows(identity, tpes.size)) + case 1 => (List(pv), List(pv), mkNewRows(xs => List(xs.head), 1)) + case _ => (pv :: tuplePVs, tuplePVs, mkNewRows(identity, tpes.size)) } } lazy val success = { - val (vdefs, ntemps, rows) = doSuccess - val srep = make(ntemps ::: scrut.sym :: rest.tvars, rows).toTree + val (squeezePVs, pvs, rows) = doSuccess + val srep = make(pvs ::: scrut.pv :: rest.tvars, rows).toTree - squeezedBlock(vdefs, srep) + squeezedBlockPVs(squeezePVs, srep) } final def tree() = @@ -475,7 +469,7 @@ trait ParallelMatching extends ast.TreeDSL def elemAt(i: Int) = (scrut.id DOT (scrut.tpe member nme.apply))(LIT(i)) def elemCount = pivot.nonStarPatterns.size - val vs = + val pvs = // one per element .. pos = pat.pos (pivot.nonStarPatterns.zipWithIndex map { case (pat, i) => scrut.createVar(scrut.elemType, _ => elemAt(i)) }) ::: // and one for the rest of the sequence @@ -488,12 +482,12 @@ trait ParallelMatching extends ast.TreeDSL } ) - val symList = if (pivot.hasStar) List(scrut.sym) else Nil - val succ = make(List(vs map (_.lhs), symList, rest.tvars).flatten, nrows.flatten) + val symList = if (pivot.hasStar) List(scrut.pv) else Nil + val succ = make(List(pvs, symList, rest.tvars).flatten, nrows.flatten) ( - squeezedBlock(vs map (_.valDef), succ.toTree), - make(scrut.sym :: rest.tvars, frows.flatten).toTree + squeezedBlockPVs(pvs, succ.toTree), + make(scrut.pv :: rest.tvars, frows.flatten).toTree ) } @@ -503,7 +497,7 @@ trait ParallelMatching extends ast.TreeDSL // @todo: equals test for same constant class MixEquals(val pmatch: PatternMatch, val rest: Rep) extends RuleApplication { private def mkNewRep(rows: List[Row]) = - make(scrut.sym :: rest.tvars, rows).toTree + make(scrut.pv :: rest.tvars, rows).toTree private lazy val labelBody = mkNewRep(List.map2(rest.rows.tail, pmatch.tail)(_ insert _)) @@ -634,7 +628,7 @@ trait ParallelMatching extends ast.TreeDSL lazy val success = { val (subtests, subtestVars) = - if (isAnyMoreSpecific) (mkZipped, List(casted.sym)) + if (isAnyMoreSpecific) (mkZipped, List(casted.pv)) else (subsumed, Nil) val newRows = @@ -642,9 +636,9 @@ trait ParallelMatching extends ast.TreeDSL (rest rows j).insert2(ps, pmatch(j).boundVariables, casted.sym) val srep = - make(subtestVars ::: casted.accessorVars ::: rest.tvars, newRows) + make(subtestVars ::: casted.accessorPatternVars ::: rest.tvars, newRows) - squeezedBlock(casted.accessorValDefs, srep.toTree) + squeezedBlock(casted.allValDefs, srep.toTree) } lazy val failure = mkFail(remaining map tupled((p1, p2) => rest rows p1 insert p2)) @@ -658,10 +652,7 @@ trait ParallelMatching extends ast.TreeDSL if (pats exists (p => !p.isDefault)) traceCategory("Row", "%s", pp(pats)) - /** Drops the 'i'th pattern */ - def drop(i: Int) = copy(pats = dropIndex(pats, i)) - - /** Extracts the nth pattern. */ + /** Extracts the 'i'th pattern. */ def extractColumn(i: Int) = { val (x, xs) = extractIndex(pats, i) (x, copy(pats = xs)) @@ -780,7 +771,7 @@ trait ParallelMatching extends ast.TreeDSL override def toString() = pp("Final%d%s".format(bx, pp(freeVars)) -> body) } - case class Rep(val tvars: List[Symbol], val rows: List[Row]) { + case class Rep(val tvars: List[PatternVar], val rows: List[Row]) { lazy val Row(pats, subst, guard, index) = rows.head lazy val guardedRest = if (guard.isEmpty) NoRep else make(tvars, rows.tail) lazy val (defaults, others) = pats span (_.isDefault) @@ -795,14 +786,14 @@ trait ParallelMatching extends ast.TreeDSL List.unzip(rows map (_ extractColumn index)) /** Now the 'i'th tvar is separated out and used as a new Scrutinee. */ - private val (_sym, _tvars) = + private val (_pv, _tvars) = extractIndex(tvars, index) /** The non-default pattern (others.head) replaces the column head. */ private val (_ncol, _nrep) = (others.head :: _column.tail, make(_tvars, _rows)) - def mix = MixtureRule(new Scrutinee(_sym), _ncol, _nrep) + def mix = MixtureRule(new Scrutinee(specialVar(_pv.sym, _pv.checked)), _ncol, _nrep) } /** Converts this to a tree - recursively acquires subreps. */ @@ -811,7 +802,7 @@ trait ParallelMatching extends ast.TreeDSL /** The VariableRule. */ private def variable() = { val binding = (defaults map (_.boundVariables) zip tvars) . - foldLeft(subst)((b, pair) => b.add(pair._1, pair._2)) + foldLeft(subst)((b, pair) => b.add(pair._1, pair._2.lhs)) VariableRule(binding, guard, guardedRest, index) } @@ -837,7 +828,7 @@ trait ParallelMatching extends ast.TreeDSL val NoRep = Rep(Nil, Nil) /** Expands the patterns recursively. */ - final def expand(roots: List[Symbol], cases: List[CaseDef]): ExpandedMatrix = { + final def expand(roots: List[PatternVar], cases: List[CaseDef]): ExpandedMatrix = { val (rows, finals) = List.unzip( for ((CaseDef(pat, guard, body), index) <- cases.zipWithIndex) yield { def mkRow(ps: List[Tree]) = Row(toPats(ps), NoBinding, Guard(guard), index) diff --git a/src/compiler/scala/tools/nsc/matching/TransMatcher.scala b/src/compiler/scala/tools/nsc/matching/TransMatcher.scala index 2fb162a059..88db07518d 100644 --- a/src/compiler/scala/tools/nsc/matching/TransMatcher.scala +++ b/src/compiler/scala/tools/nsc/matching/TransMatcher.scala @@ -56,31 +56,30 @@ trait TransMatcher extends ast.TreeDSL { (cases forall caseIsOk) // For x match { ... we start with a single root - def singleMatch(): (List[Tree], MatrixInit) = { + def singleMatch(): MatrixInit = { val v = copyVar(selector, isChecked) - - (tracing("root(s)", List(v.valDef)), MatrixInit(List(v.lhs), cases, matchError(v.ident))) + tracing("root(s)", context.MatrixInit(List(v), cases, matchError(v.ident))) } // For (x, y, z) match { ... we start with multiple roots, called tpXX. - def tupleMatch(app: Apply): (List[Tree], MatrixInit) = { + def tupleMatch(app: Apply): MatrixInit = { val Apply(fn, args) = app val vs = args zip rootTypes map { case (arg, tpe) => copyVar(arg, isChecked, tpe, "tp") } - def merror = matchError(treeCopy.Apply(app, fn, vs map (_.ident))) - (tracing("root(s)", vs map (_.valDef)), MatrixInit(vs map (_.lhs), cases, merror)) + + tracing("root(s)", context.MatrixInit(vs, cases, merror)) } // sets up top level variables and algorithm input - val (vars, matrixInit) = selector match { + val matrixInit = selector match { case app @ Apply(fn, _) if isTupleType(selector.tpe) && doApply(fn) => tupleMatch(app) case _ => singleMatch() } - val matrix = new MatchMatrix(context, matrixInit) - val rep = matrix.expansion // expands casedefs and assigns name - val mch = typer typed rep.toTree // executes algorithm, converts tree to DFA - val dfatree = typer typed Block(vars, mch) // packages into a code block + val matrix = new MatchMatrix(context) { lazy val data = matrixInit } + val rep = matrix.expansion // expands casedefs and assigns name + val mch = typer typed rep.toTree // executes algorithm, converts tree to DFA + val dfatree = typer typed Block(matrixInit.valDefs, mch) // packages into a code block // redundancy check matrix.targets filter (_.isNotReached) foreach (cs => cunit.error(cs.body.pos, "unreachable code")) |