diff options
Diffstat (limited to 'src/dotty/tools/dotc/core/TrackingConstraint.scala')
-rw-r--r-- | src/dotty/tools/dotc/core/TrackingConstraint.scala | 283 |
1 files changed, 130 insertions, 153 deletions
diff --git a/src/dotty/tools/dotc/core/TrackingConstraint.scala b/src/dotty/tools/dotc/core/TrackingConstraint.scala index 6e9ef5a2b..18e37a73e 100644 --- a/src/dotty/tools/dotc/core/TrackingConstraint.scala +++ b/src/dotty/tools/dotc/core/TrackingConstraint.scala @@ -93,7 +93,7 @@ class TrackingConstraint(private val myMap: ParamInfo, private def typeVar(entries: Array[Type], n: Int): Type = entries(paramCount(entries) + n) - private def entry(param: PolyParam): Type = { + def entry(param: PolyParam): Type = { val entries = myMap(param.binder) if (entries == null) NoType else entries(param.paramNum) @@ -119,55 +119,79 @@ class TrackingConstraint(private val myMap: ParamInfo, // ---------- Dependency handling ---------------------------------------------- - private def upperBits(i: Int): BitSet = less(i) + private def upperBits(less: Array[BitSet], i: Int): BitSet = less(i) - private def lowerBits(i: Int): BitSet = - (BitSet() /: less.indices) ((bits, j) => if (less(i)(j)) bits + j else bits) + private def lowerBits(less: Array[BitSet], i: Int): BitSet = + (BitSet() /: less.indices) ((bits, j) => if (less(j)(i)) bits + j else bits) - private def minUpperBits(i: Int): BitSet = { - val all = upperBits(i) + private def minUpperBits(less: Array[BitSet], i: Int): BitSet = { + val all = upperBits(less, i) all.filterNot(j => all.exists(k => less(k)(j))) } - private def minLowerBits(i: Int): BitSet = { - val all = lowerBits(i) + private def minLowerBits(less: Array[BitSet], i: Int): BitSet = { + val all = lowerBits(less, i) all.filterNot(j => all.exists(k => less(j)(k))) } - private def overParams(op: Int => BitSet): PolyParam => List[PolyParam] = param => - op(paramIndex(param)).toList.map(params).filter(contains) + private def overParams(op: (Array[BitSet], Int) => BitSet): PolyParam => List[PolyParam] = param => + op(less, paramIndex(param)).toList.map(params).filter(contains) - val upper = overParams(upperBits) - val lower = overParams(lowerBits) + val allUpper = overParams(upperBits) + val allLower = overParams(lowerBits) val minUpper = overParams(minUpperBits) val minLower = overParams(minLowerBits) + def upper(param: PolyParam): List[PolyParam] = allUpper(param) + def lower(param: PolyParam): List[PolyParam] = allLower(param) + + def exclusiveLower(param: PolyParam, butNot: PolyParam): List[PolyParam] = { + val excluded = lowerBits(less, paramIndex(butNot)) + overParams(lowerBits(_, _) &~ excluded)(param) + } + + def exclusiveUpper(param: PolyParam, butNot: PolyParam): List[PolyParam] = { + val excluded = upperBits(less, paramIndex(butNot)) + overParams(upperBits(_, _) &~ excluded)(param) + } // ---------- Info related to PolyParams ------------------------------------------- - def related(param1: PolyParam, param2: PolyParam, firstIsLower: Boolean)(implicit ctx: Context): Boolean = { - val i1 = paramIndex(param1) - val i2 = paramIndex(param2) - if (firstIsLower) less(i1)(i2) else less(i2)(i1) - } + def isLess(param1: PolyParam, param2: PolyParam)(implicit ctx: Context): Boolean = + less(paramIndex(param1))(paramIndex(param2)) - def nonParamBounds(param: PolyParam)(implicit ctx: Context): TypeBounds = + def nonParamBounds(param: PolyParam): TypeBounds = entry(param).asInstanceOf[TypeBounds] - def bounds(param: PolyParam)(implicit ctx: Context): TypeBounds = { - val bounds @ TypeBounds(lo, hi) = nonParamBounds(param) - bounds.derivedTypeBounds( - (lo /: minLower(param))(OrType.apply), - (hi /: minUpper(param))(AndType.apply)) - } - - def at(param: PolyParam)(implicit ctx: Context): Type = { - entry(param) match { - case _: TypeBounds => bounds(param) - case e => e + def checkBound(param: PolyParam, bound: Type)(implicit ctx: Context): Type = { + assert(param != bound) + bound match { + case TypeBounds(lo, hi) => + checkBound(param, lo) + checkBound(param, hi) + case bound: TypeVar => + checkBound(param, bound.underlying) + case bound: RefinedType => + checkBound(param, bound.underlying) + case bound: AndOrType => + checkBound(param, bound.tp1) + checkBound(param, bound.tp2) + case _ => } + bound } - + + def fullLowerBound(param: PolyParam)(implicit ctx: Context): Type = { + val lo = checkBound(param, nonParamBounds(param).lo) + checkBound(param, (lo /: minLower(param))(_ | _)) + } + + def fullUpperBound(param: PolyParam)(implicit ctx: Context): Type = + (nonParamBounds(param).hi /: minUpper(param))(_ & _) + + def fullBounds(param: PolyParam)(implicit ctx: Context): TypeBounds = + nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param)) + def typeVarOfParam(param: PolyParam): Type = { val entries = myMap(param.binder) if (entries == null) NoType @@ -177,30 +201,19 @@ class TrackingConstraint(private val myMap: ParamInfo, } } -// ---------- Type splitting -------------------------------------------------- - - /** The set of "dependent" constrained parameters that unconditionally strengthen bound `tp`. - * @param seenFromBelow If true, `bound` is an upper bound, else a lower bound. - */ - private def depParams(tp: Type, seenFromBelow: Boolean): Set[PolyParam] = tp match { - case tp: PolyParam if contains(tp) => - Set(tp) - case tp: AndOrType if seenFromBelow == tp.isAnd => - depParams(tp.tp1, seenFromBelow) | depParams(tp.tp2, seenFromBelow) - case _ => - Set.empty - } +// ---------- Adding PolyTypes -------------------------------------------------- - /** The bound type `tp` without dependent parameters. + /** The bound type `tp` without dependent parameters * NoType if type consists only of dependent parameters. * @param seenFromBelow If true, `bound` is an upper bound, else a lower bound. */ - private def stripParams(tp: Type, seenFromBelow: Boolean)(implicit ctx: Context): Type = tp match { - case tp: PolyParam if contains(tp) => - NoType + private def stripParams(tp: Type, handleParam: (PolyParam, Boolean) => Type, + seenFromBelow: Boolean)(implicit ctx: Context): Type = tp match { + case tp: PolyParam => + handleParam(tp, seenFromBelow) case tp: AndOrType if seenFromBelow == tp.isAnd => - val tp1 = nonParamType(tp.tp1, seenFromBelow) - val tp2 = nonParamType(tp.tp2, seenFromBelow) + val tp1 = stripParams(tp.tp1, handleParam, seenFromBelow) + val tp2 = stripParams(tp.tp2, handleParam, seenFromBelow) if (tp1.exists) if (tp2.exists) tp.derivedAndOrType(tp1, tp2) else tp1 @@ -213,95 +226,94 @@ class TrackingConstraint(private val myMap: ParamInfo, * A top or bottom type if type consists only of dependent parameters. * @param seenFromBelow If true, `bound` is an upper bound, else a lower bound. */ - private def nonParamType(tp: Type, seenFromBelow: Boolean)(implicit ctx: Context): Type = - stripParams(tp, seenFromBelow).orElse(if (seenFromBelow) defn.AnyType else defn.NothingType) + private def nonParamType(tp: Type, handleParam: (PolyParam, Boolean) => Type, + seenFromBelow: Boolean)(implicit ctx: Context): Type = + stripParams(tp, handleParam, seenFromBelow) + .orElse(if (seenFromBelow) defn.AnyType else defn.NothingType) - /** The `tp1 is a TypeBounds type, the bounds without dependent parameters, - * otherwise `tp`. + /** The bounds of `tp1` without dependent parameters. + * @pre `tp` is a TypeBounds type. */ - private def nonParamType(tp: Type)(implicit ctx: Context): Type = tp match { + private def nonParamBounds(tp: Type, handleParam: (PolyParam, Boolean) => Type)(implicit ctx: Context): Type = tp match { case tp @ TypeBounds(lo, hi) => tp.derivedTypeBounds( - nonParamType(lo, seenFromBelow = false), - nonParamType(hi, seenFromBelow = true)) - case _ => - tp + nonParamType(lo, handleParam, seenFromBelow = false), + nonParamType(hi, handleParam, seenFromBelow = true)) } - /** An updated partial order matrix that incorporates `less` and also reflects the new `bounds` - * for parameter `param`. - */ - private def updatedLess(less: Array[BitSet], param: PolyParam, bounds: Type): Array[BitSet] = bounds match { - case TypeBounds(lo, hi) => - updatedLess( - updatedLess(less, param, lo, seenFromBelow = false), - param, hi, seenFromBelow = true) - case _ => - less + def add(poly: PolyType, tvars: List[TypeVar])(implicit ctx: Context): This = { + assert(!contains(poly)) + val nparams = poly.paramNames.length + val entries1 = new Array[Type](nparams * 2) + poly.paramBounds.copyToArray(entries1, 0) + tvars.copyToArray(entries1, nparams) + val is = poly.paramBounds.indices + val newParams = is.map(PolyParam(poly, _)) + val params1 = params ++ newParams + var less1 = less ++ is.map(Function.const(BitSet.empty)) + for (i <- is) { + def handleParam(param: PolyParam, seenFromBelow: Boolean): Type = { + def record(paramIdx: Int): Type = { + less1 = + if (seenFromBelow) updatedLess(less1, nparams + i, paramIdx) + else updatedLess(less1, paramIdx, nparams + i) + NoType + } + if (param.binder eq poly) record(nparams + param.paramNum) + else if (contains(param.binder)) record(paramIndex(param)) + else param + } + entries1(i) = checkBound(newParams(i), nonParamBounds(entries1(i), handleParam)) + } + newConstraint(myMap.updated(poly, entries1), less1, params1) } - /** An updated partial order matrix that incorporates `less` and also reflects that `param` has a new - * `bound`, where `seenFromBelow` is true iff `bound` is an upper bound for `param`. - */ - def updatedLess(less: Array[BitSet], param: PolyParam, bound: Type, seenFromBelow: Boolean): Array[BitSet] = - updatedLess(less, param, depParams(bound, seenFromBelow).iterator, inOrder = seenFromBelow) - - /** An updated partial order matrix that incorporates `less` and also reflects that `param` relates - * to all parameters in `ps2` wrt <:< if `inOrder` is true, `>:>` otherwise. - */ - def updatedLess(less: Array[BitSet], p1: PolyParam, ps2: Iterator[PolyParam], inOrder: Boolean): Array[BitSet] = - if (ps2.hasNext) updatedLess(updatedLess(less, p1, ps2.next, inOrder), p1, ps2, inOrder) - else less +// ---------- Updates ------------------------------------------------------------ /** An updated partial order matrix that incorporates `less` and also reflects that `param` relates * to `p2` wrt <:< if `inOrder` is true, `>:>` otherwise. */ - def updatedLess(less: Array[BitSet], p1: PolyParam, p2: PolyParam, inOrder: Boolean): Array[BitSet] = - if (!inOrder) updatedLess(less, p2, p1, true) - else { - val i1 = paramIndex(p1) - val i2 = paramIndex(p2) + private def updatedLess(less: Array[BitSet], i1: Int, i2: Int): Array[BitSet] = { if (i1 == i2 || less(i1)(i2)) less else { val result = less.clone - result(i1) = result(i1) + i2 | upperBits(i2) - assert(!result(i1)(i1)) - for (j <- lowerBits(i1)) { - result(j) = result(j) + i2 | upperBits(i2) + val newUpper = upperBits(less, i2) + i2 + def update(j: Int) = { + result(j) |= newUpper assert(!result(j)(j)) } + update(i1) + lowerBits(less, i1).foreach(update) result } - } + } -// ---------- Updates ------------------------------------------------------------ - - def order(param: PolyParam, bound: PolyParam, inOrder: Boolean)(implicit ctx: Context): This = { - val less1 = updatedLess(less, param, bound, inOrder) + def addLess(p1: PolyParam, p2: PolyParam)(implicit ctx: Context): This = { + val less1 = updatedLess(less, paramIndex(p1), paramIndex(p2)) if (less1 eq less) this else newConstraint(myMap, less1, params) } - def nonParamUpdated(param: PolyParam, tpe: Type)(implicit ctx: Context): This = { - val entries1 = myMap(param.binder).clone - entries1(param.paramNum) = tpe - newConstraint(myMap.updated(param.binder, entries1), less, params) + def updateEntry(param: PolyParam, tp: Type)(implicit ctx: Context): This = { + val entries = myMap(param.binder) + val entry = entries(param.paramNum) + if (entry eq tp) this + else { + if (!tp.isInstanceOf[TypeBounds]) typr.println(i"inst entry $param to $tp") + val entries1 = entries.clone + entries1(param.paramNum) = checkBound(param, tp) + newConstraint(myMap.updated(param.binder, entries1), less, params) + } } - def updated(param: PolyParam, tpe: Type)(implicit ctx: Context): This = { - val less1 = updatedLess(less, param, tpe) - val entries = myMap(param.binder) - val entry1 = nonParamType(tpe) - val idx = param.paramNum - val entries1 = - if (entry1 eq entries(idx)) entries - else { - val entries1 = entries.clone - entries1(idx) = entry1 - entries1 - } - newConstraint(myMap.updated(param.binder, entries1), less1, params) + def unify(p1: PolyParam, p2: PolyParam)(implicit ctx: Context): This = { + val p1Bounds = + dropParamIn(nonParamBounds(p1), p2.binder, p2.paramNum) & + dropParamIn(nonParamBounds(p2), p1.binder, p1.paramNum) + this.updateEntry(p1, p1Bounds).updateEntry(p2, p1) } - + +// ---------- Removals ------------------------------------------------------------ + /** Drop parameter `PolyParam(poly, n)` from `bounds`, * replacing with Nothing in the lower bound and by `Any` in the upper bound. */ @@ -324,10 +336,6 @@ class TrackingConstraint(private val myMap: ParamInfo, approx(bounds.lo, defn.NothingType), approx(bounds.hi, defn.AnyType)) } - /** A new constraint which is derived from this constraint by removing - * the type parameter `param` from the domain and replacing all occurrences - * of the parameter elsewhere in the constraint by type `tp`. - */ def replace(param: PolyParam, tp: Type)(implicit ctx: Context): TrackingConstraint = { val replacement = tp.dealias.stripTypeVar @@ -340,7 +348,7 @@ class TrackingConstraint(private val myMap: ParamInfo, val newBounds = oldBounds.substParam(param, replacement).asInstanceOf[TypeBounds] if (oldBounds ne newBounds) { if (result eq entries) result = entries.clone - result(i) = dropParamIn(newBounds, poly, i) + result(i) = checkBound(PolyParam(poly, i), dropParamIn(newBounds, poly, i)) } case _ => } @@ -351,39 +359,13 @@ class TrackingConstraint(private val myMap: ParamInfo, if (param == replacement) this else { + assert(replacement.isValueType) val pt = param.binder - val constr1 = if (isRemovable(pt, param.paramNum)) remove(pt) else updated(param, replacement) - val result = new TrackingConstraint(constr1.myMap mapValues subst, constr1.less, constr1.params) - if (Config.checkConstraintsNonCyclic) result.checkNonCyclic() - result + val constr1 = if (isRemovable(pt, param.paramNum)) remove(pt) else updateEntry(param, replacement) + newConstraint(constr1.myMap mapValues subst, constr1.less, constr1.params) } } - def unify(p1: PolyParam, p2: PolyParam)(implicit ctx: Context): This = { - val p1Bounds = - dropParamIn(nonParamBounds(p1), p2.binder, p2.paramNum) & - dropParamIn(nonParamBounds(p2), p1.binder, p1.paramNum) - this.nonParamUpdated(p1, p1Bounds).nonParamUpdated(p2, p1) - } - - def add(poly: PolyType, tvars: List[TypeVar])(implicit ctx: Context): This = { - assert(!contains(poly)) - val nparams = poly.paramNames.length - val entries1 = new Array[Type](nparams * 2) - poly.paramBounds.copyToArray(entries1, 0) - tvars.copyToArray(entries1, nparams) - val is = poly.paramBounds.indices - val newParams = is.map(PolyParam(poly, _)) - val params1 = params ++ newParams - var less1 = less ++ is.map(Function.const(BitSet.empty)) - for (i <- is) { - less1 = updatedLess(less1, newParams(i), entries1(i)) - entries1(i) = nonParamType(entries1(i)) - } - newConstraint(myMap.updated(poly, entries1), less1, params1) - } - - /** A new constraint with all entries coming from `pt` removed. */ def remove(pt: PolyType)(implicit ctx: Context): This = { val start = polyStart(pt) val skipped = pt.paramNames.length @@ -425,7 +407,7 @@ class TrackingConstraint(private val myMap: ParamInfo, def domainPolys: List[PolyType] = polyTypes.toList - def domainParams: List[PolyParam] = params.toList + def domainParams: List[PolyParam] = params.toList.filter(contains) def forallParams(p: PolyParam => Boolean): Boolean = { myMap.foreachBinding { (poly, entries) => @@ -454,7 +436,7 @@ class TrackingConstraint(private val myMap: ParamInfo, myMap.foreachBinding { (poly, entries) => for (i <- 0 until paramCount(entries)) { typeVar(entries, i) match { - case tv: TypeVar if isBounds(entries(i)) => myUninstVars += tv + case tv: TypeVar if !tv.inst.exists && isBounds(entries(i)) => myUninstVars += tv case _ => } } @@ -468,13 +450,8 @@ class TrackingConstraint(private val myMap: ParamInfo, private def checkNonCyclic(idx: Int)(implicit ctx: Context): Unit = assert(!less(idx)(idx), i"cyclic constraint involving ${params(idx)}") - def checkNonCyclic(pt: PolyType, entries: Array[Type])(implicit ctx: Context): Unit = - for (i <- entries.indices) checkNonCyclic(paramIndex(PolyParam(pt, i))) - def checkNonCyclic()(implicit ctx: Context): Unit = for (i <- params.indices) checkNonCyclic(i) - - def checkNonCyclicTrans()(implicit ctx: Context): Unit = checkNonCyclic() // ---------- toText ----------------------------------------------------- |