From 882f8e640b034cc69b122ac221f75cbe0018e2c3 Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Tue, 1 Jan 2013 15:57:49 -0800 Subject: Eliminated VariantTypeMap. Made variance tracking a constructor parameter to TypeMap. Eliminated more variance duplication. --- .../scala/tools/nsc/typechecker/RefChecks.scala | 2 +- .../scala/tools/nsc/typechecker/Typers.scala | 2 +- src/reflect/scala/reflect/internal/Types.scala | 81 ++++++++++--- src/reflect/scala/reflect/internal/Variance.scala | 5 + src/reflect/scala/reflect/internal/Variances.scala | 127 +++++---------------- 5 files changed, 98 insertions(+), 119 deletions(-) diff --git a/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala b/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala index 704aebc50b..750a4f65ec 100644 --- a/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala +++ b/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala @@ -1459,7 +1459,7 @@ abstract class RefChecks extends InfoTransform with scala.reflect.internal.trans hidden = o.isTerm || o.isPrivateLocal o = o.owner } - if (!hidden) varianceValidator.escapedPrivateLocals += sym + if (!hidden) varianceValidator.escapedLocals += sym } def checkSuper(mix: Name) = diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala index 243b7eaf99..416eaa9996 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala @@ -945,7 +945,7 @@ trait Typers extends Modes with Adaptations with Tags { .setOriginal(tree) val skolems = new mutable.ListBuffer[TypeSymbol] - object variantToSkolem extends VariantTypeMap { + object variantToSkolem extends TypeMap(isTrackingVariance = true) { def apply(tp: Type) = mapOver(tp) match { // !!! FIXME - skipping this when variance.isInvariant allows unsoundness, see SI-5189 case TypeRef(NoPrefix, tpSym, Nil) if !variance.isInvariant && tpSym.isTypeParameterOrSkolem && tpSym.owner.isTerm => diff --git a/src/reflect/scala/reflect/internal/Types.scala b/src/reflect/scala/reflect/internal/Types.scala index ec7537daf3..f613b6fa17 100644 --- a/src/reflect/scala/reflect/internal/Types.scala +++ b/src/reflect/scala/reflect/internal/Types.scala @@ -3932,18 +3932,25 @@ trait Types extends api.Types { self: SymbolTable => /** A prototype for mapping a function over all possible types */ - abstract class TypeMap extends (Type => Type) { + abstract class TypeMap(isTrackingVariance: Boolean) extends (Type => Type) { + def this() = this(isTrackingVariance = false) def apply(tp: Type): Type - /** Mix in VariantTypeMap if you want variances to be significant. - */ - def variance: Variance = Invariant + private[this] var _variance: Variance = if (isTrackingVariance) Covariant else Invariant + + def variance_=(x: Variance) = { assert(isTrackingVariance, this) ; _variance = x } + def variance = _variance /** Map this function over given type */ def mapOver(tp: Type): Type = tp match { case tr @ TypeRef(pre, sym, args) => val pre1 = this(pre) - val args1 = args mapConserve this + val args1 = ( + if (isTrackingVariance && args.nonEmpty && !variance.isInvariant && sym.typeParams.nonEmpty) + mapOverArgs(args, sym.typeParams) + else + args mapConserve this + ) if ((pre1 eq pre) && (args1 eq args)) tp else copyTypeRef(tp, pre1, tr.coevolveSym(pre1), args1) case ThisType(_) => tp @@ -3955,12 +3962,12 @@ trait Types extends api.Types { self: SymbolTable => else singleType(pre1, sym) } case MethodType(params, result) => - val params1 = mapOver(params) + val params1 = flipped(mapOver(params)) val result1 = this(result) if ((params1 eq params) && (result1 eq result)) tp else copyMethodType(tp, params1, result1.substSym(params, params1)) case PolyType(tparams, result) => - val tparams1 = mapOver(tparams) + val tparams1 = flipped(mapOver(tparams)) val result1 = this(result) if ((tparams1 eq tparams) && (result1 eq result)) tp else PolyType(tparams1, result1.substSym(tparams, tparams1)) @@ -3975,7 +3982,7 @@ trait Types extends api.Types { self: SymbolTable => if ((thistp1 eq thistp) && (supertp1 eq supertp)) tp else SuperType(thistp1, supertp1) case TypeBounds(lo, hi) => - val lo1 = this(lo) + val lo1 = flipped(this(lo)) val hi1 = this(hi) if ((lo1 eq lo) && (hi1 eq hi)) tp else TypeBounds(lo1, hi1) @@ -3998,7 +4005,7 @@ trait Types extends api.Types { self: SymbolTable => else OverloadedType(pre1, alts) case AntiPolyType(pre, args) => val pre1 = this(pre) - val args1 = args mapConserve (this) + val args1 = args mapConserve this if ((pre1 eq pre) && (args1 eq args)) tp else AntiPolyType(pre1, args1) case tv@TypeVar(_, constr) => @@ -4026,8 +4033,33 @@ trait Types extends api.Types { self: SymbolTable => // throw new Error("mapOver inapplicable for " + tp); } - protected def mapOverArgs(args: List[Type], tparams: List[Symbol]): List[Type] = - args mapConserve this + private def flip() = if (isTrackingVariance) variance = variance.flip + @inline final def flipped[T](body: => T): T = { + flip() + try body finally flip() + } + @inline final def varyOn(tparam: Symbol)(body: => Type): Type = { + val saved = variance + variance *= tparam.variance + try body finally variance = saved + } + protected def mapOverArgs(args: List[Type], tparams: List[Symbol]): List[Type] = ( + if (isTrackingVariance) + map2Conserve(args, tparams)((arg, tparam) => varyOn(tparam)(this(arg))) + else + args mapConserve this + ) + private def isInfoUnchanged(sym: Symbol) = { + val forceInvariance = isTrackingVariance && !variance.isInvariant && sym.isAliasType + val result = if (forceInvariance) { + val saved = variance + variance = Invariant + try this(sym.info) finally variance = saved + } + else this(sym.info) + + (sym.info eq result) + } /** Called by mapOver to determine whether the original symbols can * be returned, or whether they must be cloned. Overridden in VariantTypeMap. @@ -4035,7 +4067,7 @@ trait Types extends api.Types { self: SymbolTable => protected def noChangeToSymbols(origSyms: List[Symbol]): Boolean = { @tailrec def loop(syms: List[Symbol]): Boolean = syms match { case Nil => true - case x :: xs => (x.info eq this(x.info)) && loop(xs) + case x :: xs => isInfoUnchanged(x) && loop(xs) } loop(origSyms) } @@ -4189,7 +4221,7 @@ trait Types extends api.Types { self: SymbolTable => /** Used by existentialAbstraction. */ - class ExistentialExtrapolation(tparams: List[Symbol]) extends VariantTypeMap { + class ExistentialExtrapolation(tparams: List[Symbol]) extends TypeMap(isTrackingVariance = true) { private val occurCount = mutable.HashMap[Symbol, Int]() private def countOccs(tp: Type) = { tp foreach { @@ -4208,16 +4240,29 @@ trait Types extends api.Types { self: SymbolTable => apply(tpe) } + /** If these conditions all hold: + * 1) we are in covariant (or contravariant) position + * 2) this type occurs exactly once in the existential scope + * 3) the widened upper (or lower) bound of this type contains no references to tparams + * Then we replace this lone occurrence of the type with the widened upper (or lower) bound. + * All other types pass through unchanged. + */ def apply(tp: Type): Type = { val tp1 = mapOver(tp) if (variance.isInvariant) tp1 else tp1 match { case TypeRef(pre, sym, args) if tparams contains sym => val repl = if (variance.isPositive) dropSingletonType(tp1.bounds.hi) else tp1.bounds.lo - //println("eliminate "+sym+"/"+repl+"/"+occurCount(sym)+"/"+(tparams exists (repl.contains)))//DEBUG - if (!repl.typeSymbol.isBottomClass && occurCount(sym) == 1 && !(tparams exists (repl.contains))) - repl - else tp1 + val count = occurCount(sym) + val containsTypeParam = tparams exists (repl contains _) + def msg = { + val word = if (variance.isPositive) "upper" else "lower" + s"Widened lone occurrence of $tp1 inside existential to $word bound" + } + if (!repl.typeSymbol.isBottomClass && count == 1 && !containsTypeParam) + logResult(msg)(repl) + else + tp1 case _ => tp1 } @@ -5954,7 +5999,7 @@ trait Types extends api.Types { self: SymbolTable => * `f` maps all elements to themselves. */ def map2Conserve[A <: AnyRef, B](xs: List[A], ys: List[B])(f: (A, B) => A): List[A] = - if (xs.isEmpty) xs + if (xs.isEmpty || ys.isEmpty) xs else { val x1 = f(xs.head, ys.head) val xs1 = map2Conserve(xs.tail, ys.tail)(f) diff --git a/src/reflect/scala/reflect/internal/Variance.scala b/src/reflect/scala/reflect/internal/Variance.scala index f5b994d69e..007d56eb35 100644 --- a/src/reflect/scala/reflect/internal/Variance.scala +++ b/src/reflect/scala/reflect/internal/Variance.scala @@ -74,6 +74,11 @@ final class Variance private (val flags: Int) extends AnyVal { } object Variance { + implicit class SbtCompat(val v: Variance) { + def < (other: Int) = v.flags < other + def > (other: Int) = v.flags > other + } + def fold(variances: List[Variance]): Variance = ( if (variances.isEmpty) Bivariant else variances reduceLeft (_ & _) diff --git a/src/reflect/scala/reflect/internal/Variances.scala b/src/reflect/scala/reflect/internal/Variances.scala index ca2c52a41e..a54420dd94 100644 --- a/src/reflect/scala/reflect/internal/Variances.scala +++ b/src/reflect/scala/reflect/internal/Variances.scala @@ -14,82 +14,11 @@ import scala.collection.{ mutable, immutable } trait Variances { self: SymbolTable => - /** Used by ExistentialExtrapolation and adaptConstrPattern(). - * TODO - eliminate duplication with all the rest. - */ - trait VariantTypeMap extends TypeMap { - private[this] var _variance: Variance = Covariant - - override def variance = _variance - def variance_=(x: Variance) = _variance = x - - override protected def noChangeToSymbols(origSyms: List[Symbol]) = - //OPT inline from forall to save on #closures - origSyms match { - case sym :: rest => - val v = variance - if (sym.isAliasType) variance = Invariant - val result = this(sym.info) - variance = v - (result eq sym.info) && noChangeToSymbols(rest) - case _ => - true - } - - override protected def mapOverArgs(args: List[Type], tparams: List[Symbol]): List[Type] = - map2Conserve(args, tparams) { (arg, tparam) => - val saved = variance - variance *= tparam.variance - try this(arg) finally variance = saved - } - - /** Map this function over given type */ - override def mapOver(tp: Type): Type = tp match { - case MethodType(params, result) => - variance = variance.flip - val params1 = mapOver(params) - variance = variance.flip - val result1 = this(result) - if ((params1 eq params) && (result1 eq result)) tp - else copyMethodType(tp, params1, result1.substSym(params, params1)) - case PolyType(tparams, result) => - variance = variance.flip - val tparams1 = mapOver(tparams) - variance = variance.flip - val result1 = this(result) - if ((tparams1 eq tparams) && (result1 eq result)) tp - else PolyType(tparams1, result1.substSym(tparams, tparams1)) - case TypeBounds(lo, hi) => - variance = variance.flip - val lo1 = this(lo) - variance = variance.flip - val hi1 = this(hi) - if ((lo1 eq lo) && (hi1 eq hi)) tp - else TypeBounds(lo1, hi1) - case tr @ TypeRef(pre, sym, args) => - val pre1 = this(pre) - val args1 = - if (args.isEmpty) - args - else if (variance.isInvariant) // fast & safe path: don't need to look at typeparams - args mapConserve this - else { - val tparams = sym.typeParams - if (tparams.isEmpty) args - else mapOverArgs(args, tparams) - } - if ((pre1 eq pre) && (args1 eq args)) tp - else copyTypeRef(tp, pre1, tr.coevolveSym(pre1), args1) - case _ => - super.mapOver(tp) - } - } - /** Used in Refchecks. * TODO - eliminate duplication with varianceInType */ class VarianceValidator extends Traverser { - val escapedPrivateLocals = mutable.HashSet[Symbol]() + val escapedLocals = mutable.HashSet[Symbol]() protected def issueVarianceError(base: Symbol, sym: Symbol, required: Variance): Unit = () @@ -104,10 +33,10 @@ trait Variances { ) // return Bivariant if `sym` is local to a term // or is private[this] or protected[this] - private def isLocalOnly(sym: Symbol) = !sym.owner.isClass || ( + def isLocalOnly(sym: Symbol) = !sym.owner.isClass || ( sym.isTerm && (sym.isPrivateLocal || sym.isProtectedLocal || sym.isSuperAccessor) // super accessors are implicitly local #4345 - && !escapedPrivateLocals(sym) + && !escapedLocals(sym) ) /** Validate variance of info of symbol `base` */ @@ -126,20 +55,18 @@ trait Variances { * because there may be references to the type parameter that are not checked. */ def relativeVariance(tvar: Symbol): Variance = { - val clazz = tvar.owner - var sym = base - var state: Variance = Covariant - while (sym != clazz && !state.isBivariant) { - if (isFlipped(sym, tvar)) - state = state.flip - else if (isLocalOnly(sym)) - state = Bivariant - else if (sym.isAliasType) - state = if (sym.isOverridingSymbol) Invariant else Bivariant - - sym = sym.owner - } - state + def nextVariance(sym: Symbol, v: Variance): Variance = ( + if (isFlipped(sym, tvar)) v.flip + else if (isLocalOnly(sym)) Bivariant + else if (!sym.isAliasType) v + else if (sym.isOverridingSymbol) Invariant + else Bivariant + ) + def loop(sym: Symbol, v: Variance): Variance = ( + if (sym == tvar.owner || v.isBivariant) v + else loop(sym.owner, nextVariance(sym, v)) + ) + loop(base, Covariant) } /** Validate that the type `tp` is variance-correct, assuming @@ -156,15 +83,17 @@ trait Variances { case ConstantType(_) => case SingleType(pre, sym) => validateVariance(pre, variance) + case TypeRef(_, sym, _) if sym.isAliasType => validateVariance(tp.normalize, variance) + case TypeRef(pre, sym, args) => -// println("validate "+sym+" at "+relativeVariance(sym)) - if (sym.isAliasType/* && relativeVariance(sym) == Bivariant*/) - validateVariance(tp.normalize, variance) - else if (!sym.variance.isInvariant) { - val v = relativeVariance(sym) - val requiredVariance = v * variance - if (!v.isBivariant && sym.variance != requiredVariance) - issueVarianceError(base, sym, requiredVariance) + if (!sym.variance.isInvariant) { + val relative = relativeVariance(sym) + val required = relative * variance + if (!relative.isBivariant) { + log(s"verifying $sym (${sym.variance}${sym.locationString}) is $required at $base in ${base.owner}") + if (sym.variance != required) + issueVarianceError(base, sym, required) + } } validateVariance(pre, variance) // @M for higher-kinded typeref, args.isEmpty @@ -213,18 +142,18 @@ trait Variances { } override def traverse(tree: Tree) { - // def local = tree.symbol.hasLocalFlag + def local = tree.symbol.hasLocalFlag tree match { case ClassDef(_, _, _, _) | TypeDef(_, _, _, _) => validateVariance(tree.symbol) super.traverse(tree) // ModuleDefs need not be considered because they have been eliminated already case ValDef(_, _, _, _) => - if (!tree.symbol.hasLocalFlag) + if (!local) validateVariance(tree.symbol) case DefDef(_, _, tparams, vparamss, _, _) => // No variance check for object-private/protected methods/values. - if (!tree.symbol.hasLocalFlag) { + if (!local) { validateVariance(tree.symbol) traverseTrees(tparams) traverseTreess(vparamss) -- cgit v1.2.3