diff options
-rw-r--r-- | src/dotty/tools/dotc/Compiler.scala | 2 | ||||
-rw-r--r-- | src/dotty/tools/dotc/core/Definitions.scala | 2 | ||||
-rw-r--r-- | src/dotty/tools/dotc/core/Types.scala | 6 | ||||
-rw-r--r-- | src/dotty/tools/dotc/transform/SyntheticMethods.scala | 174 | ||||
-rw-r--r-- | test/dotc/tests.scala | 2 |
5 files changed, 183 insertions, 3 deletions
diff --git a/src/dotty/tools/dotc/Compiler.scala b/src/dotty/tools/dotc/Compiler.scala index a63a69b9b..13fcc8a70 100644 --- a/src/dotty/tools/dotc/Compiler.scala +++ b/src/dotty/tools/dotc/Compiler.scala @@ -19,7 +19,7 @@ class Compiler { def phases: List[List[Phase]] = List( List(new FrontEnd), - List(new FirstTransform), + List(new FirstTransform, new SyntheticMethods), List(new SuperAccessors), // pickling goes here List(/*new RefChecks, */new ElimRepeated, new ElimLocals), diff --git a/src/dotty/tools/dotc/core/Definitions.scala b/src/dotty/tools/dotc/core/Definitions.scala index 123e6e671..82ed3ce1d 100644 --- a/src/dotty/tools/dotc/core/Definitions.scala +++ b/src/dotty/tools/dotc/core/Definitions.scala @@ -423,6 +423,8 @@ class Definitions { lazy val RootImports = List[Symbol](JavaLangPackageVal, ScalaPackageVal, ScalaPredefModule, DottyPredefModule) + lazy val overriddenBySynthetic = Set[Symbol](Any_equals, Any_hashCode, Any_toString, Product_canEqual) + def isTupleType(tp: Type)(implicit ctx: Context) = { val arity = tp.dealias.argInfos.length arity <= MaxTupleArity && (tp isRef TupleClass(arity)) diff --git a/src/dotty/tools/dotc/core/Types.scala b/src/dotty/tools/dotc/core/Types.scala index 8ec5c7295..50c729df1 100644 --- a/src/dotty/tools/dotc/core/Types.scala +++ b/src/dotty/tools/dotc/core/Types.scala @@ -1149,7 +1149,11 @@ object Types { (lastSymbol eq null) || (lastSymbol.defRunId != sym.defRunId) || (lastSymbol.defRunId == NoRunId) || - (lastSymbol.infoOrCompleter == ErrorType), + (lastSymbol.infoOrCompleter == ErrorType || + defn.overriddenBySynthetic.contains(lastSymbol) + // for overriddenBySynthetic symbols a TermRef such as SomeCaseClass.this.hashCode + // might be rewritten from Object#hashCode to the hashCode generated at SyntheticMethods + ), s"data race? overwriting symbol of $this / ${this.getClass} / ${lastSymbol.id} / ${sym.id}") protected def sig: Signature = Signature.NotAMethod diff --git a/src/dotty/tools/dotc/transform/SyntheticMethods.scala b/src/dotty/tools/dotc/transform/SyntheticMethods.scala new file mode 100644 index 000000000..4f9b2c9fa --- /dev/null +++ b/src/dotty/tools/dotc/transform/SyntheticMethods.scala @@ -0,0 +1,174 @@ +package dotty.tools.dotc +package transform + +import core._ +import Symbols._, Types._, Contexts._, Names._, StdNames._, Constants._ +import scala.collection.{ mutable, immutable } +import Flags._ +import TreeTransforms._ +import DenotTransformers._ +import ast.Trees._ +import ast.untpd +import Decorators._ +import ValueClasses.isDerivedValueClass +import scala.collection.mutable.ListBuffer +import scala.language.postfixOps + +/** Synthetic method implementations for case classes, case objects, + * and value classes. + * Selectively added to case classes/objects, unless a non-default + * implementation already exists: + * def equals(other: Any): Boolean + * def hashCode(): Int + * def canEqual(other: Any): Boolean + * def toString(): String + * Special handling: + * protected def readResolve(): AnyRef + * + * Selectively added to value classes, unless a non-default + * implementation already exists: + * + * def equals(other: Any): Boolean + * def hashCode(): Int + */ +class SyntheticMethods extends MiniPhaseTransform with IdentityDenotTransformer { thisTransformer => + import ast.tpd._ + + val name = "synthetics" + + private var valueSymbols: List[Symbol] = _ + private var caseSymbols: List[Symbol] = _ + + override def init(implicit ctx: Context, info: TransformerInfo) = { + valueSymbols = List(defn.Any_hashCode, defn.Any_equals) + caseSymbols = valueSymbols ++ List(defn.Any_toString, defn.Product_canEqual) + } + + /** The synthetic methods of the case or value class `clazz`. + */ + def syntheticMethods(clazz: ClassSymbol)(implicit ctx: Context): List[Tree] = { + val clazzType = clazz.typeRef + def accessors = clazz.decls.filter(_ is CaseAccessor) + + val symbolsToSynthesize: List[Symbol] = + if (clazz.is(Case)) caseSymbols + else if (isDerivedValueClass(clazz)) valueSymbols + else Nil + + def syntheticDefIfMissing(sym: Symbol): List[Tree] = { + val existing = sym.matchingMember(clazz.thisType) + if (existing == sym || existing.is(Deferred)) syntheticDef(sym) :: Nil + else Nil + } + + def syntheticDef(sym: Symbol): Tree = { + val synthetic = sym.copy( + owner = clazz, + flags = sym.flags &~ Deferred | Synthetic | Override, + coord = clazz.coord).enteredAfter(thisTransformer).asTerm + + def forwardToRuntime(vrefss: List[List[Tree]]): Tree = + ref(defn.runtimeMethod("_" + sym.name.toString)).appliedToArgs(This(clazz) :: vrefss.head) + + def syntheticRHS(implicit ctx: Context): List[List[Tree]] => Tree = synthetic.name match { + case nme.hashCode_ => vrefss => hashCodeBody + case nme.toString_ => forwardToRuntime + case nme.equals_ => vrefss => equalsBody(vrefss.head.head) + case nme.canEqual_ => vrefss => canEqualBody(vrefss.head.head) + } + ctx.log(s"adding $synthetic to $clazz at ${ctx.phase}") + DefDef(synthetic, syntheticRHS(ctx.withOwner(synthetic))) + } + + /** The class + * + * case class C(x: T, y: U) + * + * gets the equals method: + * + * def equals(that: Any): Boolean = + * (this eq that) || { + * that match { + * case x$0 @ (_: C) => this.x == this$0.x && this.y == x$0.y + * case _ => false + * } + * + * If C is a value class the initial `eq` test is omitted. + */ + def equalsBody(that: Tree)(implicit ctx: Context): Tree = { + val thatAsClazz = ctx.newSymbol(ctx.owner, nme.x_0, Synthetic, clazzType, coord = ctx.owner.pos) // x$0 + def wildcardAscription(tp: Type) = + Typed(untpd.Ident(nme.WILDCARD).withType(tp), TypeTree(tp)) + val pattern = Bind(thatAsClazz, wildcardAscription(clazzType)) // x$0 @ (_: C) + val comparisons = accessors map (accessor => + This(clazz).select(accessor).select(defn.Any_==).appliedTo(ref(thatAsClazz).select(accessor))) + val rhs = // this.x == this$0.x && this.y == x$0.y + if (comparisons.isEmpty) Literal(Constant(true)) else comparisons.reduceLeft(_ and _) + val matchingCase = CaseDef(pattern, EmptyTree, rhs) // case x$0 @ (_: C) => this.x == this$0.x && this.y == x$0.y + val defaultCase = CaseDef(wildcardAscription(defn.AnyType), EmptyTree, Literal(Constant(false))) // case _ => false + val matchExpr = Match(that, List(matchingCase, defaultCase)) + if (isDerivedValueClass(clazz)) matchExpr + else { + val eqCompare = This(clazz).select(defn.Object_eq).appliedTo(that.asInstance(defn.ObjectType)) + eqCompare or matchExpr + } + } + + /** The class + * + * case class C(x: T, y: T) + * + * get the hashCode method: + * + * def hashCode: Int = { + * <synthetic> var acc: Int = 0xcafebabe; + * acc = Statics.mix(acc, x); + * acc = Statics.mix(acc, Statics.this.anyHash(y)); + * Statics.finalizeHash(acc, 2) + * } + */ + def hashCodeBody(implicit ctx: Context): Tree = { + val acc = ctx.newSymbol(ctx.owner, "acc".toTermName, Mutable | Synthetic, defn.IntType, coord = ctx.owner.pos) + val accDef = ValDef(acc, Literal(Constant(0xcafebabe))) + val mixes = for (accessor <- accessors.toList) yield + Assign(ref(acc), ref(defn.staticsMethod("mix")).appliedTo(ref(acc), hashImpl(accessor))) + val finish = ref(defn.staticsMethod("finalizeHash")).appliedTo(ref(acc), Literal(Constant(accessors.size))) + Block(accDef :: mixes, finish) + } + + /** The hashCode implementation for given symbol `sym`. */ + def hashImpl(sym: Symbol)(implicit ctx: Context): Tree = { + val d = defn + import d._ + sym.info.finalResultType.typeSymbol match { + case UnitClass | NullClass => Literal(Constant(0)) + case BooleanClass => If(ref(sym), Literal(Constant(1231)), Literal(Constant(1237))) + case IntClass => ref(sym) + case ShortClass | ByteClass | CharClass => ref(sym).select(nme.toInt) + case LongClass => ref(staticsMethod("longHash")).appliedTo(ref(sym)) + case DoubleClass => ref(staticsMethod("doubleHash")).appliedTo(ref(sym)) + case FloatClass => ref(staticsMethod("floatHash")).appliedTo(ref(sym)) + case _ => ref(staticsMethod("anyHash")).appliedTo(ref(sym)) + } + } + + /** The class + * + * case class C(...) + * + * gets the canEqual method + * + * def canEqual(that: Any) = that.isInstanceOf[C] + */ + def canEqualBody(that: Tree): Tree = that.isInstance(clazzType) + + symbolsToSynthesize flatMap syntheticDefIfMissing + } + + override def transformTemplate(impl: Template)(implicit ctx: Context, info: TransformerInfo) = + if (ctx.owner.is(Case) || isDerivedValueClass(ctx.owner)) + cpy.Template(impl, impl.constr, impl.parents, impl.self, + impl.body ++ syntheticMethods(ctx.owner.asClass)(ctx.withPhase(thisTransformer.next))) + else + impl +} diff --git a/test/dotc/tests.scala b/test/dotc/tests.scala index 8bbe30ad8..6035b661a 100644 --- a/test/dotc/tests.scala +++ b/test/dotc/tests.scala @@ -14,7 +14,7 @@ class tests extends CompilerTest { "-pagewidth", "160") implicit val defaultOptions = noCheckOptions ++ List( - "-Ycheck:tailrec" + "-Ycheck:synthetic,tailrec" ) val twice = List("#runs", "2", "-YnoDoubleBindings") |