From 2482bddf7e64c5bdbd227819ba4c019d88f74c9b Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Sun, 11 Oct 2009 15:58:37 +0000 Subject: Created PatternVarGroup to hold a sequence of p... Created PatternVarGroup to hold a sequence of patternvars. Soon all the variable binding code will be swallowed and each action related to variables will happen at one location only. --- src/compiler/scala/tools/nsc/matching/Matrix.scala | 38 ++++++- .../tools/nsc/matching/ParallelMatching.scala | 113 +++++++++------------ 2 files changed, 86 insertions(+), 65 deletions(-) (limited to 'src') diff --git a/src/compiler/scala/tools/nsc/matching/Matrix.scala b/src/compiler/scala/tools/nsc/matching/Matrix.scala index e0d15a4c49..cd66c9c5ce 100644 --- a/src/compiler/scala/tools/nsc/matching/Matrix.scala +++ b/src/compiler/scala/tools/nsc/matching/Matrix.scala @@ -98,12 +98,48 @@ trait Matrix extends MatrixAdditions { def valDefs = roots map (_.valDef) } + implicit def pvlist2pvgroup(xs: List[PatternVar]): PatternVarGroup = + PatternVarGroup(xs) + + object PatternVarGroup { + def apply(xs: PatternVar*) = new PatternVarGroup(xs.toList) + def apply(xs: List[PatternVar]) = new PatternVarGroup(xs) + } + + val emptyPatternVarGroup = PatternVarGroup() + class PatternVarGroup(val pvs: List[PatternVar]) { + def syms = pvs map (_.sym) + def valDefs = pvs map (_.valDef) + def idents = pvs map (_.ident) + + def extractIndex(index: Int): (PatternVar, PatternVarGroup) = { + val (t, ts) = self.extractIndex(pvs, index) + (t, PatternVarGroup(ts)) + } + + def size = pvs.size + def head = pvs.head + def ::(t: PatternVar) = PatternVarGroup(t :: pvs) + def :::(ts: List[PatternVar]) = PatternVarGroup(ts ::: pvs) + def ++(other: PatternVarGroup) = PatternVarGroup(pvs ::: other.pvs) + + def apply(i: Int) = pvs(i) + def zipWithIndex = pvs.zipWithIndex + def indices = pvs.indices + def map[T](f: PatternVar => T) = pvs map f + def filter(p: PatternVar => Boolean) = PatternVarGroup(pvs filter p) + } + /** Every temporary variable allocated is put in a PatternVar. */ class PatternVar(val lhs: Symbol, val rhs: Tree, val checked: Boolean) { def sym = lhs + def valsym = valDef.symbol + + def tpe = valsym.tpe // XXX how will sym.tpe differ from sym.tpe ? + lazy val ident = ID(lhs) - lazy val valDef = typedValDef(lhs, rhs) + lazy val valDef = tracing("typedVal", typer typedValDef (VAL(lhs) === rhs)) override def toString() = "%s: %s = %s".format(lhs, lhs.info, rhs) } diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala index 8143b2e731..6ba9ad3b18 100644 --- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala +++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala @@ -73,21 +73,8 @@ trait ParallelMatching extends ast.TreeDSL else target.getLabelBody(idents, patternValDefs) } - // make(tvars, rows) - // make(scrut.sym :: rest.tvars, xs).toTree - // make(rest.tvars, newRows ::: defaultRows) - // make(r.tvars, r.rows map (x => x rebind bindVars(tag, x.subst))) - // make(rest.tvars, defaultRows).toTree - // make(ntemps ::: scrut.sym :: rest.tvars, rows).toTree - // make(List(vs map (_.lhs), symList, rest.tvars).flatten, nrows.flatten) - // make(scrut.sym :: rest.tvars, frows.flatten).toTree - // make(scrut.sym :: rest.tvars, rows).toTree - // make(subtestVars ::: casted.accessorVars ::: rest.tvars, newRows) - // make(tvars, rows.tail) - // make(_tvars, _rows) - /** the injection here handles alternatives and unapply type tests */ - final def make(tvars: List[PatternVar], row1: List[Row]): Rep = { + final def make(tvars: PatternVarGroup, row1: List[Row]): Rep = { def classifyPat(opat: Pattern, j: Int): Pattern = opat simplify tvars(j).sym val rows = row1 flatMap (_ expandAlternatives classifyPat) @@ -116,12 +103,13 @@ trait ParallelMatching extends ast.TreeDSL def accessors = if (isCaseClass) sym.caseFieldAccessors else Nil def accessorTypes = accessors map (x => (tpe memberType x).resultType) - lazy val accessorPatternVars = + lazy val accessorPatternVars = PatternVarGroup( for ((accessor, tpe) <- accessors zip accessorTypes) yield createVar(tpe, _ => fn(id, accessor)) + ) private def extraValDefs = if (pv.rhs.isEmpty) Nil else List(pv.valDef) - def allValDefs = extraValDefs ::: (accessorPatternVars map (_.valDef)) + def allValDefs = extraValDefs ::: accessorPatternVars.valDefs // tests def isDefined = sym ne NoSymbol @@ -129,8 +117,12 @@ trait ParallelMatching extends ast.TreeDSL def isCaseClass = tpe.typeSymbol hasFlag Flags.CASE // sequences - def seqType = tpe.widen baseType SeqClass - def elemType = tpe typeArgs 0 + def seqType = tpe.widen baseType SeqClass + def elemType = tpe typeArgs 0 + def elemAt(i: Int) = (id DOT (tpe member nme.apply))(LIT(i)) + + def createElemVar(i: Int) = createVar(elemType, _ => elemAt(i)) + def createSeqVar(drop: Int) = createVar(seqType, _ => id DROP drop) // for propagating "unchecked" to synthetic vars def isChecked = !(sym hasFlag TRANS_FLAG) @@ -241,9 +233,17 @@ trait ParallelMatching extends ast.TreeDSL lazy val head = pmatch.head def codegen: Tree = IF (cond) THEN (success) ELSE (failure) - def mkFail(xs: List[Row]): Tree = xs match { - case Nil => failTree - case _ => make(scrut.pv :: rest.tvars, xs).toTree + def mkFail(xs: List[Row]): Tree = + if (xs.isEmpty) failTree + else remake(xs).toTree + + def remake( + rows: List[Row], + pvgroup: PatternVarGroup = emptyPatternVarGroup, + includeScrut: Boolean = true): Rep = + { + val scrutpvs = if (includeScrut) List(pmatch.scrut.pv) else Nil + make(pvgroup.pvs ::: scrutpvs ::: rest.tvars, rows) } /** translate outcome of the rule application into code (possible involving recursive application of rewriting) */ @@ -320,13 +320,13 @@ trait ParallelMatching extends ast.TreeDSL lazy val cases = for ((tag, indices) <- literalMap.toList) yield { val newRows = indices map (i => addDefaultVars(i)(rest rows i)) - val r = make(rest.tvars, newRows ::: defaultRows) - val r2 = make(r.tvars, r.rows map (x => x rebind bindVars(tag, x.subst))) + val r = remake(newRows ::: defaultRows, includeScrut = false) + val r2 = make(r.tvars, r.rows map (x => x rebind bindVars(tag, x.subst))) CASE(Literal(tag)) ==> r2.toTree } - lazy val defaultTree = make(rest.tvars, defaultRows).toTree + lazy val defaultTree = remake(defaultRows, includeScrut = false).toTree def casesWithDefault = cases ::: List(CASE(WILD(IntClass.tpe)) ==> defaultTree) // cond/success/failure only used if there is exactly one case. @@ -353,7 +353,7 @@ trait ParallelMatching extends ast.TreeDSL private lazy val zipped = pmatch pzip rest.rows - lazy val unapplyResult = + lazy val unapplyResult: PatternVar = scrut.createVar(app.tpe, lhs => reapply setType lhs.tpe) // XXX in transition. @@ -379,11 +379,9 @@ trait ParallelMatching extends ast.TreeDSL case _ => r insert (emptyPatterns(dum) ::: List(pat)) } - lazy val cond: Tree = { - val s = unapplyResult.valDef.symbol - if (s.tpe.isBoolean) ID(s) - else s IS_DEFINED - } + lazy val cond: Tree = + if (unapplyResult.tpe.isBoolean) ID(unapplyResult.valsym) + else unapplyResult.valsym IS_DEFINED lazy val failure = mkFail(zipped.tail filterNot (x => isSameUnapply(x._1)) map { case (pat, r) => r insert pat }) @@ -414,7 +412,7 @@ trait ParallelMatching extends ast.TreeDSL lazy val success = { val (squeezePVs, pvs, rows) = doSuccess - val srep = make(pvs ::: scrut.pv :: rest.tvars, rows).toTree + val srep = remake(rows, pvs).toTree squeezedBlockPVs(squeezePVs, srep) } @@ -466,14 +464,10 @@ trait ParallelMatching extends ast.TreeDSL lazy val (success, failure) = { assert(scrut.tpe <:< head.tpe, "fatal: %s is not <:< %s".format(scrut, head.tpe)) - def elemAt(i: Int) = (scrut.id DOT (scrut.tpe member nme.apply))(LIT(i)) - def elemCount = pivot.nonStarPatterns.size - val pvs = - // one per element .. pos = pat.pos - (pivot.nonStarPatterns.zipWithIndex map { case (pat, i) => scrut.createVar(scrut.elemType, _ => elemAt(i)) }) ::: - // and one for the rest of the sequence - List(scrut.createVar(scrut.seqType, _ => scrut.id DROP elemCount)) + // one pattern var per sequence element up to elemCount, and one more for the rest of the sequence + val elemCount = pivot.nonStarPatterns.size + val pvs = ((0 until elemCount).toList map (scrut createElemVar _)) ::: List(scrut createSeqVar elemCount) val (nrows, frows): (List[Option[Row]], List[Option[Row]]) = List.unzip( for ((c, rows) <- pmatch pzip rest.rows) yield getSubPatterns(c) match { @@ -482,12 +476,11 @@ trait ParallelMatching extends ast.TreeDSL } ) - val symList = if (pivot.hasStar) List(scrut.pv) else Nil - val succ = make(List(pvs, symList, rest.tvars).flatten, nrows.flatten) + val succ = remake(nrows.flatten, pvs, includeScrut = pivot.hasStar) ( squeezedBlockPVs(pvs, succ.toTree), - make(scrut.pv :: rest.tvars, frows.flatten).toTree + remake(frows.flatten).toTree ) } @@ -496,11 +489,8 @@ trait ParallelMatching extends ast.TreeDSL // @todo: equals test for same constant class MixEquals(val pmatch: PatternMatch, val rest: Rep) extends RuleApplication { - private def mkNewRep(rows: List[Row]) = - make(scrut.pv :: rest.tvars, rows).toTree - private lazy val labelBody = - mkNewRep(List.map2(rest.rows.tail, pmatch.tail)(_ insert _)) + remake(List.map2(rest.rows.tail, pmatch.tail)(_ insert _)).toTree private lazy val rhs = decodedEqualsType(head.tpe) match { @@ -514,11 +504,10 @@ trait ParallelMatching extends ast.TreeDSL lazy val cond = handleOuter(scrut.id MEMBER_== rhs) - lazy val success = - mkNewRep(List( - rest.rows.head.insert2(List(NoPattern), head.boundVariables, scrut.sym), - Row(emptyPatterns(1 + rest.tvars.length), NoBinding, NoGuard, shortCut(label)) - )) + lazy val success = remake(List( + rest.rows.head.insert2(List(NoPattern), head.boundVariables, scrut.sym), + Row(emptyPatterns(1 + rest.tvars.size), NoBinding, NoGuard, shortCut(label)) + )).toTree lazy val failure = LabelDef(label, Nil, labelBody) @@ -636,7 +625,7 @@ trait ParallelMatching extends ast.TreeDSL (rest rows j).insert2(ps, pmatch(j).boundVariables, casted.sym) val srep = - make(subtestVars ::: casted.accessorPatternVars ::: rest.tvars, newRows) + remake(newRows, subtestVars ::: casted.accessorPatternVars, includeScrut = false) squeezedBlock(casted.allValDefs, srep.toTree) } @@ -687,7 +676,8 @@ trait ParallelMatching extends ast.TreeDSL object ExpandedMatrix { def unapply(x: ExpandedMatrix) = Some((x.rows, x.targets)) - def apply(rows: List[Row], targets: List[FinalState]) = new ExpandedMatrix(rows, targets) + def apply(rowz: List[(Row, FinalState)]) = + new ExpandedMatrix(rowz map (_._1), rowz map (_._2)) } class ExpandedMatrix(val rows: List[Row], val targets: List[FinalState]) { @@ -697,8 +687,6 @@ trait ParallelMatching extends ast.TreeDSL "ExpandedMatrix(%d)".format(rows.size) + pp(rows zip targets, true) } - case class Branch[T](action: T, succ: Rep, fail: Option[Rep]) - abstract class State { def body: Tree def freeVars: List[Symbol] @@ -771,7 +759,7 @@ trait ParallelMatching extends ast.TreeDSL override def toString() = pp("Final%d%s".format(bx, pp(freeVars)) -> body) } - case class Rep(val tvars: List[PatternVar], val rows: List[Row]) { + case class Rep(val tvars: PatternVarGroup, val rows: List[Row]) { lazy val Row(pats, subst, guard, index) = rows.head lazy val guardedRest = if (guard.isEmpty) NoRep else make(tvars, rows.tail) lazy val (defaults, others) = pats span (_.isDefault) @@ -787,7 +775,7 @@ trait ParallelMatching extends ast.TreeDSL /** Now the 'i'th tvar is separated out and used as a new Scrutinee. */ private val (_pv, _tvars) = - extractIndex(tvars, index) + tvars extractIndex index /** The non-default pattern (others.head) replaces the column head. */ private val (_ncol, _nrep) = @@ -801,7 +789,7 @@ trait ParallelMatching extends ast.TreeDSL /** The VariableRule. */ private def variable() = { - val binding = (defaults map (_.boundVariables) zip tvars) . + val binding = (defaults map (_.boundVariables) zip tvars.pvs) . foldLeft(subst)((b, pair) => b.add(pair._1, pair._2.lhs)) VariableRule(binding, guard, guardedRest, index) @@ -828,8 +816,8 @@ trait ParallelMatching extends ast.TreeDSL val NoRep = Rep(Nil, Nil) /** Expands the patterns recursively. */ - final def expand(roots: List[PatternVar], cases: List[CaseDef]): ExpandedMatrix = { - val (rows, finals) = List.unzip( + final def expand(roots: List[PatternVar], cases: List[CaseDef]) = + tracing("Expanded", ExpandedMatrix( for ((CaseDef(pat, guard, body), index) <- cases.zipWithIndex) yield { def mkRow(ps: List[Tree]) = Row(toPats(ps), NoBinding, Guard(guard), index) @@ -840,13 +828,10 @@ trait ParallelMatching extends ast.TreeDSL case WILD() => emptyTrees(roots.length) }) - (row, FinalState(index, body, pattern.definedVars)) - } + row -> FinalState(index, body, pattern.definedVars) + }) ) - tracing("Expanded", new ExpandedMatrix(rows, finals)) - } - /** returns the condition in "if (cond) k1 else k2" */ final def condition(tpe: Type, scrut: Scrutinee): Tree = { -- cgit v1.2.3