From 78a29a40028d18d3506014a0e068f53b29c52f4a Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Mon, 21 Jul 2014 18:55:25 +0200 Subject: Initial versions of Variances and CheckVariances Not yet integrated or tested. --- src/dotty/tools/dotc/core/Definitions.scala | 3 +- src/dotty/tools/dotc/core/Flags.scala | 3 + src/dotty/tools/dotc/core/SymDenotations.scala | 2 +- src/dotty/tools/dotc/typer/CheckVariances.scala | 130 ++++++++++++++++++++++++ src/dotty/tools/dotc/typer/Variances.scala | 93 +++++++++++++++++ 5 files changed, 229 insertions(+), 2 deletions(-) create mode 100644 src/dotty/tools/dotc/typer/CheckVariances.scala create mode 100644 src/dotty/tools/dotc/typer/Variances.scala (limited to 'src') diff --git a/src/dotty/tools/dotc/core/Definitions.scala b/src/dotty/tools/dotc/core/Definitions.scala index 9c479de73..286d1437f 100644 --- a/src/dotty/tools/dotc/core/Definitions.scala +++ b/src/dotty/tools/dotc/core/Definitions.scala @@ -196,7 +196,6 @@ class Definitions { lazy val Array_update = ctx.requiredMethod(ArrayClass, nme.update) lazy val Array_length = ctx.requiredMethod(ArrayClass, nme.length) lazy val Array_clone = ctx.requiredMethod(ArrayClass, nme.clone_) - lazy val uncheckedStableClass: ClassSymbol = ctx.requiredClass("scala.annotation.unchecked.uncheckedStable") lazy val UnitClass = valueClassSymbol("scala.Unit", BoxedUnitClass, java.lang.Void.TYPE, UnitEnc) lazy val BooleanClass = valueClassSymbol("scala.Boolean", BoxedBooleanClass, java.lang.Boolean.TYPE, BooleanEnc) @@ -296,6 +295,8 @@ class Definitions { lazy val AnnotationDefaultAnnot = ctx.requiredClass("dotty.annotation.internal.AnnotationDefault") lazy val ThrowsAnnot = ctx.requiredClass("scala.throws") lazy val UncheckedAnnot = ctx.requiredClass("scala.unchecked") + lazy val UncheckedStableAnnot = ctx.requiredClass("scala.annotation.unchecked.uncheckedStable") + lazy val UncheckedVarianceAnnot = ctx.requiredClass("scala.annotation.unchecked.uncheckedVariance") lazy val VolatileAnnot = ctx.requiredClass("scala.volatile") // convenient one-parameter method types diff --git a/src/dotty/tools/dotc/core/Flags.scala b/src/dotty/tools/dotc/core/Flags.scala index c67768861..c527cef62 100644 --- a/src/dotty/tools/dotc/core/Flags.scala +++ b/src/dotty/tools/dotc/core/Flags.scala @@ -500,6 +500,9 @@ object Flags { /** A parameter or parameter accessor */ final val ParamOrAccessor = Param | ParamAccessor + /** A type parameter or type parameter accessor */ + final val TypeParamOrAccessor = TypeParam | TypeParamAccessor + /** A covariant type parameter instance */ final val LocalCovariant = allOf(Local, Covariant) diff --git a/src/dotty/tools/dotc/core/SymDenotations.scala b/src/dotty/tools/dotc/core/SymDenotations.scala index 61c700760..91a8e4345 100644 --- a/src/dotty/tools/dotc/core/SymDenotations.scala +++ b/src/dotty/tools/dotc/core/SymDenotations.scala @@ -402,7 +402,7 @@ object SymDenotations { final def isStable(implicit ctx: Context) = { val isUnstable = (this is UnstableValue) || - ctx.isVolatile(info) && !hasAnnotation(defn.uncheckedStableClass) + ctx.isVolatile(info) && !hasAnnotation(defn.UncheckedStableAnnot) (this is Stable) || isType || { if (isUnstable) false else { setFlag(Stable); true } diff --git a/src/dotty/tools/dotc/typer/CheckVariances.scala b/src/dotty/tools/dotc/typer/CheckVariances.scala new file mode 100644 index 000000000..1e711aa17 --- /dev/null +++ b/src/dotty/tools/dotc/typer/CheckVariances.scala @@ -0,0 +1,130 @@ +package dotty.tools.dotc +package transform + +import dotty.tools.dotc.ast.{ Trees, tpd } +import core._ +import Types._, Contexts._, Flags._, Symbols._, Annotations._, Trees._ +import Decorators._ +import Variances._ + +object VarianceChecker { + + case class VarianceError(tvar: Symbol, required: Variance) +} + +/** See comments at scala.reflect.internal.Variance. + */ +class VarianceChecker(implicit ctx: Context) { + import VarianceChecker._ + import tpd._ + + private object Validator extends TypeAccumulator[Option[VarianceError]] { + private var base: Symbol = _ + + /** The variance of a symbol occurrence of `tvar` seen at the level of the definition of `base`. + * The search proceeds from `base` to the owner of `tvar`. + * Initially the state is covariant, but it might change along the search. + */ + def relativeVariance(tvar: Symbol, base: Symbol, v: Variance = Covariant): Variance = { + if (base.owner == tvar.owner) v + else if ((base is Param) && base.owner.isTerm) relativeVariance(tvar, base.owner.owner, flip(v)) + else if (base.isTerm) Bivariant + else if (base.isAliasType) relativeVariance(tvar, base.owner, Invariant) + else relativeVariance(tvar, base.owner, v) + } + + def isUncheckedVariance(tp: Type): Boolean = tp match { + case AnnotatedType(annot, tp1) => + annot.symbol == defn.UncheckedVarianceAnnot || isUncheckedVariance(tp1) + case _ => false + } + + private def checkVarianceOfSymbol(tvar: Symbol): Option[VarianceError] = { + val relative = relativeVariance(tvar, base) + val required = Variances.compose(relative, this.variance) + if (relative == Bivariant) None + else { + def tvar_s = s"$tvar (${tvar.variance}${tvar.showLocated})" + def base_s = s"$base in ${base.owner}" + (if (base.owner.isClass) "" else " in " + base.owner.enclosingClass) + ctx.log(s"verifying $tvar_s is $required at $base_s") + if (tvar.variance == required) None + else Some(VarianceError(tvar, required)) + } + } + + /** For PolyTypes, type parameters are skipped because they are defined + * explicitly (their TypeDefs will be passed here.) For MethodTypes, the + * same is true of the parameters (ValDefs) unless we are inside a + * refinement, in which case they are checked from here. + */ + def apply(status: Option[VarianceError], tp: Type): Option[VarianceError] = + if (status.isDefined) status + else tp match { + case tp: TypeRef => + val sym = tp.symbol + if (sym.variance != 0 && base.isContainedIn(sym.owner)) checkVarianceOfSymbol(sym) + else if (sym.isAliasType) this(status, sym.info) + else foldOver(status, tp) + case tp: MethodType => + this(status, tp.resultType) // params will be checked in their TypeDef nodes. + case tp: PolyType => + this(status, tp.resultType) // params will be checked in their ValDef nodes. + case AnnotatedType(annot, _) if annot.symbol == defn.UncheckedVarianceAnnot => + status + case tp: ClassInfo => + ??? + case _ => + foldOver(status, tp) + } + + def validateDefinition(base: Symbol): Option[VarianceError] = { + val saved = this.base + this.base = base + try apply(None, base.info) + finally this.base = saved + } + } + + def varianceString(v: Variance) = + if (v is Covariant) "covariant" + else if (v is Contravariant) "contravariant" + else "invariant" + + object Traverser extends TreeTraverser { + def checkVariance(sym: Symbol) = Validator.validateDefinition(sym) match { + case Some(VarianceError(tvar, required)) => + ctx.error( + i"${varianceString(tvar.flags)} $tvar occurs in ${varianceString(required)} position in type ${sym.info} of $sym", + sym.pos) + case None => + } + + override def traverse(tree: Tree) = { + def sym = tree.symbol + // No variance check for object-private/protected methods/values. + // Or constructors, or case class factory or extractor. + def skip = ( + sym == NoSymbol + || sym.is(Local) + || sym.owner.isConstructor + //|| sym.owner.isCaseApplyOrUnapply // not clear why needed + ) + tree match { + case defn: MemberDef if skip => + ctx.debuglog(s"Skipping variance check of ${sym.showDcl}") + case tree: TypeDef => + checkVariance(sym) + super.traverse(tree) + case tree: ValDef => + checkVariance(sym) + case DefDef(_, _, tparams, vparamss, _, _) => + checkVariance(sym) + tparams foreach traverse + vparamss foreach (_ foreach traverse) + case Template(_, _, _, body) => + super.traverse(tree) + case _ => + } + } + } +} diff --git a/src/dotty/tools/dotc/typer/Variances.scala b/src/dotty/tools/dotc/typer/Variances.scala new file mode 100644 index 000000000..d793f7942 --- /dev/null +++ b/src/dotty/tools/dotc/typer/Variances.scala @@ -0,0 +1,93 @@ +package dotty.tools.dotc +package transform + +import dotty.tools.dotc.ast.{Trees, tpd} +import core._ +import Types._, Contexts._, Flags._, Symbols._, Annotations._, Trees._ +import Decorators._ + +object Variances { + import tpd._ + + type Variance = FlagSet + val Bivariant = VarianceFlags + val Invariant = EmptyFlags + + /** Flip between covariant and contravariant */ + def flip(v: Variance): Variance = { + if (v == Covariant) Contravariant + else if (v == Contravariant) Covariant + else v + } + + /** Map everything below Bivariant to Invariant */ + def cut(v: Variance): Variance = + if (v == Bivariant) v else Invariant + + def compose(v: Variance, boundsVariance: Int) = + if (boundsVariance == 1) v + else if (boundsVariance == -1) flip(v) + else cut(v) + + /** Compute variance of type parameter `tparam' in types of all symbols `sym'. */ + def varianceInSyms(syms: List[Symbol])(tparam: Symbol)(implicit ctx: Context): Variance = + (Bivariant /: syms) ((v, sym) => v & varianceInSym(sym)(tparam)) + + /** Compute variance of type parameter `tparam' in type of symbol `sym'. */ + def varianceInSym(sym: Symbol)(tparam: Symbol)(implicit ctx: Context): Variance = + if (sym.isAliasType) cut(varianceInType(sym.info)(tparam)) + else varianceInType(sym.info)(tparam) + + /** Compute variance of type parameter `tparam' in all types `tps'. */ + def varianceInTypes(tps: List[Type])(tparam: Symbol)(implicit ctx: Context): Variance = + (Bivariant /: tps) ((v, tp) => v & varianceInType(tp)(tparam)) + + /** Compute variance of type parameter `tparam' in all type arguments + * tps which correspond to formal type parameters `tparams1'. + */ + def varianceInArgs(tps: List[Type], tparams1: List[Symbol])(tparam: Symbol)(implicit ctx: Context): Variance = { + var v: Variance = Bivariant; + for ((tp, tparam1) <- tps zip tparams1) { + val v1 = varianceInType(tp)(tparam) + v = v & (if (tparam1.is(Covariant)) v1 + else if (tparam1.is(Contravariant)) flip(v1) + else cut(v1)) + } + v + } + + /** Compute variance of type parameter `tparam' in all type annotations `annots'. */ + def varianceInAnnots(annots: List[Annotation])(tparam: Symbol)(implicit ctx: Context): Variance = { + (Bivariant /: annots) ((v, annot) => v & varianceInAnnot(annot)(tparam)) + } + + /** Compute variance of type parameter `tparam' in type annotation `annot'. */ + def varianceInAnnot(annot: Annotation)(tparam: Symbol)(implicit ctx: Context): Variance = { + varianceInType(annot.tree.tpe)(tparam) + } + + /** Compute variance of type parameter tparam in type tp. */ + def varianceInType(tp: Type)(tparam: Symbol)(implicit ctx: Context): Variance = tp match { + case TermRef(pre, sym) => + varianceInType(pre)(tparam) + case TypeRef(pre, sym) => + if (sym == tparam) Covariant else varianceInType(pre)(tparam) + case tp @ TypeBounds(lo, hi) => + if (lo eq hi) compose(varianceInType(hi)(tparam), tp.variance) + else flip(varianceInType(lo)(tparam)) & varianceInType(hi)(tparam) + case tp @ RefinedType(parent, _) => + varianceInType(parent)(tparam) & varianceInType(tp.refinedInfo)(tparam) + case tp @ MethodType(_, paramTypes) => + flip(varianceInTypes(paramTypes)(tparam)) & varianceInType(tp.resultType)(tparam) + case ExprType(restpe) => + varianceInType(restpe)(tparam) + case tp @ PolyType(_) => + flip(varianceInTypes(tp.paramBounds)(tparam)) & varianceInType(tp.resultType)(tparam) + case AnnotatedType(annot, tp) => + varianceInAnnot(annot)(tparam) & varianceInType(tp)(tparam) + case tp: AndOrType => + varianceInType(tp.tp1)(tparam) & varianceInType(tp.tp2)(tparam) + case _ => + Bivariant + } +} -- cgit v1.2.3