From 54e284d669418ebc6445bd0ec66804b9067f6dd3 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Wed, 7 Mar 2012 14:13:03 +0100 Subject: Allows case classes as value classes --- .../tools/nsc/typechecker/SyntheticMethods.scala | 26 +++--- src/library/scala/Serializable.scala | 2 +- test/files/pos/t715/meredith_1.scala | 2 +- test/files/run/MeterCaseClass.check | 21 +++++ test/files/run/MeterCaseClass.scala | 99 ++++++++++++++++++++++ 5 files changed, 138 insertions(+), 12 deletions(-) create mode 100644 test/files/run/MeterCaseClass.check create mode 100644 test/files/run/MeterCaseClass.scala diff --git a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala index 9bee731e1e..1c78426c20 100644 --- a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala +++ b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala @@ -217,7 +217,7 @@ trait SyntheticMethods extends ast.TreeDSL { List( Product_productPrefix -> (() => constantNullary(nme.productPrefix, clazz.name.decode)), Product_productArity -> (() => constantNullary(nme.productArity, arity)), - Product_productElement -> (() => perElementMethod(nme.productElement, accessorLub)(Ident)), + Product_productElement -> (() => perElementMethod(nme.productElement, accessorLub)(Select(This(clazz), _))), Product_iterator -> (() => productIteratorMethod), Product_canEqual -> (() => canEqualMethod) // This is disabled pending a reimplementation which doesn't add any @@ -226,12 +226,21 @@ trait SyntheticMethods extends ast.TreeDSL { ) } + def valueClassMethods = List( + Any_hashCode -> (() => hashCodeDerivedValueClassMethod), + Any_equals -> (() => equalsDerivedValueClassMethod) + ) + def caseClassMethods = productMethods ++ productNMethods ++ Seq( Object_hashCode -> (() => forwardToRuntime(Object_hashCode)), Object_toString -> (() => forwardToRuntime(Object_toString)), Object_equals -> (() => equalsCaseClassMethod) ) + def valueCaseClassMethods = productMethods ++ productNMethods ++ valueClassMethods ++ Seq( + Any_toString -> (() => forwardToRuntime(Object_toString)) + ) + def caseObjectMethods = productMethods ++ Seq( Object_hashCode -> (() => constantMethod(nme.hashCode_, clazz.name.decode.hashCode)), Object_toString -> (() => constantMethod(nme.toString_, clazz.name.decode)) @@ -239,11 +248,6 @@ trait SyntheticMethods extends ast.TreeDSL { // Object_equals -> (() => createMethod(Object_equals)(m => This(clazz) ANY_EQ Ident(m.firstParam))) ) - def inlineClassMethods = List( - Any_hashCode -> (() => hashCodeDerivedValueClassMethod), - Any_equals -> (() => equalsDerivedValueClassMethod) - ) - /** If you serialize a singleton and then deserialize it twice, * you will have two instances of your singleton unless you implement * readResolve. Here it is implemented for all objects which have @@ -258,10 +262,12 @@ trait SyntheticMethods extends ast.TreeDSL { def synthesize(): List[Tree] = { val methods = ( - if (clazz.isDerivedValueClass) inlineClassMethods - else if (!clazz.isCase) Nil - else if (clazz.isModuleClass) caseObjectMethods - else caseClassMethods + if (clazz.isCase) + if (clazz.isDerivedValueClass) valueCaseClassMethods + else if (clazz.isModuleClass) caseObjectMethods + else caseClassMethods + else if (clazz.isDerivedValueClass) valueClassMethods + else Nil ) def impls = for ((m, impl) <- methods ; if !hasOverridingImplementation(m)) yield impl() diff --git a/src/library/scala/Serializable.scala b/src/library/scala/Serializable.scala index 9be258bb83..9b9456e0a0 100644 --- a/src/library/scala/Serializable.scala +++ b/src/library/scala/Serializable.scala @@ -11,4 +11,4 @@ package scala /** * Classes extending this trait are serializable across platforms (Java, .NET). */ -trait Serializable extends java.io.Serializable +trait Serializable extends Any with java.io.Serializable diff --git a/test/files/pos/t715/meredith_1.scala b/test/files/pos/t715/meredith_1.scala index 3ed2e57d7a..8261b9881a 100644 --- a/test/files/pos/t715/meredith_1.scala +++ b/test/files/pos/t715/meredith_1.scala @@ -3,7 +3,7 @@ package com.sap.dspace.model.othello; import scala.xml._ trait XMLRenderer { - type T <: {def getClass() : java.lang.Class[_]} + type T <: Any {def getClass() : java.lang.Class[_]} val valueTypes = List( classOf[java.lang.Boolean], diff --git a/test/files/run/MeterCaseClass.check b/test/files/run/MeterCaseClass.check new file mode 100644 index 0000000000..08370d2097 --- /dev/null +++ b/test/files/run/MeterCaseClass.check @@ -0,0 +1,21 @@ +2.0 +Meter(4.0) +false +x.isInstanceOf[Meter]: true +x.hashCode: 1072693248 +x == 1: false +x == y: true +a == b: true +testing native arrays +Array(Meter(1.0), Meter(2.0)) +Meter(1.0) +>>>Meter(1.0)<<< Meter(1.0) +>>>Meter(2.0)<<< Meter(2.0) +testing wrapped arrays +FlatArray(Meter(1.0), Meter(2.0)) +Meter(1.0) +>>>Meter(1.0)<<< Meter(1.0) +>>>Meter(2.0)<<< Meter(2.0) +FlatArray(Meter(2.0), Meter(3.0)) +ArrayBuffer(1.0, 2.0) +FlatArray(0.3048ft, 0.6096ft) diff --git a/test/files/run/MeterCaseClass.scala b/test/files/run/MeterCaseClass.scala new file mode 100644 index 0000000000..4f082b5252 --- /dev/null +++ b/test/files/run/MeterCaseClass.scala @@ -0,0 +1,99 @@ +package a { + case class Meter(underlying: Double) extends AnyVal with _root_.b.Printable { + def + (other: Meter): Meter = + new Meter(this.underlying + other.underlying) + def / (other: Meter): Double = this.underlying / other.underlying + def / (factor: Double): Meter = new Meter(this.underlying / factor) + def < (other: Meter): Boolean = this.underlying < other.underlying + def toFoot: Foot = new Foot(this.underlying * 0.3048) + override def print = { Console.print(">>>"); super.print; proprint } + } + + object Meter extends (Double => Meter) { + + implicit val boxings = new BoxingConversions[Meter, Double] { + def box(x: Double) = new Meter(x) + def unbox(m: Meter) = m.underlying + } + } + + class Foot(val unbox: Double) extends AnyVal { + def + (other: Foot): Foot = + new Foot(this.unbox + other.unbox) + override def toString = unbox.toString+"ft" + } + object Foot { + implicit val boxings = new BoxingConversions[Foot, Double] { + def box(x: Double) = new Foot(x) + def unbox(m: Foot) = m.unbox + } + } + +} +package b { + trait Printable extends Any { + def print: Unit = Console.print(this) + protected def proprint = Console.print("<<<") + } +} +import a._ +import _root_.b._ +object Test extends App { + + { + val x: Meter = new Meter(1) + val a: Object = x.asInstanceOf[Object] + val y: Meter = a.asInstanceOf[Meter] + + val u: Double = 1 + val b: Object = u.asInstanceOf[Object] + val v: Double = b.asInstanceOf[Double] + } + + val x = new Meter(1) + val y = x + println((x + x) / x) + println((x + x) / 0.5) + println((x < x).toString) + println("x.isInstanceOf[Meter]: "+x.isInstanceOf[Meter]) + + + println("x.hashCode: "+x.hashCode) + println("x == 1: "+(x == 1)) + println("x == y: "+(x == y)) + assert(x.hashCode == (1.0).hashCode) + + val a: Any = x + val b: Any = y + println("a == b: "+(a == b)) + + { println("testing native arrays") + val arr = Array(x, y + x) + println(arr.deep) + def foo[T <: Printable](x: Array[T]) { + for (i <- 0 until x.length) { x(i).print; println(" "+x(i)) } + } + val m = arr(0) + println(m) + foo(arr) + } + + { println("testing wrapped arrays") + import collection.mutable.FlatArray + val arr = FlatArray(x, y + x) + println(arr) + def foo(x: FlatArray[Meter]) { + for (i <- 0 until x.length) { x(i).print; println(" "+x(i)) } + } + val m = arr(0) + println(m) + foo(arr) + val ys: Seq[Meter] = arr map (_ + new Meter(1)) + println(ys) + val zs = arr map (_ / Meter(1)) + println(zs) + val fs = arr map (_.toFoot) + println(fs) + } + +} -- cgit v1.2.3