diff options
author | Burak Emir <emir@epfl.ch> | 2007-07-07 20:44:54 +0000 |
---|---|---|
committer | Burak Emir <emir@epfl.ch> | 2007-07-07 20:44:54 +0000 |
commit | b98eb1d74141a4159539d373e6216e799d6b6dcd (patch) | |
tree | a81ca918045150fae01cb5e8c2e49af4c7178a90 | |
parent | f2e4362a5a3548e018c116d6588ddc97cdf4a16c (diff) | |
download | scala-b98eb1d74141a4159539d373e6216e799d6b6dcd.tar.gz scala-b98eb1d74141a4159539d373e6216e799d6b6dcd.tar.bz2 scala-b98eb1d74141a4159539d373e6216e799d6b6dcd.zip |
optimizing case class matches via tags (experim...
optimizing case class matches via tags (experimental)
-rw-r--r-- | src/compiler/scala/tools/nsc/matching/ParallelMatching.scala | 310 |
1 files changed, 225 insertions, 85 deletions
diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala index 6f16ea8d14..e671e60adc 100644 --- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala +++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala @@ -65,15 +65,15 @@ trait ParallelMatching { case a @ Apply(fn,_) => ((a.tpe.symbol.flags & symtab.Flags.CASE) != 0) && directSubtype( a.tpe ) && isFlatCases(col.tail) case t @ Typed(_,tpt) => - ( (tpt.symbol.flags & symtab.Flags.CASE) != 0) && directSubtype( t.tpe ) && isFlatCases(col.tail) + ( (tpt.tpe.symbol.flags & symtab.Flags.CASE) != 0) && directSubtype( t.tpe ) && isFlatCases(col.tail) case Ident(nme.WILDCARD) => - true // treat col.tail specially? + isFlatCases(col.tail) // treat col.tail specially? case i @ Ident(n) => // n ne nme.WILDCARD ( (i.symbol.flags & symtab.Flags.CASE) != 0) && directSubtype( i.tpe ) && isFlatCases(col.tail) case s @ Select(_,_) => // i.e. scala.Nil ( (s.symbol.flags & symtab.Flags.CASE) != 0) && directSubtype( s.tpe ) && isFlatCases(col.tail) case p => - Console.println(p.getClass) + //Console.println(p.getClass) false } } @@ -87,14 +87,20 @@ trait ParallelMatching { if(settings_debug) { Console.println("MixLiteral") } return new MixLiterals(scrutinee, column, rest) } - /* - if(isFlatCases(column)) { - Console.println("flat cases!"+column) - Console.println(scrutinee.tpe.symbol.children) - Console.println(scrutinee.tpe.member(nme.tag)) - Console.println(column.map { x => x.tpe.symbol.tag }) + if((column.length > 1) && isFlatCases(column)) { + if(settings_debug) { + Console.println("flat cases!"+column) + Console.println(scrutinee.tpe.symbol.children) + Console.println(scrutinee.tpe.member(nme.tag)) + Console.println(column.map { x => x.tpe.symbol.tag }) + } + return new MixCases(scrutinee, column, rest) + /*if(scrutinee.tpe.symbol.hasFlag(symtab.Flags.SEALED)) + new MixCasesSealed(scrutinee, column, rest) + else + new MixCases(scrutinee, column, rest) */ + } - */ //Console.println("isUnapplyHead") if(isUnapplyHead) { if(settings_debug) { Console.println("MixUnapply") } @@ -105,18 +111,21 @@ trait ParallelMatching { return new MixTypes(scrutinee, column, rest) // todo: handle type tests in unapply } - /** - * mixture rule for literals - * - * this rule gets translated to a switch. all defaults/same literals are collected - **/ - class MixLiterals(val scrutinee:Symbol, val column:List[Tree], val rest:Rep) extends RuleApplication { + + abstract class CaseRuleApplication extends RuleApplication { + + def rest:Rep + + class TagIndexPair(val tag:Int, val index: Int, val next: TagIndexPair) { + def find(tag:Int): Int = if(this.tag == tag) index else next.find(tag) // assumes argument can always be found + } + // e.g. (1,1) (1,3) (42,2) for column {case ..1.. => ;; case ..42..=> ;; case ..1.. => } var defaults: List[Int] = Nil var defaultV: List[Symbol] = Nil // sorted e.g. case _ => 1,5,7 - private def insertDefault(tag: Int,vs:List[Symbol]) { + protected def insertDefault(tag: Int,vs:List[Symbol]) { def ins(xs:List[Int]):List[Int] = xs match { case y::ys if y > tag => y::ins(ys) case ys => tag :: ys @@ -125,10 +134,44 @@ trait ParallelMatching { defaults = ins(defaults) } - class TagIndexPair(val tag:Int, val index: Int, val next: TagIndexPair) {} - // e.g. (1,1) (1,3) (42,2) for column {case ..1.. => ;; case ..42..=> ;; case ..1.. => } + var tagIndexPairs: TagIndexPair = null + + protected def grabTemps: List[Symbol] + protected def grabRow(index:Int): (List[Tree], List[(Symbol,Symbol)], Tree, Tree) + + /** inserts row indices using in to list of tagindexpairs*/ + protected def tagIndicesToReps = { + var trs:List[(Int,Rep)] = Nil + var old = tagIndexPairs + while(tagIndexPairs ne null) { // collect all with same tag + val tag = tagIndexPairs.tag + var tagRows = this.grabRow(tagIndexPairs.index)::Nil + tagIndexPairs = tagIndexPairs.next + while((tagIndexPairs ne null) && tagIndexPairs.tag == tag) { + tagRows = this.grabRow(tagIndexPairs.index)::tagRows + tagIndexPairs = tagIndexPairs.next + } + var tagRowsD:List[(List[Tree], List[(Symbol,Symbol)], Tree, Tree)] = Nil + var ds = defaults; while(ds ne Nil) { + tagRowsD = rest.row(ds.head) :: tagRowsD + ds = ds.tail + } + trs = (tag, Rep(this.grabTemps, tagRows ::: tagRowsD)) :: trs + } + tagIndexPairs = old + trs + } + + protected def defaultsToRep = { + var drows:List[(List[Tree], List[(Symbol,Symbol)], Tree, Tree)] = Nil + while(defaults ne Nil) { + drows = rest.row(defaults.head) :: drows + defaults = defaults.tail + } + Rep(rest.temp, drows) + } - private def insertTagIndexPair(tag: Int, index: Int) { + protected def insertTagIndexPair(tag: Int, index: Int) { def ins(current: TagIndexPair): TagIndexPair = { if (current eq null) new TagIndexPair(tag, index, null) @@ -140,7 +183,70 @@ trait ParallelMatching { tagIndexPairs = ins(tagIndexPairs) } - var tagIndexPairs: TagIndexPair = null + /** returns + * @return list of continuations, + * @return variables bound to default continuation, + * @return optionally, a default continuation, + **/ + def getTransition(implicit theOwner: Symbol): (List[(Int,Rep)],List[Symbol],Option[Rep]) = + (tagIndicesToReps, defaultV, {if(!defaults.isEmpty) Some(defaultsToRep) else None}) + } + + /** + * mixture rule for flat case class (using tags) + * + * this rule gets translated to a switch of _.$tag() + **/ + class MixCases(val scrutinee:Symbol, val column:List[Tree], val rest:Rep) extends CaseRuleApplication { + + /** insert row indices using in to list of tagindexpairs */ + { + var xs = column + var i = 0 + while(!xs.isEmpty) { // forall + val (pvars,p) = strip(xs.head) + if(isDefaultPattern(p)) + insertDefault(i,pvars) + else + insertTagIndexPair(p.tpe.symbol.tag, i) + i = i + 1 + xs = xs.tail + } + } + + def grabTemps = scrutinee::rest.temp + def grabRow(index:Int) = rest.row(tagIndexPairs.index) match { + case (pats,s,g,b) => (column(tagIndexPairs.index)::pats,s,g,b) + } + + /*{ + var trs = tagIndexPairs + while(trs ne null) { + Console.println((trs.tag,trs.index)) + trs = trs.next + } + }*/ + //System.exit(-1) + + + } + + /** + * mixture rule for flat case class (using tags) + * + * this rule gets translated to a switch of _.$tag() + class MixCasesSealed(scrut:Symbol, col:List[Tree], res:Rep) extends MixCases(scrut, col, res) { + override def grabTemps = rest.temp + override def grabRow(index:Int) = rest.row(tagIndexPairs.index) + } + **/ + + /** + * mixture rule for literals + * + * this rule gets translated to a switch. all defaults/same literals are collected + **/ + class MixLiterals(val scrutinee:Symbol, val column:List[Tree], val rest:Rep) extends CaseRuleApplication { private def sanity(pos:Position,pvars:List[Symbol]) { if(!pvars.isEmpty) cunit.error(pos, "nonsensical variable binding") @@ -148,7 +254,7 @@ trait ParallelMatching { { var xs = column - var i = 0; + var i = 0; while(!xs.isEmpty) { // forall val (pvars,p) = strip(xs.head) p match { @@ -160,36 +266,10 @@ trait ParallelMatching { xs = xs.tail } } - def tagIndicesToReps = { - var trs:List[(Int,Rep)] = Nil - while(tagIndexPairs ne null) { // collect all with same tag - val tag = tagIndexPairs.tag - var tagRows = rest.row(tagIndexPairs.index)::Nil - tagIndexPairs = tagIndexPairs.next - while((tagIndexPairs ne null) && tagIndexPairs.tag == tag) { - tagRows = rest.row(tagIndexPairs.index)::tagRows - tagIndexPairs = tagIndexPairs.next - } - var tagRowsD:List[(List[Tree], List[(Symbol,Symbol)], Tree, Tree)] = Nil - var ds = defaults; while(ds ne Nil) { - tagRowsD = rest.row(ds.head) :: tagRowsD - ds = ds.tail - } - trs = (tag, Rep(rest.temp, tagRows ::: tagRowsD)) :: trs - } - trs - } - def defaultsToRep = { - var drows:List[(List[Tree], List[(Symbol,Symbol)], Tree, Tree)] = Nil - while(defaults ne Nil) { - drows = rest.row(defaults.head) :: drows - defaults = defaults.tail - } - Rep(rest.temp, drows) - } - def getTransition(implicit theOwner: Symbol): (List[(Int,Rep)],List[Symbol],Option[Rep]) = - (tagIndicesToReps, defaultV, {if(!defaults.isEmpty) Some(defaultsToRep) else None}) + def grabTemps = rest.temp + def grabRow(index:Int) = rest.row(tagIndexPairs.index) + } /** @@ -484,44 +564,102 @@ trait ParallelMatching { } // end getTransitions } + + def genBody(subst: List[Pair[Symbol,Symbol]], b:Tree)(implicit theOwner: Symbol, bodies: collection.mutable.Map[Tree,(Tree, Tree, Symbol)]): Tree = { + bodies.get(b) match { + + case Some(EmptyTree, nb, theLabel) => //Console.println("H E L L O"+subst+" "+b) + if(b.isInstanceOf[Literal]) + return b + // recover the order of the idents that is expected for the labeldef + val args = nb match { case Block(_, LabelDef(_, origs, _)) => + origs.map { p => Ident(subst.find { q => q._1 == p.symbol }.get._2) } // wrong! + } // using this instead would not work: subst.map { p => Ident(p._2) } + // the order can be garbled, when default patterns are used... #bug 1163 + + val body = Apply(Ident(theLabel), args) + return body + + case None => + if(b.isInstanceOf[Literal]) { + bodies(b) = (EmptyTree, b, null) // unreachable code is detected via missing hash entries + return b + } + // this seems weird, but is necessary for sharing bodies. unnecessary for bodies that are not shared + var argtpes = subst map { case (v,_) => v.tpe } + val theLabel = targetLabel(theOwner, b.pos, "body"+b.hashCode, argtpes, b.tpe) + // make new value parameter for each vsym in subst + val vdefs = targetParams(subst) + + var nbody: Tree = b + val vrefs = vdefs.map { p:ValDef => Ident(p.symbol) } + nbody = squeezedBlock(vdefs:::List(Apply(Ident(theLabel), vrefs)), LabelDef(theLabel, subst.map(_._1), nbody)) + bodies(b) = (EmptyTree, nbody, theLabel) + nbody + } + } /** converts given rep to a tree - performs recursive call to translation in the process to get sub reps */ def repToTree(rep:Rep, typed:Tree => Tree, handleOuter: Tree => Tree)(implicit theOwner: Symbol, failTree: Tree, bodies: collection.mutable.Map[Tree,(Tree, Tree, Symbol)]): Tree = { rep.applyRule match { - case VariableRule(subst, EmptyTree, b) => bodies.get(b) match { + case VariableRule(subst, EmptyTree, b) => + genBody(subst,b) - case Some(EmptyTree, nb, theLabel) => //Console.println("H E L L O"+subst+" "+b) - if(b.isInstanceOf[Literal]) - return b - // recover the order of the idents that is expected for the labeldef - val args = nb match { case Block(_, LabelDef(_, origs, _)) => - origs.map { p => Ident(subst.find { q => q._1 == p.symbol }.get._2) } // wrong! - } // using this instead would not work: subst.map { p => Ident(p._2) } - // the order can be garbled, when default patterns are used... #bug 1163 - - val body = Apply(Ident(theLabel), args) - return body - - case None => - if(b.isInstanceOf[Literal]) { - bodies(b) = (EmptyTree, b, null) // unreachable code is detected via missing hash entries - return b - } - // this seems weird, but is necessary for sharing bodies. unnecessary for bodies that are not shared - var argtpes = subst map { case (v,_) => v.tpe } - val theLabel = targetLabel(theOwner, b.pos, "body"+b.hashCode, argtpes, b.tpe) - // make new value parameter for each vsym in subst - val vdefs = targetParams(subst) - - var nbody: Tree = b - val vrefs = vdefs.map { p:ValDef => Ident(p.symbol) } - nbody = squeezedBlock(vdefs:::List(Apply(Ident(theLabel), vrefs)), LabelDef(theLabel, subst.map(_._1), nbody)) - bodies(b) = (EmptyTree, nbody, theLabel) - nbody - } case VariableRule(subst,g,b) => throw CantHandleGuard + + case mc: MixCases => + val (branches, defaultV, default) = mc.getTransition // tag body pairs + if(settings_debug) { + Console.println("[[mix cases transition: branches \n"+(branches.mkString("","\n",""))) + Console.println("defaults:"+defaultV) + Console.println(default) + Console.println("]]") + } + + var ndefault = if(default.isEmpty) failTree else repToTree(default.get, typed, handleOuter) + + var cases = branches map { + case (tag, rep) => + CaseDef(Literal(tag), + EmptyTree, + { + val pat = mc.column(mc.tagIndexPairs.find(tag)); + val ptpe = pat.tpe + if(mc.scrutinee.tpe.symbol.hasFlag(symtab.Flags.SEALED) && strip(pat)._2.isInstanceOf[Apply]) { + //cast + val vtmp = newVar(pat.pos, ptpe) + squeezedBlock( + List(ValDef(vtmp, gen.mkAsInstanceOf(Ident(mc.scrutinee),ptpe))), + repToTree(Rep(vtmp::rep.temp.tail, rep.row),typed,handleOuter) + ) + } else repToTree(rep, typed, handleOuter) + } + )} + if(!defaultV.isEmpty) { + var to:List[Symbol] = Nil + var i = 0; while(i < defaultV.length) { i = i + 1; to = mc.scrutinee::to } + val tss = new TreeSymSubstituter(defaultV, to) + tss.traverse( ndefault ) + } + + // make first case a default case. + if(mc.scrutinee.tpe.symbol.hasFlag(symtab.Flags.SEALED) && defaultV.isEmpty) { + ndefault = cases.head.body + cases = cases.tail + } + + if(cases.length == 0) { + genBody(Nil, ndefault) + } else if(cases.length == 1) { + val CaseDef(lit,_,body) = cases.head + makeIf(Equals(Select(Ident(mc.scrutinee),nme.tag),lit), body, ndefault) // * + } else { + val defCase = CaseDef(mk_(definitions.IntClass.tpe), EmptyTree, ndefault) + Match(Select(Ident(mc.scrutinee),nme.tag), cases ::: defCase :: Nil) // * + } + case ml: MixLiterals => val (branches, defaultV, default) = ml.getTransition // tag body pairs if(settings_debug) { @@ -574,7 +712,9 @@ trait ParallelMatching { typed { ValDef(tmp, typedAccess) } } - vdefs = typed { ValDef(casted, gen.mkAsInstanceOf(typed{Ident(mm.scrutinee)}, casted.tpe))} :: vdefs + if(casted ne mm.scrutinee) { + vdefs = typed { ValDef(casted, gen.mkAsInstanceOf(typed{Ident(mm.scrutinee)}, casted.tpe))} :: vdefs + } typed { makeIf(cond, squeezedBlock(vdefs,succ), fail) } case mu: MixUnapply => @@ -912,9 +1052,9 @@ object Rep { Eq(scrutineeTree, Literal(value)) // constant else Equals(scrutineeTree, Literal(value)) // constant - } else if(scrutineeTree.tpe <:< tpe && tpe <:< definitions.AnyRefClass.tpe) - NotNull(scrutineeTree) - else if(tpe.prefix.symbol.isTerm && tpe.symbol.linkedModuleOfClass != NoSymbol) { // object + } else if(scrutineeTree.tpe <:< tpe && tpe <:< definitions.AnyRefClass.tpe) { + if(scrutineeTree.symbol.hasFlag(symtab.Flags.SYNTHETIC)) Literal(Constant(true)) else NotNull(scrutineeTree) + } else if(tpe.prefix.symbol.isTerm && tpe.symbol.linkedModuleOfClass != NoSymbol) { // object //Console.println("iT"+tpe.prefix.symbol.isTerm) //Console.println("lmoc"+tpe.symbol.linkedModuleOfClass) Eq(gen.mkAttributedRef(tpe.prefix, tpe.symbol.linkedModuleOfClass), scrutineeTree) |