diff options
-rw-r--r-- | src/compiler/scala/tools/nsc/ast/TreeDSL.scala | 1 | ||||
-rw-r--r-- | src/compiler/scala/tools/nsc/matching/ParallelMatching.scala | 91 | ||||
-rw-r--r-- | src/compiler/scala/tools/nsc/matching/Patterns.scala | 84 | ||||
-rw-r--r-- | test/files/pos/bug2187-2.scala | 7 | ||||
-rw-r--r-- | test/files/pos/bug2945.scala | 12 | ||||
-rw-r--r-- | test/files/run/bug2958.scala | 16 | ||||
-rw-r--r-- | test/files/run/bug3150.scala | 10 | ||||
-rw-r--r-- | test/files/run/bug3395.check | 2 | ||||
-rw-r--r-- | test/files/run/bug3395.scala | 13 | ||||
-rw-r--r-- | test/files/run/patmat-seqs.check | 13 | ||||
-rw-r--r-- | test/files/run/patmat-seqs.scala | 42 |
11 files changed, 204 insertions, 87 deletions
diff --git a/src/compiler/scala/tools/nsc/ast/TreeDSL.scala b/src/compiler/scala/tools/nsc/ast/TreeDSL.scala index fd13958053..9aa36de703 100644 --- a/src/compiler/scala/tools/nsc/ast/TreeDSL.scala +++ b/src/compiler/scala/tools/nsc/ast/TreeDSL.scala @@ -93,6 +93,7 @@ trait TreeDSL { def INT_| (other: Tree) = fn(target, getMember(IntClass, nme.OR), other) def INT_& (other: Tree) = fn(target, getMember(IntClass, nme.AND), other) + def INT_>= (other: Tree) = fn(target, getMember(IntClass, nme.GE), other) def INT_== (other: Tree) = fn(target, getMember(IntClass, nme.EQ), other) def INT_!= (other: Tree) = fn(target, getMember(IntClass, nme.NE), other) diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala index 1806dec2d2..77997c4565 100644 --- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala +++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala @@ -175,7 +175,6 @@ trait ParallelMatching extends ast.TreeDSL def isCaseHead = head.isCaseClass private val dummyCount = if (isCaseHead) headType.typeSymbol.caseFieldAccessors.length else 0 def dummies = emptyPatterns(dummyCount) - // def dummies = head.dummies def apply(i: Int): Pattern = ps(i) def pzip() = ps.zipWithIndex @@ -443,32 +442,92 @@ trait ParallelMatching extends ast.TreeDSL * Note: pivot == head, just better typed. */ sealed class MixSequence(val pmatch: PatternMatch, val rest: Rep, pivot: SequencePattern) extends RuleApplication { + require(scrut.tpe <:< head.tpe) + def hasStar = pivot.hasStar - private def pivotLen = pivot.nonStarLength + private def pivotLen = pivot.nonStarLength + private def seqDummies = emptyPatterns(pivot.elems.length + 1) // one pattern var per sequence element up to elemCount, and one more for the rest of the sequence lazy val pvs = scrut createSequenceVars pivotLen - // divide the remaining rows into success/failure branches, expanding subsequences of patterns - private lazy val rowsplit = { - require(scrut.tpe <:< head.tpe) + // Should the given pattern join the expanded pivot in the success matrix? If so, + // this partial function will be defined for the pattern, and the result of the apply + // is the expanded sequence of new patterns. + lazy val successMatrixFn = new PartialFunction[Pattern, List[Pattern]] { + private def seqIsDefinedAt(x: SequenceLikePattern) = (hasStar, x.hasStar) match { + case (true, true) => true + case (true, false) => pivotLen <= x.nonStarLength + case (false, true) => pivotLen >= x.nonStarLength + case (false, false) => pivotLen == x.nonStarLength + } - val res = for ((c, rows) <- pmatch pzip rest.rows) yield { - def canSkip = pivot canSkipSubsequences c - def passthrough(skip: Boolean) = if (skip) None else Some(rows insert c) + def isDefinedAt(pat: Pattern) = pat match { + case x: SequenceLikePattern => seqIsDefinedAt(x) + case WildcardPattern() => true + case _ => false + } - pivot.subsequences(c, scrut.seqType) match { - case Some(ps) => (Some(rows insert ps), passthrough(canSkip)) - case None => (None, passthrough(false)) - } + def apply(pat: Pattern): List[Pattern] = pat match { + case x: SequenceLikePattern => + def isSameLength = pivotLen == x.nonStarLength + def rebound = x.nonStarPatterns :+ (x.elemPatterns.last rebindTo WILD(scrut.seqType)) + + (pivot.hasStar, x.hasStar, isSameLength) match { + case (true, true, true) => rebound :+ NoPattern + case (true, true, false) => (seqDummies drop 1) :+ x + case (true, false, true) => x.elemPatterns ++ List(NilPattern, NoPattern) + case (false, true, true) => rebound + case (false, false, true) => x.elemPatterns :+ NoPattern + case _ => seqDummies + } + + case _ => seqDummies } + } - res.unzip match { case (l1, l2) => (l1.flatten, l2.flatten) } + // Should the given pattern be in the fail matrix? This is true of any sequences + // as long as the result of the length test on the pivot doesn't make it impossible: + // for instance if neither sequence is right ignoring and they are of different + // lengths, the later one cannot match since its length must be wrong. + def failureMatrixFn(c: Pattern) = (pivot ne c) && (c match { + case x: SequenceLikePattern => + (hasStar, x.hasStar) match { + case (_, true) => true + case (true, false) => pivotLen > x.nonStarLength + case (false, false) => pivotLen != x.nonStarLength + } + case WildcardPattern() => true + case _ => false + }) + + // divide the remaining rows into success/failure branches, expanding subsequences of patterns + val successRows = pmatch pzip rest.rows collect { + case (c, row) if successMatrixFn isDefinedAt c => row insert successMatrixFn(c) + } + val failRows = pmatch pzip rest.rows collect { + case (c, row) if failureMatrixFn(c) => row insert c } - lazy val cond = (pivot precondition pmatch).get // length check - lazy val success = squeezedBlockPVs(pvs, remake(rowsplit._1, pvs, hasStar).toTree) - lazy val failure = remake(rowsplit._2).toTree + // the discrimination test for sequences is a call to lengthCompare. Note that + // this logic must be fully consistent wiith successMatrixFn and failureMatrixFn above: + // any inconsistency will (and frequently has) manifested as pattern matcher crashes. + lazy val cond = { + // the method call symbol + val methodOp: Symbol = head.tpe member nme.lengthCompare + + // the comparison to perform. If the pivot is right ignoring, then a scrutinee sequence + // of >= pivot length could match it; otherwise it must be exactly equal. + val compareOp: (Tree, Tree) => Tree = if (hasStar) _ INT_>= _ else _ INT_== _ + + // scrutinee.lengthCompare(pivotLength) [== | >=] 0 + val compareFn: Tree => Tree = (t: Tree) => compareOp((t DOT methodOp)(LIT(pivotLen)), ZERO) + + // wrapping in a null check on the scrutinee + nullSafe(compareFn, FALSE)(scrut.id) + } + lazy val success = squeezedBlockPVs(pvs, remake(successRows, pvs, hasStar).toTree) + lazy val failure = remake(failRows).toTree final def tree(): Tree = codegen } diff --git a/src/compiler/scala/tools/nsc/matching/Patterns.scala b/src/compiler/scala/tools/nsc/matching/Patterns.scala index a21a9c7d9f..d35049c1e5 100644 --- a/src/compiler/scala/tools/nsc/matching/Patterns.scala +++ b/src/compiler/scala/tools/nsc/matching/Patterns.scala @@ -177,9 +177,9 @@ trait Patterns extends ast.TreeDSL { } // 8.1.8 (unapplySeq calls) - case class SequenceExtractorPattern(tree: UnApply) extends UnapplyPattern { + case class SequenceExtractorPattern(tree: UnApply) extends UnapplyPattern with SequenceLikePattern { - private val UnApply( + lazy val UnApply( Apply(TypeApply(Select(_, nme.unapplySeq), List(tptArg)), _), List(ArrayValue(_, elems)) ) = tree @@ -211,88 +211,34 @@ trait Patterns extends ast.TreeDSL { override def description = "UnSeq(%s => %s)".format(tptArg, resTypesString) } - abstract class SequencePattern extends Pattern { - val tree: ArrayValue - def nonStarPatterns: List[Pattern] - def subsequences(other: Pattern, seqType: Type): Option[List[Pattern]] - def canSkipSubsequences(second: Pattern): Boolean - - lazy val ArrayValue(elemtpt, elems) = tree - lazy val elemPatterns = toPats(elems) - - override def dummies = emptyPatterns(elems.length + 1) - override def subpatternsForVars: List[Pattern] = elemPatterns + trait SequenceLikePattern extends Pattern { + def elems: List[Tree] + def elemPatterns = toPats(elems) + def nonStarPatterns: List[Pattern] = if (hasStar) elemPatterns.init else elemPatterns def nonStarLength = nonStarPatterns.length def isAllDefaults = nonStarPatterns forall (_.isDefault) - def isShorter(other: SequencePattern) = nonStarLength < other.nonStarLength - def isSameLength(other: SequencePattern) = nonStarLength == other.nonStarLength - - protected def lengthCheckOp: (Tree, Tree) => Tree = - if (hasStar) _ ANY_>= _ - else _ MEMBER_== _ - - // optimization to avoid trying to match if length makes it impossible - override def precondition(pm: PatternMatch) = { - import pm.{ scrut, head } - val len = nonStarLength - val compareOp = head.tpe member nme.lengthCompare - - def cmpFunction(t1: Tree) = lengthCheckOp((t1 DOT compareOp)(LIT(len)), ZERO) + def isShorter(other: SequenceLikePattern) = nonStarLength < other.nonStarLength + def isSameLength(other: SequenceLikePattern) = nonStarLength == other.nonStarLength + } - Some(nullSafe(cmpFunction _, FALSE)(scrut.id)) - } + abstract class SequencePattern extends Pattern with SequenceLikePattern { + val tree: ArrayValue + lazy val ArrayValue(elemtpt, elems) = tree - /** True if 'next' must be checked even if 'first' failed to match after passing its length test - * (the conditional supplied by getPrecondition.) This is an optimization to avoid checking sequences - * which cannot match due to a length incompatibility. - */ + override def subpatternsForVars: List[Pattern] = elemPatterns override def description = "Seq(%s)".format(elemPatterns) } // 8.1.8 (b) (literal ArrayValues) case class SequenceNoStarPattern(tree: ArrayValue) extends SequencePattern { require(!hasStar) - lazy val nonStarPatterns = elemPatterns - - // no star - def subsequences(other: Pattern, seqType: Type): Option[List[Pattern]] = - condOpt(other) { - case next: SequenceStarPattern if isSameLength(next) => next rebindStar seqType - case next: SequenceNoStarPattern if isSameLength(next) => next.elemPatterns ::: List(NoPattern) - case WildcardPattern() | (_: SequencePattern) => dummies - } - - def canSkipSubsequences(second: Pattern): Boolean = - (tree eq second.tree) || (cond(second) { - case x: SequenceNoStarPattern => (x isShorter this) && this.isAllDefaults - }) } // 8.1.8 (b) case class SequenceStarPattern(tree: ArrayValue) extends SequencePattern { require(hasStar) - lazy val nonStarPatterns = elemPatterns.init - - // yes star - private def nilPats = List(NilPattern, NoPattern) - def subsequences(other: Pattern, seqType: Type): Option[List[Pattern]] = - condOpt(other) { - case next: SequenceStarPattern if isSameLength(next) => (next rebindStar seqType) ::: List(NoPattern) - case next: SequenceStarPattern if (next isShorter this) => (dummies drop 1) ::: List(next) - case next: SequenceNoStarPattern if isSameLength(next) => next.elemPatterns ::: nilPats - case WildcardPattern() | (_: SequencePattern) => dummies - } - - def rebindStar(seqType: Type): List[Pattern] = - nonStarPatterns ::: List(elemPatterns.last rebindTo WILD(seqType)) - - def canSkipSubsequences(second: Pattern): Boolean = - (tree eq second.tree) || (cond(second) { - case x: SequenceStarPattern => this isShorter x - case x: SequenceNoStarPattern => !(x isShorter this) - }) override def description = "Seq*(%s)".format(elemPatterns) } @@ -504,10 +450,6 @@ trait Patterns extends ast.TreeDSL { // the right number of dummies for this pattern def dummies: List[Pattern] = Nil - // given this scrutinee, what if any condition must be satisfied before - // we even try to match? - def precondition(scrut: PatternMatch): Option[Tree] = None - // 8.1.13 // A pattern p is irrefutable for type T if any of the following applies: // 1) p is a variable pattern diff --git a/test/files/pos/bug2187-2.scala b/test/files/pos/bug2187-2.scala new file mode 100644 index 0000000000..3f2742dd89 --- /dev/null +++ b/test/files/pos/bug2187-2.scala @@ -0,0 +1,7 @@ +class Test { + def test[A](list: List[A]) = list match { + case Seq(x, y) => "xy" + case Seq(x) => "x" + case _ => "something else" + } +}
\ No newline at end of file diff --git a/test/files/pos/bug2945.scala b/test/files/pos/bug2945.scala new file mode 100644 index 0000000000..762bdb61e1 --- /dev/null +++ b/test/files/pos/bug2945.scala @@ -0,0 +1,12 @@ +object Foo { + def test(s: String) = { + (s: Seq[Char]) match { + case Seq('f', 'o', 'o', ' ', rest1 @ _*) => + rest1 + case Seq('b', 'a', 'r', ' ', ' ', rest2 @ _*) => + rest2 + case _ => + s + } + } +}
\ No newline at end of file diff --git a/test/files/run/bug2958.scala b/test/files/run/bug2958.scala new file mode 100644 index 0000000000..dcd24ecc36 --- /dev/null +++ b/test/files/run/bug2958.scala @@ -0,0 +1,16 @@ +object Test { + def f(args: Array[String]) = args match { + case Array("-p", prefix, from, to) => + prefix + from + to + + case Array(from, to) => + from + to + + case _ => + "default" + } + + def main(args: Array[String]) { + assert(f(Array("1", "2")) == "12") + } +}
\ No newline at end of file diff --git a/test/files/run/bug3150.scala b/test/files/run/bug3150.scala new file mode 100644 index 0000000000..034703b5f7 --- /dev/null +++ b/test/files/run/bug3150.scala @@ -0,0 +1,10 @@ +object Test { + case object Bob { override def equals(other: Any) = true } + def f(x: Any) = x match { case Bob => Bob } + + def main(args: Array[String]): Unit = { + assert(f(Bob) eq Bob) + assert(f(0) eq Bob) + assert(f(Nil) eq Bob) + } +} diff --git a/test/files/run/bug3395.check b/test/files/run/bug3395.check new file mode 100644 index 0000000000..5f5521fae2 --- /dev/null +++ b/test/files/run/bug3395.check @@ -0,0 +1,2 @@ +abc +def diff --git a/test/files/run/bug3395.scala b/test/files/run/bug3395.scala new file mode 100644 index 0000000000..b4990a1716 --- /dev/null +++ b/test/files/run/bug3395.scala @@ -0,0 +1,13 @@ +object Test { + def main(args: Array[String]): Unit = { + Seq("") match { + case Seq("") => println("abc") + case Seq(_, _, x) => println(x) + } + + Seq(1, 2, "def") match { + case Seq("") => println("abc") + case Seq(_, _, x) => println(x) + } + } +}
\ No newline at end of file diff --git a/test/files/run/patmat-seqs.check b/test/files/run/patmat-seqs.check new file mode 100644 index 0000000000..bb2a5ee44a --- /dev/null +++ b/test/files/run/patmat-seqs.check @@ -0,0 +1,13 @@ +s3 +s2 +s1 +s0 +ss6 +d +s3 +s3 +d +s1 +s3 +d +d diff --git a/test/files/run/patmat-seqs.scala b/test/files/run/patmat-seqs.scala new file mode 100644 index 0000000000..b5c47b4b4b --- /dev/null +++ b/test/files/run/patmat-seqs.scala @@ -0,0 +1,42 @@ +object Test { + def f1(x: Any) = x match { + case Seq(1, 2, 3) => "s3" + case Seq(4, 5) => "s2" + case Seq(7) => "s1" + case Nil => "s0" + case Seq(_, _, _, _, _, x: String) => "ss6" + case _ => "d" + } + + def f2(x: Any) = x match { + case Seq("a", "b", _*) => "s2" + case Seq(1, _*) => "s1" + case Seq(5, 6, 7, _*) => "s3" + case _ => "d" + } + + def main(args: Array[String]): Unit = { + val xs1 = List( + List(1,2,3), + List(4,5), + Vector(7), + Seq(), + Seq(1, 2, 3, 4, 5, "abcd"), + "abc" + ) map f1 + + xs1 foreach println + + val xs2 = List( + Seq(5, 6, 7), + Seq(5, 6, 7, 8, 9), + Seq("a"), + Seq(1, 6, 7), + List(5, 6, 7), + Nil, + 5 + ) map f2 + + xs2 foreach println + } +} |