diff options
author | Burak Emir <emir@epfl.ch> | 2006-12-25 11:30:13 +0000 |
---|---|---|
committer | Burak Emir <emir@epfl.ch> | 2006-12-25 11:30:13 +0000 |
commit | 8c2a69d14e1566a7468a530ccfc2d08608483753 (patch) | |
tree | 3afcb07916c49a5c405a16c4a0e73229e7ec85ad | |
parent | 2820d1ff440f460e4fdd124cccec2bfa52dffaa8 (diff) | |
download | scala-8c2a69d14e1566a7468a530ccfc2d08608483753.tar.gz scala-8c2a69d14e1566a7468a530ccfc2d08608483753.tar.bz2 scala-8c2a69d14e1566a7468a530ccfc2d08608483753.zip |
unapply <-> as in "Matching with Objects"
-rw-r--r-- | src/compiler/scala/tools/nsc/matching/PatternMatchers.scala | 85 | ||||
-rw-r--r-- | src/compiler/scala/tools/nsc/matching/PatternNodes.scala | 4 | ||||
-rw-r--r-- | src/compiler/scala/tools/nsc/symtab/Definitions.scala | 85 | ||||
-rw-r--r-- | src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala | 3 | ||||
-rw-r--r-- | src/compiler/scala/tools/nsc/transform/UnCurry.scala | 14 | ||||
-rw-r--r-- | src/compiler/scala/tools/nsc/typechecker/Typers.scala | 56 | ||||
-rw-r--r-- | src/library/scala/List.scala | 3 | ||||
-rw-r--r-- | test/files/run/patmatnew.scala | 20 | ||||
-rw-r--r-- | test/files/run/regularpatmatnew.scala | 20 |
9 files changed, 214 insertions, 76 deletions
diff --git a/src/compiler/scala/tools/nsc/matching/PatternMatchers.scala b/src/compiler/scala/tools/nsc/matching/PatternMatchers.scala index 7337d01155..b627a7027f 100644 --- a/src/compiler/scala/tools/nsc/matching/PatternMatchers.scala +++ b/src/compiler/scala/tools/nsc/matching/PatternMatchers.scala @@ -139,7 +139,10 @@ trait PatternMatchers requires (transform.ExplicitOuter with PatternNodes) { } def pUnapplyPat(pos: PositionType, fn: Tree) = { - val tpe = definitions.productType(definitions.optionOfProductElems(fn.tpe)) // is subtype of producttype + var tpe = definitions.unapplyUnwrap(fn.tpe) + if(definitions.isOptionOrSomeType(tpe)) { + tpe = tpe.typeArgs.head + } val node = new UnapplyPat(newVar(pos, tpe), fn) node.pos = pos node.setType(tpe) @@ -612,16 +615,23 @@ print() //assert index >= 0 : casted; if (index < 0) { Predef.error("error entering:" + casted); return null } - //access the index'th child of a case class - - curHeader = newHeader(pat.pos, casted, index) - target.and = curHeader; // (*) - - if (bodycond ne null) target.and = bodycond(target.and) // restores body with the guards + target match { + case u @ UnapplyPat(_,_) if u.returnsOne => + assert(index==0) + curHeader = pHeader(pat.pos, casted.tpe, Ident(casted) setType casted.tpe) + target.and = curHeader + curHeader.or = patternNode(pat, curHeader, env) + enter(patArgs, curHeader.or, casted, env) + case _ => + //access the index'th child of a case class + curHeader = newHeader(pat.pos, casted, index) + target.and = curHeader; // (*) - curHeader.or = patternNode(pat, curHeader, env) - enter(patArgs, curHeader.or, casted, env) + if (bodycond ne null) target.and = bodycond(target.and) // restores body with the guards + curHeader.or = patternNode(pat, curHeader, env) + enter(patArgs, curHeader.or, casted, env) + } } else { //Console.println(" enter: using old header for casted = "+casted) // DBG // find most recent header @@ -967,8 +977,6 @@ print() def generalSwitchToTree(): Tree = { this.exit = owner.newLabel(root.pos, "exit").setInfo(new MethodType(List(resultType), resultType)); val result = exit.newValueParameter(root.pos, "result").setInfo( resultType ); - - squeezedBlock( List( ValDef(root.symbol, selector), @@ -1182,21 +1190,46 @@ print() case DefaultPat() => return toTree(node.and); - case UnapplyPat(casted, fn @ Apply(fn2, _)) => - val fntpe = if(settings.Xkilloption.value) fn.tpe.typeArgs(0) else fn.tpe - //Console.println("unapply "+fntpe) - val v = newVar(fn.pos, fntpe) - var __opt_get__ = if(settings.Xkilloption.value) Ident(v) else typed(Select(Ident(v),nme.get)) - var __opt_nonemp__ = if(settings.Xkilloption.value) NotNull(Ident(v)) else Not(Select(Ident(v), nme.isEmpty)) - - var fn1 = fn2.duplicate; if(settings.Xkilloption.value) { - fn1.setType(transformInfo(owner, fn1.tpe)) - } - Or(squeezedBlock( - List(ValDef(v,Apply(fn1,List(selector)))), - And(__opt_nonemp__, - squeezedBlock(List(ValDef(casted, __opt_get__)),toTree(node.and)))), - toTree(node.or, selector.duplicate)) + case UnapplyPat(casted, Apply(fn1, _)) if casted.tpe.symbol == definitions.BooleanClass => // special case + var useSelector = selector + val checkType = fn1.tpe match { + case MethodType(List(argtpe),_) => + if(isSubType(selector.tpe, argtpe)) + Literal(Constant(true)) + else { + useSelector = gen.mkAsInstanceOf(selector, argtpe) + gen.mkIsInstanceOf(selector.duplicate, argtpe) + } + } + Or(And(checkType, + And(Apply(fn1,List(useSelector)), + toTree(node.and)) + ), + toTree(node.or, selector.duplicate)) + + case UnapplyPat(casted, fn @ Apply(fn1, _)) => + var useSelector = selector + val checkType = fn1.tpe match { + case MethodType(List(argtpe),_) => + if(isSubType(selector.tpe, argtpe)) + Literal(Constant(true)) + else { + useSelector = gen.mkAsInstanceOf(selector, argtpe) + gen.mkIsInstanceOf(selector.duplicate, argtpe) + } + } + val fntpe = fn.tpe + val v = newVar(fn.pos, fntpe) + var __opt_get__ = typed(Select(Ident(v),nme.get)) + var __opt_nonemp__ = Not(Select(Ident(v), nme.isEmpty)) + + Or(And(checkType, + squeezedBlock( + List(ValDef(v,Apply(fn1,List(useSelector)))), + And(__opt_nonemp__, + squeezedBlock(List(ValDef(casted, __opt_get__)),toTree(node.and)))) + ), + toTree(node.or, selector.duplicate)) case ConstrPat(casted) => if(!isSubType(casted.tpe,selector.tpe) && settings.Xkilloption.value && definitions.isOptionType(selector.tpe)) { diff --git a/src/compiler/scala/tools/nsc/matching/PatternNodes.scala b/src/compiler/scala/tools/nsc/matching/PatternNodes.scala index b048dd673b..11cf8f88fa 100644 --- a/src/compiler/scala/tools/nsc/matching/PatternNodes.scala +++ b/src/compiler/scala/tools/nsc/matching/PatternNodes.scala @@ -273,7 +273,9 @@ trait PatternNodes requires transform.ExplicitOuter { case class DefaultPat()extends PatternNode case class ConstrPat(casted:Symbol) extends PatternNode - case class UnapplyPat(prod:Symbol, fn:Tree) extends PatternNode + case class UnapplyPat(casted:Symbol, fn:Tree) extends PatternNode { + def returnsOne = !definitions.isProductType(casted.tpe) + } case class ConstantPat(value: Any /*AConstant*/) extends PatternNode case class VariablePat(tree: Tree) extends PatternNode case class AltPat(subheader: Header) extends PatternNode diff --git a/src/compiler/scala/tools/nsc/symtab/Definitions.scala b/src/compiler/scala/tools/nsc/symtab/Definitions.scala index 36906ad7e7..9103b7a69f 100644 --- a/src/compiler/scala/tools/nsc/symtab/Definitions.scala +++ b/src/compiler/scala/tools/nsc/symtab/Definitions.scala @@ -139,6 +139,10 @@ trait Definitions requires SymbolTable { case TypeRef(_, sym, List(_)) if sym == OptionClass => true case _ => false } + def isOptionOrSomeType(tp: Type) = tp match { + case TypeRef(_, sym, List(_)) => sym == OptionClass || sym == SomeClass + case _ => false + } def optionType(tp: Type) = typeRef(OptionClass.typeConstructor.prefix, OptionClass, List(tp)) @@ -154,21 +158,80 @@ trait Definitions requires SymbolTable { case _ => false } - def optionOfProductElems(tp: Type): List[Type] = { - assert(tp.symbol == OptionClass) - val prod = tp.typeArgs.head - if (prod.symbol == UnitClass) List() - else prod.baseClasses.find { x => isProductType(x.tpe) } match { - case Some(p) => prod.baseType(p).typeArgs + + def unapplyUnwrap(tpe:Type) = tpe match { + case PolyType(_,MethodType(_, res)) => res + case MethodType(_, res) => res + case tpe => tpe + } + + /** returns type list for return type of the extraction */ + def unapplyTypeList(ufn: Symbol, ufntpe: Type) = { + assert(ufn.isMethod) + //Console.println("utl "+ufntpe+" "+ufntpe.symbol) + ufn.name match { + case nme.unapply => unapplyTypeListFromReturnType(ufntpe) + case nme.unapplySeq => unapplyTypeListFromReturnTypeSeq(ufntpe) + case _ => throw new IllegalArgumentException("expected function symbol of extraction") + } + } + /** (the inverse of unapplyReturnTypeSeq) + * for type Boolean, returns Nil + * for type Option[T] or Some[T]: + * - returns T0...Tn if n>0 and T <: Product[T0...Tn]] + * - returns T otherwise + */ + def unapplyTypeListFromReturnType(tp1: Type): List[Type] = { // rename: unapplyTypeListFromReturnType + val tp = unapplyUnwrap(tp1) + val B = BooleanClass; + val O = OptionClass; val S = SomeClass; tp.symbol match { // unapplySeqResultToMethodSig + case B => Nil + case O | S => + val prod = tp.typeArgs.head + prod.baseClasses.find { x => isProductType(x.tpe) } match { + case Some(p) => prod.baseType(p).typeArgs + case _ => prod::Nil // special case n = 0 + } + case _ => throw new IllegalArgumentException(tp.symbol + " in not in {boolean, option, some}") } } - def unapplySeqResultToMethodSig(tp: Type) = { - val ts = optionOfProductElems(tp) - val last1 = ts.last.baseType(SeqClass) match { - case TypeRef(pre, seqClass, args) => typeRef(pre, RepeatedParamClass, args) + /** let type be the result type of the (possibly polymorphic) unapply method + * for type Option[T] or Some[T] + * -returns T0...Tn-1,Tn* if n>0 and T <: Product[T0...Tn-1,Seq[Tn]]], + * -returns R* if T = Seq[R] + */ + def unapplyTypeListFromReturnTypeSeq(tp1: Type): List[Type] = { + val tp = unapplyUnwrap(tp1) + val O = OptionClass; val S = SomeClass; tp.symbol match { + case O | S => + val ts = unapplyTypeListFromReturnType(tp1) + val last1 = ts.last.baseType(SeqClass) match { + case TypeRef(pre, seqClass, args) => typeRef(pre, RepeatedParamClass, args) + case _ => throw new IllegalArgumentException("last not seq") + } + ts.init ::: List(last1) + case _ => throw new IllegalArgumentException(tp.symbol + " in not in {option, some}") } - ts.init ::: List(last1) + } + + /** returns type of the unapply method returning T_0...T_n + * for n == 0, boolean + * for n == 1, Some[T0] + * else Some[Product[Ti]] + def unapplyReturnType(elems: List[Type], useWildCards: Boolean) = + if (elems.isEmpty) + BooleanClass.tpe + else if (elems.length == 1) + optionType(if(useWildCards) WildcardType else elems(0)) + else + productType({val es = elems; if(useWildCards) elems map { x => WildcardType} else elems}) + */ + + def unapplyReturnTypeExpected(argsLength: int) = argsLength match { + case 0 => BooleanClass.tpe + case 1 => optionType(WildcardType) + case n => optionType(productType(List.range(0,n).map (arg => WildcardType))) } /** returns unapply or unapplySeq if available */ diff --git a/src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala b/src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala index b205c38f55..14d4f7d5a9 100644 --- a/src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala +++ b/src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala @@ -304,6 +304,9 @@ abstract class ExplicitOuter extends InfoTransform with TransMatcher with Patter } catch {//debug case ex: Throwable => Console.println("exception when transforming " + tree) + //Console.println(ex.getMessage) + //Console.println(ex.printStackTrace()) + //System.exit(-1); throw ex } finally { outerParam = savedOuterParam diff --git a/src/compiler/scala/tools/nsc/transform/UnCurry.scala b/src/compiler/scala/tools/nsc/transform/UnCurry.scala index 167b8be1c8..05bd8b664d 100644 --- a/src/compiler/scala/tools/nsc/transform/UnCurry.scala +++ b/src/compiler/scala/tools/nsc/transform/UnCurry.scala @@ -368,12 +368,16 @@ abstract class UnCurry extends InfoTransform with TypingTransformers { mainTransform(new TreeSubstituter(vparams map (.symbol), args).transform(body)) } - case UnApply(fn, args) => - if(fn.symbol.name == nme.unapply) copy.UnApply(tree, fn, transformTrees(args)) - else { - val formals = definitions.unapplySeqResultToMethodSig(fn.tpe) + case UnApply(fn, args) if nme.unapply == fn.symbol.name => + copy.UnApply(tree, fn, transformTrees(args)) + + case UnApply(fn, args) if nme.unapplySeq == fn.symbol.name => + val formals = definitions.unapplyTypeListFromReturnTypeSeq(fn.tpe) copy.UnApply(tree, fn, transformTrees(transformArgs(tree.pos, args, formals))) - } + + case UnApply(_, _) => + assert(false,"internal error: UnApply node has wrong symbol"); null + case Apply(fn, args) => if (settings.noassertions.value && (fn.symbol ne null) && diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala index aa6f5148f8..d262ad4855 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala @@ -561,7 +561,7 @@ trait Typers requires Analyzer { throw t } tree1 - } else if (clazz.isNonBottomSubClass(SeqClass)) { // (5.2) + } else if (!settings.Xunapply.value && clazz.isNonBottomSubClass(SeqClass)) { // (5.2) val restpe = pt.baseType(clazz) restpe.baseType(SeqClass) match { case TypeRef(pre, seqClass, args) => @@ -1396,46 +1396,46 @@ trait Typers requires Analyzer { val unapp = definitions.unapplyMember(otpe) assert(unapp.exists, tree) assert(isFullyDefined(pt)) + + val unappArg:Type = unapp.tpe match { + case PolyType(_,MethodType(List(res), _)) => res + case MethodType(List(res), _) => res + case _ => error(fun.pos, "unapply takes too many arguments to be used as pattern"); NoType + } + val argDummy = context.owner.newValue(fun.pos, nme.SELECTOR_DUMMY) .setFlag(SYNTHETIC) - .setInfo(pt) + .setInfo(unappArg) // was: pt if (args.length > MaxTupleArity) error(fun.pos, "too many arguments for unapply pattern, maximum = "+MaxTupleArity) - val arg = Ident(argDummy) setType pt - var prod: Type = null - - if (nme.unapplySeq == unapp.name) { - def failSeq = { - error(fun.pos, " unapplySeq should return Option[T] for T<:Product?[...Seq[?]]") - setError(tree) - } - unapp.tpe.resultType match { // last should be seq... - case TypeRef(_, opt, List(tpe)) if opt == definitions.OptionClass => - // the following is almost definitions.optionOfProductElems, but error handling? - tpe.baseClasses.find { x => isProductType(x.tpe) } match { - case Some(p) => - val xs = tpe.baseType(p).typeArgs - if(xs.isEmpty) - return failSeq - prod = productType((xs.tail map {x => WildcardType}) ::: List(seqType(WildcardType))) - case None => return failSeq - } - case _ => return failSeq + val arg = Ident(argDummy) setType unappArg // was pt + var funPt: Type = null + try { + funPt = unapp.name match { + case nme.unapply => unapplyReturnTypeExpected(args.length) + case nme.unapplySeq => unapplyTypeListFromReturnTypeSeq(unapp.tpe) match { + case List() => null //fail + case List(TypeRef(pre,repeatedParam, tpe)) => optionType(seqType(WildcardType)) //succeed + case xs => optionType(productType((xs.tail map {x => WildcardType}) ::: List(seqType(WildcardType))))// succeed + } } - } else if (args.length == 0) prod = UnitClass.tpe - else prod = productType(args map (arg => WildcardType)) - - val funPt = appliedType(OptionClass.typeConstructor, List(prod)) + } catch { + case ex => //failure + //Console.println("DEBUG") + //ex.printStackTrace() + error(fun.pos, " unapplySeq should return Option[T] for T<:Product?[...Seq[?]]") + return setError(tree) + } val fun0 = Ident(fun.symbol) setPos fun.pos setType otpe // would this change when patterns are terms??? val fun1untyped = atPos(fun.pos) { Apply(Select(gen.mkAttributedRef(fun.tpe.prefix,fun.symbol), unapp), List(arg)) } //Console.println("UNAPP "+fun+"/"+fun.tpe+" "+fun1untyped) + //Console.println("funPt: "+funPt) val fun1 = typed(fun1untyped, EXPRmode, funPt) if (fun1.tpe.isErroneous) setError(tree) else { - val formals0 = if(unapp.name == nme.unapply) optionOfProductElems(fun1.tpe) - else unapplySeqResultToMethodSig(fun1.tpe) + val formals0 = unapplyTypeList(fun1.symbol, fun1.tpe) val formals1 = formalTypes(formals0, args.length) val args1 = typedArgs(args, mode, formals0, formals1) UnApply(fun1, args1) setPos tree.pos setType pt diff --git a/src/library/scala/List.scala b/src/library/scala/List.scala index 0b41e0221d..362b2744ff 100644 --- a/src/library/scala/List.scala +++ b/src/library/scala/List.scala @@ -31,8 +31,7 @@ object List { /** for unapply matching */ - def unapplySeq[A](x: Any): Option[Tuple1[List[A]]] = - if (x.isInstanceOf[List[A]]) Some(Tuple1(x.asInstanceOf[List[A]])) else None + def unapplySeq[A](x: List[A]): Option[List[A]] = Some(x) /** Create a sorted list of all integers in a range. * diff --git a/test/files/run/patmatnew.scala b/test/files/run/patmatnew.scala index 377e89a267..10b0021f0f 100644 --- a/test/files/run/patmatnew.scala +++ b/test/files/run/patmatnew.scala @@ -20,8 +20,8 @@ object Test { val tr = new TestResult new TestSuite( - new Test717 - + new Test717, + new TestGuards ).run(tr) for(val f <- tr.failures()) @@ -44,6 +44,22 @@ object Test { } } + class TestGuards extends TestCase("multiple guards for same pattern") with Shmeez { + val tree:Tree = Beez(2) + override def runTest = { + val res = tree match { + case Beez(x) if x == 3 => false + case Beez(x) if x == 2 => true + } + assertTrue("ok", res); + val ret = (Beez(3):Tree) match { + case Beez(x) if x == 3 => true + case Beez(x) if x == 2 => false + } + assertTrue("ok", ret); + } + } + class Test806_818 { // #806, #811 compile only -- type of bind // bug811 trait Core { diff --git a/test/files/run/regularpatmatnew.scala b/test/files/run/regularpatmatnew.scala index 9e056af9f8..a12ed5dfc6 100644 --- a/test/files/run/regularpatmatnew.scala +++ b/test/files/run/regularpatmatnew.scala @@ -10,7 +10,8 @@ object Test { new Test03, new Test04, new Test05, - new Test06 + new Test06, + new Test07 ).run(tr) @@ -131,4 +132,21 @@ object Test { } } + + class Test07 extends TestCase("sette List of chars") { + def doMatch1(xs:List[char]) = xs match { + case List(x, y, _*) => x::y::Nil + } + def doMatch2(xs:List[char]) = xs match { + case List(x, y, z, w) => List(z,w) + } + //def doMatch3(xs:List[char]) = xs match { + // case List(_*, z, w) => w::Nil + //} + override def runTest() { + assertEquals(doMatch1(List('a','b','c','d')), List('a','b')) + assertEquals(doMatch2(List('a','b','c','d')), List('c','d')) + //assertEquals(doMatch3(List('a','b','c','d')), List('d')) + } + } } |