summaryrefslogblamecommitdiff
path: root/src/compiler/scala/tools/nsc/typechecker/Variances.scala
blob: b9f2b9abd824884495a3dfd4ff1015aa22c52fef (plain) (tree)
1
2
3
4
5
6
7
8
9
                            
                                

                          
 

                       
 
                                                     
 
                                                                                          
   
                 
 

                    
 
                                                 


                                          


          
                                            

                                
 
                                                                                   

                                                                    
 
                                                                             
                                                       
                                                              
                                         
 
                                                                        

                                                                  
 

                                                                             




                                                                                      

                                                         




                            
                                                                                      

                                                                              

   
                                                                                

                                                                      

   

                                                                                        
                                                                                        
               
                                
                                 

                                   


                                                                                             



                                                                            

                                                                           

                                     

                                                                            

                                                                      

                                                                    

   
/* NSC -- new scala compiler
 * Copyright 2005-2012 LAMP/EPFL
 * @author  Martin Odersky
 */

package scala.tools.nsc
package typechecker

import symtab.Flags.{ VarianceFlags => VARIANCES, _ }

/** Variances form a lattice, 0 <= COVARIANT <= Variances, 0 <= CONTRAVARIANT <= VARIANCES
 */
trait Variances {

  val global: Global
  import global._

  /** Flip between covariant and contravariant */
  private def flip(v: Int): Int = {
    if (v == COVARIANT) CONTRAVARIANT
    else if (v == CONTRAVARIANT) COVARIANT
    else v
  }

  /** Map everything below VARIANCES to 0 */
  private def cut(v: Int): Int =
    if (v == VARIANCES) v else 0

  /** Compute variance of type parameter `tparam` in types of all symbols `sym`. */
  def varianceInSyms(syms: List[Symbol])(tparam: Symbol): Int =
    (VARIANCES /: syms) ((v, sym) => v & varianceInSym(sym)(tparam))

  /** Compute variance of type parameter `tparam` in type of symbol `sym`. */
  def varianceInSym(sym: Symbol)(tparam: Symbol): Int =
    if (sym.isAliasType) cut(varianceInType(sym.info)(tparam))
    else varianceInType(sym.info)(tparam)

  /** Compute variance of type parameter `tparam` in all types `tps`. */
  def varianceInTypes(tps: List[Type])(tparam: Symbol): Int =
    (VARIANCES /: tps) ((v, tp) => v & varianceInType(tp)(tparam))

  /** Compute variance of type parameter `tparam` in all type arguments
   *  <code>tps</code> which correspond to formal type parameters `tparams1`.
   */
  def varianceInArgs(tps: List[Type], tparams1: List[Symbol])(tparam: Symbol): Int = {
    var v: Int = VARIANCES;
    for ((tp, tparam1) <- tps zip tparams1) {
      val v1 = varianceInType(tp)(tparam)
      v = v & (if (tparam1.isCovariant) v1
	       else if (tparam1.isContravariant) flip(v1)
	       else cut(v1))
    }
    v
  }

  /** Compute variance of type parameter `tparam` in all type annotations `annots`. */
  def varianceInAttribs(annots: List[AnnotationInfo])(tparam: Symbol): Int = {
    (VARIANCES /: annots) ((v, annot) => v & varianceInAttrib(annot)(tparam))
  }

  /** Compute variance of type parameter `tparam` in type annotation `annot`. */
  def varianceInAttrib(annot: AnnotationInfo)(tparam: Symbol): Int = {
    varianceInType(annot.atp)(tparam)
  }

  /** Compute variance of type parameter <code>tparam</code> in type <code>tp</code>. */
  def varianceInType(tp: Type)(tparam: Symbol): Int = tp match {
    case ErrorType | WildcardType | NoType | NoPrefix | ThisType(_) | ConstantType(_) =>
      VARIANCES
    case SingleType(pre, sym) =>
      varianceInType(pre)(tparam)
    case TypeRef(pre, sym, args) =>
      if (sym == tparam) COVARIANT
      // tparam cannot occur in tp's args if tp is a type constructor (those don't have args)
      else if (tp.isHigherKinded) varianceInType(pre)(tparam)
      else varianceInType(pre)(tparam) & varianceInArgs(args, sym.typeParams)(tparam)
    case TypeBounds(lo, hi) =>
      flip(varianceInType(lo)(tparam)) & varianceInType(hi)(tparam)
    case RefinedType(parents, defs) =>
      varianceInTypes(parents)(tparam) & varianceInSyms(defs.toList)(tparam)
    case MethodType(params, restpe) =>
      flip(varianceInSyms(params)(tparam)) & varianceInType(restpe)(tparam)
    case NullaryMethodType(restpe) =>
      varianceInType(restpe)(tparam)
    case PolyType(tparams, restpe) =>
      flip(varianceInSyms(tparams)(tparam)) & varianceInType(restpe)(tparam)
    case ExistentialType(tparams, restpe) =>
      varianceInSyms(tparams)(tparam) & varianceInType(restpe)(tparam)
    case AnnotatedType(annots, tp, _) =>
      varianceInAttribs(annots)(tparam) & varianceInType(tp)(tparam)
  }
}