summaryrefslogtreecommitdiff
path: root/src/compiler
diff options
context:
space:
mode:
authorMartin Odersky <odersky@gmail.com>2006-10-27 17:18:58 +0000
committerMartin Odersky <odersky@gmail.com>2006-10-27 17:18:58 +0000
commitd8e8ab6a9ec2550716278c8ddffa03d295531808 (patch)
tree8632c6e124817786a80f3cefeb6e2134950c8af3 /src/compiler
parent5c642cbca2725bc45b2e62ff224c34c92a9b1012 (diff)
downloadscala-d8e8ab6a9ec2550716278c8ddffa03d295531808.tar.gz
scala-d8e8ab6a9ec2550716278c8ddffa03d295531808.tar.bz2
scala-d8e8ab6a9ec2550716278c8ddffa03d295531808.zip
changed unapply impl
Diffstat (limited to 'src/compiler')
-rw-r--r--src/compiler/scala/tools/nsc/ast/TreePrinters.scala3
-rw-r--r--src/compiler/scala/tools/nsc/ast/Trees.scala16
-rw-r--r--src/compiler/scala/tools/nsc/ast/parser/Parsers.scala10
-rw-r--r--src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala4
-rw-r--r--src/compiler/scala/tools/nsc/symtab/Definitions.scala11
-rw-r--r--src/compiler/scala/tools/nsc/symtab/StdNames.scala3
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/Namers.scala26
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala58
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/Typers.scala107
9 files changed, 98 insertions, 140 deletions
diff --git a/src/compiler/scala/tools/nsc/ast/TreePrinters.scala b/src/compiler/scala/tools/nsc/ast/TreePrinters.scala
index f63a1a3991..ad447d6196 100644
--- a/src/compiler/scala/tools/nsc/ast/TreePrinters.scala
+++ b/src/compiler/scala/tools/nsc/ast/TreePrinters.scala
@@ -230,6 +230,9 @@ abstract class TreePrinters {
case Bind(name, t) =>
print("("); print(symName(tree, name)); print(" @ "); print(t); print(")")
+ case UnApply(fun, args) =>
+ print(fun); print(" <unapply> "); printRow(args, "(", ", ", ")")
+
case ArrayValue(elemtpt, trees) =>
print("Array["); print(elemtpt); printRow(trees, "]{", ", ", "}")
diff --git a/src/compiler/scala/tools/nsc/ast/Trees.scala b/src/compiler/scala/tools/nsc/ast/Trees.scala
index 4a64067f8a..286f297b30 100644
--- a/src/compiler/scala/tools/nsc/ast/Trees.scala
+++ b/src/compiler/scala/tools/nsc/ast/Trees.scala
@@ -419,6 +419,9 @@ trait Trees requires Global {
def Bind(sym: Symbol, body: Tree): Bind =
Bind(sym.name, body) setSymbol sym
+ case class UnApply(fun: Tree, args: List[Tree])
+ extends TermTree
+
/** Array of expressions, needs to be translated in backend,
*/
case class ArrayValue(elemtpt: Tree, elems: List[Tree])
@@ -598,6 +601,7 @@ trait Trees requires Global {
case Alternative(trees) => (eliminated by transmatch)
case Star(elem) => (eliminated by transmatch)
case Bind(name, body) => (eliminated by transmatch)
+ case UnApply(fun: Tree, args) (introduced by typer, eliminated by transmatch)
case ArrayValue(elemtpt, trees) => (introduced by uncurry)
case Function(vparams, body) => (eliminated by lambdaLift)
case Assign(lhs, rhs) =>
@@ -642,6 +646,7 @@ trait Trees requires Global {
def Alternative(tree: Tree, trees: List[Tree]): Alternative
def Star(tree: Tree, elem: Tree): Star
def Bind(tree: Tree, name: Name, body: Tree): Bind
+ def UnApply(tree: Tree, fun: Tree, args: List[Tree]): UnApply
def ArrayValue(tree: Tree, elemtpt: Tree, trees: List[Tree]): ArrayValue
def Function(tree: Tree, vparams: List[ValDef], body: Tree): Function
def Assign(tree: Tree, lhs: Tree, rhs: Tree): Assign
@@ -704,6 +709,8 @@ trait Trees requires Global {
new Star(elem).copyAttrs(tree)
def Bind(tree: Tree, name: Name, body: Tree) =
new Bind(name, body).copyAttrs(tree)
+ def UnApply(tree: Tree, fun: Tree, args: List[Tree]) =
+ new UnApply(fun, args).copyAttrs(tree)
def ArrayValue(tree: Tree, elemtpt: Tree, trees: List[Tree]) =
new ArrayValue(elemtpt, trees).copyAttrs(tree)
def Function(tree: Tree, vparams: List[ValDef], body: Tree) =
@@ -845,6 +852,11 @@ trait Trees requires Global {
if (name0 == name) && (body0 == body) => t
case _ => copy.Bind(tree, name, body)
}
+ def UnApply(tree: Tree, fun: Tree, args: List[Tree]) = tree match {
+ case t @ UnApply(fun0, args0)
+ if (fun0 == fun) && (args0 == args) => t
+ case _ => copy.UnApply(tree, fun, args)
+ }
def ArrayValue(tree: Tree, elemtpt: Tree, trees: List[Tree]) = tree match {
case t @ ArrayValue(elemtpt0, trees0)
if (elemtpt0 == elemtpt) && (trees0 == trees) => t
@@ -1022,6 +1034,8 @@ trait Trees requires Global {
copy.Star(tree, transform(elem))
case Bind(name, body) =>
copy.Bind(tree, name, transform(body))
+ case UnApply(fun, args) =>
+ copy.UnApply(tree, fun, args)
case ArrayValue(elemtpt, trees) =>
copy.ArrayValue(tree, transform(elemtpt), transformTrees(trees))
case Function(vparams, body) =>
@@ -1156,6 +1170,8 @@ trait Trees requires Global {
traverse(elem)
case Bind(name, body) =>
traverse(body)
+ case UnApply(fun, args) =>
+ traverse(fun); traverseTrees(args)
case ArrayValue(elemtpt, trees) =>
traverse(elemtpt); traverseTrees(trees)
case Function(vparams, body) =>
diff --git a/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala b/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala
index 2d7ac53c3a..a3efad6ae9 100644
--- a/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala
+++ b/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala
@@ -1812,7 +1812,15 @@ trait Parsers requires SyntaxAnalyzer {
}
if (name != nme.ScalaObject.toTypeName)
parents += scalaScalaObjectConstr
- if (mods.hasFlag(Flags.CASE)) parents += caseClassConstr
+ if (mods.hasFlag(Flags.CASE)) {
+ parents += caseClassConstr
+ if (!vparamss.isEmpty) {
+ val argtypes: List[Tree] = vparamss.head map (.tpt.duplicate) //remove type annotation and you will get an interesting error message!!!
+ val nargs = argtypes.length
+ if (0 < nargs && nargs <= definitions.MaxTupleArity)
+ parents += productConstr(argtypes)
+ }
+ }
val ps = parents.toList
newLineOptWhenFollowedBy(LBRACE)
var body =
diff --git a/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala b/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala
index ed51119904..e9fdb5c91d 100644
--- a/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala
+++ b/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala
@@ -30,6 +30,10 @@ abstract class TreeBuilder {
def caseClassConstr: Tree =
scalaDot(nme.CaseClass.toTypeName)
+ def productConstr(typeArgs: List[Tree]) =
+ AppliedTypeTree(scalaDot(newTypeName("Product"+typeArgs.length)), typeArgs)
+
+
/** Convert all occurrences of (lower-case) variables in a pattern as follows:
* x becomes x @ _
* x: T becomes x @ (_: T)
diff --git a/src/compiler/scala/tools/nsc/symtab/Definitions.scala b/src/compiler/scala/tools/nsc/symtab/Definitions.scala
index bbfa7d6d5b..382b5a2c0b 100644
--- a/src/compiler/scala/tools/nsc/symtab/Definitions.scala
+++ b/src/compiler/scala/tools/nsc/symtab/Definitions.scala
@@ -116,7 +116,7 @@ trait Definitions requires SymbolTable {
/* <unapply> */
val ProductClass: Array[Symbol] = new Array(MaxTupleArity + 1)
- def productProj(n: Int, j: Int) = getMember(ProductClass(n), "__" + j)
+ def productProj(n: Int, j: Int) = getMember(ProductClass(n), "_" + j)
def isProductType(tp: Type): Boolean = tp match {
case TypeRef(_, sym, elems) =>
elems.length <= MaxTupleArity && sym == ProductClass(elems.length);
@@ -146,6 +146,15 @@ trait Definitions requires SymbolTable {
def someType(tp: Type) =
typeRef(SomeClass.typeConstructor.prefix, SomeClass, List(tp))
+ def optionOfProductElems(tp: Type): List[Type] = {
+ assert(tp.symbol == OptionClass)
+ val prod = tp.typeArgs.head
+ if (prod.symbol == UnitClass) List()
+ else prod.baseClasses.find { x => isProductType(x.tpe) } match {
+ case Some(p) => prod.baseType(p).typeArgs
+ }
+ }
+
/* </unapply> */
val MaxFunctionArity = 9
val FunctionClass: Array[Symbol] = new Array(MaxFunctionArity + 1)
diff --git a/src/compiler/scala/tools/nsc/symtab/StdNames.scala b/src/compiler/scala/tools/nsc/symtab/StdNames.scala
index cf383bb446..951349e54f 100644
--- a/src/compiler/scala/tools/nsc/symtab/StdNames.scala
+++ b/src/compiler/scala/tools/nsc/symtab/StdNames.scala
@@ -82,6 +82,7 @@ trait StdNames requires SymbolTable {
val MODULE_SUFFIX = newTermName("$module")
val LOCALDUMMY_PREFIX = newTermName(LOCALDUMMY_PREFIX_STRING)
val THIS_SUFFIX = newTermName(".this")
+ val SELECTOR_DUMMY = newTermName("<unapply-selector>")
val MODULE_INSTANCE_FIELD = newTermName("MODULE$")
@@ -209,7 +210,7 @@ trait StdNames requires SymbolTable {
val ScalaRunTime = newTermName("ScalaRunTime")
val Seq = newTermName("Seq")
val Short = newTermName("Short")
- val Some = newTypeName("Some")
+ val Some = newTermName("Some")
val SourceFile = newTermName("SourceFile")
val String = newTermName("String")
val Symbol = newTermName("Symbol")
diff --git a/src/compiler/scala/tools/nsc/typechecker/Namers.scala b/src/compiler/scala/tools/nsc/typechecker/Namers.scala
index 1ef4585800..63a3bc9367 100644
--- a/src/compiler/scala/tools/nsc/typechecker/Namers.scala
+++ b/src/compiler/scala/tools/nsc/typechecker/Namers.scala
@@ -400,34 +400,8 @@ trait Namers requires Analyzer {
case _ => tpe
});
- //<unapply>bq: this should probably be in SyntheticMethods, but needs the typer
- private def getCaseFields(templ_stats:List[Tree]):List[Pair[Type,Name]] =
- for(val z <- templ_stats; // why does `z: ValDef <- ' not work? because translation of `for' is buggy?
- z.isInstanceOf[ ValDef ];
- val x = z.asInstanceOf[ ValDef ];
- x.mods.hasFlag( CASEACCESSOR ))
- yield Pair(typer.typedType(x.tpt).tpe, x.name)
- //</unapply>
-
private def templateSig(templ0: Template): Type = {
var templ = templ0
- //<unapply>
- if(settings.Xunapply.value && (context.owner hasFlag CASE)) {
- val caseFields = getCaseFields(templ.body)
- if(caseFields.length > 0) {
- //if(settings.debug.value) Console.println("[ templateSig("+templ+") of case class")
- val addparent = TypeTree(productType(caseFields map (._1)))
- var i = 0;
- // CAREFUL for Tuple1, name `_1' will be added later by synthetic methods :/
- val addimpl = caseFields map { x =>
- val ident = Ident(x._2)
- i = i + 1
- DefDef(Modifiers(OVERRIDE | FINAL), "__"+i.toString(), List(), List(List()), TypeTree(x._1), ident) // pick __1 for now
- }
- templ = copy.Template(templ, templ.parents:::List(addparent),templ.body:::addimpl)
- }
- }
- //</unapply>
val clazz = context.owner
def checkParent(tpt: Tree): Type = {
val tp = tpt.tpe
diff --git a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala
index 05bafc1020..4bec142283 100644
--- a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala
+++ b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala
@@ -32,42 +32,6 @@ trait SyntheticMethods requires Analyzer {
import definitions._ // standard classes and methods
import typer.{typed} // methods to type trees
- /** adds ProductN parent class and methods */
- def addProductParts(clazz: Symbol, templ:Template): Template = {
- def newSyntheticMethod(name: Name, flags: Int, tpe: Type) = {
- val method = clazz.newMethod(clazz.pos, name) setFlag (flags) setInfo tpe
- clazz.info.decls.enter(method)
- method
- }
-
- def addParent:List[Tree] = {
- val caseFields = clazz.caseFieldAccessors
- val caseTypes = caseFields map { x => TypeTree(x.tpe.resultType) }
- val prodTree:Tree = TypeTree(productType(caseFields map { x => x.tpe.resultType }))
- templ.parents ::: List(prodTree)
- }
-
- def addImpl: List[Tree] = { // test
- var i = 1;
- val defs = clazz.caseFieldAccessors map {
- x =>
- val ident = gen.mkAttributedRef(x)
- val method = clazz.info.decl("__"+i.toString())
- i = i + 1
- DefDef(method, {vparamss => ident})
- }
- templ.body ::: defs
- }
-
- if(clazz.caseFieldAccessors.length == 0)
- templ
- else {
- //Console.println("#[addProductParts("+clazz+","+templ)
- //Console.println("(]#")
- copy.Template(templ, addParent, addImpl)
- }
- }
-
/**
* @param templ ...
* @param clazz ...
@@ -92,6 +56,11 @@ trait SyntheticMethods requires Analyzer {
method
}
+ def productSelectorMethod(n: int, accessor: Symbol): Tree = {
+ val method = syntheticMethod(newTermName("_"+n), FINAL, accessor.tpe)
+ typed(DefDef(method, vparamss => gen.mkAttributedRef(accessor)))
+ }
+
def caseElementMethod: Tree = {
val method = syntheticMethod(
nme.caseElement, FINAL, MethodType(List(IntClass.tpe), AnyClass.tpe))
@@ -224,6 +193,9 @@ trait SyntheticMethods requires Analyzer {
else Apply(gen.mkAttributedRef(sym), List(Ident(vparamss.head.head)))))
}
+ def isPublic(sym: Symbol) =
+ !sym.hasFlag(PRIVATE | PROTECTED) && sym.privateWithin == NoSymbol
+
if (!phase.erasedTypes) {
if (clazz hasFlag CASE) {
@@ -231,11 +203,10 @@ trait SyntheticMethods requires Analyzer {
clazz.attributes = Triple(SerializableAttr.tpe, List(), List()) :: clazz.attributes
for (val stat <- templ.body) {
- if (stat.isDef && stat.symbol.isMethod && stat.symbol.hasFlag(CASEACCESSOR) &&
- (stat.symbol.hasFlag(PRIVATE | PROTECTED) || stat.symbol.privateWithin != NoSymbol)) {
- ts += newAccessorMethod(stat)
- stat.symbol.resetFlag(CASEACCESSOR)
- }
+ if (stat.isDef && stat.symbol.isMethod && stat.symbol.hasFlag(CASEACCESSOR) && !isPublic(stat.symbol)) {
+ ts += newAccessorMethod(stat)
+ stat.symbol.resetFlag(CASEACCESSOR)
+ }
}
if (clazz.info.nonPrivateDecl(nme.tag) == NoSymbol) ts += tagMethod
@@ -249,7 +220,12 @@ trait SyntheticMethods requires Analyzer {
if (!hasImplementation(nme.caseElement)) ts += caseElementMethod
if (!hasImplementation(nme.caseArity)) ts += caseArityMethod
if (!hasImplementation(nme.caseName)) ts += caseNameMethod
+ for (val i <- 0 until clazz.caseFieldAccessors.length) {
+ val acc = clazz.caseFieldAccessors(i)
+ if (acc.name.toString != "_"+(i+1)) ts += productSelectorMethod(i+1, acc)
+ }
}
+
if (clazz.isModuleClass && isSerializable(clazz)) {
// If you serialize a singleton and then deserialize it twice,
// you will have two instances of your singleton, unless you implement
diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala
index 9850f71ae4..47f52c2041 100644
--- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala
+++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala
@@ -396,6 +396,12 @@ trait Typers requires Analyzer {
} else tree
}
+ def unapplyMember(tp: Type): Symbol = {
+ var unapp = tp.member(nme.unapply)
+ if (unapp == NoSymbol) unapp = tp.member(nme.unapplySeq)
+ unapp
+ }
+
/** Perform the following adaptations of expression, pattern or type `tree' wrt to
* given mode `mode' and given prototype `pt':
* (0) Convert expressions with constant types to literals
@@ -517,15 +523,11 @@ trait Typers requires Analyzer {
// fix symbol -- we are using the module not the class
val consp = if(clazz.isModule) clazz else {
val obj = clazz.linkedModuleOfClass
- tree.setSymbol(obj)
+ if (obj != NoSymbol) tree.setSymbol(obj)
obj
}
- var unapp = { // find unapply[Seq] mehod
- val x = consp.tpe.decl(nme.unapply)
- if (x != NoSymbol) x else consp.tpe.decl(nme.unapplySeq)
- }
- if(unapp != NoSymbol) tree
- else errorTree(tree, "" + clazz + " does not have unapply/unapplySeq method")
+ if (unapplyMember(consp.tpe).exists) tree
+ else errorTree(tree, "" + clazz + " is not a case class, nor does it have unapply/unapplySeq method")
} else {
errorTree(tree, "" + clazz + " is neither a case class nor a sequence class")
}
@@ -741,14 +743,8 @@ trait Typers requires Analyzer {
reenterTypeParams(cdef.tparams)
val tparams1 = List.mapConserve(cdef.tparams)(typedAbsTypeDef)
val tpt1 = checkNoEscaping.privates(clazz.thisSym, typedType(cdef.tpt))
- //<unapply> // add ProductN before typing
- var impl0 = if(settings.Xunapply.value && (clazz hasFlag CASE) && !phase.erasedTypes) {
- addProductParts(clazz, cdef.impl)
- } else
- cdef.impl
- //</unapply>
- val impl1 = newTyper(context.make(impl0, clazz, newScope))
- .typedTemplate(impl0, parentTypes(impl0))
+ val impl1 = newTyper(context.make(cdef.impl, clazz, newScope))
+ .typedTemplate(cdef.impl, parentTypes(cdef.impl))
val impl2 = addSyntheticMethods(impl1, clazz, context.unit)
val ret = copy.ClassDef(cdef, cdef.mods, cdef.name, tparams1, tpt1, impl2)
.setType(NoType)
@@ -1196,61 +1192,32 @@ trait Typers requires Analyzer {
typedApply(tree, adapt(fun, funMode(mode), WildcardType), args1, mode, pt)
/* --- begin unapply --- */
case otpe @ SingleType(_,sym) if settings.Xunapply.value => // normally, an object 'Foo' cannot be applied -> unapply pattern
-
- val unapp = otpe.decl(nme.unapply)
- if(unapp.exists) unapp.tpe match { // try unapply first
- case MethodType(formals0,restpe) => // must take (x:Any)
-
- if(formals0.length != 1)
- return errorTree(tree,"unapply should take exactly one argument but takes "+formals0)
-
- val argt = formals0(0) // @todo: check this against pt, e.g. cons[int](_hd:int_, tl:List[int])
-
- var prodtpe: Type = null
- var nargs: Int = -1 // signals error
-
- val sometpe = restpe match {
- case TypeRef(_,sym, List(stpe)) if sym == definitions.getClass("scala.Option") => stpe
- case _ => return errorTree(tree,"unapply should return option[.], not "+restpe); null
- }
-
- val rsym = sometpe.symbol
-
- sometpe.baseClasses.find { x => isProductType(x.tpe) } match {
- case Some(x) =>
- prodtpe = sometpe.baseType(x)
- nargs = x.tpe.asInstanceOf[TypeRef].args.length
- case _ =>
- if(sometpe =:= definitions.UnitClass.tpe)
- nargs = 0
- else
- return errorTree(tree, "result type '"+sometpe+"' of unapply is neither option[product] nor option[unit]")
- }
-
- if(nargs != args.length)
- return errorTree(tree, "wrong number of arguments for unapply, expects "+formals0)
-
- // check arg types
- val child = for(val Pair(arg,atpe) <- args.zip(sometpe.typeArgs)) yield typed(arg, mode, atpe)
-
- // Product_N(...)
- val child1 = Apply(Ident(prodtpe.symbol.name) setType prodtpe setSymbol prodtpe.symbol, child) setSymbol prodtpe.symbol setType prodtpe
-
- // Some(Product_N(...)) DBKK
- val some = typed(Apply(Ident(nme.Some), List(child1)), mode, definitions.optionType(prodtpe) /*AnyClass.tpe*/)
-
- // Foo.unapply(Some(Product_N(...)))
- return Apply(gen.mkAttributedSelect(fun0,unapp), List(some)) setType pt
-
- case PolyType(params,restpe) =>
- errorTree(tree, "polym unapply not implemented")
- case tpe =>
- errorTree(tree, " can't handle unapply of type "+tpe)
- } // no unapply
-
- val unappSeq = otpe.decl(nme.unapplySeq)
- errorTree(tree, " can't handle unapplySeq")
-
+ // !!! this is fragile, maybe needs to be revised when unapply patterns become terms
+ val unapp = unapplyMember(otpe)
+ assert(unapp.exists, tree)
+ assert(isFullyDefined(pt))
+ val argDummy = context.owner.newValue(fun.pos, nme.SELECTOR_DUMMY)
+ .setFlag(SYNTHETIC)
+ .setInfo(pt)
+ if (args.length > MaxTupleArity)
+ error(fun.pos, "too many arguments for unapply pattern, maximum = "+MaxTupleArity)
+ val arg = Ident(argDummy) setType pt
+ val prod =
+ if (args.length == 0) UnitClass.tpe
+ else productType(args map (arg => WildcardType))
+ val tupleSym =
+ if (args.length == 0) UnitClass
+ else TupleClass(args.length)
+
+ val funPt = appliedType(OptionClass.typeConstructor, List(prod))
+ val fun0 = Ident(sym) setPos fun.pos setType otpe // no longer needed when patterns are terms!!!
+ val fun1 = typed(atPos(fun.pos) { Apply(Select(fun0, unapp), List(arg)) }, EXPRmode, funPt)
+ if (fun1.tpe.isErroneous) setError(tree)
+ else {
+ val argPts = optionOfProductElems(fun1.tpe)
+ val args1 = List.map2(args, argPts)((arg, formal) => typedArg(arg, mode, formal))
+ UnApply(fun1, args1) setPos tree.pos setType pt
+ }
/* --- end unapply --- */
case MethodType(formals0, restpe) =>
val formals = formalTypes(formals0, args.length)