summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/compiler/scala/tools/nsc/matching/Matrix.scala18
-rw-r--r--src/compiler/scala/tools/nsc/matching/ParallelMatching.scala147
-rw-r--r--src/compiler/scala/tools/nsc/matching/TransMatcher.scala4
3 files changed, 79 insertions, 90 deletions
diff --git a/src/compiler/scala/tools/nsc/matching/Matrix.scala b/src/compiler/scala/tools/nsc/matching/Matrix.scala
index c5373f0924..631f517ab7 100644
--- a/src/compiler/scala/tools/nsc/matching/Matrix.scala
+++ b/src/compiler/scala/tools/nsc/matching/Matrix.scala
@@ -95,13 +95,19 @@ trait Matrix extends MatrixAdditions {
// TRANS_FLAG communicates there should be no exhaustiveness checking
private def flags(checked: Boolean) = if (checked) Nil else List(TRANS_FLAG)
+ /** Every new variable allocated gets one of these. */
+ class PatternVar(val lhs: Symbol, val rhs: Tree) {
+ lazy val ident = ID(lhs)
+ lazy val valDef = typedValDef(lhs, rhs)
+ }
+
/** Given a tree, creates a new synthetic variable of the same type
* and assigns the tree to it.
*/
def copyVar(
root: Tree,
+ checked: Boolean,
_tpe: Type = null,
- checked: Boolean = false,
label: String = "temp"): PatternVar =
{
val tpe = ifNull(_tpe, root.tpe)
@@ -112,20 +118,14 @@ trait Matrix extends MatrixAdditions {
}
/** The rhs is expressed as a function of the lhs. */
- def createVar(tpe: Type, f: Symbol => Tree, checked: Boolean = false) = {
+ def createVar(tpe: Type, f: Symbol => Tree, checked: Boolean) = {
val lhs = newVar(owner.pos, tpe, flags(checked))
val rhs = f(lhs)
new PatternVar(lhs, rhs)
}
- class PatternVar(val lhs: Symbol, val rhs: Tree) {
- lazy val ident = ID(lhs)
- lazy val valDef = typedValDef(lhs, rhs)
- // lazy val valDef = typedValDef(lhs, rhs setType lhs.tpe)
- }
-
- def newVar(
+ private def newVar(
pos: Position,
tpe: Type,
flags: List[Long] = Nil,
diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
index 8c15f7ae08..5b946cca6b 100644
--- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
+++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
@@ -93,7 +93,7 @@ trait ParallelMatching extends ast.TreeDSL
* Note that we only ever match on Symbols, not Trees: a temporary variable
* is created for any expressions being matched on.
*/
- class Scrutinee(val sym: Symbol) {
+ class Scrutinee(val sym: Symbol, extraValdefs: List[ValDef] = Nil) {
import definitions._
// presenting a face of our symbol
@@ -102,7 +102,14 @@ trait ParallelMatching extends ast.TreeDSL
def id = ID(sym) // attributed ident
def accessors = if (isCaseClass) sym.caseFieldAccessors else Nil
- def accessorVars = accessors map (a => newVarOfTpe((tpe memberType a).resultType))
+ def accessorTypes = accessors map (x => (tpe memberType x).resultType)
+
+ private lazy val accessorPatternVars =
+ for ((accessor, tpe) <- accessors zip accessorTypes) yield
+ createVar(tpe, _ => fn(id, accessor))
+
+ def accessorVars = accessorPatternVars map (_.lhs)
+ def accessorValDefs = extraValdefs ::: (accessorPatternVars map (_.valDef))
// tests
def isDefined = sym ne NoSymbol
@@ -113,10 +120,6 @@ trait ParallelMatching extends ast.TreeDSL
def seqType = tpe.widen baseType SeqClass
def elemType = tpe typeArgs 0
- def newVarOfTpe(tpe: Type) = newVar(pos, tpe, flags)
- def newVarOfSeqType = newVar(pos, seqType)
- def newVarOfElemType = newVar(pos, elemType)
-
// for propagating "unchecked" to synthetic vars
def isChecked = !(sym hasFlag TRANS_FLAG)
def flags: List[Long] = List(TRANS_FLAG) filter (sym hasFlag _)
@@ -126,7 +129,10 @@ trait ParallelMatching extends ast.TreeDSL
def castedTo(headType: Type) =
if (tpe =:= headType) this
- else new Scrutinee(newVar(pos, headType, flags))
+ else {
+ val pv = createVar(headType, lhs => id AS_ANY lhs.tpe)
+ new Scrutinee(pv.lhs, List(pv.valDef))
+ }
override def toString() = "(%s: %s)".format(id, tpe)
}
@@ -242,12 +248,6 @@ trait ParallelMatching extends ast.TreeDSL
"Rule/%s (%s =^= %s)".format(getClass.getSimpleName, scrut, head)
}
- case class ErrorRule() extends RuleApplication {
- def pmatch: PatternMatch = impossible
- def rest: Rep = impossible
- final def tree() = failTree
- }
-
/** {case ... if guard => bx} else {guardedRest} */
/** VariableRule: The top-most rows has only variable (non-constructor) patterns. */
case class VariableRule(subst: Bindings, guard: Guard, guardedRest: Rep, bx: Int) extends RuleApplicationFormal {
@@ -344,10 +344,7 @@ trait ParallelMatching extends ast.TreeDSL
private lazy val zipped = pmatch pzip rest.rows
lazy val unapplyResult =
- createVar(app.tpe,
- lhs => reapply setType lhs.tpe,
- scrut.isChecked
- )
+ scrut.createVar(app.tpe, lhs => reapply setType lhs.tpe)
// XXX in transition.
object sameUnapplyCall {
@@ -382,10 +379,12 @@ trait ParallelMatching extends ast.TreeDSL
mkFail(zipped.tail filterNot (x => isSameUnapply(x._1)) map { case (pat, r) => r insert pat })
private def doSuccess: (List[Tree], List[Symbol], List[Row]) = {
- def mkVar(tpe: Type) = newVar(ua.pos, tpe, scrut.flags)
-
- lazy val lhs = mkVar(app.tpe typeArgs 0)
- lazy val vdef = typedValDef(lhs, fn(ID(unapplyResult.lhs), nme.get))
+ lazy val alloc = scrut.createVar(
+ app.tpe typeArgs 0,
+ _ => fn(ID(unapplyResult.lhs), nme.get)
+ )
+ def vdef = alloc.valDef
+ def lhs = alloc.lhs
// at this point it's Some[T1,T2...]
lazy val tpes = getProductArgs(lhs.tpe).get
@@ -419,7 +418,7 @@ trait ParallelMatching extends ast.TreeDSL
/** handle sequence pattern and ArrayValue (but not star patterns)
*/
- sealed class MixSequence(val pmatch: PatternMatch, val rest: Rep) extends RuleApplication {
+ sealed class MixSequence(val pmatch: PatternMatch, val rest: Rep) extends RuleApplicationFormal {
// Called 'pivot' since it's the head of the column under consideration in the mixture rule.
val pivot @ SequencePattern(av @ ArrayValue(_, _)) = head
private def pivotLen = pivot.nonStarLength
@@ -455,16 +454,18 @@ trait ParallelMatching extends ast.TreeDSL
emptyPatterns(pivot.elemPatterns.length + 1)
}
- final def tree(): Tree = {
- assert(scrut.tpe <:< head.tpe, "fatal: %s is not <:< %s".format(scrut, head.tpe))
+ lazy val cond = (pivot precondition pmatch).get
- val vs = pivot.nonStarPatterns map (x => newVar(x.tree.pos, scrut.elemType))
- lazy val tail = scrut.newVarOfSeqType
- lazy val lastBinding = typedValDef(tail, scrut.id DROP vs.size)
- def elemAt(i: Int) = (scrut.id DOT (scrut.tpe member nme.apply))(LIT(i))
+ lazy val (success, failure) = {
+ assert(scrut.tpe <:< head.tpe, "fatal: %s is not <:< %s".format(scrut, head.tpe))
+ def elemAt(i: Int) = (scrut.id DOT (scrut.tpe member nme.apply))(LIT(i))
+ def elemCount = pivot.nonStarPatterns.size
- val bindings =
- (vs.zipWithIndex map tupled((v, i) => typedValDef(v, elemAt(i)))) ::: List(lastBinding)
+ val vs =
+ // one per element .. pos = pat.pos
+ (pivot.nonStarPatterns.zipWithIndex map { case (pat, i) => scrut.createVar(scrut.elemType, _ => elemAt(i)) }) :::
+ // and one for the rest of the sequence
+ List(scrut.createVar(scrut.seqType, _ => scrut.id DROP elemCount))
val (nrows, frows): (List[Option[Row]], List[Option[Row]]) = List.unzip(
for ((c, rows) <- pmatch pzip rest.rows) yield getSubPatterns(c) match {
@@ -474,12 +475,15 @@ trait ParallelMatching extends ast.TreeDSL
)
val symList = if (pivot.hasStar) List(scrut.sym) else Nil
- val cond = (pivot precondition pmatch).get
- val succ = make(List(vs, List(tail), symList, rest.tvars).flatten, nrows.flatten).toTree
- val fail = make(scrut.sym :: rest.tvars, frows.flatten).toTree
+ val succ = make(List(vs map (_.lhs), symList, rest.tvars).flatten, nrows.flatten)
- IF (cond) THEN squeezedBlock(bindings, succ) ELSE fail
+ (
+ squeezedBlock(vs map (_.valDef), succ.toTree),
+ make(scrut.sym :: rest.tvars, frows.flatten).toTree
+ )
}
+
+ final def tree(): Tree = codegen
}
// @todo: equals test for same constant
@@ -516,7 +520,7 @@ trait ParallelMatching extends ast.TreeDSL
/** mixture rule for type tests
**/
- class MixTypes(val pmatch: PatternMatch, val rest: Rep) extends RuleApplication {
+ class MixTypes(val pmatch: PatternMatch, val rest: Rep) extends RuleApplicationFormal {
// see bug1434.scala for an illustration of why "x <:< y" is insufficient.
// this code is definitely inadequate at best. Inherited comment:
//
@@ -586,6 +590,16 @@ trait ParallelMatching extends ast.TreeDSL
}
) match { case (x,y,z) => (join(x), join(y), join(z)) }
+ // temporary checks so we're less crashy while we determine what to implement.
+ def checkErroneous(scrut: Scrutinee): Type = {
+ scrut.tpe match {
+ case tpe @ ThisType(_) if tpe.termSymbol == NoSymbol =>
+ cunit.error(scrut.pos, "self type test in anonymous class forbidden by implementation.")
+ ErrorType
+ case x => x
+ }
+ }
+
override def toString = {
val msgs = List(
"moreSpecific: " + pp(moreSpecific),
@@ -596,14 +610,15 @@ trait ParallelMatching extends ast.TreeDSL
super.toString() + "\n" + indentAll(msgs)
}
- /** returns casted symbol, success matrix and optionally fail matrix for type test on the top of this column */
- final def getTransition() = {
- val casted = scrut castedTo pmatch.headType
-
- val isAnyMoreSpecific = moreSpecific exists (x => !x.isEmpty)
+ private def isAnyMoreSpecific = moreSpecific exists (x => !x.isEmpty)
+ private def mkZipped = moreSpecific zip subsumed map {
+ case (mspat, (j, pmatch)) => (j, mspat :: pmatch)
+ }
- def mkZipped = moreSpecific zip subsumed map { case (mspat, (j, pmatch)) => (j, mspat :: pmatch) }
+ lazy val casted = scrut castedTo pmatch.headType
+ lazy val cond = condition(checkErroneous(casted), scrut)
+ lazy val success = {
val (subtests, subtestVars) =
if (isAnyMoreSpecific) (mkZipped, List(casted.sym))
else (subsumed, Nil)
@@ -612,41 +627,15 @@ trait ParallelMatching extends ast.TreeDSL
for ((j, ps) <- subtests) yield
(rest rows j).insert2(ps, pmatch(j).boundVariables, casted.sym)
- val success = make(subtestVars ::: casted.accessorVars ::: rest.tvars, newRows)
- val failure = mkFail(remaining map tupled((p1, p2) => rest rows p1 insert p2))
+ val srep =
+ make(subtestVars ::: casted.accessorVars ::: rest.tvars, newRows)
- (casted, success, failure)
+ squeezedBlock(casted.accessorValDefs, srep.toTree)
}
+ lazy val failure =
+ mkFail(remaining map tupled((p1, p2) => rest rows p1 insert p2))
- // temporary checks so we're less crashy while we determine what to implement.
- def checkErroneous(scrut: Scrutinee): Type = {
- scrut.tpe match {
- case tpe @ ThisType(_) if tpe.termSymbol == NoSymbol =>
- cunit.error(scrut.pos, "self type test in anonymous class forbidden by implementation.")
- ErrorType
- case x => x
- }
- }
-
- final def tree(): Tree = {
- val (casted, srep, fail) = this.getTransition
- val castedTpe = checkErroneous(casted)
- val cond = condition(castedTpe, scrut)
- val succ = srep.toTree
-
- // dig out case field accessors that were buried in (***)
- val cfa = if (pmatch.isCaseHead) casted.accessors else Nil
- val caseTemps = srep.tvars match { case x :: xs if x == casted.sym => xs ; case x => x }
- def castedScrut = typedValDef(casted.sym, scrut.id AS_ANY castedTpe)
- def needCast = if (casted.sym ne scrut.sym) List(castedScrut) else Nil
-
- val vdefs = needCast ::: (
- for ((tmp, accessor) <- caseTemps zip cfa) yield
- typedValDef(tmp, fn(casted.id, accessor))
- )
-
- IF (cond) THEN squeezedBlock(vdefs, succ) ELSE fail
- }
+ final def tree(): Tree = codegen
}
/*** States, Rows, Etc. ***/
@@ -795,7 +784,7 @@ trait ParallelMatching extends ast.TreeDSL
}
/** Converts this to a tree - recursively acquires subreps. */
- final def toTree(): Tree = typer typed this.applyRule.tree()
+ final def toTree(): Tree = typer typed this.applyRule()
/** The VariableRule. */
private def variable() = {
@@ -812,12 +801,12 @@ trait ParallelMatching extends ast.TreeDSL
*
* VariableRule - if all patterns are default patterns
* MixtureRule - if one or more patterns are not default patterns
- * ErrorRule - if there are no rows remaining
+ * Error - no rows remaining
*/
- final def applyRule(): RuleApplication =
- if (rows.isEmpty) ErrorRule()
- else if (others.isEmpty) variable()
- else mixture()
+ final def applyRule(): Tree =
+ if (rows.isEmpty) failTree
+ else if (others.isEmpty) variable.tree()
+ else mixture.tree()
override def toString() =
if (tvars.size == 0) "Rep(%d) = %s".format(rows.size, pp(rows))
diff --git a/src/compiler/scala/tools/nsc/matching/TransMatcher.scala b/src/compiler/scala/tools/nsc/matching/TransMatcher.scala
index 4d8315530f..b73acb8004 100644
--- a/src/compiler/scala/tools/nsc/matching/TransMatcher.scala
+++ b/src/compiler/scala/tools/nsc/matching/TransMatcher.scala
@@ -56,7 +56,7 @@ trait TransMatcher extends ast.TreeDSL with CompactTreePrinter {
// For x match { ... we start with a single root
def singleMatch(): (List[Tree], MatrixInit) = {
- val v = copyVar(selector, checked = isChecked)
+ val v = copyVar(selector, isChecked)
(List(v.valDef), MatrixInit(List(v.lhs), cases, matchError(v.ident)))
}
@@ -64,7 +64,7 @@ trait TransMatcher extends ast.TreeDSL with CompactTreePrinter {
// For (x, y, z) match { ... we start with multiple roots, called tpXX.
def tupleMatch(app: Apply): (List[Tree], MatrixInit) = {
val Apply(fn, args) = app
- val vs = args zip rootTypes map { case (arg, tpe) => copyVar(arg, tpe, isChecked, "tp") }
+ val vs = args zip rootTypes map { case (arg, tpe) => copyVar(arg, isChecked, tpe, "tp") }
def merror = matchError(treeCopy.Apply(app, fn, vs map (_.ident)))
(vs map (_.valDef), MatrixInit(vs map (_.lhs), cases, merror))