From 61d34ed3fd61ab60ddea10445c737a3c9a6aa525 Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Thu, 16 Feb 2012 19:41:31 -0800 Subject: An old patch on pattern matcher exhaustivity. Simplifies the implementation a bit. --- .../scala/tools/nsc/matching/MatchSupport.scala | 4 ++ src/compiler/scala/tools/nsc/matching/Matrix.scala | 9 ++-- .../scala/tools/nsc/matching/MatrixAdditions.scala | 16 +----- .../tools/nsc/matching/ParallelMatching.scala | 2 +- .../scala/tools/nsc/matching/PatternBindings.scala | 5 +- .../scala/tools/nsc/matching/Patterns.scala | 61 +++++++++++++--------- 6 files changed, 50 insertions(+), 47 deletions(-) (limited to 'src') diff --git a/src/compiler/scala/tools/nsc/matching/MatchSupport.scala b/src/compiler/scala/tools/nsc/matching/MatchSupport.scala index 5e46960d04..371f4bc4d8 100644 --- a/src/compiler/scala/tools/nsc/matching/MatchSupport.scala +++ b/src/compiler/scala/tools/nsc/matching/MatchSupport.scala @@ -115,6 +115,10 @@ trait MatchSupport extends ast.TreeDSL { self: ParallelMatching => println(fmt.format(xs: _*) + " == " + x) x } + private[nsc] def debugging[T](fmt: String, xs: Any*)(x: T): T = { + if (settings.debug.value) printing(fmt, xs: _*)(x) + else x + } def indent(s: Any) = s.toString() split "\n" map (" " + _) mkString "\n" def indentAll(s: Seq[Any]) = s map (" " + _.toString() + "\n") mkString diff --git a/src/compiler/scala/tools/nsc/matching/Matrix.scala b/src/compiler/scala/tools/nsc/matching/Matrix.scala index d81f05cd51..e1ff88557e 100644 --- a/src/compiler/scala/tools/nsc/matching/Matrix.scala +++ b/src/compiler/scala/tools/nsc/matching/Matrix.scala @@ -198,6 +198,10 @@ trait Matrix extends MatrixAdditions { class PatternVar(val lhs: Symbol, val rhs: Tree, val checked: Boolean) { def sym = lhs def tpe = lhs.tpe + if (checked) + lhs resetFlag NO_EXHAUSTIVE + else + lhs setFlag NO_EXHAUSTIVE // See #1427 for an example of a crash which occurs unless we retype: // in that instance there is an existential in the pattern. @@ -207,11 +211,6 @@ trait Matrix extends MatrixAdditions { override def toString() = "%s: %s = %s".format(lhs, tpe, rhs) } - /** Sets the rhs to EmptyTree, which makes the valDef ignored in Scrutinee. - */ - def specialVar(lhs: Symbol, checked: Boolean) = - new PatternVar(lhs, EmptyTree, checked) - /** Given a tree, creates a new synthetic variable of the same type * and assigns the tree to it. */ diff --git a/src/compiler/scala/tools/nsc/matching/MatrixAdditions.scala b/src/compiler/scala/tools/nsc/matching/MatrixAdditions.scala index 24d3c38e74..e72a0007a0 100644 --- a/src/compiler/scala/tools/nsc/matching/MatrixAdditions.scala +++ b/src/compiler/scala/tools/nsc/matching/MatrixAdditions.scala @@ -131,23 +131,11 @@ trait MatrixAdditions extends ast.TreeDSL { import Flags.{ MUTABLE, ABSTRACT, SEALED } - private case class Combo(index: Int, sym: Symbol) { - val isBaseClass = sym.tpe.baseClasses.toSet - - // is this combination covered by the given pattern? - def isCovered(p: Pattern) = { - def coversSym = isBaseClass(decodedEqualsType(p.tpe).typeSymbol) - - cond(p.tree) { - case _: UnApply | _: ArrayValue => true - case x => p.isDefault || coversSym - } - } - } + private case class Combo(index: Int, sym: Symbol) { } /* True if the patterns in 'row' cover the given type symbol combination, and has no guard. */ private def rowCoversCombo(row: Row, combos: List[Combo]) = - row.guard.isEmpty && (combos forall (c => c isCovered row.pats(c.index))) + row.guard.isEmpty && combos.forall(c => row.pats(c.index) covers c.sym) private def requiresExhaustive(sym: Symbol) = { (sym.isMutable) && // indicates that have not yet checked exhaustivity diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala index 9d4c9b4411..1285e29d4a 100644 --- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala +++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala @@ -745,7 +745,7 @@ trait ParallelMatching extends ast.TreeDSL (others.head :: _column.tail, make(_tvars, _rows)) def mix() = { - val newScrut = new Scrutinee(specialVar(_pv.sym, _pv.checked)) + val newScrut = new Scrutinee(new PatternVar(_pv.sym, EmptyTree, _pv.checked)) PatternMatch(newScrut, _ncol) mkRule _nrep } } diff --git a/src/compiler/scala/tools/nsc/matching/PatternBindings.scala b/src/compiler/scala/tools/nsc/matching/PatternBindings.scala index 5dd7d8f3ee..56297f0195 100644 --- a/src/compiler/scala/tools/nsc/matching/PatternBindings.scala +++ b/src/compiler/scala/tools/nsc/matching/PatternBindings.scala @@ -19,9 +19,10 @@ trait PatternBindings extends ast.TreeDSL import Debug._ /** EqualsPattern **/ - def isEquals(tpe: Type) = cond(tpe) { case TypeRef(_, EqualsPatternClass, _) => true } + def isEquals(tpe: Type) = tpe.typeSymbol == EqualsPatternClass def mkEqualsRef(tpe: Type) = typeRef(NoPrefix, EqualsPatternClass, List(tpe)) - def decodedEqualsType(tpe: Type) = condOpt(tpe) { case TypeRef(_, EqualsPatternClass, List(arg)) => arg } getOrElse (tpe) + def decodedEqualsType(tpe: Type) = + if (tpe.typeSymbol == EqualsPatternClass) tpe.typeArgs.head else tpe // A subtype test which creates fresh existentials for type // parameters on the right hand side. diff --git a/src/compiler/scala/tools/nsc/matching/Patterns.scala b/src/compiler/scala/tools/nsc/matching/Patterns.scala index 18409cfffe..a6d8556db3 100644 --- a/src/compiler/scala/tools/nsc/matching/Patterns.scala +++ b/src/compiler/scala/tools/nsc/matching/Patterns.scala @@ -26,19 +26,6 @@ trait Patterns extends ast.TreeDSL { type PatternMatch = MatchMatrix#PatternMatch private type PatternVar = MatrixContext#PatternVar - // private def unapplyArgs(x: Any) = x match { - // case UnApply(Apply(TypeApply(_, targs), args), _) => (targs map (_.symbol), args map (_.symbol)) - // case _ => (Nil, Nil) - // } - // - // private def unapplyCall(x: Any) = x match { - // case UnApply(t, _) => treeInfo.methPart(t).symbol - // case _ => NoSymbol - // } - - private lazy val dummyMethod = - NoSymbol.newTermSymbol(newTermName("matching$dummy")) - // Fresh patterns def emptyPatterns(i: Int): List[Pattern] = List.fill(i)(NoPattern) def emptyTrees(i: Int): List[Tree] = List.fill(i)(EmptyTree) @@ -56,13 +43,14 @@ trait Patterns extends ast.TreeDSL { case class VariablePattern(tree: Ident) extends NamePattern { lazy val Ident(name) = tree require(isVarPattern(tree) && name != nme.WILDCARD) - + override def covers(sym: Symbol) = true override def description = "%s".format(name) } // 8.1.1 (b) case class WildcardPattern() extends Pattern { - def tree = EmptyTree + val tree = EmptyTree + override def covers(sym: Symbol) = true override def isDefault = true override def description = "_" } @@ -71,6 +59,8 @@ trait Patterns extends ast.TreeDSL { case class TypedPattern(tree: Typed) extends Pattern { lazy val Typed(expr, tpt) = tree + override def covers(sym: Symbol) = newMatchesPattern(sym, tpt.tpe) + override def sufficientType = tpt.tpe override def subpatternsForVars: List[Pattern] = List(Pattern(expr)) override def simplify(pv: PatternVar) = Pattern(expr) match { case ExtractorPattern(ua) if pv.sym.tpe <:< tpt.tpe => this rebindTo expr @@ -115,6 +105,7 @@ trait Patterns extends ast.TreeDSL { } } + override def covers(sym: Symbol) = newMatchesPattern(sym, sufficientType) override def simplify(pv: PatternVar) = this.rebindToObjectCheck() override def description = backticked match { case Some(s) => "this." + s @@ -133,13 +124,15 @@ trait Patterns extends ast.TreeDSL { case class ObjectPattern(tree: Apply) extends ApplyPattern { // NamePattern? require(!fn.isType && isModule) + override def covers(sym: Symbol) = newMatchesPattern(sym, sufficientType) override def sufficientType = tpe.narrow override def simplify(pv: PatternVar) = this.rebindToObjectCheck() override def description = "Obj(%s)".format(fn) } // 8.1.4 (e) case class SimpleIdPattern(tree: Ident) extends NamePattern { - lazy val Ident(name) = tree + val Ident(name) = tree + override def covers(sym: Symbol) = newMatchesPattern(sym, tpe.narrow) override def description = "Id(%s)".format(name) } @@ -163,6 +156,11 @@ trait Patterns extends ast.TreeDSL { if (args.isEmpty) this rebindToEmpty tree.tpe else this + override def covers(sym: Symbol) = { + debugging("[constructor] Does " + this + " cover " + sym + " ? ") { + sym.tpe.typeSymbol == this.tpe.typeSymbol + } + } override def description = { if (isColonColon) "%s :: %s".format(Pattern(args(0)), Pattern(args(1))) else "%s(%s)".format(name, toPats(args).mkString(", ")) @@ -175,17 +173,12 @@ trait Patterns extends ast.TreeDSL { // 8.1.7 / 8.1.8 (unapply and unapplySeq calls) case class ExtractorPattern(tree: UnApply) extends UnapplyPattern { - override def simplify(pv: PatternVar) = { - if (pv.sym hasFlag NO_EXHAUSTIVE) () - else { - TRACE("Setting NO_EXHAUSTIVE on " + pv.sym + " due to extractor " + tree) - pv.sym setFlag NO_EXHAUSTIVE - } + private def uaTyped = Typed(tree, TypeTree(arg.tpe)) setType arg.tpe + override def simplify(pv: PatternVar) = { if (pv.tpe <:< arg.tpe) this else this rebindTo uaTyped } - override def description = "Unapply(%s => %s)".format(necessaryType, resTypesString) } @@ -208,6 +201,7 @@ trait Patterns extends ast.TreeDSL { private def listFolder(hd: Tree, tl: Tree): Tree = unbind(hd) match { case t @ Star(_) => moveBindings(hd, WILD(t.tpe)) case _ => + val dummyMethod = NoSymbol.newTermSymbol(newTermName("matching$dummy")) val consType = MethodType(dummyMethod newSyntheticValueParams List(packedType, listRef), consRef) Apply(TypeTree(consType), List(hd, tl)) setType consRef @@ -376,7 +370,7 @@ trait Patterns extends ast.TreeDSL { case _: This if isVariableName(name) => Some("`%s`".format(name)) case _ => None } - + override def covers(sym: Symbol) = newMatchesPattern(sym, tree.tpe) protected def getPathSegments(t: Tree): List[Name] = t match { case Select(q, name) => name :: getPathSegments(q) case Apply(f, Nil) => getPathSegments(f) @@ -395,7 +389,13 @@ trait Patterns extends ast.TreeDSL { lazy val UnApply(unfn, args) = tree lazy val Apply(fn, _) = unfn lazy val MethodType(List(arg, _*), _) = fn.tpe - protected def uaTyped = Typed(tree, TypeTree(arg.tpe)) setType arg.tpe + + // Covers if the symbol matches the unapply method's argument type, + // and the return type of the unapply is Some. + override def covers(sym: Symbol) = newMatchesPattern(sym, arg.tpe) + + // TODO: for alwaysCovers: + // fn.tpe.finalResultType.typeSymbol == SomeClass override def necessaryType = arg.tpe override def subpatternsForVars = args match { @@ -419,6 +419,7 @@ trait Patterns extends ast.TreeDSL { else emptyPatterns(sufficientType.typeSymbol.caseFieldAccessors.size) def isConstructorPattern = fn.isType + override def covers(sym: Symbol) = newMatchesPattern(sym, fn.tpe) } sealed abstract class Pattern extends PatternBindingLogic { @@ -443,6 +444,15 @@ trait Patterns extends ast.TreeDSL { // the subpatterns for this pattern (at the moment, that means constructor arguments) def subpatterns(pm: MatchMatrix#PatternMatch): List[Pattern] = pm.dummies + // if this pattern should be considered to cover the given symbol + def covers(sym: Symbol): Boolean = newMatchesPattern(sym, sufficientType) + def newMatchesPattern(sym: Symbol, pattp: Type) = { + debugging("[" + kindString + "] Does " + pattp + " cover " + sym + " ? ") { + (sym.isModuleClass && (sym.tpe.typeSymbol eq pattp.typeSymbol)) || + (sym.tpe.baseTypeSeq exists (_ matchesPattern pattp)) + } + } + def sym = tree.symbol def tpe = tree.tpe def isEmpty = tree.isEmpty @@ -475,6 +485,7 @@ trait Patterns extends ast.TreeDSL { final override def toString = description def toTypeString() = "%s <: x <: %s".format(necessaryType, sufficientType) + def kindString = "" } /*** Extractors ***/ -- cgit v1.2.3