summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBurak Emir <emir@epfl.ch>2006-12-25 11:30:13 +0000
committerBurak Emir <emir@epfl.ch>2006-12-25 11:30:13 +0000
commit8c2a69d14e1566a7468a530ccfc2d08608483753 (patch)
tree3afcb07916c49a5c405a16c4a0e73229e7ec85ad
parent2820d1ff440f460e4fdd124cccec2bfa52dffaa8 (diff)
downloadscala-8c2a69d14e1566a7468a530ccfc2d08608483753.tar.gz
scala-8c2a69d14e1566a7468a530ccfc2d08608483753.tar.bz2
scala-8c2a69d14e1566a7468a530ccfc2d08608483753.zip
unapply <-> as in "Matching with Objects"
-rw-r--r--src/compiler/scala/tools/nsc/matching/PatternMatchers.scala85
-rw-r--r--src/compiler/scala/tools/nsc/matching/PatternNodes.scala4
-rw-r--r--src/compiler/scala/tools/nsc/symtab/Definitions.scala85
-rw-r--r--src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala3
-rw-r--r--src/compiler/scala/tools/nsc/transform/UnCurry.scala14
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/Typers.scala56
-rw-r--r--src/library/scala/List.scala3
-rw-r--r--test/files/run/patmatnew.scala20
-rw-r--r--test/files/run/regularpatmatnew.scala20
9 files changed, 214 insertions, 76 deletions
diff --git a/src/compiler/scala/tools/nsc/matching/PatternMatchers.scala b/src/compiler/scala/tools/nsc/matching/PatternMatchers.scala
index 7337d01155..b627a7027f 100644
--- a/src/compiler/scala/tools/nsc/matching/PatternMatchers.scala
+++ b/src/compiler/scala/tools/nsc/matching/PatternMatchers.scala
@@ -139,7 +139,10 @@ trait PatternMatchers requires (transform.ExplicitOuter with PatternNodes) {
}
def pUnapplyPat(pos: PositionType, fn: Tree) = {
- val tpe = definitions.productType(definitions.optionOfProductElems(fn.tpe)) // is subtype of producttype
+ var tpe = definitions.unapplyUnwrap(fn.tpe)
+ if(definitions.isOptionOrSomeType(tpe)) {
+ tpe = tpe.typeArgs.head
+ }
val node = new UnapplyPat(newVar(pos, tpe), fn)
node.pos = pos
node.setType(tpe)
@@ -612,16 +615,23 @@ print()
//assert index >= 0 : casted;
if (index < 0) { Predef.error("error entering:" + casted); return null }
- //access the index'th child of a case class
-
- curHeader = newHeader(pat.pos, casted, index)
- target.and = curHeader; // (*)
-
- if (bodycond ne null) target.and = bodycond(target.and) // restores body with the guards
+ target match {
+ case u @ UnapplyPat(_,_) if u.returnsOne =>
+ assert(index==0)
+ curHeader = pHeader(pat.pos, casted.tpe, Ident(casted) setType casted.tpe)
+ target.and = curHeader
+ curHeader.or = patternNode(pat, curHeader, env)
+ enter(patArgs, curHeader.or, casted, env)
+ case _ =>
+ //access the index'th child of a case class
+ curHeader = newHeader(pat.pos, casted, index)
+ target.and = curHeader; // (*)
- curHeader.or = patternNode(pat, curHeader, env)
- enter(patArgs, curHeader.or, casted, env)
+ if (bodycond ne null) target.and = bodycond(target.and) // restores body with the guards
+ curHeader.or = patternNode(pat, curHeader, env)
+ enter(patArgs, curHeader.or, casted, env)
+ }
} else {
//Console.println(" enter: using old header for casted = "+casted) // DBG
// find most recent header
@@ -967,8 +977,6 @@ print()
def generalSwitchToTree(): Tree = {
this.exit = owner.newLabel(root.pos, "exit").setInfo(new MethodType(List(resultType), resultType));
val result = exit.newValueParameter(root.pos, "result").setInfo( resultType );
-
-
squeezedBlock(
List(
ValDef(root.symbol, selector),
@@ -1182,21 +1190,46 @@ print()
case DefaultPat() =>
return toTree(node.and);
- case UnapplyPat(casted, fn @ Apply(fn2, _)) =>
- val fntpe = if(settings.Xkilloption.value) fn.tpe.typeArgs(0) else fn.tpe
- //Console.println("unapply "+fntpe)
- val v = newVar(fn.pos, fntpe)
- var __opt_get__ = if(settings.Xkilloption.value) Ident(v) else typed(Select(Ident(v),nme.get))
- var __opt_nonemp__ = if(settings.Xkilloption.value) NotNull(Ident(v)) else Not(Select(Ident(v), nme.isEmpty))
-
- var fn1 = fn2.duplicate; if(settings.Xkilloption.value) {
- fn1.setType(transformInfo(owner, fn1.tpe))
- }
- Or(squeezedBlock(
- List(ValDef(v,Apply(fn1,List(selector)))),
- And(__opt_nonemp__,
- squeezedBlock(List(ValDef(casted, __opt_get__)),toTree(node.and)))),
- toTree(node.or, selector.duplicate))
+ case UnapplyPat(casted, Apply(fn1, _)) if casted.tpe.symbol == definitions.BooleanClass => // special case
+ var useSelector = selector
+ val checkType = fn1.tpe match {
+ case MethodType(List(argtpe),_) =>
+ if(isSubType(selector.tpe, argtpe))
+ Literal(Constant(true))
+ else {
+ useSelector = gen.mkAsInstanceOf(selector, argtpe)
+ gen.mkIsInstanceOf(selector.duplicate, argtpe)
+ }
+ }
+ Or(And(checkType,
+ And(Apply(fn1,List(useSelector)),
+ toTree(node.and))
+ ),
+ toTree(node.or, selector.duplicate))
+
+ case UnapplyPat(casted, fn @ Apply(fn1, _)) =>
+ var useSelector = selector
+ val checkType = fn1.tpe match {
+ case MethodType(List(argtpe),_) =>
+ if(isSubType(selector.tpe, argtpe))
+ Literal(Constant(true))
+ else {
+ useSelector = gen.mkAsInstanceOf(selector, argtpe)
+ gen.mkIsInstanceOf(selector.duplicate, argtpe)
+ }
+ }
+ val fntpe = fn.tpe
+ val v = newVar(fn.pos, fntpe)
+ var __opt_get__ = typed(Select(Ident(v),nme.get))
+ var __opt_nonemp__ = Not(Select(Ident(v), nme.isEmpty))
+
+ Or(And(checkType,
+ squeezedBlock(
+ List(ValDef(v,Apply(fn1,List(useSelector)))),
+ And(__opt_nonemp__,
+ squeezedBlock(List(ValDef(casted, __opt_get__)),toTree(node.and))))
+ ),
+ toTree(node.or, selector.duplicate))
case ConstrPat(casted) =>
if(!isSubType(casted.tpe,selector.tpe) && settings.Xkilloption.value && definitions.isOptionType(selector.tpe)) {
diff --git a/src/compiler/scala/tools/nsc/matching/PatternNodes.scala b/src/compiler/scala/tools/nsc/matching/PatternNodes.scala
index b048dd673b..11cf8f88fa 100644
--- a/src/compiler/scala/tools/nsc/matching/PatternNodes.scala
+++ b/src/compiler/scala/tools/nsc/matching/PatternNodes.scala
@@ -273,7 +273,9 @@ trait PatternNodes requires transform.ExplicitOuter {
case class DefaultPat()extends PatternNode
case class ConstrPat(casted:Symbol) extends PatternNode
- case class UnapplyPat(prod:Symbol, fn:Tree) extends PatternNode
+ case class UnapplyPat(casted:Symbol, fn:Tree) extends PatternNode {
+ def returnsOne = !definitions.isProductType(casted.tpe)
+ }
case class ConstantPat(value: Any /*AConstant*/) extends PatternNode
case class VariablePat(tree: Tree) extends PatternNode
case class AltPat(subheader: Header) extends PatternNode
diff --git a/src/compiler/scala/tools/nsc/symtab/Definitions.scala b/src/compiler/scala/tools/nsc/symtab/Definitions.scala
index 36906ad7e7..9103b7a69f 100644
--- a/src/compiler/scala/tools/nsc/symtab/Definitions.scala
+++ b/src/compiler/scala/tools/nsc/symtab/Definitions.scala
@@ -139,6 +139,10 @@ trait Definitions requires SymbolTable {
case TypeRef(_, sym, List(_)) if sym == OptionClass => true
case _ => false
}
+ def isOptionOrSomeType(tp: Type) = tp match {
+ case TypeRef(_, sym, List(_)) => sym == OptionClass || sym == SomeClass
+ case _ => false
+ }
def optionType(tp: Type) =
typeRef(OptionClass.typeConstructor.prefix, OptionClass, List(tp))
@@ -154,21 +158,80 @@ trait Definitions requires SymbolTable {
case _ => false
}
- def optionOfProductElems(tp: Type): List[Type] = {
- assert(tp.symbol == OptionClass)
- val prod = tp.typeArgs.head
- if (prod.symbol == UnitClass) List()
- else prod.baseClasses.find { x => isProductType(x.tpe) } match {
- case Some(p) => prod.baseType(p).typeArgs
+
+ def unapplyUnwrap(tpe:Type) = tpe match {
+ case PolyType(_,MethodType(_, res)) => res
+ case MethodType(_, res) => res
+ case tpe => tpe
+ }
+
+ /** returns type list for return type of the extraction */
+ def unapplyTypeList(ufn: Symbol, ufntpe: Type) = {
+ assert(ufn.isMethod)
+ //Console.println("utl "+ufntpe+" "+ufntpe.symbol)
+ ufn.name match {
+ case nme.unapply => unapplyTypeListFromReturnType(ufntpe)
+ case nme.unapplySeq => unapplyTypeListFromReturnTypeSeq(ufntpe)
+ case _ => throw new IllegalArgumentException("expected function symbol of extraction")
+ }
+ }
+ /** (the inverse of unapplyReturnTypeSeq)
+ * for type Boolean, returns Nil
+ * for type Option[T] or Some[T]:
+ * - returns T0...Tn if n>0 and T <: Product[T0...Tn]]
+ * - returns T otherwise
+ */
+ def unapplyTypeListFromReturnType(tp1: Type): List[Type] = { // rename: unapplyTypeListFromReturnType
+ val tp = unapplyUnwrap(tp1)
+ val B = BooleanClass;
+ val O = OptionClass; val S = SomeClass; tp.symbol match { // unapplySeqResultToMethodSig
+ case B => Nil
+ case O | S =>
+ val prod = tp.typeArgs.head
+ prod.baseClasses.find { x => isProductType(x.tpe) } match {
+ case Some(p) => prod.baseType(p).typeArgs
+ case _ => prod::Nil // special case n = 0
+ }
+ case _ => throw new IllegalArgumentException(tp.symbol + " in not in {boolean, option, some}")
}
}
- def unapplySeqResultToMethodSig(tp: Type) = {
- val ts = optionOfProductElems(tp)
- val last1 = ts.last.baseType(SeqClass) match {
- case TypeRef(pre, seqClass, args) => typeRef(pre, RepeatedParamClass, args)
+ /** let type be the result type of the (possibly polymorphic) unapply method
+ * for type Option[T] or Some[T]
+ * -returns T0...Tn-1,Tn* if n>0 and T <: Product[T0...Tn-1,Seq[Tn]]],
+ * -returns R* if T = Seq[R]
+ */
+ def unapplyTypeListFromReturnTypeSeq(tp1: Type): List[Type] = {
+ val tp = unapplyUnwrap(tp1)
+ val O = OptionClass; val S = SomeClass; tp.symbol match {
+ case O | S =>
+ val ts = unapplyTypeListFromReturnType(tp1)
+ val last1 = ts.last.baseType(SeqClass) match {
+ case TypeRef(pre, seqClass, args) => typeRef(pre, RepeatedParamClass, args)
+ case _ => throw new IllegalArgumentException("last not seq")
+ }
+ ts.init ::: List(last1)
+ case _ => throw new IllegalArgumentException(tp.symbol + " in not in {option, some}")
}
- ts.init ::: List(last1)
+ }
+
+ /** returns type of the unapply method returning T_0...T_n
+ * for n == 0, boolean
+ * for n == 1, Some[T0]
+ * else Some[Product[Ti]]
+ def unapplyReturnType(elems: List[Type], useWildCards: Boolean) =
+ if (elems.isEmpty)
+ BooleanClass.tpe
+ else if (elems.length == 1)
+ optionType(if(useWildCards) WildcardType else elems(0))
+ else
+ productType({val es = elems; if(useWildCards) elems map { x => WildcardType} else elems})
+ */
+
+ def unapplyReturnTypeExpected(argsLength: int) = argsLength match {
+ case 0 => BooleanClass.tpe
+ case 1 => optionType(WildcardType)
+ case n => optionType(productType(List.range(0,n).map (arg => WildcardType)))
}
/** returns unapply or unapplySeq if available */
diff --git a/src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala b/src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala
index b205c38f55..14d4f7d5a9 100644
--- a/src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala
+++ b/src/compiler/scala/tools/nsc/transform/ExplicitOuter.scala
@@ -304,6 +304,9 @@ abstract class ExplicitOuter extends InfoTransform with TransMatcher with Patter
} catch {//debug
case ex: Throwable =>
Console.println("exception when transforming " + tree)
+ //Console.println(ex.getMessage)
+ //Console.println(ex.printStackTrace())
+ //System.exit(-1);
throw ex
} finally {
outerParam = savedOuterParam
diff --git a/src/compiler/scala/tools/nsc/transform/UnCurry.scala b/src/compiler/scala/tools/nsc/transform/UnCurry.scala
index 167b8be1c8..05bd8b664d 100644
--- a/src/compiler/scala/tools/nsc/transform/UnCurry.scala
+++ b/src/compiler/scala/tools/nsc/transform/UnCurry.scala
@@ -368,12 +368,16 @@ abstract class UnCurry extends InfoTransform with TypingTransformers {
mainTransform(new TreeSubstituter(vparams map (.symbol), args).transform(body))
}
- case UnApply(fn, args) =>
- if(fn.symbol.name == nme.unapply) copy.UnApply(tree, fn, transformTrees(args))
- else {
- val formals = definitions.unapplySeqResultToMethodSig(fn.tpe)
+ case UnApply(fn, args) if nme.unapply == fn.symbol.name =>
+ copy.UnApply(tree, fn, transformTrees(args))
+
+ case UnApply(fn, args) if nme.unapplySeq == fn.symbol.name =>
+ val formals = definitions.unapplyTypeListFromReturnTypeSeq(fn.tpe)
copy.UnApply(tree, fn, transformTrees(transformArgs(tree.pos, args, formals)))
- }
+
+ case UnApply(_, _) =>
+ assert(false,"internal error: UnApply node has wrong symbol"); null
+
case Apply(fn, args) =>
if (settings.noassertions.value &&
(fn.symbol ne null) &&
diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala
index aa6f5148f8..d262ad4855 100644
--- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala
+++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala
@@ -561,7 +561,7 @@ trait Typers requires Analyzer {
throw t
}
tree1
- } else if (clazz.isNonBottomSubClass(SeqClass)) { // (5.2)
+ } else if (!settings.Xunapply.value && clazz.isNonBottomSubClass(SeqClass)) { // (5.2)
val restpe = pt.baseType(clazz)
restpe.baseType(SeqClass) match {
case TypeRef(pre, seqClass, args) =>
@@ -1396,46 +1396,46 @@ trait Typers requires Analyzer {
val unapp = definitions.unapplyMember(otpe)
assert(unapp.exists, tree)
assert(isFullyDefined(pt))
+
+ val unappArg:Type = unapp.tpe match {
+ case PolyType(_,MethodType(List(res), _)) => res
+ case MethodType(List(res), _) => res
+ case _ => error(fun.pos, "unapply takes too many arguments to be used as pattern"); NoType
+ }
+
val argDummy = context.owner.newValue(fun.pos, nme.SELECTOR_DUMMY)
.setFlag(SYNTHETIC)
- .setInfo(pt)
+ .setInfo(unappArg) // was: pt
if (args.length > MaxTupleArity)
error(fun.pos, "too many arguments for unapply pattern, maximum = "+MaxTupleArity)
- val arg = Ident(argDummy) setType pt
- var prod: Type = null
-
- if (nme.unapplySeq == unapp.name) {
- def failSeq = {
- error(fun.pos, " unapplySeq should return Option[T] for T<:Product?[...Seq[?]]")
- setError(tree)
- }
- unapp.tpe.resultType match { // last should be seq...
- case TypeRef(_, opt, List(tpe)) if opt == definitions.OptionClass =>
- // the following is almost definitions.optionOfProductElems, but error handling?
- tpe.baseClasses.find { x => isProductType(x.tpe) } match {
- case Some(p) =>
- val xs = tpe.baseType(p).typeArgs
- if(xs.isEmpty)
- return failSeq
- prod = productType((xs.tail map {x => WildcardType}) ::: List(seqType(WildcardType)))
- case None => return failSeq
- }
- case _ => return failSeq
+ val arg = Ident(argDummy) setType unappArg // was pt
+ var funPt: Type = null
+ try {
+ funPt = unapp.name match {
+ case nme.unapply => unapplyReturnTypeExpected(args.length)
+ case nme.unapplySeq => unapplyTypeListFromReturnTypeSeq(unapp.tpe) match {
+ case List() => null //fail
+ case List(TypeRef(pre,repeatedParam, tpe)) => optionType(seqType(WildcardType)) //succeed
+ case xs => optionType(productType((xs.tail map {x => WildcardType}) ::: List(seqType(WildcardType))))// succeed
+ }
}
- } else if (args.length == 0) prod = UnitClass.tpe
- else prod = productType(args map (arg => WildcardType))
-
- val funPt = appliedType(OptionClass.typeConstructor, List(prod))
+ } catch {
+ case ex => //failure
+ //Console.println("DEBUG")
+ //ex.printStackTrace()
+ error(fun.pos, " unapplySeq should return Option[T] for T<:Product?[...Seq[?]]")
+ return setError(tree)
+ }
val fun0 = Ident(fun.symbol) setPos fun.pos setType otpe // would this change when patterns are terms???
val fun1untyped = atPos(fun.pos) {
Apply(Select(gen.mkAttributedRef(fun.tpe.prefix,fun.symbol), unapp), List(arg))
}
//Console.println("UNAPP "+fun+"/"+fun.tpe+" "+fun1untyped)
+ //Console.println("funPt: "+funPt)
val fun1 = typed(fun1untyped, EXPRmode, funPt)
if (fun1.tpe.isErroneous) setError(tree)
else {
- val formals0 = if(unapp.name == nme.unapply) optionOfProductElems(fun1.tpe)
- else unapplySeqResultToMethodSig(fun1.tpe)
+ val formals0 = unapplyTypeList(fun1.symbol, fun1.tpe)
val formals1 = formalTypes(formals0, args.length)
val args1 = typedArgs(args, mode, formals0, formals1)
UnApply(fun1, args1) setPos tree.pos setType pt
diff --git a/src/library/scala/List.scala b/src/library/scala/List.scala
index 0b41e0221d..362b2744ff 100644
--- a/src/library/scala/List.scala
+++ b/src/library/scala/List.scala
@@ -31,8 +31,7 @@ object List {
/** for unapply matching
*/
- def unapplySeq[A](x: Any): Option[Tuple1[List[A]]] =
- if (x.isInstanceOf[List[A]]) Some(Tuple1(x.asInstanceOf[List[A]])) else None
+ def unapplySeq[A](x: List[A]): Option[List[A]] = Some(x)
/** Create a sorted list of all integers in a range.
*
diff --git a/test/files/run/patmatnew.scala b/test/files/run/patmatnew.scala
index 377e89a267..10b0021f0f 100644
--- a/test/files/run/patmatnew.scala
+++ b/test/files/run/patmatnew.scala
@@ -20,8 +20,8 @@ object Test {
val tr = new TestResult
new TestSuite(
- new Test717
-
+ new Test717,
+ new TestGuards
).run(tr)
for(val f <- tr.failures())
@@ -44,6 +44,22 @@ object Test {
}
}
+ class TestGuards extends TestCase("multiple guards for same pattern") with Shmeez {
+ val tree:Tree = Beez(2)
+ override def runTest = {
+ val res = tree match {
+ case Beez(x) if x == 3 => false
+ case Beez(x) if x == 2 => true
+ }
+ assertTrue("ok", res);
+ val ret = (Beez(3):Tree) match {
+ case Beez(x) if x == 3 => true
+ case Beez(x) if x == 2 => false
+ }
+ assertTrue("ok", ret);
+ }
+ }
+
class Test806_818 { // #806, #811 compile only -- type of bind
// bug811
trait Core {
diff --git a/test/files/run/regularpatmatnew.scala b/test/files/run/regularpatmatnew.scala
index 9e056af9f8..a12ed5dfc6 100644
--- a/test/files/run/regularpatmatnew.scala
+++ b/test/files/run/regularpatmatnew.scala
@@ -10,7 +10,8 @@ object Test {
new Test03,
new Test04,
new Test05,
- new Test06
+ new Test06,
+ new Test07
).run(tr)
@@ -131,4 +132,21 @@ object Test {
}
}
+
+ class Test07 extends TestCase("sette List of chars") {
+ def doMatch1(xs:List[char]) = xs match {
+ case List(x, y, _*) => x::y::Nil
+ }
+ def doMatch2(xs:List[char]) = xs match {
+ case List(x, y, z, w) => List(z,w)
+ }
+ //def doMatch3(xs:List[char]) = xs match {
+ // case List(_*, z, w) => w::Nil
+ //}
+ override def runTest() {
+ assertEquals(doMatch1(List('a','b','c','d')), List('a','b'))
+ assertEquals(doMatch2(List('a','b','c','d')), List('c','d'))
+ //assertEquals(doMatch3(List('a','b','c','d')), List('d'))
+ }
+ }
}