From aafc0fe172b8db946ccf74110d2f6fa257ffa094 Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Thu, 21 Jul 2011 23:35:54 +0000 Subject: As per discussion documented in SI-1799, brough... As per discussion documented in SI-1799, brought back the ProductN traits and synthesized them into case classes. It's -Xexperimental for now because there may be minor implications for existing code which should be discussed. And also because I snuck in another "improvement" but it's probably too dangerous to be touching productIterator directly and it should go into something else. scala> case class Bippy(x: Int, y: Int) defined class Bippy scala> Bippy(5, 10).productIterator res0: Iterator[Int] = non-empty iterator ^^^----- as opposed to Iterator[Any] There is an even better idea available than lubbing the case class field types: it starts with "H" and ends with "List"... Review by oderksy. --- .../scala/reflect/internal/Definitions.scala | 1 + src/compiler/scala/reflect/internal/StdNames.scala | 1 + .../scala/tools/nsc/ast/parser/Parsers.scala | 16 ++++++- .../scala/tools/nsc/ast/parser/TreeBuilder.scala | 1 + .../tools/nsc/typechecker/SyntheticMethods.scala | 51 ++++++++++++++++++---- src/library/scala/runtime/ScalaRunTime.scala | 14 ++++++ test/files/pos/caseclass-productN.flags | 1 + test/files/pos/caseclass-productN.scala | 20 +++++++++ 8 files changed, 96 insertions(+), 9 deletions(-) create mode 100644 test/files/pos/caseclass-productN.flags create mode 100644 test/files/pos/caseclass-productN.scala diff --git a/src/compiler/scala/reflect/internal/Definitions.scala b/src/compiler/scala/reflect/internal/Definitions.scala index 314fe1ed82..e1a3e732b0 100644 --- a/src/compiler/scala/reflect/internal/Definitions.scala +++ b/src/compiler/scala/reflect/internal/Definitions.scala @@ -415,6 +415,7 @@ trait Definitions extends reflect.api.StandardDefinitions { def Product_productArity = getMember(ProductRootClass, nme.productArity) def Product_productElement = getMember(ProductRootClass, nme.productElement) // def Product_productElementName = getMember(ProductRootClass, nme.productElementName) + def Product_iterator = getMember(ProductRootClass, nme.productIterator) def Product_productPrefix = getMember(ProductRootClass, nme.productPrefix) def Product_canEqual = getMember(ProductRootClass, nme.canEqual_) diff --git a/src/compiler/scala/reflect/internal/StdNames.scala b/src/compiler/scala/reflect/internal/StdNames.scala index fdbe918d55..43fe89ed63 100644 --- a/src/compiler/scala/reflect/internal/StdNames.scala +++ b/src/compiler/scala/reflect/internal/StdNames.scala @@ -247,6 +247,7 @@ trait StdNames extends /*reflect.generic.StdNames with*/ NameManglers { self: Sy val ofDim: NameType = "ofDim" val productArity: NameType = "productArity" val productElement: NameType = "productElement" + val productIterator: NameType = "productIterator" val productPrefix: NameType = "productPrefix" val readResolve: NameType = "readResolve" val sameElements: NameType = "sameElements" diff --git a/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala b/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala index a90cb6c409..f0d6f4c5bf 100644 --- a/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala +++ b/src/compiler/scala/tools/nsc/ast/parser/Parsers.scala @@ -2649,6 +2649,20 @@ self => * }}} */ def templateOpt(mods: Modifiers, name: Name, constrMods: Modifiers, vparamss: List[List[ValDef]], tstart: Int): Template = { + /** A synthetic ProductN parent for case classes. */ + def extraCaseParents = ( + if (settings.Xexperimental.value && mods.isCase) { + val arity = if (vparamss.isEmpty || vparamss.head.isEmpty) 0 else vparamss.head.size + if (arity == 0) Nil + else List( + AppliedTypeTree( + productConstrN(arity), + vparamss.head map (vd => vd.tpt) + ) + ) + } + else Nil + ) val (parents0, argss, self, body) = ( if (in.token == EXTENDS || in.token == SUBTYPE && mods.hasTraitFlag) { in.nextToken() @@ -2673,7 +2687,7 @@ self => else if (parents0.isEmpty) List(scalaAnyRefConstr) else parents0 ) ++ ( - if (mods.isCase) List(productConstr, serializableConstr) + if (mods.isCase) List(productConstr, serializableConstr) ++ extraCaseParents else Nil ) diff --git a/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala b/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala index b82ff4fb4f..79aa69efe7 100644 --- a/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala +++ b/src/compiler/scala/tools/nsc/ast/parser/TreeBuilder.scala @@ -33,6 +33,7 @@ abstract class TreeBuilder { def scalaUnitConstr = gen.scalaUnitConstr def scalaScalaObjectConstr = gen.scalaScalaObjectConstr def productConstr = gen.productConstr + def productConstrN(n: Int) = scalaDot(newTypeName("Product" + n)) def serializableConstr = gen.serializableConstr def convertToTypeName(t: Tree) = gen.convertToTypeName(t) diff --git a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala index 63456849cd..95a86b25c8 100644 --- a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala +++ b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala @@ -44,6 +44,8 @@ trait SyntheticMethods extends ast.TreeDSL { val localTyper = newTyper( if (reporter.hasErrors) context makeSilent false else context ) + def accessorTypes = clazz.caseFieldAccessors map (_.tpe.finalResultType) + def accessorLub = global.weakLub(accessorTypes)._1 def hasOverridingImplementation(meth: Symbol): Boolean = { val sym = clazz.info nonPrivateMember meth.name @@ -72,14 +74,30 @@ trait SyntheticMethods extends ast.TreeDSL { import CODE._ - def productPrefixMethod: Tree = typer.typed { - val method = syntheticMethod(nme.productPrefix, 0, sym => NullaryMethodType(StringClass.tpe)) - DEF(method) === LIT(clazz.name.decode) + def newNullaryMethod(name: Name, tpe: Type, body: Tree) = { + val flags = if (clazz.tpe.member(name.toTermName) != NoSymbol) OVERRIDE else 0 + val method = clazz.newMethod(clazz.pos.focus, name.toTermName) setFlag flags + + method setInfo NullaryMethodType(tpe) + clazz.info.decls enter method + + typer typed (DEF(method) === body) + } + def productPrefixMethod = newNullaryMethod(nme.productPrefix, StringClass.tpe, LIT(clazz.name.decode)) + def productArityMethod(arity: Int) = newNullaryMethod(nme.productArity, IntClass.tpe, LIT(arity)) + def productIteratorMethod = { + val method = getMember(ScalaRunTimeModule, "typedProductIterator") + val iteratorType = typeRef(NoPrefix, IteratorClass, List(accessorLub)) + + newNullaryMethod( + nme.productIterator, + iteratorType, + gen.mkMethodCall(method, List(accessorLub), List(This(clazz))) + ) } - def productArityMethod(nargs: Int): Tree = { - val method = syntheticMethod(nme.productArity, 0, sym => NullaryMethodType(IntClass.tpe)) - typer typed { DEF(method) === LIT(nargs) } + def projectionMethod(accessor: Symbol, num: Int) = { + newNullaryMethod(nme.productAccessorName(num), accessor.tpe.resultType, REF(accessor)) } /** Common code for productElement and (currently disabled) productElementName @@ -230,7 +248,9 @@ trait SyntheticMethods extends ast.TreeDSL { if (!phase.erasedTypes) try { if (clazz.isCase) { - val isTop = clazz.ancestors forall (x => !x.isCase) + val isTop = clazz.ancestors forall (x => !x.isCase) + val accessors = clazz.caseFieldAccessors + val arity = accessors.size if (isTop) { // If this case class has fields with less than public visibility, their getter at this @@ -243,13 +263,29 @@ trait SyntheticMethods extends ast.TreeDSL { stat.symbol resetFlag CASEACCESSOR } } + /** The _1, _2, etc. methods to implement ProductN, and an override + * of productIterator with a more specific element type. + * Only enabled under -Xexperimental. + */ + def productNMethods = { + val projectionMethods = (accessors, 1 to arity).zipped map ((accessor, num) => + productProj(arity, num) -> (() => projectionMethod(accessor, num)) + ) + projectionMethods :+ ( + Product_iterator -> (() => productIteratorMethod) + ) + } // methods for case classes only def classMethods = List( Object_hashCode -> (() => forwardingMethod(nme.hashCode_, "_" + nme.hashCode_)), Object_toString -> (() => forwardingMethod(nme.toString_, "_" + nme.toString_)), Object_equals -> (() => equalsClassMethod) + ) ++ ( + if (settings.Xexperimental.value) productNMethods + else Nil ) + // methods for case objects only def objectMethods = List( Object_hashCode -> (() => moduleHashCodeMethod), @@ -257,7 +293,6 @@ trait SyntheticMethods extends ast.TreeDSL { ) // methods for both classes and objects def everywhereMethods = { - val accessors = clazz.caseFieldAccessors List( Product_productPrefix -> (() => productPrefixMethod), Product_productArity -> (() => productArityMethod(accessors.length)), diff --git a/src/library/scala/runtime/ScalaRunTime.scala b/src/library/scala/runtime/ScalaRunTime.scala index 4d5f783d61..c8cc624468 100644 --- a/src/library/scala/runtime/ScalaRunTime.scala +++ b/src/library/scala/runtime/ScalaRunTime.scala @@ -199,6 +199,20 @@ object ScalaRunTime { } } + /** A helper for case classes. */ + def typedProductIterator[T](x: Product): Iterator[T] = { + new Iterator[T] { + private var c: Int = 0 + private val cmax = x.productArity + def hasNext = c < cmax + def next() = { + val result = x.productElement(c) + c += 1 + result.asInstanceOf[T] + } + } + } + /** Fast path equality method for inlining; used when -optimise is set. */ @inline def inlinedEquals(x: Object, y: Object): Boolean = diff --git a/test/files/pos/caseclass-productN.flags b/test/files/pos/caseclass-productN.flags new file mode 100644 index 0000000000..e1b37447c9 --- /dev/null +++ b/test/files/pos/caseclass-productN.flags @@ -0,0 +1 @@ +-Xexperimental \ No newline at end of file diff --git a/test/files/pos/caseclass-productN.scala b/test/files/pos/caseclass-productN.scala new file mode 100644 index 0000000000..a0964218ac --- /dev/null +++ b/test/files/pos/caseclass-productN.scala @@ -0,0 +1,20 @@ +object Test { + class A + class B extends A + class C extends B + + case class Bippy[T](x: Int, y: List[T], z: T) { } + case class Bippy2[T](x: Int, y: List[T], z: T) { } + + def bippies = List( + Bippy(5, List(new C), new B), + Bippy2(5, List(new B), new C) + ) + + def bmethod(x: B) = () + + def main(args: Array[String]): Unit = { + bippies flatMap (_._2) foreach bmethod + bippies map (_._3) foreach bmethod + } +} -- cgit v1.2.3