summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Phillips <paulp@improving.org>2009-07-01 02:57:50 +0000
committerPaul Phillips <paulp@improving.org>2009-07-01 02:57:50 +0000
commitd17c979ce0259c8eb2ad678e63f1eb28f46737a6 (patch)
tree7529cafeeb8a69c2719e836e350ae11e53bb5b31
parent91c683e22d8ab80b9d7e41258f4053d2f08d507f (diff)
downloadscala-d17c979ce0259c8eb2ad678e63f1eb28f46737a6.tar.gz
scala-d17c979ce0259c8eb2ad678e63f1eb28f46737a6.tar.bz2
scala-d17c979ce0259c8eb2ad678e63f1eb28f46737a6.zip
More pattern matcher streamlining.
-rw-r--r--src/compiler/scala/tools/nsc/matching/ParallelMatching.scala24
-rw-r--r--src/compiler/scala/tools/nsc/matching/PatternNodes.scala82
2 files changed, 50 insertions, 56 deletions
diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
index 99779a03fa..c028933266 100644
--- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
+++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala
@@ -234,7 +234,7 @@ trait ParallelMatching extends ast.TreeDSL {
*/
class MixLiterals(val pats: Patterns, val rest: Rep)(implicit rep:RepFactory) extends RuleApplication(rep) {
// e.g. (1,1) (1,3) (42,2) for column { case ..1.. => ;; case ..42..=> ;; case ..1.. => }
- var defaultV: immutable.Set[Symbol] = emptySymbolSet
+ var defaultV: Set[Symbol] = emptySymbolSet
var defaultIndexSet = new BitSet(pats.size)
def insertDefault(tag: Int, vs: Traversable[Symbol]) {
@@ -609,22 +609,18 @@ trait ParallelMatching extends ast.TreeDSL {
val (casted, srep, frep) = this.getTransition
val cond = condition(casted.tpe, scrut)
val succ = srep.toTree
- val fail = frep.map(_.toTree) getOrElse failTree
+ val fail = frep map (_.toTree) getOrElse (failTree)
// dig out case field accessors that were buried in (***)
val cfa = if (!pats.isCaseHead) Nil else casted.accessors
val caseTemps = srep.temp match { case x :: xs if x == casted.sym => xs ; case x => x }
+ def needCast = if (casted.sym ne scrut.sym) List(VAL(casted.sym) === (scrut.id AS_ANY casted.tpe)) else Nil
+ val vdefs = needCast ::: (
+ for ((tmp, accessor) <- caseTemps zip cfa) yield
+ typedValDef(tmp, typer typed fn(casted.id, accessor))
+ )
- var vdefs = for ((tmp, accessorMethod) <- caseTemps.zip(cfa)) yield {
- val untypedAccess = fn(casted.id, accessorMethod)
- val typedAccess = typer.typed(untypedAccess)
- typedValDef(tmp, typedAccess)
- }
-
- if (casted.sym ne scrut.sym)
- vdefs = ValDef(casted.sym, gen.mkAsInstanceOf(scrut.id, casted.tpe)) :: vdefs
-
- typer.typed( If(cond, squeezedBlock(vdefs, succ), fail) )
+ typer typed (IF (cond) THEN squeezedBlock(vdefs, succ) ELSE fail)
}
}
@@ -648,7 +644,7 @@ trait ParallelMatching extends ast.TreeDSL {
case _ => p.tpe coversSym sym
}
- guard.isEmpty && results.forall(_ == true)
+ guard.isEmpty && (results forall (true ==))
}
// returns this row with alternatives expanded
@@ -850,7 +846,7 @@ trait ParallelMatching extends ast.TreeDSL {
}
}
- val row = row1 flatMap { _.expand(classifyPat) }
+ val row = row1 flatMap (_ expand classifyPat)
if (row.length != row1.length) make(temp, row) // recursive call if any change
else Rep(temp, row).init
}
diff --git a/src/compiler/scala/tools/nsc/matching/PatternNodes.scala b/src/compiler/scala/tools/nsc/matching/PatternNodes.scala
index d25dfb0143..72ac832870 100644
--- a/src/compiler/scala/tools/nsc/matching/PatternNodes.scala
+++ b/src/compiler/scala/tools/nsc/matching/PatternNodes.scala
@@ -20,6 +20,7 @@ trait PatternNodes extends ast.TreeDSL
import symtab.Flags
import Types._
import CODE._
+ import definitions.{ ConsClass, ListClass, EqualsPatternClass, ListModule }
case class TypeComparison(x: Type, y: Type) {
def xIsaY = x <:< y
@@ -66,8 +67,8 @@ trait PatternNodes extends ast.TreeDSL
* test, it is returned as is.
*/
private def tpeWRTEquality(t: Type) = t match {
- case TypeRef(_, sym, arg::Nil) if sym eq EqualsPatternClass => arg
- case _ => t
+ case TypeRef(_, EqualsPatternClass, List(arg)) => arg
+ case _ => t
}
lazy val tpe = tpeWRTEquality(_tpe)
@@ -93,7 +94,7 @@ trait PatternNodes extends ast.TreeDSL
(tpe.typeSymbol == sym) ||
(symtpe <:< tpe) ||
- (symtpe.parents.exists(_.typeSymbol eq tpe.typeSymbol)) || // e.g. Some[Int] <: Option[&b]
+ (symtpe.parents exists (_.typeSymbol eq tpe.typeSymbol)) || // e.g. Some[Int] <: Option[&b]
(tpe.prefix.memberType(sym) <:< tpe) // outer, see combinator.lexical.Scanner
}
}
@@ -107,31 +108,32 @@ trait PatternNodes extends ast.TreeDSL
final def DBG(x: => String) = if (settings.debug.value) Console.println(x)
- final def getDummies(i: Int): List[Tree] = List.make(i, EmptyTree)
+ final def getDummies(i: Int): List[Tree] = List.fill(i)(EmptyTree)
- def makeBind(vs:List[Symbol], pat:Tree): Tree =
- if (vs eq Nil) pat else Bind(vs.head, makeBind(vs.tail, pat)) setType pat.tpe
-
- def mkTypedBind(vs: List[Symbol], tpe: Type) =
- makeBind(vs, Typed(WILD(tpe), TypeTree(tpe)) setType tpe)
+ def makeBind(vs: List[Symbol], pat: Tree): Tree = vs match {
+ case Nil => pat
+ case x :: xs => Bind(x, makeBind(xs, pat)) setType pat.tpe
+ }
- def mkEmptyTreeBind(vs: List[Symbol], tpe: Type) =
- makeBind(vs, Typed(EmptyTree, TypeTree(tpe)) setType tpe)
+ private def mkBind(vs: List[Symbol], tpe: Type, arg: Tree) =
+ makeBind(vs, Typed(arg, TypeTree(tpe)) setType tpe)
- def mkEqualsRef(xs: List[Type]) = typeRef(NoPrefix, definitions.EqualsPatternClass, xs)
+ def mkTypedBind(vs: List[Symbol], tpe: Type) = mkBind(vs, tpe, WILD(tpe))
+ def mkEmptyTreeBind(vs: List[Symbol], tpe: Type) = mkBind(vs, tpe, EmptyTree)
+ def mkEqualsRef(xs: List[Type]) = typeRef(NoPrefix, EqualsPatternClass, xs)
- def normalizedListPattern(pats:List[Tree], tptArg:Type): Tree = pats match {
+ def normalizedListPattern(pats: List[Tree], tptArg: Type): Tree = pats match {
case Nil => gen.mkNil
case (sp @ Strip(_, _: Star)) :: _ => makeBind(definedVars(sp), WILD(sp.tpe))
case x :: xs =>
- var resType: Type = null;
- val consType: Type = definitions.ConsClass.primaryConstructor.tpe match {
- case mt @ MethodType(args, res @ TypeRef(pre,sym,origArgs)) =>
- val listType = TypeRef(pre, definitions.ListClass, List(tptArg))
- resType = TypeRef(pre, sym , List(tptArg))
-
- val dummyMethod = new TermSymbol(NoSymbol, NoPosition, "matching$dummy")
- MethodType(dummyMethod.newSyntheticValueParams(List(tptArg, listType)), resType)
+ val (consType, resType): (Type, Type) = {
+ val MethodType(args, TypeRef(pre, sym, _)) = ConsClass.primaryConstructor.tpe
+
+ val mkTypeRef = TypeRef(pre, _: Symbol, List(tptArg))
+ val dummyMethod = new TermSymbol(NoSymbol, NoPosition, "matching$dummy")
+ val res = mkTypeRef(sym)
+
+ (MethodType(dummyMethod.newSyntheticValueParams(List(tptArg, mkTypeRef(ListClass))), res), res)
}
Apply(TypeTree(consType), List(x, normalizedListPattern(xs, tptArg))) setType resType
}
@@ -181,7 +183,7 @@ trait PatternNodes extends ast.TreeDSL
object UnApply_TypeApply {
def unapply(x: UnApply) = x match {
case UnApply(Apply(TypeApply(sel @ Select(stor, nme.unapplySeq), List(tptArg)), _), ArrayValue(_, xs)::Nil)
- if (stor.symbol eq definitions.ListModule) => Some(tptArg, xs)
+ if (stor.symbol eq ListModule) => Some(tptArg, xs)
case _ => None
}
}
@@ -223,18 +225,15 @@ trait PatternNodes extends ast.TreeDSL
}
final def definedVars(x: Tree): List[Symbol] = {
- implicit def listToStream[T](xs: List[T]): Stream[T] = xs.toStream
- def definedVars1(x: Tree): Stream[Symbol] = x match {
- case Apply(_, args) => definedVars2(args)
- case b @ Bind(_,p) => Stream.cons(b.symbol, definedVars1(p))
- case Typed(p,_) => definedVars1(p) // otherwise x @ (_:T)
- case UnApply(_,args) => definedVars2(args)
- case ArrayValue(_,xs) => definedVars2(xs)
- case _ => Nil
+ def vars(x: Tree): List[Symbol] = x match {
+ case Apply(_, args) => args flatMap vars
+ case b @ Bind(_, p) => b.symbol :: vars(p)
+ case Typed(p, _) => vars(p) // otherwise x @ (_:T)
+ case UnApply(_, args) => args flatMap vars
+ case ArrayValue(_, xs) => xs flatMap vars
+ case x => Nil
}
- def definedVars2(args: Stream[Tree]): Stream[Symbol] = args flatMap definedVars1
-
- definedVars1(x).reverse.toList
+ vars(x) reverse
}
/** pvar: the symbol of the pattern variable
@@ -243,20 +242,19 @@ trait PatternNodes extends ast.TreeDSL
case class Binding(pvar: Symbol, temp: Symbol)
case class Bindings(bindings: Binding*) extends Function1[Symbol, Option[Ident]] {
+ private def castIfNeeded(pvar: Symbol, t: Symbol) =
+ if (t.tpe <:< pvar.tpe) ID(t) else ID(t) AS_ANY pvar.tpe
+
def add(vs: Iterable[Symbol], temp: Symbol): Bindings =
- Bindings(vs.toList.map(Binding(_, temp)) ++ bindings : _*)
+ Bindings((vs.toList map (Binding(_: Symbol, temp))) ++ bindings : _*)
- def apply(v: Symbol): Option[Ident] = bindings.find(_.pvar eq v) match {
- case Some(b) => Some(Ident(b.temp) setType v.tpe)
- case None => None // abort("Symbol " + v + " has no binding in " + bindings)
- }
+ def apply(v: Symbol): Option[Ident] =
+ bindings find (_.pvar eq v) map (x => Ident(x.temp) setType v.tpe)
- /**
- * The corresponding list of value definitions.
- */
+ /** The corresponding list of value definitions. */
final def targetParams(implicit typer: Typer): List[ValDef] =
for (Binding(v, t) <- bindings.toList) yield
- VAL(v) === (typer typed (if (t.tpe <:< v.tpe) ID(t) else (ID(t) AS_ANY v.tpe)))
+ VAL(v) === (typer typed castIfNeeded(v, t))
}
val NoBinding: Bindings = Bindings()