summaryrefslogblamecommitdiff
path: root/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala
blob: cf905779597840de59db6c642a09fec9cf394601 (plain) (tree)
1
2
3
4
5
6
7
8
9
                            
                                
                         
   
 

                       
 
                   
                     
                               
                                          
 
















                                                                       
   


                                            


                      
 




                                                                                          
 
                                               



                                                                               
 




                                                                        
                                                                               


                              









                                                                                 
                             



                                                                                         
       

                       

                                                
                                                                                                     









                                                                                        
     
                                                                                    
                                                 



                                                                                      
                                 
                                                                         
                                                                                                            



                                                                                             

     
                                                                                  
       


                                                                                            
                                                                                                                          
 

                                                           
       

                                                                             
                                                               
     
 









                                                                                                       
       
                                                                                                        
                                         
                                                                                        
                                                 


                                                                    
                                                                                 








                                                                                                                                
         





                                                                
 





                                                                                                     
 




                                                                                                 
                                                                                                    





                                                                                      

     



                                                                    

     


                                                                                           
                                                          
                                                                                                            
     
 
                                                                   



                                                                          
       

                            


                                          
 



















                                                                                                                         
     
 












                                                                                                          
                           





                                                     
 
                                               
     








                                                            

   
/* NSC -- new Scala compiler
 * Copyright 2005-2011 LAMP/EPFL
 * @author Martin Odersky
 */

package scala.tools.nsc
package typechecker

import symtab.Flags
import symtab.Flags._
import scala.collection.mutable
import scala.collection.mutable.ListBuffer

/** Synthetic method implementations for case classes and case objects.
 *
 *  Added to all case classes/objects:
 *    def productArity: Int
 *    def productElement(n: Int): Any
 *    def productPrefix: String
 *    def productIterator: Iterator[Any]
 *
 *  Selectively added to case classes/objects, unless a non-default
 *  implementation already exists:
 *    def equals(other: Any): Boolean
 *    def hashCode(): Int
 *    def canEqual(other: Any): Boolean
 *    def toString(): String
 *
 *  Special handling:
 *    protected def readResolve(): AnyRef
 */
trait SyntheticMethods extends ast.TreeDSL {
  self: Analyzer =>

  import global._
  import definitions._
  import CODE._

