summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorPaul Phillips <paulp@improving.org>2009-07-10 15:26:30 +0000
committerPaul Phillips <paulp@improving.org>2009-07-10 15:26:30 +0000
commit69fb6eaa7d9d65f974cc54a084ec27d347d054bf (patch)
tree3fd6e3d0191b1cee8df8bc22a5ebbd3d4882203c /src
parent79dc3b49f0ed25fcc7cb33fc8fe1c13a6fdc21b3 (diff)
downloadscala-69fb6eaa7d9d65f974cc54a084ec27d347d054bf.tar.gz
scala-69fb6eaa7d9d65f974cc54a084ec27d347d054bf.tar.bz2
scala-69fb6eaa7d9d65f974cc54a084ec27d347d054bf.zip
More of the same in the pattern matcher.
targets the mixing rule which generates switch statements. We should soon be generating switches for quite a few more cases than we are at present.
Diffstat (limited to 'src')
-rw-r--r--src/compiler/scala/tools/nsc/matching/ParallelMatching.scala242
-rw-r--r--src/compiler/scala/tools/nsc/matching/PatternNodes.scala18
2 files changed, 141 insertions, 119 deletions
diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
index 1d8c842dbd..534624eae5 100644
--- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
+++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
@@ -96,30 +96,7 @@ trait ParallelMatching extends ast.TreeDSL {
final def isDefault = isDefaultPattern(tree)
- /* a Seq ending in _* */
-
- final def getAlternativeBranches: List[Tree] = {
- def get_BIND(pctx: TreeFunction1, p: Tree): List[Tree] = p match {
- case b @ Bind(n, p) => get_BIND((x: Tree) => pctx(treeCopy.Bind(b, n, x) setType x.tpe), p)
- case Alternative(ps) => ps map pctx
- }
- get_BIND(x => x, tree)
- }
-
- // XXX move right place
- def allBindings: List[Bind] = allBindingsInt(tree)
- private def allBindingsInt(t: Tree): List[Bind] = t match {
- case b @ Bind(_, t) => b :: allBindingsInt(t)
- case _ => Nil
- }
-
- /** All variables binding the given pattern. */
- def boundVariables: List[Symbol] = allBindings.foldRight(List[Symbol]())(_.symbol :: _)
-
- /** The pattern with its variable bindings stripped. */
- def stripped: Tree = if (allBindings.isEmpty) tree else allBindings.last.body
-
- final def definedVars: List[Symbol] = ParallelMatching.this.definedVars(tree)
+ lazy val Strip(boundVariables, stripped) = tree
/** returns true if pattern tests an object */
final def isObjectTest(head: Type) =
@@ -135,10 +112,10 @@ trait ParallelMatching extends ast.TreeDSL {
val shortCuts = new ListBuffer[Symbol]()
lazy val reached = new BitSet(targets.size)
- private lazy val expandResult = expand(roots, cases)
- lazy val targets: List[Tree] = expandResult._2
- lazy val vss: List[List[Symbol]] = expandResult._3
- lazy val expansion: Rep = make(roots, expandResult._1)
+ private lazy val expandResult = expand(roots, cases)
+ lazy val targets: List[FinalState] = expandResult._2
+ lazy val vss: List[List[Symbol]] = expandResult._3
+ lazy val expansion: Rep = make(roots, expandResult._1)
final def shortCut(theLabel: Symbol): Int = {
shortCuts += theLabel
@@ -219,7 +196,7 @@ trait ParallelMatching extends ast.TreeDSL {
(v, typedValDef(v, substv))
)
- val body = targets(bx)
+ val body = targets(bx).body
// @bug: typer is not able to digest a body of type Nothing being assigned result type Unit
val tpe = if (body.tpe.isNothing) body.tpe else resultType
val newType = MethodType(vsyms, tpe)
@@ -237,7 +214,7 @@ trait ParallelMatching extends ast.TreeDSL {
val args = vss(bx) map subst flatten
val label = labels(bx)
- val body = targets(bx)
+ val body = targets(bx).body
val fmls = label.tpe.paramTypes
def debugConsistencyFailure(): String = {
@@ -346,7 +323,7 @@ trait ParallelMatching extends ast.TreeDSL {
f(unbind(opat))
}
- val rows = row1 flatMap (_ expand classifyPat)
+ val rows = row1 flatMap (_ expandAlternatives classifyPat)
if (rows.length != row1.length) make(tvars, rows) // recursive call if any change
else Rep(tvars, rows).checkExhaustive
}
@@ -404,7 +381,7 @@ trait ParallelMatching extends ast.TreeDSL {
}
case class Patterns(scrut: Scrutinee, ps: List[Pattern]) {
- private lazy val column = ps map (_.tree)
+ private lazy val trees = ps map (_.tree)
lazy val head = ps.head
lazy val tail = Patterns(scrut, ps.tail)
lazy val last = ps.last.tree
@@ -414,27 +391,47 @@ trait ParallelMatching extends ast.TreeDSL {
lazy val size = ps.length
def apply(i: Int): Tree = ps(i).tree
- def zip() = column.zipWithIndex
- def zip[T](ys: List[T]) = column zip ys
+ def zip() = trees.zipWithIndex
+ def zip[T](others: List[T]) = trees zip others
def isObjectTest(pat: Pattern) = pat isObjectTest headType
def isObjectTest(pat: Tree) = Pattern(pat) isObjectTest headType
- // an unapply for which we don't need a type test
- def isUnapplyHead = cond (head.tree) { case __UnApply(_,tpe,_) => scrut.tpe <:< tpe }
- def isSimpleSwitch: Boolean =
- scrut.isSimple && (column.init forall isSwitchableConst) &&
- // TODO: This needs to also allow the case that the last is a compatible type pattern.
- (isSwitchableConst(last) || isDefaultPattern(last))
+ def extractSimpleSwitch(): Option[(List[Tree], Option[Tree])] = {
+ def isSwitchableDefault(x: Tree) = isSwitchableConst(x) || isDefaultPattern(x)
+ val (lits, others) = trees span isSwitchableConst
+ others match {
+ case Nil => Some(lits, None)
+ // TODO: This needs to also allow the case that the last is a compatible type pattern.
+ case List(x) if isSwitchableDefault(x) => Some(lits, Some(x))
+ case _ => None
+ }
+ }
+
+ // an unapply for which we don't need a type test (the scrutinee's static type conforms
+ // to the unapply's argument type.)
+ object SafeUnapply {
+ def unapply(x: Tree): Boolean = cond(x) { case __UnApply(_,tpe,_) => scrut.tpe <:< tpe }
+ }
+
+ object SimpleSwitch {
+ // TODO - scala> (5: Any) match { case 5 => 5 ; case 6 => 7 }
+ // ... should compile to a switch. It doesn't because the scrut isn't Int/Char, but
+ // that could be handle in an if/else since every pattern requires an Int.
+ // More immediately, Byte and Short scruts should also work.
+ def unapply(x: Patterns) = if (x.scrut.isSimple) x.extractSimpleSwitch else None
+ }
def mkRule(rest: Rep): RuleApplication =
- logAndReturn("mkRule: ", head match {
- case x if isEquals(x.tpe) => new MixEquals(this, rest)
- case Pattern(x: ArrayValue) => if (isRightIgnoring(x)) new MixSequenceStar(this, rest)
- else new MixSequence(this, rest)
- case _ if isSimpleSwitch => new MixLiterals(this, rest)
- case _ if isUnapplyHead => new MixUnapply(this, rest)
- case _ => new MixTypes(this, rest)
+ logAndReturn("mkRule: ", head.tree match {
+ case x if isEquals(x.tpe) => new MixEquals(this, rest)
+ case x: ArrayValue if isRightIgnoring(x) => new MixSequenceStar(this, rest)
+ case x: ArrayValue => new MixSequence(this, rest)
+ case SafeUnapply() => new MixUnapply(this, rest)
+ case _ => this match {
+ case SimpleSwitch(lits, d) => new MixLiteralInts(this, rest, lits, d)
+ case _ => new MixTypes(this, rest)
+ }
}
)
}
@@ -487,7 +484,7 @@ trait ParallelMatching extends ast.TreeDSL {
**/
/** picks which rewrite rule to apply
- * @precondition: column does not contain alternatives (ensured by initRep)
+ * @precondition: column does not contain alternatives
*/
def MixtureRule(scrut: Scrutinee, column: List[Tree], rest: Rep): RuleApplication =
Patterns(scrut, column map Pattern) mkRule rest
@@ -536,61 +533,49 @@ trait ParallelMatching extends ast.TreeDSL {
}
}
- /** mixture rule for literals
+ /** Mixture rule for all literal ints (and chars) i.e. hopefully a switch
+ * will be emitted on the JVM.
*/
- class MixLiterals(val pats: Patterns, val rest: Rep) extends RuleApplication {
- // e.g. (1,1) (1,3) (42,2) for column { case ..1.. => ;; case ..42..=> ;; case ..1.. => }
- var defaultV: Set[Symbol] = emptySymbolSet
- var defaultIndexSet = new BitSet(pats.size)
- protected var tagIndices = IntMap.empty[List[Int]]
+ class MixLiteralInts(
+ val pats: Patterns,
+ val rest: Rep,
+ literals: List[Tree],
+ defaultPattern: Option[Tree])
+ extends RuleApplication
+ {
+ lazy val defaultVars = defaultPattern.toList flatMap (_.boundVariables)
+ lazy val defaultRows =
+ if (defaultPattern.isEmpty) Nil
+ else List((rest rows literals.size).rebind2(defaultVars, scrut.sym))
- def insertDefault(tag: Int, vs: Traversable[Symbol]): Unit = {
- defaultIndexSet += tag
- defaultV = defaultV ++ vs
- }
-
- def haveDefault: Boolean = !defaultIndexSet.isEmpty
- def defaultRows: List[Row] = defaultIndexSet.toList reverseMap grabRow
+ protected var tagIndices = IntMap.empty[List[Int]]
protected def grabRow(index: Int): Row = {
val r = rest.rows(index)
- if (defaultV.isEmpty) r
+ if (defaultVars.isEmpty) r
else r.rebind2(pats(index).boundVariables, scrut.sym) // get vars
}
- /** inserts rows indices using in to list of tagIndices */
- protected def tagIndicesToReps() : List[(Int, Rep)] =
- tagIndices map { case (k, v) => (k, make(rest.tvars, (v reverseMap grabRow) ::: defaultRows)) } toList
-
- protected def defaultsToRep() = make(rest.tvars, defaultRows)
-
- protected def insertTagIndexPair(tag: Int, index: Int) =
- tagIndices = tagIndices.update(tag, index :: tagIndices.getOrElse(tag, Nil))
-
- /** returns
- * @return list of continuations,
- * @return variables bound to default continuation,
- * @return optionally, a default continuation
- **/
- def getTransition(): (List[(Int,Rep)], Set[Symbol], Option[Rep]) =
- (tagIndicesToReps, defaultV, if (haveDefault) Some(defaultsToRep) else None)
+ private def listToRep(indices: List[Int]) =
+ make(rest.tvars, (indices reverseMap grabRow) ::: defaultRows)
val varMap: List[(Int, List[Symbol])] = {
- def insertPair(c: Int, index: Int, x: Tree) = {
- insertTagIndexPair(c, index)
- Some(c, definedVars(x))
+ def insertPair(tag: Int, index: Int, x: Tree) = {
+ tagIndices = tagIndices.update(tag, index :: tagIndices.getOrElse(tag, Nil))
+ Some(tag, definedVars(x))
}
(for ((p, i) <- pats.zip) yield unbind(p) match {
+ case LIT(c: Byte) => insertPair(c, i, p)
+ case LIT(c: Short) => insertPair(c, i, p)
case LIT(c: Int) => insertPair(c, i, p)
case LIT(c: Char) => insertPair(c.toInt, i, p)
- case _ => insertDefault(i, p.boundVariables) ; None
+ case _ => None
}) flatMap (x => x) reverse
}
- // lazy
private def bindVars(Tag: Int, orig: Bindings): Bindings = {
- def myBindVars(rest:List[(Int,List[Symbol])], bnd: Bindings): Bindings = rest match {
+ def myBindVars(rest: List[(Int, List[Symbol])], bnd: Bindings): Bindings = rest match {
case Nil => bnd
case (Tag,vs)::xs => myBindVars(xs, bnd.add(vs, scrut.sym))
case (_, vs)::xs => myBindVars(xs, bnd)
@@ -599,24 +584,29 @@ trait ParallelMatching extends ast.TreeDSL {
}
final def tree(): Tree = {
- val (branches, defaultV, defaultRep) = this.getTransition // tag body pairs
- val cases = for ((tag, r) <- branches) yield {
- val r2 = make(r.tvars, r.rows map (x => x rebind bindVars(tag, x.subst)))
+ val cases =
+ for ((tag, vs) <- tagIndices.toList) yield {
+ val r = listToRep(vs)
+ val r2 = make(r.tvars, r.rows map (x => x rebind bindVars(tag, x.subst)))
- CASE(Literal(tag)) ==> r2.toTree
- }
- lazy val ndefault = defaultRep map (_.toTree) getOrElse (failTree)
- lazy val casesWithDefault = cases ::: List(CASE(WILD(IntClass.tpe)) ==> ndefault)
+ CASE(Literal(tag)) ==> r2.toTree
+ }
+
+ val defaultTree = make(rest.tvars, defaultRows).toTree
+ def casesWithDefault = cases ::: List(CASE(WILD(IntClass.tpe)) ==> defaultTree)
cases match {
- case CaseDef(lit,_,body) :: Nil => IF (scrut.id ANY_== lit) THEN body ELSE ndefault
+ case List(CaseDef(lit, _, body)) =>
+ // only one case becomes if/else
+ IF (scrut.id ANY_== lit) THEN body ELSE defaultTree
case _ =>
- val target: Tree = if (scrut.tpe.isChar) scrut.id DOT nme.toInt else scrut.id // chars to ints
+ // otherwise cast to an Int if necessary and run match
+ val target: Tree = if (!scrut.tpe.isInt) scrut.id DOT nme.toInt else scrut.id
target MATCH (casesWithDefault: _*)
}
}
override def toString = {
- "MixLiterals {\n pats: %s\n varMap: %s\n}".format(
+ "MixLiteralInts {\n pats: %s\n varMap: %s\n}".format(
pats, varMap
)
}
@@ -680,9 +670,9 @@ trait ParallelMatching extends ast.TreeDSL {
val nrows = mkNewRows(identity, ts.size)
val (vdefs: List[Tree], vsyms: List[Symbol]) = List.unzip(
- for ((vtpe, i) <- ts.zip((1 to ts.size).toList)) yield {
+ for ((vtpe, i) <- ts.zipWithIndex) yield {
val vchild = mkVar(vtpe)
- val accSym = productProj(uresGet, i)
+ val accSym = productProj(uresGet, i+1)
val rhs = typer typed fn(ID(uresGet), accSym)
(typedValDef(vchild, rhs), vchild)
@@ -907,6 +897,19 @@ trait ParallelMatching extends ast.TreeDSL {
def xIsaY = s <:< p
def yIsaX = p <:< s
+ // XXX exploring what breaks things and what doesn't
+ // def dummyIsOk = {
+ // val old = erased.yIsaX || yIsaX || isDef
+ // println("Old logic: %s || %s || %s == %s".format(erased.yIsaX, yIsaX, isDef, erased.yIsaX || yIsaX || isDef))
+ // println("isCaseClass(spat.tpe) = %s, isCaseClass(pats.headType) = %s".format(
+ // isCaseClass(spat.tpe), isCaseClass(pats.headType)))
+ // println("spat.tpe = %s, pats.head = %s, pats.headType = %s".format(
+ // spat.tpe, pats.head, pats.headType))
+ //
+ // (erased.yIsaX || yIsaX || isDef)
+ // // (!isCaseClass(spat.tpe) || !isCaseClass(pats.headType))
+ // }
+
// each pattern will yield a triple of options corresponding to the three lists,
// which will be flattened down to the values
implicit def mkOpt[T](x: T): Option[T] = Some(x) // limits noise from Some(value)
@@ -1015,21 +1018,27 @@ trait ParallelMatching extends ast.TreeDSL {
guard.isEmpty && (combos forall (c => c isCovered pat(c.index)))
// returns this rows with alternatives expanded
- def expand(classifyPat: (Tree, Int) => Tree): List[Row] = {
- def getAlternativeBranches(p: Tree): List[Tree] = {
- def get_BIND(pctx: TreeFunction1, p:Tree): List[Tree] = p match {
- case b @ Bind(n,p) => get_BIND((x: Tree) => pctx(treeCopy.Bind(b, n, x) setType x.tpe), p)
- case Alternative(ps) => ps map pctx
- }
- logAndReturn("get_BIND: ", get_BIND(x => x, p))
+ def expandAlternatives(classifyPat: (Tree, Int) => Tree): List[Row] = {
+ // If the given pattern contains alternatives, return it as a list of patterns.
+ // Makes typed copies of any bindings found so all alternatives point to final state.
+ def newPrev(b: Bind): TreeFunction1 = (x: Tree) => treeCopy.Bind(b, b.name, x) setType x.tpe
+ def extractBindings(p: Tree, prevBindings: TreeFunction1 = identity[Tree] _): List[Tree] = p match {
+ case b @ Bind(_, body) => extractBindings(body, newPrev(b))
+ case Alternative(ps) => ps map prevBindings
+ case x => List(x) // this shouldn't happen
}
- val indexOfAlternative = pat findIndexOf isAlternative
- val pats: List[Tree] = List.map2(pat, pat.indices.toList)(classifyPat)
- lazy val (prefix, alts :: suffix) = pats splitAt indexOfAlternative
- lazy val alternativeBranches = getAlternativeBranches(alts) map (x => replace(prefix ::: x :: suffix))
+ // classify all the top level patterns - alternatives come back unaltered
+ val newPats: List[Tree] = List.map2(pat, pat.indices.toList)(classifyPat)
- if (indexOfAlternative == -1) List(replace(pats)) else alternativeBranches
+ // expand alternatives if any are present
+ (newPats findIndexOf isAlternative) match {
+ case -1 => List(replace(newPats))
+ case index =>
+ val (prefix, alts :: suffix) = newPats splitAt index
+ // make a new row for each alternative, with it spliced into the original position
+ extractBindings(alts) map (x => replace(prefix ::: x :: suffix))
+ }
}
override def toString() = {
val patStr = pat.mkString
@@ -1039,6 +1048,9 @@ trait ParallelMatching extends ast.TreeDSL {
"Row(%d) %s%s".format(bx, patStr, otherStr)
}
}
+
+ case class FinalState(subst: Bindings, body: Tree)
+
case class Combo(index: Int, sym: Symbol) {
// is this combination covered by the given pattern?
def isCovered(p: Tree) = cond(unbind(p)) {
@@ -1050,6 +1062,16 @@ trait ParallelMatching extends ast.TreeDSL {
case class Branch[T](action: T, succ: Rep, fail: Option[Rep])
case class UnapplyCall(ua: Tree, args: List[Tree])
+ // sealed abstract class Pat {
+ // def isSimple: Boolean
+ // }
+ // case class SimplePat(pat: Tree) extends Pat {
+ // val isSimple = true
+ // }
+ // case class ComplexPat(pat: Tree) extends Pat {
+ // val isSimple = false
+ // }
+
case class Rep(val tvars: List[Symbol], val rows: List[Row]) {
import Flags._
@@ -1173,7 +1195,7 @@ trait ParallelMatching extends ast.TreeDSL {
}
/** Expands the patterns recursively. */
- final def expand(roots: List[Symbol], cases: List[Tree]): (List[Row], List[Tree], List[List[Symbol]]) = {
+ final def expand(roots: List[Symbol], cases: List[Tree]): (List[Row], List[FinalState], List[List[Symbol]]) = {
val res = unzip3(
for ((CaseDef(pat, guard, body), index) <- cases.zipWithIndex) yield {
def mkRow(ps: List[Tree]) = Row(ps, NoBinding, Guard(guard), index)
@@ -1184,7 +1206,7 @@ trait ParallelMatching extends ast.TreeDSL {
case WILD() => mkRow(getDummies(roots.length))
}
- (rowForPat, body, definedVars(pat))
+ (rowForPat, FinalState(NoBinding, body), definedVars(pat))
}
)
diff --git a/src/compiler/scala/tools/nsc/matching/PatternNodes.scala b/src/compiler/scala/tools/nsc/matching/PatternNodes.scala
index 4ce321aeb7..256019721c 100644
--- a/src/compiler/scala/tools/nsc/matching/PatternNodes.scala
+++ b/src/compiler/scala/tools/nsc/matching/PatternNodes.scala
@@ -25,9 +25,9 @@ trait PatternNodes extends ast.TreeDSL
type TypeComparison = (Type, Type) => Boolean
// Tests on Types
- def isEquals(t: Type) = cond(t) { case TypeRef(_, EqualsPatternClass, _) => true }
- def isAnyRef(t: Type) = t <:< AnyRefClass.tpe
- def isCaseClass(t: Type) = t.typeSymbol hasFlag Flags.CASE
+ def isEquals(t: Type) = cond(t) { case TypeRef(_, EqualsPatternClass, _) => true }
+ def isAnyRef(t: Type) = t <:< AnyRefClass.tpe
+ def isCaseClass(t: Type) = t.typeSymbol hasFlag Flags.CASE
// Comparisons on types
// def sameSymbols: TypeComparison = _.typeSymbol eq _.typeSymbol
@@ -109,7 +109,7 @@ trait PatternNodes extends ast.TreeDSL
(tpe.typeSymbol == sym) ||
(symtpe <:< tpe) ||
- (symtpe.parents exists (_.typeSymbol eq tpe.typeSymbol)) || // e.g. Some[Int] <: Option[&b]
+ (symtpe.parents exists (x => cmpSymbols(x, tpe))) || // e.g. Some[Int] <: Option[&b]
((tpe.prefix memberType sym) <:< tpe) // outer, see combinator.lexical.Scanner
}
}
@@ -212,12 +212,12 @@ trait PatternNodes extends ast.TreeDSL
}
// break a pattern down into bound variables and underlying tree.
- object Strip {
- private def strip(syms: Set[Symbol], t: Tree): (Set[Symbol], Tree) = t match {
- case b @ Bind(_, pat) => strip(syms + b.symbol, pat)
- case _ => (syms, t)
+ object Strip {
+ private def strip(t: Tree, syms: List[Symbol] = Nil): (Tree, List[Symbol]) = t match {
+ case b @ Bind(_, pat) => strip(pat, b.symbol :: syms)
+ case _ => (t, syms)
}
- def unapply(x: Tree): Option[(Set[Symbol], Tree)] = Some(strip(Set(), x))
+ def unapply(x: Tree): Option[(List[Symbol], Tree)] = Some(strip(x).swap)
}
object Stripped {