From caa2d287d640a1917467c310f9240cd133a41f4a Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Fri, 2 Oct 2009 04:44:46 +0000 Subject: Most of this big patch is organizational, but t... Most of this big patch is organizational, but there's also a healthy dose of new code. If my last few changelog comments about patternization were at all unclear, you can now see the skeleton of what I'm after in matching/Patterns.scala. --- .../scala/tools/nsc/matching/MatchUtil.scala | 2 +- src/compiler/scala/tools/nsc/matching/Matrix.scala | 56 ++++++ .../tools/nsc/matching/ParallelMatching.scala | 150 ++++----------- .../scala/tools/nsc/matching/PatternBindings.scala | 94 +++++++++ .../scala/tools/nsc/matching/PatternNodes.scala | 71 ------- .../tools/nsc/matching/PatternOptimizer.scala | 137 ++++++++++++++ .../scala/tools/nsc/matching/Patterns.scala | 210 ++++++++++++++++----- .../scala/tools/nsc/matching/TransMatcher.scala | 137 +------------- .../scala/tools/nsc/transform/ExplicitOuter.scala | 2 +- 9 files changed, 503 insertions(+), 356 deletions(-) create mode 100644 src/compiler/scala/tools/nsc/matching/Matrix.scala create mode 100644 src/compiler/scala/tools/nsc/matching/PatternBindings.scala create mode 100644 src/compiler/scala/tools/nsc/matching/PatternOptimizer.scala diff --git a/src/compiler/scala/tools/nsc/matching/MatchUtil.scala b/src/compiler/scala/tools/nsc/matching/MatchUtil.scala index 2c7668564e..fde881beaf 100644 --- a/src/compiler/scala/tools/nsc/matching/MatchUtil.scala +++ b/src/compiler/scala/tools/nsc/matching/MatchUtil.scala @@ -12,7 +12,7 @@ object MatchUtil import collection.mutable.ListBuffer def impossible: Nothing = abort("this never happens") - def abort(msg: String): Nothing = throw new RuntimeException(msg) + def abort(msg: String): Nothing = Predef.error(msg) /** Transforms a list of triples into a triple of lists. * diff --git a/src/compiler/scala/tools/nsc/matching/Matrix.scala b/src/compiler/scala/tools/nsc/matching/Matrix.scala new file mode 100644 index 0000000000..d02576208e --- /dev/null +++ b/src/compiler/scala/tools/nsc/matching/Matrix.scala @@ -0,0 +1,56 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2009 LAMP/EPFL + * Author: Paul Phillips + */ + +package scala.tools.nsc +package matching + +import transform.ExplicitOuter +import util.Position + +trait Matrix extends PatternOptimizer { + self: ExplicitOuter with ParallelMatching => + + import global.{ typer => _, _ } + import analyzer.Typer + import CODE._ + + case class MatrixInit( + roots: List[Symbol], + cases: List[CaseDef], + default: Tree + ) + + case class MatrixContext( + handleOuter: TreeFunction1, // Tree => Tree function + typer: Typer, // a local typer + owner: Symbol, // the current owner + matchResultType: Type) // the expected result type of the whole match + extends Squeezer + { + def newVar( + pos: Position, + tpe: Type, + flags: List[Long] = Nil, + name: Name = null): Symbol = + { + val n: Name = if (name == null) newName(pos, "temp") else name + // careful: pos has special meaning + owner.newVariable(pos, n) setInfo tpe setFlag (0L /: flags)(_|_) + } + + def typedValDef(x: Symbol, rhs: Tree) = { + val finalRhs = x.tpe match { + case WildcardType => + rhs setType null + x setInfo (typer typed rhs).tpe + rhs + case _ => + typer.typed(rhs, x.tpe) + } + typer typed (VAL(x) === finalRhs) + } + } + +} \ No newline at end of file diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala index 4d1b52da93..d06bfcf575 100644 --- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala +++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala @@ -9,11 +9,14 @@ package scala.tools.nsc package matching import util.Position +import transform.ExplicitOuter +import symtab.Flags import collection._ -import mutable.BitSet +import mutable.{ BitSet, HashMap, ListBuffer } import immutable.IntMap import MatchUtil._ import annotation.elidable +import Function.tupled /** Translation of match expressions. * @@ -36,20 +39,21 @@ import annotation.elidable * * @author Burak Emir */ -trait ParallelMatching extends ast.TreeDSL { - self: transform.ExplicitOuter with PatternNodes => +trait ParallelMatching extends ast.TreeDSL + with Matrix + with Patterns + with PatternBindings + with PatternOptimizer + with PatternNodes +{ + self: ExplicitOuter => import global.{ typer => _, _ } - import definitions.{ AnyRefClass, EqualsPatternClass, IntClass, getProductArgs, productProj } - import symtab.Flags + import definitions.{ AnyRefClass, IntClass, getProductArgs, productProj } import Types._ import CODE._ - import scala.Function.tupled - // XXX temp - def toPats(xs: List[Tree]): List[Pattern] = xs map (x => Pattern(x)) - - // debugging val, set to true with -Ypmat-debug + /** Debugging support: enable with -Ypmat-debug **/ private final def trace = settings.Ypmatdebug.value def ifDebug(body: => Unit): Unit = { if (settings.debug.value) body } @@ -61,27 +65,20 @@ trait ParallelMatching extends ast.TreeDSL { def logAndReturn[T](s: String, x: T): T = { log(s + x.toString) ; x } def traceAndReturn[T](s: String, x: T): T = { TRACE(s + x.toString) ; x } - // Tests on misc + /** Functions in transition - doomed upon completion of patternization. **/ def isDefaultPattern(t: Tree) = cond(unbind(t)) { case EmptyTree | WILD() => true } def isStar(t: Tree) = cond(unbind(t)) { case Star(q) => isDefaultPattern(q) } def isRightIgnoring(t: Tree) = cond(unbind(t)) { case ArrayValue(_, xs) if !xs.isEmpty => isStar(xs.last) } - // If the given pattern contains alternatives, return it as a list of patterns. - // Makes typed copies of any bindings found so all alternatives point to final state. - def extractBindings(p: Tree, prevBindings: Tree => Tree = identity[Tree] _): List[Tree] = { - def newPrev(b: Bind) = (x: Tree) => treeCopy.Bind(b, b.name, x) setType x.tpe + def getDummies(i: Int): List[Tree] = List.fill(i)(EmptyTree) + def toPats(xs: List[Tree]): List[Pattern] = xs map Pattern.apply - p match { - case b @ Bind(_, body) => extractBindings(body, newPrev(b)) - case Alternative(ps) => ps map prevBindings - } - } + /** Back to the regular schedule. **/ - import collection.mutable.{ HashMap, ListBuffer } - class MatchMatrix(context: MatchMatrixContext, data: MatchMatrixInit) { + class MatchMatrix(val context: MatrixContext, data: MatrixInit) extends MatchMatrixOptimizer { import context._ - val MatchMatrixInit(roots, cases, failTree) = data + val MatrixInit(roots, cases, failTree) = data val ExpandedMatrix(rows, targets) = expand(roots, cases) val expansion: Rep = make(roots, rows) @@ -92,57 +89,6 @@ trait ParallelMatching extends ast.TreeDSL { -shortCuts.length } - final def cleanup(tree: Tree): Tree = { - // Extractors which can spot pure true/false expressions - // even through the haze of braces - abstract class SeeThroughBlocks[T] { - protected def unapplyImpl(x: Tree): T - def unapply(x: Tree): T = x match { - case Block(Nil, expr) => unapply(expr) - case _ => unapplyImpl(x) - } - } - object IsTrue extends SeeThroughBlocks[Boolean] { - protected def unapplyImpl(x: Tree): Boolean = x equalsStructure TRUE - } - object IsFalse extends SeeThroughBlocks[Boolean] { - protected def unapplyImpl(x: Tree): Boolean = x equalsStructure FALSE - } - object lxtt extends Transformer { - override def transform(tree: Tree): Tree = tree match { - case blck @ Block(vdefs, ld @ LabelDef(name, params, body)) => - def shouldInline(t: FinalState) = t.isReachedOnce && (t.label eq ld.symbol) - - if (targets exists shouldInline) squeezedBlock(vdefs, body) - else blck - - case t => - super.transform(t match { - // note - it is too early for any other true/false related optimizations - case If(cond, IsTrue(), IsFalse()) => cond - - case If(cond1, If(cond2, thenp, elsep1), elsep2) if (elsep1 equalsStructure elsep2) => - IF (cond1 AND cond2) THEN thenp ELSE elsep1 - case If(cond1, If(cond2, thenp, Apply(jmp, Nil)), ld: LabelDef) if jmp.symbol eq ld.symbol => - IF (cond1 AND cond2) THEN thenp ELSE ld - case t => t - }) - } - } - object resetTraverser extends Traverser { - import Flags._ - def reset(vd: ValDef) = - if (vd.symbol hasFlag SYNTHETIC) vd.symbol resetFlag (TRANS_FLAG|MUTABLE) - - override def traverse(x: Tree): Unit = x match { - case vd: ValDef => reset(vd) - case _ => super.traverse(x) - } - } - - returning[Tree](resetTraverser traverse _)(lxtt transform tree) - } - /** first time bx is requested, a LabelDef is returned. next time, a jump. * the function takes care of binding */ @@ -168,9 +114,9 @@ trait ParallelMatching extends ast.TreeDSL { def classifyPat(opat: Pattern, j: Int): Pattern = { def vars = opat.boundVariables - def rebind(t: Pattern) = Pattern(makeBind(vars, t.tree)) + def rebind(t: Pattern) = Pattern(makeBind(vars, t.boundTree)) def rebindEmpty(tpe: Type) = Pattern(mkEmptyTreeBind(vars, tpe)) - def rebindTyped() = Pattern(mkTypedBind(vars, equalsCheck(opat.stripped))) + def rebindTyped() = Pattern(mkTypedBind(vars, equalsCheck(opat.tree))) // @pre for doUnapplySeq: is not right-ignoring (no star pattern) ; no exhaustivity check def doUnapplySeq(tptArg: Tree, xs: List[Tree]) = { @@ -226,7 +172,7 @@ trait ParallelMatching extends ast.TreeDSL { { case x => abort("Unexpected pattern: " + x.getClass + " => " + x) } ) reduceLeft (_ orElse _) - f(opat.stripped) + f(opat.tree) } val rows = row1 flatMap (_ expandAlternatives classifyPat) @@ -236,9 +182,6 @@ trait ParallelMatching extends ast.TreeDSL { override def toString() = "MatchMatrix(%s)".format(targets) - /** Intended to be the DFA created from the match matrix. */ - class MatchAutomaton(matrix: MatchMatrix) { } - /** * Encapsulates a symbol being matched on. * @@ -277,13 +220,13 @@ trait ParallelMatching extends ast.TreeDSL { } case class Patterns(scrut: Scrutinee, ps: List[Pattern]) { - private lazy val trees = ps map (_.tree) + private lazy val trees = ps map (_.boundTree) lazy val head = ps.head lazy val tail = ps.tail lazy val size = ps.length - lazy val headType = head.stripped match { - case p @ (_:Ident | _:Select) => head.unbound.mkSingleton // should be singleton object + lazy val headType = head.tree match { + case p @ (_:Ident | _:Select) => head.mkSingleton // should be singleton object case __UnApply(_,argtpe,_) => argtpe // ?? why argtpe? case _ => head.tpe } @@ -330,7 +273,7 @@ trait ParallelMatching extends ast.TreeDSL { } def mkRule(rest: Rep): RuleApplication = - logAndReturn("mkRule: ", head.tree match { + logAndReturn("mkRule: ", head.boundTree match { case x if isEquals(x.tpe) => new MixEquals(this, rest) case x: ArrayValue if isRightIgnoring(x) => new MixSequenceStar(this, rest) case x: ArrayValue => new MixSequence(this, rest) @@ -529,7 +472,7 @@ trait ParallelMatching extends ast.TreeDSL { */ class MixUnapply(val pats: Patterns, val rest: Rep, typeTest: Boolean) extends RuleApplication { // Note: trailingArgs is not necessarily Nil, because unapply can take implicit parameters. - lazy val ua @ UnApply(app, args) = head.stripped + lazy val ua @ UnApply(app, args) = head.tree lazy val Apply(fxn, _ :: trailingArgs) = app object sameUnapplyCall { @@ -561,7 +504,7 @@ trait ParallelMatching extends ast.TreeDSL { def mkNewRows(sameFilter: (List[Tree]) => List[Tree], dum: Int) = for ((pat @ Strip(vs, p), r) <- zipped) yield p match { case sameUnapplyCall(args) => r.insert2(toPats(sameFilter(args)) ::: List(NoPattern), vs, scrut.sym) - case _ => r insert (getDummyPatterns(dum) ::: List(Pattern(pat))) + case _ => r insert (emptyPatterns(dum) ::: List(Pattern(pat))) } def mkGet(s: Symbol) = typedValDef(s, fn(ID(unapplyRes), nme.get)) def mkVar(tpe: Type) = newVarCapture(ua.pos, tpe) @@ -633,7 +576,7 @@ trait ParallelMatching extends ast.TreeDSL { protected def getSubPatterns(len: Int, x: Tree): Option[List[Pattern]] = condOpt(x) { case av @ ArrayValue(_,xs) if !isRightIgnoring(av) && xs.length == len => toPats(xs) ::: List(NoPattern) case av @ ArrayValue(_,xs) if isRightIgnoring(av) && xs.length == len+1 => removeStar(toPats(xs)) // (*) - case EmptyTree | WILD() => getDummyPatterns(len + 1) + case EmptyTree | WILD() => emptyPatterns(len + 1) } protected def makeSuccRep(vs: List[Symbol], tail: Symbol, nrows: List[Row]) = @@ -663,7 +606,7 @@ trait ParallelMatching extends ast.TreeDSL { def getTransition(): Branch[TransitionContext] = { assert(scrut.tpe <:< head.tpe, "fatal: %s is not <:< %s".format(scrut, head.tpe)) - val av @ ArrayValue(_, elems) = head.tree + val av @ ArrayValue(_, elems) = head.boundTree val ys = if (isRightIgnoring(av)) elems.init else elems val vs = ys map (y => newVar(unbind(y).pos, scrut.elemType)) def scrutCopy = scrut.id.duplicate @@ -716,9 +659,9 @@ trait ParallelMatching extends ast.TreeDSL { case av @ ArrayValue(_,xs) if ( isRightIgnoring(av) && xs.length-1 == minlen) => // Seq(p1,...,pN,_*) removeStar(toPats(xs)) ::: List(NoPattern) case av @ ArrayValue(_,xs) if ( isRightIgnoring(av) && xs.length-1 < minlen) => // Seq(p1..,pJ,_*) J < N - getDummyPatterns(minlen + 1) ::: List(Pattern(x)) + emptyPatterns(minlen + 1) ::: List(Pattern(x)) case EmptyTree | WILD() => - getDummyPatterns(minlen + 1 + 1) + emptyPatterns(minlen + 1 + 1) } override protected def makeSuccRep(vs: List[Symbol], tail: Symbol, nrows: List[Row]) = @@ -745,7 +688,7 @@ trait ParallelMatching extends ast.TreeDSL { val label = owner.newLabel(scrut.pos, newName(scrut.pos, "failCont%")) // warning, untyped val succ = List( rest.rows.head.insert2(List(NoPattern), head.boundVariables, scrut.sym), - Row(getDummyPatterns(1 + rest.tvars.length), NoBinding, NoGuard, shortCut(label)) + Row(emptyPatterns(1 + rest.tvars.length), NoBinding, NoGuard, shortCut(label)) ) // todo: optimize if no guard, and no further tests @@ -766,13 +709,11 @@ trait ParallelMatching extends ast.TreeDSL { } } - case class PatPair(moreSpecific: Tree, moreGeneral: Tree, index: Int) - /** mixture rule for type tests **/ class MixTypes(val pats: Patterns, val rest: Rep) extends RuleApplication { private def subpatterns(p: Pattern): List[Pattern] = - p.stripped match { + p.tree match { case app @ Apply(fn, ps) if isCaseClass(app.tpe) && fn.isType => if (pats.isCaseHead) toPats(ps) else pats.dummyPatterns case Apply(fn, xs) if !xs.isEmpty || fn.isType => abort("strange Apply") case _ => pats.dummyPatterns @@ -790,7 +731,7 @@ trait ParallelMatching extends ast.TreeDSL { lazy val isDef = isDefaultPattern(pat) lazy val dummy = (j, pats.dummies) lazy val pass = (j, pat) - lazy val subs = (j, subpatterns(Pattern(pat)) map (_.tree)) + lazy val subs = (j, subpatterns(Pattern(pat)) map (_.boundTree)) lazy val cmpOld: TypeComp = spat.tpe cmp pats.headType // contains type info about pattern's type vs. head pattern import cmpOld.{ erased } @@ -803,19 +744,6 @@ trait ParallelMatching extends ast.TreeDSL { def xIsaY = s <:< p def yIsaX = p <:< s - // XXX exploring what breaks things and what doesn't - // def dummyIsOk = { - // val old = erased.yIsaX || yIsaX || isDef - // println("Old logic: %s || %s || %s == %s".format(erased.yIsaX, yIsaX, isDef, erased.yIsaX || yIsaX || isDef)) - // println("isCaseClass(spat.tpe) = %s, isCaseClass(pats.headType) = %s".format( - // isCaseClass(spat.tpe), isCaseClass(pats.headType))) - // println("spat.tpe = %s, pats.head = %s, pats.headType = %s".format( - // spat.tpe, pats.head, pats.headType)) - // - // (erased.yIsaX || yIsaX || isDef) - // // (!isCaseClass(spat.tpe) || !isCaseClass(pats.headType)) - // } - // each pattern will yield a triple of options corresponding to the three lists, // which will be flattened down to the values implicit def mkOpt[T](x: T): Option[T] = Some(x) // limits noise from Some(value) @@ -933,7 +861,7 @@ trait ParallelMatching extends ast.TreeDSL { case index => val (prefix, alts :: suffix) = newPats splitAt index // make a new row for each alternative, with it spliced into the original position - extractBindings(alts.tree) map (x => replace(prefix ::: Pattern(x) :: suffix)) + extractBindings(alts.boundTree) map (x => replace(prefix ::: Pattern(x) :: suffix)) } } override def toString() = { @@ -998,7 +926,7 @@ trait ParallelMatching extends ast.TreeDSL { case class Combo(index: Int, sym: Symbol) { // is this combination covered by the given pattern? - def isCovered(p: Pattern) = cond(p.stripped) { + def isCovered(p: Pattern) = cond(p.tree) { case _: UnApply | _: ArrayValue => true case x => isDefaultPattern(x) || (p.tpe coversSym sym) } @@ -1011,7 +939,7 @@ trait ParallelMatching extends ast.TreeDSL { import Flags._ /** Converts this to a tree - recursively acquires subreps. */ - final def toTree(): Tree = this.applyRule.tree + final def toTree(): Tree = this.applyRule.tree() private def toUse(s: Symbol) = (s hasFlag MUTABLE) && // indicates that have not yet checked exhaustivity diff --git a/src/compiler/scala/tools/nsc/matching/PatternBindings.scala b/src/compiler/scala/tools/nsc/matching/PatternBindings.scala new file mode 100644 index 0000000000..2f55ab3c2e --- /dev/null +++ b/src/compiler/scala/tools/nsc/matching/PatternBindings.scala @@ -0,0 +1,94 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2009 LAMP/EPFL + * Author: Paul Phillips + */ + +package scala.tools.nsc +package matching + +import transform.ExplicitOuter + +trait PatternBindings extends ast.TreeDSL +{ + self: ExplicitOuter with ParallelMatching => + + import global.{ typer => _, _ } + import definitions.{ EqualsPatternClass } + import CODE._ + + // If the given pattern contains alternatives, return it as a list of patterns. + // Makes typed copies of any bindings found so all alternatives point to final state. + def extractBindings(p: Tree, prevBindings: Tree => Tree = identity[Tree] _): List[Tree] = { + def newPrev(b: Bind) = (x: Tree) => treeCopy.Bind(b, b.name, x) setType x.tpe + + p match { + case b @ Bind(_, body) => extractBindings(body, newPrev(b)) + case Alternative(ps) => ps map prevBindings + } + } + + def makeBind(vs: List[Symbol], pat: Tree): Tree = vs match { + case Nil => pat + case x :: xs => Bind(x, makeBind(xs, pat)) setType pat.tpe + } + + private def mkBind(vs: List[Symbol], tpe: Type, arg: Tree) = + makeBind(vs, Typed(arg, TypeTree(tpe)) setType tpe) + + def mkTypedBind(vs: List[Symbol], tpe: Type) = mkBind(vs, tpe, WILD(tpe)) + def mkEmptyTreeBind(vs: List[Symbol], tpe: Type) = mkBind(vs, tpe, EmptyTree) + def mkEqualsRef(xs: List[Type]) = typeRef(NoPrefix, EqualsPatternClass, xs) + + case class Binding(pvar: Symbol, tvar: Symbol) { + override def toString() = "%s: %s @ %s: %s".format(pvar.name, pvar.tpe, tvar.name, tvar.tpe) + } + + case class BindingsInfo(xs: List[BindingInfo]) { + def idents = xs map (_.ident) + def vsyms = xs map (_.vsym) + + def vdefs(implicit context: MatrixContext) = + xs map (x => context.typedValDef(x.vsym, x.ident)) + } + case class BindingInfo(vsym: Symbol, ident: Ident) + + case class Bindings(bindings: Binding*) extends Function1[Symbol, Option[Ident]] { + private def castIfNeeded(pvar: Symbol, tvar: Symbol) = + if (tvar.tpe <:< pvar.tpe) ID(tvar) + else ID(tvar) AS_ANY pvar.tpe + + // filters the given list down to those defined in these bindings + def infoFor(vs: List[Symbol]): BindingsInfo = BindingsInfo( + for (v <- vs ; substv <- apply(v)) yield + BindingInfo(v, substv) + ) + + def add(vs: Iterable[Symbol], tvar: Symbol): Bindings = { + def newBinding(v: Symbol) = { + // see bug #1843 for the consequences of not setting info. + // there is surely a better way to do this, especially since + // this looks to be the only usage of containsTp anywhere + // in the compiler, but it suffices for now. + if (tvar.info containsTp WildcardType) + tvar setInfo v.info + + Binding(v, tvar) + } + val newBindings = vs.toList map newBinding + Bindings(newBindings ++ bindings: _*) + } + + def apply(v: Symbol): Option[Ident] = + bindings find (_.pvar eq v) map (x => Ident(x.tvar) setType v.tpe) + + override def toString() = + if (bindings.isEmpty) "" else bindings.mkString(" Bound(", ", ", ")") + + /** The corresponding list of value definitions. */ + final def targetParams(implicit typer: analyzer.Typer): List[ValDef] = + for (Binding(v, t) <- bindings.toList) yield + VAL(v) === (typer typed castIfNeeded(v, t)) + } + + val NoBinding: Bindings = Bindings() +} \ No newline at end of file diff --git a/src/compiler/scala/tools/nsc/matching/PatternNodes.scala b/src/compiler/scala/tools/nsc/matching/PatternNodes.scala index 50bf8db756..7b866125a1 100644 --- a/src/compiler/scala/tools/nsc/matching/PatternNodes.scala +++ b/src/compiler/scala/tools/nsc/matching/PatternNodes.scala @@ -124,21 +124,6 @@ trait PatternNodes extends ast.TreeDSL } } - final def getDummies(i: Int): List[Tree] = List.fill(i)(EmptyTree) - final def getDummyPatterns(i: Int): List[Pattern] = List.fill(i)(NoPattern) - - def makeBind(vs: List[Symbol], pat: Tree): Tree = vs match { - case Nil => pat - case x :: xs => Bind(x, makeBind(xs, pat)) setType pat.tpe - } - - private def mkBind(vs: List[Symbol], tpe: Type, arg: Tree) = - makeBind(vs, Typed(arg, TypeTree(tpe)) setType tpe) - - def mkTypedBind(vs: List[Symbol], tpe: Type) = mkBind(vs, tpe, WILD(tpe)) - def mkEmptyTreeBind(vs: List[Symbol], tpe: Type) = mkBind(vs, tpe, EmptyTree) - def mkEqualsRef(xs: List[Type]) = typeRef(NoPrefix, EqualsPatternClass, xs) - /** For folding a list into a well-typed x :: y :: etc :: tree. */ private def listFolder(tpe: Type) = { val MethodType(_, TypeRef(pre, sym, _)) = ConsClass.primaryConstructor.tpe @@ -239,60 +224,4 @@ trait PatternNodes extends ast.TreeDSL } vars(x) reverse } - - /** pvar: the symbol of the pattern variable - * tvar: the temporary variable that holds the actual value - */ - case class Binding(pvar: Symbol, tvar: Symbol) { - override def toString() = "%s: %s @ %s: %s".format(pvar.name, pvar.tpe, tvar.name, tvar.tpe) - } - - case class BindingsInfo(xs: List[BindingInfo]) { - def idents = xs map (_.ident) - def vsyms = xs map (_.vsym) - - def vdefs(implicit context: MatchMatrixContext) = - xs map (x => context.typedValDef(x.vsym, x.ident)) - } - case class BindingInfo(vsym: Symbol, ident: Ident) - - case class Bindings(bindings: Binding*) extends Function1[Symbol, Option[Ident]] { - private def castIfNeeded(pvar: Symbol, tvar: Symbol) = - if (tvar.tpe <:< pvar.tpe) ID(tvar) - else ID(tvar) AS_ANY pvar.tpe - - // filters the given list down to those defined in these bindings - def infoFor(vs: List[Symbol]): BindingsInfo = BindingsInfo( - for (v <- vs ; substv <- apply(v)) yield - BindingInfo(v, substv) - ) - - def add(vs: Iterable[Symbol], tvar: Symbol): Bindings = { - def newBinding(v: Symbol) = { - // see bug #1843 for the consequences of not setting info. - // there is surely a better way to do this, especially since - // this looks to be the only usage of containsTp anywhere - // in the compiler, but it suffices for now. - if (tvar.info containsTp WildcardType) - tvar setInfo v.info - - Binding(v, tvar) - } - val newBindings = vs.toList map newBinding - Bindings(newBindings ++ bindings: _*) - } - - def apply(v: Symbol): Option[Ident] = - bindings find (_.pvar eq v) map (x => Ident(x.tvar) setType v.tpe) - - override def toString() = - if (bindings.isEmpty) "" else bindings.mkString(" Bound(", ", ", ")") - - /** The corresponding list of value definitions. */ - final def targetParams(implicit typer: Typer): List[ValDef] = - for (Binding(v, t) <- bindings.toList) yield - VAL(v) === (typer typed castIfNeeded(v, t)) - } - - val NoBinding: Bindings = Bindings() } diff --git a/src/compiler/scala/tools/nsc/matching/PatternOptimizer.scala b/src/compiler/scala/tools/nsc/matching/PatternOptimizer.scala new file mode 100644 index 0000000000..aa51ea4bd3 --- /dev/null +++ b/src/compiler/scala/tools/nsc/matching/PatternOptimizer.scala @@ -0,0 +1,137 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2009 LAMP/EPFL + * Author: Paul Phillips + */ + +package scala.tools.nsc +package matching + +import transform.ExplicitOuter + +trait PatternOptimizer extends ast.TreeDSL +{ + self: ExplicitOuter with ParallelMatching => + + import global.{ typer => _, _ } + import symtab.Flags + import CODE._ + + private[matching] trait Squeezer { + self: MatrixContext => + + def squeezedBlock(vds: List[Tree], exp: Tree): Tree = + if (settings_squeeze) Block(Nil, squeezedBlock1(vds, exp)) + else Block(vds, exp) + + private def squeezedBlock1(vds: List[Tree], exp: Tree): Tree = { + class RefTraverser(sym: Symbol) extends Traverser { + var nref, nsafeRef = 0 + override def traverse(tree: Tree) = tree match { + case t: Ident if t.symbol eq sym => + nref += 1 + if (sym.owner == currentOwner) // oldOwner should match currentOwner + nsafeRef += 1 + + case LabelDef(_, args, rhs) => + (args dropWhile(_.symbol ne sym)) match { + case Nil => + case _ => nref += 2 // cannot substitute this one + } + traverse(rhs) + case t if nref > 1 => // abort, no story to tell + case t => + super.traverse(t) + } + } + + class Subst(sym: Symbol, rhs: Tree) extends Transformer { + var stop = false + override def transform(tree: Tree) = tree match { + case t: Ident if t.symbol == sym => + stop = true + rhs + case _ => if (stop) tree else super.transform(tree) + } + } + + lazy val squeezedTail = squeezedBlock(vds.tail, exp) + def default = squeezedTail match { + case Block(vds2, exp2) => Block(vds.head :: vds2, exp2) + case exp2 => Block(vds.head :: Nil, exp2) + } + + if (vds.isEmpty) exp + else vds.head match { + case vd: ValDef => + val sym = vd.symbol + val rt = new RefTraverser(sym) + rt.atOwner (owner) (rt traverse squeezedTail) + + rt.nref match { + case 0 => squeezedTail + case 1 if rt.nsafeRef == 1 => new Subst(sym, vd.rhs) transform squeezedTail + case _ => default + } + case _ => + default + } + } + } + + private[matching] trait MatchMatrixOptimizer { + self: MatchMatrix => + + import self.context._ + + final def optimize(tree: Tree): Tree = { + // Extractors which can spot pure true/false expressions + // even through the haze of braces + abstract class SeeThroughBlocks[T] { + protected def unapplyImpl(x: Tree): T + def unapply(x: Tree): T = x match { + case Block(Nil, expr) => unapply(expr) + case _ => unapplyImpl(x) + } + } + object IsTrue extends SeeThroughBlocks[Boolean] { + protected def unapplyImpl(x: Tree): Boolean = x equalsStructure TRUE + } + object IsFalse extends SeeThroughBlocks[Boolean] { + protected def unapplyImpl(x: Tree): Boolean = x equalsStructure FALSE + } + object lxtt extends Transformer { + override def transform(tree: Tree): Tree = tree match { + case blck @ Block(vdefs, ld @ LabelDef(name, params, body)) => + def shouldInline(t: FinalState) = t.isReachedOnce && (t.label eq ld.symbol) + + if (targets exists shouldInline) squeezedBlock(vdefs, body) + else blck + + case t => + super.transform(t match { + // note - it is too early for any other true/false related optimizations + case If(cond, IsTrue(), IsFalse()) => cond + + case If(cond1, If(cond2, thenp, elsep1), elsep2) if (elsep1 equalsStructure elsep2) => + IF (cond1 AND cond2) THEN thenp ELSE elsep1 + case If(cond1, If(cond2, thenp, Apply(jmp, Nil)), ld: LabelDef) if jmp.symbol eq ld.symbol => + IF (cond1 AND cond2) THEN thenp ELSE ld + case t => t + }) + } + } + object resetTraverser extends Traverser { + import Flags._ + def reset(vd: ValDef) = + if (vd.symbol hasFlag SYNTHETIC) vd.symbol resetFlag (TRANS_FLAG|MUTABLE) + + override def traverse(x: Tree): Unit = x match { + case vd: ValDef => reset(vd) + case _ => super.traverse(x) + } + } + + returning[Tree](resetTraverser traverse _)(lxtt transform tree) + } + } +} \ No newline at end of file diff --git a/src/compiler/scala/tools/nsc/matching/Patterns.scala b/src/compiler/scala/tools/nsc/matching/Patterns.scala index 49419b9e52..f6f84bd380 100644 --- a/src/compiler/scala/tools/nsc/matching/Patterns.scala +++ b/src/compiler/scala/tools/nsc/matching/Patterns.scala @@ -6,12 +6,25 @@ package scala.tools.nsc package matching -import util.Position -import collection._ -import mutable.BitSet -import immutable.IntMap -import MatchUtil._ -import annotation.elidable +/** + * Simple pattern types: + * + * 1 Variable x + * 3 Literal 56 + * + * Types which must be decomposed into conditionals and simple types: + * + * 2 Typed x: Int + * 4 Stable Identifier Bob or `x` + * 5 Constructor Symbol("abc") + * 6 Tuple (5, 5) + * 7 Extractor List(1, 2) + * 8 Sequence List(1, 2, _*) + * 9 Infix 5 :: xs + * 10 Alternative "foo" | "bar" + * 11 XML -- + * 12 Regular Expression -- + */ trait Patterns extends ast.TreeDSL { self: transform.ExplicitOuter => @@ -19,41 +32,154 @@ trait Patterns extends ast.TreeDSL { import global.{ typer => _, _ } import definitions._ import CODE._ + import treeInfo.{ unbind, isVarPattern } + + // Fresh patterns + final def emptyPatterns(i: Int): List[Pattern] = List.fill(i)(NoPattern) + + // A fresh, empty pattern + def NoPattern = WildcardPattern() + + // 8.1.1 + case class VariablePattern(tree: Ident) extends Pattern { + override def irrefutableFor(tpe: Type) = true + } + + // 8.1.1 (b) + case class WildcardPattern() extends Pattern { + val tree = EmptyTree + override def irrefutableFor(tpe: Type) = true + } - val NoPattern = Pattern(EmptyTree) + // 8.1.2 + case class TypedPattern(tree: Typed) extends Pattern { + private val Typed(expr, tpt) = tree - class ConstructorPattern(override val tree: Apply) extends Pattern(tree) { - private val Apply(fn, args) = tree + override def irrefutableFor(tpe: Type) = tpe <:< tree.tpe + } + + // 8.1.3 + case class LiteralPattern(tree: Literal) extends Pattern { } - def isCaseClass = fn.isType - def isCaseObject = args == Nil && !fn.isType - def isFunction = !isCaseObject && !isCaseObject + // 8.1.4 + case class StableIdPattern(tree: Ident) extends Pattern { } + + // 8.1.5 + case class ConstructorPattern(tree: Apply) extends ApplyPattern { + // XXX todo + // override def irrefutableFor(tpe: Type) = false } - class LiteralPattern(override val tree: Literal) extends Pattern(tree) - class IdentPattern(override val tree: Ident) extends Pattern(tree) - class ObjectPattern(override val tree: Apply) extends Pattern(tree) + // 8.1.6 + case class TuplePattern(tree: Apply) extends ApplyPattern { + // XXX todo + // override def irrefutableFor(tpe: Type) = false + } + + // 8.1.7 + case class ExtractorPattern(tree: UnApply) extends Pattern { } + + // 8.1.8 + case class SequencePattern(tree: ArrayValue) extends Pattern { } + + // 8.1.8 (b) + case class SequenceStarPattern(tree: ArrayValue) extends Pattern { } + + // 8.1.9 + // InfixPattern ... subsumed by Constructor/Extractor Patterns + + // 8.1.10 + case class AlternativePattern(tree: Alternative, subpatterns: Seq[Pattern]) extends Pattern { } + + // 8.1.11 + // XMLPattern ... for now, subsumed by SequencePattern, but if we want + // to make it work right, it probably needs special handling. + + + // XXX - temporary pattern until we have integrated every tree type. + case class MiscPattern(tree: Tree) extends Pattern { + // println("Resorted to MiscPattern: %s/%s".format(tree, tree.getClass)) + } - class TypedPattern(override val tree: Typed) extends Pattern(tree) - class UnapplyPattern(override val tree: UnApply) extends Pattern(tree) - class SeqPattern(override val tree: UnApply) extends Pattern(tree) object Pattern { + def isDefaultPattern(t: Tree) = cond(unbind(t)) { case EmptyTree | WILD() => true } + def isStar(t: Tree) = cond(unbind(t)) { case Star(q) => isDefaultPattern(q) } + def isRightIgnoring(t: Tree) = cond(unbind(t)) { case ArrayValue(_, xs) if !xs.isEmpty => isStar(xs.last) } + def apply(tree: Tree): Pattern = tree match { - case x: Apply => new ConstructorPattern(x) - case _ => new MiscPattern(tree) + case x: Bind => apply(unbind(tree)) withBoundTree x + case EmptyTree | WILD() => WildcardPattern() + case x @ Alternative(ps) => AlternativePattern(x, ps map apply) + case x: Apply => ApplyPattern(x) + case x: Typed => TypedPattern(x) + case x: Literal => LiteralPattern(x) + case x: UnApply => ExtractorPattern(x) + case x: Ident => if (isVarPattern(x)) VariablePattern(x) else StableIdPattern(x) + case x: ArrayValue => if (isRightIgnoring(x)) SequenceStarPattern(x) else SequencePattern(x) + case x: Select => MiscPattern(x) // XXX + case x: Star => MiscPattern(x) // XXX + case _ => abort("Unknown Tree reached pattern matcher: %s/%s".format(tree, tree.getClass)) } - // def apply(x: Tree, preGuard: Tree): Pattern = new Pattern(x, preGuard) def unapply(other: Pattern): Option[Tree] = Some(other.tree) } - class MiscPattern(tree: Tree) extends Pattern(tree) { } + // right now a tree like x @ Apply(fn, Nil) where !fn.isType + // is handled by creating a singleton type: + // + // val stype = Types.singleType(x.tpe.prefix, x.symbol) + // + // and then passing that as a type argument to EqualsPatternClass: + // + // val tpe = typeRef(NoPrefix, EqualsPatternClass, List(stype)) + // + // then creating a Typed pattern and rebinding. + // + // val newpat = Typed(EmptyTree, TypeTree(tpe)) setType tpe) + // + object ApplyPattern { + def apply(x: Apply): Pattern = { + val Apply(fn, args) = x + + if (fn.isType) { + if (isTupleType(fn.tpe)) TuplePattern(x) + else ConstructorPattern(x) + } + else if (args.isEmpty) fn match { + case _ => ConstructorPattern(x) // XXX + // case x: Ident => StableIdPattern(x) + // case x => MiscPattern(x) + } + else abort("Strange apply: %s/%s".format(x)) + } + } + + sealed abstract class ApplyPattern extends Pattern { + protected lazy val Apply(fn, args) = tree - sealed abstract class Pattern(val tree: Tree, val preGuard: Tree) { - // type T <: Tree - // val tree: T + def isConstructorPattern = fn.isType + } - def this(tree: Tree) = this(tree, null) + sealed abstract class Pattern { + val tree: Tree + + // 8.1.13 + // A pattern p is irrefutable for type T if any of the following applies: + // 1) p is a variable pattern + // 2) p is a typed pattern x: T', and T <: T' + // 3) p is a constructor pattern C(p1,...,pn), the type T is an instance of class C, + // the primary constructor of type T has argument types T1,...,Tn and and each + // pi is irrefutable for Ti. + def irrefutableFor(tpe: Type) = false + + // XXX only a var for short-term experimentation. + private var _boundTree: Bind = null + def boundTree = if (_boundTree == null) tree else _boundTree + def withBoundTree(x: Bind): this.type = { + _boundTree = x + this + } + lazy val boundVariables = strip(boundTree) def sym = tree.symbol def tpe = tree.tpe @@ -66,43 +192,37 @@ trait Patterns extends ast.TreeDSL { tree setType tpe this } - lazy val stripped = strip(tree)._1 - lazy val boundVariables = strip(tree)._2 - lazy val unbound: Pattern = copy(stripped) def mkSingleton = tpe match { case st: SingleType => st case _ => singleType(prefix, sym) } - final def isBind = cond(tree) { case x: Bind => true } - final def isDefault = cond(stripped) { case EmptyTree | WILD() => true } - final def isStar = cond(stripped) { case Star(q) => Pattern(q).isDefault } - final def isAlternative = cond(stripped) { case Alternative(_) => true } - final def isRightIgnoring = cond(stripped) { case ArrayValue(_, xs) if !xs.isEmpty => Pattern(xs.last).isStar } + final def isDefault = cond(tree) { case EmptyTree | WILD() => true } + final def isStar = cond(tree) { case Star(q) => Pattern(q).isDefault } + final def isAlternative = cond(tree) { case Alternative(_) => true } + final def isRightIgnoring = cond(tree) { case ArrayValue(_, xs) if !xs.isEmpty => Pattern(xs.last).isStar } /** returns true if pattern tests an object */ final def isObjectTest(head: Type) = isSymValid && prefix.isStable && (head =:= mkSingleton) /** Helpers **/ - private def strip(t: Tree, syms: List[Symbol] = Nil): (Tree, List[Symbol]) = t match { - case b @ Bind(_, pat) => strip(pat, b.symbol :: syms) - case _ => (t, syms) + private def strip(t: Tree): List[Symbol] = t match { + case b @ Bind(_, pat) => b.symbol :: strip(pat) + case _ => Nil } /** Standard methods **/ - def copy( - tree: Tree = this.tree, - preGuard: Tree = this.preGuard - ): Pattern = Pattern(tree) // XXX - // Pattern(tree, preGuard) + def copy(tree: Tree = this.tree): Pattern = + if (_boundTree == null) Pattern(tree) + else Pattern(tree) withBoundTree _boundTree - override def toString() = "Pattern(%s)".format(tree) + override def toString() = "Pattern(%s, %s)".format(tree, boundVariables) override def equals(other: Any) = other match { - case Pattern(t) => this.tree == t + case x: Pattern => this.boundTree == x.boundTree case _ => super.equals(other) } - override def hashCode() = tree.hashCode() + override def hashCode() = boundTree.hashCode() } } \ No newline at end of file diff --git a/src/compiler/scala/tools/nsc/matching/TransMatcher.scala b/src/compiler/scala/tools/nsc/matching/TransMatcher.scala index 02124c2041..30e0a62efd 100644 --- a/src/compiler/scala/tools/nsc/matching/TransMatcher.scala +++ b/src/compiler/scala/tools/nsc/matching/TransMatcher.scala @@ -3,27 +3,6 @@ * Copyright 2007 Google Inc. All Rights Reserved. * Author: bqe@google.com (Burak Emir) */ -// $Id$ - -/** - * Simple pattern types: - * - * 1 Variable x - * 3 Literal 56 - * - * Types which must be decomposed into conditionals and simple types: - * - * 2 Typed x: Int - * 4 Stable Identifier Bob or `x` - * 5 Constructor Symbol("abc") - * 6 Tuple (5, 5) - * 7 Extractor List(1, 2) - * 8 Sequence List(1, 2, _*) - * 9 Infix 5 :: xs - * 10 Alternative "foo" | "bar" - * 11 XML -- - * 12 Regular Expression -- - */ package scala.tools.nsc package matching @@ -31,6 +10,7 @@ package matching import util.Position import ast.{ TreePrinters, Trees } import symtab.SymbolTable +import transform.ExplicitOuter import java.io.{ StringWriter, PrintWriter } import scala.util.NameTransformer.decode @@ -39,10 +19,10 @@ import scala.util.NameTransformer.decode * @author Burak Emir */ trait TransMatcher extends ast.TreeDSL with CompactTreePrinter { - self: transform.ExplicitOuter with PatternNodes with ParallelMatching => + self: ExplicitOuter with ParallelMatching with PatternOptimizer => import global.{ typer => _, _ } - import analyzer.Typer; + import analyzer.Typer import definitions._ import symtab.Flags import CODE._ @@ -53,110 +33,13 @@ trait TransMatcher extends ast.TreeDSL with CompactTreePrinter { final val settings_squeeze = settings.Xsqueeze.value == "on" - /** Contains data structures which the match algorithm implementation - * requires but which aren't essential to the algorithm itself. - */ - case class MatchMatrixContext( - handleOuter: TreeFunction1, // Tree => Tree function - typer: Typer, // a local typer - owner: Symbol, // the current owner - matchResultType: Type) // the expected result type of the whole match - { - def newVar( - pos: Position, - tpe: Type, - flags: List[Long] = Nil, - name: Name = null): Symbol = - { - val n: Name = if (name == null) newName(pos, "temp") else name - // careful: pos has special meaning - owner.newVariable(pos, n) setInfo tpe setFlag (0L /: flags)(_|_) - } - - def typedValDef(x: Symbol, rhs: Tree) = { - val finalRhs = x.tpe match { - case WildcardType => - rhs setType null - x setInfo (typer typed rhs).tpe - rhs - case _ => - typer.typed(rhs, x.tpe) - } - typer typed (VAL(x) === finalRhs) - } - - def squeezedBlock(vds: List[Tree], exp: Tree): Tree = - if (settings_squeeze) Block(Nil, squeezedBlock1(vds, exp)) - else Block(vds, exp) - - private def squeezedBlock1(vds: List[Tree], exp: Tree): Tree = { - class RefTraverser(sym: Symbol) extends Traverser { - var nref, nsafeRef = 0 - override def traverse(tree: Tree) = tree match { - case t: Ident if t.symbol eq sym => - nref += 1 - if (sym.owner == currentOwner) // oldOwner should match currentOwner - nsafeRef += 1 - - case LabelDef(_, args, rhs) => - (args dropWhile(_.symbol ne sym)) match { - case Nil => - case _ => nref += 2 // cannot substitute this one - } - traverse(rhs) - case t if nref > 1 => // abort, no story to tell - case t => - super.traverse(t) - } - } - - class Subst(sym: Symbol, rhs: Tree) extends Transformer { - var stop = false - override def transform(tree: Tree) = tree match { - case t: Ident if t.symbol == sym => - stop = true - rhs - case _ => if (stop) tree else super.transform(tree) - } - } - - lazy val squeezedTail = squeezedBlock(vds.tail, exp) - def default = squeezedTail match { - case Block(vds2, exp2) => Block(vds.head :: vds2, exp2) - case exp2 => Block(vds.head :: Nil, exp2) - } - - if (vds.isEmpty) exp - else vds.head match { - case vd: ValDef => - val sym = vd.symbol - val rt = new RefTraverser(sym) - rt.atOwner (owner) (rt traverse squeezedTail) - - rt.nref match { - case 0 => squeezedTail - case 1 if rt.nsafeRef == 1 => new Subst(sym, vd.rhs) transform squeezedTail - case _ => default - } - case _ => - default - } - } - } - - case class MatchMatrixInit( - roots: List[Symbol], - cases: List[CaseDef], - default: Tree - ) - /** Handles all translation of pattern matching. */ def handlePattern( selector: Tree, // tree being matched upon (called scrutinee after this) cases: List[CaseDef], // list of cases in the match isChecked: Boolean, // whether exhaustiveness checking is enabled (disabled with @unchecked) - context: MatchMatrixContext): Tree = + context: MatrixContext): Tree = { import context._ @@ -172,15 +55,15 @@ trait TransMatcher extends ast.TreeDSL with CompactTreePrinter { (cases forall caseIsOk) // For x match { ... we start with a single root - def singleMatch(): (List[Tree], MatchMatrixInit) = { + def singleMatch(): (List[Tree], MatrixInit) = { val root: Symbol = newVar(selector.pos, selector.tpe, flags) val varDef: Tree = typedValDef(root, selector) - (List(varDef), MatchMatrixInit(List(root), cases, matchError(ID(root)))) + (List(varDef), MatrixInit(List(root), cases, matchError(ID(root)))) } // For (x, y, z) match { ... we start with multiple roots, called tpXX. - def tupleMatch(app: Apply): (List[Tree], MatchMatrixInit) = { + def tupleMatch(app: Apply): (List[Tree], MatrixInit) = { val Apply(fn, args) = app val (roots, vars) = List.unzip( for ((arg, typeArg) <- args zip selector.tpe.typeArgs) yield { @@ -188,7 +71,7 @@ trait TransMatcher extends ast.TreeDSL with CompactTreePrinter { (v, typedValDef(v, arg)) } ) - (vars, MatchMatrixInit(roots, cases, matchError(treeCopy.Apply(app, fn, roots map ID)))) + (vars, MatrixInit(roots, cases, matchError(treeCopy.Apply(app, fn, roots map ID)))) } // sets up top level variables and algorithm input @@ -204,8 +87,8 @@ trait TransMatcher extends ast.TreeDSL with CompactTreePrinter { // redundancy check matrix.targets filter (_.isNotReached) foreach (cs => cunit.error(cs.body.pos, "unreachable code")) - // cleanup performs squeezing and resets any remaining TRANS_FLAGs - matrix cleanup dfatree + // optimize performs squeezing and resets any remaining TRANS_FLAGs + matrix optimize dfatree } private def toCompactString(t: Tree): String = { diff --git a/src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala b/src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala index e7ecee9ac1..905d41e9f9 100644 --- a/src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala +++ b/src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala @@ -387,7 +387,7 @@ abstract class ExplicitOuter extends InfoTransform } val t = atPos(tree.pos) { - val context = MatchMatrixContext(transform, localTyper, currentOwner, tree.tpe) + val context = MatrixContext(transform, localTyper, currentOwner, tree.tpe) val t_untyped = handlePattern(nselector, ncases, checkExhaustive, context) /* if @switch annotation is present, verify the resulting tree is a Match */ -- cgit v1.2.3