From 2672f972ebdab9e504bf88227ab0dd0046b2992d Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Mon, 5 Oct 2009 05:09:14 +0000 Subject: More run of the mill pattern matcher work. to enjoy it when I no longer have to work around bugs in the pattern matcher while implementing the pattern matcher. Metacircularity: more fun applied to features than to bugs! --- .../tools/nsc/matching/ParallelMatching.scala | 133 ++++++++-------- .../scala/tools/nsc/matching/PatternBindings.scala | 12 -- .../scala/tools/nsc/matching/PatternNodes.scala | 2 +- .../scala/tools/nsc/matching/Patterns.scala | 168 +++++++++++---------- .../scala/tools/nsc/transform/ExplicitOuter.scala | 2 +- 5 files changed, 160 insertions(+), 157 deletions(-) diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala index cbc7615e7e..65106abc7e 100644 --- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala +++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala @@ -29,6 +29,7 @@ trait ParallelMatching extends ast.TreeDSL import global.{ typer => _, _ } import definitions.{ AnyRefClass, IntClass, getProductArgs, productProj } + import treeInfo.{ isStar } import Types._ import CODE._ @@ -44,11 +45,8 @@ trait ParallelMatching extends ast.TreeDSL def logAndReturn[T](s: String, x: T): T = { log(s + x.toString) ; x } def traceAndReturn[T](s: String, x: T): T = { TRACE(s + x.toString) ; x } - /** Functions in transition - doomed upon completion of patternization. **/ - def isDefaultPattern(t: Tree) = cond(unbind(t)) { case EmptyTree | WILD() => true } - def isStar(t: Tree) = cond(unbind(t)) { case Star(q) => isDefaultPattern(q) } - def isRightIgnoring(t: Tree) = cond(unbind(t)) { case ArrayValue(_, xs) if !xs.isEmpty => isStar(xs.last) } - + /** Transition **/ + def isRightIgnoring(t: Tree) = cond(unbind(t)) { case ArrayValue(_, xs) if !xs.isEmpty => isStar(xs.last) } def getDummies(i: Int): List[Tree] = List.fill(i)(EmptyTree) def toPats(xs: List[Tree]): List[Pattern] = xs map Pattern.apply @@ -133,7 +131,38 @@ trait ParallelMatching extends ast.TreeDSL override def toString() = "Scrutinee(sym = %s, tpe = %s, id = %s)".format(sym, tpe, id) } - case class Patterns(scrut: Scrutinee, ps: List[Pattern]) { + def isPatternSwitch(scrut: Scrutinee, ps: List[Pattern]): Option[PatternSwitch] = { + def isSwitchableConst(x: Pattern) = cond(x) { case x: LiteralPattern if x.isSwitchable => true } + def isSwitchableDefault(x: Pattern) = isSwitchableConst(x) || x.isDefault + + // TODO - scala> (5: Any) match { case 5 => 5 ; case 6 => 7 } + // ... should compile to a switch. It doesn't because the scrut isn't Int/Char, but + // that could be handle in an if/else since every pattern requires an Int. + // More immediately, Byte and Short scruts should also work. + if (!scrut.isSimple) None + else { + val (lits, others) = { + val (l, o) = ps span isSwitchableConst + (l filterMap { case x: LiteralPattern => x }, o) + } + + condOpt(others) { + case Nil => new PatternSwitch(scrut, lits, None) + // TODO: This needs to also allow the case that the last is a compatible type pattern. + case List(x) if isSwitchableDefault(x) => new PatternSwitch(scrut, lits, Some(x)) + } + } + } + + class PatternSwitch( + scrut: Scrutinee, + override val ps: List[LiteralPattern], + val defaultPattern: Option[Pattern] + ) extends PatternMatch(scrut, ps) { + require(scrut.isSimple && (ps forall (_.isSwitchable))) + } + + case class PatternMatch(scrut: Scrutinee, ps: List[Pattern]) { private lazy val trees = ps map (_.boundTree) lazy val head = ps.head lazy val tail = ps.tail @@ -149,54 +178,33 @@ trait ParallelMatching extends ast.TreeDSL def zip[T](others: List[T]) = trees zip others def pzip[T](others: List[T]) = ps zip others - def extractSimpleSwitch(): Option[(List[Tree], Option[Pattern])] = { - def isSwitchableTag(tag: Int) = cond(tag) { case ByteTag | ShortTag | IntTag | CharTag => true } - def isSwitchableConst(t: Tree) = cond(unbind(t)) { case Literal(x: Constant) => isSwitchableTag(x.tag) } - def isSwitchableDefault(x: Tree) = isSwitchableConst(x) || isDefaultPattern(x) - - val (lits, others) = trees span isSwitchableConst - others match { - case Nil => Some((lits, None)) - // TODO: This needs to also allow the case that the last is a compatible type pattern. - case List(x) if isSwitchableDefault(x) => Some((lits, Some(Pattern(x)))) - case _ => None - } - } - // Any unapply - returns Some(true) if a type test is needed before the unapply can // be called (e.g. def unapply(x: Foo) = { ... } but our scrutinee is type Any.) object AnyUnapply { def unapply(x: Tree): Option[Boolean] = condOpt(x) { case UnapplyParamType(tpe) => !(scrut.tpe <:< tpe) } } - object SimpleSwitch { - // TODO - scala> (5: Any) match { case 5 => 5 ; case 6 => 7 } - // ... should compile to a switch. It doesn't because the scrut isn't Int/Char, but - // that could be handle in an if/else since every pattern requires an Int. - // More immediately, Byte and Short scruts should also work. - def unapply(x: Patterns) = if (x.scrut.isSimple) x.extractSimpleSwitch else None - } - def mkRule(rest: Rep): RuleApplication = { logAndReturn("mkRule: ", head.boundTree match { case x if isEquals(x.tpe) => new MixEquals(this, rest) case x: ArrayValue if isRightIgnoring(x) => new MixSequenceStar(this, rest) case x: ArrayValue => new MixSequence(this, rest) case AnyUnapply(false) => new MixUnapply(this, rest, false) - case _ => this match { - case SimpleSwitch(lits, d) => new MixLiteralInts(this, rest, lits, d) - case _ => new MixTypes(this, rest) - } + case _ => + isPatternSwitch(scrut, ps) match { + case Some(x) => new MixLiteralInts(x, rest) + case _ => new MixTypes(this, rest) + } } ) } - } // Patterns + } // PatternMatch /** picks which rewrite rule to apply * @precondition: column does not contain alternatives */ def MixtureRule(scrut: Scrutinee, column: List[Pattern], rest: Rep): RuleApplication = - Patterns(scrut, column) mkRule rest + PatternMatch(scrut, column) mkRule rest /** * Class encapsulating a guard expression in a pattern match: @@ -212,9 +220,9 @@ trait ParallelMatching extends ast.TreeDSL /***** Rule Applications *****/ sealed abstract class RuleApplication { - def pats: Patterns + def pats: PatternMatch def rest: Rep - lazy val Patterns(scrut, patterns) = pats + lazy val PatternMatch(scrut, patterns) = pats lazy val head = pats.head private def sym = scrut.sym @@ -239,7 +247,7 @@ trait ParallelMatching extends ast.TreeDSL } case class ErrorRule() extends RuleApplication { - def pats: Patterns = impossible + def pats: PatternMatch = impossible def rest: Rep = impossible final def tree() = failTree } @@ -247,7 +255,7 @@ trait ParallelMatching extends ast.TreeDSL /** {case ... if guard => bx} else {guardedRest} */ /** VariableRule: The top-most rows has only variable (non-constructor) patterns. */ case class VariableRule(subst: Bindings, guard: Guard, guardedRest: Rep, bx: Int) extends RuleApplication { - def pats: Patterns = impossible + def pats: PatternMatch = impossible def rest: Rep = guardedRest final def tree(): Tree = { @@ -264,16 +272,11 @@ trait ParallelMatching extends ast.TreeDSL /** Mixture rule for all literal ints (and chars) i.e. hopefully a switch * will be emitted on the JVM. */ - class MixLiteralInts( - val pats: Patterns, - val rest: Rep, - literals: List[Tree], - defaultPattern: Option[Pattern]) - extends RuleApplication + class MixLiteralInts(val pats: PatternSwitch, val rest: Rep) extends RuleApplication { - private object NUM { - def unapply(x: Tree): Option[Int] = condOpt(unbind(x)) { case Literal(c) => c.intValue } - } + val literals = pats.ps + val defaultPattern = pats.defaultPattern + // bound vars and rows for default pattern (only one row, but a list is easier to use later) val (defaultVars, defaultRows) = defaultPattern match { case None => (Nil, Nil) @@ -282,8 +285,8 @@ trait ParallelMatching extends ast.TreeDSL // literalMap is a map from each literal to a list of row indices. // varMap is a list from each literal to a list of the defined vars. val (literalMap, varMap) = { - val tags = literals map { case NUM(tag) => tag } - val varMap = tags zip (literals map definedVars) + val tags = literals map (_.intValue) + val varMap = tags zip (literals map (_.definedVars)) val litMap = tags.zipWithIndex.reverse.foldLeft(IntMap.empty[List[Int]]) { // we reverse before the fold so the list can be built with :: @@ -342,7 +345,7 @@ trait ParallelMatching extends ast.TreeDSL /** mixture rule for unapply pattern */ - class MixUnapply(val pats: Patterns, val rest: Rep, typeTest: Boolean) extends RuleApplication { + class MixUnapply(val pats: PatternMatch, val rest: Rep, typeTest: Boolean) extends RuleApplication { // Note: trailingArgs is not necessarily Nil, because unapply can take implicit parameters. lazy val ua @ UnApply(app, args) = head.tree lazy val Apply(fxn, _ :: trailingArgs) = app @@ -437,13 +440,13 @@ trait ParallelMatching extends ast.TreeDSL /** handle sequence pattern and ArrayValue (but not star patterns) */ - sealed class MixSequence(val pats: Patterns, val rest: Rep) extends RuleApplication { + sealed class MixSequence(val pats: PatternMatch, val rest: Rep) extends RuleApplication { /** array elements except for star (if present) */ protected def nonStarElems(x: ArrayValue) = if (isRightIgnoring(x)) x.elems.init else x.elems protected def elemLength(x: ArrayValue) = nonStarElems(x).length - protected def isAllDefaults(x: ArrayValue) = nonStarElems(x) forall isDefaultPattern + protected def isAllDefaults(x: ArrayValue) = nonStarElems(x) forall (t => Pattern(t).isDefault) final def removeStar(xs: List[Pattern]): List[Pattern] = xs.init ::: List(Pattern(makeBind(xs.last.boundVariables, WILD(scrut.seqType)))) @@ -462,7 +465,7 @@ trait ParallelMatching extends ast.TreeDSL * which cannot match due to a length incompatibility. */ protected def mustCheck(first: Tree, next: Tree): Boolean = - (first ne next) && (isDefaultPattern(next) || cond((first, next)) { + (first ne next) && (Pattern(next).isDefault || cond((first, next)) { case (av: ArrayValue, bv: ArrayValue) => // number of non-star elements in each sequence val (len1, len2) = (elemLength(av), elemLength(bv)) @@ -526,16 +529,15 @@ trait ParallelMatching extends ast.TreeDSL /** handle sequence pattern and ArrayValue with star patterns */ - final class MixSequenceStar(pats: Patterns, rest: Rep) extends MixSequence(pats, rest) { - // in principle, we could optimize more, but variable binding gets complicated (@todo use finite state methods instead) + final class MixSequenceStar(pats: PatternMatch, rest: Rep) extends MixSequence(pats, rest) { // override def getSubPatterns(minlen: Int, x: Tree): Option[List[Pattern]] = { // implicit val min = minlen // implicit val tpe = scrut.seqType - // Pattern(x) match { - // case SeqStarSubPatterns(xs) => Some(xs) - // case _ => None - // } + // + // condOpt(Pattern(x)) { case SeqStarSubPatterns(xs) => xs } // } + + // in principle, we could optimize more, but variable binding gets complicated (@todo use finite state methods instead) override def getSubPatterns(minlen: Int, x: Tree): Option[List[Pattern]] = condOpt(x) { case av @ ArrayValue(_,xs) if (!isRightIgnoring(av) && xs.length == minlen) => // Seq(p1,...,pN) toPats(xs ::: List(gen.mkNil, EmptyTree)) @@ -556,7 +558,7 @@ trait ParallelMatching extends ast.TreeDSL } // @todo: equals test for same constant - class MixEquals(val pats: Patterns, val rest: Rep) extends RuleApplication { + class MixEquals(val pats: PatternMatch, val rest: Rep) extends RuleApplication { /** condition (to be used in IF), success and failure Rep */ final def getTransition(): (Branch[Tree], Symbol) = { val value = { @@ -594,7 +596,7 @@ trait ParallelMatching extends ast.TreeDSL /** mixture rule for type tests **/ - class MixTypes(val pats: Patterns, val rest: Rep) extends RuleApplication { + class MixTypes(val pats: PatternMatch, val rest: Rep) extends RuleApplication { // see bug1434.scala for an illustration of why "x <:< y" is insufficient. // this code is definitely inadequate at best. Inherited comment: // @@ -626,7 +628,7 @@ trait ParallelMatching extends ast.TreeDSL def sEqualsP = p =:= s def alts[T](yes: T, no: T): T = if (sEqualsP) yes else no - def isObjectTest = pattern.isObject && (pats.headType =:= pattern.mkSingleton) + def isObjectTest = pattern.isObject && (p =:= pattern.mkSingleton) lazy val dummy = (j, pats.dummyPatterns) lazy val pass = (j, pattern) @@ -843,7 +845,7 @@ trait ParallelMatching extends ast.TreeDSL cond(p.tree) { case _: UnApply | _: ArrayValue => true - case x => isDefaultPattern(x) || coversSym + case x => p.isDefault || coversSym } } } @@ -982,13 +984,14 @@ trait ParallelMatching extends ast.TreeDSL val (rows, finals) = List.unzip( for ((CaseDef(pat, guard, body), index) <- cases.zipWithIndex) yield { def mkRow(ps: List[Pattern]) = Row(ps, NoBinding, Guard(guard), index) + def pattern = Pattern(pat) def rowForPat = pat match { - case _ if roots.length <= 1 => mkRow(List(Pattern(pat))) + case _ if roots.length <= 1 => mkRow(List(pattern)) case Apply(fn, args) => mkRow(toPats(args)) case WILD() => mkRow(emptyPatterns(roots.length)) } - (rowForPat, FinalState(NoBinding, body, definedVars(pat))) + (rowForPat, FinalState(NoBinding, body, pattern.definedVars)) } ) diff --git a/src/compiler/scala/tools/nsc/matching/PatternBindings.scala b/src/compiler/scala/tools/nsc/matching/PatternBindings.scala index 0fd3043281..575050821a 100644 --- a/src/compiler/scala/tools/nsc/matching/PatternBindings.scala +++ b/src/compiler/scala/tools/nsc/matching/PatternBindings.scala @@ -27,18 +27,6 @@ trait PatternBindings extends ast.TreeDSL override def safeToString: String = "PseudoType("+o+")" } - final def definedVars(x: Tree): List[Symbol] = { - def vars(x: Tree): List[Symbol] = x match { - case Apply(_, args) => args flatMap vars - case b @ Bind(_, p) => b.symbol :: vars(p) - case Typed(p, _) => vars(p) // otherwise x @ (_:T) - case UnApply(_, args) => args flatMap vars - case ArrayValue(_, xs) => xs flatMap vars - case x => Nil - } - vars(x) reverse - } - // If the given pattern contains alternatives, return it as a list of patterns. // Makes typed copies of any bindings found so all alternatives point to final state. def extractBindings(p: Tree, prevBindings: Tree => Tree = identity[Tree] _): List[Tree] = { diff --git a/src/compiler/scala/tools/nsc/matching/PatternNodes.scala b/src/compiler/scala/tools/nsc/matching/PatternNodes.scala index 6a0e48f9c0..c28a0790b2 100644 --- a/src/compiler/scala/tools/nsc/matching/PatternNodes.scala +++ b/src/compiler/scala/tools/nsc/matching/PatternNodes.scala @@ -47,7 +47,7 @@ trait PatternNodes extends ast.TreeDSL val listRef = typeRef(pre, ListClass, List(tpe)) def fold(x: Tree, xs: Tree) = unbind(x) match { - case _: Star => makeBind(definedVars(x), WILD(x.tpe)) + case _: Star => makeBind(Pattern(x).definedVars, WILD(x.tpe)) case _ => val dummyMethod = new TermSymbol(NoSymbol, NoPosition, "matching$dummy") val consType = MethodType(dummyMethod newSyntheticValueParams List(tpe, listRef), consRef) diff --git a/src/compiler/scala/tools/nsc/matching/Patterns.scala b/src/compiler/scala/tools/nsc/matching/Patterns.scala index 393a9df585..cdaa37621f 100644 --- a/src/compiler/scala/tools/nsc/matching/Patterns.scala +++ b/src/compiler/scala/tools/nsc/matching/Patterns.scala @@ -43,12 +43,15 @@ trait Patterns extends ast.TreeDSL { def NoPattern = WildcardPattern() // The constant null pattern - def NullPattern = Pattern(NULL) + def NullPattern = LiteralPattern(NULL) // 8.1.1 case class VariablePattern(tree: Ident) extends Pattern { + require(tree != Ident(nme.WILDCARD)) + override def irrefutableFor(tpe: Type) = true override def simplify(testVar: Symbol) = this.rebindToEqualsCheck() + // override def matchingType = mkSingleton ??? XXX } @@ -56,6 +59,7 @@ trait Patterns extends ast.TreeDSL { case class WildcardPattern() extends Pattern { val tree = EmptyTree override def irrefutableFor(tpe: Type) = true + override def isDefault = true } // 8.1.2 @@ -70,16 +74,20 @@ trait Patterns extends ast.TreeDSL { } // 8.1.3 - case class LiteralPattern(tree: Literal) extends Pattern { } + case class LiteralPattern(tree: Literal) extends Pattern { + val Literal(const @ Constant(value)) = tree - // 8.1.4 - case class StableIdPattern(tree: Ident) extends Pattern { - override def simplify(testVar: Symbol) = this.rebindToEqualsCheck() - override def matchingType = mkSingleton + def isSwitchable = cond(const.tag) { case ByteTag | ShortTag | IntTag | CharTag => true } + def intValue = const.intValue } + // 8.1.4 + case class StableIdPattern(tree: Ident) extends IdentifierPattern { } + // 8.1.4 (b) - case class SelectPattern(tree: Select) extends Pattern { + case class SelectPattern(tree: Select) extends IdentifierPattern { } + + trait IdentifierPattern extends Pattern { override def simplify(testVar: Symbol) = this.rebindToEqualsCheck() override def matchingType = mkSingleton } @@ -88,7 +96,7 @@ trait Patterns extends ast.TreeDSL { case class ConstructorPattern(tree: Apply) extends ApplyPattern { require(fn.isType && this.isCaseClass) - override def subpatterns(pats: MatchMatrix#Patterns) = + override def subpatterns(pats: MatchMatrix#PatternMatch) = if (pats.isCaseHead) args map Pattern.apply else super.subpatterns(pats) @@ -164,13 +172,20 @@ trait Patterns extends ast.TreeDSL { // // 8.1.8 (b) case class SequenceStarPattern(tree: ArrayValue) extends Pattern { } + // abstract trait ArrayValuePattern extends Pattern { + // val tree: ArrayValue + // lazy val av @ ArrayValue(elemTpt, elems) = tree + // lazy val elemPatterns = toPats(elems) + // def nonStarElems = if (isRightIgnoring) elems.init else elems + // } + // 8.1.9 // InfixPattern ... subsumed by Constructor/Extractor Patterns // 8.1.10 case class AlternativePattern(tree: Alternative) extends Pattern { private val Alternative(subtrees) = tree - // override def subpatterns(pats: MatchMatrix#Patterns) = subtrees map Pattern.apply + // override def subpatterns(pats: PatternMatch) = subtrees map Pattern.apply } // 8.1.11 @@ -187,13 +202,10 @@ trait Patterns extends ast.TreeDSL { } object Pattern { - def isDefaultPattern(t: Tree) = cond(unbind(t)) { case EmptyTree | WILD() => true } - def isStar(t: Tree) = cond(unbind(t)) { case Star(q) => isDefaultPattern(q) } - def isRightIgnoring(t: Tree) = cond(unbind(t)) { case ArrayValue(_, xs) if !xs.isEmpty => isStar(xs.last) } - def apply(tree: Tree): Pattern = tree match { case x: Bind => apply(unbind(tree)) withBoundTree x - case EmptyTree | WILD() => WildcardPattern() + case EmptyTree => WildcardPattern() + case Ident(nme.WILDCARD) => WildcardPattern() case x @ Alternative(ps) => AlternativePattern(x) case x: Apply => ApplyPattern(x) case x: Typed => TypedPattern(x) @@ -266,14 +278,14 @@ trait Patterns extends ast.TreeDSL { // trait SimplePattern extends Pattern { // def simplify(testVar: Symbol): Pattern = this // } - sealed abstract class Pattern { + sealed abstract class Pattern extends PatternBindingLogic { val tree: Tree // The logic formerly in classifyPat, returns either a simplification // of this pattern or identity. def simplify(testVar: Symbol): Pattern = this def simplify(): Pattern = this simplify NoSymbol - def subpatterns(pats: MatchMatrix#Patterns): List[Pattern] = pats.dummyPatterns + def subpatterns(pats: MatchMatrix#PatternMatch): List[Pattern] = pats.dummyPatterns // 8.1.13 // A pattern p is irrefutable for type T if any of the following applies: @@ -284,38 +296,8 @@ trait Patterns extends ast.TreeDSL { // pi is irrefutable for Ti. def irrefutableFor(tpe: Type) = false - // XXX only a var for short-term experimentation. - private var _boundTree: Bind = null - def boundTree = if (_boundTree == null) tree else _boundTree - def withBoundTree(x: Bind): this.type = { - _boundTree = x - this - } - lazy val boundVariables = strip(boundTree) - - private def wrapBindings(vs: List[Symbol], pat: Tree): Tree = vs match { - case Nil => pat - case x :: xs => Bind(x, wrapBindings(xs, pat)) setType pat.tpe - } - - // If a tree has bindings, boundTree looks something like - // Bind(v3, Bind(v2, Bind(v1, tree))) - // This takes the given tree and creates a new pattern - // using the same bindings. - def rebindTo(t: Tree): Pattern = - Pattern(wrapBindings(boundVariables, t)) - - // Wrap this pattern's bindings around (_: Type) - def rebindToType(tpe: Type): Pattern = - rebindTo(Typed(WILD(tpe), TypeTree(tpe)) setType tpe) - - // Wrap them around _ - def rebindToEmpty(tpe: Type): Pattern = - rebindTo(Typed(EmptyTree, TypeTree(tpe)) setType tpe) - - // Wrap them around a singleton type for an EqualsPattern check. - def rebindToEqualsCheck(): Pattern = - rebindToType(equalsCheck) + // Is this a default pattern (untyped "_" or an EmptyTree inserted by the matcher) + def isDefault = false def sym = tree.symbol def tpe = tree.tpe @@ -328,7 +310,7 @@ trait Patterns extends ast.TreeDSL { def isSymValid = (sym != null) && (sym != NoSymbol) def isModule = sym.isModule || tpe.termSymbol.isModule def isCaseClass = tpe.typeSymbol hasFlag Flags.CASE - def isObject = isSymValid && prefix.isStable + def isObject = isSymValid && prefix.isStable // XXX not entire logic def setType(tpe: Type): this.type = { tree setType tpe @@ -344,21 +326,12 @@ trait Patterns extends ast.TreeDSL { case _ => singleType(prefix, sym) } - final def isDefault = cond(tree) { case EmptyTree | WILD() => true } - final def isStar = cond(tree) { case Star(q) => Pattern(q).isDefault } final def isAlternative = cond(tree) { case Alternative(_) => true } - final def isRightIgnoring = cond(tree) { case ArrayValue(_, xs) if !xs.isEmpty => Pattern(xs.last).isStar } - - /** Helpers **/ - private def strip(t: Tree): List[Symbol] = t match { - case b @ Bind(_, pat) => b.symbol :: strip(pat) - case _ => Nil - } /** Standard methods **/ def copy(tree: Tree = this.tree): Pattern = - if (_boundTree == null) Pattern(tree) - else Pattern(tree) withBoundTree _boundTree + if (boundTree eq tree) Pattern(tree) + else Pattern(tree) withBoundTree boundTree.asInstanceOf[Bind] // override def toString() = "Pattern(%s, %s)".format(tree, boundVariables) override def equals(other: Any) = other match { @@ -368,6 +341,62 @@ trait Patterns extends ast.TreeDSL { override def hashCode() = boundTree.hashCode() } + trait PatternBindingLogic { + self: Pattern => + + // XXX only a var for short-term experimentation. + private var _boundTree: Bind = null + def boundTree = if (_boundTree == null) tree else _boundTree + def withBoundTree(x: Bind): this.type = { + _boundTree = x + this + } + lazy val boundVariables = strip(boundTree) + + def definedVars = definedVarsInternal(boundTree) + private def definedVarsInternal(x: Tree): List[Symbol] = { + def vars(x: Tree): List[Symbol] = x match { + case Apply(_, args) => args flatMap vars + case b @ Bind(_, p) => b.symbol :: vars(p) + case Typed(p, _) => vars(p) // otherwise x @ (_:T) + case UnApply(_, args) => args flatMap vars + case ArrayValue(_, xs) => xs flatMap vars + case x => Nil + } + vars(x) reverse + } + + private def wrapBindings(vs: List[Symbol], pat: Tree): Tree = vs match { + case Nil => pat + case x :: xs => Bind(x, wrapBindings(xs, pat)) setType pat.tpe + } + + // If a tree has bindings, boundTree looks something like + // Bind(v3, Bind(v2, Bind(v1, tree))) + // This takes the given tree and creates a new pattern + // using the same bindings. + def rebindTo(t: Tree): Pattern = + Pattern(wrapBindings(boundVariables, t)) + + // Wrap this pattern's bindings around (_: Type) + def rebindToType(tpe: Type): Pattern = + rebindTo(Typed(WILD(tpe), TypeTree(tpe)) setType tpe) + + // Wrap them around _ + def rebindToEmpty(tpe: Type): Pattern = + rebindTo(Typed(EmptyTree, TypeTree(tpe)) setType tpe) + + // Wrap them around a singleton type for an EqualsPattern check. + def rebindToEqualsCheck(): Pattern = + rebindToType(equalsCheck) + + /** Helpers **/ + private def strip(t: Tree): List[Symbol] = t match { + case b @ Bind(_, pat) => b.symbol :: strip(pat) + case _ => Nil + } + } + /*** Extractors ***/ object UnapplyParamType { @@ -377,12 +406,6 @@ trait Patterns extends ast.TreeDSL { } } } - // - // protected def getSubPatterns(len: Int, x: Tree): Option[List[Pattern]] = condOpt(x) { - // case av @ ArrayValue(_,xs) if !isRightIgnoring(av) && xs.length == len => toPats(xs) ::: List(NoPattern) - // case av @ ArrayValue(_,xs) if isRightIgnoring(av) && xs.length == len+1 => removeStar(toPats(xs)) // (*) - // case EmptyTree | WILD() => emptyPatterns(len + 1) - // } object SeqStarSubPatterns { def removeStar(xs: List[Tree], seqType: Type): List[Pattern] = { @@ -390,22 +413,11 @@ trait Patterns extends ast.TreeDSL { ps.init ::: List(ps.last rebindToType seqType) } - // override def getSubPatterns(minlen: Int, x: Tree): Option[List[Pattern]] = condOpt(x) { - // case av @ ArrayValue(_,xs) if (!isRightIgnoring(av) && xs.length == minlen) => // Seq(p1,...,pN) - // toPats(xs ::: List(gen.mkNil, EmptyTree)) - // case av @ ArrayValue(_,xs) if ( isRightIgnoring(av) && xs.length-1 == minlen) => // Seq(p1,...,pN,_*) - // removeStar(toPats(xs)) ::: List(NoPattern) - // case av @ ArrayValue(_,xs) if ( isRightIgnoring(av) && xs.length-1 < minlen) => // Seq(p1..,pJ,_*) J < N - // emptyPatterns(minlen + 1) ::: List(Pattern(x)) - // case EmptyTree | WILD() => - // emptyPatterns(minlen + 1 + 1) - // } - def unapply(x: Pattern)(implicit min: Int, seqType: Type): Option[List[Pattern]] = x.tree match { case av @ ArrayValue(_, xs) => if (!isRightIgnoring(av) && xs.length == min) Some(toPats(xs ::: List(gen.mkNil, EmptyTree))) // Seq(p1,...,pN) else if ( isRightIgnoring(av) && xs.length-1 == min) Some(removeStar(xs, seqType) ::: List(NoPattern)) // Seq(p1,...,pN,_*) - else if ( isRightIgnoring(av) && xs.length-1 == min) Some(emptyPatterns(min + 1) ::: List(x)) // Seq(p1..,pJ,_*) J < N + else if ( isRightIgnoring(av) && xs.length-1 < min) Some(emptyPatterns(min + 1) ::: List(x)) // Seq(p1..,pJ,_*) J < N else None case _ => if (x.isDefault) Some(emptyPatterns(min + 1 + 1)) else None diff --git a/src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala b/src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala index 9419bba23e..c5d8e481fe 100644 --- a/src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala +++ b/src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala @@ -362,7 +362,7 @@ abstract class ExplicitOuter extends InfoTransform val gdcall = if (guard == EmptyTree) EmptyTree else { - val vs = definedVars(p) + val vs = Pattern(p).definedVars val guardDef = makeGuardDef(vs, guard) nguard += transform(guardDef) // building up list of guards -- cgit v1.2.3