summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/compiler/scala/tools/nsc/ast/TreeBrowsers.scala7
-rw-r--r--src/compiler/scala/tools/nsc/matching/CodeFactory.scala6
-rw-r--r--src/compiler/scala/tools/nsc/matching/ParallelMatching.scala257
-rw-r--r--src/compiler/scala/tools/nsc/matching/PatternMatchers.scala16
4 files changed, 224 insertions, 62 deletions
diff --git a/src/compiler/scala/tools/nsc/ast/TreeBrowsers.scala b/src/compiler/scala/tools/nsc/ast/TreeBrowsers.scala
index 78e05d8e72..d5ee4018cc 100644
--- a/src/compiler/scala/tools/nsc/ast/TreeBrowsers.scala
+++ b/src/compiler/scala/tools/nsc/ast/TreeBrowsers.scala
@@ -376,6 +376,10 @@ abstract class TreeBrowsers {
case Star(t) =>
("Star", EMPTY)
+
+ case UnApply(fn, args) =>
+ ("Unapply", EMPTY)
+
}
/** Return a list of children for the given tree node */
@@ -520,6 +524,9 @@ abstract class TreeBrowsers {
case Star(t) =>
List(t)
+
+ case UnApply(fn,args) =>
+ fn::args
}
/** Return a textual representation of this t's symbol */
diff --git a/src/compiler/scala/tools/nsc/matching/CodeFactory.scala b/src/compiler/scala/tools/nsc/matching/CodeFactory.scala
index 8494b9e4bf..13472a8d19 100644
--- a/src/compiler/scala/tools/nsc/matching/CodeFactory.scala
+++ b/src/compiler/scala/tools/nsc/matching/CodeFactory.scala
@@ -80,6 +80,12 @@ trait CodeFactory {
result
}
+ def makeIf(cond:Tree, thenp:Tree, elsep:Tree) = cond match {
+ case Literal(Constant(true)) => thenp
+ case Literal(Constant(false)) => elsep
+ case _ => If(cond, thenp, elsep)
+ }
+
/** returns code `<seqObj>.elements' */
def newIterator(seqObj: Tree): Tree =
Apply(Select(seqObj, newTermName("elements")), List())
diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
index 90a36070f9..ea0d904b9b 100644
--- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
+++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
@@ -35,22 +35,36 @@ trait ParallelMatching {
}
return true
}}
- def containsUnapply = column exists { _.isInstanceOf[UnApply] }
-
- if(isSimpleIntSwitch) {
- new MixLiterals(scrutinee, column, rest)
- } else {
- new MixTypes(scrutinee, column, rest)
- }
- }
- /*
- class SwitchRule(val scrutinee:Symbol, val column:List[Tree], val rest:Rep) extends MixtureRule(scrutinee,column,rest) {
- def getTransition() = {
- //here, we generate a switch tree, bodies of the switch is the corresponding [[row in rest]]
- //if a default case has to be handled, body is [[that row + the default row]]
+ // an unapply for which we don't need a type test
+ def isUnapplyHead: Boolean = {
+ column.head match {
+ case UnApply(Apply(fn,_),arg) => fn.tpe match {
+ case MethodType(List(argtpe),_) =>
+ //Console.println("scrutinee.tpe"+scrutinee.tpe)
+ //Console.println("argtpe"+argtpe)
+ val r = scrutinee.tpe <:< argtpe
+ //Console.println("<: ???"+r)
+ r
+ case _ =>
+ //Console.println("wrong tpe")
+ false
+ }
+ case x =>
+ //Console.println("is something else"+x+" "+x.getClass)
+ false
+ }
}
+ //def containsUnapply = column exists { _.isInstanceOf[UnApply] }
+ //Console.println("intswitch? "+scrutinee.tpe.widen+" . "+column)
+ if(isSimpleIntSwitch)
+ return new MixLiterals(scrutinee, column, rest)
+
+ //Console.println("isUnapplyHead")
+ if(isUnapplyHead)
+ return new MixUnapply(scrutinee, column, rest)
+
+ return new MixTypes(scrutinee, column, rest) // todo: handle type tests in unapply
}
- */
/** this rule gets translated to a switch. all defaults/same literals are collected
*/
@@ -128,8 +142,66 @@ trait ParallelMatching {
class MixUnapply(val scrutinee:Symbol, val column:List[Tree], val rest:Rep) extends RuleApplication {
+ /** returns the (un)apply and two continuations */
+ def getTransition(implicit theOwner: Symbol): (Tree, List[Tree], Rep, Option[Rep]) = {
+ column.head match {
+ case ua @ UnApply(app @ Apply(fn, _), args) =>
+ val ures = newVar(ua.pos, app.tpe)
+ val n = args.length
+ val uacall = typer.typed { ValDef(ures, Apply(fn, List(Ident(scrutinee)))) }
+ //Console.println("uacall:"+uacall)
+
+ val nrowsOther = column.tail.zip(rest.row.tail) flatMap { case (pat,(ps,subst,g,b)) => pat match {
+ case UnApply(Apply(fn1,_),args) if fn.symbol==fn1.symbol => Nil
+ case p => List((p::ps, subst, g, b))
+ }}
+ val nrepFail = if(nrowsOther.isEmpty) None else Some(Rep(scrutinee::rest.temp, nrowsOther))
+ //Console.println("active = "+column.head+" / nrepFail = "+nrepFail)
+ n match {
+ case 0 => //special case for unapply(), app.tpe is boolean
+ val ntemps = rest.temp
+ val nrows = column.zip(rest.row) map { case (pat,(ps,subst,g,b)) => pat match {
+ case UnApply(Apply(fn1,_),args) if fn.symbol==fn1.symbol => (EmptyTree::ps, subst, g, b)
+ case p => (p::ps, subst, g, b)
+ }}
+ (uacall, List(), Rep(ntemps, nrows), nrepFail)
+
+ case 1 => //special case for unapply(p), app.tpe is Option[T]
+ val vsym = newVar(ua.pos, app.tpe.typeArgs(0))
+ val ntemps = vsym :: scrutinee :: rest.temp
+ val nrows = column.zip(rest.row) map { case (pat,(ps,subst,g,b)) => pat match {
+ case UnApply(Apply(fn1,_),args) if fn.symbol==fn1.symbol => (args(0) ::Ident(nme.WILDCARD)::ps, subst, g, b)
+ case p => (Ident(nme.WILDCARD) :: p ::ps, subst, g, b)
+ }}
+ (uacall, List( typer.typed { ValDef(vsym, Select(Ident(ures), nme.get) )}), Rep(ntemps, nrows), nrepFail)
+
+ case _ => // app.tpe is Option[? <: ProductN[T1,...,Tn]]
+ val uresGet = newVar(ua.pos, app.tpe.typeArgs(0))
+ var vdefs = ValDef(uresGet, Select(Ident(ures), nme.get))::Nil
+ var ts = definitions.getProductArgs(uresGet.tpe).get
+ var i = 1;
+ //Console.println("typeargs"+ts)
+ var vsyms:List[Symbol] = Nil
+ var dummies:List[Tree] = Nil
+ while(ts ne Nil) {
+ val vchild = newVar(ua.pos, ts.head)
+ val accSym = definitions.productProj(uresGet, i)
+ val rhs = typer.typed(Apply(Select(Ident(uresGet), accSym), List())) // nsc !
+ vdefs = ValDef(vchild, rhs)::vdefs
+ vsyms = vchild :: vsyms
+ ts = ts.tail
+ i = i + 1
+ dummies = Ident(nme.WILDCARD)::dummies
+ }
+ val ntemps = vsyms.reverse ::: scrutinee :: rest.temp
+ val nrows = column.zip(rest.row) map { case (pat,(ps,subst,g,b)) => pat match {
+ case UnApply(Apply(fn1,_),args) if fn.symbol==fn1.symbol =>( args:::Ident(nme.WILDCARD)::ps, subst, g, b)
+ case p =>(dummies::: p ::ps, subst, g, b)
+ }}
+ (uacall, vdefs.reverse, Rep(ntemps, nrows), nrepFail)
+ }}
+ }
}
-
/** this rule handles a column of type tests (and also tests for literals)
*/
class MixTypes(val scrutinee:Symbol, val column:List[Tree], val rest:Rep) extends RuleApplication {
@@ -163,10 +235,12 @@ trait ParallelMatching {
*/ private def subpatterns(pat:Tree): List[Tree] = pat match {
case Bind(_,p) => subpatterns(p)
case app @ Apply(fn, pats) if app.tpe.symbol.hasFlag(symtab.Flags.CASE) => pats
- case _: UnApply => throw CantHandleUnapply
+ case _: UnApply => dummies //throw CantHandleUnapply
+ //case Typed(p,_) => subpatterns(p)
case pat => /*//DEBUG("dummy patterns for "+pat+" of class "+pat.getClass);*/dummies
}
+ /** returns true if pattern tests an object */
def objectPattern(pat:Tree): Boolean = try {
(pat.symbol ne null) &&
(pat.symbol != NoSymbol) &&
@@ -180,20 +254,66 @@ trait ParallelMatching {
throw e
}
- // more specific patterns, subpatterns, remaining patterns
+ {
+ var sr = (moreSpecific,subsumed,remaining)
+ var j = 0; var pats = column; while(pats ne Nil) {
+ val (ms,ss,rs) = sr // more specific, more general(subsuming current), remaining patterns
+ val pat = pats.head
+
+ //Console.println("pat = "+pat)
+ //Console.println("current pat is wild? = "+isDefault(pat))
+ //Console.println("current pat.symbol = "+pat.symbol+", pat.tpe "+pat.tpe)
+ //Console.println("patternType = "+patternType)
+ //Console.println("(current)pat.tpe <:< patternType = "+(pat.tpe <:< patternType))
+ //Console.println("patternType <:< (current)pat.tpe = "+(patternType <:< pat.tpe))
+ //Console.println("(current)pat.tpe =:= patternType = "+(pat.tpe <:< patternType))
+
+ sr = pat match {
+
+ case Literal(Constant(null)) if !(patternType =:= pat.tpe) => //special case for constant null pattern
+ //Console.println("[1")
+ (ms,ss,(j,pat)::rs);
+ case _ if objectPattern(pat) =>
+ //Console.println("[2")
+ (EmptyTree::ms, (j,dummies)::ss, rs); // matching an object
+
+ case Typed(p,_) if (pat.tpe <:< patternType) =>
+ //Console.println("current pattern is same or *more* specific")
+ (p::ms, (j, dummies)::ss, rs);
+
+ case _ if (pat.tpe <:< patternType) && !isDefaultPattern(pat) =>
+ //Console.println("current pattern is same or *more* specific")
+ ({if(pat.tpe =:= patternType) EmptyTree else pat}::ms, (j,subpatterns(pat))::ss, rs);
+
+ case _ if (patternType <:< pat.tpe) || isDefaultPattern(pat) =>
+ //Console.println("current pattern is *more general*")
+ (EmptyTree::ms, (j,dummies)::ss, (j,pat)::rs); // subsuming (matched *and* remaining pattern)
+
+ case _ =>
+ //Console.println("current pattern tests something else")
+ (ms,ss,(j,pat)::rs)
+ }
+ j = j + 1
+ pats = pats.tail
+ }
+ this.moreSpecific = sr._1.reverse
+ this.subsumed = sr._2.reverse
+ this.remaining = sr._3.reverse
+ sr = null
+ }
+ //Console.println("GAGA:"+this)
+ /*
private var sr = column.zipWithIndex.foldLeft (moreSpecific,subsumed,remaining) {
(p,patAndIndex) =>
val (ms,ss,rs) = p
val (pat,j) = patAndIndex
-/*
- Console.println("pat = "+pat)
- Console.println("current pat is wild? = "+isDefault(pat))
- Console.println("current pat.symbol = "+pat.symbol+", pat.tpe "+pat.tpe)
- Console.println("patternType = "+patternType)
- Console.println("(current)pat.tpe <:< patternType = "+(pat.tpe <:< patternType))
- Console.println("patternType <:< (current)pat.tpe = "+(patternType <:< pat.tpe))
- Console.println("(current)pat.tpe =:= patternType = "+(pat.tpe <:< patternType))
-*/
+ //Console.println("pat = "+pat)
+ //Console.println("current pat is wild? = "+isDefault(pat))
+ //Console.println("current pat.symbol = "+pat.symbol+", pat.tpe "+pat.tpe)
+ //Console.println("patternType = "+patternType)
+ //Console.println("(current)pat.tpe <:< patternType = "+(pat.tpe <:< patternType))
+ //Console.println("patternType <:< (current)pat.tpe = "+(patternType <:< pat.tpe))
+ //Console.println("(current)pat.tpe =:= patternType = "+(pat.tpe <:< patternType))
pat match {
case Literal(Constant(null)) if !(patternType =:= pat.tpe) => //special case for constant null pattern
//Console.println("[1")
@@ -215,11 +335,7 @@ trait ParallelMatching {
(ms,ss,(j,pat)::rs)
}
}
-
- this.moreSpecific = sr._1.reverse
- this.subsumed = sr._2.reverse
- this.remaining = sr._3.reverse
- sr = null
+*/
override def toString = {
"MixTypes("+scrutinee+":"+scrutinee.tpe+") {\n moreSpecific:"+moreSpecific+"\n subsumed:"+subsumed+"\n remaining"+remaining+"\n}"
}
@@ -246,6 +362,7 @@ trait ParallelMatching {
if(moreSpecific.exists { x => x != EmptyTree }) {
ntemps = casted::ntemps // (***)
subtests = moreSpecific.zip(subsumed) map { case (mspat, (j,pats)) => (j,mspat::pats) }
+ //Console.println("MOS "+subtests)
}
ntemps = ntemps ::: rest.temp
val ntriples = subtests map {
@@ -263,9 +380,9 @@ trait ParallelMatching {
// BTW duplicating body/guard, gets you "defined twice" error cause hashing breaks
}
+ //Console.println("ntriples = "+ntriples.mkString("[[","\n","]]"))
Rep(ntemps, ntriples) /*setParent this*/
}
-
// fails => transition to translate(remaining)
val nmatrixFail: Option[Rep] = {
@@ -360,12 +477,14 @@ trait ParallelMatching {
}
vdefs = typed { ValDef(casted, gen.mkAsInstanceOf(typed{Ident(mm.scrutinee)}, casted.tpe))} :: vdefs
- def makeIf(cond:Tree, thenp:Tree, elsep:Tree) = cond match {
- case Literal(Constant(true)) => thenp
- case Literal(Constant(false)) => elsep
- case _ => If(cond, thenp, elsep)
- }
typed { makeIf(cond, Block(vdefs,succ), fail) }
+
+ case mu:MixUnapply =>
+ val (uacall,vdefs,srep,frep) = mu.getTransition
+ val succ = repToTree(srep, typed, handleOuter)
+ val fail = if(frep.isEmpty) failTree else repToTree(frep.get, typed, handleOuter)
+ val cond = typed { Not(Select(Ident(uacall.symbol),nme.isEmpty)) }
+ typed { Block(List(uacall), makeIf(cond,Block(vdefs,succ),fail)) }
}
}
@@ -381,7 +500,7 @@ object Rep {
def _2 = row
}
- /** the injection here handles alternatives */
+ /** the injection here handles alternatives and unapply type tests */
def apply(temp:List[Symbol], row1:List[(List[Tree], List[(Symbol,Symbol)], Tree, Tree)]): Rep = {
var i = -1
val row = row1 flatMap {
@@ -391,27 +510,51 @@ object Rep {
case Alternative(ps) => true
case _ => false
}
- def getAlternativeBranches(p:Tree): List[Tree] = {
- def get_BIND(pctx:Tree => Tree, p:Tree):List[Tree] = p match {
- case b @ Bind(n,p) => get_BIND({ x:Tree => pctx(copy.Bind(b, n, x) setType x.tpe) }, p)
- case Alternative(ps) => ps map pctx
+ def getAlternativeBranches(p:Tree): List[Tree] = {
+ def get_BIND(pctx:Tree => Tree, p:Tree):List[Tree] = p match {
+ case b @ Bind(n,p) => get_BIND({ x:Tree => pctx(copy.Bind(b, n, x) setType x.tpe) }, p)
+ case Alternative(ps) => ps map pctx
+ }
+ get_BIND({x=>x}, p)
+ }
+ val (opatso,subst,g,b) = xx
+ var opats = opatso
+ var pats:List[Tree] = Nil
+ var j = 0; while(opats ne Nil) {
+ opats.head match {
+ case p if isAlternative(p) && j == -1 =>
+ i = j
+ pats = p :: pats
+
+ case typat @ Typed(p,tpt) =>
+ pats = (if (tpt.tpe <:< temp(j).tpe) p else typat)::pats
+
+ case ua @ UnApply(Apply(fn,_),arg) => fn.tpe match {
+ case MethodType(List(argtpe),_) =>
+ pats = (if (temp(j).tpe <:< argtpe) ua else Typed(ua,TypeTree(argtpe)).setType(argtpe))::pats
+ }
+ case p =>
+ pats = p :: pats
+ }
+ opats = opats.tail
+ j = j + 1
+ }
+ pats = pats.reverse
+ //i = pats findIndexOf isAlternative
+ if(i == -1)
+ List((pats,subst,g,b))
+ else {
+ val prefix:List[Tree] = pats.take(i)
+ val alts = getAlternativeBranches(pats(i))
+ val suffix:List[Tree] = pats.drop(i+1)
+ alts map { p => (prefix ::: p :: suffix, subst, g, b) }
}
- get_BIND({x=>x}, p)
- }
- val (pats,subst,g,b) = xx
- i = pats findIndexOf isAlternative
- if(i == -1)
- List((pats,subst,g,b))
- else {
- val prefix:List[Tree] = pats.take(i)
- val alts = getAlternativeBranches(pats(i))
- val suffix:List[Tree] = pats.drop(i+1)
- alts map { p => (prefix ::: p :: suffix, subst, g, b) }
- }
}
- if(i == -1)
- RepImpl(temp,row).init
- else
+ if(i == -1) {
+ val ri = RepImpl(temp,row).init
+ //Console.println("ri = "+ri)
+ ri
+ } else
Rep(temp,row) // recursive call
}
}
diff --git a/src/compiler/scala/tools/nsc/matching/PatternMatchers.scala b/src/compiler/scala/tools/nsc/matching/PatternMatchers.scala
index bf5c6919e6..45b951666e 100644
--- a/src/compiler/scala/tools/nsc/matching/PatternMatchers.scala
+++ b/src/compiler/scala/tools/nsc/matching/PatternMatchers.scala
@@ -123,6 +123,9 @@ trait PatternMatchers { self: transform.ExplicitOuter with PatternNodes with Par
def construct(cases: List[Tree]): Unit = {
nPatterns = nPatterns + 1
+ if (settings.debug.value) {
+ Console.println(cases.mkString("construct{{{","\n","}}}"))
+ }
if(global.settings.Xmatchalgo.value != "incr")
try {
constructParallel(cases)
@@ -134,7 +137,7 @@ trait PatternMatchers { self: transform.ExplicitOuter with PatternNodes with Par
if (settings.debug.value) {
e.printStackTrace()
Console.println("****")
- Console.println("**** falling back, cause " + e.getMessage)
+ Console.println("**** falling back, cause " + e.toString +"/msg "+ e.getMessage)
Console.println("****")
for (CaseDef(pat,guard,_) <- cases)
Console.println(pat.toString)
@@ -203,7 +206,8 @@ trait PatternMatchers { self: transform.ExplicitOuter with PatternNodes with Par
//Console.println("pat:'"+x+"' / "+x.getClass)
x match {
case _:ArrayValue => throw CantHandleSeq
- case _:UnApply => throw CantHandleUnapply
+ case UnApply(fn,args) => //throw CantHandleUnapply // EXPR
+ traverseTrees(args)
case Ident(n) if n!= nme.WILDCARD =>
//DEBUG("I can't handle IDENT pattern:"+x)
//DEBUG("x.tpe.symbols:"+x.tpe.symbol)
@@ -213,9 +217,11 @@ trait PatternMatchers { self: transform.ExplicitOuter with PatternNodes with Par
// //DEBUG("I can't handle SELECT pattern:"+p)
// //DEBUG("p.tpe.symbols:"+p.tpe.symbol)
//throw CantHandleUnapply
- case p@Apply(_,_) if !p.tpe.symbol.hasFlag(symtab.Flags.CASE) =>
- //DEBUG("I can't handle APPLY pattern:"+p)
- //DEBUG("p.tpe.symbols:"+p.tpe.symbol)
+ case p@Apply(fn,args) if !p.tpe.symbol.hasFlag(symtab.Flags.CASE) =>
+ //Console.println("I can't handle APPLY pattern:"+p)
+ //Console.println("fn:"+fn)
+ //Console.println("args:"+args)
+ //Console.println("p.tpe.symbols:"+p.tpe.symbol)
throw CantHandleApply
//case p@Apply(_,_) if !p.tpe.symbol.hasFlag(symtab.Flags.CASE) => throw CantHandleUnapply //@todo