summaryrefslogtreecommitdiff
path: root/src/compiler/scala/tools/nsc/transform/patmat/Solving.scala
diff options
context:
space:
mode:
authorGerard Basler <gerard.basler@gmail.com>2015-03-12 01:47:47 +0100
committerAdriaan Moors <adriaan.moors@typesafe.com>2015-04-06 14:58:46 -0700
commitd44a86f432a7f9ca250b014acdeab02ac9f2c304 (patch)
tree3014edc291460cfe0dec4b5f7345a0ff1d174313 /src/compiler/scala/tools/nsc/transform/patmat/Solving.scala
parent214d79841970be29bac126eb48f955c8f082e1bc (diff)
downloadscala-d44a86f432a7f9ca250b014acdeab02ac9f2c304.tar.gz
scala-d44a86f432a7f9ca250b014acdeab02ac9f2c304.tar.bz2
scala-d44a86f432a7f9ca250b014acdeab02ac9f2c304.zip
Patmat: efficient reasoning about mutual exclusion
Faster analysis of wide (but relatively flat) class hierarchies by using a more efficient encoding of mutual exclusion. The old CNF encoding for mutually exclusive symbols of a domain added a quadratic number of clauses to the formula to satisfy. E.g. if a domain has the symbols `a`, `b` and `c` then the clauses ``` !a \/ !b /\ !a \/ !c /\ !b \/ !c ``` were added. The first line prevents that `a` and `b` are both true at the same time, etc. There's a simple, more efficient encoding that can be used instead: consider a comparator circuit in hardware, that checks that out of `n` signals, at most 1 is true. Such a circuit can be built in the form of a sequential counter and thus requires only 3n-4 additional clauses [1]. A comprehensible comparison of different encodings can be found in [2]. [1]: http://www.carstensinz.de/papers/CP-2005.pdf [2]: http://www.wv.inf.tu-dresden.de/Publications/2013/report-13-04.pdf
Diffstat (limited to 'src/compiler/scala/tools/nsc/transform/patmat/Solving.scala')
-rw-r--r--src/compiler/scala/tools/nsc/transform/patmat/Solving.scala79
1 files changed, 71 insertions, 8 deletions
diff --git a/src/compiler/scala/tools/nsc/transform/patmat/Solving.scala b/src/compiler/scala/tools/nsc/transform/patmat/Solving.scala
index c43f1b6209..9710c5c66b 100644
--- a/src/compiler/scala/tools/nsc/transform/patmat/Solving.scala
+++ b/src/compiler/scala/tools/nsc/transform/patmat/Solving.scala
@@ -65,11 +65,22 @@ trait Solving extends Logic {
def size = symbols.size
}
+ def cnfString(f: Array[Clause]): String
+
final case class Solvable(cnf: Cnf, symbolMapping: SymbolMapping) {
def ++(other: Solvable) = {
require(this.symbolMapping eq other.symbolMapping)
Solvable(cnf ++ other.cnf, symbolMapping)
}
+
+ override def toString: String = {
+ "Solvable\nLiterals:\n" +
+ (for {
+ (lit, sym) <- symbolMapping.symForVar.toSeq.sortBy(_._1)
+ } yield {
+ s"$lit -> $sym"
+ }).mkString("\n") + "Cnf:\n" + cnfString(cnf)
+ }
}
trait CnfBuilder {
@@ -140,20 +151,23 @@ trait Solving extends Logic {
def apply(p: Prop): Solvable = {
- def convert(p: Prop): Lit = {
+ def convert(p: Prop): Option[Lit] = {
p match {
case And(fv) =>
- and(fv.map(convert))
+ Some(and(fv.flatMap(convert)))
case Or(fv) =>
- or(fv.map(convert))
+ Some(or(fv.flatMap(convert)))
case Not(a) =>
- not(convert(a))
+ convert(a).map(not)
case sym: Sym =>
- convertSym(sym)
+ Some(convertSym(sym))
case True =>
- constTrue
+ Some(constTrue)
case False =>
- constFalse
+ Some(constFalse)
+ case AtMostOne(ops) =>
+ atMostOne(ops)
+ None
case _: Eq =>
throw new MatchError(p)
}
@@ -199,8 +213,57 @@ trait Solving extends Logic {
// no need for auxiliary variable
def not(a: Lit): Lit = -a
+ /**
+ * This encoding adds 3n-4 variables auxiliary variables
+ * to encode that at most 1 symbol can be set.
+ * See also "Towards an Optimal CNF Encoding of Boolean Cardinality Constraints"
+ * http://www.carstensinz.de/papers/CP-2005.pdf
+ */
+ def atMostOne(ops: List[Sym]) {
+ (ops: @unchecked) match {
+ case hd :: Nil => convertSym(hd)
+ case x1 :: tail =>
+ // sequential counter: 3n-4 clauses
+ // pairwise encoding: n*(n-1)/2 clauses
+ // thus pays off only if n > 5
+ if (ops.lengthCompare(5) > 0) {
+
+ @inline
+ def /\(a: Lit, b: Lit) = addClauseProcessed(clause(a, b))
+
+ val (mid, xn :: Nil) = tail.splitAt(tail.size - 1)
+
+ // 1 <= x1,...,xn <==>
+ //
+ // (!x1 \/ s1) /\ (!xn \/ !sn-1) /\
+ //
+ // /\
+ // / \ (!xi \/ si) /\ (!si-1 \/ si) /\ (!xi \/ !si-1)
+ // 1 < i < n
+ val s1 = newLiteral()
+ /\(-convertSym(x1), s1)
+ val snMinus = mid.foldLeft(s1) {
+ case (siMinus, sym) =>
+ val xi = convertSym(sym)
+ val si = newLiteral()
+ /\(-xi, si)
+ /\(-siMinus, si)
+ /\(-xi, -siMinus)
+ si
+ }
+ /\(-convertSym(xn), -snMinus)
+ } else {
+ ops.map(convertSym).combinations(2).foreach {
+ case a :: b :: Nil =>
+ addClauseProcessed(clause(-a, -b))
+ case _ =>
+ }
+ }
+ }
+ }
+
// add intermediate variable since we want the formula to be SAT!
- addClauseProcessed(clause(convert(p)))
+ addClauseProcessed(convert(p).toSet)
Solvable(buildCnf, symbolMapping)
}