From ea93654a6e9674efa9ab98328462d3b5f59ce91f Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Mon, 31 Dec 2012 11:54:31 -0800 Subject: Incorporated Variance value class in Variances. No one will ever know what it took for me to refine Variances into its current condition. A LONELY QUEST. --- src/compiler/scala/tools/nsc/package.scala | 3 + .../scala/tools/nsc/typechecker/Checkable.scala | 11 ++- .../scala/tools/nsc/typechecker/Contexts.scala | 2 +- .../scala/tools/nsc/typechecker/Infer.scala | 24 +++-- .../scala/tools/nsc/typechecker/RefChecks.scala | 50 +++++----- .../tools/nsc/typechecker/TypeDiagnostics.scala | 9 +- .../scala/tools/nsc/typechecker/Typers.scala | 7 +- .../scala/tools/nsc/typechecker/Variances.scala | 102 ++++++--------------- .../scala/reflect/internal/BaseTypeSeqs.scala | 2 +- src/reflect/scala/reflect/internal/Flags.scala | 2 +- src/reflect/scala/reflect/internal/Kinds.scala | 6 +- src/reflect/scala/reflect/internal/Symbols.scala | 18 ++-- src/reflect/scala/reflect/internal/Types.scala | 85 +++++++++-------- 13 files changed, 134 insertions(+), 187 deletions(-) diff --git a/src/compiler/scala/tools/nsc/package.scala b/src/compiler/scala/tools/nsc/package.scala index 00a9f3b39c..940df1bd1a 100644 --- a/src/compiler/scala/tools/nsc/package.scala +++ b/src/compiler/scala/tools/nsc/package.scala @@ -9,6 +9,9 @@ package object nsc { type Phase = scala.reflect.internal.Phase val NoPhase = scala.reflect.internal.NoPhase + type Variance = scala.reflect.internal.Variance + val Variance = scala.reflect.internal.Variance + type FatalError = scala.reflect.internal.FatalError val FatalError = scala.reflect.internal.FatalError diff --git a/src/compiler/scala/tools/nsc/typechecker/Checkable.scala b/src/compiler/scala/tools/nsc/typechecker/Checkable.scala index 9efa3f36b0..fb3909746f 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Checkable.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Checkable.scala @@ -200,11 +200,12 @@ trait Checkable { def isNeverSubClass(sym1: Symbol, sym2: Symbol) = areIrreconcilableAsParents(sym1, sym2) private def isNeverSubArgs(tps1: List[Type], tps2: List[Type], tparams: List[Symbol]): Boolean = /*logResult(s"isNeverSubArgs($tps1, $tps2, $tparams)")*/ { - def isNeverSubArg(t1: Type, t2: Type, variance: Int) = { - if (variance > 0) isNeverSubType(t2, t1) - else if (variance < 0) isNeverSubType(t1, t2) - else isNeverSameType(t1, t2) - } + def isNeverSubArg(t1: Type, t2: Type, variance: Variance) = ( + if (variance.isInvariant) isNeverSameType(t1, t2) + else if (variance.isCovariant) isNeverSubType(t2, t1) + else if (variance.isContravariant) isNeverSubType(t1, t2) + else false + ) exists3(tps1, tps2, tparams map (_.variance))(isNeverSubArg) } private def isNeverSameType(tp1: Type, tp2: Type): Boolean = (tp1, tp2) match { diff --git a/src/compiler/scala/tools/nsc/typechecker/Contexts.scala b/src/compiler/scala/tools/nsc/typechecker/Contexts.scala index e24f0bca1d..274a567075 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Contexts.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Contexts.scala @@ -138,7 +138,7 @@ trait Contexts { self: Analyzer => } var enclMethod: Context = _ // The next outer context whose tree is a method - var variance: Int = _ // Variance relative to enclosing class + var variance: Variance = Variance.Invariant // Variance relative to enclosing class private var _undetparams: List[Symbol] = List() // Undetermined type parameters, // not inherited to child contexts var depth: Int = 0 diff --git a/src/compiler/scala/tools/nsc/typechecker/Infer.scala b/src/compiler/scala/tools/nsc/typechecker/Infer.scala index 7188290688..88f0ccca98 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Infer.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Infer.scala @@ -198,7 +198,7 @@ trait Infer extends Checkable { * @throws NoInstance */ def solvedTypes(tvars: List[TypeVar], tparams: List[Symbol], - variances: List[Int], upper: Boolean, depth: Int): List[Type] = { + variances: List[Variance], upper: Boolean, depth: Int): List[Type] = { if (tvars.nonEmpty) printInference("[solve types] solving for " + tparams.map(_.name).mkString(", ") + " in " + tvars.mkString(", ")) @@ -488,7 +488,7 @@ trait Infer extends Checkable { pt: Type): List[Type] = { /** Map type variable to its instance, or, if `variance` is covariant/contravariant, * to its upper/lower bound */ - def instantiateToBound(tvar: TypeVar, variance: Int): Type = try { + def instantiateToBound(tvar: TypeVar, variance: Variance): Type = { lazy val hiBounds = tvar.constr.hiBounds lazy val loBounds = tvar.constr.loBounds lazy val upper = glb(hiBounds) @@ -501,23 +501,21 @@ trait Infer extends Checkable { //Console.println("instantiate "+tvar+tvar.constr+" variance = "+variance);//DEBUG if (tvar.constr.inst != NoType) instantiate(tvar.constr.inst) - else if ((variance & COVARIANT) != 0 && hiBounds.nonEmpty) - setInst(upper) - else if ((variance & CONTRAVARIANT) != 0 && loBounds.nonEmpty) + else if (loBounds.nonEmpty && variance.isContravariant) setInst(lower) - else if (hiBounds.nonEmpty && loBounds.nonEmpty && upper <:< lower) + else if (hiBounds.nonEmpty && (variance.isPositive || loBounds.nonEmpty && upper <:< lower)) setInst(upper) else WildcardType - } catch { - case ex: NoInstance => WildcardType } val tvars = tparams map freshVar if (isConservativelyCompatible(restpe.instantiateTypeParams(tparams, tvars), pt)) map2(tparams, tvars)((tparam, tvar) => - instantiateToBound(tvar, varianceInTypes(formals)(tparam))) + try instantiateToBound(tvar, varianceInTypes(formals)(tparam)) + catch { case ex: NoInstance => WildcardType } + ) else - tvars map (tvar => WildcardType) + tvars map (_ => WildcardType) } /** [Martin] Can someone comment this please? I have no idea what it's for @@ -571,8 +569,8 @@ trait Infer extends Checkable { foreach3(tparams, tvars, targs) { (tparam, tvar, targ) => val retract = ( - targ.typeSymbol == NothingClass // only retract Nothings - && (restpe.isWildcard || (varianceInType(restpe)(tparam) & COVARIANT) == 0) // don't retract covariant occurrences + targ.typeSymbol == NothingClass // only retract Nothings + && (restpe.isWildcard || !varianceInType(restpe)(tparam).isPositive) // don't retract covariant occurrences ) buf += ((tparam, @@ -1316,7 +1314,7 @@ trait Infer extends Checkable { val tvars1 = tvars map (_.cloneInternal) // Note: right now it's not clear that solving is complete, or how it can be made complete! // So we should come back to this and investigate. - solve(tvars1, tvars1 map (_.origin.typeSymbol), tvars1 map (x => COVARIANT), false) + solve(tvars1, tvars1 map (_.origin.typeSymbol), tvars1 map (_ => Variance.Covariant), false) } // this is quite nasty: it destructively changes the info of the syms of e.g., method type params (see #3692, where the type param T's bounds were set to >: T <: T, so that parts looped) diff --git a/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala b/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala index 12562fecf8..71e04db24c 100644 --- a/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala +++ b/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala @@ -11,6 +11,7 @@ import scala.collection.{ mutable, immutable } import transform.InfoTransform import scala.collection.mutable.ListBuffer import scala.language.postfixOps +import scala.reflect.internal.Variance._ /**

* Post-attribution checking and transformation. @@ -844,10 +845,6 @@ abstract class RefChecks extends InfoTransform with scala.reflect.internal.trans // Variance Checking -------------------------------------------------------- - private val NoVariance = 0 - private val CoVariance = 1 - private val AnyVariance = 2 - private val escapedPrivateLocals = new mutable.HashSet[Symbol] val varianceValidator = new Traverser { @@ -858,41 +855,36 @@ abstract class RefChecks extends InfoTransform with scala.reflect.internal.trans // need to be checked. var inRefinement = false - def varianceString(variance: Int): String = - if (variance == 1) "covariant" - else if (variance == -1) "contravariant" - else "invariant"; - /** The variance of a symbol occurrence of `tvar` * seen at the level of the definition of `base`. * The search proceeds from `base` to the owner of `tvar`. * Initially the state is covariant, but it might change along the search. */ - def relativeVariance(tvar: Symbol): Int = { + def relativeVariance(tvar: Symbol): Variance = { val clazz = tvar.owner var sym = base - var state = CoVariance - while (sym != clazz && state != AnyVariance) { + var state: Variance = Covariant + while (sym != clazz && state != Bivariant) { //Console.println("flip: " + sym + " " + sym.isParameter());//DEBUG // Flip occurrences of type parameters and parameters, unless // - it's a constructor, or case class factory or extractor // - it's a type parameter of tvar's owner. if (sym.isParameter && !sym.owner.isConstructor && !sym.owner.isCaseApplyOrUnapply && !(tvar.isTypeParameterOrSkolem && sym.isTypeParameterOrSkolem && - tvar.owner == sym.owner)) state = -state; + tvar.owner == sym.owner)) state = state.flip else if (!sym.owner.isClass || sym.isTerm && ((sym.isPrivateLocal || sym.isProtectedLocal || sym.isSuperAccessor /* super accessors are implicitly local #4345*/) && !(escapedPrivateLocals contains sym))) { // return AnyVariance if `sym` is local to a term // or is private[this] or protected[this] - state = AnyVariance + state = Bivariant } else if (sym.isAliasType) { - // return AnyVariance if `sym` is an alias type + // return Bivariant if `sym` is an alias type // that does not override anything. This is OK, because we always // expand aliases for variance checking. // However, if `sym` does override a type in a base class - // we have to assume NoVariance, as there might then be + // we have to assume Invariant, as there might then be // references to the type parameter that are not variance checked. - state = if (sym.isOverridingSymbol) NoVariance else AnyVariance + state = if (sym.isOverridingSymbol) Invariant else Bivariant } sym = sym.owner } @@ -902,7 +894,7 @@ abstract class RefChecks extends InfoTransform with scala.reflect.internal.trans /** Validate that the type `tp` is variance-correct, assuming * the type occurs itself at variance position given by `variance` */ - def validateVariance(tp: Type, variance: Int): Unit = tp match { + def validateVariance(tp: Type, variance: Variance): Unit = tp match { case ErrorType => case WildcardType => case BoundedWildcardType(bounds) => @@ -915,19 +907,19 @@ abstract class RefChecks extends InfoTransform with scala.reflect.internal.trans validateVariance(pre, variance) case TypeRef(pre, sym, args) => // println("validate "+sym+" at "+relativeVariance(sym)) - if (sym.isAliasType/* && relativeVariance(sym) == AnyVariance*/) + if (sym.isAliasType/* && relativeVariance(sym) == Bivariant*/) validateVariance(tp.normalize, variance) - else if (sym.variance != NoVariance) { + else if (!sym.variance.isInvariant) { val v = relativeVariance(sym) - if (v != AnyVariance && sym.variance != v * variance) { + if (!v.isBivariant && sym.variance != v * variance) { //Console.println("relativeVariance(" + base + "," + sym + ") = " + v);//DEBUG def tpString(tp: Type) = tp match { case ClassInfoType(parents, _, clazz) => "supertype "+intersectionType(parents, clazz.owner) case _ => "type "+tp } unit.error(base.pos, - varianceString(sym.variance) + " " + sym + - " occurs in " + varianceString(v * variance) + + sym.variance + " " + sym + + " occurs in " + (v * variance) + " position in " + tpString(base.info) + " of " + base); } } @@ -944,14 +936,14 @@ abstract class RefChecks extends InfoTransform with scala.reflect.internal.trans val saved = inRefinement inRefinement = true for (sym <- decls) - validateVariance(sym.info, if (sym.isAliasType) NoVariance else variance) + validateVariance(sym.info, if (sym.isAliasType) Invariant else variance) inRefinement = saved case TypeBounds(lo, hi) => - validateVariance(lo, -variance) + validateVariance(lo, variance.flip) validateVariance(hi, variance) case mt @ MethodType(formals, result) => if (inRefinement) - validateVariances(mt.paramTypes, -variance) + validateVariances(mt.paramTypes, variance.flip) validateVariance(result, variance) case NullaryMethodType(result) => validateVariance(result, variance) @@ -966,15 +958,15 @@ abstract class RefChecks extends InfoTransform with scala.reflect.internal.trans validateVariance(tp, variance) } - def validateVariances(tps: List[Type], variance: Int) { + def validateVariances(tps: List[Type], variance: Variance) { tps foreach (tp => validateVariance(tp, variance)) } - def validateVarianceArgs(tps: List[Type], variance: Int, tparams: List[Symbol]) { + def validateVarianceArgs(tps: List[Type], variance: Variance, tparams: List[Symbol]) { foreach2(tps, tparams)((tp, tparam) => validateVariance(tp, variance * tparam.variance)) } - validateVariance(base.info, CoVariance) + validateVariance(base.info, Covariant) } override def traverse(tree: Tree) { diff --git a/src/compiler/scala/tools/nsc/typechecker/TypeDiagnostics.scala b/src/compiler/scala/tools/nsc/typechecker/TypeDiagnostics.scala index 3bb6ae53dc..56b43eaa41 100644 --- a/src/compiler/scala/tools/nsc/typechecker/TypeDiagnostics.scala +++ b/src/compiler/scala/tools/nsc/typechecker/TypeDiagnostics.scala @@ -169,11 +169,6 @@ trait TypeDiagnostics { case xs => " where " + (disambiguate(xs map (_.existentialToString)) mkString ", ") } - def varianceWord(sym: Symbol): String = - if (sym.variance == 1) "covariant" - else if (sym.variance == -1) "contravariant" - else "invariant" - def explainAlias(tp: Type) = { // Don't automatically normalize standard aliases; they still will be // expanded if necessary to disambiguate simple identifiers. @@ -223,7 +218,7 @@ trait TypeDiagnostics { ) val explainDef = { val prepend = if (isJava) "Java-defined " else "" - "%s%s is %s in %s.".format(prepend, reqsym, varianceWord(param), param) + "%s%s is %s in %s.".format(prepend, reqsym, param.variance, param) } // Don't suggest they change the class declaration if it's somewhere // under scala.* or defined in a java class, because attempting either @@ -243,7 +238,7 @@ trait TypeDiagnostics { || ((arg <:< reqArg) && param.isCovariant) || ((reqArg <:< arg) && param.isContravariant) ) - val invariant = param.variance == 0 + val invariant = param.variance.isInvariant if (conforms) Some("") else if ((arg <:< reqArg) && invariant) mkMsg(true) // covariant relationship diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala index b853f687a7..243b7eaf99 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala @@ -938,7 +938,7 @@ trait Typers extends Modes with Adaptations with Tags { val unapply = unapplyMember(extractor.tpe) val clazz = unapplyParameterType(unapply) - if (unapply.isCase && clazz.isCase && !(clazz.ancestors exists (_.isCase))) { + if (unapply.isCase && clazz.isCase) { // convert synthetic unapply of case class to case class constructor val prefix = tree.tpe.prefix val tree1 = TypeTree(clazz.primaryConstructor.tpe.asSeenFrom(prefix, clazz.owner)) @@ -947,8 +947,9 @@ trait Typers extends Modes with Adaptations with Tags { val skolems = new mutable.ListBuffer[TypeSymbol] object variantToSkolem extends VariantTypeMap { def apply(tp: Type) = mapOver(tp) match { - case TypeRef(NoPrefix, tpSym, Nil) if variance != 0 && tpSym.isTypeParameterOrSkolem && tpSym.owner.isTerm => - val bounds = if (variance == 1) TypeBounds.upper(tpSym.tpe) else TypeBounds.lower(tpSym.tpe) + // !!! FIXME - skipping this when variance.isInvariant allows unsoundness, see SI-5189 + case TypeRef(NoPrefix, tpSym, Nil) if !variance.isInvariant && tpSym.isTypeParameterOrSkolem && tpSym.owner.isTerm => + val bounds = if (variance.isPositive) TypeBounds.upper(tpSym.tpe) else TypeBounds.lower(tpSym.tpe) // origin must be the type param so we can deskolemize val skolem = context.owner.newGADTSkolem(unit.freshTypeName("?"+tpSym.name), tpSym, bounds) // println("mapping "+ tpSym +" to "+ skolem + " : "+ bounds +" -- pt= "+ pt +" in "+ context.owner +" at "+ context.tree ) diff --git a/src/compiler/scala/tools/nsc/typechecker/Variances.scala b/src/compiler/scala/tools/nsc/typechecker/Variances.scala index aa66a8d00a..44c0ca25b0 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Variances.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Variances.scala @@ -6,89 +6,43 @@ package scala.tools.nsc package typechecker -import symtab.Flags.{ VarianceFlags => VARIANCES, _ } +// import scala.reflect.internal.Variance -/** Variances form a lattice, 0 <= COVARIANT <= Variances, 0 <= CONTRAVARIANT <= VARIANCES +/** See comments at scala.reflect.internal.Variance. */ trait Variances { - val global: Global import global._ - - /** Flip between covariant and contravariant */ - private def flip(v: Int): Int = { - if (v == COVARIANT) CONTRAVARIANT - else if (v == CONTRAVARIANT) COVARIANT - else v - } - - /** Map everything below VARIANCES to 0 */ - private def cut(v: Int): Int = - if (v == VARIANCES) v else 0 - - /** Compute variance of type parameter `tparam` in types of all symbols `sym`. */ - def varianceInSyms(syms: List[Symbol])(tparam: Symbol): Int = - (VARIANCES /: syms) ((v, sym) => v & varianceInSym(sym)(tparam)) - - /** Compute variance of type parameter `tparam` in type of symbol `sym`. */ - def varianceInSym(sym: Symbol)(tparam: Symbol): Int = - if (sym.isAliasType) cut(varianceInType(sym.info)(tparam)) - else varianceInType(sym.info)(tparam) + import scala.reflect.internal.Variance._ /** Compute variance of type parameter `tparam` in all types `tps`. */ - def varianceInTypes(tps: List[Type])(tparam: Symbol): Int = - (VARIANCES /: tps) ((v, tp) => v & varianceInType(tp)(tparam)) + def varianceInTypes(tps: List[Type])(tparam: Symbol): Variance = + fold(tps map (tp => varianceInType(tp)(tparam))) - /** Compute variance of type parameter `tparam` in all type arguments - * `tps` which correspond to formal type parameters `tparams1`. - */ - def varianceInArgs(tps: List[Type], tparams1: List[Symbol])(tparam: Symbol): Int = { - var v: Int = VARIANCES; - for ((tp, tparam1) <- tps zip tparams1) { - val v1 = varianceInType(tp)(tparam) - v = v & (if (tparam1.isCovariant) v1 - else if (tparam1.isContravariant) flip(v1) - else cut(v1)) + /** Compute variance of type parameter `tparam` in type `tp`. */ + def varianceInType(tp: Type)(tparam: Symbol): Variance = { + def inArgs(sym: Symbol, args: List[Type]): Variance = fold(map2(args, sym.typeParams)((a, p) => inType(a) * p.variance)) + def inSyms(syms: List[Symbol]): Variance = fold(syms map inSym) + def inTypes(tps: List[Type]): Variance = fold(tps map inType) + + def inSym(sym: Symbol): Variance = if (sym.isAliasType) inType(sym.info).cut else inType(sym.info) + def inType(tp: Type): Variance = tp match { + case ErrorType | WildcardType | NoType | NoPrefix => Bivariant + case ThisType(_) | ConstantType(_) => Bivariant + case TypeRef(_, `tparam`, _) => Covariant + case BoundedWildcardType(bounds) => inType(bounds) + case NullaryMethodType(restpe) => inType(restpe) + case SingleType(pre, sym) => inType(pre) + case TypeRef(pre, _, _) if tp.isHigherKinded => inType(pre) // a type constructor cannot occur in tp's args + case TypeRef(pre, sym, args) => inType(pre) & inArgs(sym, args) + case TypeBounds(lo, hi) => inType(lo).flip & inType(hi) + case RefinedType(parents, defs) => inTypes(parents) & inSyms(defs.toList) + case MethodType(params, restpe) => inSyms(params).flip & inType(restpe) + case PolyType(tparams, restpe) => inSyms(tparams).flip & inType(restpe) + case ExistentialType(tparams, restpe) => inSyms(tparams) & inType(restpe) + case AnnotatedType(annots, tp, _) => inTypes(annots map (_.atp)) & inType(tp) } - v - } - - /** Compute variance of type parameter `tparam` in all type annotations `annots`. */ - def varianceInAttribs(annots: List[AnnotationInfo])(tparam: Symbol): Int = { - (VARIANCES /: annots) ((v, annot) => v & varianceInAttrib(annot)(tparam)) - } - /** Compute variance of type parameter `tparam` in type annotation `annot`. */ - def varianceInAttrib(annot: AnnotationInfo)(tparam: Symbol): Int = { - varianceInType(annot.atp)(tparam) - } - - /** Compute variance of type parameter `tparam` in type `tp`. */ - def varianceInType(tp: Type)(tparam: Symbol): Int = tp match { - case ErrorType | WildcardType | NoType | NoPrefix | ThisType(_) | ConstantType(_) => - VARIANCES - case BoundedWildcardType(bounds) => - varianceInType(bounds)(tparam) - case SingleType(pre, sym) => - varianceInType(pre)(tparam) - case TypeRef(pre, sym, args) => - if (sym == tparam) COVARIANT - // tparam cannot occur in tp's args if tp is a type constructor (those don't have args) - else if (tp.isHigherKinded) varianceInType(pre)(tparam) - else varianceInType(pre)(tparam) & varianceInArgs(args, sym.typeParams)(tparam) - case TypeBounds(lo, hi) => - flip(varianceInType(lo)(tparam)) & varianceInType(hi)(tparam) - case RefinedType(parents, defs) => - varianceInTypes(parents)(tparam) & varianceInSyms(defs.toList)(tparam) - case MethodType(params, restpe) => - flip(varianceInSyms(params)(tparam)) & varianceInType(restpe)(tparam) - case NullaryMethodType(restpe) => - varianceInType(restpe)(tparam) - case PolyType(tparams, restpe) => - flip(varianceInSyms(tparams)(tparam)) & varianceInType(restpe)(tparam) - case ExistentialType(tparams, restpe) => - varianceInSyms(tparams)(tparam) & varianceInType(restpe)(tparam) - case AnnotatedType(annots, tp, _) => - varianceInAttribs(annots)(tparam) & varianceInType(tp)(tparam) + inType(tp) } } diff --git a/src/reflect/scala/reflect/internal/BaseTypeSeqs.scala b/src/reflect/scala/reflect/internal/BaseTypeSeqs.scala index 18a4a36840..9daf9504f1 100644 --- a/src/reflect/scala/reflect/internal/BaseTypeSeqs.scala +++ b/src/reflect/scala/reflect/internal/BaseTypeSeqs.scala @@ -64,7 +64,7 @@ trait BaseTypeSeqs { //Console.println("compute closure of "+this+" => glb("+variants+")") pending += i try { - mergePrefixAndArgs(variants, -1, lubDepth(variants)) match { + mergePrefixAndArgs(variants, Variance.Contravariant, lubDepth(variants)) match { case Some(tp0) => pending(i) = false elems(i) = tp0 diff --git a/src/reflect/scala/reflect/internal/Flags.scala b/src/reflect/scala/reflect/internal/Flags.scala index 654dbd76e7..6aa7eab689 100644 --- a/src/reflect/scala/reflect/internal/Flags.scala +++ b/src/reflect/scala/reflect/internal/Flags.scala @@ -508,4 +508,4 @@ class Flags extends ModifierFlags { final val rawFlagPickledOrder: Array[Long] = pickledListOrder.toArray } -object Flags extends Flags { } +object Flags extends Flags diff --git a/src/reflect/scala/reflect/internal/Kinds.scala b/src/reflect/scala/reflect/internal/Kinds.scala index 08686832ef..5fecc06128 100644 --- a/src/reflect/scala/reflect/internal/Kinds.scala +++ b/src/reflect/scala/reflect/internal/Kinds.scala @@ -93,8 +93,8 @@ trait Kinds { * If `sym2` is invariant, `sym1`'s variance is irrelevant. Otherwise they must be equal. */ private def variancesMatch(sym1: Symbol, sym2: Symbol) = ( - sym2.variance==0 - || sym1.variance==sym2.variance + sym2.variance.isInvariant + || sym1.variance == sym2.variance ) /** Check well-kindedness of type application (assumes arities are already checked) -- @M @@ -229,4 +229,4 @@ trait Kinds { } } } -} \ No newline at end of file +} diff --git a/src/reflect/scala/reflect/internal/Symbols.scala b/src/reflect/scala/reflect/internal/Symbols.scala index ae68c1bcfd..ec280f5413 100644 --- a/src/reflect/scala/reflect/internal/Symbols.scala +++ b/src/reflect/scala/reflect/internal/Symbols.scala @@ -12,6 +12,7 @@ import util.{ Statistics, shortClassOfInstance } import Flags._ import scala.annotation.tailrec import scala.reflect.io.{ AbstractFile, NoAbstractFile } +import Variance._ trait Symbols extends api.Symbols { self: SymbolTable => import definitions._ @@ -165,10 +166,7 @@ trait Symbols extends api.Symbols { self: SymbolTable => def debugFlagString: String = flagString(AllFlags) /** String representation of symbol's variance */ - def varianceString: String = - if (variance == 1) "+" - else if (variance == -1) "-" - else "" + def varianceString: String = variance.symbolicString override def flagMask = if (settings.debug.value && !isAbstractType) AllFlags @@ -890,11 +888,11 @@ trait Symbols extends api.Symbols { self: SymbolTable => return true } - /** The variance of this symbol as an integer */ - final def variance: Int = - if (isCovariant) 1 - else if (isContravariant) -1 - else 0 + /** The variance of this symbol. */ + final def variance: Variance = + if (isCovariant) Covariant + else if (isContravariant) Contravariant + else Invariant /** The sequence number of this parameter symbol among all type * and value parameters of symbol's owner. -1 if symbol does not @@ -3288,7 +3286,7 @@ trait Symbols extends api.Symbols { self: SymbolTable => // ----- Hoisted closures and convenience methods, for compile time reductions ------- private[scala] final val symbolIsPossibleInRefinement = (sym: Symbol) => sym.isPossibleInRefinement - private[scala] final val symbolIsNonVariant = (sym: Symbol) => sym.variance == 0 + private[scala] final val symbolIsNonVariant = (sym: Symbol) => sym.variance.isInvariant @tailrec private[scala] final def allSymbolsHaveOwner(syms: List[Symbol], owner: Symbol): Boolean = syms match { diff --git a/src/reflect/scala/reflect/internal/Types.scala b/src/reflect/scala/reflect/internal/Types.scala index 68f51e97bd..605e4e983a 100644 --- a/src/reflect/scala/reflect/internal/Types.scala +++ b/src/reflect/scala/reflect/internal/Types.scala @@ -16,6 +16,7 @@ import scala.annotation.tailrec import util.Statistics import scala.runtime.ObjectRef import util.ThreeValues._ +import Variance._ /* A standard type pattern match: case ErrorType => @@ -2807,7 +2808,7 @@ trait Types extends api.Types { self: SymbolTable => val tvars = quantifiedFresh map (tparam => TypeVar(tparam)) val underlying1 = underlying.instantiateTypeParams(quantified, tvars) // fuse subst quantified -> quantifiedFresh -> tvars op(underlying1) && { - solve(tvars, quantifiedFresh, quantifiedFresh map (x => 0), false, depth) && + solve(tvars, quantifiedFresh, quantifiedFresh map (_ => Invariant), false, depth) && isWithinBounds(NoPrefix, NoSymbol, quantifiedFresh, tvars map (_.constr.inst)) } } @@ -3928,17 +3929,17 @@ trait Types extends api.Types { self: SymbolTable => } trait VariantTypeMap extends TypeMap { - private[this] var _variance = 1 + private[this] var _variance: Variance = Covariant override def variance = _variance - def variance_=(x: Int) = _variance = x + def variance_=(x: Variance) = _variance = x override protected def noChangeToSymbols(origSyms: List[Symbol]) = //OPT inline from forall to save on #closures origSyms match { case sym :: rest => val v = variance - if (sym.isAliasType) variance = 0 + if (sym.isAliasType) variance = Invariant val result = this(sym.info) variance = v (result eq sym.info) && noChangeToSymbols(rest) @@ -3949,8 +3950,8 @@ trait Types extends api.Types { self: SymbolTable => override protected def mapOverArgs(args: List[Type], tparams: List[Symbol]): List[Type] = map2Conserve(args, tparams) { (arg, tparam) => val v = variance - if (tparam.isContravariant) variance = -variance - else if (!tparam.isCovariant) variance = 0 + if (tparam.isContravariant) variance = variance.flip + else if (!tparam.isCovariant) variance = Invariant val arg1 = this(arg) variance = v arg1 @@ -3959,23 +3960,23 @@ trait Types extends api.Types { self: SymbolTable => /** Map this function over given type */ override def mapOver(tp: Type): Type = tp match { case MethodType(params, result) => - variance = -variance + variance = variance.flip val params1 = mapOver(params) - variance = -variance + variance = variance.flip val result1 = this(result) if ((params1 eq params) && (result1 eq result)) tp else copyMethodType(tp, params1, result1.substSym(params, params1)) case PolyType(tparams, result) => - variance = -variance + variance = variance.flip val tparams1 = mapOver(tparams) - variance = -variance + variance = variance.flip val result1 = this(result) if ((tparams1 eq tparams) && (result1 eq result)) tp else PolyType(tparams1, result1.substSym(tparams, tparams1)) case TypeBounds(lo, hi) => - variance = -variance + variance = variance.flip val lo1 = this(lo) - variance = -variance + variance = variance.flip val hi1 = this(hi) if ((lo1 eq lo) && (hi1 eq hi)) tp else TypeBounds(lo1, hi1) @@ -3984,7 +3985,7 @@ trait Types extends api.Types { self: SymbolTable => val args1 = if (args.isEmpty) args - else if (variance == 0) // fast & safe path: don't need to look at typeparams + else if (variance.isInvariant) // fast & safe path: don't need to look at typeparams args mapConserve this else { val tparams = sym.typeParams @@ -4007,7 +4008,7 @@ trait Types extends api.Types { self: SymbolTable => /** Mix in VariantTypeMap if you want variances to be significant. */ - def variance = 0 + def variance: Variance = Invariant /** Map this function over given type */ def mapOver(tp: Type): Type = tp match { @@ -4280,10 +4281,10 @@ trait Types extends api.Types { self: SymbolTable => def apply(tp: Type): Type = { val tp1 = mapOver(tp) - if (variance == 0) tp1 + if (variance.isInvariant) tp1 else tp1 match { case TypeRef(pre, sym, args) if tparams contains sym => - val repl = if (variance == 1) dropSingletonType(tp1.bounds.hi) else tp1.bounds.lo + val repl = if (variance.isPositive) dropSingletonType(tp1.bounds.hi) else tp1.bounds.lo //println("eliminate "+sym+"/"+repl+"/"+occurCount(sym)+"/"+(tparams exists (repl.contains)))//DEBUG if (!repl.typeSymbol.isBottomClass && occurCount(sym) == 1 && !(tparams exists (repl.contains))) repl @@ -5059,17 +5060,19 @@ trait Types extends api.Types { self: SymbolTable => def isConsistent(tp1: Type, tp2: Type): Boolean = (tp1, tp2) match { case (TypeRef(pre1, sym1, args1), TypeRef(pre2, sym2, args2)) => assert(sym1 == sym2, (sym1, sym2)) - pre1 =:= pre2 && - forall3(args1, args2, sym1.typeParams)((arg1, arg2, tparam) => - if (tparam.variance == 0) arg1 =:= arg2 - // if left-hand argument is a typevar, make it compatible with variance - // this is for more precise pattern matching - // todo: work this in the spec of this method - // also: think what happens if there are embedded typevars? - else arg1 match { - case _: TypeVar => if (tparam.variance < 0) arg1 <:< arg2 else arg2 <:< arg1 - case _ => true - } + ( pre1 =:= pre2 + && forall3(args1, args2, sym1.typeParams) { (arg1, arg2, tparam) => + // if left-hand argument is a typevar, make it compatible with variance + // this is for more precise pattern matching + // todo: work this in the spec of this method + // also: think what happens if there are embedded typevars? + if (tparam.variance.isInvariant) + arg1 =:= arg2 + else !arg1.isInstanceOf[TypeVar] || { + if (tparam.variance.isContravariant) arg1 <:< arg2 + else arg2 <:< arg1 + } + } ) case (et: ExistentialType, _) => et.withTypeVars(isConsistent(_, tp2)) @@ -5639,9 +5642,11 @@ trait Types extends api.Types { self: SymbolTable => })) def isSubArgs(tps1: List[Type], tps2: List[Type], tparams: List[Symbol], depth: Int): Boolean = { - def isSubArg(t1: Type, t2: Type, variance: Int) = - (variance > 0 || isSubType(t2, t1, depth)) && - (variance < 0 || isSubType(t1, t2, depth)) + def isSubArg(t1: Type, t2: Type, variance: Variance) = ( + (variance.isContravariant || isSubType(t1, t2, depth)) + && (variance.isCovariant || isSubType(t2, t1, depth)) + ) + corresponds3(tps1, tps2, tparams map (_.variance))(isSubArg) } @@ -6037,15 +6042,15 @@ trait Types extends api.Types { self: SymbolTable => * @param upper When `true` search for max solution else min. */ def solve(tvars: List[TypeVar], tparams: List[Symbol], - variances: List[Int], upper: Boolean): Boolean = + variances: List[Variance], upper: Boolean): Boolean = solve(tvars, tparams, variances, upper, AnyDepth) def solve(tvars: List[TypeVar], tparams: List[Symbol], - variances: List[Int], upper: Boolean, depth: Int): Boolean = { + variances: List[Variance], upper: Boolean, depth: Int): Boolean = { - def solveOne(tvar: TypeVar, tparam: Symbol, variance: Int) { + def solveOne(tvar: TypeVar, tparam: Symbol, variance: Variance) { if (tvar.constr.inst == NoType) { - val up = if (variance != CONTRAVARIANT) upper else !upper + val up = if (variance.isContravariant) !upper else upper tvar.constr.inst = null val bound: Type = if (up) tparam.info.bounds.hi else tparam.info.bounds.lo //Console.println("solveOne0(tv, tp, v, b)="+(tvar, tparam, variance, bound)) @@ -6228,7 +6233,7 @@ trait Types extends api.Types { self: SymbolTable => else t } val tails = tsBts map (_.tail) - mergePrefixAndArgs(elimSub(ts1, depth) map elimHigherOrderTypeParam, 1, depth) match { + mergePrefixAndArgs(elimSub(ts1, depth) map elimHigherOrderTypeParam, Covariant, depth) match { case Some(tp) => loop(tp :: pretypes, tails) case _ => loop(pretypes, tails) } @@ -6711,18 +6716,18 @@ trait Types extends api.Types { self: SymbolTable => finally foreach2(tvs, saved)(_.suspended = _) } - /** Compute lub (if `variance == 1`) or glb (if `variance == -1`) of given list + /** Compute lub (if `variance == Covariant`) or glb (if `variance == Contravariant`) of given list * of types `tps`. All types in `tps` are typerefs or singletypes * with the same symbol. * Return `Some(x)` if the computation succeeds with result `x`. * Return `None` if the computation fails. */ - def mergePrefixAndArgs(tps: List[Type], variance: Int, depth: Int): Option[Type] = tps match { + def mergePrefixAndArgs(tps: List[Type], variance: Variance, depth: Int): Option[Type] = tps match { case List(tp) => Some(tp) case TypeRef(_, sym, _) :: rest => val pres = tps map (_.prefix) // prefix normalizes automatically - val pre = if (variance == 1) lub(pres, depth) else glb(pres, depth) + val pre = if (variance.isPositive) lub(pres, depth) else glb(pres, depth) val argss = tps map (_.normalize.typeArgs) // symbol equality (of the tp in tps) was checked using typeSymbol, which normalizes, so should normalize before retrieving arguments val capturedParams = new ListBuffer[Symbol] try { @@ -6760,7 +6765,7 @@ trait Types extends api.Types { self: SymbolTable => } else { if (tparam.variance == variance) lub(as, decr(depth)) - else if (tparam.variance == -variance) glb(as, decr(depth)) + else if (tparam.variance == variance.flip) glb(as, decr(depth)) else { val l = lub(as, decr(depth)) val g = glb(as, decr(depth)) @@ -6784,7 +6789,7 @@ trait Types extends api.Types { self: SymbolTable => } case SingleType(_, sym) :: rest => val pres = tps map (_.prefix) - val pre = if (variance == 1) lub(pres, depth) else glb(pres, depth) + val pre = if (variance.isPositive) lub(pres, depth) else glb(pres, depth) try { Some(singleType(pre, sym)) } catch { -- cgit v1.2.3