From 90394b899f49cb98844197558b29582a483cc2f3 Mon Sep 17 00:00:00 2001 From: David MacIver Date: Mon, 3 Nov 2008 22:21:26 +0000 Subject: Large iamount of tidying up, mostly from paul. --- .../scala/tools/nsc/matching/MatchUtil.scala | 80 ++ .../tools/nsc/matching/ParallelMatching.scala | 1077 +++++++++----------- .../scala/tools/nsc/matching/PatternNodes.scala | 86 +- .../scala/tools/nsc/matching/TransMatcher.scala | 86 +- 4 files changed, 625 insertions(+), 704 deletions(-) create mode 100644 src/compiler/scala/tools/nsc/matching/MatchUtil.scala diff --git a/src/compiler/scala/tools/nsc/matching/MatchUtil.scala b/src/compiler/scala/tools/nsc/matching/MatchUtil.scala new file mode 100644 index 0000000000..f2d209233a --- /dev/null +++ b/src/compiler/scala/tools/nsc/matching/MatchUtil.scala @@ -0,0 +1,80 @@ +/* NSC -- new Scala compiler + */ + +package scala.tools.nsc.matching + +/** + * Utility classes, most of which probably belong somewhere else. + */ +object MatchUtil +{ + import collection.mutable.ListBuffer + + def impossible: Nothing = throw new RuntimeException("this never happens") + + object Implicits { + implicit def listPlusOps[T](xs: List[T]) = new ListPlus(xs) + } + + object Flags { + import symtab.Flags + import symtab.Symbols + + def propagateFlag(from: Symbols#Symbol, to: Symbols#Symbol, flag: Long) = if (from hasFlag flag) to setFlag flag + } + + class ListPlus[A](list: List[A]) { + /** Returns the list without the element at index n. + * If this list has fewer than n elements, the same list is returned. + * + * @param n the index of the element to drop. + * @return the list without the nth element. + */ + def dropIndex(n: Int) = list.take(n) ::: list.drop(n + 1) + + /** Returns a list formed from this list and the specified lists list2 + * and list3 by associating each element of the first list with + * the elements at the same positions in the other two. + * If any of the lists is shorter than the others, later elements in the other two are ignored. + * + * @return List((a0,b0), ..., + * (amin(m,n),bmin(m,n))) when + * List(a0, ..., am) + * zip List(b0, ..., bn) is invoked. + */ + def zip3[B, C](list2: List[B], list3: List[C]): List[(A, B, C)] = { + val b = new ListBuffer[(A, B, C)] + var xs1 = list + var xs2 = list2 + var xs3 = list3 + while (!xs1.isEmpty && !xs2.isEmpty && !xs3.isEmpty) { + b += ((xs1.head, xs2.head, xs3.head)) + xs1 = xs1.tail + xs2 = xs2.tail + xs3 = xs3.tail + } + b.toList + } + } + + object ListPlus { + /** Transforms a list of triples into a triple of lists. + * + * @param xs the list of triples to unzip + * @return a triple of lists. + */ + def unzip3[A,B,C](xs: List[(A,B,C)]): (List[A], List[B], List[C]) = { + val b1 = new ListBuffer[A] + val b2 = new ListBuffer[B] + val b3 = new ListBuffer[C] + var xc = xs + while (!xc.isEmpty) { + b1 += xc.head._1 + b2 += xc.head._2 + b3 += xc.head._3 + xc = xc.tail + } + (b1.toList, b2.toList, b3.toList) + } + } +} diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala index a01e31247d..156f1c134d 100644 --- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala +++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala @@ -8,8 +8,12 @@ package scala.tools.nsc.matching import util.Position -import collection.mutable.{ListBuffer, BitSet} +import collection.mutable.BitSet import collection.immutable.IntMap +import MatchUtil.ListPlus._ +import MatchUtil.Implicits._ +import MatchUtil.Flags._ +import MatchUtil._ /** Translation of match expressions. * @@ -59,18 +63,19 @@ trait ParallelMatching { val last = column.last; // TODO: This needs to also allow the case that the last is a compatible // type pattern. - simpleSwitchCandidate(last) || isDefaultPattern(last); - }} + simpleSwitchCandidate(last) || isDefaultPattern(last) + } + } // an unapply for which we don't need a type test def isUnapplyHead(): Boolean = column.head match { - case __UnApply(_,argtpe,_) => scrutinee.tpe <:< argtpe - case _ => false + case __UnApply(_,argtpe,_) => scrutinee.tpe <:< argtpe + case _ => false } // true if pattern type is direct subtype of scrutinee (can't use just <:< cause have to take variance into account) def directSubtype(ptpe: Type) = - (ptpe.parents.exists { x => ((x.typeSymbol eq scrutinee.tpe.typeSymbol) && (x <:< scrutinee.tpe))}); + ptpe.parents.exists(x => (x.typeSymbol eq scrutinee.tpe.typeSymbol) && (x <:< scrutinee.tpe)) // true if each pattern type is case and direct subtype of scrutinee def isFlatCases(col:List[Tree]): Boolean = (col eq Nil) || { @@ -91,7 +96,7 @@ trait ParallelMatching { case (x : ArrayValue) => if (isRightIgnoring(x)) new MixSequenceStar(scrutinee, column, rest) else new MixSequence(scrutinee, column, rest); case _ if isSimpleSwitch => new MixLiterals(scrutinee, column, rest) - case _ if (settings_casetags && (column.length > 1) && isFlatCases(column)) => new MixCases(scrutinee, column, rest) + case _ if settings_casetags && (column.length > 1) && isFlatCases(column) => new MixCases(scrutinee, column, rest) case _ if isUnapplyHead() => new MixUnapply(scrutinee, column, rest) case _ => new MixTypes(scrutinee, column, rest) } @@ -112,13 +117,13 @@ trait ParallelMatching { } case class ErrorRule(implicit rep:RepFactory) extends RuleApplication(rep) { - def scrutinee:Symbol = throw new RuntimeException("this never happens") + def scrutinee: Symbol = impossible final def tree(implicit theOwner: Symbol, failTree: Tree) = failTree } /** {case ... if guard => bx} else {guardedRest} */ case class VariableRule(subst:Binding, guard: Tree, guardedRest:Rep, bx: Int)(implicit rep:RepFactory) extends RuleApplication(rep) { - def scrutinee:Symbol = throw new RuntimeException("this never happens") + def scrutinee: Symbol = impossible final def tree(implicit theOwner: Symbol, failTree: Tree): Tree = { val body = typer.typed { rep.requestBody(bx, subst) } if (guard eq EmptyTree) @@ -133,9 +138,9 @@ trait ParallelMatching { /** superclass of mixture rules for case classes and literals (both translated to switch on an integer) */ - abstract class CaseRuleApplication(rep:RepFactory) extends RuleApplication(rep) { + abstract class CaseRuleApplication(rep: RepFactory) extends RuleApplication(rep) { def column: List[Tree] - def rest:Rep + def rest: Rep // e.g. (1,1) (1,3) (42,2) for column {case ..1.. => ;; case ..42..=> ;; case ..1.. => } var defaultV: collection.immutable.Set[Symbol] = emptySymbolSet @@ -147,12 +152,9 @@ trait ParallelMatching { } def haveDefault: Boolean = !defaultIndexSet.isEmpty - - lazy val defaultRows: List[Row] = - defaultIndexSet.toList.reverseMap(grabRow); + lazy val defaultRows: List[Row] = defaultIndexSet.toList.reverseMap(grabRow); protected var tagIndices = IntMap.empty[List[Int]] - protected def grabTemps: List[Symbol] = rest.temp protected def grabRow(index: Int): Row = { val r @ Row(_,s,_,_) = rest.row(index) @@ -174,10 +176,10 @@ trait ParallelMatching { /** returns * @return list of continuations, * @return variables bound to default continuation, - * @return optionally, a default continuation, + * @return optionally, a default continuation **/ - def getTransition(implicit theOwner: Symbol): (List[(Int,Rep)],Set[Symbol],Option[Rep]) = - (tagIndicesToReps, defaultV, {if (haveDefault) Some(defaultsToRep) else None}) + def getTransition(implicit theOwner: Symbol): (List[(Int,Rep)], Set[Symbol], Option[Rep]) = + (tagIndicesToReps, defaultV, if (haveDefault) Some(defaultsToRep) else None) } /** mixture rule for flat case class (using tags) @@ -203,39 +205,34 @@ trait ParallelMatching { } final def tree(implicit theOwner: Symbol, failTree: Tree): Tree = { - val (branches, defaultV, default) = getTransition // tag body pairs - - var ndefault = if (default.isEmpty) failTree else repToTree(default.get) - var cases = branches map { - case (tag, r) => - CaseDef(Literal(tag), - EmptyTree, - { - val pat = this.column(tagIndices(tag).head); - val ptpe = pat.tpe - if (this.scrutinee.tpe.typeSymbol.hasFlag(Flags.SEALED) && strip2(pat).isInstanceOf[Apply]) { - //cast - val vtmp = newVar(pat.pos, ptpe) - squeezedBlock( - List(typedValDef(vtmp, gen.mkAsInstanceOf(mkIdent(this.scrutinee), ptpe))), - repToTree(rep.make(vtmp :: r.temp.tail, r.row)) - ) - } else repToTree(r) - } - )} + val (branches, defaultV, defaultRep) = getTransition // tag body pairs + val isSealed = scrutinee.tpe.typeSymbol hasFlag Flags.SEALED + val cases: List[CaseDef] = for ((tag, r) <- branches) yield { + val pat = column(tagIndices(tag).head) + val t2 = strip2(pat) match { + case _: Apply if isSealed => + val vtmp = newVar(pat.pos, pat.tpe) + squeezedBlock( + List(typedValDef(vtmp, gen.mkAsInstanceOf(mkIdent(this.scrutinee), pat.tpe))), + repToTree(rep.make(vtmp :: r.temp.tail, r.row)) + ) + case _ => repToTree(r) + } - // make first case a default case. - if (this.scrutinee.tpe.typeSymbol.hasFlag(Flags.SEALED) && defaultV.isEmpty) { - ndefault = cases.head.body - cases = cases.tail + CaseDef(Literal(tag), EmptyTree, t2) } - cases.length match { - case 0 => ndefault - case 1 => val CaseDef(lit,_,body) = cases.head - If(Equals(Select(mkIdent(this.scrutinee), nme.tag), lit), body, ndefault) - case _ => val defCase = CaseDef(mk_(definitions.IntClass.tpe), EmptyTree, ndefault) - Match(Select(mkIdent(this.scrutinee),nme.tag), cases ::: defCase :: Nil) + // make first case a default case. + lazy val ndefault: Tree = defaultRep.map(repToTree) getOrElse failTree + lazy val defCase: CaseDef = CaseDef(mk_(definitions.IntClass.tpe), EmptyTree, ndefault) + val (first, rest) = + if (isSealed && defaultV.isEmpty) (cases.head.body, cases.tail) + else (ndefault, cases) + + rest match { + case Nil => ndefault + case CaseDef(lit,_,body) :: Nil => If(Equals(Select(mkIdent(this.scrutinee), nme.tag), lit), body, ndefault) + case _ => Match(Select(mkIdent(this.scrutinee), nme.tag), cases ::: List(defCase)) } } } @@ -266,15 +263,14 @@ trait ParallelMatching { } final def tree(implicit theOwner: Symbol, failTree: Tree): Tree = { - val (branches, defaultV, defaultRepOpt) = this.getTransition // tag body pairs - val cases = branches map { - case (tag, r) => - val r2 = rep.make(r.temp, r.row.map(x => x.insert2(Nil, bindVars(tag, x.subst)))) - val t2 = repToTree(r2) - CaseDef(Literal(tag), EmptyTree, t2) + val (branches, defaultV, defaultRep) = this.getTransition // tag body pairs + val cases = for ((tag, r) <- branches) yield { + val r2 = rep.make(r.temp, r.row.map(x => x.insert2(Nil, bindVars(tag, x.subst)))) + val t2 = repToTree(r2) + CaseDef(Literal(tag), EmptyTree, t2) } - lazy val ndefault = defaultRepOpt.map(repToTree) getOrElse failTree + lazy val ndefault = defaultRep.map(repToTree) getOrElse failTree lazy val defCase = CaseDef(mk_(definitions.IntClass.tpe), EmptyTree, ndefault) cases match { @@ -290,106 +286,81 @@ trait ParallelMatching { /** mixture rule for unapply pattern */ - class MixUnapply(val scrutinee: Symbol, val column: List[Tree], val rest: Rep)(implicit rep: RepFactory) extends RuleApplication(rep) { + class MixUnapply(val scrutinee: Symbol, val column: List[Tree], val rest: Rep)(implicit rep: RepFactory) + extends RuleApplication(rep) { + val (vs, unapp) = strip(column.head) + lazy val ua @ UnApply(app @ Apply(fn, appargs), args) = unapp def newVarCapture(pos:Position,tpe:Type)(implicit theOwner:Symbol) = { val v = newVar(pos,tpe) - if (scrutinee.hasFlag(Flags.TRANS_FLAG)) - v.setFlag(Flags.TRANS_FLAG) // propagate "unchecked" + propagateFlag(scrutinee, v, Flags.TRANS_FLAG) // propagate "unchecked" v } - private def bindToScrutinee(x:Symbol) = typedValDef(x,mkIdent(scrutinee)) - - val (vs,unapp) = strip(column.head) + private def bindToScrutinee(x: Symbol) = typedValDef(x, mkIdent(scrutinee)) // XXX this is dead code /** returns (unapply-call, success-rep, optional fail-rep*/ final def getTransition(implicit theOwner: Symbol): (Tree, List[Tree], Rep, Option[Rep]) = { - unapp match { - case ua @ UnApply(app @ Apply(fn, appargs), args) => - object sameUnapplyCall { - def unapply(t:Tree) = t match { - case UnApply(Apply(fn1,_), differentArgs) if (fn.symbol == fn1.symbol) && fn.equalsStructure(fn1) => - Some(differentArgs) - case _ => - None - } - } - val ures = newVarCapture(ua.pos, app.tpe) - val arg0 = mkIdent(scrutinee) - val rhs = Apply(fn, arg0 :: appargs.tail) setType ures.tpe - val uacall = typedValDef(ures, rhs) - - val nrowsOther = column.tail.zip(rest.row.tail) flatMap { - case (pat, r) => strip2(pat) match { - case sameUnapplyCall(_) => Nil - case _ => List(r.insert(pat)) - }} - val nrepFail = if (nrowsOther.isEmpty) - None - else - Some(rep.make(scrutinee::rest.temp, nrowsOther)) - args.length match { - case 0 => // special case for unapply(), app.tpe is boolean - val ntemps = scrutinee :: rest.temp - val nrows = column.zip(rest.row) map { - case (pat, r) => strip2(pat) match { - case sameUnapplyCall(args) => - r.insert2(List(EmptyTree), r.subst.add(strip1(pat), scrutinee)) - case _ => - r.insert(pat) - }} - (uacall, Nil, rep.make(ntemps, nrows), nrepFail) - - case 1 => // special case for unapply(p), app.tpe is Option[T] - val vtpe = app.tpe.typeArgs(0) - val vsym = newVarCapture(ua.pos, vtpe) - val ntemps = vsym :: scrutinee :: rest.temp - val nrows = column.zip(rest.row) map { - case (pat, r: Row) => strip2(pat) match { - case sameUnapplyCall(args) => - val nsubst = r.subst.add(strip1(pat), scrutinee) - r.insert2(List(args(0), EmptyTree), nsubst) - case _ => - r.insert(List(EmptyTree, pat)) - }} - - val vdef = typedValDef(vsym, Get(mkIdent(ures))) - (uacall, List(vdef), rep.make(ntemps, nrows), nrepFail) - - case _ => // app.tpe is Option[? <: ProductN[T1,...,Tn]] - val uresGet = newVarCapture(ua.pos, app.tpe.typeArgs(0)) - val vdefHead = typedValDef(uresGet, Get(mkIdent(ures))) - val ts = definitions.getProductArgs(uresGet.tpe).get - - val (vdefs: List[Tree], vsyms: List[Symbol]) = List.unzip( - for ((vtpe, i) <- ts.zip((1 to ts.size).toList)) yield { - val vchild = newVarCapture(ua.pos, vtpe) - val accSym = definitions.productProj(uresGet, i) - val rhs = typer.typed(Apply(Select(mkIdent(uresGet), accSym), Nil)) - - (typedValDef(vchild, rhs), vchild) - }) - - val ntemps = vsyms ::: scrutinee :: rest.temp - val dummies = getDummies(ts.size) - val nrows = column.zip(rest.row) map { - case (pat, r: Row) => strip2(pat) match { - case sameUnapplyCall(args) => - val nsubst = r.subst.add(strip1(pat), scrutinee) - r.insert2(args ::: List(EmptyTree), nsubst) - case _ => - r.insert(dummies ::: List(pat)) - }} - - (uacall, vdefHead :: vdefs, rep.make(ntemps, nrows), nrepFail) - }} + object sameUnapplyCall { + def unapply(t: Tree) = t match { + case UnApply(Apply(fn1,_), differentArgs) if (fn.symbol == fn1.symbol) && fn.equalsStructure(fn1) => + Some(differentArgs) + case _ => + None + } + } + val ures = newVarCapture(ua.pos, app.tpe) + val rhs = Apply(fn, mkIdent(scrutinee) :: appargs.tail) setType ures.tpe + val uacall = typedValDef(ures, rhs) + val zipped = column.zip(rest.row) + val nrowsOther = zipped.tail.flatMap { case (pat, r) => + strip2(pat) match { case sameUnapplyCall(_) => Nil ; case _ => List(r.insert(pat)) } + } + val nrepFail = + if (nrowsOther.isEmpty) None + else Some(rep.make(scrutinee::rest.temp, nrowsOther)) + + def mkTransition(vdefs: List[Tree], ntemps: List[Symbol], nrows: List[Row]) = + (uacall, vdefs, rep.make(ntemps ::: scrutinee :: rest.temp, nrows), nrepFail) + + def mkNewRows(sameFilter: (List[Tree]) => List[Tree], defaultTrees: List[Tree]) = + for ((pat, r) <- zipped) yield strip2(pat) match { + case sameUnapplyCall(args) => r.insert2(sameFilter(args) ::: List(EmptyTree), r.subst.add(strip1(pat), scrutinee)) + case _ => r.insert(defaultTrees ::: List(pat)) + } + + args.length match { + case 0 => // special case for unapply(), app.tpe is boolean + mkTransition(Nil, Nil, mkNewRows((xs) => Nil, Nil)) + + case 1 => // special case for unapply(p), app.tpe is Option[T] + val vtpe = app.tpe.typeArgs(0) + val vsym = newVarCapture(ua.pos, vtpe) + val nrows = mkNewRows((xs) => List(xs.head), List(EmptyTree)) + val vdef = typedValDef(vsym, Get(mkIdent(ures))) + mkTransition(List(vdef), List(vsym), nrows) + + case _ => // app.tpe is Option[? <: ProductN[T1,...,Tn]] + val uresGet = newVarCapture(ua.pos, app.tpe.typeArgs(0)) + val vdefHead = typedValDef(uresGet, Get(mkIdent(ures))) + val ts = definitions.getProductArgs(uresGet.tpe).get + val nrows = mkNewRows(identity, getDummies(ts.size)) + val (vdefs: List[Tree], vsyms: List[Symbol]) = List.unzip( + for ((vtpe, i) <- ts.zip((1 to ts.size).toList)) yield { + val vchild = newVarCapture(ua.pos, vtpe) + val accSym = definitions.productProj(uresGet, i) + val rhs = typer.typed(Apply(Select(mkIdent(uresGet), accSym), Nil)) + + (typedValDef(vchild, rhs), vchild) + }) + mkTransition(vdefHead :: vdefs, vsyms, nrows) + } } /* def getTransition(...) */ final def tree(implicit theOwner: Symbol, failTree: Tree) = { - val (uacall , vdefs,srep,frep) = this.getTransition + val (uacall, vdefs, srep, frep) = this.getTransition val succ = repToTree(srep) - val fail = if (frep.isEmpty) failTree else repToTree(frep.get) + val fail = frep.map(repToTree) getOrElse failTree val cond = if (uacall.symbol.tpe.typeSymbol eq definitions.BooleanClass) typer.typed{ mkIdent(uacall.symbol) } @@ -406,8 +377,8 @@ trait ParallelMatching { private val sequenceType = scrutinee.tpe.widen.baseType(definitions.SeqClass) private val elementType = getElemType_Sequence(scrutinee.tpe) - final def removeStar(xs:List[Tree]):List[Tree] = - xs.take(xs.length-1) ::: makeBind(strip1(xs.last).toList, mk_(sequenceType)) :: Nil + final def removeStar(xs: List[Tree]): List[Tree] = + xs.init ::: makeBind(strip1(xs.last).toList, mk_(sequenceType)) :: Nil protected def getSubPatterns(len:Int, x:Tree):Option[List[Tree]] = x match { case av @ ArrayValue(_,xs) if (!isRightIgnoring(av) && xs.length == len) => Some(xs ::: List(EmptyTree)) @@ -417,7 +388,7 @@ trait ParallelMatching { } protected def makeSuccRep(vs:List[Symbol], tail:Symbol, nrows:List[Row])(implicit theOwner: Symbol) = - rep.make( vs ::: tail :: rest.temp, nrows.toList) + rep.make(vs ::: tail :: rest.temp, nrows.toList) /** returns true if x is more general than y */ protected def subsumes(x:Tree, y:Tree): Boolean = (x,y) match { @@ -443,20 +414,17 @@ trait ParallelMatching { lazy val tail = newVar(scrutinee.pos, sequenceType) lazy val lastBinding = if (ys.size > 0) seqDrop(treeAsSeq.duplicate, ys.size) else mkIdent(scrutinee) val bindings = - (vs.zipWithIndex map { case (v, i) => typedValDef(v, seqElement(treeAsSeq.duplicate, i)) }) ::: + (for ((v, i) <- vs.zipWithIndex) yield typedValDef(v, seqElement(treeAsSeq.duplicate, i))) ::: List(typedValDef(tail, lastBinding)) - val nrows = new ListBuffer[Row] - val frows = new ListBuffer[Row] - - for ((c, row) <- column.zip(rest.row)) - getSubPatterns(ys.size, c) match { - case Some(ps) => nrows += row.insert(ps) ; if (isDefaultPattern(c) || subsumes(c, av)) frows += row.insert(c) - case None => frows += row.insert(c) - } + val (nrows, frows)/* : (List[Option[Row]], List[Option[Row]]) */ = List.unzip( + for ((c, row) <- column.zip(rest.row)) yield getSubPatterns(ys.size, c) match { + case Some(ps) => (Some(row.insert(ps)), if (isDefaultPattern(c) || subsumes(c, av)) Some(row.insert(c)) else None) + case None => (None, Some(row.insert(c))) + }) - val succRep = makeSuccRep(vs, tail, nrows.toList) - val failRep = rep.make(scrutinee :: rest.temp, frows.toList) + val succRep = makeSuccRep(vs, tail, nrows.flatMap(x => x)) + val failRep = rep.make(scrutinee :: rest.temp, frows.flatMap(x => x)) // fixed length val cond = getCond(treeAsSeq, xs.length) @@ -494,7 +462,7 @@ trait ParallelMatching { } override protected def makeSuccRep(vs:List[Symbol], tail:Symbol, nrows:List[Row])(implicit theOwner: Symbol) = - rep.make( vs ::: tail :: scrutinee :: rest.temp, nrows) + rep.make(vs ::: tail :: scrutinee :: rest.temp, nrows) // lengthArg is minimal length override protected def getCond(tree:Tree, lengthArg:Int) = seqLongerThan(tree.duplicate, column.head.tpe, lengthArg - 1) @@ -507,46 +475,40 @@ trait ParallelMatching { final def getTransition(implicit theOwner: Symbol): (Tree, Rep, Symbol, Rep) = { val nmatrix = rest val vlue = (column.head.tpe: @unchecked) match { - case TypeRef(_,_,List(SingleType(pre,sym))) => - gen.mkAttributedRef(pre,sym) - case TypeRef(_,_,List(PseudoType(o))) => - o.duplicate + case TypeRef(_,_,List(SingleType(pre,sym))) => gen.mkAttributedRef(pre,sym) + case TypeRef(_,_,List(PseudoType(o))) => o.duplicate } assert(vlue.tpe ne null, "value tpe is null") - val vs = strip1(column.head) - val nsuccFst = rest.row.head match { case r: Row => r.insert2(List(EmptyTree), r.subst.add(vs, scrutinee)) } - val fLabel = theOwner.newLabel(scrutinee.pos, cunit.fresh.newName(scrutinee.pos, "failCont%")) // warning, untyped - val sx = rep.shortCut(fLabel) // register shortcut - val nsuccRow = nsuccFst :: Row(getDummies( 1 /*scrutinee*/ + rest.temp.length), NoBinding, EmptyTree, sx) :: Nil + val vs = strip1(column.head) + val nsuccFst = rest.row.head match { case r => r.insert2(List(EmptyTree), r.subst.add(vs, scrutinee)) } + val fLabel = theOwner.newLabel(scrutinee.pos, cunit.fresh.newName(scrutinee.pos, "failCont%")) // warning, untyped + val sx = rep.shortCut(fLabel) // register shortcut + val nsuccRow = nsuccFst :: Row(getDummies( 1 /*scrutinee*/ + rest.temp.length), NoBinding, EmptyTree, sx) :: Nil // todo: optimize if no guard, and no further tests val nsucc = rep.make(scrutinee :: rest.temp, nsuccRow) val nfail = repWithoutHead(column, rest) - return (typer.typed{ Equals(mkIdent(scrutinee) setType scrutinee.tpe, vlue) }, nsucc, fLabel, nfail) + + (typer.typed(Equals(mkIdent(scrutinee) setType scrutinee.tpe, vlue)), nsucc, fLabel, nfail) } final def tree(implicit theOwner: Symbol, failTree: Tree) = { val (cond, srep, fLabel, frep) = this.getTransition - val cond2 = typer.typed { rep.handleOuter(cond) } - val fail = typer.typed { repToTree(frep) } - fLabel setInfo (new MethodType(Nil, fail.tpe)) + val cond2 = typer.typed( rep.handleOuter(cond) ) + val fail = typer.typed( repToTree(frep) ) + fLabel setInfo MethodType(Nil, fail.tpe) val succ = repToTree(srep) - typer.typed{ If(cond2, succ, LabelDef(fLabel, Nil, fail)) } + typer.typed( If(cond2, succ, LabelDef(fLabel, Nil, fail)) ) } } /** mixture rule for type tests **/ class MixTypes(val scrutinee: Symbol, val column: List[Tree], val rest: Rep)(implicit rep: RepFactory) extends RuleApplication(rep) { - - var casted: Symbol = null - var moreSpecific: List[Tree] = Nil - var subsumed: List[(Int,List[Tree])] = Nil // row index and subpatterns - var remaining: List[(Int,Tree)] = Nil // row index and pattern - + // TODO: this flag is never examined val isExhaustive = !scrutinee.tpe.typeSymbol.hasFlag(Flags.SEALED) || { - val tpes = column.map {x => x.tpe.typeSymbol} - scrutinee.tpe.typeSymbol.children.forall { sym => tpes.contains(sym) } + val tpes = column.map(x => x.tpe.typeSymbol) + scrutinee.tpe.typeSymbol.children.forall(sym => tpes.contains(sym)) } private val headPatternType = strip2(column.head) match { @@ -558,18 +520,13 @@ trait ParallelMatching { private val isCaseHead = isCaseClass(headPatternType) private val dummies = if (!isCaseHead) Nil else getDummies(headPatternType.typeSymbol.caseFieldAccessors.length) - private def subpatterns(pat:Tree): List[Tree] = { - pat match { - case Bind(_,p) => - subpatterns(p) - case app @ Apply(fn, pats) if isCaseClass(app.tpe) && fn.isType => - if (isCaseHead) pats else dummies - case Apply(fn,xs) => assert((xs.isEmpty) && (!fn.isType), "strange Apply"); dummies // named constant - case _: UnApply => - dummies - case pat => - dummies - } + private def subpatterns(pat: Tree): List[Tree] = pat match { + case Bind(_,p) => subpatterns(p) + case app @ Apply(fn, pats) if isCaseClass(app.tpe) && fn.isType => if (isCaseHead) pats else dummies + case Apply(fn, xs) => // named constant + assert(xs.isEmpty && !fn.isType, "strange Apply"); dummies + // case _: UnApply => dummies + case _ => dummies } /** an approximation of _tp1 <:< tp2 that ignores _ types. this code is wrong, @@ -577,11 +534,10 @@ trait ParallelMatching { */ def subsumes_erased(_tp1:Type, tp2:Type) = { val tp1 = patternType_wrtEquals(_tp1) - tp1.isInstanceOf[TypeRef] && tp2.isInstanceOf[TypeRef] && - ((tp1.prefix =:= tp2.prefix) && - ((tp1.typeSymbol eq tp2.typeSymbol) && - (tp1.typeSymbol ne definitions.ArrayClass)) || - tp1.parents.exists(_.typeSymbol eq tp2.typeSymbol)) + lazy val eqSymbolsNotArray = (tp1.typeSymbol eq tp2.typeSymbol) && (tp1.typeSymbol ne definitions.ArrayClass) + tp1.isInstanceOf[TypeRef] && + tp2.isInstanceOf[TypeRef] && + ((tp1.prefix =:= tp2.prefix) && eqSymbolsNotArray || tp1.parents.exists(_.typeSymbol eq tp2.typeSymbol)) // rather: tp1.baseTypes.exists...? } @@ -593,42 +549,36 @@ trait ParallelMatching { headPatternType =:= singleType(pat.symbol.tpe.prefix, pat.symbol) } - /*init block*/ { - var sr = (moreSpecific,subsumed,remaining) - for ((pat, j) <- column.zipWithIndex){ - val (ms,ss,rs) = sr // more specific, more general(subsuming current), remaining patterns - val strippedPattern = strip2(pat) - val patternType = strippedPattern.tpe - sr = strippedPattern match { - case Literal(Constant(null)) if !(headPatternType =:= patternType) => // special case for constant null pattern - (ms,ss,(j,pat)::rs); - case _ if objectPattern(pat) => - (EmptyTree::ms, (j,dummies)::ss, rs); // matching an object - - case Typed(p, _) if (strip2(p).isInstanceOf[UnApply] && (patternType /*is never */ <:< headPatternType)) => - (p::ms, (j, dummies)::ss, rs); - - case q @ Typed(pp,_) if (patternType_wrtEquals(patternType) <:< headPatternType) => - ({if (pat.tpe =:= headPatternType /*never true for */) pp else q}::ms, (j, dummies)::ss, rs); - - case z:UnApply => - (ms,ss,(j,pat)::rs) - - case qq if subsumes_erased(patternType, headPatternType) || (patternType_wrtEquals(patternType) <:< headPatternType) && !isDefaultPattern(pat) => - ({if (pat.tpe =:= headPatternType /*never true for */) EmptyTree else pat}::ms, (j,subpatterns(pat))::ss, rs); - - case _ if subsumes_erased(headPatternType, patternType) || (headPatternType <:< patternType /*never true for */) || isDefaultPattern(pat) => - (EmptyTree::ms, (j, dummies)::ss, (j,pat)::rs) // subsuming (matched *and* remaining pattern) - + // moreSpecific: more specific patterns + // subsumed: more general patterns (subsuming current), row index and subpatterns + // remaining: remaining, row index and pattern + def join[T](xs: List[Option[T]]): List[T] = xs.flatMap(x => x) + val (moreSpecific, subsumed, remaining) : (List[Tree], List[(Int, List[Tree])], List[(Int, Tree)]) = unzip3( + for ((pat, j) <- column.zipWithIndex) yield { + val spat = strip2(pat) + val patType = spat.tpe + + // each pattern will yield a triple of options corresponding to the three lists, which will be flattened down to the values + spat match { + case Literal(Constant(null)) if !(headPatternType =:= patType) => // special case for constant null pattern + (None, None, Some((j, pat))) + case _ if objectPattern(pat) => // matching an object + (Some(EmptyTree), Some((j, dummies)), None) + case Typed(p, _) if (strip2(p).isInstanceOf[UnApply] && (patType <:< headPatternType)) => // <:< is never + (Some(p), Some((j, dummies)), None) + case q @ Typed(pp, _) if patternType_wrtEquals(patType) <:< headPatternType => + (Some(if (pat.tpe =:= headPatternType) pp else q), Some((j, dummies)), None) // never =:= for + case z: UnApply => + (None, None, Some((j, pat))) + case qq if subsumes_erased(patType, headPatternType) || (patternType_wrtEquals(patType) <:< headPatternType) && !isDefaultPattern(pat) => + (Some(if (pat.tpe =:= headPatternType) EmptyTree else pat), Some((j, subpatterns(pat))), None) // never =:= for + case _ if subsumes_erased(headPatternType, patType) || (headPatternType <:< patType) || isDefaultPattern(pat) => // never <:< for + (Some(EmptyTree), Some((j, dummies)), Some((j, pat))) // subsuming (matched *and* remaining pattern) case _ => - (ms,ss,(j,pat)::rs) + (None, None, Some((j, pat))) } } - this.moreSpecific = sr._1.reverse - this.subsumed = sr._2.reverse - this.remaining = sr._3.reverse - sr = null - } /* init block */ + ) match { case (x,y,z) => (join(x), join(y), join(z)) } override def toString = { "MixTypes("+scrutinee+":"+scrutinee.tpe+") {\n moreSpecific:"+moreSpecific+"\n subsumed:"+subsumed+"\n remaining"+remaining+"\n}" @@ -636,82 +586,75 @@ trait ParallelMatching { /** returns casted symbol, success matrix and optionally fail matrix for type test on the top of this column */ final def getTransition(implicit theOwner: Symbol): (Symbol, Rep, Option[Rep]) = { - casted = if (scrutinee.tpe =:= headPatternType) scrutinee else newVar(scrutinee.pos, headPatternType) - if (scrutinee.hasFlag(Flags.TRANS_FLAG)) - casted.setFlag(Flags.TRANS_FLAG) + val casted = if (scrutinee.tpe =:= headPatternType) scrutinee else newVar(scrutinee.pos, headPatternType) + propagateFlag(scrutinee, casted, Flags.TRANS_FLAG) // succeeding => transition to translate(subsumed) (taking into account more specific) val nmatrix = { - var ntemps = if (!isCaseHead) Nil else casted.caseFieldAccessors map { - meth => + var ntemps = + if (!isCaseHead) Nil + else for (meth <- casted.caseFieldAccessors) yield { val ctemp = newVar(scrutinee.pos, casted.tpe.memberType(meth).resultType) - if (scrutinee.hasFlag(Flags.TRANS_FLAG)) - ctemp.setFlag(Flags.TRANS_FLAG) + propagateFlag(scrutinee, ctemp, Flags.TRANS_FLAG) ctemp - } // (***) flag needed later - var subtests = subsumed - if (moreSpecific.exists { x => x != EmptyTree }) { - ntemps = casted::ntemps - subtests = moreSpecific.zip(subsumed) map { - case (mspat, (j,pats)) => (j,mspat::pats) } - } + val subtests = + if (!moreSpecific.exists(_ != EmptyTree)) subsumed + else { + ntemps = casted :: ntemps + moreSpecific.zip(subsumed) map { case (mspat, (j, pats)) => (j, mspat::pats) } + } + ntemps = ntemps ::: rest.temp - val ntriples = subtests map { - case (j, pats) => - val (vs, thePat) = strip(column(j)) - val r = rest.row(j) - val nsubst = r.subst.add(vs, casted) - r.insert2(pats, nsubst) + val ntriples = for ((j, pats) <- subtests) yield { + val (vs, thePat) = strip(column(j)) + val r = rest.row(j) + val nsubst = r.subst.add(vs, casted) + r.insert2(pats, nsubst) } rep.make(ntemps, ntriples) } // fails => transition to translate(remaining) val nmatrixFail: Option[Rep] = { val ntemps = scrutinee :: rest.temp - val ntriples = remaining map { case (j, pat) => rest.row(j).insert(pat) } + val ntriples = for ((j, pat) <- remaining) yield rest.row(j).insert(pat) if (ntriples.isEmpty) None else Some(rep.make(ntemps, ntriples)) } (casted, nmatrix, nmatrixFail) } final def tree(implicit theOwner: Symbol, failTree: Tree): Tree = { - val (casted,srep,frep) = this.getTransition + val (casted, srep, frep) = this.getTransition val condUntyped = condition(casted.tpe, this.scrutinee) var cond = rep.handleOuter(typer.typed { condUntyped }) if (needsOuterTest(casted.tpe, this.scrutinee.tpe, theOwner)) { // @todo merge into def condition cond = addOuterCondition(cond, casted.tpe, mkIdent(this.scrutinee), rep.handleOuter) } val succ = repToTree(srep) - - val fail = if (frep.isEmpty) failTree else repToTree(frep.get) + val fail = frep.map(repToTree) getOrElse failTree // dig out case field accessors that were buried in (***) val cfa = if (!isCaseHead) Nil else casted.caseFieldAccessors val caseTemps = (if (!srep.temp.isEmpty && srep.temp.head == casted) srep.temp.tail else srep.temp).zip(cfa) - var vdefs = caseTemps map { - p => - val tmp = p._1; - val accessorMethod = p._2 - val untypedAccess = Apply(Select(mkIdent(casted), accessorMethod),List()) - val typedAccess = typer.typed { untypedAccess } - typedValDef(tmp, typedAccess) + var vdefs = for ((tmp, accessorMethod) <- caseTemps) yield { + val untypedAccess = Apply(Select(mkIdent(casted), accessorMethod), Nil) + val typedAccess = typer.typed(untypedAccess) + typedValDef(tmp, typedAccess) } if (casted ne this.scrutinee) vdefs = ValDef(casted, gen.mkAsInstanceOf(mkIdent(this.scrutinee), casted.tpe)) :: vdefs - return typer.typed { If(cond, squeezedBlock(vdefs, succ), fail) } + return typer.typed( If(cond, squeezedBlock(vdefs, succ), fail) ) } } /** converts given rep to a tree - performs recursive call to translation in the process to get sub reps */ - final def repToTree(r: Rep)(implicit theOwner: Symbol, failTree: Tree, rep: RepFactory): Tree = { + final def repToTree(r: Rep)(implicit theOwner: Symbol, failTree: Tree, rep: RepFactory): Tree = r.applyRule.tree - } - case class Row(pat:List[Tree], subst:Binding, guard:Tree, bx:Int) { + case class Row(pat:List[Tree], subst: Binding, guard: Tree, bx: Int) { def insert(h: Tree) = Row(h :: pat, subst, guard, bx) // prepends supplied tree def insert(hs: List[Tree]) = Row(hs ::: pat, subst, guard, bx) def insert2(hs: List[Tree], b: Binding) = Row(hs ::: pat, b, guard, bx) // prepends and substitutes @@ -720,9 +663,10 @@ trait ParallelMatching { object Rep { type RepType = Product2[List[Symbol], List[Row]] - final def unapply(x:Rep)(implicit rep:RepFactory):Option[RepType] = + final def unapply(x:Rep)(implicit rep:RepFactory): Option[RepType] = if (x.isInstanceOf[rep.RepImpl]) Some(x.asInstanceOf[RepType]) else None } + class RepFactory(val handleOuter: Tree => Tree)(implicit val typer : Typer) { case class RepImpl(val temp:List[Symbol], val row:List[Row]) extends Rep with Rep.RepType { (row.find { case Row(pats, _, _, _) => temp.length != pats.length }) match { @@ -734,24 +678,24 @@ trait ParallelMatching { } var vss: List[SymList] = _ - var labels: Array[Symbol] = new Array[Symbol](4) + var labels: Array[Symbol] = new Array[Symbol](4) var targets: List[Tree] = _ - var reached : BitSet = _; - var shortCuts: List[Symbol] = Nil; + var reached: BitSet = _ + var shortCuts: List[Symbol] = Nil final def make(temp:List[Symbol], row:List[Row], targets: List[Tree], vss:List[SymList])(implicit theOwner: Symbol): Rep = { // ensured that labels(i) eq null for all i, cleanup() has to be called after translation - this.targets = targets + this.targets = targets if (targets.length > labels.length) - this.labels = new Array[Symbol](targets.length) - this.vss = vss - this.reached = new BitSet(targets.length); - return make(temp, row) + this.labels = new Array[Symbol](targets.length) + this.vss = vss + this.reached = new BitSet(targets.length) + make(temp, row) } final def shortCut(theLabel:Symbol): Int = { - this.shortCuts = shortCuts:::theLabel::Nil; - return -shortCuts.length + shortCuts = shortCuts ::: List(theLabel) + -shortCuts.length } final def cleanup(tree: Tree)(implicit theOwner: Symbol): Tree = { @@ -759,11 +703,8 @@ trait ParallelMatching { override def transform(tree:Tree): Tree = tree match { case blck @ Block(vdefs, ld @ LabelDef(name,params,body)) => val bx = labelIndex(ld.symbol) - if ((bx >= 0) && !isReachedTwice(bx)) { - squeezedBlock(vdefs,body) - } - else - blck + if (bx >= 0 && !isReachedTwice(bx)) squeezedBlock(vdefs,body) + else blck case If(cond, Literal(Constant(true)), Literal(Constant(false))) => super.transform(cond) @@ -779,22 +720,19 @@ trait ParallelMatching { cleanup() res } + final def cleanup() { - var i = targets.length; - while (i>0) { i-=1; labels(i) = null; }; - reached = null; + for (i <- 0 until targets.length) labels(i) = null + reached = null shortCuts = Nil } - final def isReached(bx:Int) = { labels(bx) ne null } + final def isReached(bx:Int) = labels(bx) ne null final def markReachedTwice(bx:Int) { reached += bx } /** @pre bx < 0 || labelIndex(bx) != -1 */ - final def isReachedTwice(bx:Int) = (bx < 0) || reached(bx) + final def isReachedTwice(bx: Int) = (bx < 0) || reached(bx) /* @returns bx such that labels(bx) eq label, -1 if no such bx exists */ - final def labelIndex(label:Symbol): Int = { - var bx = 0; while((bx < labels.length) && (labels(bx) ne label)) { bx += 1 } - if (bx >= targets.length) bx = -1 - return bx - } + final def labelIndex(label: Symbol): Int = labels.findIndexOf(_ eq label) + /** first time bx is requested, a LabelDef is returned. next time, a jump. * the function takes care of binding */ @@ -805,64 +743,144 @@ trait ParallelMatching { } if (!isReached(bx)) { // first time this bx is requested // might be bound elsewhere ( see `x @ unapply' ) <-- this comment refers to null check - val allVs = + val (vsyms, argts, vdefs) : (List[Symbol], List[Type], List[Tree]) = unzip3( for (v <- vss(bx) ; val substv = subst(v) ; if substv ne null) yield (v, v.tpe, typedValDef(v, substv)) + ) - val vsyms : List[Symbol] = allVs.map(_._1) - val argts : List[Type] = allVs.map(_._2) - val vdefs : List[Tree] = allVs.map(_._3) + val body = targets(bx) + // @bug: typer is not able to digest a body of type Nothing being assigned result type Unit + val tpe = if (body.tpe.typeSymbol eq definitions.NothingClass) body.tpe else resultType + val label = theOwner.newLabel(body.pos, "body%"+bx) setInfo MethodType(argts, tpe) + labels(bx) = label - val body = targets(bx) - // @bug: typer is not able to digest a body of type Nothing being assigned result type Unit - val tpe = if (body.tpe.typeSymbol eq definitions.NothingClass) body.tpe else resultType - val label = theOwner.newLabel(body.pos, "body%"+bx).setInfo(new MethodType(argts, tpe)) - labels(bx) = label + return body match { + case _: Throw | _: Literal => squeezedBlock(vdefs, body.duplicate setType tpe) + case _ => squeezedBlock(vdefs.reverse, LabelDef(label, vsyms, body setType tpe)) + } + } - return body match { - case _: Throw | _: Literal => squeezedBlock(vdefs, body.duplicate setType tpe) - case _ => squeezedBlock(vdefs.reverse, LabelDef(label, vsyms, body setType tpe)) + // if some bx is not reached twice, its LabelDef is replaced with body itself + markReachedTwice(bx) + val args: List[Ident] = vss(bx).map(subst) + val label = labels(bx) + val body = targets(bx) + val MethodType(fmls, _) = label.tpe + + // sanity checks + if (fmls.length != args.length) { + cunit.error(body.pos, "consistency problem in target generation ! I have args "+ + args+" and need to jump to a label with fmls "+fmls) + throw FatalError("consistency problem") + } + fmls.zip(args).find(x => !(x._2.tpe <:< x._1)) match { + case Some(Tuple2(f, a)) => cunit.error(body.pos, "consistency problem ! "+a.tpe+" "+f) ; throw FatalError("consistency problem") + case None => } - } - // if some bx is not reached twice, its LabelDef is replaced with body itself - markReachedTwice(bx) - val args : List[Ident] = vss(bx).map(subst) - val label = labels(bx) - val body = targets(bx) - val MethodType(fmls, _) = label.tpe - - // sanity checks - if (fmls.length != args.length) { - cunit.error(body.pos, "consistency problem in target generation ! I have args "+ - args+" and need to jump to a label with fmls "+fmls) - throw FatalError("consistency problem") - } - fmls.zip(args).find(x => !(x._2.tpe <:< x._1)) match { - case Some(Tuple2(f, a)) => cunit.error(body.pos, "consistency problem ! "+a.tpe+" "+f) ; throw FatalError("consistency problem") - case None => + body match { + case _: Throw | _: Literal => // might be bound elsewhere (see `x @ unapply') + val vdefs = for (v <- vss(bx) ; val substv = subst(v) ; if substv ne null) yield typedValDef(v, substv) + squeezedBlock(vdefs, body.duplicate setType resultType) + case _ => + Apply(mkIdent(label),args) + } } - body match { - case _: Throw | _: Literal => // might be bound elsewhere (see `x @ unapply') - val vdefs = for (v <- vss(bx) ; val substv = subst(v) ; if substv ne null) yield typedValDef(v, substv) - squeezedBlock(vdefs, body.duplicate setType resultType) - case _ => - Apply(mkIdent(label),args) - } - } + /** the injection here handles alternatives and unapply type tests */ + final def make(temp: List[Symbol], row1: List[Row])(implicit theOwner: Symbol): Rep = { + var unchanged: Boolean = true + // equals check: call singleType(NoPrefix, o.symbol) `stpe'. Then we could also return + // `typeRef(definitions.ScalaPackageClass.tpe, definitions.EqualsPatternClass, List(stpe))' + // and force an equality check. However, exhaustivity checking would not work anymore. + // so first, extend exhaustivity check to equalspattern + def sType(o: Tree) = singleType(o.tpe.prefix, o.symbol) + def equalsCheck(o: Tree) =if (o.symbol.isValue) singleType(NoPrefix, o.symbol) else sType(o) + + def classifyPat(opat: Tree, j: Int): Tree = { + val (vs, strippedPat) = strip(opat) match { case (vset, pat) => (vset.toList, pat) } + + (strippedPat: @unchecked) match { + case p @ Alternative(ps) => + DBG("Alternative") ; opat + case typat @ Typed(p, tpt) if strip2(p).isInstanceOf[UnApply]=> + DBG("Typed") + if (temp(j).tpe <:< tpt.tpe) makeBind(vs, p) else opat + + case Ident(nme.WILDCARD) | EmptyTree | _:Literal | _:Typed => + DBG("Ident(_)|EmptyTree") ; opat + case o @ Ident(n) => // n != nme.WILDCARD + DBG("Ident") + val tpe = equalsCheck(o) + val p = Ident(nme.WILDCARD) setType tpe + val q = Typed(p, TypeTree(tpe)) setType tpe + makeBind(vs, q) setType tpe + + case o @ Select(stor,_) => + DBG("Select") + val stpe = equalsCheck(o) + val p = Ident(nme.WILDCARD) setType stpe + makeBind(vs, Typed(p, TypeTree(stpe)) setType stpe) setType stpe + + case UnApply(Apply(TypeApply(sel @ Select(stor, nme.unapplySeq), List(tptArg)),_),ArrayValue(_,xs)::Nil) + if (stor.symbol eq definitions.ListModule) => + DBG("Unapply(...TypeApply...)") + // @pre: is not right-ignoring (no star pattern) + // no exhaustivity check, please + temp(j) setFlag Flags.TRANS_FLAG + val listType = typeRef(mkThisType(definitions.ScalaPackage), definitions.ListClass, List(tptArg.tpe)) + makeBind(vs, normalizedListPattern(xs, tptArg.tpe)) + + // @todo: rewrite, using __UnApply instead of UnApply like so: + // case ua @ __UnApply(_,argtpe,_) => + // val ua = prepat + // val npat = (if (temp(j).tpe <:< argtpe) ua else Typed(ua,TypeTree(argtpe)).setType(argtpe)) + // pats = (makeBind(vs, npat) setType argtpe)::pats + case ua @ UnApply(Apply(fn, _), _) => + DBG("Unapply(Apply())") + val MethodType(List(argtpe, _*), _) = fn.tpe + val npat = if (temp(j).tpe <:< argtpe) ua else Typed(ua, TypeTree(argtpe)).setType(argtpe) + makeBind(vs, npat) setType argtpe + + case o @ Apply(fn, Nil) if !isCaseClass(o.tpe) || /*see t301*/ !Apply_Value.unapply(o).isEmpty => + DBG("Apply !isCaseClass") + val stpe: Type = fn match { + case _ if o.symbol.isModule || o.tpe.termSymbol.isModule => sType(o) + case Select(path, sym) => path.tpe match { + case t @ ThisType(sym) => singleType(t, o.symbol) + // next two cases: e.g. `case Some(p._2)' in scala.collection.jcl.Map + case _ if path.isInstanceOf[Apply] => PseudoType(o) // outer-matching: test/files/pos/t154.scala + case _ => singleType(sType(path), o.symbol) // old + } + case o: Ident => equalsCheck(o) + } + val ttst = typeRef(NoPrefix, definitions.EqualsPatternClass, List(stpe)) + val p = Ident(nme.WILDCARD) setType ttst + makeBind(vs, Typed(p, TypeTree(stpe)) setType ttst) + + case Apply_Value(pre, sym) => + DBG("Apply_Value") + val tpe = typeRef(NoPrefix, definitions.EqualsPatternClass, List(singleType(pre, sym))) + makeBind(vs, Typed(EmptyTree, TypeTree(tpe)) setType tpe) + + case Apply_CaseClass_NoArgs(tpe) => // no-args case class pattern + DBG("Apply_CaseClass_NoArgs") + makeBind(vs, Typed(EmptyTree, TypeTree(tpe)) setType tpe) + + case Apply_CaseClass_WithArgs() => // case class pattern with args + DBG("Apply_CaseClass_WithArgs") ; opat + case ArrayValue(_,_) => + DBG("ArrayValue") ; opat + } + } - /** the injection here handles alternatives and unapply type tests */ - final def make(temp:List[Symbol], row1:List[Row])(implicit theOwner: Symbol): Rep = { - var unchanged: Boolean = true - val row = row1 flatMap { - xx => + val row = row1 flatMap { xx => def isAlternative(p: Tree): Boolean = p match { case Bind(_,p) => isAlternative(p) case Alternative(ps) => true case _ => false } - def getAlternativeBranches(p:Tree): List[Tree] = { + def getAlternativeBranches(p: Tree): List[Tree] = { def get_BIND(pctx:Tree => Tree, p:Tree):List[Tree] = p match { case b @ Bind(n,p) => get_BIND({ x:Tree => pctx(copy.Bind(b, n, x) setType x.tpe) }, p) case Alternative(ps) => ps map pctx @@ -870,224 +888,95 @@ trait ParallelMatching { get_BIND(x => x, p) } val Row(opats, subst, g, bx) = xx - var pats:List[Tree] = Nil - var indexOfAlternative = -1 - for((opat, j) <- opats.zipWithIndex){ - val (vars, strippedPat) = strip(opat) - val vs = vars.toList - (strippedPat: @unchecked) match { - - case p @ Alternative(ps) => - DBG("Alternative") - if (indexOfAlternative == -1) { - unchanged = false - indexOfAlternative = j - } - pats = opat :: pats - - case typat @ Typed(p,tpt) if strip2(p).isInstanceOf[UnApply]=> - DBG("Typed") - pats = (if (temp(j).tpe <:< tpt.tpe) makeBind(vs, p) else opat)::pats - - case Ident(nme.WILDCARD) | EmptyTree | _:Literal | _:Typed => - DBG("Ident(_)|EmptyTree") - pats = opat :: pats - - case o @ Ident(n) => // n != nme.WILDCARD - DBG("Ident") - val tpe = - if (!o.symbol.isValue) { - singleType(o.tpe.prefix, o.symbol) - } else { - singleType(NoPrefix, o.symbol) // equals-check - // call the above `stpe'. Then we could also return - // `typeRef(definitions.ScalaPackageClass.tpe, definitions.EqualsPatternClass, List(stpe))' - // and force an equality check. However, exhaustivity checking would not work anymore. - // so first, extend exhaustivity check to equalspattern - } - val p = Ident(nme.WILDCARD) setType tpe - val q = Typed(p, TypeTree(tpe)) setType tpe - pats = (makeBind( vs, q) setType tpe) :: pats - - - case o @ Select(stor,_) => - DBG("Select") - val stpe = - if (!o.symbol.isValue) { - singleType(o.tpe.prefix, o.symbol) - } else { - singleType(NoPrefix, o.symbol) // equals-check - } - val p = Ident(nme.WILDCARD) setType stpe - val q = makeBind(vs,Typed(p, TypeTree(stpe)) setType stpe) setType stpe - pats = q::pats - - case UnApply(Apply(TypeApply(sel @ Select(stor, nme.unapplySeq),List(tptArg)),_),ArrayValue(_,xs)::Nil) if (stor.symbol eq definitions.ListModule) => - DBG("Unapply(...TypeApply...)") - //@pre: is not right-ignoring (no star pattern) - // no exhaustivity check, please - temp(j).setFlag(Flags.TRANS_FLAG) - val listType = typeRef(mkThisType(definitions.ScalaPackage), definitions.ListClass, List(tptArg.tpe)) - val nmlzdPat = normalizedListPattern(xs, tptArg.tpe) - pats = makeBind(vs, nmlzdPat) :: pats - - //@todo: rewrite, using __UnApply instead of UnApply like so: - //case ua @ __UnApply(_,argtpe,_) => - //val ua = prepat - // val npat = (if (temp(j).tpe <:< argtpe) ua else Typed(ua,TypeTree(argtpe)).setType(argtpe)) - // pats = (makeBind(vs, npat) setType argtpe)::pats - - - case ua @ UnApply(Apply(fn, _), _) => - DBG("Unapply(Apply())") - fn.tpe match { - case MethodType(List(argtpe,_*),_) => - val npat = (if (temp(j).tpe <:< argtpe) ua else Typed(ua,TypeTree(argtpe)).setType(argtpe)) - pats = (makeBind(vs, npat) setType argtpe)::pats - } + val indexOfAlternative = opats.findIndexOf(isAlternative) + if (indexOfAlternative != -1) unchanged = false + val pats: List[Tree] = opats.zipWithIndex.map { case (opat, j) => classifyPat(opat, j) } - case o @ Apply(fn, List()) if !isCaseClass(o.tpe) || /*see t301*/ !Apply_Value.unapply(o).isEmpty => - DBG("Apply !isCaseClass") - val stpe: Type = fn match { - case _ if (o.symbol.isModule) => - singleType(o.tpe.prefix, o.symbol) - case _ if (o.tpe.termSymbol.isModule) => - singleType(o.tpe.prefix, o.symbol) - case Select(path,sym) => - path.tpe match { - case ThisType(sym) => - singleType(path.tpe, o.symbol) - - case _ => // e.g. `case Some(p._2)' in scala.collection.jcl.Map - if (path.isInstanceOf[Apply]) - new PseudoType(o) // outer-matching, see test/files/pos/t154.scala - else - singleType(singleType(path.tpe.prefix, path.symbol), o.symbol) // old - - } - case o @ Ident(_) => - if (!o.symbol.isValue) - singleType(o.tpe.prefix, o.symbol) - else - singleType(NoPrefix, o.symbol) - } - val ttst = typeRef(NoPrefix, definitions.EqualsPatternClass, List(stpe)) - val p = Ident(nme.WILDCARD) setType ttst - val q = makeBind(vs,Typed(p, TypeTree(stpe)) setType ttst) - pats = q::pats - - case Apply_Value(pre, sym) => - DBG("Apply_Value") - val tpe = typeRef(NoPrefix, definitions.EqualsPatternClass, singleType(pre, sym)::Nil) - val q = makeBind(vs,Typed(EmptyTree, TypeTree(tpe)) setType tpe) - pats = q :: pats - - case Apply_CaseClass_NoArgs(tpe) => // no-args case class pattern - DBG("Apply_CaseClass_NoArgs") - val q = makeBind(vs, Typed(EmptyTree, TypeTree(tpe)) setType tpe) - pats = q :: pats - - case Apply_CaseClass_WithArgs() => // case class pattern with args - DBG("Apply_CaseClass_WithArgs") - pats = opat :: pats - - case ArrayValue(_,xs) => - DBG("ArrayValue") - pats = opat :: pats - - } - } - pats = pats.reverse if (indexOfAlternative == -1) List(xx.replace(pats)) else { - val prefix = pats.take( indexOfAlternative ) - val alts = getAlternativeBranches(pats( indexOfAlternative )) - val suffix = pats.drop(indexOfAlternative + 1) - alts map { p => xx.replace(prefix ::: p :: suffix) } + val (prefix, alts :: suffix) = pats.splitAt(indexOfAlternative) + getAlternativeBranches(alts) map { p => xx.replace(prefix ::: p :: suffix) } } } - if (unchanged) RepImpl(temp,row).init - else this.make(temp,row) // recursive call + if (unchanged) RepImpl(temp, row).init + else this.make(temp, row) // recursive call } } abstract class Rep { - val temp:List[Symbol] - val row:List[Row] - var sealedCols = List[Int]() - var sealedComb = List[Set[Symbol]]() + val temp: List[Symbol] + val row: List[Row] final def init: this.type = { - temp.zipWithIndex.foreach { - case (sym,i) => - if (sym.hasFlag(Flags.MUTABLE) && // indicates that have not yet checked exhaustivity - !sym.hasFlag(Flags.TRANS_FLAG) && // indicates @unchecked - sym.tpe.typeSymbol.hasFlag(Flags.SEALED)) { - - sym.resetFlag(Flags.MUTABLE) - sealedCols = i::sealedCols - // this should enumerate all cases... however, also the superclass is taken if it is not abstract - def candidates(tpesym: Symbol): SymSet = - if (!tpesym.hasFlag(Flags.SEALED)) emptySymbolSet else - tpesym.children.flatMap { x => - val z = candidates(x) - if (x.hasFlag(Flags.ABSTRACT)) z else z + x - } - val cases = candidates(sym.tpe.typeSymbol) - sealedComb = cases::sealedComb - } - } + val setsToCombine: List[(Int, SymSet)] = + for { + (sym, i) <- temp.zipWithIndex + if sym hasFlag Flags.MUTABLE // indicates that have not yet checked exhaustivity + if !(sym hasFlag Flags.TRANS_FLAG) // indicates @unchecked + if sym.tpe.typeSymbol hasFlag Flags.SEALED + } yield { + sym resetFlag Flags.MUTABLE + // this should enumerate all cases... however, also the superclass is taken if it is not abstract + def candidates(tpesym: Symbol): SymSet = { + def countCandidates(x: Symbol) = if (x hasFlag Flags.ABSTRACT) candidates(x) else candidates(x) + x + if (tpesym hasFlag Flags.SEALED) tpesym.children.flatMap(countCandidates) + else emptySymbolSet + } + (i, candidates(sym.tpe.typeSymbol)) + } // .reverse ? XXX + + if (setsToCombine.isEmpty) return this + // computes cartesian product, keeps indices available - def combine(colcom: List[(Int,Set[Symbol])]): List[List[(Int,Symbol)]] = colcom match { + def combine(colcom: List[(Int, Set[Symbol])]): List[List[(Int, Symbol)]] = colcom match { case Nil => Nil case (i,syms)::Nil => syms.toList.map { sym => List((i,sym)) } case (i,syms)::cs => for (s <- syms.toList; rest <- combine(cs)) yield (i,s) :: rest } - if (!sealedCols.isEmpty) { - val allcomb = combine(sealedCols zip sealedComb) - /** returns true if pattern vector pats covers a type symbols "combination" - * @param pats pattern vector - * @param comb pairs of (column index, type symbol) - */ - def covers(pats: List[Tree], comb:List[(Int,Symbol)]) = - comb forall { - case (i,sym) => - val p = strip2(pats(i)); - val res = - isDefaultPattern(p) || p.isInstanceOf[UnApply] || p.isInstanceOf[ArrayValue] || { - val ptpe = patternType_wrtEquals(p.tpe) - val symtpe = if (sym.hasFlag(Flags.MODULE) && (sym.linkedModuleOfClass ne NoSymbol)) { - singleType(sym.tpe.prefix, sym.linkedModuleOfClass) // e.g. None, Nil - } else sym.tpe - (ptpe.typeSymbol == sym) || (symtpe <:< ptpe) || - (symtpe.parents.exists(_.typeSymbol eq ptpe.typeSymbol)) || // e.g. Some[Int] <: Option[&b] - /* outer, see scala.util.parsing.combinator.lexical.Scanner */ - (ptpe.prefix.memberType(sym) <:< ptpe) - } - res - } + val allcomb = combine(setsToCombine) - val coversAll = allcomb forall { combination => row exists { r => (r.guard eq EmptyTree) && covers(r.pat, combination)}} - if (!coversAll) { - val sb = new StringBuilder() - sb.append("match is not exhaustive!\n") - for (open <- allcomb if !(row exists { r => covers(r.pat, open)})) { - sb.append("missing combination ") - val NPAD = 15 - def pad(s:String) = { 1.until(NPAD - s.length).foreach { x => sb.append(" ") }; sb.append(s) } - List.range(0, temp.length) foreach { - i => open.find { case (j,sym) => j==i } match { - case None => pad("*") - case Some((_,sym)) => pad(sym.name.toString) - } - } - sb.append('\n') - } - cunit.warning(temp.head.pos, sb.toString) + /** returns true if pattern vector pats covers a type symbols "combination" + * @param pats pattern vector + * @param comb pairs of (column index, type symbol) + */ + def covers(pats: List[Tree], comb: List[(Int, Symbol)]) = { + val results = for ((i, sym) <- comb ; val p = strip2(pats(i))) yield p match { + case _ if isDefaultPattern(p) => true + case _: UnApply | _: ArrayValue => true + case _ => + val ptpe = patternType_wrtEquals(p.tpe) + val symtpe = + if ((sym hasFlag Flags.MODULE) && (sym.linkedModuleOfClass ne NoSymbol)) + singleType(sym.tpe.prefix, sym.linkedModuleOfClass) // e.g. None, Nil + else sym.tpe + + (ptpe.typeSymbol == sym) || + (symtpe <:< ptpe) || + (symtpe.parents.exists(_.typeSymbol eq ptpe.typeSymbol)) || // e.g. Some[Int] <: Option[&b] + (ptpe.prefix.memberType(sym) <:< ptpe) // outer, see combinator.lexical.Scanner } + results.forall(_ == true) } + + def comboCovers(combo: List[(Int, Symbol)]) = row exists { r => (r.guard eq EmptyTree) && covers(r.pat, combo) } + + if (!(allcomb forall comboCovers)) { + def mkMissingStr(xs: List[(Int, Symbol)], i: Int) = xs.find(_._1 == i) match { + case None => pad("*") + case Some(pair) => pad(pair._2.name.toString) + } + + val missingCombos = + (for (open <- allcomb ; if row.forall(r => !covers(r.pat, open))) yield + "missing combination " + + (for (i <- 0 until temp.length) yield + mkMissingStr(open, i)).mkString + "\n").mkString + + cunit.warning(temp.head.pos, "match is not exhaustive!\n" + missingCombos) + } + return this } @@ -1098,23 +987,20 @@ trait ParallelMatching { final def applyRule(implicit theOwner: Symbol, rep: RepFactory): RuleApplication = row match { case Nil => ErrorRule() - case Row(pats, subst, g, bx)::xs => - var px = 0; var rpats = pats; var bnd = subst; var temps = temp; while(rpats ne Nil){ - val (vs,p) = strip(rpats.head); - if (!isDefaultPattern(p)) { + case Row(pats, subst, g, bx) :: xs => + var bnd = subst + for (((rpat, t), px) <- pats.zip(temp).zipWithIndex) { + val (vs, p) = strip(rpat) + if (isDefaultPattern(p)) bnd = bnd.add(vs, t) + else { // Row( _ ... _ p_1i ... p_1n g_m b_m ) :: rows // cut out column px that contains the non-default pattern - val column = rpats.head :: (row.tail map { case Row(pats,_,_,_) => pats(px) }) - val restTemp = temp.take(px) ::: temp.drop(px+1) - val restRows = row map { r => r.replace(r.pat.take(px) ::: r.pat.drop(px+1)) } - val mr = MixtureRule(temps.head, column, rep.make(restTemp, restRows)) - DBG("\n---\nmixture rule is = " + mr.getClass.toString) + val column = rpat :: row.tail.map(_.pat(px)) + val restTemp = temp.dropIndex(px) + val restRows = row.map(r => r.replace(r.pat.dropIndex(px))) + val mr = MixtureRule(t, column, rep.make(restTemp, restRows)) + DBG("\n---\nmixture rule is = " + mr.getClass) return mr - } else { - bnd = bnd.add(vs,temps.head) - rpats = rpats.tail - temps = temps.tail - px += 1 // pattern index } } //Row( _ ... _ g_1 b_1 ) :: rows it's all default patterns @@ -1125,49 +1011,38 @@ trait ParallelMatching { // a fancy toString method for debugging override final def toString = { - val sb = new StringBuilder - val NPAD = 15 - def pad(s:String) = { 1.until(NPAD - s.length).foreach { x => sb.append(" ") }; sb.append(s) } - for (tmp <- temp) pad(tmp.name.toString) - sb.append('\n') - for ((r,i) <- row.zipWithIndex) { - for (c <- r.pat ::: List(r.subst, r.guard, r.bx)) { - pad(c.toString) - } - sb.append('\n') - } - sb.toString + val tempStr = temp.map(t => pad(t.name)).mkString + "\n" + val rowStr = row.map(r => (r.pat ::: List(r.subst, r.guard, r.bx)).map(pad).mkString + "\n").mkString + tempStr + rowStr } + + private val NPAD = 15 + private def pad(s: Any): String = pad(s.toString) + private def pad(s: String): String = List.make(NPAD - s.length - 1, " ").mkString + s } /** creates initial clause matrix */ - final def initRep(roots: List[Symbol], cases: List[Tree], rep:RepFactory)(implicit theOwner: Symbol) = { + final def initRep(roots: List[Symbol], cases: List[Tree], rep: RepFactory)(implicit theOwner: Symbol) = { // communicate whether exhaustiveness-checking is enabled via some flag - var bx = 0; - val targets = new ListBuffer[Tree] - val vss = new ListBuffer[SymList] - val row = new ListBuffer[Row] - - var cs = cases; while (cs ne Nil) cs.head match { // stash away pvars and bodies for later - case CaseDef(pat,g,b) => - vss += definedVars(pat) - targets += b - if (roots.length > 1) pat match { - case Apply(fn, pargs) => - row += Row(pargs, NoBinding, g, bx) - case Ident(nme.WILDCARD) => - row += Row(getDummies(roots.length), NoBinding, g, bx) - } else - row += Row(List(pat), NoBinding, g, bx) - bx += 1 - cs = cs.tail - } - rep.make(roots, row.toList, targets.toList, vss.toList) + val (rows, targets, vss): (List[Option[Row]], List[Tree], List[SymList]) = unzip3( + for ((CaseDef(pat, g, b), bx) <- cases.zipWithIndex) yield { // stash away pvars and bodies for later + def rowForPat: Option[Row] = pat match { + case _ if roots.length <= 1 => Some(Row(List(pat), NoBinding, g, bx)) + case Apply(fn, pargs) => Some(Row(pargs, NoBinding, g, bx)) + case Ident(nme.WILDCARD) => Some(Row(getDummies(roots.length), NoBinding, g, bx)) + case _ => None + } + (rowForPat, b, definedVars(pat)) + } + ) + + // flatMap the list of options yields the list of values + rep.make(roots, rows.flatMap(x => x), targets, vss) } final def newVar(pos: Position, name: Name, tpe: Type)(implicit theOwner: Symbol): Symbol = { - if (tpe eq null) assert(tpe ne null, "newVar("+name+", null)") + assert(tpe ne null, "newVar("+name+", null)") val sym = theOwner.newVariable(pos, name) // careful: pos has special meaning sym setInfo tpe sym @@ -1184,31 +1059,23 @@ trait ParallelMatching { } final def condition(tpe: Type, scrutineeTree: Tree)(implicit typer : Typer): Tree = { - assert(tpe ne NoType) - assert(scrutineeTree.tpe ne NoType) - if (tpe.isInstanceOf[SingletonType] && !tpe.isInstanceOf[ConstantType]) { - if (tpe.termSymbol.isModule) {// object - //if (scrutineeTree.tpe <:< definitions.AnyRefClass.tpe) - // Eq(gen.mkAttributedRef(tpe.termSymbol), scrutineeTree) // object - //else - Equals(gen.mkAttributedRef(tpe.termSymbol), scrutineeTree) // object - } else { - val x = - if (tpe.prefix ne NoPrefix) gen.mkIsInstanceOf(scrutineeTree, tpe) - else - Equals(gen.mkAttributedRef(tpe.termSymbol), scrutineeTree) - typer.typed { x } - } - } else if (tpe.isInstanceOf[ConstantType]) { - val value = tpe.asInstanceOf[ConstantType].value - if (value == Constant(null) && scrutineeTree.tpe <:< definitions.AnyRefClass.tpe) - Eq(scrutineeTree, Literal(value)) // constant - else - Equals(scrutineeTree, Literal(value)) // constant - } else if (scrutineeTree.tpe <:< tpe && tpe <:< definitions.AnyRefClass.tpe) { - NotNull(scrutineeTree) - } else { - gen.mkIsInstanceOf(scrutineeTree, tpe) + assert((tpe ne NoType) && (scrutineeTree.tpe ne NoType)) + + tpe match { + case _: SingletonType if !tpe.isInstanceOf[ConstantType] => + lazy val equalsRef = Equals(gen.mkAttributedRef(tpe.termSymbol), scrutineeTree) + if (tpe.termSymbol.isModule) equalsRef // object + else if (tpe.prefix ne NoPrefix) typer.typed(gen.mkIsInstanceOf(scrutineeTree, tpe)) + else typer.typed(equalsRef) + + case ct: ConstantType => ct.value match { // constant + case v @ Constant(null) if scrutineeTree.tpe <:< definitions.AnyRefClass.tpe => Eq(scrutineeTree, Literal(v)) + case v => Equals(scrutineeTree, Literal(v)) + } + case _ if scrutineeTree.tpe <:< tpe && tpe <:< definitions.AnyRefClass.tpe => + NotNull(scrutineeTree) + case _ => + gen.mkIsInstanceOf(scrutineeTree, tpe) } } @@ -1225,12 +1092,12 @@ trait ParallelMatching { val outerAcc = outerAccessor(tpe2test.typeSymbol) if (outerAcc == NoSymbol) { - if (settings_debug) cunit.warning(scrutinee.pos, "no outer acc for "+tpe2test.typeSymbol) + if (settings.debug.value) cunit.warning(scrutinee.pos, "no outer acc for "+tpe2test.typeSymbol) cond } else And(cond, Eq(Apply(Select( - gen.mkAsInstanceOf(scrutinee, tpe2test, true), outerAcc),List()), theRef)) + gen.mkAsInstanceOf(scrutinee, tpe2test, true), outerAcc), Nil), theRef)) } } diff --git a/src/compiler/scala/tools/nsc/matching/PatternNodes.scala b/src/compiler/scala/tools/nsc/matching/PatternNodes.scala index 191461398f..6ab57f4b5b 100644 --- a/src/compiler/scala/tools/nsc/matching/PatternNodes.scala +++ b/src/compiler/scala/tools/nsc/matching/PatternNodes.scala @@ -12,21 +12,14 @@ import scala.tools.nsc.util.{Position, NoPosition} * @author Burak Emir */ trait PatternNodes { self: transform.ExplicitOuter => - import global._ + import symtab.Flags + final def DBG(x: => String) = if (settings.debug.value) Console.println(x) - private val dummies = new Array[List[Tree]](8); - dummies(0) = Nil; - for (i <- 1 until dummies.length){ - dummies(i) = EmptyTree :: dummies(i - 1); - } - - final def getDummies(i:Int): List[Tree] = - if (i < dummies.length) dummies(i); - else EmptyTree::getDummies(i-1) + final def getDummies(i: Int): List[Tree] = List.make(i, EmptyTree) def makeBind(vs:SymList, pat:Tree): Tree = - if(vs eq Nil) pat else Bind(vs.head, makeBind(vs.tail, pat)) setType pat.tpe + if (vs eq Nil) pat else Bind(vs.head, makeBind(vs.tail, pat)) setType pat.tpe def normalizedListPattern(pats:List[Tree], tptArg:Type): Tree = pats match { case Nil => gen.mkAttributedRef(definitions.NilModule) @@ -67,7 +60,7 @@ trait PatternNodes { self: transform.ExplicitOuter => /* equality checks for named constant patterns like "Foo()" are encoded as "_:[Foo().type]" * and later compiled to "if(Foo() == scrutinee) ...". This method extracts type information from * such an encoded type, which is used in optimization. If the argument is not an encoded equals - * test, it is returned as is. + * test, it is returned as is. */ def patternType_wrtEquals(pattpe:Type) = pattpe match { case TypeRef(_,sym,arg::Nil) if sym eq definitions.EqualsPatternClass => @@ -76,7 +69,7 @@ trait PatternNodes { self: transform.ExplicitOuter => } /** returns if pattern can be considered a no-op test ??for expected type?? */ - final def isDefaultPattern(pattern:Tree): Boolean = pattern match { + final def isDefaultPattern(pattern: Tree): Boolean = pattern match { case Bind(_, p) => isDefaultPattern(p) case EmptyTree => true // dummy case Ident(nme.WILDCARD) => true @@ -85,15 +78,13 @@ trait PatternNodes { self: transform.ExplicitOuter => // case Typed(nme.WILDCARD,_) => pattern.tpe <:< scrutinee.tpe } - final def DBG(x : =>String) { if (settings_debug) Console.println(x) } - /** returns all variables that are binding the given pattern * @param x a pattern * @return vs variables bound, p pattern proper */ final def strip(x: Tree): (Set[Symbol], Tree) = x match { case b @ Bind(_,pat) => val (vs, p) = strip(pat); (vs + b.symbol, p) - case z => (emptySymbolSet,z) + case z => (emptySymbolSet, z) } final def strip1(x: Tree): Set[Symbol] = x match { // same as strip(x)._1 @@ -104,46 +95,39 @@ trait PatternNodes { self: transform.ExplicitOuter => case Bind(_,pat) => strip2(pat) case z => z } + object StrippedPat { + def unapply(x: Tree): Option[Tree] = Some(strip2(x)) + } - final def isCaseClass(tpe: Type): Boolean = - tpe match { - case TypeRef(_, sym, _) => - if(!sym.isAliasType) - sym.hasFlag(symtab.Flags.CASE) - else - tpe.normalize.typeSymbol.hasFlag(symtab.Flags.CASE) - case _ => false - } + final def isCaseClass(tpe: Type): Boolean = tpe match { + case TypeRef(_, sym, _) => + if (sym.isAliasType) tpe.normalize.typeSymbol hasFlag Flags.CASE + else sym hasFlag Flags.CASE + case _ => false + } - final def isEqualsPattern(tpe: Type): Boolean = - tpe match { - case TypeRef(_, sym, _) => sym eq definitions.EqualsPatternClass - case _ => false - } + final def isEqualsPattern(tpe: Type): Boolean = tpe match { + case TypeRef(_, sym, _) => sym eq definitions.EqualsPatternClass + case _ => false + } // this method obtains tag method in a defensive way final def getCaseTag(x:Type): Int = { x.typeSymbol.tag } - final def definedVars(x:Tree): SymList = { - // I commented out the no-op cases, but left them in case the order is somehow significant -- paulp - def definedVars1(x:Tree): SymList = x match { - // case Alternative(bs) => ; // must not have any variables - case Apply(_, args) => definedVars2(args) - case b @ Bind(_,p) => b.symbol :: definedVars1(p) - // case Ident(_) => ; - // case Literal(_) => ; - // case Select(_,_) => ; - case Typed(p,_) => definedVars1(p) //otherwise x @ (_:T) - case UnApply(_,args) => definedVars2(args) - // regexp specific - case ArrayValue(_,xs)=> definedVars2(xs) - // case Star(p) => ; // must not have variables + final def definedVars(x: Tree): SymList = { + implicit def listToStream[T](xs: List[T]): Stream[T] = xs.toStream + def definedVars1(x: Tree): Stream[Symbol] = x match { + case Apply(_, args) => definedVars2(args) + case b @ Bind(_,p) => Stream.cons(b.symbol, definedVars1(p)) + case Typed(p,_) => definedVars1(p) // otherwise x @ (_:T) + case UnApply(_,args) => definedVars2(args) + case ArrayValue(_,xs) => definedVars2(xs) case _ => Nil } - def definedVars2(args: List[Tree]): SymList = args.flatMap(definedVars1) + def definedVars2(args: Stream[Tree]): Stream[Symbol] = args flatMap definedVars1 - definedVars1(x) + definedVars1(x).reverse.toList } /** pvar: the symbol of the pattern variable @@ -161,17 +145,17 @@ trait PatternNodes { self: transform.ExplicitOuter => case Binding(pv2,tmp2,next2) => (pvar eq pv2) && (temp eq tmp2) && (next==next2) } } - def apply(v:Symbol): Ident = - if(v eq pvar) {Ident(temp).setType(v.tpe)} else next(v) + def apply(v:Symbol): Ident = { + if (v eq pvar) Ident(temp).setType(v.tpe) else next(v) + } } - object NoBinding extends Binding(null,null,null) { + + object NoBinding extends Binding(null, null, null) { override def apply(v:Symbol) = null // not found, means bound elsewhere (x @ unapply-call) override def toString = "." override def equals(x:Any) = x.isInstanceOf[Binding] && (x.asInstanceOf[Binding] eq this) } - // misc methods END --- - type SymSet = collection.immutable.Set[Symbol] type SymList = List[Symbol] diff --git a/src/compiler/scala/tools/nsc/matching/TransMatcher.scala b/src/compiler/scala/tools/nsc/matching/TransMatcher.scala index 924135c5ad..c71874fe8e 100644 --- a/src/compiler/scala/tools/nsc/matching/TransMatcher.scala +++ b/src/compiler/scala/tools/nsc/matching/TransMatcher.scala @@ -13,17 +13,12 @@ package scala.tools.nsc.matching */ trait TransMatcher { self: transform.ExplicitOuter with PatternNodes with ParallelMatching with CodeFactory => - import global.{typer => _, _} + import global.{ typer => _, _ } import analyzer.Typer; import definitions._ - import posAssigner.atPos import symtab.Flags - import collection.mutable.ListBuffer - - var cunit: CompilationUnit = _ // memory leak? - def fresh = cunit.fresh - var nPatterns = 0 + var cunit: CompilationUnit = _ // memory leak? var resultType: Type = _ // cache these @@ -45,58 +40,53 @@ trait TransMatcher { self: transform.ExplicitOuter with PatternNodes with Parall /** handles all translation of pattern matching */ def handlePattern(selector: Tree, cases: List[CaseDef], doCheckExhaustive: Boolean, owner: Symbol, handleOuter: Tree => Tree)(implicit typer : Typer): Tree = { + DBG("****") + DBG("**** initalize, selector = "+selector+" selector.tpe = "+selector.tpe) + DBG("**** doCheckExhaustive == "+doCheckExhaustive) + implicit val theOwner = owner - if (settings_debug) { - Console.println("****") - Console.println("**** initalize, selector = "+selector+" selector.tpe = "+selector.tpe) - Console.println("**** doCheckExhaustive == "+doCheckExhaustive) + implicit val rep = new RepFactory(handleOuter) + + def caseIsOk(c: CaseDef) = c match { + case CaseDef(_: Apply, _, _) => true + case CaseDef(Ident(nme.WILDCARD), _, _) => true + case _ => false } + def doApply(fn: Tree) = (fn.symbol eq selector.tpe.decls.lookup(nme.CONSTRUCTOR)) && (cases forall caseIsOk) - implicit val rep = new RepFactory(handleOuter) - val tmps = new ListBuffer[Symbol] - val vds = new ListBuffer[Tree] - var root:Symbol = newVar(selector.pos, selector.tpe) - if (!doCheckExhaustive) - root.setFlag(Flags.TRANS_FLAG) - - var vdef:Tree = typer.typed{ValDef(root, selector)} - var theFailTree:Tree = ThrowMatchError(selector.pos, mkIdent(root)) - - if (definitions.isTupleType(selector.tpe)) selector match { - case app @ Apply(fn, args) - if (fn.symbol eq selector.tpe.decls.lookup(nme.CONSTRUCTOR)) && - (cases forall { x => x match { - case CaseDef(Apply(fn, pargs),_,_) => true ; - case CaseDef(Ident(nme.WILDCARD),_,_) => true ; - case _ => false - }}) => - for ((ti, i) <- args.zipWithIndex){ + def processApply(app: Apply): (List[Symbol], List[Tree], Tree) = { + val Apply(fn, args) = app + val (tmps, vds) = List.unzip( + for ((ti, i) <- args.zipWithIndex) yield { val v = newVar(ti.pos, cunit.fresh.newName(ti.pos, "tp"), selector.tpe.typeArgs(i)) - if (!doCheckExhaustive) - v.setFlag(Flags.TRANS_FLAG) - vds += typedValDef(v, ti) - tmps += v + if (!doCheckExhaustive) v setFlag Flags.TRANS_FLAG + (v, typedValDef(v, ti)) } - theFailTree = ThrowMatchError(selector.pos, copy.Apply(app, fn, tmps.toList map mkIdent)) + ) + (tmps, vds, ThrowMatchError(selector.pos, copy.Apply(app, fn, tmps map mkIdent))) + } + + // sets temporaries, variable declarations, and the fail tree + val (tmps, vds, theFailTree) = selector match { + case app @ Apply(fn, _) if isTupleType(selector.tpe) && doApply(fn) => processApply(app) case _ => - tmps += root - vds += vdef - } else { - tmps += root - vds += vdef + val root: Symbol = newVar(selector.pos, selector.tpe) + if (!doCheckExhaustive) + root setFlag Flags.TRANS_FLAG + val vdef: Tree = typer.typed(ValDef(root, selector)) + val failTree: Tree = ThrowMatchError(selector.pos, mkIdent(root)) + (List(root), List(vdef), failTree) } - val irep = initRep(tmps.toList, cases, rep) implicit val fail: Tree = theFailTree + val irep = initRep(tmps, cases, rep) + val mch = typer.typed(repToTree(irep)) + var dfatree = typer.typed(Block(vds, mch)) - val mch = typer.typed{ repToTree(irep)} - var dfatree = typer.typed{Block(vds.toList, mch)} // cannot use squeezedBlock because of side-effects, see t275 - for ((cs, bx) <- cases.zipWithIndex){ - if (!rep.isReached(bx)) { - cunit.error(cs.body.pos, "unreachable code") - } - } + for ((cs, bx) <- cases.zipWithIndex) + if (!rep.isReached(bx)) cunit.error(cs.body.pos, "unreachable code") + dfatree = rep.cleanup(dfatree) resetTrav.traverse(dfatree) dfatree -- cgit v1.2.3