From 2985d1806b66d4bf59807f35a6427b81ef66961e Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Tue, 6 Aug 2013 12:48:28 +0200 Subject: Refined treatment of unapply --- src/dotty/tools/dotc/core/Definitions.scala | 10 ++--- src/dotty/tools/dotc/core/Types.scala | 52 +++++++++++++++++++--- src/dotty/tools/dotc/typer/Applications.scala | 63 +++++++++++++++++++-------- src/dotty/tools/dotc/typer/Inferencing.scala | 32 +++++++++++--- 4 files changed, 123 insertions(+), 34 deletions(-) (limited to 'src/dotty') diff --git a/src/dotty/tools/dotc/core/Definitions.scala b/src/dotty/tools/dotc/core/Definitions.scala index bde325897..df0173467 100644 --- a/src/dotty/tools/dotc/core/Definitions.scala +++ b/src/dotty/tools/dotc/core/Definitions.scala @@ -142,7 +142,7 @@ class Definitions(implicit ctx: Context) { ScalaPackageClass, tpnme.Null, AbstractFinal, List(AnyRefAlias.typeConstructor)).entered lazy val PredefModule = requiredModule("scala.Predef") - lazy val NilModule = requiredModule("scala.collection.immutable.Nil") + lazy val NilModule = requiredModule("scala.collection.immutable.Nil") // lazy val FunctionClass: ClassSymbol = requiredClass("scala.Function") lazy val SingletonClass: ClassSymbol = @@ -330,7 +330,7 @@ class Definitions(implicit ctx: Context) { * - v_i are the variances of the bound symbols (i.e. +, -, or empty). * - _$hk$i are hgiher-kinded parameter names, which are special treated in type application. */ - def hkTrait(variances: List[Int]) = { + def hkTrait(vcs: List[Int]) = { def varianceSuffix(v: Int) = v match { case -1 => "N" @@ -348,14 +348,14 @@ class Definitions(implicit ctx: Context) { def complete(denot: SymDenotation): Unit = { val cls = denot.asClass.classSymbol val paramDecls = newScope - for ((v, i) <- variances.zipWithIndex) + for ((v, i) <- vcs.zipWithIndex) newTypeParam(cls, tpnme.higherKindedParamName(i), varianceFlags(v), paramDecls) denot.info = ClassInfo(ScalaPackageClass.thisType, cls, List(ObjectClass.typeConstructor), paramDecls) } } val traitName = - tpnme.higherKindedTraitName(variances.length) ++ (variances map varianceSuffix).mkString + tpnme.higherKindedTraitName(vcs.length) ++ (vcs map varianceSuffix).mkString def createTrait = { val cls = ctx.newClassSymbol( @@ -367,7 +367,7 @@ class Definitions(implicit ctx: Context) { cls } - hkTraitOfArity.getOrElseUpdate(variances, createTrait) + hkTraitOfArity.getOrElseUpdate(vcs, createTrait) } diff --git a/src/dotty/tools/dotc/core/Types.scala b/src/dotty/tools/dotc/core/Types.scala index 07f3f284f..ff0f2e990 100644 --- a/src/dotty/tools/dotc/core/Types.scala +++ b/src/dotty/tools/dotc/core/Types.scala @@ -940,6 +940,31 @@ object Types { */ def varianceOf(tp: Type): FlagSet = ??? + type VarianceMap = Map[TypeVar, Int] + + /** All occurrences of type vars in this type that satisfy predicate + * `include` mapped to their variances (-1/0/1) in this type, where + * -1 means: only covariant occurrences + * +1 means: only covariant occurrences + * 0 means: mixed or non-variant occurrences + */ + def variances(include: TypeVar => Boolean)(implicit ctx: Context): VarianceMap = { + val accu = new TypeAccumulator[VarianceMap] { + def apply(vmap: VarianceMap, t: Type): VarianceMap = t match { + case t: TypeVar if include(t) => + vmap get t match { + case Some(v) => + if (v == variance) vmap else vmap updated (t, 0) + case None => + vmap updated (t, variance) + } + case _ => + foldOver(vmap, t) + } + } + accu(Map.empty, this) + } + // ----- hashing ------------------------------------------------------ /** customized hash code of this type. @@ -2086,6 +2111,8 @@ object Types { protected def apply(x: T, annot: Annotation): T = x // don't go into annotations + protected var variance = 1 + def foldOver(x: T, tp: Type): T = tp match { case tp: NamedType => this(x, tp.prefix) @@ -2097,21 +2124,36 @@ object Types { this(this(x, tp.parent), tp.refinedInfo) case tp @ MethodType(pnames, ptypes) => - this((x /: ptypes)(this), tp.resultType) + variance = -variance + val y = (x /: ptypes)(this) + variance = -variance + this(y, tp.resultType) case ExprType(restpe) => this(x, restpe) case tp @ PolyType(pnames) => - this((x /: tp.paramBounds)(this), tp.resultType) + variance = -variance + val y = (x /: tp.paramBounds)(this) + variance = -variance + this(y, tp.resultType) case SuperType(thistp, supertp) => this(this(x, thistp), supertp) case TypeBounds(lo, hi) => - if (lo eq hi) this(x, lo) - else this(this(x, lo), hi) - + if (lo eq hi) { + val saved = variance + variance = 0 + try this(x, lo) + finally variance = saved + } + else { + variance = -variance + val y = this(x, lo) + variance = -variance + this(y, hi) + } case AnnotatedType(annot, underlying) => this(this(x, annot), underlying) diff --git a/src/dotty/tools/dotc/typer/Applications.scala b/src/dotty/tools/dotc/typer/Applications.scala index 6424b96e4..4026b84d4 100644 --- a/src/dotty/tools/dotc/typer/Applications.scala +++ b/src/dotty/tools/dotc/typer/Applications.scala @@ -560,17 +560,18 @@ trait Applications extends Compatibility { self: Typer => def unapplyArgs(unapplyResult: Type)(implicit ctx: Context): List[Type] = { def recur(tp: Type): List[Type] = { - def nonOverloadedMember(name: Name) = { + def extractorMemberType(name: Name) = { val ref = tp member name - if (ref.isOverloaded) { - errorType(s"Overloaded reference to $ref is not allowed in extractor", tree.pos) - } + if (ref.isOverloaded) + errorType(s"Overloaded reference to ${ref.show} is not allowed in extractor", tree.pos) + else if (ref.info.isInstanceOf[PolyType]) + errorType(s"Reference to polymorphic ${ref.show}: ${ref.info.show} is not allowed in extractor", tree.pos) else ref.info } def productSelectors: List[Type] = { - val sels = for (n <- Iterator.from(0)) yield nonOverloadedMember(("_" + n).toTermName) + val sels = for (n <- Iterator.from(0)) yield extractorMemberType(("_" + n).toTermName) sels.takeWhile(_.exists).toList } def seqSelector = defn.RepeatedParamType.appliedTo(tp.elemType :: Nil) @@ -578,8 +579,8 @@ trait Applications extends Compatibility { self: Typer => if (tp derivesFrom defn.ProductClass) productSelectors else if (tp derivesFrom defn.SeqClass) seqSelector :: Nil else if (tp.typeSymbol == defn.BooleanClass) Nil - else if (nonOverloadedMember(nme.isDefined).exists && - nonOverloadedMember(nme.get).exists) recur(nonOverloadedMember(nme.get)) + else if (extractorMemberType(nme.isDefined).exists && + extractorMemberType(nme.get).exists) recur(extractorMemberType(nme.get)) else { ctx.error(s"${unapplyResult.show} is not a valid result type of an unapply method of an extractor", tree.pos) Nil @@ -589,7 +590,10 @@ trait Applications extends Compatibility { self: Typer => recur(unapplyResult) } - val fn = { + def notAnExtractor(tree: Tree) = + errorTree(tree, s"${qual.show} cannot be used as an extractor in a pattern because it lacks an unapply or unapplySeq method") + + val unapply = { val dummyArg = untpd.TypedSplice(dummyTreeOfType(WildcardType)) val unappProto = FunProtoType(dummyArg :: Nil, pt, this) tryEither { @@ -599,28 +603,51 @@ trait Applications extends Compatibility { self: Typer => tryEither { implicit ctx => typedExpr(untpd.Select(qual, nme.unapplySeq), unappProto) // for backwards compatibility; will be dropped } { - _ => errorTree(s.value, s"${qual.show} cannot be used as an extractor in a pattern because it lacks an unapply or unapplySeq method") + _ => notAnExtractor(s.value) } } } - fn.tpe.widen match { - case mt: MethodType => - val ownType = mt.resultType - ownType <:< pt // done for registering the constraints; error message would come later - var argTypes = unapplyArgs(ownType) + + unapply.tpe.widen match { + case mt: MethodType if !mt.isDependent => + val unapplyArgType = mt.paramTypes.head + val ownType = + if (pt <:< unapplyArgType) { + assert(isFullyDefined(unapplyArgType)) + pt + } + else if (unapplyArgType <:< pt) + ctx.maximizeType(unapplyArgType) match { + case None => unapplyArgType + case Some(tvar) => + errorType( + s"""There is no best instantiation of pattern type ${unapplyArgType.show} + |that makes it a subtype of selector type ${pt.show}. + |Non-variant type variable ${tvar.origin.show} cannot be uniquely instantiated.""".stripMargin, + tree.pos) + } + else errorType( + s"Pattern type ${unapplyArgType.show} is neither a subtype nor a supertype of selector type ${pt.show}", + tree.pos) + + var argTypes = unapplyArgs(mt.resultType) val bunchedArgs = argTypes match { case argType :: Nil if argType.isRepeatedParam => untpd.SeqLiteral(args) :: Nil case _ => args } if (argTypes.length != bunchedArgs.length) { - ctx.error(s"wrong number of argument patterns for ${err.patternConstrStr(fn)}", tree.pos) + ctx.error(s"wrong number of argument patterns for ${err.patternConstrStr(unapply)}", tree.pos) argTypes = argTypes.take(args.length) ++ List.fill(argTypes.length - args.length)(WildcardType) } val typedArgs = (bunchedArgs, argTypes).zipped map (typed(_, _)) - untpd.UnApply(fn, typedArgs).withPos(tree.pos).withType(ownType) - case et: ErrorType => - tree.withType(ErrorType) + val result = cpy.UnApply(tree, unapply, typedArgs) withType ownType + if ((ownType eq pt) || ownType.isError) result + else Typed(result, TypeTree(ownType)) + case tp => + val unapplyErr = if (tp.isError) unapply else notAnExtractor(unapply) + val typedArgsErr = args map (typed(_, defn.AnyType)) + cpy.UnApply(tree, unapplyErr, typedArgsErr) withType ErrorType } } diff --git a/src/dotty/tools/dotc/typer/Inferencing.scala b/src/dotty/tools/dotc/typer/Inferencing.scala index ff38ef997..fe1572631 100644 --- a/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/src/dotty/tools/dotc/typer/Inferencing.scala @@ -91,13 +91,33 @@ object Inferencing { * approximate it by its lower bound. Otherwise, if it appears contravariantly * in type `tp` approximate it by its upper bound. */ - def interpolateUndetVars(tp: Type, pos: Position): Unit = - for (tvar <- ctx.typerState.undetVars) - if (pos contains tvar.pos) { - val v = tp varianceOf tvar - if (v is Covariant) tvar.instantiate(fromBelow = true) - else if (v is Contravariant) tvar.instantiate(fromBelow = false) + def interpolateUndetVars(tp: Type, pos: Position): Unit = { + val vs = tp.variances(tvar => + (ctx.typerState.undetVars contains tvar) && (pos contains tvar.pos)) + for ((tvar, v) <- vs) + if (v == 1) tvar.instantiate(fromBelow = true) + else if (v == -1) tvar.instantiate(fromBelow = false) + for (tvar <- ctx.typerState.undetVars if !(vs contains tvar)) + tvar.instantiate(fromBelow = false) + } + + /** Instantiate undetermined type variables to that type `tp` is + * maximized and return None. If this is not possible, because a non-variant + * typevar is not uniquely determined, return that typevar in a Some. + */ + def maximizeType(tp: Type): Option[TypeVar] = { + val vs = tp.variances(tvar => ctx.typerState.undetVars contains tvar) + var result: Option[TypeVar] = None + for ((tvar, v) <- vs) + if (v == 1) tvar.instantiate(fromBelow = false) + else if (v == -1) tvar.instantiate(fromBelow = true) + else { + val bounds @ TypeBounds(lo, hi) = ctx.typerState.constraint(tvar.origin) + if (hi <:< lo) tvar.instantiate(fromBelow = false) + else result = Some(tvar) } + result + } /** Create new type variables for the parameters of a poly type. * @param pos The position of the new type variables (relevant for -- cgit v1.2.3