summaryrefslogtreecommitdiff
path: root/src/compiler
diff options
context:
space:
mode:
authorPaul Phillips <paulp@improving.org>2009-10-14 21:57:35 +0000
committerPaul Phillips <paulp@improving.org>2009-10-14 21:57:35 +0000
commit9d9889a7d6b9625aff6ee9ef72850cd9c9e7c17c (patch)
tree8b8364894c6eea0db2ce89b2b69c02872005ac88 /src/compiler
parent1747692434cece862d63a0f67decd810707b1508 (diff)
downloadscala-9d9889a7d6b9625aff6ee9ef72850cd9c9e7c17c.tar.gz
scala-9d9889a7d6b9625aff6ee9ef72850cd9c9e7c17c.tar.bz2
scala-9d9889a7d6b9625aff6ee9ef72850cd9c9e7c17c.zip
A hard fought distillation of sequence patterns.
I can fix #1697 without making other things break (fix not included but should be forthcoming.)
Diffstat (limited to 'src/compiler')
-rw-r--r--src/compiler/scala/tools/nsc/matching/ParallelMatching.scala100
-rw-r--r--src/compiler/scala/tools/nsc/matching/Patterns.scala143
2 files changed, 107 insertions, 136 deletions
diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
index 6c8c2c5956..c5fd840973 100644
--- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
+++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
@@ -189,7 +189,7 @@ trait ParallelMatching extends ast.TreeDSL
// Any unapply - returns Some(true) if a type test is needed before the unapply can
// be called (e.g. def unapply(x: Foo) = { ... } but our scrutinee is type Any.)
object AnyUnapply {
- def unapply(x: Tree): Option[Boolean] = condOpt(x) {
+ def unapply(x: Pattern): Option[Boolean] = condOpt(x.tree) {
case UnapplyParamType(tpe) => !(scrut.tpe <:< tpe)
}
}
@@ -201,18 +201,16 @@ trait ParallelMatching extends ast.TreeDSL
}
def mkRule(rest: Rep): RuleApplication = {
- tracing("Rule", head.tree match {
- case x if isEquals(x.tpe) => new MixEquals(this, rest)
- case x: ArrayValue => new MixSequence(this, rest)
- case AnyUnapply(false) => new MixUnapply(this, rest, false)
- // case TypedUnapply(needsTest) =>
- case _ =>
- isPatternSwitch(scrut, ps) match {
- case Some(x) => new MixLiteralInts(x, rest)
- case _ => new MixTypes(this, rest)
- }
- }
- )
+ tracing("Rule", head match {
+ case x if isEquals(x.tree.tpe) => new MixEquals(this, rest)
+ case x: SequencePattern => new MixSequence(this, rest, x)
+ case AnyUnapply(false) => new MixUnapply(this, rest, false)
+ case _ =>
+ isPatternSwitch(scrut, ps) match {
+ case Some(x) => new MixLiteralInts(x, rest)
+ case _ => new MixTypes(this, rest)
+ }
+ })
}
override def toString() = "%s match {%s}".format(scrut, indentAll(ps))
} // PatternMatch
@@ -441,69 +439,37 @@ trait ParallelMatching extends ast.TreeDSL
squeezedBlock(List(handleOuter(unapplyResult.valDef)), codegen)
}
- /** handle sequence pattern and ArrayValue (but not star patterns)
+ /** Handle Sequence patterns (including Star patterns.)
+ * Note: pivot == head, just better typed.
*/
- sealed class MixSequence(val pmatch: PatternMatch, val rest: Rep) extends RuleApplication {
- // Called 'pivot' since it's the head of the column under consideration in the mixture rule.
- val pivot @ SequencePattern(av @ ArrayValue(_, _)) = head
+ sealed class MixSequence(val pmatch: PatternMatch, val rest: Rep, pivot: SequencePattern) extends RuleApplication {
+ def hasStar = pivot.hasStar
private def pivotLen = pivot.nonStarLength
- def mustCheck(first: Pattern, next: Pattern): Boolean = {
- if (first.tree eq next.tree)
- return false
-
- !(first completelyCovers next)
- }
+ // one pattern var per sequence element up to elemCount, and one more for the rest of the sequence
+ lazy val pvs = scrut createSequenceVars pivotLen
- def getSubPatterns(x: Pattern): Option[List[Pattern]] = {
- // def defaults = Some(emptyPatterns(pivot.elemPatterns.length + 1))
- def defaults = Some(pivot.dummies)
- val av @ ArrayValue(_, xs) = x.tree match {
- case x: ArrayValue => x
- case EmptyTree | WILD() => return defaults
- case _ => return None
- }
+ // divide the remaining rows into success/failure branches, expanding subsequences of patterns
+ private lazy val rowsplit = {
+ require(scrut.tpe <:< head.tpe)
- val sp = x.asInstanceOf[SequencePattern]
- val (star1, star2) = (pivot.hasStar, sp.hasStar)
+ List.unzip(
+ for ((c, rows) <- pmatch pzip rest.rows) yield {
+ def canSkip = pivot canSkipSubsequences c
+ def passthrough(skip: Boolean) = if (skip) None else Some(rows insert c)
- if (sp.nonStarLength == pivotLen) {
- Some((star1, star2) match {
- case (true, true) => (sp rebindStar scrut.seqType) ::: List(NoPattern)
- case (true, false) => toPats(xs ::: List(gen.mkNil, EmptyTree))
- case (false, true) => (sp rebindStar scrut.seqType)
- case (false, false) => toPats(xs) ::: List(NoPattern)
- })
- }
- else if (pivot.hasStar && sp.hasStar && xs.length-1 < pivotLen)
- Some(emptyPatterns(pivotLen + 1) ::: List(x))
- else
- defaults // XXX
- }
-
- lazy val cond = (pivot precondition pmatch).get
-
- lazy val (success, failure) = {
- assert(scrut.tpe <:< head.tpe, "fatal: %s is not <:< %s".format(scrut, head.tpe))
-
- // one pattern var per sequence element up to elemCount, and one more for the rest of the sequence
- val pvs = scrut createSequenceVars pivot.nonStarPatterns.size
-
- val (nrows, frows): (List[Option[Row]], List[Option[Row]]) = List.unzip(
- for ((c, rows) <- pmatch pzip rest.rows) yield getSubPatterns(c) match {
- case Some(ps) => (Some(rows insert ps), if (mustCheck(head, c)) Some(rows insert c) else None)
- case None => (None, Some(rows insert c))
+ pivot.subsequences(c, scrut.seqType) match {
+ case Some(ps) => (Some(rows insert ps), passthrough(canSkip))
+ case None => (None, passthrough(false))
+ }
}
- )
-
- val succ = remake(nrows.flatten, pvs, includeScrut = pivot.hasStar)
-
- (
- squeezedBlockPVs(pvs, succ.toTree),
- remake(frows.flatten).toTree
- )
+ ) match { case (l1, l2) => (l1.flatten, l2.flatten) }
}
+ lazy val cond = (pivot precondition pmatch).get // length check
+ lazy val success = squeezedBlockPVs(pvs, remake(rowsplit._1, pvs, hasStar).toTree)
+ lazy val failure = remake(rowsplit._2).toTree
+
final def tree(): Tree = codegen
}
diff --git a/src/compiler/scala/tools/nsc/matching/Patterns.scala b/src/compiler/scala/tools/nsc/matching/Patterns.scala
index e2ef626751..ffbe8191e7 100644
--- a/src/compiler/scala/tools/nsc/matching/Patterns.scala
+++ b/src/compiler/scala/tools/nsc/matching/Patterns.scala
@@ -22,7 +22,7 @@ trait Patterns extends ast.TreeDSL {
import definitions._
import CODE._
import Debug._
- import treeInfo.{ unbind, isVarPattern, isVariableName }
+ import treeInfo.{ unbind, isStar, isVarPattern, isVariableName }
type PatternMatch = MatchMatrix#PatternMatch
private type PatternVar = MatrixContext#PatternVar
@@ -37,6 +37,9 @@ trait Patterns extends ast.TreeDSL {
// The constant null pattern
def NullPattern = LiteralPattern(NULL)
+ // The Nil pattern
+ def NilPattern = Pattern(gen.mkNil)
+
// 8.1.1
case class VariablePattern(tree: Ident) extends NamePattern {
val Ident(name) = tree
@@ -208,35 +211,35 @@ trait Patterns extends ast.TreeDSL {
override def description = "UnSeq(%s => %s)".format(tptArg, resTypesString)
}
- // 8.1.8 (b) (literal ArrayValues)
- case class SequencePattern(tree: ArrayValue) extends Pattern {
+ abstract class SequencePattern extends Pattern {
+ val tree: ArrayValue
+ def nonStarPatterns: List[Pattern]
+ def subsequences(other: Pattern, seqType: Type): Option[List[Pattern]]
+ def canSkipSubsequences(second: Pattern): Boolean
+
lazy val ArrayValue(elemtpt, elems) = tree
lazy val elemPatterns = toPats(elems)
- lazy val nonStarPatterns = if (hasStar) elemPatterns.init else elemPatterns
- private def lastPattern = elemPatterns.last
+ override def dummies = emptyPatterns(elems.length + 1)
override def subpatternsForVars: List[Pattern] = elemPatterns
def nonStarLength = nonStarPatterns.length
def isAllDefaults = nonStarPatterns forall (_.isDefault)
- override def dummies = emptyPatterns(elemPatterns.length + 1)
+ def isShorter(other: SequencePattern) = nonStarLength < other.nonStarLength
+ def isSameLength(other: SequencePattern) = nonStarLength == other.nonStarLength
- def rebindStar(seqType: Type): List[Pattern] = {
- require(hasStar)
- nonStarPatterns ::: List(lastPattern rebindTo WILD(seqType))
- }
+ protected def lengthCheckOp: (Tree, Tree) => Tree =
+ if (hasStar) _ ANY_>= _
+ else _ MEMBER_== _
// optimization to avoid trying to match if length makes it impossible
override def precondition(pm: PatternMatch) = {
import pm.{ scrut, head }
val len = nonStarLength
- val compareOp = head.tpe member nme.lengthCompare // symbol for "lengthCompare" method
- val op: (Tree, Tree) => Tree =
- if (hasStar) _ ANY_>= _
- else _ MEMBER_== _
+ val compareOp = head.tpe member nme.lengthCompare
- def cmpFunction(t1: Tree) = op((t1 DOT compareOp)(LIT(len)), ZERO)
+ def cmpFunction(t1: Tree) = lengthCheckOp((t1 DOT compareOp)(LIT(len)), ZERO)
Some(nullSafe(cmpFunction _, FALSE)(scrut.id))
}
@@ -245,41 +248,54 @@ trait Patterns extends ast.TreeDSL {
* (the conditional supplied by getPrecondition.) This is an optimization to avoid checking sequences
* which cannot match due to a length incompatibility.
*/
- override def completelyCovers(second: Pattern): Boolean = {
- if (second.isDefault) return false
-
- second match {
- case x: SequencePattern =>
- val (len1, len2) = (nonStarLength, x.nonStarLength)
- val (star1, star2) = (this.hasStar, x.hasStar)
-
- // this still needs rewriting.
- val res =
- ( star1 && star2 && len2 < len1 ) || // Seq(a,b,c,_*) followed by Seq(a,b,_*) because of (a,b)
- ( star1 && !star2 && len2 < len1 && isAllDefaults ) || // Seq(a,b,c,_*) followed by Seq(a,b) because of (a,b)
- // ( star1 && len2 < len1 ) ||
- (!star1 && star2 ) ||
- (!star1 && !star2 && len2 >= len1 )
-
- !res
- case _ =>
- // shouldn't happen...
- false
- }
- }
override def description = "Seq(%s)".format(elemPatterns)
}
+ // 8.1.8 (b) (literal ArrayValues)
+ case class SequenceNoStarPattern(tree: ArrayValue) extends SequencePattern {
+ require(!hasStar)
+ lazy val nonStarPatterns = elemPatterns
+
+ // no star
+ def subsequences(other: Pattern, seqType: Type): Option[List[Pattern]] =
+ condOpt(other) {
+ case next: SequenceStarPattern if isSameLength(next) => next rebindStar seqType
+ case next: SequenceNoStarPattern if isSameLength(next) => next.elemPatterns ::: List(NoPattern)
+ case WildcardPattern() | (_: SequencePattern) => dummies
+ }
+
+ def canSkipSubsequences(second: Pattern): Boolean =
+ (tree eq second.tree) || (cond(second) {
+ case x: SequenceNoStarPattern => (x isShorter this) && this.isAllDefaults
+ })
+ }
+
// 8.1.8 (b)
- // temporarily subsumed by SequencePattern
- // case class SequenceStarPattern(tree: ArrayValue) extends Pattern { }
+ case class SequenceStarPattern(tree: ArrayValue) extends SequencePattern {
+ require(hasStar)
+ lazy val nonStarPatterns = elemPatterns.init
+
+ // yes star
+ private def nilPats = List(NilPattern, NoPattern)
+ def subsequences(other: Pattern, seqType: Type): Option[List[Pattern]] =
+ condOpt(other) {
+ case next: SequenceStarPattern if isSameLength(next) => (next rebindStar seqType) ::: List(NoPattern)
+ case next: SequenceStarPattern if (next isShorter this) => (dummies drop 1) ::: List(next)
+ case next: SequenceNoStarPattern if isSameLength(next) => next.elemPatterns ::: nilPats
+ case WildcardPattern() | (_: SequencePattern) => dummies
+ }
- // abstract trait ArrayValuePattern extends Pattern {
- // val tree: ArrayValue
- // lazy val av @ ArrayValue(elemTpt, elems) = tree
- // lazy val elemPatterns = toPats(elems)
- // def nonStarElems = if (isRightIgnoring) elems.init else elems
- // }
+ def rebindStar(seqType: Type): List[Pattern] =
+ nonStarPatterns ::: List(elemPatterns.last rebindTo WILD(seqType))
+
+ def canSkipSubsequences(second: Pattern): Boolean =
+ (tree eq second.tree) || (cond(second) {
+ case x: SequenceStarPattern => this isShorter x
+ case x: SequenceNoStarPattern => !(x isShorter this)
+ })
+
+ override def description = "Seq*(%s)".format(elemPatterns)
+ }
// 8.1.8 (c)
case class StarPattern(tree: Star) extends Pattern {
@@ -313,6 +329,16 @@ trait Patterns extends ast.TreeDSL {
// a small tree -> pattern cache
private val cache = new collection.mutable.HashMap[Tree, Pattern]
+ def unadorn(x: Tree): Tree = x match {
+ case Typed(expr, _) => unadorn(expr)
+ case Bind(_, x) => unadorn(x)
+ case _ => x
+ }
+
+ def isRightIgnoring(t: Tree) = cond(unadorn(t)) {
+ case ArrayValue(_, xs) if !xs.isEmpty => isStar(unadorn(xs.last))
+ }
+
def apply(tree: Tree): Pattern = {
if (cache contains tree)
return cache(tree)
@@ -327,8 +353,7 @@ trait Patterns extends ast.TreeDSL {
case x: Literal => LiteralPattern(x)
case x: UnApply => UnapplyPattern(x)
case x: Ident => if (isVarPattern(x)) VariablePattern(x) else SimpleIdPattern(x)
- // case x: ArrayValue => if (isRightIgnoring(x)) SequenceStarPattern(x) else SequencePattern(x)
- case x: ArrayValue => SequencePattern(x)
+ case x: ArrayValue => if (isRightIgnoring(x)) SequenceStarPattern(x) else SequenceNoStarPattern(x)
case x: Select => StableIdPattern(x)
case x: Star => StarPattern(x)
case x: This => ThisPattern(x) // XXX ?
@@ -521,15 +546,12 @@ trait Patterns extends ast.TreeDSL {
def isCaseClass = tpe.typeSymbol hasFlag Flags.CASE
def isObject = isSymValid && prefix.isStable // XXX not entire logic
- def unadorn(x: Tree): Tree = x match {
- case Typed(expr, _) => unadorn(expr)
- case Bind(_, x) => unadorn(x)
- case _ => x
- }
+ def unadorn(t: Tree): Tree = Pattern unadorn t
private def isStar(x: Tree) = cond(unadorn(x)) { case Star(_) => true }
private def endsStar(xs: List[Tree]) = xs.nonEmpty && isStar(xs.last)
+ def isStarSequence = isSequence && hasStar
def isSequence = cond(unadorn(tree)) {
case Sequence(xs) => true
case ArrayValue(tpt, xs) => true
@@ -577,21 +599,4 @@ trait Patterns extends ast.TreeDSL {
}
}
}
-
- // object SeqStarSubPatterns {
- // def removeStar(xs: List[Tree], seqType: Type): List[Pattern] = {
- // val ps = toPats(xs)
- // ps.init ::: List(ps.last rebindToType seqType)
- // }
- //
- // def unapply(x: Pattern)(implicit min: Int, seqType: Type): Option[List[Pattern]] = x.tree match {
- // case av @ ArrayValue(_, xs) =>
- // if (!isRightIgnoring(av) && xs.length == min) Some(toPats(xs ::: List(gen.mkNil, EmptyTree))) // Seq(p1,...,pN)
- // else if ( isRightIgnoring(av) && xs.length-1 == min) Some(removeStar(xs, seqType) ::: List(NoPattern)) // Seq(p1,...,pN,_*)
- // else if ( isRightIgnoring(av) && xs.length-1 < min) Some(emptyPatterns(min + 1) ::: List(x)) // Seq(p1..,pJ,_*) J < N
- // else None
- // case _ =>
- // if (x.isDefault) Some(emptyPatterns(min + 1 + 1)) else None
- // }
- // }
} \ No newline at end of file