summaryrefslogtreecommitdiff
path: root/src/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'src/compiler')
-rw-r--r--src/compiler/scala/tools/nsc/matching/ParallelMatching.scala76
-rw-r--r--src/compiler/scala/tools/nsc/matching/PatternBindings.scala12
-rw-r--r--src/compiler/scala/tools/nsc/matching/Patterns.scala6
3 files changed, 47 insertions, 47 deletions
diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
index 5b946cca6b..54ad0354da 100644
--- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
+++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
@@ -221,20 +221,16 @@ trait ParallelMatching extends ast.TreeDSL
/***** Rule Applications *****/
- trait RuleApplicationFormal extends RuleApplication {
- def cond: Tree
- def success: Tree
- def failure: Tree
-
- def codegen: Tree = IF (cond) THEN (success) ELSE (failure)
- }
-
sealed abstract class RuleApplication {
def pmatch: PatternMatch
def rest: Rep
+ def cond: Tree
+ def success: Tree
+ def failure: Tree
lazy val PatternMatch(scrut, patterns) = pmatch
lazy val head = pmatch.head
+ def codegen: Tree = IF (cond) THEN (success) ELSE (failure)
def mkFail(xs: List[Row]): Tree = xs match {
case Nil => failTree
@@ -250,7 +246,7 @@ trait ParallelMatching extends ast.TreeDSL
/** {case ... if guard => bx} else {guardedRest} */
/** VariableRule: The top-most rows has only variable (non-constructor) patterns. */
- case class VariableRule(subst: Bindings, guard: Guard, guardedRest: Rep, bx: Int) extends RuleApplicationFormal {
+ case class VariableRule(subst: Bindings, guard: Guard, guardedRest: Rep, bx: Int) extends RuleApplication {
def pmatch: PatternMatch = impossible
def rest: Rep = guardedRest
@@ -272,7 +268,7 @@ trait ParallelMatching extends ast.TreeDSL
val literals = pmatch.ps
val defaultPattern = pmatch.defaultPattern
- private lazy val target: Tree =
+ private lazy val casted: Tree =
if (!scrut.tpe.isInt) scrut.id DOT nme.toInt else scrut.id
// creates a row transformer for injecting the default case bindings at a given index
@@ -324,16 +320,21 @@ trait ParallelMatching extends ast.TreeDSL
lazy val defaultTree = make(rest.tvars, defaultRows).toTree
def casesWithDefault = cases ::: List(CASE(WILD(IntClass.tpe)) ==> defaultTree)
- // only one case becomes if/else, otherwise match
- def tree() = cases match {
- case List(CaseDef(lit, _, body)) => IF (scrut.id MEMBER_== lit) THEN body ELSE defaultTree
- case _ => target MATCH (casesWithDefault: _*)
+ // cond/success/failure only used if there is exactly one case.
+ lazy val (cond, success) = cases match {
+ case List(CaseDef(lit, _, body)) => (scrut.id MEMBER_== lit, body)
}
+ lazy val failure = defaultTree
+
+ // only one case becomes if/else, otherwise match
+ def tree() =
+ if (cases.size == 1) codegen
+ else casted MATCH (casesWithDefault: _*)
}
/** mixture rule for unapply pattern
*/
- class MixUnapply(val pmatch: PatternMatch, val rest: Rep, typeTest: Boolean) extends RuleApplicationFormal {
+ class MixUnapply(val pmatch: PatternMatch, val rest: Rep, typeTest: Boolean) extends RuleApplication {
val uapattern = head match { case x: UnapplyPattern => x ; case _ => abort("XXX") }
val ua @ UnApply(app, args) = head.tree
@@ -418,18 +419,11 @@ trait ParallelMatching extends ast.TreeDSL
/** handle sequence pattern and ArrayValue (but not star patterns)
*/
- sealed class MixSequence(val pmatch: PatternMatch, val rest: Rep) extends RuleApplicationFormal {
+ sealed class MixSequence(val pmatch: PatternMatch, val rest: Rep) extends RuleApplication {
// Called 'pivot' since it's the head of the column under consideration in the mixture rule.
val pivot @ SequencePattern(av @ ArrayValue(_, _)) = head
private def pivotLen = pivot.nonStarLength
- /** array elements except for star (if present) */
- protected def nonStarElems(x: ArrayValue) =
- if (isRightIgnoring(x)) x.elems.init else x.elems
-
- final def removeStar(xs: List[Pattern]): List[Pattern] =
- xs.init ::: List(Pattern(makeBind(xs.last.boundVariables, WILD(scrut.seqType))))
-
def mustCheck(first: Pattern, next: Pattern): Boolean = {
if (first.tree eq next.tree)
return false
@@ -437,21 +431,29 @@ trait ParallelMatching extends ast.TreeDSL
!(first completelyCovers next)
}
- def getSubPatterns(x: Pattern): Option[List[Pattern]] = condOpt(x.tree) {
- case av @ ArrayValue(_, xs) if nonStarElems(av).length == pivotLen =>
- val (star1, star2) = (pivot.hasStar, isRightIgnoring(av))
+ def getSubPatterns(x: Pattern): Option[List[Pattern]] = {
+ def defaults = Some(emptyPatterns(pivot.elemPatterns.length + 1))
+ val av @ ArrayValue(_, xs) = x.tree match {
+ case x: ArrayValue => x
+ case EmptyTree | WILD() => return defaults
+ case _ => return None
+ }
- (star1, star2) match {
- case (true, true) => removeStar(toPats(xs)) ::: List(NoPattern)
+ val sp = x.asInstanceOf[SequencePattern]
+ val (star1, star2) = (pivot.hasStar, sp.hasStar)
+
+ if (sp.nonStarLength == pivotLen) {
+ Some((star1, star2) match {
+ case (true, true) => (sp rebindStar scrut.seqType) ::: List(NoPattern)
case (true, false) => toPats(xs ::: List(gen.mkNil, EmptyTree))
- case (false, true) => removeStar(toPats(xs))
+ case (false, true) => (sp rebindStar scrut.seqType)
case (false, false) => toPats(xs) ::: List(NoPattern)
- }
- case av @ ArrayValue(_, xs) if pivot.hasStar && isRightIgnoring(av) && xs.length-1 < pivotLen =>
- emptyPatterns(pivotLen + 1) ::: List(x)
-
- case EmptyTree | WILD() =>
- emptyPatterns(pivot.elemPatterns.length + 1)
+ })
+ }
+ else if (pivot.hasStar && isRightIgnoring(av) && xs.length-1 < pivotLen)
+ Some(emptyPatterns(pivotLen + 1) ::: List(x))
+ else
+ defaults // XXX
}
lazy val cond = (pivot precondition pmatch).get
@@ -487,7 +489,7 @@ trait ParallelMatching extends ast.TreeDSL
}
// @todo: equals test for same constant
- class MixEquals(val pmatch: PatternMatch, val rest: Rep) extends RuleApplicationFormal {
+ class MixEquals(val pmatch: PatternMatch, val rest: Rep) extends RuleApplication {
private def mkNewRep(rows: List[Row]) =
make(scrut.sym :: rest.tvars, rows).toTree
@@ -520,7 +522,7 @@ trait ParallelMatching extends ast.TreeDSL
/** mixture rule for type tests
**/
- class MixTypes(val pmatch: PatternMatch, val rest: Rep) extends RuleApplicationFormal {
+ class MixTypes(val pmatch: PatternMatch, val rest: Rep) extends RuleApplication {
// see bug1434.scala for an illustration of why "x <:< y" is insufficient.
// this code is definitely inadequate at best. Inherited comment:
//
diff --git a/src/compiler/scala/tools/nsc/matching/PatternBindings.scala b/src/compiler/scala/tools/nsc/matching/PatternBindings.scala
index 991b94330f..de622b91a2 100644
--- a/src/compiler/scala/tools/nsc/matching/PatternBindings.scala
+++ b/src/compiler/scala/tools/nsc/matching/PatternBindings.scala
@@ -43,11 +43,6 @@ trait PatternBindings extends ast.TreeDSL
}
}
- def makeBind(vs: List[Symbol], pat: Tree): Tree = vs match {
- case Nil => pat
- case x :: xs => Bind(x, makeBind(xs, pat)) setType pat.tpe
- }
-
trait PatternBindingLogic {
self: Pattern =>
@@ -73,8 +68,12 @@ trait PatternBindings extends ast.TreeDSL
// Bind(v3, Bind(v2, Bind(v1, tree)))
// This takes the given tree and creates a new pattern
// using the same bindings.
- def rebindTo(t: Tree): Pattern =
+ def rebindTo(t: Tree): Pattern = {
+ if (boundVariables.size < definedVars.size)
+ TRACE("In %s, boundVariables = %s but definedVars = %s", this, boundVariables, definedVars)
+
Pattern(wrapBindings(boundVariables, t))
+ }
// Wrap this pattern's bindings around (_: Type)
def rebindToType(tpe: Type, annotatedType: Type = null): Pattern = {
@@ -98,7 +97,6 @@ trait PatternBindings extends ast.TreeDSL
}
/** Helpers **/
-
private def wrapBindings(vs: List[Symbol], pat: Tree): Tree = vs match {
case Nil => pat
case x :: xs => Bind(x, wrapBindings(xs, pat)) setType pat.tpe
diff --git a/src/compiler/scala/tools/nsc/matching/Patterns.scala b/src/compiler/scala/tools/nsc/matching/Patterns.scala
index 573adee60b..647d6c92d8 100644
--- a/src/compiler/scala/tools/nsc/matching/Patterns.scala
+++ b/src/compiler/scala/tools/nsc/matching/Patterns.scala
@@ -163,7 +163,7 @@ trait Patterns extends ast.TreeDSL {
val listRef = typeRef(pre, ListClass, List(tpe))
def fold(x: Tree, xs: Tree) = unbind(x) match {
- case _: Star => makeBind(Pattern(x).definedVars, WILD(x.tpe))
+ case _: Star => Pattern(x) rebindTo WILD(x.tpe) boundTree // this is using boundVariables instead of definedVars
case _ =>
val dummyMethod = new TermSymbol(NoSymbol, NoPosition, "matching$dummy")
val consType = MethodType(dummyMethod newSyntheticValueParams List(tpe, listRef), consRef)
@@ -187,6 +187,7 @@ trait Patterns extends ast.TreeDSL {
lazy val ArrayValue(elemtpt, elems) = tree
lazy val elemPatterns = toPats(elems)
lazy val nonStarPatterns = if (hasStar) elemPatterns.init else elemPatterns
+ private def starPattern = elemPatterns.last
override def subpatternsForVars: List[Pattern] = elemPatterns
@@ -196,7 +197,7 @@ trait Patterns extends ast.TreeDSL {
def rebindStar(seqType: Type): List[Pattern] = {
require(hasStar)
- nonStarPatterns ::: List(elemPatterns.last rebindToType seqType)
+ nonStarPatterns ::: List(starPattern rebindTo WILD(seqType))
}
// optimization to avoid trying to match if length makes it impossible
@@ -484,7 +485,6 @@ trait Patterns extends ast.TreeDSL {
if (boundTree eq tree) Pattern(tree)
else Pattern(tree) withBoundTree boundTree.asInstanceOf[Bind]
- // override def toString() = "Pattern(%s, %s)".format(tree, boundVariables)
override def equals(other: Any) = other match {
case x: Pattern => this.boundTree == x.boundTree
case _ => super.equals(other)