aboutsummaryrefslogtreecommitdiff
path: root/src/dotty/tools/dotc/typer/Variances.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/dotty/tools/dotc/typer/Variances.scala')
-rw-r--r--src/dotty/tools/dotc/typer/Variances.scala93
1 files changed, 93 insertions, 0 deletions
diff --git a/src/dotty/tools/dotc/typer/Variances.scala b/src/dotty/tools/dotc/typer/Variances.scala
new file mode 100644
index 000000000..d793f7942
--- /dev/null
+++ b/src/dotty/tools/dotc/typer/Variances.scala
@@ -0,0 +1,93 @@
+package dotty.tools.dotc
+package transform
+
+import dotty.tools.dotc.ast.{Trees, tpd}
+import core._
+import Types._, Contexts._, Flags._, Symbols._, Annotations._, Trees._
+import Decorators._
+
+object Variances {
+ import tpd._
+
+ type Variance = FlagSet
+ val Bivariant = VarianceFlags
+ val Invariant = EmptyFlags
+
+ /** Flip between covariant and contravariant */
+ def flip(v: Variance): Variance = {
+ if (v == Covariant) Contravariant
+ else if (v == Contravariant) Covariant
+ else v
+ }
+
+ /** Map everything below Bivariant to Invariant */
+ def cut(v: Variance): Variance =
+ if (v == Bivariant) v else Invariant
+
+ def compose(v: Variance, boundsVariance: Int) =
+ if (boundsVariance == 1) v
+ else if (boundsVariance == -1) flip(v)
+ else cut(v)
+
+ /** Compute variance of type parameter `tparam' in types of all symbols `sym'. */
+ def varianceInSyms(syms: List[Symbol])(tparam: Symbol)(implicit ctx: Context): Variance =
+ (Bivariant /: syms) ((v, sym) => v & varianceInSym(sym)(tparam))
+
+ /** Compute variance of type parameter `tparam' in type of symbol `sym'. */
+ def varianceInSym(sym: Symbol)(tparam: Symbol)(implicit ctx: Context): Variance =
+ if (sym.isAliasType) cut(varianceInType(sym.info)(tparam))
+ else varianceInType(sym.info)(tparam)
+
+ /** Compute variance of type parameter `tparam' in all types `tps'. */
+ def varianceInTypes(tps: List[Type])(tparam: Symbol)(implicit ctx: Context): Variance =
+ (Bivariant /: tps) ((v, tp) => v & varianceInType(tp)(tparam))
+
+ /** Compute variance of type parameter `tparam' in all type arguments
+ * <code>tps</code> which correspond to formal type parameters `tparams1'.
+ */
+ def varianceInArgs(tps: List[Type], tparams1: List[Symbol])(tparam: Symbol)(implicit ctx: Context): Variance = {
+ var v: Variance = Bivariant;
+ for ((tp, tparam1) <- tps zip tparams1) {
+ val v1 = varianceInType(tp)(tparam)
+ v = v & (if (tparam1.is(Covariant)) v1
+ else if (tparam1.is(Contravariant)) flip(v1)
+ else cut(v1))
+ }
+ v
+ }
+
+ /** Compute variance of type parameter `tparam' in all type annotations `annots'. */
+ def varianceInAnnots(annots: List[Annotation])(tparam: Symbol)(implicit ctx: Context): Variance = {
+ (Bivariant /: annots) ((v, annot) => v & varianceInAnnot(annot)(tparam))
+ }
+
+ /** Compute variance of type parameter `tparam' in type annotation `annot'. */
+ def varianceInAnnot(annot: Annotation)(tparam: Symbol)(implicit ctx: Context): Variance = {
+ varianceInType(annot.tree.tpe)(tparam)
+ }
+
+ /** Compute variance of type parameter <code>tparam</code> in type <code>tp</code>. */
+ def varianceInType(tp: Type)(tparam: Symbol)(implicit ctx: Context): Variance = tp match {
+ case TermRef(pre, sym) =>
+ varianceInType(pre)(tparam)
+ case TypeRef(pre, sym) =>
+ if (sym == tparam) Covariant else varianceInType(pre)(tparam)
+ case tp @ TypeBounds(lo, hi) =>
+ if (lo eq hi) compose(varianceInType(hi)(tparam), tp.variance)
+ else flip(varianceInType(lo)(tparam)) & varianceInType(hi)(tparam)
+ case tp @ RefinedType(parent, _) =>
+ varianceInType(parent)(tparam) & varianceInType(tp.refinedInfo)(tparam)
+ case tp @ MethodType(_, paramTypes) =>
+ flip(varianceInTypes(paramTypes)(tparam)) & varianceInType(tp.resultType)(tparam)
+ case ExprType(restpe) =>
+ varianceInType(restpe)(tparam)
+ case tp @ PolyType(_) =>
+ flip(varianceInTypes(tp.paramBounds)(tparam)) & varianceInType(tp.resultType)(tparam)
+ case AnnotatedType(annot, tp) =>
+ varianceInAnnot(annot)(tparam) & varianceInType(tp)(tparam)
+ case tp: AndOrType =>
+ varianceInType(tp.tp1)(tparam) & varianceInType(tp.tp2)(tparam)
+ case _ =>
+ Bivariant
+ }
+}