From 403bf69a0b2134773bc751bb79c81978608238ff Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Tue, 6 Oct 2009 06:20:55 +0000 Subject: Another day of pattern matcher work. by accident as I go - I just noticed #2175 is working. That's even better than fixing them on purpose, in the same way that "money won is twice as sweet as money earned." --- .../tools/nsc/matching/ParallelMatching.scala | 179 +++++++++------------ .../scala/tools/nsc/matching/PatternBindings.scala | 69 ++++---- .../scala/tools/nsc/matching/Patterns.scala | 91 ++++++++--- 3 files changed, 175 insertions(+), 164 deletions(-) diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala index 65106abc7e..8f354c2be9 100644 --- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala +++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala @@ -47,7 +47,6 @@ trait ParallelMatching extends ast.TreeDSL /** Transition **/ def isRightIgnoring(t: Tree) = cond(unbind(t)) { case ArrayValue(_, xs) if !xs.isEmpty => isStar(xs.last) } - def getDummies(i: Int): List[Tree] = List.fill(i)(EmptyTree) def toPats(xs: List[Tree]): List[Pattern] = xs map Pattern.apply /** Back to the regular schedule. **/ @@ -78,9 +77,9 @@ trait ParallelMatching extends ast.TreeDSL // shortcut if (bx < 0) Apply(ID(shortCuts(-bx-1)), Nil) // first time this bx is requested - might be bound elsewhere - else if (target.isNotReached) target.createLabelBody("body%"+bx, vsyms, vdefs) + else if (target.isNotReached) target.createLabelBody("body%"+bx, patternVars, patternValDefs) // call label "method" if possible - else target.getLabelBody(idents, vdefs) + else target.getLabelBody(idents, patternValDefs) } /** the injection here handles alternatives and unapply type tests */ @@ -110,23 +109,30 @@ trait ParallelMatching extends ast.TreeDSL // presenting a face of our symbol def tpe = sym.tpe def pos = sym.pos - def accessors = sym.caseFieldAccessors def id = ID(sym) // attributed ident + def accessors = if (isCaseClass) sym.caseFieldAccessors else Nil + def accessorVars = accessors map (a => newVarOfTpe((tpe memberType a).resultType)) + // tests - def isDefined = sym ne NoSymbol - def isSimple = tpe.isByte || tpe.isShort || tpe.isChar || tpe.isInt + def isDefined = sym ne NoSymbol + def isSimple = tpe.isByte || tpe.isShort || tpe.isChar || tpe.isInt + def isCaseClass = tpe.typeSymbol hasFlag Flags.CASE // sequences def seqType = tpe.widen baseType SeqClass def elemType = tpe typeArgs 0 // can this happen? if (seqType == NoType) error("...") + def newVarOfTpe(tpe: Type) = context.newVar(pos, tpe, flags) + def newVarOfSeqType = newVar(pos, seqType) + def newVarOfElemType = newVar(pos, elemType) + // for propagating "unchecked" to synthetic vars def flags: List[Long] = List(Flags.TRANS_FLAG) filter (sym hasFlag _) def castedTo(headType: Type) = if (tpe =:= headType) this - else new Scrutinee(newVar(pos, headType, flags = flags)) + else new Scrutinee(newVar(pos, headType, flags)) override def toString() = "Scrutinee(sym = %s, tpe = %s, id = %s)".format(sym, tpe, id) } @@ -141,10 +147,8 @@ trait ParallelMatching extends ast.TreeDSL // More immediately, Byte and Short scruts should also work. if (!scrut.isSimple) None else { - val (lits, others) = { - val (l, o) = ps span isSwitchableConst - (l filterMap { case x: LiteralPattern => x }, o) - } + val (_lits, others) = ps span isSwitchableConst + val lits = _lits filterMap { case x: LiteralPattern => x } condOpt(others) { case Nil => new PatternSwitch(scrut, lits, None) @@ -163,19 +167,17 @@ trait ParallelMatching extends ast.TreeDSL } case class PatternMatch(scrut: Scrutinee, ps: List[Pattern]) { - private lazy val trees = ps map (_.boundTree) - lazy val head = ps.head - lazy val tail = ps.tail - lazy val size = ps.length + def head = ps.head + def tail = ps.tail + def size = ps.length - lazy val headType = head.matchingType + def headType = head.matchingType def isCaseHead = head.isCaseClass - def dummies = if (isCaseHead) getDummies(headType.typeSymbol.caseFieldAccessors.length) else Nil - def dummyPatterns = dummies map (x => Pattern(x)) + private val dummyCount = if (isCaseHead) headType.typeSymbol.caseFieldAccessors.length else 0 + def dummies = emptyPatterns(dummyCount) def apply(i: Int): Pattern = ps(i) def pzip() = ps.zipWithIndex - def zip[T](others: List[T]) = trees zip others def pzip[T](others: List[T]) = ps zip others // Any unapply - returns Some(true) if a type test is needed before the unapply can @@ -185,9 +187,8 @@ trait ParallelMatching extends ast.TreeDSL } def mkRule(rest: Rep): RuleApplication = { - logAndReturn("mkRule: ", head.boundTree match { + logAndReturn("mkRule: ", head.tree match { case x if isEquals(x.tpe) => new MixEquals(this, rest) - case x: ArrayValue if isRightIgnoring(x) => new MixSequenceStar(this, rest) case x: ArrayValue => new MixSequence(this, rest) case AnyUnapply(false) => new MixUnapply(this, rest, false) case _ => @@ -261,10 +262,11 @@ trait ParallelMatching extends ast.TreeDSL final def tree(): Tree = { def body = requestBody(bx, subst) def guardTest = IF (guard.duplicate.tree) THEN body ELSE guardedRest.toTree + implicit val ctx = context typer typed( if (guard.isEmpty) body - else squeezedBlock(subst targetParams typer, guardTest) + else squeezedBlock(subst.infoForAll.patternValDefs, guardTest) ) } } @@ -358,12 +360,9 @@ trait ParallelMatching extends ast.TreeDSL } } - def newVarCapture(pos: Position, tpe: Type) = - newVar(pos, tpe, flags = scrut.flags) - /** returns (un)apply-call, success-rep, optional fail-rep */ final def getTransition(): Branch[UnapplyCall] = { - val unapplyRes = newVarCapture(ua.pos, app.tpe) + val unapplyRes = newVar(ua.pos, app.tpe, scrut.flags) val rhs = Apply(fxn, scrut.id :: trailingArgs) setType unapplyRes.tpe val zipped = pats pzip rest.rows val nrowsOther = zipped.tail flatMap { @@ -385,7 +384,7 @@ trait ParallelMatching extends ast.TreeDSL case _ => r insert (emptyPatterns(dum) ::: List(pat)) } def mkGet(s: Symbol) = typedValDef(s, fn(ID(unapplyRes), nme.get)) - def mkVar(tpe: Type) = newVarCapture(ua.pos, tpe) + def mkVar(tpe: Type) = newVar(ua.pos, tpe, scrut.flags) // 0 args => Boolean, 1 => Option[T], >1 => Option[? <: ProductN[T1,...,Tn]] args.length match { @@ -441,42 +440,46 @@ trait ParallelMatching extends ast.TreeDSL /** handle sequence pattern and ArrayValue (but not star patterns) */ sealed class MixSequence(val pats: PatternMatch, val rest: Rep) extends RuleApplication { + // Called 'pivot' since it's the head of the column under consideration in the mixture rule. + val pivot @ SequencePattern(av @ ArrayValue(_, _)) = head + private def pivotLen = pivot.nonStarLength + /** array elements except for star (if present) */ protected def nonStarElems(x: ArrayValue) = if (isRightIgnoring(x)) x.elems.init else x.elems - protected def elemLength(x: ArrayValue) = nonStarElems(x).length - protected def isAllDefaults(x: ArrayValue) = nonStarElems(x) forall (t => Pattern(t).isDefault) - final def removeStar(xs: List[Pattern]): List[Pattern] = xs.init ::: List(Pattern(makeBind(xs.last.boundVariables, WILD(scrut.seqType)))) - protected def getSubPatterns(len: Int, x: Tree): Option[List[Pattern]] = condOpt(x) { - case av @ ArrayValue(_,xs) if !isRightIgnoring(av) && xs.length == len => toPats(xs) ::: List(NoPattern) - case av @ ArrayValue(_,xs) if isRightIgnoring(av) && xs.length == len+1 => removeStar(toPats(xs)) // (*) - case EmptyTree | WILD() => emptyPatterns(len + 1) + def mustCheck(first: Pattern, next: Pattern): Boolean = { + if (first.tree eq next.tree) + return false + + !(first completelyCovers next) } - protected def makeSuccRep(vs: List[Symbol], tail: Symbol, nrows: List[Row]) = - make(vs ::: tail :: rest.tvars, nrows) + def getSubPatterns(x: Pattern): Option[List[Pattern]] = condOpt(x.tree) { + case av @ ArrayValue(_, xs) if nonStarElems(av).length == pivotLen => + val (star1, star2) = (pivot.hasStar, isRightIgnoring(av)) - /** True if 'next' must be checked even if 'first' failed to match after passing its length test - * (the conditional supplied by getPrecondition.) This is an optimization to avoid checking sequences - * which cannot match due to a length incompatibility. - */ - protected def mustCheck(first: Tree, next: Tree): Boolean = - (first ne next) && (Pattern(next).isDefault || cond((first, next)) { - case (av: ArrayValue, bv: ArrayValue) => - // number of non-star elements in each sequence - val (len1, len2) = (elemLength(av), elemLength(bv)) - val (star1, star2) = (isRightIgnoring(av), isRightIgnoring(bv)) - - // this still needs rewriting. - ( star1 && star2 && len2 < len1 ) || // Seq(a,b,c,_*) followed by Seq(a,b,_*) because of (a,b) - ( star1 && !star2 && len2 < len1 && isAllDefaults(av)) || // Seq(a,b,c,_*) followed by Seq(a,b) because of (a,b) - (!star1 && star2 ) || - (!star1 && !star2 && len2 >= len1 ) - }) + (star1, star2) match { + case (true, true) => removeStar(toPats(xs)) ::: List(NoPattern) + case (true, false) => toPats(xs ::: List(gen.mkNil, EmptyTree)) + case (false, true) => removeStar(toPats(xs)) + case (false, false) => toPats(xs) ::: List(NoPattern) + } + case av @ ArrayValue(_, xs) if pivot.hasStar && isRightIgnoring(av) && xs.length-1 < pivotLen => + emptyPatterns(pivotLen + 1) ::: List(x) + + case EmptyTree | WILD() => + emptyPatterns(pivot.elemPatterns.length + 1) + } + + def makeSuccRep(vs: List[Symbol], tail: Symbol, nrows: List[Row]) = { + val ssym = if (pivot.hasStar) List(scrut.sym) else Nil + + make(List(vs, List(tail), ssym, rest.tvars).flatten, nrows) + } case class TransitionContext(f: TreeFunction2) @@ -484,28 +487,25 @@ trait ParallelMatching extends ast.TreeDSL def getTransition(): Branch[TransitionContext] = { assert(scrut.tpe <:< head.tpe, "fatal: %s is not <:< %s".format(scrut, head.tpe)) - val av @ ArrayValue(_, elems) = head.tree - val ys = if (isRightIgnoring(av)) elems.init else elems - val vs = ys map (y => newVar(unbind(y).pos, scrut.elemType)) - def scrutCopy = scrut.id.duplicate - - lazy val tail = newVar(scrut.pos, scrut.seqType) - lazy val lastBinding = typedValDef(tail, scrutCopy DROP ys.size) - def elemAt(i: Int) = typer typed ((scrutCopy DOT (scrutCopy.tpe member nme.apply))(LIT(i))) + val vs = pivot.nonStarPatterns map (x => newVar(x.tree.pos, scrut.elemType)) + lazy val tail = scrut.newVarOfSeqType + lazy val lastBinding = typedValDef(tail, scrut.id DROP vs.size) + def elemAt(i: Int) = typer typed ((scrut.id DOT (scrut.tpe member nme.apply))(LIT(i))) val bindings = (vs.zipWithIndex map tupled((v, i) => typedValDef(v, elemAt(i)))) ::: List(lastBinding) val (nrows, frows): (List[Option[Row]], List[Option[Row]]) = List.unzip( - for ((c, rows) <- pats zip rest.rows) yield getSubPatterns(ys.size, c) match { - case Some(ps) => (Some(rows insert ps), if (mustCheck(av, c)) Some(rows insert Pattern(c)) else None) - case None => (None, Some(rows insert Pattern(c))) - }) + for ((c, rows) <- pats pzip rest.rows) yield getSubPatterns(c) match { + case Some(ps) => (Some(rows insert ps), if (mustCheck(head, c)) Some(rows insert c) else None) + case None => (None, Some(rows insert c)) + } + ) val succ = makeSuccRep(vs, tail, nrows flatMap (x => x)) val fail = mkFail(frows flatMap (x => x)) def transition = (thenp: Tree, elsep: Tree) => - IF (getPrecondition(scrutCopy, elemLength(av))) THEN squeezedBlock(bindings, thenp) ELSE elsep + IF (getPrecondition(scrut.id, pivot.nonStarLength)) THEN squeezedBlock(bindings, thenp) ELSE elsep Branch(TransitionContext(transition), succ, fail) } @@ -517,9 +517,10 @@ trait ParallelMatching extends ast.TreeDSL typer typed nullSafe(cmpFunction _, FALSE)(tree) } - // precondition for matching: sequence is exactly length of arg + // precondition for matching protected def getPrecondition(tree: Tree, lengthArg: Int) = - lengthCheck(tree, lengthArg, _ MEMBER_== _) + if (pivot.hasStar) lengthCheck(tree, lengthArg, _ ANY_>= _) // seq length >= pattern length + else lengthCheck(tree, lengthArg, _ MEMBER_== _) // seq length == pattern length final def tree() = { val Branch(TransitionContext(transition), succ, Some(fail)) = this.getTransition @@ -527,36 +528,6 @@ trait ParallelMatching extends ast.TreeDSL } } - /** handle sequence pattern and ArrayValue with star patterns - */ - final class MixSequenceStar(pats: PatternMatch, rest: Rep) extends MixSequence(pats, rest) { - // override def getSubPatterns(minlen: Int, x: Tree): Option[List[Pattern]] = { - // implicit val min = minlen - // implicit val tpe = scrut.seqType - // - // condOpt(Pattern(x)) { case SeqStarSubPatterns(xs) => xs } - // } - - // in principle, we could optimize more, but variable binding gets complicated (@todo use finite state methods instead) - override def getSubPatterns(minlen: Int, x: Tree): Option[List[Pattern]] = condOpt(x) { - case av @ ArrayValue(_,xs) if (!isRightIgnoring(av) && xs.length == minlen) => // Seq(p1,...,pN) - toPats(xs ::: List(gen.mkNil, EmptyTree)) - case av @ ArrayValue(_,xs) if ( isRightIgnoring(av) && xs.length-1 == minlen) => // Seq(p1,...,pN,_*) - removeStar(toPats(xs)) ::: List(NoPattern) - case av @ ArrayValue(_,xs) if ( isRightIgnoring(av) && xs.length-1 < minlen) => // Seq(p1..,pJ,_*) J < N - emptyPatterns(minlen + 1) ::: List(Pattern(x)) - case EmptyTree | WILD() => - emptyPatterns(minlen + 1 + 1) - } - - override protected def makeSuccRep(vs: List[Symbol], tail: Symbol, nrows: List[Row]) = - mkNewRep(vs ::: List(tail), rest.tvars, nrows) - - // precondition for matching - override protected def getPrecondition(tree: Tree, lengthArg: Int) = - lengthCheck(tree, lengthArg, _ ANY_>= _) - } - // @todo: equals test for same constant class MixEquals(val pats: PatternMatch, val rest: Rep) extends RuleApplication { /** condition (to be used in IF), success and failure Rep */ @@ -625,12 +596,11 @@ trait ParallelMatching extends ast.TreeDSL def sMatchesP = matches(s, p) def pMatchesS = matches(p, s) - def sEqualsP = p =:= s - def alts[T](yes: T, no: T): T = if (sEqualsP) yes else no + def alts[T](yes: T, no: T): T = if (p =:= s) yes else no def isObjectTest = pattern.isObject && (p =:= pattern.mkSingleton) - lazy val dummy = (j, pats.dummyPatterns) + lazy val dummy = (j, pats.dummies) lazy val pass = (j, pattern) lazy val subs = (j, pattern subpatterns pats) @@ -647,7 +617,7 @@ trait ParallelMatching extends ast.TreeDSL // (4) never =:= for (pattern match { - case Pattern(LIT(null), _) if !sEqualsP => (None, None, pass) // (1) + case Pattern(LIT(null), _) if !(p =:= s) => (None, None, pass) // (1) case x if isObjectTest => (NoPattern, dummy, None) // (2) case Pattern(Typed(pp @ Pattern(_: UnApply, _), _), _) if sMatchesP => (Pattern(pp), dummy, None) // (3) case Pattern(Typed(pp, _), _) if sMatchesP => (alts(Pattern(pp), pattern), dummy, None) // (4) @@ -676,18 +646,15 @@ trait ParallelMatching extends ast.TreeDSL /** returns casted symbol, success matrix and optionally fail matrix for type test on the top of this column */ final def getTransition(): Branch[Scrutinee] = { val casted = scrut castedTo pats.headType - // val neededCast = (scrut ne casted) val isAnyMoreSpecific = moreSpecific exists (x => !x.isEmpty) def mkZipped = moreSpecific zip subsumed map { case (mspat, (j, pats)) => (j, mspat :: pats) } - def mkAccessors = casted.accessors map (m => newVar(scrut.pos, (casted.tpe memberType m).resultType, scrut.flags)) val (subtests, subtestVars) = if (isAnyMoreSpecific) (mkZipped, List(casted.sym)) else (subsumed, Nil) - val accessorVars = if (pats.isCaseHead) mkAccessors else Nil val newRows = for ((j, ps) <- subtests) yield (rest rows j).insert2(ps, pats(j).boundVariables, casted.sym) @@ -695,7 +662,7 @@ trait ParallelMatching extends ast.TreeDSL Branch( casted, // succeeding => transition to translate(subsumed) (taking into account more specific) - make(subtestVars ::: accessorVars ::: rest.tvars, newRows), + make(subtestVars ::: casted.accessorVars ::: rest.tvars, newRows), // fails => transition to translate(remaining) mkFailOpt(remaining map tupled((p1, p2) => rest rows p1 insert p2)) ) diff --git a/src/compiler/scala/tools/nsc/matching/PatternBindings.scala b/src/compiler/scala/tools/nsc/matching/PatternBindings.scala index 575050821a..f791236a10 100644 --- a/src/compiler/scala/tools/nsc/matching/PatternBindings.scala +++ b/src/compiler/scala/tools/nsc/matching/PatternBindings.scala @@ -7,6 +7,7 @@ package scala.tools.nsc package matching import transform.ExplicitOuter +import collection.immutable.TreeMap trait PatternBindings extends ast.TreeDSL { @@ -44,55 +45,49 @@ trait PatternBindings extends ast.TreeDSL } case class Binding(pvar: Symbol, tvar: Symbol) { - override def toString() = "%s: %s @ %s: %s".format(pvar.name, pvar.tpe, tvar.name, tvar.tpe) + // see bug #1843 for the consequences of not setting info. + // there is surely a better way to do this, especially since + // this looks to be the only usage of containsTp anywhere + // in the compiler, but it suffices for now. + if (tvar.info containsTp WildcardType) + tvar setInfo pvar.info + + def toIdent = + Ident(tvar) setType pvar.tpe + + def castIfNeeded = + if (tvar.tpe <:< pvar.tpe) ID(tvar) + else ID(tvar) AS_ANY pvar.tpe } - case class BindingsInfo(xs: List[BindingInfo]) { - def idents = xs map (_.ident) - def vsyms = xs map (_.vsym) + case class BindingsInfo(xs: List[Binding]) { + def patternVars = xs map (_.pvar) + def temporaryVars = xs map (_.tvar) + def idents = xs map (_.toIdent) - def vdefs(implicit context: MatrixContext) = - xs map (x => context.typedValDef(x.vsym, x.ident)) + def patternValDefs(implicit context: MatrixContext) = + for (b @ Binding(pvar, tvar) <- xs) yield + context.typedValDef(pvar, b.toIdent) } - case class BindingInfo(vsym: Symbol, ident: Ident) - case class Bindings(bindings: Binding*) extends Function1[Symbol, Option[Ident]] { - private def castIfNeeded(pvar: Symbol, tvar: Symbol) = - if (tvar.tpe <:< pvar.tpe) ID(tvar) - else ID(tvar) AS_ANY pvar.tpe + class Bindings(private val vlist: List[Binding]) extends Function1[Symbol, Option[Ident]] { + def this() = this(Nil) + + def vmap(v: Symbol): Option[Binding] = vlist find (_.pvar eq v) // filters the given list down to those defined in these bindings - def infoFor(vs: List[Symbol]): BindingsInfo = BindingsInfo( - for (v <- vs ; substv <- apply(v)) yield - BindingInfo(v, substv) - ) + def infoFor(vs: List[Symbol]) = BindingsInfo(vs map vmap flatten) + def infoForAll = BindingsInfo(vlist) def add(vs: Iterable[Symbol], tvar: Symbol): Bindings = { - def newBinding(v: Symbol) = { - // see bug #1843 for the consequences of not setting info. - // there is surely a better way to do this, especially since - // this looks to be the only usage of containsTp anywhere - // in the compiler, but it suffices for now. - if (tvar.info containsTp WildcardType) - tvar setInfo v.info - - Binding(v, tvar) - } - val newBindings = vs.toList map newBinding - Bindings(newBindings ++ bindings: _*) + val newBindings = vs.toList map (v => Binding(v, tvar)) + new Bindings(newBindings ++ vlist) } - def apply(v: Symbol): Option[Ident] = - bindings find (_.pvar eq v) map (x => Ident(x.tvar) setType v.tpe) - - override def toString() = - if (bindings.isEmpty) "" else bindings.mkString(" Bound(", ", ", ")") + def apply(v: Symbol): Option[Ident] = vmap(v) map (_.toIdent) - /** The corresponding list of value definitions. */ - final def targetParams(implicit typer: analyzer.Typer): List[ValDef] = - for (Binding(v, t) <- bindings.toList) yield - VAL(v) === (typer typed castIfNeeded(v, t)) + override def toString() = " Bound(%s)".format(vlist) } - val NoBinding: Bindings = Bindings() + val NoBinding: Bindings = new Bindings() } \ No newline at end of file diff --git a/src/compiler/scala/tools/nsc/matching/Patterns.scala b/src/compiler/scala/tools/nsc/matching/Patterns.scala index cdaa37621f..017840d0c6 100644 --- a/src/compiler/scala/tools/nsc/matching/Patterns.scala +++ b/src/compiler/scala/tools/nsc/matching/Patterns.scala @@ -37,7 +37,8 @@ trait Patterns extends ast.TreeDSL { import treeInfo.{ unbind, isVarPattern } // Fresh patterns - final def emptyPatterns(i: Int): List[Pattern] = List.fill(i)(NoPattern) + def emptyPatterns(i: Int): List[Pattern] = List.fill(i)(NoPattern) + def emptyTrees(i: Int): List[Tree] = List.fill(i)(EmptyTree) // An empty pattern def NoPattern = WildcardPattern() @@ -82,12 +83,17 @@ trait Patterns extends ast.TreeDSL { } // 8.1.4 - case class StableIdPattern(tree: Ident) extends IdentifierPattern { } + case class StableIdPattern(tree: Ident) extends IdentifierPattern { + val Ident(name) = tree + } // 8.1.4 (b) - case class SelectPattern(tree: Select) extends IdentifierPattern { } + case class SelectPattern(tree: Select) extends IdentifierPattern { + val Select(qualifier, name) = tree + } trait IdentifierPattern extends Pattern { + val name: Name override def simplify(testVar: Symbol) = this.rebindToEqualsCheck() override def matchingType = mkSingleton } @@ -96,9 +102,9 @@ trait Patterns extends ast.TreeDSL { case class ConstructorPattern(tree: Apply) extends ApplyPattern { require(fn.isType && this.isCaseClass) - override def subpatterns(pats: MatchMatrix#PatternMatch) = - if (pats.isCaseHead) args map Pattern.apply - else super.subpatterns(pats) + override def subpatterns(pm: MatchMatrix#PatternMatch) = + if (pm.isCaseHead) args map Pattern.apply + else super.subpatterns(pm) override def simplify(testVar: Symbol) = if (args.isEmpty) this rebindToEmpty tree.tpe @@ -113,7 +119,7 @@ trait Patterns extends ast.TreeDSL { override def simplify(testVar: Symbol) = { def examinePrefix(path: Tree) = (path, path.tpe) match { - case (_, t: ThisType) => singleType(t, sym) + case (_, t: ThisType) => singleType(t, sym) // this.X case (_: Apply, _) => PseudoType(tree) case _ => singleType(Pattern(path).mkSingleton, sym) } @@ -167,10 +173,52 @@ trait Patterns extends ast.TreeDSL { } // 8.1.8 (b) (literal ArrayValues) - case class SequencePattern(tree: ArrayValue) extends Pattern { } + case class SequencePattern(tree: ArrayValue) extends Pattern { + lazy val ArrayValue(elemtpt, elems) = tree + lazy val elemPatterns = toPats(elems) + lazy val nonStarPatterns = if (hasStar) elemPatterns.init else elemPatterns + + def hasStar = isRightIgnoring(tree) + def nonStarLength = nonStarPatterns.length + def isAllDefaults = nonStarPatterns forall (_.isDefault) + + def rebindStar(seqType: Type): List[Pattern] = { + require(hasStar) + nonStarPatterns ::: List(elemPatterns.last rebindToType seqType) + } + + /** True if 'next' must be checked even if 'first' failed to match after passing its length test + * (the conditional supplied by getPrecondition.) This is an optimization to avoid checking sequences + * which cannot match due to a length incompatibility. + */ + + override def completelyCovers(second: Pattern): Boolean = { + if (second.isDefault) return false + + second match { + case x: SequencePattern => + val (len1, len2) = (nonStarLength, x.nonStarLength) + val (star1, star2) = (this.hasStar, x.hasStar) + + // this still needs rewriting. + val res = + ( star1 && star2 && len2 < len1 ) || // Seq(a,b,c,_*) followed by Seq(a,b,_*) because of (a,b) + ( star1 && !star2 && len2 < len1 && isAllDefaults ) || // Seq(a,b,c,_*) followed by Seq(a,b) because of (a,b) + // ( star1 && len2 < len1 ) || + (!star1 && star2 ) || + (!star1 && !star2 && len2 >= len1 ) + + !res + case _ => + // shouldn't happen... + false + } + } + } - // // 8.1.8 (b) - case class SequenceStarPattern(tree: ArrayValue) extends Pattern { } + // 8.1.8 (b) + // temporarily subsumed by SequencePattern + // case class SequenceStarPattern(tree: ArrayValue) extends Pattern { } // abstract trait ArrayValuePattern extends Pattern { // val tree: ArrayValue @@ -179,6 +227,11 @@ trait Patterns extends ast.TreeDSL { // def nonStarElems = if (isRightIgnoring) elems.init else elems // } + // 8.1.8 (c) + case class StarPattern(tree: Star) extends Pattern { + val Star(elem) = tree + } + // 8.1.9 // InfixPattern ... subsumed by Constructor/Extractor Patterns @@ -192,14 +245,6 @@ trait Patterns extends ast.TreeDSL { // XMLPattern ... for now, subsumed by SequencePattern, but if we want // to make it work right, it probably needs special handling. - // XXX - temporary pattern until we have integrated every tree type. - case class MiscPattern(tree: Tree) extends Pattern { - log("Resorted to MiscPattern: %s/%s".format(tree, tree.getClass)) - override def simplify(testVar: Symbol) = tree match { - case x: Ident => this.rebindToEqualsCheck() - case _ => super.simplify(testVar) - } - } object Pattern { def apply(tree: Tree): Pattern = tree match { @@ -212,9 +257,10 @@ trait Patterns extends ast.TreeDSL { case x: Literal => LiteralPattern(x) case x: UnApply => UnapplyPattern(x) case x: Ident => if (isVarPattern(x)) VariablePattern(x) else StableIdPattern(x) - case x: ArrayValue => if (isRightIgnoring(x)) SequenceStarPattern(x) else SequencePattern(x) + // case x: ArrayValue => if (isRightIgnoring(x)) SequenceStarPattern(x) else SequencePattern(x) + case x: ArrayValue => SequencePattern(x) case x: Select => SelectPattern(x) - case x: Star => MiscPattern(x) // XXX + case x: Star => StarPattern(x) case _ => abort("Unknown Tree reached pattern matcher: %s/%s".format(tree, tree.getClass)) } def unapply(other: Any): Option[(Tree, List[Symbol])] = other match { @@ -285,7 +331,7 @@ trait Patterns extends ast.TreeDSL { def simplify(testVar: Symbol): Pattern = this def simplify(): Pattern = this simplify NoSymbol - def subpatterns(pats: MatchMatrix#PatternMatch): List[Pattern] = pats.dummyPatterns + def subpatterns(pm: MatchMatrix#PatternMatch): List[Pattern] = pm.dummies // 8.1.13 // A pattern p is irrefutable for type T if any of the following applies: @@ -296,6 +342,9 @@ trait Patterns extends ast.TreeDSL { // pi is irrefutable for Ti. def irrefutableFor(tpe: Type) = false + // does this pattern completely cover that pattern (i.e. latter cannot be matched) + def completelyCovers(second: Pattern) = false + // Is this a default pattern (untyped "_" or an EmptyTree inserted by the matcher) def isDefault = false -- cgit v1.2.3