diff options
3 files changed, 121 insertions, 109 deletions
diff --git a/src/compiler/scala/tools/nsc/ast/TreeDSL.scala b/src/compiler/scala/tools/nsc/ast/TreeDSL.scala index ed9e933659..9c12ac6bd5 100644 --- a/src/compiler/scala/tools/nsc/ast/TreeDSL.scala +++ b/src/compiler/scala/tools/nsc/ast/TreeDSL.scala @@ -43,10 +43,7 @@ trait TreeDSL { object LIT extends (Any => Literal) { def apply(x: Any) = Literal(Constant(x)) - def unapply(x: Any) = x match { - case Literal(Constant(value)) => Some(value) - case _ => None - } + def unapply(x: Any) = condOpt(x) { case Literal(Constant(value)) => value } } // You might think these could all be vals, but empirically I have found that diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala index 0c7c51d6a1..2fe1bb4051 100644 --- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala +++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala @@ -307,8 +307,8 @@ trait ParallelMatching extends ast.TreeDSL { def rebindEmpty(tpe: Type) = mkEmptyTreeBind(vars, tpe) def rebindTyped() = mkTypedBind(vars, equalsCheck(unbind(opat))) - // @pre for UnApply_TypeApply: is not right-ignoring (no star pattern) ; no exhaustivity check - def doUnapplyTypeApply(tptArg: Tree, xs: List[Tree]) = { + // @pre for doUnapplySeq: is not right-ignoring (no star pattern) ; no exhaustivity check + def doUnapplySeq(tptArg: Tree, xs: List[Tree]) = { temp(j) setFlag Flags.TRANS_FLAG rebind(normalizedListPattern(xs, tptArg.tpe)) } @@ -345,7 +345,7 @@ trait ParallelMatching extends ast.TreeDSL { { case x if doReturnOriginal(x) => opat } , { case x if doRebindTyped(x) => rebindTyped() } , // Ident(_) != nme.WILDCARD { case _: This => opat } , - { case UnApply_TypeApply(tptArg, xs) => doUnapplyTypeApply(tptArg, xs) } , + { case UnapplySeq(tptArg, xs) => doUnapplySeq(tptArg, xs) } , { case ua @ UnApply(Apply(fn, _), _) => doUnapplyApply(ua, fn) } , { case o @ Apply_Function(fn) => doApplyFunction(o, fn) } , { case Apply_Value(pre, sym) => rebindEmpty(mkEqualsRef(List(singleType(pre, sym)))) } , @@ -582,12 +582,18 @@ trait ParallelMatching extends ast.TreeDSL { def getTransition(): (List[(Int,Rep)], Set[Symbol], Option[Rep]) = (tagIndicesToReps, defaultV, if (haveDefault) Some(defaultsToRep) else None) - val varMap: List[(Int, List[Symbol])] = - (for ((x, i) <- pats.zip) yield unbind(x) match { - case p @ LIT(c: Int) => insertTagIndexPair(c, i) ; Some(c, definedVars(x)) - case p @ LIT(c: Char) => insertTagIndexPair(c.toInt, i) ; Some(c.toInt, definedVars(x)) - case p if isDefaultPattern(p) => insertDefault(i, x.boundVariables) ; None + val varMap: List[(Int, List[Symbol])] = { + def insertPair(c: Int, index: Int, x: Tree) = { + insertTagIndexPair(c, index) + Some(c, definedVars(x)) + } + + (for ((p, i) <- pats.zip) yield unbind(p) match { + case LIT(c: Int) => insertPair(c, i, p) + case LIT(c: Char) => insertPair(c.toInt, i, p) + case _ => insertDefault(i, p.boundVariables) ; None }) flatMap (x => x) reverse + } // lazy private def bindVars(Tag: Int, orig: Bindings): Bindings = { @@ -630,10 +636,7 @@ trait ParallelMatching extends ast.TreeDSL { object sameUnapplyCall { def sameFunction(fn1: Tree) = fxn.symbol == fn1.symbol && (fxn equalsStructure fn1) - def unapply(t: Tree) = t match { - case UnApply(Apply(fn1, _), args) if sameFunction(fn1) => Some(args) - case _ => None - } + def unapply(t: Tree) = condOpt(t) { case UnApply(Apply(fn1, _), args) if sameFunction(fn1) => args } } def newVarCapture(pos: Position, tpe: Type) = @@ -841,22 +844,26 @@ trait ParallelMatching extends ast.TreeDSL { class MixEquals(val pats: Patterns, val rest: Rep) extends RuleApplication { /** condition (to be used in IF), success and failure Rep */ final def getTransition(): (Branch[Tree], Symbol) = { - val vlue = (head.tpe: @unchecked) match { - case TypeRef(_,_,List(SingleType(pre,sym))) => REF(pre, sym) - case TypeRef(_,_,List(PseudoType(o))) => o.duplicate + val value = { + // how is it known these are the only possible typerefs here? + val TypeRef(_,_,List(arg)) = head.tpe + arg match { + case SingleType(pre, sym) => REF(pre, sym) + case PseudoType(o) => o.duplicate + } } - assert(vlue.tpe ne null, "value tpe is null") - val vs = head.boundVariables - val nsuccFst = rest.row.head.insert2(List(EmptyTree), vs, scrut.sym) - val failLabel = owner.newLabel(scrut.pos, newName(scrut.pos, "failCont%")) // warning, untyped - val sx = shortCut(failLabel) // register shortcut - val nsuccRow = nsuccFst :: Row(getDummies( 1 /* scrutinee */ + rest.temp.length), NoBinding, NoGuard, sx) :: Nil + + val label = owner.newLabel(scrut.pos, newName(scrut.pos, "failCont%")) // warning, untyped + val succ = List( + rest.row.head.insert2(List(EmptyTree), head.boundVariables, scrut.sym), + Row(getDummies(1 + rest.temp.length), NoBinding, NoGuard, shortCut(label)) + ) // todo: optimize if no guard, and no further tests - val nsucc = mkNewRep(Nil, rest.temp, nsuccRow) - val nfail = mkFail(List.map2(rest.row.tail, pats.tail.ps)(_ insert _)) + val fail = mkFail(List.map2(rest.row.tail, pats.tail.ps)(_ insert _)) + val action = typer typed (scrut.id ANY_== value) - (Branch(typer typed (scrut.id ANY_== vlue), nsucc, Some(nfail)), failLabel) + (Branch(action, mkNewRep(Nil, rest.temp, succ), Some(fail)), label) } final def tree() = { @@ -972,7 +979,7 @@ trait ParallelMatching extends ast.TreeDSL { val fail = frep map (_.toTree) getOrElse (failTree) // dig out case field accessors that were buried in (***) - val cfa = if (!pats.isCaseHead) Nil else casted.accessors + val cfa = if (pats.isCaseHead) casted.accessors else Nil val caseTemps = srep.temp match { case x :: xs if x == casted.sym => xs ; case x => x } def needCast = if (casted.sym ne scrut.sym) List(VAL(casted.sym) === (scrut.id AS_ANY castedTpe)) else Nil val vdefs = needCast ::: ( @@ -1076,32 +1083,29 @@ trait ParallelMatching extends ast.TreeDSL { */ final def applyRule(): RuleApplication = { def dropIndex[T](xs: List[T], n: Int) = (xs take n) ::: (xs drop (n + 1)) - if (row.isEmpty) - return ErrorRule() - - val Row(pats, subst, guard, index) = row.head - def guardedRest = - if (guard.isEmpty) null else make(temp, row.tail) - val (defaults, others) = pats span (p => isDefaultPattern(unbind(p))) + lazy val Row(pats, subst, guard, index) = row.head + lazy val guardedRest = if (guard.isEmpty) null else make(temp, row.tail) + lazy val (defaults, others) = pats span (p => isDefaultPattern(unbind(p))) - /** top-most row contains only variables/wildcards */ - if (others.isEmpty) { - val binding = (defaults map { case Strip(vs, _) => vs } zip temp) . - foldLeft(subst)((b, pair) => b.add(pair._1, pair._2)) + if (row.isEmpty) ErrorRule() + else others match { + /** top-most row contains only variables/wildcards */ + case Nil => + val binding = (defaults map (_.boundVariables) zip temp) . + foldLeft(subst)((b, pair) => b.add(pair._1, pair._2)) - return VariableRule(binding, guard, guardedRest, index) - } + VariableRule(binding, guard, guardedRest, index) - val rpat @ Strip(vs, p) = others.head - val px = defaults.size - // Row( _ ... _ p_1i ... p_1n g_m b_m ) :: rows - // cut out column px that contains the non-default pattern - val column = rpat :: (row.tail map (_ pat px)) - val restTemp = dropIndex(temp, px) - val restRows = row map (r => r replace dropIndex(r.pat, px)) + /** cut out the column (px) containing the non-default pattern. */ + case (rpat @ Strip(vs, _)) :: _ => + val px = defaults.size + val column = rpat :: (row.tail map (_ pat px)) + val restTemp = dropIndex(temp, px) + val restRows = row map (r => r replace dropIndex(r.pat, px)) - MixtureRule(new Scrutinee(temp(px)), column, make(restTemp, restRows)) + MixtureRule(new Scrutinee(temp(px)), column, make(restTemp, restRows)) + } } def checkExhaustive: this.type = { @@ -1168,12 +1172,12 @@ trait ParallelMatching extends ast.TreeDSL { final def expand(roots: List[Symbol], cases: List[Tree]): (List[Row], List[Tree], List[List[Symbol]]) = { val res = unzip3( for ((CaseDef(pat, guard, body), index) <- cases.zipWithIndex) yield { - def mkRow(ps: List[Tree]) = Some(Row(ps, NoBinding, Guard(guard), index)) - def rowForPat: Option[Row] = pat match { + def mkRow(ps: List[Tree]) = Row(ps, NoBinding, Guard(guard), index) + + def rowForPat: Option[Row] = condOpt(pat) { case _ if roots.length <= 1 => mkRow(List(pat)) case Apply(fn, args) => mkRow(args) case WILD() => mkRow(getDummies(roots.length)) - case _ => None } (rowForPat, body, definedVars(pat)) diff --git a/src/compiler/scala/tools/nsc/matching/PatternNodes.scala b/src/compiler/scala/tools/nsc/matching/PatternNodes.scala index 5a671792ea..3e101da5fe 100644 --- a/src/compiler/scala/tools/nsc/matching/PatternNodes.scala +++ b/src/compiler/scala/tools/nsc/matching/PatternNodes.scala @@ -123,61 +123,79 @@ trait PatternNodes extends ast.TreeDSL def mkEmptyTreeBind(vs: List[Symbol], tpe: Type) = mkBind(vs, tpe, EmptyTree) def mkEqualsRef(xs: List[Type]) = typeRef(NoPrefix, EqualsPatternClass, xs) - def normalizedListPattern(pats: List[Tree], tptArg: Type): Tree = pats match { - case Nil => gen.mkNil - case (sp @ Strip(_, _: Star)) :: _ => makeBind(definedVars(sp), WILD(sp.tpe)) - case x :: xs => - val (consType, resType): (Type, Type) = { - val MethodType(args, TypeRef(pre, sym, _)) = ConsClass.primaryConstructor.tpe - - val mkTypeRef = TypeRef(pre, _: Symbol, List(tptArg)) + /** For folding a list into a well-typed x :: y :: etc :: tree. */ + private def listFolder(tpe: Type) = { + val MethodType(_, TypeRef(pre, sym, _)) = ConsClass.primaryConstructor.tpe + val consRef = typeRef(pre, sym, List(tpe)) + val listRef = typeRef(pre, ListClass, List(tpe)) + + def fold(x: Tree, xs: Tree) = x match { + case sp @ Strip(_, _: Star) => makeBind(definedVars(sp), WILD(sp.tpe)) + case _ => val dummyMethod = new TermSymbol(NoSymbol, NoPosition, "matching$dummy") - val res = mkTypeRef(sym) + val consType = MethodType(dummyMethod newSyntheticValueParams List(tpe, listRef), consRef) - (MethodType(dummyMethod newSyntheticValueParams List(tptArg, mkTypeRef(ListClass)), res), res) - } - Apply(TypeTree(consType), List(x, normalizedListPattern(xs, tptArg))) setType resType + Apply(TypeTree(consType), List(x, xs)) setType consRef + } + + fold _ } - // if Apply target !isType and takes no args, returns target prefix and symbol - object Apply_Value { - def unapply(x: Any) = x match { - case x @ Apply(fun, args) if !fun.isType && args.isEmpty => Some(x.tpe.prefix, x.symbol) - case _ => None + def normalizedListPattern(pats: List[Tree], tptArg: Type): Tree = + pats.foldRight(gen.mkNil)(listFolder(tptArg)) + + // An Apply that's a constructor pattern (case class) + // foo match { case C() => true } + object Apply_CaseClass { + def unapply(x: Any) = condOpt(x) { + case x @ Apply(fn, args) if fn.isType => (x.tpe, args) } } - // if Apply tpe !isCaseClass and Apply_Value says false, return the Apply target - object Apply_Function { - /* see t301 */ - def isApplyFunction(t: Apply) = !isCaseClass(t.tpe) || !Apply_Value.unapply(t).isEmpty - def unapply(x: Any) = x match { - case x @ Apply(fun, Nil) if isApplyFunction(x) => Some(fun) - case _ => None + // No-args Apply where fn is not a type - looks like, case object with prefix? + // + // class Pip { + // object opcodes { case object EmptyInstr } + // def bop(x: Any) = x match { case opcodes.EmptyInstr => true } + // } + object Apply_Value { + def unapply(x: Any) = condOpt(x) { + case x @ Apply(fn, Nil) if !fn.isType => (x.tpe.prefix, x.symbol) } } - // if Apply target isType (? what does this imply exactly) returns type and arguments. - object Apply_CaseClass { - def unapply(x: Any) = x match { - case x @ Apply(fun, args) if fun.isType => Some(x.tpe, args) - case _ => None + // No-args Apply for all the other cases + // val Bop = Nil + // def foo(x: Any) = x match { case Bop => true } + object Apply_Function { + def isApplyFunction(t: Apply) = cond(t) { + case Apply_Value(_, _) => true + case x if !isCaseClass(x.tpe) => true + } + def unapply(x: Any) = condOpt(x) { + case x @ Apply(fn, Nil) if isApplyFunction(x) => fn } } - object UnApply_TypeApply { - def unapply(x: UnApply) = x match { - case UnApply(Apply(TypeApply(sel @ Select(stor, nme.unapplySeq), List(tptArg)), _), ArrayValue(_, xs)::Nil) - if (stor.symbol eq ListModule) => Some(tptArg, xs) - case _ => None + // unapplySeq extractor + // val List(x,y) = List(1,2) + object UnapplySeq { + private object TypeApp { + def unapply(x: Any) = condOpt(x) { + case TypeApply(sel @ Select(stor, nme.unapplySeq), List(tpe)) if stor.symbol eq ListModule => tpe + } + } + def unapply(x: UnApply) = condOpt(x) { + case UnApply(Apply(TypeApp(tptArg), _), List(ArrayValue(_, xs))) => (tptArg, xs) } } + // unapply extractor + // val Pair(_,x) = Pair(1,2) object __UnApply { private def paramType(fn: Tree) = fn.tpe match { case m: MethodType => m.paramTypes.head } - def unapply(x: Tree) = x match { - case Strip(vs, UnApply(Apply(fn, _), args)) => Some(vs, paramType(fn), args) - case _ => None + def unapply(x: Tree) = condOpt(x) { + case Strip(vs, UnApply(Apply(fn, _), args)) => (vs, paramType(fn), args) } } @@ -195,11 +213,8 @@ trait PatternNodes extends ast.TreeDSL def unapply(x: Tree): Option[(Set[Symbol], Tree)] = Some(strip(Set(), x)) } - object BoundVariables { - def unapply(x: Tree): Option[List[Symbol]] = Some(Pattern(x).boundVariables) - } object Stripped { - def unapply(x: Tree): Option[Tree] = Some(Pattern(x).stripped) + def unapply(x: Tree): Option[Tree] = Some(unbind(x)) } final def definedVars(x: Tree): List[Symbol] = { @@ -215,29 +230,25 @@ trait PatternNodes extends ast.TreeDSL } /** pvar: the symbol of the pattern variable - * temp: the temp variable that holds the actual value + * tvar: the temporary variable that holds the actual value */ - case class Binding(pvar: Symbol, temp: Symbol) { - override def toString() = "%s @ %s".format(pvar.name, temp.name) + case class Binding(pvar: Symbol, tvar: Symbol) { + override def toString() = "%s @ %s".format(pvar.name, tvar.name) } case class Bindings(bindings: Binding*) extends Function1[Symbol, Option[Ident]] { - private def castIfNeeded(pvar: Symbol, t: Symbol) = - if (t.tpe <:< pvar.tpe) ID(t) else ID(t) AS_ANY pvar.tpe + private def castIfNeeded(pvar: Symbol, tvar: Symbol) = + if (tvar.tpe <:< pvar.tpe) ID(tvar) + else ID(tvar) AS_ANY pvar.tpe - def add(vs: Iterable[Symbol], temp: Symbol): Bindings = - Bindings((vs.toList map (Binding(_: Symbol, temp))) ++ bindings : _*) + def add(vs: Iterable[Symbol], tvar: Symbol): Bindings = + Bindings((vs.toList map (Binding(_: Symbol, tvar))) ++ bindings : _*) def apply(v: Symbol): Option[Ident] = - (bindings find (_.pvar eq v) map (x => Ident(x.temp) setType v.tpe)) match { - case res @ Some(_) => traceAndReturn("Bindings.apply = ", res) - case res => res - } + bindings find (_.pvar eq v) map (x => Ident(x.tvar) setType v.tpe) - override def toString() = bindings.toList match { - case Nil => "" - case xs => xs.mkString(" Bound(", ", ", ")") - } + override def toString() = + if (bindings.isEmpty) "" else bindings.mkString(" Bound(", ", ", ")") /** The corresponding list of value definitions. */ final def targetParams(implicit typer: Typer): List[ValDef] = |