From c79f8876aa04370fc99692f73825392ea48d02e2 Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Tue, 26 Apr 2011 19:00:24 +0000 Subject: Some solid progress on the pattern matcher, no ... Some solid progress on the pattern matcher, no review. --- .../scala/tools/nsc/backend/icode/GenICode.scala | 12 +- .../scala/tools/nsc/matching/MatchSupport.scala | 5 +- src/compiler/scala/tools/nsc/matching/Matrix.scala | 6 + .../tools/nsc/matching/ParallelMatching.scala | 178 +++++++++++++-------- .../scala/tools/nsc/matching/PatternBindings.scala | 17 +- 5 files changed, 136 insertions(+), 82 deletions(-) diff --git a/src/compiler/scala/tools/nsc/backend/icode/GenICode.scala b/src/compiler/scala/tools/nsc/backend/icode/GenICode.scala index 540c5f50dd..13340f7f08 100644 --- a/src/compiler/scala/tools/nsc/backend/icode/GenICode.scala +++ b/src/compiler/scala/tools/nsc/backend/icode/GenICode.scala @@ -1142,10 +1142,14 @@ abstract class GenICode extends SubComponent { log("Dropped an " + from); case _ => - if (settings.debug.value) - assert(from != UNIT, "Can't convert from UNIT to " + to + " at: " + pos) - assert(!from.isReferenceType && !to.isReferenceType, "type error: can't convert from " + from + " to " + to +" in unit "+this.unit) - ctx.bb.emit(CALL_PRIMITIVE(Conversion(from, to)), pos); + if (settings.debug.value) { + assert(from != UNIT, + "Can't convert from UNIT to " + to + " at: " + pos) + } + assert(!from.isReferenceType && !to.isReferenceType, + "type error: can't convert from " + from + " to " + to +" in unit " + unit.source) + + ctx.bb.emit(CALL_PRIMITIVE(Conversion(from, to)), pos) } } else if (from == NothingReference) { ctx.bb.emit(THROW(ThrowableClass)) diff --git a/src/compiler/scala/tools/nsc/matching/MatchSupport.scala b/src/compiler/scala/tools/nsc/matching/MatchSupport.scala index ca2e252d35..c5673fced7 100644 --- a/src/compiler/scala/tools/nsc/matching/MatchSupport.scala +++ b/src/compiler/scala/tools/nsc/matching/MatchSupport.scala @@ -31,10 +31,7 @@ trait MatchSupport extends ast.TreeDSL { self: ParallelMatching => import definitions._ implicit def enrichType(x: Type): RichType = new RichType(x) - // A subtype test which creates fresh existentials for type - // parameters on the right hand side. - private[matching] def matches(arg1: Type, arg2: Type) = - decodedEqualsType(arg1) matchesPattern decodedEqualsType(arg2) + val subrangeTypes = Set(ByteClass, ShortClass, CharClass, IntClass) class RichType(undecodedTpe: Type) { def tpe = decodedEqualsType(undecodedTpe) diff --git a/src/compiler/scala/tools/nsc/matching/Matrix.scala b/src/compiler/scala/tools/nsc/matching/Matrix.scala index 8dc960894c..b9b9b51384 100644 --- a/src/compiler/scala/tools/nsc/matching/Matrix.scala +++ b/src/compiler/scala/tools/nsc/matching/Matrix.scala @@ -236,6 +236,12 @@ trait Matrix extends MatrixAdditions { tracing("create")(new PatternVar(lhs, rhs, checked)) } + def createLazy(tpe: Type, f: Symbol => Tree, checked: Boolean) = { + val lhs = newVar(owner.pos, tpe, Flags.LAZY :: flags(checked)) + val rhs = f(lhs) + + tracing("createLazy")(new PatternVar(lhs, rhs, checked)) + } private def newVar( pos: Position, diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala index f41c37080a..a9c9d959a5 100644 --- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala +++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala @@ -50,14 +50,10 @@ trait ParallelMatching extends ast.TreeDSL shortCuts(key) = theLabel -key } + def createLabelDef(prefix: String, params: List[Symbol] = Nil, tpe: Type = matchResultType) = { + val labelSym = owner.newLabel(owner.pos, cunit.freshTermName(prefix)) setInfo MethodType(params, tpe) - /** first time bx is requested, a LabelDef is returned. next time, a jump. - * the function takes care of binding - */ - final def requestBody(bx: Int, subst: Bindings): Tree = { - // shortcut - if (bx < 0) Apply(ID(shortCuts(-bx)), Nil) - else targets(bx) labelBody subst + (body: Tree) => LabelDef(labelSym, params, body setType tpe) } /** This is the recursively focal point for translating the current @@ -108,7 +104,7 @@ trait ParallelMatching extends ast.TreeDSL // tests def isDefined = sym ne NoSymbol - def isSubrangeType = Set(ByteClass, ShortClass, CharClass, IntClass) contains tpe.typeSymbol + def isSubrangeType = subrangeTypes(tpe.typeSymbol) def isCaseClass = tpe.typeSymbol.isCase // sequences @@ -205,12 +201,6 @@ trait ParallelMatching extends ast.TreeDSL override def toString() = "%s match {%s}".format(scrut, indentAll(ps)) } // PatternMatch - /** picks which rewrite rule to apply - * @precondition: column does not contain alternatives - */ - def MixtureRule(scrut: Scrutinee, column: List[Pattern], rest: Rep): RuleApplication = - PatternMatch(scrut, column) mkRule rest - /***** Rule Applications *****/ sealed abstract class RuleApplication { @@ -250,14 +240,21 @@ trait ParallelMatching extends ast.TreeDSL def pmatch: PatternMatch = impossible def rest: Rep = guardedRest - lazy val cond = if (guard.isEmpty) TRUE else guard - lazy val success = requestBody(bx, subst) - lazy val failure = guardedRest.toTree - lazy val pvgroup = PatternVarGroup.fromBindings(subst.get()) + private lazy val (valDefs, successTree) = targets(bx) applyBindings subst.toMap + lazy val cond = guard + lazy val success = successTree + lazy val failure = guardedRest.toTree final def tree(): Tree = - if (guard.isEmpty) success - else squeezedBlock(pvgroup.valDefs, codegen) + if (bx < 0) REF(shortCuts(-bx)) + else squeezedBlock( + valDefs, + if (cond.isEmpty) success else codegen + ) + + override def toString = "(case %d) {\n Bindings: %s\n\n if (%s) { %s }\n else { %s }\n}".format( + bx, subst, guard, success, guardedRest + ) } class MixLiteralInts(val pmatch: PatternSwitch, val rest: Rep) extends RuleApplication { @@ -506,29 +503,20 @@ trait ParallelMatching extends ast.TreeDSL final def tree(): Tree = codegen } - // @todo: equals test for same constant class MixEquals(val pmatch: PatternMatch, val rest: Rep) extends RuleApplication { - private lazy val labelBody = - remake((rest.rows.tail, pmatch.tail).zipped map (_ insert _)).toTree - private lazy val rhs = decodedEqualsType(head.tpe) match { case SingleType(pre, sym) => REF(pre, sym) case PseudoType(o) => o } + private lazy val labelDef = + createLabelDef("fail%")(remake((rest.rows.tail, pmatch.tail).zipped map (_ insert _)).toTree) - lazy val label = - owner.newLabel(scrut.pos, cunit.freshTermName("failCont%")) setInfo MethodType(Nil, labelBody.tpe) - - lazy val cond = - handleOuter(rhs MEMBER_== scrut.id ) - - lazy val success = remake(List( - rest.rows.head.insert2(List(NoPattern), head.boundVariables, scrut.sym), - Row(emptyPatterns(1 + rest.tvars.size), NoBinding, EmptyTree, createShortCut(label)) - )).toTree - - lazy val failure = LabelDef(label, Nil, labelBody) + lazy val cond = handleOuter(rhs MEMBER_== scrut.id) + lazy val successOne = rest.rows.head.insert2(List(NoPattern), head.boundVariables, scrut.sym) + lazy val successTwo = Row(emptyPatterns(1 + rest.tvars.size), NoBinding, EmptyTree, createShortCut(labelDef.symbol)) + lazy val success = remake(List(successOne, successTwo)).toTree + lazy val failure = labelDef final def tree() = codegen override def toString() = "MixEquals(%s == %s)".format(scrut, head) @@ -656,34 +644,82 @@ trait ParallelMatching extends ast.TreeDSL "Row(%d)(%s%s)".format(bx, pp(pats), bs) } } + abstract class State { + def bx: Int // index into the list of rows + def params: List[Symbol] // bound names to be supplied as arguments to labeldef + def body: Tree // body to execute upon match + def label: Option[LabelDef] // label definition for this state + + // Called with a bindings map when a match is achieved. + // Returns a list of variable declarations based on the labeldef parameters + // and the given substitution, and the body to execute. + protected def applyBindingsImpl(subst: Map[Symbol, Symbol]): (List[ValDef], Tree) + + final def applyBindings(subst: Map[Symbol, Symbol]): (List[ValDef], Tree) = { + _referenceCount += 1 + applyBindingsImpl(subst) + } + + private var _referenceCount = 0 + def referenceCount = _referenceCount + def unreached = referenceCount == 0 + def shouldInline(sym: Symbol) = referenceCount == 1 && label.exists(_.symbol == sym) + + protected def maybeCast(lhs: Symbol, rhs: Symbol)(tree: Tree) = { + if (rhs.tpe <:< lhs.tpe) tree + else tree AS lhs.tpe + } + + protected def newValDefinition(lhs: Symbol, rhs: Symbol) = + VAL(lhs) === maybeCast(lhs, rhs)(Ident(rhs)) + + protected def newValReference(lhs: Symbol, rhs: Symbol) = + maybeCast(lhs, rhs)(Ident(rhs)) + + protected def mapSubst[T](subst: Map[Symbol, Symbol])(f: (Symbol, Symbol) => T): List[T] = + params flatMap { lhs => + subst get lhs map (rhs => f(lhs, rhs)) orElse { + // This should not happen; the code should be structured so it is + // impossible, but that still lies ahead. + cunit.warning(lhs.pos, "No binding") + None + } + } + + protected def valDefsFor(subst: Map[Symbol, Symbol]) = + mapSubst(subst)(typer typedValDef newValDefinition(_, _)) + + protected def identsFor(subst: Map[Symbol, Symbol]) = + mapSubst(subst)(typer typed newValReference(_, _)) - case class FinalState(bx: Int, body: Tree, params: List[Symbol]) { - private var referenceCount = 0 // typer is not able to digest a body of type Nothing being assigned result type Unit - private def caseResultType = if (body.tpe.isNothing) body.tpe else matchResultType - private lazy val label: LabelDef = body match { - case Literal(_) => null - case _ => - val symbol = owner.newLabel(body.pos, "body%" + bx) setInfo MethodType(params, caseResultType) - // typer typedLabelDef - LabelDef(symbol, params, body setType caseResultType) - } - - def unreached = referenceCount == 0 - def shouldInline(sym: Symbol) = referenceCount == 1 && label != null && label.symbol == sym - - def labelBody(subst: Bindings): Tree = { - referenceCount += 1 - val pvgroup = PatternVarGroup.fromBindings(subst.get(), params) - - if (referenceCount > 1 && label != null) - ID(label.symbol) APPLY (pvgroup map (_.rhs)) - else squeezedBlock(pvgroup.valDefs, - if (label != null) label - else body.duplicate setType caseResultType - ) + protected def caseResultType = + if (body.tpe.isNothing) body.tpe else matchResultType + } + + case class LiteralState(bx: Int, params: List[Symbol], body: Tree) extends State { + def label = None + + protected def applyBindingsImpl(subst: Map[Symbol, Symbol]) = + (valDefsFor(subst), body.duplicate setType caseResultType) + } + + case class FinalState(bx: Int, params: List[Symbol], body: Tree) extends State { + traceCategory("Final State", "(%s) => %s", paramsString, body) + def label = Some(labelDef) + + private lazy val labelDef = createLabelDef("body%" + bx, params, caseResultType)(body) + + protected def applyBindingsImpl(subst: Map[Symbol, Symbol]) = { + val tree = + if (referenceCount > 1) ID(labelDef.symbol) APPLY identsFor(subst) + else labelDef + + (valDefsFor(subst), tree) } - override def toString() = pp("Final%d%s".format(bx, pp(params)) -> body) + + private def paramsString = params map (s => s.name + ": " + s.tpe) mkString ", " + override def toString() = pp("(%s) => %s".format(pp(params), body)) } case class Rep(val tvars: PatternVarGroup, val rows: List[Row]) { @@ -706,7 +742,10 @@ trait ParallelMatching extends ast.TreeDSL private val (_ncol, _nrep) = (others.head :: _column.tail, make(_tvars, _rows)) - def mix = MixtureRule(new Scrutinee(specialVar(_pv.sym, _pv.checked)), _ncol, _nrep) + def mix() = { + val newScrut = new Scrutinee(specialVar(_pv.sym, _pv.checked)) + PatternMatch(newScrut, _ncol) mkRule _nrep + } } /** Converts this to a tree - recursively acquires subreps. */ @@ -719,9 +758,8 @@ trait ParallelMatching extends ast.TreeDSL VariableRule(binding, guard, guardedRest, index) } - - /** The MixtureRule. */ - def mixture() = new Cut(defaults.size) mix + /** The MixtureRule: picks a rewrite rule to apply. */ + private def mixture() = new Cut(defaults.size) mix() /** Applying the rule will result in one of: * @@ -742,14 +780,18 @@ trait ParallelMatching extends ast.TreeDSL /** Expands the patterns recursively. */ final def expand(roots: List[PatternVar], cases: List[CaseDef]) = tracing("expand") { - for ((CaseDef(pat, guard, body), index) <- cases.zipWithIndex) yield { + for ((CaseDef(pat, guard, body), bx) <- cases.zipWithIndex) yield { val subtrees = pat match { case x if roots.length <= 1 => List(x) case Apply(_, args) => args case WILD() => emptyTrees(roots.length) } - val row = Row(toPats(subtrees), NoBinding, guard, index) - val state = FinalState(index, body, Pattern(pat).deepBoundVariables) + val params = pat filter (_.isInstanceOf[Bind]) map (_.symbol) distinct + val row = Row(toPats(subtrees), NoBinding, guard, bx) + val state = body match { + case x: Literal => LiteralState(bx, params, body) + case _ => FinalState(bx, params, body) + } row -> state } diff --git a/src/compiler/scala/tools/nsc/matching/PatternBindings.scala b/src/compiler/scala/tools/nsc/matching/PatternBindings.scala index 88983a792f..f26eec6339 100644 --- a/src/compiler/scala/tools/nsc/matching/PatternBindings.scala +++ b/src/compiler/scala/tools/nsc/matching/PatternBindings.scala @@ -20,9 +20,13 @@ trait PatternBindings extends ast.TreeDSL import Debug._ /** EqualsPattern **/ - def isEquals(tpe: Type) = cond(tpe) { case TypeRef(_, EqualsPatternClass, _) => true } - def mkEqualsRef(tpe: Type) = typeRef(NoPrefix, EqualsPatternClass, List(tpe)) - def decodedEqualsType(tpe: Type) = condOpt(tpe) { case TypeRef(_, EqualsPatternClass, List(arg)) => arg } getOrElse (tpe) + def isEquals(tpe: Type) = cond(tpe) { case TypeRef(_, EqualsPatternClass, _) => true } + def mkEqualsRef(tpe: Type) = typeRef(NoPrefix, EqualsPatternClass, List(tpe)) + def decodedEqualsType(tpe: Type) = condOpt(tpe) { case TypeRef(_, EqualsPatternClass, List(arg)) => arg } getOrElse (tpe) + + // A subtype test which creates fresh existentials for type + // parameters on the right hand side. + def matches(arg1: Type, arg2: Type) = decodedEqualsType(arg1) matchesPattern decodedEqualsType(arg2) // used as argument to `EqualsPatternClass' case class PseudoType(o: Tree) extends SimpleTypeProxy { @@ -121,7 +125,7 @@ trait PatternBindings extends ast.TreeDSL } case class Binding(pvar: Symbol, tvar: Symbol) { - override def toString() = pp(pvar -> tvar) + override def toString() = pvar.name + " -> " + tvar.name } class Bindings(private val vlist: List[Binding]) { @@ -129,6 +133,7 @@ trait PatternBindings extends ast.TreeDSL // traceCategory("Bindings", this.toString) def get() = vlist + def toMap = vlist map (x => (x.pvar, x.tvar)) toMap def add(vs: Iterable[Symbol], tvar: Symbol): Bindings = { val newBindings = vs.toList map (v => Binding(v, tvar)) @@ -136,8 +141,8 @@ trait PatternBindings extends ast.TreeDSL } override def toString() = - if (vlist.isEmpty) "No Bindings" - else "%d Bindings(%s)".format(vlist.size, pp(vlist)) + if (vlist.isEmpty) "" + else vlist.mkString(", ") } val NoBinding: Bindings = new Bindings(Nil) -- cgit v1.2.3