diff options
author | Martin Odersky <odersky@gmail.com> | 2006-10-27 17:18:58 +0000 |
---|---|---|
committer | Martin Odersky <odersky@gmail.com> | 2006-10-27 17:18:58 +0000 |
commit | d8e8ab6a9ec2550716278c8ddffa03d295531808 (patch) | |
tree | 8632c6e124817786a80f3cefeb6e2134950c8af3 | |
parent | 5c642cbca2725bc45b2e62ff224c34c92a9b1012 (diff) | |
download | scala-d8e8ab6a9ec2550716278c8ddffa03d295531808.tar.gz scala-d8e8ab6a9ec2550716278c8ddffa03d295531808.tar.bz2 scala-d8e8ab6a9ec2550716278c8ddffa03d295531808.zip |
changed unapply impl
10 files changed, 99 insertions, 141 deletions
diff --git a/src/compiler/scala/tools/nsc/ast/TreePrinters.scala b/src/compiler/scala/tools/nsc/ast/TreePrinters.scala index f63a1a3991..ad447d6196 100644 --- a/src/compiler/scala/tools/nsc/ast/TreePrinters.scala +++ b/src/compiler/scala/tools/nsc/ast/TreePrinters.scala @@ -230,6 +230,9 @@ abstract class TreePrinters { case Bind(name, t) => print("("); print(symName(tree, name)); print(" @ "); print(t); print(")") + case UnApply(fun, args) => + print(fun); print(" <unapply> "); printRow(args, "(", ", ", ")") + case ArrayValue(elemtpt, trees) => print("Array["); print(elemtpt); printRow(trees, "]{", ", ", "}") diff --git a/src/compiler/scala/tools/nsc/ast/Trees.scala b/src/compiler/scala/tools/nsc/ast/Trees.scala index 4a64067f8a..286f297b30 100644 --- a/src/compiler/scala/tools/nsc/ast/Trees.scala +++ b/src/compiler/scala/tools/nsc/ast/Trees.scala @@ -419,6 +419,9 @@ trait Trees requires Global { def Bind(sym: Symbol, body: Tree): Bind = Bind(sym.name, body) setSymbol sym + case class UnApply(fun: Tree, args: List[Tree]) + extends TermTree + /** Array of expressions, needs to be translated in backend, */ case class ArrayValue(elemtpt: Tree, elems: List[Tree]) @@ -598,6 +601,7 @@ trait Trees requires Global { case Alternative(trees) => (eliminated by transmatch) case Star(elem) => (eliminated by transmatch) case Bind(name, body) => (eliminated by transmatch) + case UnApply(fun: Tree, args) (introduced by typer, eliminated by transmatch) case ArrayValue(elemtpt, trees) => (introduced by uncurry) case Function(vparams, body) => (eliminated by lambdaLift) case Assign(lhs, rhs) => @@ -642,6 +646,7 @@ trait Trees requires Global { def Alternative(tree: Tree, trees: List[Tree]): Alternative def Star(tree: Tree, elem: Tree): Star def Bind(tree: Tree, name: Name, body: Tree): Bind + def UnApply(tree: Tree, fun: Tree, args: List[Tree]): UnApply def ArrayValue(tree: Tree, elemtpt: Tree, trees: List[Tree]): ArrayValue def Function(tree: Tree, vparams: List[ValDef], body: Tree): Function def Assign(tree: Tree, lhs: Tree, rhs: Tree): Assign @@ -704,6 +709,8 @@ trait Trees requires Global { new Star(elem).copyAttrs(tree) def Bind(tree: Tree, name: Name, body: Tree) = new Bind(name, body).copyAttrs(tree) + def UnApply(tree: Tree, fun: Tree, args: List[Tree]) = + new UnApply(fun, args).copyAttrs(tree) def ArrayValue(tree: Tree, elemtpt: Tree, trees: List[Tree]) = new ArrayValue(elemtpt, trees).copyAttrs(tree) def Function(tree: Tree, vparams: List[ValDef], body: Tree) = @@ -845,6 +852,11 @@ trait Trees requires Global { if (name0 == name) && (body0 == body) => t case _ => copy.Bind(tree, name, body) } + def UnApply(tree: Tree, fun: Tree, args: List[Tree]) = tree match { + case t @ UnApply(fun0, args0) + if (fun0 == fun) && (args0 == args) => t + case _ => copy.UnApply(tree, fun, args) + } def ArrayValue(tree: Tree, elemtpt: Tree, trees: List[Tree]) = tree match { case t @ ArrayValue(elemtpt0, trees0) if (elemtpt0 == elemtpt) && (trees0 == trees) => t @@ -1022,6 +1034,8 @@ trait Trees requires Global { copy.Star(tree, transform(elem)) case Bind(name, body) => copy.Bind(tree, name, transform(body)) + case UnApply(fun, args) => + copy.UnApply(tree, fun, args) case ArrayValue(elemtpt, trees) => copy.ArrayValue(tree, transform(elemtpt), transformTrees(trees)) case Function(vparams, body) => @@ -1156,6 +1170,8 @@ trait Trees requires Global { traverse(elem) case Bind(name, body) => traverse(body) + case UnApply(fun, args) => + traverse(fun); traverseTrees(args) case ArrayValue(elemtpt, trees) => traverse(elemtpt); traverseTrees(trees) case Function(vparams, body) => diff --git a/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala b/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala index 2d7ac53c3a..a3efad6ae9 100644 --- a/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala +++ b/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala @@ -1812,7 +1812,15 @@ trait Parsers requires SyntaxAnalyzer { } if (name != nme.ScalaObject.toTypeName) parents += scalaScalaObjectConstr - if (mods.hasFlag(Flags.CASE)) parents += caseClassConstr + if (mods.hasFlag(Flags.CASE)) { + parents += caseClassConstr + if (!vparamss.isEmpty) { + val argtypes: List[Tree] = vparamss.head map (.tpt.duplicate) //remove type annotation and you will get an interesting error message!!! + val nargs = argtypes.length + if (0 < nargs && nargs <= definitions.MaxTupleArity) + parents += productConstr(argtypes) + } + } val ps = parents.toList newLineOptWhenFollowedBy(LBRACE) var body = diff --git a/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala b/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala index ed51119904..e9fdb5c91d 100644 --- a/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala +++ b/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala @@ -30,6 +30,10 @@ abstract class TreeBuilder { def caseClassConstr: Tree = scalaDot(nme.CaseClass.toTypeName) + def productConstr(typeArgs: List[Tree]) = + AppliedTypeTree(scalaDot(newTypeName("Product"+typeArgs.length)), typeArgs) + + /** Convert all occurrences of (lower-case) variables in a pattern as follows: * x becomes x @ _ * x: T becomes x @ (_: T) diff --git a/src/compiler/scala/tools/nsc/symtab/Definitions.scala b/src/compiler/scala/tools/nsc/symtab/Definitions.scala index bbfa7d6d5b..382b5a2c0b 100644 --- a/src/compiler/scala/tools/nsc/symtab/Definitions.scala +++ b/src/compiler/scala/tools/nsc/symtab/Definitions.scala @@ -116,7 +116,7 @@ trait Definitions requires SymbolTable { /* <unapply> */ val ProductClass: Array[Symbol] = new Array(MaxTupleArity + 1) - def productProj(n: Int, j: Int) = getMember(ProductClass(n), "__" + j) + def productProj(n: Int, j: Int) = getMember(ProductClass(n), "_" + j) def isProductType(tp: Type): Boolean = tp match { case TypeRef(_, sym, elems) => elems.length <= MaxTupleArity && sym == ProductClass(elems.length); @@ -146,6 +146,15 @@ trait Definitions requires SymbolTable { def someType(tp: Type) = typeRef(SomeClass.typeConstructor.prefix, SomeClass, List(tp)) + def optionOfProductElems(tp: Type): List[Type] = { + assert(tp.symbol == OptionClass) + val prod = tp.typeArgs.head + if (prod.symbol == UnitClass) List() + else prod.baseClasses.find { x => isProductType(x.tpe) } match { + case Some(p) => prod.baseType(p).typeArgs + } + } + /* </unapply> */ val MaxFunctionArity = 9 val FunctionClass: Array[Symbol] = new Array(MaxFunctionArity + 1) diff --git a/src/compiler/scala/tools/nsc/symtab/StdNames.scala b/src/compiler/scala/tools/nsc/symtab/StdNames.scala index cf383bb446..951349e54f 100644 --- a/src/compiler/scala/tools/nsc/symtab/StdNames.scala +++ b/src/compiler/scala/tools/nsc/symtab/StdNames.scala @@ -82,6 +82,7 @@ trait StdNames requires SymbolTable { val MODULE_SUFFIX = newTermName("$module") val LOCALDUMMY_PREFIX = newTermName(LOCALDUMMY_PREFIX_STRING) val THIS_SUFFIX = newTermName(".this") + val SELECTOR_DUMMY = newTermName("<unapply-selector>") val MODULE_INSTANCE_FIELD = newTermName("MODULE$") @@ -209,7 +210,7 @@ trait StdNames requires SymbolTable { val ScalaRunTime = newTermName("ScalaRunTime") val Seq = newTermName("Seq") val Short = newTermName("Short") - val Some = newTypeName("Some") + val Some = newTermName("Some") val SourceFile = newTermName("SourceFile") val String = newTermName("String") val Symbol = newTermName("Symbol") diff --git a/src/compiler/scala/tools/nsc/typechecker/Namers.scala b/src/compiler/scala/tools/nsc/typechecker/Namers.scala index 1ef4585800..63a3bc9367 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Namers.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Namers.scala @@ -400,34 +400,8 @@ trait Namers requires Analyzer { case _ => tpe }); - //<unapply>bq: this should probably be in SyntheticMethods, but needs the typer - private def getCaseFields(templ_stats:List[Tree]):List[Pair[Type,Name]] = - for(val z <- templ_stats; // why does `z: ValDef <- ' not work? because translation of `for' is buggy? - z.isInstanceOf[ ValDef ]; - val x = z.asInstanceOf[ ValDef ]; - x.mods.hasFlag( CASEACCESSOR )) - yield Pair(typer.typedType(x.tpt).tpe, x.name) - //</unapply> - private def templateSig(templ0: Template): Type = { var templ = templ0 - //<unapply> - if(settings.Xunapply.value && (context.owner hasFlag CASE)) { - val caseFields = getCaseFields(templ.body) - if(caseFields.length > 0) { - //if(settings.debug.value) Console.println("[ templateSig("+templ+") of case class") - val addparent = TypeTree(productType(caseFields map (._1))) - var i = 0; - // CAREFUL for Tuple1, name `_1' will be added later by synthetic methods :/ - val addimpl = caseFields map { x => - val ident = Ident(x._2) - i = i + 1 - DefDef(Modifiers(OVERRIDE | FINAL), "__"+i.toString(), List(), List(List()), TypeTree(x._1), ident) // pick __1 for now - } - templ = copy.Template(templ, templ.parents:::List(addparent),templ.body:::addimpl) - } - } - //</unapply> val clazz = context.owner def checkParent(tpt: Tree): Type = { val tp = tpt.tpe diff --git a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala index 05bafc1020..4bec142283 100644 --- a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala +++ b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala @@ -32,42 +32,6 @@ trait SyntheticMethods requires Analyzer { import definitions._ // standard classes and methods import typer.{typed} // methods to type trees - /** adds ProductN parent class and methods */ - def addProductParts(clazz: Symbol, templ:Template): Template = { - def newSyntheticMethod(name: Name, flags: Int, tpe: Type) = { - val method = clazz.newMethod(clazz.pos, name) setFlag (flags) setInfo tpe - clazz.info.decls.enter(method) - method - } - - def addParent:List[Tree] = { - val caseFields = clazz.caseFieldAccessors - val caseTypes = caseFields map { x => TypeTree(x.tpe.resultType) } - val prodTree:Tree = TypeTree(productType(caseFields map { x => x.tpe.resultType })) - templ.parents ::: List(prodTree) - } - - def addImpl: List[Tree] = { // test - var i = 1; - val defs = clazz.caseFieldAccessors map { - x => - val ident = gen.mkAttributedRef(x) - val method = clazz.info.decl("__"+i.toString()) - i = i + 1 - DefDef(method, {vparamss => ident}) - } - templ.body ::: defs - } - - if(clazz.caseFieldAccessors.length == 0) - templ - else { - //Console.println("#[addProductParts("+clazz+","+templ) - //Console.println("(]#") - copy.Template(templ, addParent, addImpl) - } - } - /** * @param templ ... * @param clazz ... @@ -92,6 +56,11 @@ trait SyntheticMethods requires Analyzer { method } + def productSelectorMethod(n: int, accessor: Symbol): Tree = { + val method = syntheticMethod(newTermName("_"+n), FINAL, accessor.tpe) + typed(DefDef(method, vparamss => gen.mkAttributedRef(accessor))) + } + def caseElementMethod: Tree = { val method = syntheticMethod( nme.caseElement, FINAL, MethodType(List(IntClass.tpe), AnyClass.tpe)) @@ -224,6 +193,9 @@ trait SyntheticMethods requires Analyzer { else Apply(gen.mkAttributedRef(sym), List(Ident(vparamss.head.head))))) } + def isPublic(sym: Symbol) = + !sym.hasFlag(PRIVATE | PROTECTED) && sym.privateWithin == NoSymbol + if (!phase.erasedTypes) { if (clazz hasFlag CASE) { @@ -231,11 +203,10 @@ trait SyntheticMethods requires Analyzer { clazz.attributes = Triple(SerializableAttr.tpe, List(), List()) :: clazz.attributes for (val stat <- templ.body) { - if (stat.isDef && stat.symbol.isMethod && stat.symbol.hasFlag(CASEACCESSOR) && - (stat.symbol.hasFlag(PRIVATE | PROTECTED) || stat.symbol.privateWithin != NoSymbol)) { - ts += newAccessorMethod(stat) - stat.symbol.resetFlag(CASEACCESSOR) - } + if (stat.isDef && stat.symbol.isMethod && stat.symbol.hasFlag(CASEACCESSOR) && !isPublic(stat.symbol)) { + ts += newAccessorMethod(stat) + stat.symbol.resetFlag(CASEACCESSOR) + } } if (clazz.info.nonPrivateDecl(nme.tag) == NoSymbol) ts += tagMethod @@ -249,7 +220,12 @@ trait SyntheticMethods requires Analyzer { if (!hasImplementation(nme.caseElement)) ts += caseElementMethod if (!hasImplementation(nme.caseArity)) ts += caseArityMethod if (!hasImplementation(nme.caseName)) ts += caseNameMethod + for (val i <- 0 until clazz.caseFieldAccessors.length) { + val acc = clazz.caseFieldAccessors(i) + if (acc.name.toString != "_"+(i+1)) ts += productSelectorMethod(i+1, acc) + } } + if (clazz.isModuleClass && isSerializable(clazz)) { // If you serialize a singleton and then deserialize it twice, // you will have two instances of your singleton, unless you implement diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala index 9850f71ae4..47f52c2041 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala @@ -396,6 +396,12 @@ trait Typers requires Analyzer { } else tree } + def unapplyMember(tp: Type): Symbol = { + var unapp = tp.member(nme.unapply) + if (unapp == NoSymbol) unapp = tp.member(nme.unapplySeq) + unapp + } + /** Perform the following adaptations of expression, pattern or type `tree' wrt to * given mode `mode' and given prototype `pt': * (0) Convert expressions with constant types to literals @@ -517,15 +523,11 @@ trait Typers requires Analyzer { // fix symbol -- we are using the module not the class val consp = if(clazz.isModule) clazz else { val obj = clazz.linkedModuleOfClass - tree.setSymbol(obj) + if (obj != NoSymbol) tree.setSymbol(obj) obj } - var unapp = { // find unapply[Seq] mehod - val x = consp.tpe.decl(nme.unapply) - if (x != NoSymbol) x else consp.tpe.decl(nme.unapplySeq) - } - if(unapp != NoSymbol) tree - else errorTree(tree, "" + clazz + " does not have unapply/unapplySeq method") + if (unapplyMember(consp.tpe).exists) tree + else errorTree(tree, "" + clazz + " is not a case class, nor does it have unapply/unapplySeq method") } else { errorTree(tree, "" + clazz + " is neither a case class nor a sequence class") } @@ -741,14 +743,8 @@ trait Typers requires Analyzer { reenterTypeParams(cdef.tparams) val tparams1 = List.mapConserve(cdef.tparams)(typedAbsTypeDef) val tpt1 = checkNoEscaping.privates(clazz.thisSym, typedType(cdef.tpt)) - //<unapply> // add ProductN before typing - var impl0 = if(settings.Xunapply.value && (clazz hasFlag CASE) && !phase.erasedTypes) { - addProductParts(clazz, cdef.impl) - } else - cdef.impl - //</unapply> - val impl1 = newTyper(context.make(impl0, clazz, newScope)) - .typedTemplate(impl0, parentTypes(impl0)) + val impl1 = newTyper(context.make(cdef.impl, clazz, newScope)) + .typedTemplate(cdef.impl, parentTypes(cdef.impl)) val impl2 = addSyntheticMethods(impl1, clazz, context.unit) val ret = copy.ClassDef(cdef, cdef.mods, cdef.name, tparams1, tpt1, impl2) .setType(NoType) @@ -1196,61 +1192,32 @@ trait Typers requires Analyzer { typedApply(tree, adapt(fun, funMode(mode), WildcardType), args1, mode, pt) /* --- begin unapply --- */ case otpe @ SingleType(_,sym) if settings.Xunapply.value => // normally, an object 'Foo' cannot be applied -> unapply pattern - - val unapp = otpe.decl(nme.unapply) - if(unapp.exists) unapp.tpe match { // try unapply first - case MethodType(formals0,restpe) => // must take (x:Any) - - if(formals0.length != 1) - return errorTree(tree,"unapply should take exactly one argument but takes "+formals0) - - val argt = formals0(0) // @todo: check this against pt, e.g. cons[int](_hd:int_, tl:List[int]) - - var prodtpe: Type = null - var nargs: Int = -1 // signals error - - val sometpe = restpe match { - case TypeRef(_,sym, List(stpe)) if sym == definitions.getClass("scala.Option") => stpe - case _ => return errorTree(tree,"unapply should return option[.], not "+restpe); null - } - - val rsym = sometpe.symbol - - sometpe.baseClasses.find { x => isProductType(x.tpe) } match { - case Some(x) => - prodtpe = sometpe.baseType(x) - nargs = x.tpe.asInstanceOf[TypeRef].args.length - case _ => - if(sometpe =:= definitions.UnitClass.tpe) - nargs = 0 - else - return errorTree(tree, "result type '"+sometpe+"' of unapply is neither option[product] nor option[unit]") - } - - if(nargs != args.length) - return errorTree(tree, "wrong number of arguments for unapply, expects "+formals0) - - // check arg types - val child = for(val Pair(arg,atpe) <- args.zip(sometpe.typeArgs)) yield typed(arg, mode, atpe) - - // Product_N(...) - val child1 = Apply(Ident(prodtpe.symbol.name) setType prodtpe setSymbol prodtpe.symbol, child) setSymbol prodtpe.symbol setType prodtpe - - // Some(Product_N(...)) DBKK - val some = typed(Apply(Ident(nme.Some), List(child1)), mode, definitions.optionType(prodtpe) /*AnyClass.tpe*/) - - // Foo.unapply(Some(Product_N(...))) - return Apply(gen.mkAttributedSelect(fun0,unapp), List(some)) setType pt - - case PolyType(params,restpe) => - errorTree(tree, "polym unapply not implemented") - case tpe => - errorTree(tree, " can't handle unapply of type "+tpe) - } // no unapply - - val unappSeq = otpe.decl(nme.unapplySeq) - errorTree(tree, " can't handle unapplySeq") - + // !!! this is fragile, maybe needs to be revised when unapply patterns become terms + val unapp = unapplyMember(otpe) + assert(unapp.exists, tree) + assert(isFullyDefined(pt)) + val argDummy = context.owner.newValue(fun.pos, nme.SELECTOR_DUMMY) + .setFlag(SYNTHETIC) + .setInfo(pt) + if (args.length > MaxTupleArity) + error(fun.pos, "too many arguments for unapply pattern, maximum = "+MaxTupleArity) + val arg = Ident(argDummy) setType pt + val prod = + if (args.length == 0) UnitClass.tpe + else productType(args map (arg => WildcardType)) + val tupleSym = + if (args.length == 0) UnitClass + else TupleClass(args.length) + + val funPt = appliedType(OptionClass.typeConstructor, List(prod)) + val fun0 = Ident(sym) setPos fun.pos setType otpe // no longer needed when patterns are terms!!! + val fun1 = typed(atPos(fun.pos) { Apply(Select(fun0, unapp), List(arg)) }, EXPRmode, funPt) + if (fun1.tpe.isErroneous) setError(tree) + else { + val argPts = optionOfProductElems(fun1.tpe) + val args1 = List.map2(args, argPts)((arg, formal) => typedArg(arg, mode, formal)) + UnApply(fun1, args1) setPos tree.pos setType pt + } /* --- end unapply --- */ case MethodType(formals0, restpe) => val formals = formalTypes(formals0, args.length) diff --git a/test/pending/pos/unapply.scala b/test/pending/pos/unapply.scala index ee33845d81..4461c3324c 100644 --- a/test/pending/pos/unapply.scala +++ b/test/pending/pos/unapply.scala @@ -1,4 +1,4 @@ -case class MyTuple2[A,B](val _1:A, val _2:B) +case class MyTuple2[A,B](val _1:A, val snd:B) object Foo { def unapply(x:Any): Option[Product2[Int,String]] = { |