From e33a7385268084138e3d51faffdff33b540ad942 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Sun, 20 Jul 2014 14:48:34 +0200 Subject: Enabled variance checking Variance checking is now run as part of type-checking. Fixed tests that exhibited variance errors. Added tests where some classes of variance errors should be detected. --- src/dotty/tools/dotc/core/Types.scala | 1 + src/dotty/tools/dotc/typer/CheckVariances.scala | 130 --------------------- src/dotty/tools/dotc/typer/Typer.scala | 1 + src/dotty/tools/dotc/typer/VarianceChecker.scala | 138 +++++++++++++++++++++++ src/dotty/tools/dotc/typer/Variances.scala | 12 +- 5 files changed, 151 insertions(+), 131 deletions(-) delete mode 100644 src/dotty/tools/dotc/typer/CheckVariances.scala create mode 100644 src/dotty/tools/dotc/typer/VarianceChecker.scala (limited to 'src') diff --git a/src/dotty/tools/dotc/core/Types.scala b/src/dotty/tools/dotc/core/Types.scala index 19141a873..474958b86 100644 --- a/src/dotty/tools/dotc/core/Types.scala +++ b/src/dotty/tools/dotc/core/Types.scala @@ -2203,6 +2203,7 @@ object Types { override def variance = 1 override def toString = "Co" + super.toString } + final class ContraTypeBounds(lo: Type, hi: Type, hc: Int) extends CachedTypeBounds(lo, hi, hc) { override def variance = -1 override def toString = "Contra" + super.toString diff --git a/src/dotty/tools/dotc/typer/CheckVariances.scala b/src/dotty/tools/dotc/typer/CheckVariances.scala deleted file mode 100644 index f21a99566..000000000 --- a/src/dotty/tools/dotc/typer/CheckVariances.scala +++ /dev/null @@ -1,130 +0,0 @@ -package dotty.tools.dotc -package transform - -import dotty.tools.dotc.ast.{ Trees, tpd } -import core._ -import Types._, Contexts._, Flags._, Symbols._, Annotations._, Trees._ -import Decorators._ -import Variances._ - -object VarianceChecker { - - case class VarianceError(tvar: Symbol, required: Variance) -} - -/** See comments at scala.reflect.internal.Variance. - */ -class VarianceChecker(implicit ctx: Context) { - import VarianceChecker._ - import tpd._ - - private object Validator extends TypeAccumulator[Option[VarianceError]] { - private var base: Symbol = _ - - /** The variance of a symbol occurrence of `tvar` seen at the level of the definition of `base`. - * The search proceeds from `base` to the owner of `tvar`. - * Initially the state is covariant, but it might change along the search. - */ - def relativeVariance(tvar: Symbol, base: Symbol, v: Variance = Covariant): Variance = { - if (base.owner == tvar.owner) v - else if ((base is Param) && base.owner.isTerm) relativeVariance(tvar, base.owner.owner, flip(v)) - else if (base.isTerm) Bivariant - else if (base.isAliasType) relativeVariance(tvar, base.owner, Invariant) - else relativeVariance(tvar, base.owner, v) - } - - def isUncheckedVariance(tp: Type): Boolean = tp match { - case AnnotatedType(annot, tp1) => - annot.symbol == defn.UncheckedVarianceAnnot || isUncheckedVariance(tp1) - case _ => false - } - - private def checkVarianceOfSymbol(tvar: Symbol): Option[VarianceError] = { - val relative = relativeVariance(tvar, base) - val required = Variances.compose(relative, this.variance) - if (relative == Bivariant) None - else { - def tvar_s = s"$tvar (${tvar.variance}${tvar.showLocated})" - def base_s = s"$base in ${base.owner}" + (if (base.owner.isClass) "" else " in " + base.owner.enclosingClass) - ctx.log(s"verifying $tvar_s is $required at $base_s") - if (tvar.variance == required) None - else Some(VarianceError(tvar, required)) - } - } - - /** For PolyTypes, type parameters are skipped because they are defined - * explicitly (their TypeDefs will be passed here.) For MethodTypes, the - * same is true of the parameters (ValDefs) unless we are inside a - * refinement, in which case they are checked from here. - */ - def apply(status: Option[VarianceError], tp: Type): Option[VarianceError] = - if (status.isDefined) status - else tp match { - case tp: TypeRef => - val sym = tp.symbol - if (sym.variance != 0 && base.isContainedIn(sym.owner)) checkVarianceOfSymbol(sym) - else if (sym.isAliasType) this(status, sym.info) - else foldOver(status, tp) - case tp: MethodType => - this(status, tp.resultType) // params will be checked in their TypeDef nodes. - case tp: PolyType => - this(status, tp.resultType) // params will be checked in their ValDef nodes. - case AnnotatedType(annot, _) if annot.symbol == defn.UncheckedVarianceAnnot => - status - case tp: ClassInfo => - ??? - case _ => - foldOver(status, tp) - } - - def validateDefinition(base: Symbol): Option[VarianceError] = { - val saved = this.base - this.base = base - try apply(None, base.info) - finally this.base = saved - } - } - - def varianceString(v: Variance) = - if (v is Covariant) "covariant" - else if (v is Contravariant) "contravariant" - else "invariant" - - object Traverser extends TreeTraverser { - def checkVariance(sym: Symbol) = Validator.validateDefinition(sym) match { - case Some(VarianceError(tvar, required)) => - ctx.error( - i"${varianceString(tvar.flags)} $tvar occurs in ${varianceString(required)} position in type ${sym.info} of $sym", - sym.pos) - case None => - } - - override def traverse(tree: Tree) = { - def sym = tree.symbol - // No variance check for object-private/protected methods/values. - // Or constructors, or case class factory or extractor. - def skip = ( - sym == NoSymbol - || sym.is(Local) - || sym.owner.isConstructor - //|| sym.owner.isCaseApplyOrUnapply // not clear why needed - ) - tree match { - case defn: MemberDef if skip => - ctx.debuglog(s"Skipping variance check of ${sym.showDcl}") - case tree: TypeDef => - checkVariance(sym) - foldOver((), tree) - case tree: ValDef => - checkVariance(sym) - case DefDef(_, _, tparams, vparamss, _, _) => - checkVariance(sym) - tparams foreach traverse - vparamss foreach (_ foreach traverse) - case Template(_, _, _, body) => - foldOver((), tree) - case _ => - } - } - } -} diff --git a/src/dotty/tools/dotc/typer/Typer.scala b/src/dotty/tools/dotc/typer/Typer.scala index 729536308..0ce02d198 100644 --- a/src/dotty/tools/dotc/typer/Typer.scala +++ b/src/dotty/tools/dotc/typer/Typer.scala @@ -844,6 +844,7 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit checkNoDoubleDefs(cls) val impl1 = cpy.Template(impl, constr1, parents1, self1, body1) .withType(dummy.termRef) + VarianceChecker.check(impl1) assignType(cpy.TypeDef(cdef, mods1, name, impl1), cls) // todo later: check that diff --git a/src/dotty/tools/dotc/typer/VarianceChecker.scala b/src/dotty/tools/dotc/typer/VarianceChecker.scala new file mode 100644 index 000000000..36310767a --- /dev/null +++ b/src/dotty/tools/dotc/typer/VarianceChecker.scala @@ -0,0 +1,138 @@ +package dotty.tools.dotc +package typer + +import dotty.tools.dotc.ast.{ Trees, tpd } +import core._ +import Types._, Contexts._, Flags._, Symbols._, Annotations._, Trees._, NameOps._ +import Decorators._ +import Variances._ + +object VarianceChecker { + private case class VarianceError(tvar: Symbol, required: Variance) + def check(tree: tpd.Tree)(implicit ctx: Context) = + new VarianceChecker()(ctx).Traverser.traverse(tree) +} + +/** See comments at scala.reflect.internal.Variance. + */ +class VarianceChecker()(implicit ctx: Context) { + import VarianceChecker._ + import tpd._ + + private object Validator extends TypeAccumulator[Option[VarianceError]] { + private var base: Symbol = _ + + /** Is no variance checking needed within definition of `base`? */ + def ignoreVarianceIn(base: Symbol): Boolean = ( + base.isTerm + || base.is(Package) + || base.is(Local) + ) + + /** The variance of a symbol occurrence of `tvar` seen at the level of the definition of `base`. + * The search proceeds from `base` to the owner of `tvar`. + * Initially the state is covariant, but it might change along the search. + */ + def relativeVariance(tvar: Symbol, base: Symbol, v: Variance = Covariant): Variance = /*ctx.traceIndented(i"relative variance of $tvar wrt $base, so far: $v")*/ { + if (base == tvar.owner) v + else if ((base is Param) && base.owner.isTerm) + relativeVariance(tvar, paramOuter(base.owner), flip(v)) + else if (ignoreVarianceIn(base.owner)) Bivariant + else if (base.isAliasType) relativeVariance(tvar, base.owner, Invariant) + else relativeVariance(tvar, base.owner, v) + } + + /** The next level to take into account when determining the + * relative variance with a method parameter as base. The method + * is always skipped. If the method is a constructor, we also skip + * its class owner, because constructors are not checked for variance + * relative to the type parameters of their own class. On the other + * hand constructors do count for checking the variance of type parameters + * of enclosing classes. I believe the Scala 2 rules are too lenient in + * that respect. + */ + private def paramOuter(meth: Symbol) = + if (meth.isConstructor) meth.owner.owner else meth.owner + + /** Check variance of abstract type `tvar` when referred from `base`. */ + private def checkVarianceOfSymbol(tvar: Symbol): Option[VarianceError] = { + val relative = relativeVariance(tvar, base) + if (relative == Bivariant) None + else { + val required = compose(relative, this.variance) + def tvar_s = s"$tvar (${varianceString(tvar.flags)} ${tvar.showLocated})" + def base_s = s"$base in ${base.owner}" + (if (base.owner.isClass) "" else " in " + base.owner.enclosingClass) + ctx.log(s"verifying $tvar_s is ${varianceString(required)} at $base_s") + ctx.log(s"relative variance: ${varianceString(relative)}") + ctx.log(s"current variance: ${this.variance}") + ctx.log(s"owner chain: ${base.ownersIterator.toList}") + if (tvar is required) None + else Some(VarianceError(tvar, required)) + } + } + + /** For PolyTypes, type parameters are skipped because they are defined + * explicitly (their TypeDefs will be passed here.) For MethodTypes, the + * same is true of the parameters (ValDefs). + */ + def apply(status: Option[VarianceError], tp: Type): Option[VarianceError] = ctx.traceIndented(s"variance checking $tp of $base at $variance") { + if (status.isDefined) status + else tp match { + case tp: TypeRef => + val sym = tp.symbol + if (sym.variance != 0 && base.isContainedIn(sym.owner)) checkVarianceOfSymbol(sym) + else if (sym.isAliasType) this(status, sym.info) + else foldOver(status, tp) + case tp: MethodType => + this(status, tp.resultType) // params will be checked in their TypeDef nodes. + case tp: PolyType => + this(status, tp.resultType) // params will be checked in their ValDef nodes. + case AnnotatedType(annot, _) if annot.symbol == defn.UncheckedVarianceAnnot => + status + //case tp: ClassInfo => + // ??? not clear what to do here yet. presumably, it's all checked at local typedefs + case _ => + foldOver(status, tp) + } + } + + def validateDefinition(base: Symbol): Option[VarianceError] = { + val saved = this.base + this.base = base + try apply(None, base.info) + finally this.base = saved + } + } + + private object Traverser extends TreeTraverser { + def checkVariance(sym: Symbol) = Validator.validateDefinition(sym) match { + case Some(VarianceError(tvar, required)) => + ctx.error( + i"${varianceString(tvar.flags)} $tvar occurs in ${varianceString(required)} position in type ${sym.info} of $sym", + sym.pos) + case None => + } + + override def traverse(tree: Tree) = { + def sym = tree.symbol + // No variance check for private/protected[this] methods/values. + def skip = !sym.exists || sym.is(Local) + tree match { + case defn: MemberDef if skip => + ctx.debuglog(s"Skipping variance check of ${sym.showDcl}") + case tree: TypeDef => + checkVariance(sym) + foldOver((), tree) + case tree: ValDef => + checkVariance(sym) + case DefDef(_, _, tparams, vparamss, _, _) => + checkVariance(sym) + tparams foreach traverse + vparamss foreach (_ foreach traverse) + case Template(_, _, _, body) => + foldOver((), tree) + case _ => + } + } + } +} diff --git a/src/dotty/tools/dotc/typer/Variances.scala b/src/dotty/tools/dotc/typer/Variances.scala index d793f7942..fadf9f952 100644 --- a/src/dotty/tools/dotc/typer/Variances.scala +++ b/src/dotty/tools/dotc/typer/Variances.scala @@ -1,5 +1,5 @@ package dotty.tools.dotc -package transform +package typer import dotty.tools.dotc.ast.{Trees, tpd} import core._ @@ -90,4 +90,14 @@ object Variances { case _ => Bivariant } + + def varianceString(v: Variance) = + if (v is Covariant) "covariant" + else if (v is Contravariant) "contravariant" + else "invariant" + + def varianceString(v: Int) = + if (v > 0) "+" + else if (v < 0) "-" + else "" } -- cgit v1.2.3