summaryrefslogtreecommitdiff
path: root/src/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'src/compiler')
-rw-r--r--src/compiler/scala/tools/nsc/matching/Matrix.scala34
-rw-r--r--src/compiler/scala/tools/nsc/matching/MatrixAdditions.scala7
-rw-r--r--src/compiler/scala/tools/nsc/matching/ParallelMatching.scala109
-rw-r--r--src/compiler/scala/tools/nsc/matching/TransMatcher.scala21
4 files changed, 88 insertions, 83 deletions
diff --git a/src/compiler/scala/tools/nsc/matching/Matrix.scala b/src/compiler/scala/tools/nsc/matching/Matrix.scala
index 6d9f6a099b..e0d15a4c49 100644
--- a/src/compiler/scala/tools/nsc/matching/Matrix.scala
+++ b/src/compiler/scala/tools/nsc/matching/Matrix.scala
@@ -77,12 +77,6 @@ trait Matrix extends MatrixAdditions {
states. Otherwise,the error state is used after its reference count has been incremented.
**/
- case class MatrixInit(
- roots: List[Symbol],
- cases: List[CaseDef],
- default: Tree
- )
-
case class MatrixContext(
handleOuter: Tree => Tree, // for outer pointer
typer: Typer, // a local typer
@@ -95,14 +89,30 @@ trait Matrix extends MatrixAdditions {
// TRANS_FLAG communicates there should be no exhaustiveness checking
private def flags(checked: Boolean) = if (checked) Nil else List(TRANS_FLAG)
- /** Every new variable allocated gets one of these. */
- class PatternVar(val lhs: Symbol, val rhs: Tree) {
+ case class MatrixInit(
+ roots: List[PatternVar],
+ cases: List[CaseDef],
+ default: Tree
+ ) {
+ def tvars = roots map (_.lhs)
+ def valDefs = roots map (_.valDef)
+ }
+
+ /** Every temporary variable allocated is put in a PatternVar.
+ */
+ class PatternVar(val lhs: Symbol, val rhs: Tree, val checked: Boolean) {
+ def sym = lhs
lazy val ident = ID(lhs)
lazy val valDef = typedValDef(lhs, rhs)
override def toString() = "%s: %s = %s".format(lhs, lhs.info, rhs)
}
+ /** Sets the rhs to EmptyTree, which makes the valDef ignored in Scrutinee.
+ */
+ def specialVar(lhs: Symbol, checked: Boolean) =
+ new PatternVar(lhs, EmptyTree, checked)
+
/** Given a tree, creates a new synthetic variable of the same type
* and assigns the tree to it.
*/
@@ -116,15 +126,17 @@ trait Matrix extends MatrixAdditions {
val name = newName(root.pos, label)
val sym = newVar(root.pos, tpe, flags(checked), name)
- tracing("copy", new PatternVar(sym, root))
+ tracing("copy", new PatternVar(sym, root, checked))
}
- /** The rhs is expressed as a function of the lhs. */
+ /** Creates a new synthetic variable of the specified type and
+ * assigns the result of f(symbol) to it.
+ */
def createVar(tpe: Type, f: Symbol => Tree, checked: Boolean) = {
val lhs = newVar(owner.pos, tpe, flags(checked))
val rhs = f(lhs)
- tracing("create", new PatternVar(lhs, rhs))
+ tracing("create", new PatternVar(lhs, rhs, checked))
}
private def newVar(
diff --git a/src/compiler/scala/tools/nsc/matching/MatrixAdditions.scala b/src/compiler/scala/tools/nsc/matching/MatrixAdditions.scala
index c1364af806..c27c713a47 100644
--- a/src/compiler/scala/tools/nsc/matching/MatrixAdditions.scala
+++ b/src/compiler/scala/tools/nsc/matching/MatrixAdditions.scala
@@ -25,6 +25,9 @@ trait MatrixAdditions extends ast.TreeDSL
private[matching] trait Squeezer {
self: MatrixContext =>
+ def squeezedBlockPVs(pvs: List[PatternVar], exp: Tree): Tree =
+ squeezedBlock(pvs map (_.valDef), exp)
+
def squeezedBlock(vds: List[Tree], exp: Tree): Tree =
if (settings_squeeze) Block(Nil, squeezedBlock1(vds, exp))
else Block(vds, exp)
@@ -204,7 +207,7 @@ trait MatrixAdditions extends ast.TreeDSL
}
private lazy val inexhaustives: List[List[Combo]] = {
val collected =
- for ((sym, i) <- tvars.zipWithIndex ; if requiresExhaustive(sym)) yield
+ for ((pv, i) <- tvars.zipWithIndex ; val sym = pv.lhs ; if requiresExhaustive(sym)) yield
i -> sealedSymsFor(sym.tpe.typeSymbol)
val folded =
@@ -230,7 +233,7 @@ trait MatrixAdditions extends ast.TreeDSL
def check = {
def errMsg = (inexhaustives map mkMissingStr).mkString
if (inexhaustives.nonEmpty)
- cunit.warning(tvars.head.pos, "match is not exhaustive!\n" + errMsg)
+ cunit.warning(tvars.head.lhs.pos, "match is not exhaustive!\n" + errMsg)
rep
}
diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
index 801bdd3d19..8143b2e731 100644
--- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
+++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
@@ -37,12 +37,14 @@ trait ParallelMatching extends ast.TreeDSL
def toPats(xs: List[Tree]): List[Pattern] = xs map Pattern.apply
/** The umbrella matrix class. **/
- class MatchMatrix(val context: MatrixContext, data: MatrixInit) extends MatchMatrixOptimizer with MatrixExhaustiveness {
+ abstract class MatchMatrix(val context: MatrixContext) extends MatchMatrixOptimizer with MatrixExhaustiveness {
import context._
- val MatrixInit(roots, cases, failTree) = data
- val ExpandedMatrix(rows, targets) = expand(roots, cases)
- val expansion: Rep = make(roots, rows)
+ def data: MatrixContext#MatrixInit
+
+ lazy val MatrixInit(roots, cases, failTree) = data
+ lazy val ExpandedMatrix(rows, targets) = expand(roots, cases)
+ lazy val expansion: Rep = make(roots, rows)
val shortCuts = new ListBuffer[Symbol]()
@@ -85,8 +87,8 @@ trait ParallelMatching extends ast.TreeDSL
// make(_tvars, _rows)
/** the injection here handles alternatives and unapply type tests */
- final def make(tvars: List[Symbol], row1: List[Row]): Rep = {
- def classifyPat(opat: Pattern, j: Int): Pattern = opat simplify tvars(j)
+ final def make(tvars: List[PatternVar], row1: List[Row]): Rep = {
+ def classifyPat(opat: Pattern, j: Int): Pattern = opat simplify tvars(j).sym
val rows = row1 flatMap (_ expandAlternatives classifyPat)
if (rows.length != row1.length) make(tvars, rows) // recursive call if any change
@@ -96,32 +98,30 @@ trait ParallelMatching extends ast.TreeDSL
override def toString() = "MatchMatrix(%s) { %s }".format(matchResultType, indentAll(targets))
/**
- * Encapsulates a symbol being matched on.
- *
- * sym match { ... }
- *
- * results in Scrutinee(sym).
+ * Encapsulates a symbol being matched on. It is created from a
+ * PatternVar, which encapsulates the symbol's creation and assignment.
*
- * Note that we only ever match on Symbols, not Trees: a temporary variable
- * is created for any expressions being matched on.
+ * We never match on trees directly - a temporary variable is created
+ * (in a PatternVar) for any expression being matched on.
*/
- class Scrutinee(val sym: Symbol, extraValdefs: List[ValDef] = Nil) {
+ class Scrutinee(val pv: PatternVar) {
import definitions._
// presenting a face of our symbol
- def tpe = sym.tpe
- def pos = sym.pos
- def id = ID(sym) // attributed ident
+ def sym = pv.sym
+ def tpe = sym.tpe
+ def pos = sym.pos
+ def id = ID(sym) // attributed ident
def accessors = if (isCaseClass) sym.caseFieldAccessors else Nil
def accessorTypes = accessors map (x => (tpe memberType x).resultType)
- private lazy val accessorPatternVars =
+ lazy val accessorPatternVars =
for ((accessor, tpe) <- accessors zip accessorTypes) yield
createVar(tpe, _ => fn(id, accessor))
- def accessorVars = accessorPatternVars map (_.lhs)
- def accessorValDefs = extraValdefs ::: (accessorPatternVars map (_.valDef))
+ private def extraValDefs = if (pv.rhs.isEmpty) Nil else List(pv.valDef)
+ def allValDefs = extraValDefs ::: (accessorPatternVars map (_.valDef))
// tests
def isDefined = sym ne NoSymbol
@@ -141,10 +141,7 @@ trait ParallelMatching extends ast.TreeDSL
def castedTo(headType: Type) =
if (tpe =:= headType) this
- else {
- val pv = createVar(headType, lhs => id AS_ANY lhs.tpe)
- new Scrutinee(pv.lhs, List(pv.valDef))
- }
+ else new Scrutinee(createVar(headType, lhs => id AS_ANY lhs.tpe))
override def toString() = "(%s: %s)".format(id, tpe)
}
@@ -246,7 +243,7 @@ trait ParallelMatching extends ast.TreeDSL
def mkFail(xs: List[Row]): Tree = xs match {
case Nil => failTree
- case _ => make(scrut.sym :: rest.tvars, xs).toTree
+ case _ => make(scrut.pv :: rest.tvars, xs).toTree
}
/** translate outcome of the rule application into code (possible involving recursive application of rewriting) */
@@ -391,38 +388,35 @@ trait ParallelMatching extends ast.TreeDSL
lazy val failure =
mkFail(zipped.tail filterNot (x => isSameUnapply(x._1)) map { case (pat, r) => r insert pat })
- private def doSuccess: (List[Tree], List[Symbol], List[Row]) = {
- lazy val alloc = scrut.createVar(
+ private def doSuccess: (List[PatternVar], List[PatternVar], List[Row]) = {
+ // pattern variable for the unapply result of Some(x).get
+ lazy val pv = scrut.createVar(
app.tpe typeArgs 0,
_ => fn(ID(unapplyResult.lhs), nme.get)
)
- def vdef = alloc.valDef
- def lhs = alloc.lhs
+ def tuple = pv.lhs
// at this point it's Some[T1,T2...]
- lazy val tpes = getProductArgs(lhs.tpe).get
+ lazy val tpes = getProductArgs(tuple.tpe).get
- // one allocation per tuple element
- lazy val allocations =
+ // one pattern variable per tuple element
+ lazy val tuplePVs =
for ((tpe, i) <- tpes.zipWithIndex) yield
- scrut.createVar(tpe, _ => fn(ID(lhs), productProj(lhs, i + 1)))
-
- def vdefs = allocations map (_.valDef)
- def vsyms = allocations map (_.lhs)
+ scrut.createVar(tpe, _ => fn(ID(tuple), productProj(tuple, i + 1)))
// 0 is Boolean, 1 is Option[T], 2+ is Option[(T1,T2,...)]
args.length match {
case 0 => (Nil, Nil, mkNewRows((xs) => Nil, 0))
- case 1 => (List(vdef), List(lhs), mkNewRows(xs => List(xs.head), 1))
- case _ => (vdef :: vdefs, vsyms, mkNewRows(identity, tpes.size))
+ case 1 => (List(pv), List(pv), mkNewRows(xs => List(xs.head), 1))
+ case _ => (pv :: tuplePVs, tuplePVs, mkNewRows(identity, tpes.size))
}
}
lazy val success = {
- val (vdefs, ntemps, rows) = doSuccess
- val srep = make(ntemps ::: scrut.sym :: rest.tvars, rows).toTree
+ val (squeezePVs, pvs, rows) = doSuccess
+ val srep = make(pvs ::: scrut.pv :: rest.tvars, rows).toTree
- squeezedBlock(vdefs, srep)
+ squeezedBlockPVs(squeezePVs, srep)
}
final def tree() =
@@ -475,7 +469,7 @@ trait ParallelMatching extends ast.TreeDSL
def elemAt(i: Int) = (scrut.id DOT (scrut.tpe member nme.apply))(LIT(i))
def elemCount = pivot.nonStarPatterns.size
- val vs =
+ val pvs =
// one per element .. pos = pat.pos
(pivot.nonStarPatterns.zipWithIndex map { case (pat, i) => scrut.createVar(scrut.elemType, _ => elemAt(i)) }) :::
// and one for the rest of the sequence
@@ -488,12 +482,12 @@ trait ParallelMatching extends ast.TreeDSL
}
)
- val symList = if (pivot.hasStar) List(scrut.sym) else Nil
- val succ = make(List(vs map (_.lhs), symList, rest.tvars).flatten, nrows.flatten)
+ val symList = if (pivot.hasStar) List(scrut.pv) else Nil
+ val succ = make(List(pvs, symList, rest.tvars).flatten, nrows.flatten)
(
- squeezedBlock(vs map (_.valDef), succ.toTree),
- make(scrut.sym :: rest.tvars, frows.flatten).toTree
+ squeezedBlockPVs(pvs, succ.toTree),
+ make(scrut.pv :: rest.tvars, frows.flatten).toTree
)
}
@@ -503,7 +497,7 @@ trait ParallelMatching extends ast.TreeDSL
// @todo: equals test for same constant
class MixEquals(val pmatch: PatternMatch, val rest: Rep) extends RuleApplication {
private def mkNewRep(rows: List[Row]) =
- make(scrut.sym :: rest.tvars, rows).toTree
+ make(scrut.pv :: rest.tvars, rows).toTree
private lazy val labelBody =
mkNewRep(List.map2(rest.rows.tail, pmatch.tail)(_ insert _))
@@ -634,7 +628,7 @@ trait ParallelMatching extends ast.TreeDSL
lazy val success = {
val (subtests, subtestVars) =
- if (isAnyMoreSpecific) (mkZipped, List(casted.sym))
+ if (isAnyMoreSpecific) (mkZipped, List(casted.pv))
else (subsumed, Nil)
val newRows =
@@ -642,9 +636,9 @@ trait ParallelMatching extends ast.TreeDSL
(rest rows j).insert2(ps, pmatch(j).boundVariables, casted.sym)
val srep =
- make(subtestVars ::: casted.accessorVars ::: rest.tvars, newRows)
+ make(subtestVars ::: casted.accessorPatternVars ::: rest.tvars, newRows)
- squeezedBlock(casted.accessorValDefs, srep.toTree)
+ squeezedBlock(casted.allValDefs, srep.toTree)
}
lazy val failure =
mkFail(remaining map tupled((p1, p2) => rest rows p1 insert p2))
@@ -658,10 +652,7 @@ trait ParallelMatching extends ast.TreeDSL
if (pats exists (p => !p.isDefault))
traceCategory("Row", "%s", pp(pats))
- /** Drops the 'i'th pattern */
- def drop(i: Int) = copy(pats = dropIndex(pats, i))
-
- /** Extracts the nth pattern. */
+ /** Extracts the 'i'th pattern. */
def extractColumn(i: Int) = {
val (x, xs) = extractIndex(pats, i)
(x, copy(pats = xs))
@@ -780,7 +771,7 @@ trait ParallelMatching extends ast.TreeDSL
override def toString() = pp("Final%d%s".format(bx, pp(freeVars)) -> body)
}
- case class Rep(val tvars: List[Symbol], val rows: List[Row]) {
+ case class Rep(val tvars: List[PatternVar], val rows: List[Row]) {
lazy val Row(pats, subst, guard, index) = rows.head
lazy val guardedRest = if (guard.isEmpty) NoRep else make(tvars, rows.tail)
lazy val (defaults, others) = pats span (_.isDefault)
@@ -795,14 +786,14 @@ trait ParallelMatching extends ast.TreeDSL
List.unzip(rows map (_ extractColumn index))
/** Now the 'i'th tvar is separated out and used as a new Scrutinee. */
- private val (_sym, _tvars) =
+ private val (_pv, _tvars) =
extractIndex(tvars, index)
/** The non-default pattern (others.head) replaces the column head. */
private val (_ncol, _nrep) =
(others.head :: _column.tail, make(_tvars, _rows))
- def mix = MixtureRule(new Scrutinee(_sym), _ncol, _nrep)
+ def mix = MixtureRule(new Scrutinee(specialVar(_pv.sym, _pv.checked)), _ncol, _nrep)
}
/** Converts this to a tree - recursively acquires subreps. */
@@ -811,7 +802,7 @@ trait ParallelMatching extends ast.TreeDSL
/** The VariableRule. */
private def variable() = {
val binding = (defaults map (_.boundVariables) zip tvars) .
- foldLeft(subst)((b, pair) => b.add(pair._1, pair._2))
+ foldLeft(subst)((b, pair) => b.add(pair._1, pair._2.lhs))
VariableRule(binding, guard, guardedRest, index)
}
@@ -837,7 +828,7 @@ trait ParallelMatching extends ast.TreeDSL
val NoRep = Rep(Nil, Nil)
/** Expands the patterns recursively. */
- final def expand(roots: List[Symbol], cases: List[CaseDef]): ExpandedMatrix = {
+ final def expand(roots: List[PatternVar], cases: List[CaseDef]): ExpandedMatrix = {
val (rows, finals) = List.unzip(
for ((CaseDef(pat, guard, body), index) <- cases.zipWithIndex) yield {
def mkRow(ps: List[Tree]) = Row(toPats(ps), NoBinding, Guard(guard), index)
diff --git a/src/compiler/scala/tools/nsc/matching/TransMatcher.scala b/src/compiler/scala/tools/nsc/matching/TransMatcher.scala
index 2fb162a059..88db07518d 100644
--- a/src/compiler/scala/tools/nsc/matching/TransMatcher.scala
+++ b/src/compiler/scala/tools/nsc/matching/TransMatcher.scala
@@ -56,31 +56,30 @@ trait TransMatcher extends ast.TreeDSL {
(cases forall caseIsOk)
// For x match { ... we start with a single root
- def singleMatch(): (List[Tree], MatrixInit) = {
+ def singleMatch(): MatrixInit = {
val v = copyVar(selector, isChecked)
-
- (tracing("root(s)", List(v.valDef)), MatrixInit(List(v.lhs), cases, matchError(v.ident)))
+ tracing("root(s)", context.MatrixInit(List(v), cases, matchError(v.ident)))
}
// For (x, y, z) match { ... we start with multiple roots, called tpXX.
- def tupleMatch(app: Apply): (List[Tree], MatrixInit) = {
+ def tupleMatch(app: Apply): MatrixInit = {
val Apply(fn, args) = app
val vs = args zip rootTypes map { case (arg, tpe) => copyVar(arg, isChecked, tpe, "tp") }
-
def merror = matchError(treeCopy.Apply(app, fn, vs map (_.ident)))
- (tracing("root(s)", vs map (_.valDef)), MatrixInit(vs map (_.lhs), cases, merror))
+
+ tracing("root(s)", context.MatrixInit(vs, cases, merror))
}
// sets up top level variables and algorithm input
- val (vars, matrixInit) = selector match {
+ val matrixInit = selector match {
case app @ Apply(fn, _) if isTupleType(selector.tpe) && doApply(fn) => tupleMatch(app)
case _ => singleMatch()
}
- val matrix = new MatchMatrix(context, matrixInit)
- val rep = matrix.expansion // expands casedefs and assigns name
- val mch = typer typed rep.toTree // executes algorithm, converts tree to DFA
- val dfatree = typer typed Block(vars, mch) // packages into a code block
+ val matrix = new MatchMatrix(context) { lazy val data = matrixInit }
+ val rep = matrix.expansion // expands casedefs and assigns name
+ val mch = typer typed rep.toTree // executes algorithm, converts tree to DFA
+ val dfatree = typer typed Block(matrixInit.valDefs, mch) // packages into a code block
// redundancy check
matrix.targets filter (_.isNotReached) foreach (cs => cunit.error(cs.body.pos, "unreachable code"))