  /** Add the synthetic methods to case classes.
   */
  def addSyntheticMethods(templ: Template, clazz0: Symbol, context: Context): Template = {
    if (phase.erasedTypes)
      return templ

    val synthesizer = new ClassMethodSynthesis(
      clazz0,
      newTyper( if (reporter.hasErrors) context makeSilent false else context )
    )
    import synthesizer._

    val originalAccessors = clazz.caseFieldAccessors
    // private ones will have been renamed -- make sure they are entered
    // in the original order.
    def accessors = clazz.caseFieldAccessors sortBy { acc =>
      originalAccessors indexWhere { orig =>
        (acc.name == orig.name) || (acc.name startsWith (orig.name append "$"))
      }
    }
    val arity = accessors.size
    // If this is ProductN[T1, T2, ...], accessorLub is the lub of T1, T2, ..., .
    // !!! Hidden behind -Xexperimental due to bummer type inference bugs.
    // Refining from Iterator[Any] leads to types like
    //
    //    Option[Int] { def productIterator: Iterator[String] }
    //
    // appearing legitimately, but this breaks invariant places
    // like Manifests and Arrays which are not robust and infer things
    // which they shouldn't.
    val accessorLub  = (
      if (opt.experimental) {
        global.weakLub(accessors map (_.tpe.finalResultType))._1 match {
          case RefinedType(parents, decls) if !decls.isEmpty => intersectionType(parents)
          case tp                                            => tp
        }
      }
      else AnyClass.tpe
    )

    def forwardToRuntime(method: Symbol): Tree =
      forwardMethod(method, getMember(ScalaRunTimeModule, method.name prepend "_"))(This(clazz) :: _)

    // Any member, including private
    def hasConcreteImpl(name: Name) =
      clazz.info.member(name).alternatives exists (m => !m.isDeferred && !m.isSynthetic)

    def hasOverridingImplementation(meth: Symbol) = {
      val sym = clazz.info nonPrivateMember meth.name
      sym.alternatives filterNot (_ eq meth) exists { m0 =>
        !m0.isDeferred && !m0.isSynthetic && (typeInClazz(m0) matches typeInClazz(meth))
      }
    }
    def readConstantValue[T](name: String, default: T = null.asInstanceOf[T]): T = {
      clazzMember(newTermName(name)).info match {
        case NullaryMethodType(ConstantType(Constant(value))) => value.asInstanceOf[T]
        case _                                                => default
      }
    }
    def productIteratorMethod = {
      createMethod(nme.productIterator, iteratorOfType(accessorLub))(_ =>
        gen.mkMethodCall(ScalaRunTimeModule, nme.typedProductIterator, List(accessorLub), List(This(clazz)))
      )
    }
    def projectionMethod(accessor: Symbol, num: Int) = {
      createMethod(nme.productAccessorName(num), accessor.tpe.resultType)(_ => REF(accessor))
    }

    /** Common code for productElement and (currently disabled) productElementName
     */
    def perElementMethod(name: Name, returnType: Type)(caseFn: Symbol => Tree): Tree =
      createSwitchMethod(name, accessors.indices, returnType)(idx => caseFn(accessors(idx)))

    // def productElementNameMethod = perElementMethod(nme.productElementName, StringClass.tpe)(x => LIT(x.name.toString))

    /** The canEqual method for case classes.
     *    def canEqual(that: Any) = that.isInstanceOf[This]
     */
    def canEqualMethod: Tree = (
      createMethod(nme.canEqual_, List(AnyClass.tpe), BooleanClass.tpe)(m => 
        Ident(m.firstParam) IS_OBJ classExistentialType(clazz))
    )

    /** The equality method for case classes.
     *  0 args:
     *    def equals(that: Any) = that.isInstanceOf[this.C] && that.asInstanceOf[this.C].canEqual(this)
     *  1+ args:
     *    def equals(that: Any) = (this eq that.asInstanceOf[AnyRef]) || {
     *      (that.isInstanceOf[this.C]) && {
     *        val x$1 = that.asInstanceOf[this.C]
     *        (this.arg_1 == x$1.arg_1) && (this.arg_2 == x$1.arg_2) && ... && (x$1 canEqual this)
     *       }
     *    }
     */
    def equalsClassMethod: Tree = createMethod(nme.equals_, List(AnyClass.tpe), BooleanClass.tpe) { m =>
      val arg0      = Ident(m.firstParam)
      val thatTest  = gen.mkIsInstanceOf(arg0, classExistentialType(clazz), true, false)
      val thatCast  = gen.mkCast(arg0, clazz.tpe)

      def argsBody: Tree = {
        val otherName = context.unit.freshTermName(clazz.name + "$")
        val otherSym  = m.newValue(otherName, m.pos, SYNTHETIC) setInfo clazz.tpe
        val pairwise  = accessors map (acc => fn(Select(This(clazz), acc), acc.tpe member nme.EQ, Select(Ident(otherSym), acc)))
        val canEq     = gen.mkMethodCall(otherSym, nme.canEqual_, Nil, List(This(clazz)))
        def block     = Block(ValDef(otherSym, thatCast), AND(pairwise :+ canEq: _*))

        (This(clazz) ANY_EQ arg0) OR {
          thatTest AND Block(
            ValDef(otherSym, thatCast),
            AND(pairwise :+ canEq: _*)
          )
        }
      }
      if (accessors.isEmpty)
        thatTest AND ((thatCast DOT nme.canEqual_)(This(clazz)))
      else
        argsBody
    }

    /** The _1, _2, etc. methods to implement ProductN.
     */
    def productNMethods = {
      val accs = accessors.toIndexedSeq
      1 to arity map (num => productProj(arity, num) -> (() => projectionMethod(accs(num - 1), num)))
    }

    // methods for both classes and objects
    def productMethods = {
      List(
        Product_productPrefix   -> (() => constantNullary(nme.productPrefix, clazz.name.decode)),
        Product_productArity    -> (() => constantNullary(nme.productArity, arity)),
        Product_productElement  -> (() => perElementMethod(nme.productElement, accessorLub)(Ident)),
        Product_iterator        -> (() => productIteratorMethod),
        Product_canEqual        -> (() => canEqualMethod)
        // This is disabled pending a reimplementation which doesn't add any
        // weight to case classes (i.e. inspects the bytecode.)
        // Product_productElementName  -> (() => productElementNameMethod(accessors)),
      )
    }

    def caseClassMethods = productMethods ++ productNMethods ++ Seq(
      Object_hashCode -> (() => forwardToRuntime(Object_hashCode)),
      Object_toString -> (() => forwardToRuntime(Object_toString)),
      Object_equals   -> (() => equalsClassMethod)
    )

    def caseObjectMethods = productMethods ++ Seq(
      Object_hashCode -> (() => constantMethod(nme.hashCode_, clazz.name.decode.hashCode)),
      Object_toString -> (() => constantMethod(nme.toString_, clazz.name.decode))
      // Not needed, as reference equality is the default.
      // Object_equals   -> (() => createMethod(Object_equals)(m => This(clazz) ANY_EQ Ident(m.firstParam)))
    )

    /** If you serialize a singleton and then deserialize it twice,
     *  you will have two instances of your singleton unless you implement
     *  readResolve.  Here it is implemented for all objects which have
     *  no implementation and which are marked serializable (which is true
     *  for all case objects.)
     */
    def needsReadResolve = (
         clazz.isModuleClass
      && clazz.isSerializable
      && !hasConcreteImpl(nme.readResolve)
    )

    def synthesize(): List[Tree] = {
      val methods = (
        if (!clazz.isCase) Nil
        else if (clazz.isModuleClass) caseObjectMethods
        else caseClassMethods
      )
      def impls = for ((m, impl) <- methods ; if !hasOverridingImplementation(m)) yield impl()
      def extras = (
        if (needsReadResolve) {
          // Aha, I finally decoded the original comment.
          // This method should be generated as private, but apparently if it is, then
          // it is name mangled afterward.  (Wonder why that is.) So it's only protected.
          // For sure special methods like "readResolve" should not be mangled.
          List(createMethod(nme.readResolve, Nil, ObjectClass.tpe)(m => { m setFlag PRIVATE ; REF(clazz.sourceModule) }))
        }
        else Nil
      )

      try impls ++ extras
      catch { case _: TypeError if reporter.hasErrors => Nil }
    }

    /** If this case class has any less than public accessors,
     *  adds new accessors at the correct locations to preserve ordering.
     *  Note that this must be done before the other method synthesis
     *  because synthesized methods need refer to the new symbols.
     *  Care must also be taken to preserve the case accessor order.
     */
    def caseTemplateBody(): List[Tree] = {
      val lb = ListBuffer[Tree]()
      def isRewrite(sym: Symbol) = sym.isCaseAccessorMethod && !sym.isPublic

      for (ddef @ DefDef(_, _, _, _, _, _) <- templ.body ; if isRewrite(ddef.symbol)) {
        val original = ddef.symbol
        val newAcc = deriveMethod(ddef.symbol, name => context.unit.freshTermName(name + "$")) { newAcc =>
          newAcc.makePublic
          newAcc resetFlag (ACCESSOR | PARAMACCESSOR)
          ddef.rhs.duplicate
        }
        ddef.symbol resetFlag CASEACCESSOR
        lb += logResult("case accessor new")(newAcc)
      }

      lb ++= templ.body ++= synthesize() toList
    }

    if (phase.id > currentRun.typerPhase.id) templ
    else treeCopy.Template(templ, templ.parents, templ.self,
      if (clazz.isCase) caseTemplateBody()
      else synthesize() match {
        case Nil  => templ.body // avoiding unnecessary copy
        case ms   => templ.body ++ ms
      }
    )
  }
}