From 9d9889a7d6b9625aff6ee9ef72850cd9c9e7c17c Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Wed, 14 Oct 2009 21:57:35 +0000 Subject: A hard fought distillation of sequence patterns. I can fix #1697 without making other things break (fix not included but should be forthcoming.) --- .../tools/nsc/matching/ParallelMatching.scala | 100 +++++--------- .../scala/tools/nsc/matching/Patterns.scala | 143 +++++++++++---------- 2 files changed, 107 insertions(+), 136 deletions(-) (limited to 'src') diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala index 6c8c2c5956..c5fd840973 100644 --- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala +++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala @@ -189,7 +189,7 @@ trait ParallelMatching extends ast.TreeDSL // Any unapply - returns Some(true) if a type test is needed before the unapply can // be called (e.g. def unapply(x: Foo) = { ... } but our scrutinee is type Any.) object AnyUnapply { - def unapply(x: Tree): Option[Boolean] = condOpt(x) { + def unapply(x: Pattern): Option[Boolean] = condOpt(x.tree) { case UnapplyParamType(tpe) => !(scrut.tpe <:< tpe) } } @@ -201,18 +201,16 @@ trait ParallelMatching extends ast.TreeDSL } def mkRule(rest: Rep): RuleApplication = { - tracing("Rule", head.tree match { - case x if isEquals(x.tpe) => new MixEquals(this, rest) - case x: ArrayValue => new MixSequence(this, rest) - case AnyUnapply(false) => new MixUnapply(this, rest, false) - // case TypedUnapply(needsTest) => - case _ => - isPatternSwitch(scrut, ps) match { - case Some(x) => new MixLiteralInts(x, rest) - case _ => new MixTypes(this, rest) - } - } - ) + tracing("Rule", head match { + case x if isEquals(x.tree.tpe) => new MixEquals(this, rest) + case x: SequencePattern => new MixSequence(this, rest, x) + case AnyUnapply(false) => new MixUnapply(this, rest, false) + case _ => + isPatternSwitch(scrut, ps) match { + case Some(x) => new MixLiteralInts(x, rest) + case _ => new MixTypes(this, rest) + } + }) } override def toString() = "%s match {%s}".format(scrut, indentAll(ps)) } // PatternMatch @@ -441,69 +439,37 @@ trait ParallelMatching extends ast.TreeDSL squeezedBlock(List(handleOuter(unapplyResult.valDef)), codegen) } - /** handle sequence pattern and ArrayValue (but not star patterns) + /** Handle Sequence patterns (including Star patterns.) + * Note: pivot == head, just better typed. */ - sealed class MixSequence(val pmatch: PatternMatch, val rest: Rep) extends RuleApplication { - // Called 'pivot' since it's the head of the column under consideration in the mixture rule. - val pivot @ SequencePattern(av @ ArrayValue(_, _)) = head + sealed class MixSequence(val pmatch: PatternMatch, val rest: Rep, pivot: SequencePattern) extends RuleApplication { + def hasStar = pivot.hasStar private def pivotLen = pivot.nonStarLength - def mustCheck(first: Pattern, next: Pattern): Boolean = { - if (first.tree eq next.tree) - return false - - !(first completelyCovers next) - } + // one pattern var per sequence element up to elemCount, and one more for the rest of the sequence + lazy val pvs = scrut createSequenceVars pivotLen - def getSubPatterns(x: Pattern): Option[List[Pattern]] = { - // def defaults = Some(emptyPatterns(pivot.elemPatterns.length + 1)) - def defaults = Some(pivot.dummies) - val av @ ArrayValue(_, xs) = x.tree match { - case x: ArrayValue => x - case EmptyTree | WILD() => return defaults - case _ => return None - } + // divide the remaining rows into success/failure branches, expanding subsequences of patterns + private lazy val rowsplit = { + require(scrut.tpe <:< head.tpe) - val sp = x.asInstanceOf[SequencePattern] - val (star1, star2) = (pivot.hasStar, sp.hasStar) + List.unzip( + for ((c, rows) <- pmatch pzip rest.rows) yield { + def canSkip = pivot canSkipSubsequences c + def passthrough(skip: Boolean) = if (skip) None else Some(rows insert c) - if (sp.nonStarLength == pivotLen) { - Some((star1, star2) match { - case (true, true) => (sp rebindStar scrut.seqType) ::: List(NoPattern) - case (true, false) => toPats(xs ::: List(gen.mkNil, EmptyTree)) - case (false, true) => (sp rebindStar scrut.seqType) - case (false, false) => toPats(xs) ::: List(NoPattern) - }) - } - else if (pivot.hasStar && sp.hasStar && xs.length-1 < pivotLen) - Some(emptyPatterns(pivotLen + 1) ::: List(x)) - else - defaults // XXX - } - - lazy val cond = (pivot precondition pmatch).get - - lazy val (success, failure) = { - assert(scrut.tpe <:< head.tpe, "fatal: %s is not <:< %s".format(scrut, head.tpe)) - - // one pattern var per sequence element up to elemCount, and one more for the rest of the sequence - val pvs = scrut createSequenceVars pivot.nonStarPatterns.size - - val (nrows, frows): (List[Option[Row]], List[Option[Row]]) = List.unzip( - for ((c, rows) <- pmatch pzip rest.rows) yield getSubPatterns(c) match { - case Some(ps) => (Some(rows insert ps), if (mustCheck(head, c)) Some(rows insert c) else None) - case None => (None, Some(rows insert c)) + pivot.subsequences(c, scrut.seqType) match { + case Some(ps) => (Some(rows insert ps), passthrough(canSkip)) + case None => (None, passthrough(false)) + } } - ) - - val succ = remake(nrows.flatten, pvs, includeScrut = pivot.hasStar) - - ( - squeezedBlockPVs(pvs, succ.toTree), - remake(frows.flatten).toTree - ) + ) match { case (l1, l2) => (l1.flatten, l2.flatten) } } + lazy val cond = (pivot precondition pmatch).get // length check + lazy val success = squeezedBlockPVs(pvs, remake(rowsplit._1, pvs, hasStar).toTree) + lazy val failure = remake(rowsplit._2).toTree + final def tree(): Tree = codegen } diff --git a/src/compiler/scala/tools/nsc/matching/Patterns.scala b/src/compiler/scala/tools/nsc/matching/Patterns.scala index e2ef626751..ffbe8191e7 100644 --- a/src/compiler/scala/tools/nsc/matching/Patterns.scala +++ b/src/compiler/scala/tools/nsc/matching/Patterns.scala @@ -22,7 +22,7 @@ trait Patterns extends ast.TreeDSL { import definitions._ import CODE._ import Debug._ - import treeInfo.{ unbind, isVarPattern, isVariableName } + import treeInfo.{ unbind, isStar, isVarPattern, isVariableName } type PatternMatch = MatchMatrix#PatternMatch private type PatternVar = MatrixContext#PatternVar @@ -37,6 +37,9 @@ trait Patterns extends ast.TreeDSL { // The constant null pattern def NullPattern = LiteralPattern(NULL) + // The Nil pattern + def NilPattern = Pattern(gen.mkNil) + // 8.1.1 case class VariablePattern(tree: Ident) extends NamePattern { val Ident(name) = tree @@ -208,35 +211,35 @@ trait Patterns extends ast.TreeDSL { override def description = "UnSeq(%s => %s)".format(tptArg, resTypesString) } - // 8.1.8 (b) (literal ArrayValues) - case class SequencePattern(tree: ArrayValue) extends Pattern { + abstract class SequencePattern extends Pattern { + val tree: ArrayValue + def nonStarPatterns: List[Pattern] + def subsequences(other: Pattern, seqType: Type): Option[List[Pattern]] + def canSkipSubsequences(second: Pattern): Boolean + lazy val ArrayValue(elemtpt, elems) = tree lazy val elemPatterns = toPats(elems) - lazy val nonStarPatterns = if (hasStar) elemPatterns.init else elemPatterns - private def lastPattern = elemPatterns.last + override def dummies = emptyPatterns(elems.length + 1) override def subpatternsForVars: List[Pattern] = elemPatterns def nonStarLength = nonStarPatterns.length def isAllDefaults = nonStarPatterns forall (_.isDefault) - override def dummies = emptyPatterns(elemPatterns.length + 1) + def isShorter(other: SequencePattern) = nonStarLength < other.nonStarLength + def isSameLength(other: SequencePattern) = nonStarLength == other.nonStarLength - def rebindStar(seqType: Type): List[Pattern] = { - require(hasStar) - nonStarPatterns ::: List(lastPattern rebindTo WILD(seqType)) - } + protected def lengthCheckOp: (Tree, Tree) => Tree = + if (hasStar) _ ANY_>= _ + else _ MEMBER_== _ // optimization to avoid trying to match if length makes it impossible override def precondition(pm: PatternMatch) = { import pm.{ scrut, head } val len = nonStarLength - val compareOp = head.tpe member nme.lengthCompare // symbol for "lengthCompare" method - val op: (Tree, Tree) => Tree = - if (hasStar) _ ANY_>= _ - else _ MEMBER_== _ + val compareOp = head.tpe member nme.lengthCompare - def cmpFunction(t1: Tree) = op((t1 DOT compareOp)(LIT(len)), ZERO) + def cmpFunction(t1: Tree) = lengthCheckOp((t1 DOT compareOp)(LIT(len)), ZERO) Some(nullSafe(cmpFunction _, FALSE)(scrut.id)) } @@ -245,41 +248,54 @@ trait Patterns extends ast.TreeDSL { * (the conditional supplied by getPrecondition.) This is an optimization to avoid checking sequences * which cannot match due to a length incompatibility. */ - override def completelyCovers(second: Pattern): Boolean = { - if (second.isDefault) return false - - second match { - case x: SequencePattern => - val (len1, len2) = (nonStarLength, x.nonStarLength) - val (star1, star2) = (this.hasStar, x.hasStar) - - // this still needs rewriting. - val res = - ( star1 && star2 && len2 < len1 ) || // Seq(a,b,c,_*) followed by Seq(a,b,_*) because of (a,b) - ( star1 && !star2 && len2 < len1 && isAllDefaults ) || // Seq(a,b,c,_*) followed by Seq(a,b) because of (a,b) - // ( star1 && len2 < len1 ) || - (!star1 && star2 ) || - (!star1 && !star2 && len2 >= len1 ) - - !res - case _ => - // shouldn't happen... - false - } - } override def description = "Seq(%s)".format(elemPatterns) } + // 8.1.8 (b) (literal ArrayValues) + case class SequenceNoStarPattern(tree: ArrayValue) extends SequencePattern { + require(!hasStar) + lazy val nonStarPatterns = elemPatterns + + // no star + def subsequences(other: Pattern, seqType: Type): Option[List[Pattern]] = + condOpt(other) { + case next: SequenceStarPattern if isSameLength(next) => next rebindStar seqType + case next: SequenceNoStarPattern if isSameLength(next) => next.elemPatterns ::: List(NoPattern) + case WildcardPattern() | (_: SequencePattern) => dummies + } + + def canSkipSubsequences(second: Pattern): Boolean = + (tree eq second.tree) || (cond(second) { + case x: SequenceNoStarPattern => (x isShorter this) && this.isAllDefaults + }) + } + // 8.1.8 (b) - // temporarily subsumed by SequencePattern - // case class SequenceStarPattern(tree: ArrayValue) extends Pattern { } + case class SequenceStarPattern(tree: ArrayValue) extends SequencePattern { + require(hasStar) + lazy val nonStarPatterns = elemPatterns.init + + // yes star + private def nilPats = List(NilPattern, NoPattern) + def subsequences(other: Pattern, seqType: Type): Option[List[Pattern]] = + condOpt(other) { + case next: SequenceStarPattern if isSameLength(next) => (next rebindStar seqType) ::: List(NoPattern) + case next: SequenceStarPattern if (next isShorter this) => (dummies drop 1) ::: List(next) + case next: SequenceNoStarPattern if isSameLength(next) => next.elemPatterns ::: nilPats + case WildcardPattern() | (_: SequencePattern) => dummies + } - // abstract trait ArrayValuePattern extends Pattern { - // val tree: ArrayValue - // lazy val av @ ArrayValue(elemTpt, elems) = tree - // lazy val elemPatterns = toPats(elems) - // def nonStarElems = if (isRightIgnoring) elems.init else elems - // } + def rebindStar(seqType: Type): List[Pattern] = + nonStarPatterns ::: List(elemPatterns.last rebindTo WILD(seqType)) + + def canSkipSubsequences(second: Pattern): Boolean = + (tree eq second.tree) || (cond(second) { + case x: SequenceStarPattern => this isShorter x + case x: SequenceNoStarPattern => !(x isShorter this) + }) + + override def description = "Seq*(%s)".format(elemPatterns) + } // 8.1.8 (c) case class StarPattern(tree: Star) extends Pattern { @@ -313,6 +329,16 @@ trait Patterns extends ast.TreeDSL { // a small tree -> pattern cache private val cache = new collection.mutable.HashMap[Tree, Pattern] + def unadorn(x: Tree): Tree = x match { + case Typed(expr, _) => unadorn(expr) + case Bind(_, x) => unadorn(x) + case _ => x + } + + def isRightIgnoring(t: Tree) = cond(unadorn(t)) { + case ArrayValue(_, xs) if !xs.isEmpty => isStar(unadorn(xs.last)) + } + def apply(tree: Tree): Pattern = { if (cache contains tree) return cache(tree) @@ -327,8 +353,7 @@ trait Patterns extends ast.TreeDSL { case x: Literal => LiteralPattern(x) case x: UnApply => UnapplyPattern(x) case x: Ident => if (isVarPattern(x)) VariablePattern(x) else SimpleIdPattern(x) - // case x: ArrayValue => if (isRightIgnoring(x)) SequenceStarPattern(x) else SequencePattern(x) - case x: ArrayValue => SequencePattern(x) + case x: ArrayValue => if (isRightIgnoring(x)) SequenceStarPattern(x) else SequenceNoStarPattern(x) case x: Select => StableIdPattern(x) case x: Star => StarPattern(x) case x: This => ThisPattern(x) // XXX ? @@ -521,15 +546,12 @@ trait Patterns extends ast.TreeDSL { def isCaseClass = tpe.typeSymbol hasFlag Flags.CASE def isObject = isSymValid && prefix.isStable // XXX not entire logic - def unadorn(x: Tree): Tree = x match { - case Typed(expr, _) => unadorn(expr) - case Bind(_, x) => unadorn(x) - case _ => x - } + def unadorn(t: Tree): Tree = Pattern unadorn t private def isStar(x: Tree) = cond(unadorn(x)) { case Star(_) => true } private def endsStar(xs: List[Tree]) = xs.nonEmpty && isStar(xs.last) + def isStarSequence = isSequence && hasStar def isSequence = cond(unadorn(tree)) { case Sequence(xs) => true case ArrayValue(tpt, xs) => true @@ -577,21 +599,4 @@ trait Patterns extends ast.TreeDSL { } } } - - // object SeqStarSubPatterns { - // def removeStar(xs: List[Tree], seqType: Type): List[Pattern] = { - // val ps = toPats(xs) - // ps.init ::: List(ps.last rebindToType seqType) - // } - // - // def unapply(x: Pattern)(implicit min: Int, seqType: Type): Option[List[Pattern]] = x.tree match { - // case av @ ArrayValue(_, xs) => - // if (!isRightIgnoring(av) && xs.length == min) Some(toPats(xs ::: List(gen.mkNil, EmptyTree))) // Seq(p1,...,pN) - // else if ( isRightIgnoring(av) && xs.length-1 == min) Some(removeStar(xs, seqType) ::: List(NoPattern)) // Seq(p1,...,pN,_*) - // else if ( isRightIgnoring(av) && xs.length-1 < min) Some(emptyPatterns(min + 1) ::: List(x)) // Seq(p1..,pJ,_*) J < N - // else None - // case _ => - // if (x.isDefault) Some(emptyPatterns(min + 1 + 1)) else None - // } - // } } \ No newline at end of file -- cgit v1.2.3