summaryrefslogtreecommitdiff
path: root/src/compiler/scala/tools/nsc/transform/patmat/Logic.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/compiler/scala/tools/nsc/transform/patmat/Logic.scala')
-rw-r--r--src/compiler/scala/tools/nsc/transform/patmat/Logic.scala41
1 files changed, 38 insertions, 3 deletions
diff --git a/src/compiler/scala/tools/nsc/transform/patmat/Logic.scala b/src/compiler/scala/tools/nsc/transform/patmat/Logic.scala
index 4ea569c8e6..cef22d7d6b 100644
--- a/src/compiler/scala/tools/nsc/transform/patmat/Logic.scala
+++ b/src/compiler/scala/tools/nsc/transform/patmat/Logic.scala
@@ -10,6 +10,7 @@ package tools.nsc.transform.patmat
import scala.language.postfixOps
import scala.collection.mutable
import scala.reflect.internal.util.{NoPosition, Position, Statistics, HashSet}
+import scala.tools.nsc.Global
trait Logic extends Debugging {
import PatternMatchingStats._
@@ -90,6 +91,8 @@ trait Logic extends Debugging {
// compute the domain and return it (call registerNull first!)
def domainSyms: Option[Set[Sym]]
+ def groupedDomains: List[Set[Sym]]
+
// the symbol for this variable being equal to its statically known type
// (only available if registerEquality has been called for that type before)
def symForStaticTp: Option[Sym]
@@ -118,6 +121,9 @@ trait Logic extends Debugging {
final case class Not(a: Prop) extends Prop
+ // mutually exclusive (i.e., not more than one symbol is set)
+ final case class AtMostOne(ops: List[Sym]) extends Prop
+
case object True extends Prop
case object False extends Prop
@@ -192,7 +198,8 @@ trait Logic extends Debugging {
case Not(negated) => negationNormalFormNot(negated)
case True
| False
- | (_: Sym) => p
+ | (_: Sym)
+ | (_: AtMostOne) => p
}
def simplifyProp(p: Prop): Prop = p match {
@@ -252,6 +259,7 @@ trait Logic extends Debugging {
case Not(a) => apply(a)
case Eq(a, b) => applyVar(a); applyConst(b)
case s: Sym => applySymbol(s)
+ case AtMostOne(ops) => ops.foreach(applySymbol)
case _ =>
}
def applyVar(x: Var): Unit = {}
@@ -374,7 +382,23 @@ trait Logic extends Debugging {
// when sym is true, what must hold...
implied foreach (impliedSym => addAxiom(Or(Not(sym), impliedSym)))
// ... and what must not?
- excluded foreach (excludedSym => addAxiom(Or(Not(sym), Not(excludedSym))))
+ excluded foreach {
+ excludedSym =>
+ val related = Set(sym, excludedSym)
+ val exclusive = v.groupedDomains.exists {
+ domain => related subsetOf domain.toSet
+ }
+
+ // TODO: populate `v.exclusiveDomains` with `Set`s from the start, and optimize to:
+ // val exclusive = v.exclusiveDomains.exists { inDomain => inDomain(sym) && inDomain(excludedSym) }
+ if (!exclusive)
+ addAxiom(Or(Not(sym), Not(excludedSym)))
+ }
+ }
+
+ // all symbols in a domain are mutually exclusive
+ v.groupedDomains.foreach {
+ syms => if (syms.size > 1) addAxiom(AtMostOne(syms.toList))
}
}
@@ -449,7 +473,9 @@ trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis {
// once we go to run-time checks (on Const's), convert them to checkable types
// TODO: there seems to be bug for singleton domains (variable does not show up in model)
lazy val domain: Option[Set[Const]] = {
- val subConsts = enumerateSubtypes(staticTp).map{ tps =>
+ val subConsts =
+ enumerateSubtypes(staticTp, grouped = false)
+ .headOption.map { tps =>
tps.toSet[Type].map{ tp =>
val domainC = TypeConst(tp)
registerEquality(domainC)
@@ -467,6 +493,15 @@ trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis {
observed(); allConsts
}
+ lazy val groupedDomains: List[Set[Sym]] = {
+ val subtypes = enumerateSubtypes(staticTp, grouped = true)
+ subtypes.map {
+ subTypes =>
+ val syms = subTypes.flatMap(tpe => symForEqualsTo.get(TypeConst(tpe))).toSet
+ if (mayBeNull) syms + symForEqualsTo(NullConst) else syms
+ }.filter(_.nonEmpty)
+ }
+
// populate equalitySyms
// don't care about the result, but want only one fresh symbol per distinct constant c
def registerEquality(c: Const): Unit = {ensureCanModify(); symForEqualsTo getOrElseUpdate(c, Sym(this, c))}