/* NSC -- new Scala compiler
* Copyright 2005-2013 LAMP/EPFL
* @author Martin Odersky
*/
package scala.tools.nsc
package typechecker
import scala.language.postfixOps
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import symtab.Flags._
/** Synthetic method implementations for case classes and case objects.
*
* Added to all case classes/objects:
* def productArity: Int
* def productElement(n: Int): Any
* def productPrefix: String
* def productIterator: Iterator[Any]
*
* Selectively added to case classes/objects, unless a non-default
* implementation already exists:
* def equals(other: Any): Boolean
* def hashCode(): Int
* def canEqual(other: Any): Boolean
* def toString(): String
*
* Special handling:
* protected def readResolve(): AnyRef
*/
trait SyntheticMethods extends ast.TreeDSL {
self: Analyzer =>
import global._
import definitions._
import CODE._
private lazy val productSymbols = List(Product_productPrefix, Product_productArity, Product_productElement, Product_iterator, Product_canEqual)
private lazy val valueSymbols = List(Any_hashCode, Any_equals)
private lazy val caseSymbols = List(Object_hashCode, Object_toString) ::: productSymbols
private lazy val caseValueSymbols = Any_toString :: valueSymbols ::: productSymbols
private lazy val caseObjectSymbols = Object_equals :: caseSymbols
private def symbolsToSynthesize(clazz: Symbol): List[Symbol] = {
if (clazz.isCase) {
if (clazz.isDerivedValueClass) caseValueSymbols
else if (clazz.isModuleClass) caseSymbols
else caseObjectSymbols
}
else if (clazz.isDerivedValueClass) valueSymbols
else Nil
}
private lazy val renamedCaseAccessors = perRunCaches.newMap[Symbol, mutable.Map[TermName, TermName]]()
/** Does not force the info of `caseclazz` */
final def caseAccessorName(caseclazz: Symbol, paramName: TermName) =
(renamedCaseAccessors get caseclazz).fold(paramName)(_(paramName))
final def clearRenamedCaseAccessors(caseclazz: Symbol): Unit = {
renamedCaseAccessors -= caseclazz
}
/** Add the synthetic methods to case classes.
*/
def addSyntheticMethods(templ: Template, clazz0: Symbol, context: Context): Template = {
val syntheticsOk = (phase.id <= currentRun.typerPhase.id) && {
symbolsToSynthesize(clazz0) filter (_ matchingSymbol clazz0.info isSynthetic) match {
case Nil => true
case syms => log("Not adding synthetic methods: already has " + syms.mkString(", ")) ; false
}
}
if (!syntheticsOk)
return templ
val synthesizer = new ClassMethodSynthesis(
clazz0,
newTyper( if (reporter.hasErrors) context makeSilent false else context )
)
import synthesizer._
if (clazz0 == AnyValClass || isPrimitiveValueClass(clazz0)) return {
if ((clazz0.info member nme.getClass_).isDeferred) {
// XXX dummy implementation for now
val getClassMethod = createMethod(nme.getClass_, getClassReturnType(clazz.tpe))(_ => NULL)
deriveTemplate(templ)(_ :+ getClassMethod)
}
else templ
}
def accessors = clazz.caseFieldAccessors
val arity = accessors.size
def forwardToRuntime(method: Symbol): Tree =
forwardMethod(method, getMember(ScalaRunTimeModule, (method.name prepend "_")))(mkThis :: _)
def callStaticsMethodName(name: TermName)(args: Tree*): Tree = {
val method = RuntimeStaticsModule.info.member(name)
Apply(gen.mkAttributedRef(method), args.toList)
}
def callStaticsMethod(name: String)(args: Tree*): Tree =
callStaticsMethodName(newTermName(name))(args: _*)
// Any concrete member, including private
def hasConcreteImpl(name: Name) =
clazz.info.member(name).alternatives exists (m => !m.isDeferred)
def hasOverridingImplementation(meth: Symbol) = {
val sym = clazz.info nonPrivateMember meth.name
sym.alternatives exists { m0 =>
(m0 ne meth) && !m0.isDeferred && !m0.isSynthetic && (m0.owner != AnyValClass) && (typeInClazz(m0) matches typeInClazz(meth))
}
}
def productIteratorMethod = {
createMethod(nme.productIterator, iteratorOfType(AnyTpe))(_ =>
gen.mkMethodCall(ScalaRunTimeModule, nme.typedProductIterator, List(AnyTpe), List(mkThis))
)
}
/* Common code for productElement and (currently disabled) productElementName */
def perElementMethod(name: Name, returnType: Type)(caseFn: Symbol => Tree): Tree =
createSwitchMethod(name, accessors.indices, returnType)(idx => caseFn(accessors(idx)))
// def productElementNameMethod = perElementMethod(nme.productElementName, StringTpe)(x => LIT(x.name.toString))
var syntheticCanEqual = false
/* The canEqual method for case classes.
* def canEqual(that: Any) = that.isInstanceOf[This]
*/
def canEqualMethod: Tree = {
syntheticCanEqual = true
createMethod(nme.canEqual_, List(AnyTpe), BooleanTpe) { m =>
Ident(m.firstParam) IS_OBJ classExistentialType(context.prefix, clazz)
}
}
/* that match { case _: this.C => true ; case _ => false }
* where `that` is the given method's first parameter.
*
* An isInstanceOf test is insufficient because it has weaker
* requirements than a pattern match. Given an inner class Foo and
* two different instantiations of the container, an x.Foo and and a y.Foo
* are both .isInstanceOf[Foo], but the one does not match as the other.
*/
def thatTest(eqmeth: Symbol): Tree = {
Match(
Ident(eqmeth.firstParam),
List(
CaseDef(Typed(Ident(nme.WILDCARD), TypeTree(clazz.tpe)), EmptyTree, TRUE),
CaseDef(Ident(nme.WILDCARD), EmptyTree, FALSE)
)
)
}
/* (that.asInstanceOf[this.C])
* where that is the given methods first parameter.
*/
def thatCast(eqmeth: Symbol): Tree =
gen.mkCast(Ident(eqmeth.firstParam), clazz.tpe)
/* The equality method core for case classes and inline classes.
* 1+ args:
* (that.isInstanceOf[this.C]) && {
* val x$1 = that.asInstanceOf[this.C]
* (this.arg_1 == x$1.arg_1) && (this.arg_2 == x$1.arg_2) && ... && (x$1 canEqual this)
* }
* Drop canBuildFrom part if class is final and canBuildFrom is synthesized
*/
def equalsCore(eqmeth: Symbol, accessors: List[Symbol]) = {
val otherName = context.unit.freshTermName(clazz.name + "$")
val otherSym = eqmeth.newValue(otherName, eqmeth.pos, SYNTHETIC) setInfo clazz.tpe
val pairwise = accessors map (acc => fn(Select(mkThis, acc), acc.tpe member nme.EQ, Select(Ident(otherSym), acc)))
val canEq = gen.mkMethodCall(otherSym, nme.canEqual_, Nil, List(mkThis))
val tests = if (clazz.isDerivedValueClass || clazz.isFinal && syntheticCanEqual) pairwise else pairwise :+ canEq
thatTest(eqmeth) AND Block(
ValDef(otherSym, thatCast(eqmeth)),
AND(tests: _*)
)
}
/* The equality method for case classes.
* 0 args:
* def equals(that: Any) = that.isInstanceOf[this.C] && that.asInstanceOf[this.C].canEqual(this)
* 1+ args:
* def equals(that: Any) = (this eq that.asInstanceOf[AnyRef]) || {
* (that.isInstanceOf[this.C]) && {
* val x$1 = that.asInstanceOf[this.C]
* (this.arg_1 == x$1.arg_1) && (this.arg_2 == x$1.arg_2) && ... && (x$1 canEqual this)
* }
* }
*/
def equalsCaseClassMethod: Tree = createMethod(nme.equals_, List(AnyTpe), BooleanTpe) { m =>
if (accessors.isEmpty)
if (clazz.isFinal) thatTest(m)
else thatTest(m) AND ((thatCast(m) DOT nme.canEqual_)(mkThis))
else
(mkThis ANY_EQ Ident(m.firstParam)) OR equalsCore(m, accessors)
}
/* The equality method for value classes
* def equals(that: Any) = (this.asInstanceOf[AnyRef]) eq that.asInstanceOf[AnyRef]) || {
* (that.isInstanceOf[this.C]) && {
* val x$1 = that.asInstanceOf[this.C]
* (this.underlying == that.underlying
*/
def equalsDerivedValueClassMethod: Tree = createMethod(nme.equals_, List(AnyTpe), BooleanTpe) { m =>
equalsCore(m, List(clazz.derivedValueClassUnbox))
}
/* The hashcode method for value classes
* def hashCode(): Int = this.underlying.hashCode
*/
def hashCodeDerivedValueClassMethod: Tree = createMethod(nme.hashCode_, Nil, IntTpe) { m =>
Select(mkThisSelect(clazz.derivedValueClassUnbox), nme.hashCode_)
}
/* The _1, _2, etc. methods to implement ProductN, disabled
* until we figure out how to introduce ProductN without cycles.
*/
/****
def productNMethods = {
val accs = accessors.toIndexedSeq
1 to arity map (num => productProj(arity, num) -> (() => projectionMethod(accs(num - 1), num)))
}
def projectionMethod(accessor: Symbol, num: Int) = {
createMethod(nme.productAccessorName(num), accessor.tpe.resultType)(_ => REF(accessor))
}
****/
// methods for both classes and objects
def productMethods = {
List(
Product_productPrefix -> (() => constantNullary(nme.productPrefix, clazz.name.decode)),
Product_productArity -> (() => constantNullary(nme.productArity, arity)),
Product_productElement -> (() => perElementMethod(nme.productElement, AnyTpe)(mkThisSelect)),
Product_iterator -> (() => productIteratorMethod),
Product_canEqual -> (() => canEqualMethod)
// This is disabled pending a reimplementation which doesn't add any
// weight to case classes (i.e. inspects the bytecode.)
// Product_productElementName -> (() => productElementNameMethod(accessors)),
)
}
def hashcodeImplementation(sym: Symbol): Tree = {
sym.tpe.finalResultType.typeSymbol match {
case UnitClass | NullClass => Literal(Constant(0))
case BooleanClass => If(Ident(sym), Literal(Constant(1231)), Literal(Constant(1237)))
case IntClass => Ident(sym)
case ShortClass | ByteClass | CharClass => Select(Ident(sym), nme.toInt)
case LongClass => callStaticsMethodName(nme.longHash)(Ident(sym))
case DoubleClass => callStaticsMethodName(nme.doubleHash)(Ident(sym))
case FloatClass => callStaticsMethodName(nme.floatHash)(Ident(sym))
case _ => callStaticsMethodName(nme.anyHash)(Ident(sym))
}
}
def specializedHashcode = {
createMethod(nme.hashCode_, Nil, IntTpe) { m =>
val accumulator = m.newVariable(newTermName("acc"), m.pos, SYNTHETIC) setInfo IntTpe
val valdef = ValDef(accumulator, Literal(Constant(0xcafebabe)))
val mixes = accessors map (acc =>
Assign(
Ident(accumulator),
callStaticsMethod("mix")(Ident(accumulator), hashcodeImplementation(acc))
)
)
val finish = callStaticsMethod("finalizeHash")(Ident(accumulator), Literal(Constant(arity)))
Block(valdef :: mixes, finish)
}
}
def chooseHashcode = {
if (accessors exists (x => isPrimitiveValueType(x.tpe.finalResultType)))
specializedHashcode
else
forwardToRuntime(Object_hashCode)
}
def valueClassMethods = List(
Any_hashCode -> (() => hashCodeDerivedValueClassMethod),
Any_equals -> (() => equalsDerivedValueClassMethod)
)
def caseClassMethods = productMethods ++ /*productNMethods ++*/ Seq(
Object_hashCode -> (() => chooseHashcode),
Object_toString -> (() => forwardToRuntime(Object_toString)),
Object_equals -> (() => equalsCaseClassMethod)
)
def valueCaseClassMethods = productMethods ++ /*productNMethods ++*/ valueClassMethods ++ Seq(
Any_toString -> (() => forwardToRuntime(Object_toString))
)
def caseObjectMethods = productMethods ++ Seq(
Object_hashCode -> (() => constantMethod(nme.hashCode_, clazz.name.decode.hashCode)),
Object_toString -> (() => constantMethod(nme.toString_, clazz.name.decode))
// Not needed, as reference equality is the default.
// Object_equals -> (() => createMethod(Object_equals)(m => This(clazz) ANY_EQ Ident(m.firstParam)))
)
/* If you serialize a singleton and then deserialize it twice,
* you will have two instances of your singleton unless you implement
* readResolve. Here it is implemented for all objects which have
* no implementation and which are marked serializable (which is true
* for all case objects.)
*/
def needsReadResolve = (
clazz.isModuleClass
&& clazz.isSerializable
&& !hasConcreteImpl(nme.readResolve)
&& clazz.isStatic
)
def synthesize(): List[Tree] = {
val methods = (
if (clazz.isCase)
if (clazz.isDerivedValueClass) valueCaseClassMethods
else if (clazz.isModuleClass) caseObjectMethods
else caseClassMethods
else if (clazz.isDerivedValueClass) valueClassMethods
else Nil
)
/* Always generate overrides for equals and hashCode in value classes,
* so they can appear in universal traits without breaking value semantics.
*/
def impls = {
def shouldGenerate(m: Symbol) = {
!hasOverridingImplementation(m) || {
clazz.isDerivedValueClass && (m == Any_hashCode || m == Any_equals) && {
// Without a means to suppress this warning, I've thought better of it.
if (settings.warnValueOverrides) {
(clazz.info nonPrivateMember m.name) filter (m => (m.owner != AnyClass) && (m.owner != clazz) && !m.isDeferred) andAlso { m =>
typer.context.warning(clazz.pos, s"Implementation of ${m.name} inherited from ${m.owner} overridden in $clazz to enforce value class semantics")
}
}
true
}
}
}
for ((m, impl) <- methods ; if shouldGenerate(m)) yield impl()
}
def extras = {
if (needsReadResolve) {
// Aha, I finally decoded the original comment.
// This method should be generated as private, but apparently if it is, then
// it is name mangled afterward. (Wonder why that is.) So it's only protected.
// For sure special methods like "readResolve" should not be mangled.
List(createMethod(nme.readResolve, Nil, ObjectTpe)(m => {
m setFlag PRIVATE; REF(clazz.sourceModule)
}))
}
else Nil
}
try impls ++ extras
catch { case _: TypeError if reporter.hasErrors => Nil }
}
/* If this case class has any less than public accessors,
* adds new accessors at the correct locations to preserve ordering.
* Note that this must be done before the other method synthesis
* because synthesized methods need refer to the new symbols.
* Care must also be taken to preserve the case accessor order.
*/
def caseTemplateBody(): List[Tree] = {
val lb = ListBuffer[Tree]()
def isRewrite(sym: Symbol) = sym.isCaseAccessorMethod && !sym.isPublic
for (ddef @ DefDef(_, _, _, _, _, _) <- templ.body ; if isRewrite(ddef.symbol)) {
val original = ddef.symbol
val i = original.owner.caseFieldAccessors.indexOf(original)
def freshAccessorName = {
devWarning(s"Unable to find $original among case accessors of ${original.owner}: ${original.owner.caseFieldAccessors}")
context.unit.freshTermName(original.name + "$")
}
def nameSuffixedByParamIndex = original.name.append(nme.CASE_ACCESSOR + "$" + i).toTermName
val newName = if (i < 0) freshAccessorName else nameSuffixedByParamIndex
val newAcc = deriveMethod(ddef.symbol, name => newName) { newAcc =>
newAcc.makePublic
newAcc resetFlag (ACCESSOR | PARAMACCESSOR | OVERRIDE)
ddef.rhs.duplicate
}
// TODO: shouldn't the next line be: `original resetFlag CASEACCESSOR`?
ddef.symbol resetFlag CASEACCESSOR
lb += logResult("case accessor new")(newAcc)
val renamedInClassMap = renamedCaseAccessors.getOrElseUpdate(clazz, mutable.Map() withDefault(x => x))
renamedInClassMap(original.name.toTermName) = newAcc.symbol.name.toTermName
}
(lb ++= templ.body ++= synthesize()).toList
}
deriveTemplate(templ)(body =>
if (clazz.isCase) caseTemplateBody()
else synthesize() match {
case Nil => body // avoiding unnecessary copy
case ms => body ++ ms
}
)
}
}