summaryrefslogtreecommitdiff
path: root/src/compiler/scala/tools/nsc/transform/patmat
diff options
context:
space:
mode:
authorGerard Basler <gerard.basler@gmail.com>2015-03-12 01:47:47 +0100
committerAdriaan Moors <adriaan.moors@typesafe.com>2015-04-06 14:58:46 -0700
commitd44a86f432a7f9ca250b014acdeab02ac9f2c304 (patch)
tree3014edc291460cfe0dec4b5f7345a0ff1d174313 /src/compiler/scala/tools/nsc/transform/patmat
parent214d79841970be29bac126eb48f955c8f082e1bc (diff)
downloadscala-d44a86f432a7f9ca250b014acdeab02ac9f2c304.tar.gz
scala-d44a86f432a7f9ca250b014acdeab02ac9f2c304.tar.bz2
scala-d44a86f432a7f9ca250b014acdeab02ac9f2c304.zip
Patmat: efficient reasoning about mutual exclusion
Faster analysis of wide (but relatively flat) class hierarchies by using a more efficient encoding of mutual exclusion. The old CNF encoding for mutually exclusive symbols of a domain added a quadratic number of clauses to the formula to satisfy. E.g. if a domain has the symbols `a`, `b` and `c` then the clauses ``` !a \/ !b /\ !a \/ !c /\ !b \/ !c ``` were added. The first line prevents that `a` and `b` are both true at the same time, etc. There's a simple, more efficient encoding that can be used instead: consider a comparator circuit in hardware, that checks that out of `n` signals, at most 1 is true. Such a circuit can be built in the form of a sequential counter and thus requires only 3n-4 additional clauses [1]. A comprehensible comparison of different encodings can be found in [2]. [1]: http://www.carstensinz.de/papers/CP-2005.pdf [2]: http://www.wv.inf.tu-dresden.de/Publications/2013/report-13-04.pdf
Diffstat (limited to 'src/compiler/scala/tools/nsc/transform/patmat')
-rw-r--r--src/compiler/scala/tools/nsc/transform/patmat/Logic.scala41
-rw-r--r--src/compiler/scala/tools/nsc/transform/patmat/MatchAnalysis.scala86
-rw-r--r--src/compiler/scala/tools/nsc/transform/patmat/Solving.scala79
3 files changed, 165 insertions, 41 deletions
diff --git a/src/compiler/scala/tools/nsc/transform/patmat/Logic.scala b/src/compiler/scala/tools/nsc/transform/patmat/Logic.scala
index 4ea569c8e6..cef22d7d6b 100644
--- a/src/compiler/scala/tools/nsc/transform/patmat/Logic.scala
+++ b/src/compiler/scala/tools/nsc/transform/patmat/Logic.scala
@@ -10,6 +10,7 @@ package tools.nsc.transform.patmat
import scala.language.postfixOps
import scala.collection.mutable
import scala.reflect.internal.util.{NoPosition, Position, Statistics, HashSet}
+import scala.tools.nsc.Global
trait Logic extends Debugging {
import PatternMatchingStats._
@@ -90,6 +91,8 @@ trait Logic extends Debugging {
// compute the domain and return it (call registerNull first!)
def domainSyms: Option[Set[Sym]]
+ def groupedDomains: List[Set[Sym]]
+
// the symbol for this variable being equal to its statically known type
// (only available if registerEquality has been called for that type before)
def symForStaticTp: Option[Sym]
@@ -118,6 +121,9 @@ trait Logic extends Debugging {
final case class Not(a: Prop) extends Prop
+ // mutually exclusive (i.e., not more than one symbol is set)
+ final case class AtMostOne(ops: List[Sym]) extends Prop
+
case object True extends Prop
case object False extends Prop
@@ -192,7 +198,8 @@ trait Logic extends Debugging {
case Not(negated) => negationNormalFormNot(negated)
case True
| False
- | (_: Sym) => p
+ | (_: Sym)
+ | (_: AtMostOne) => p
}
def simplifyProp(p: Prop): Prop = p match {
@@ -252,6 +259,7 @@ trait Logic extends Debugging {
case Not(a) => apply(a)
case Eq(a, b) => applyVar(a); applyConst(b)
case s: Sym => applySymbol(s)
+ case AtMostOne(ops) => ops.foreach(applySymbol)
case _ =>
}
def applyVar(x: Var): Unit = {}
@@ -374,7 +382,23 @@ trait Logic extends Debugging {
// when sym is true, what must hold...
implied foreach (impliedSym => addAxiom(Or(Not(sym), impliedSym)))
// ... and what must not?
- excluded foreach (excludedSym => addAxiom(Or(Not(sym), Not(excludedSym))))
+ excluded foreach {
+ excludedSym =>
+ val related = Set(sym, excludedSym)
+ val exclusive = v.groupedDomains.exists {
+ domain => related subsetOf domain.toSet
+ }
+
+ // TODO: populate `v.exclusiveDomains` with `Set`s from the start, and optimize to:
+ // val exclusive = v.exclusiveDomains.exists { inDomain => inDomain(sym) && inDomain(excludedSym) }
+ if (!exclusive)
+ addAxiom(Or(Not(sym), Not(excludedSym)))
+ }
+ }
+
+ // all symbols in a domain are mutually exclusive
+ v.groupedDomains.foreach {
+ syms => if (syms.size > 1) addAxiom(AtMostOne(syms.toList))
}
}
@@ -449,7 +473,9 @@ trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis {
// once we go to run-time checks (on Const's), convert them to checkable types
// TODO: there seems to be bug for singleton domains (variable does not show up in model)
lazy val domain: Option[Set[Const]] = {
- val subConsts = enumerateSubtypes(staticTp).map{ tps =>
+ val subConsts =
+ enumerateSubtypes(staticTp, grouped = false)
+ .headOption.map { tps =>
tps.toSet[Type].map{ tp =>
val domainC = TypeConst(tp)
registerEquality(domainC)
@@ -467,6 +493,15 @@ trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis {
observed(); allConsts
}
+ lazy val groupedDomains: List[Set[Sym]] = {
+ val subtypes = enumerateSubtypes(staticTp, grouped = true)
+ subtypes.map {
+ subTypes =>
+ val syms = subTypes.flatMap(tpe => symForEqualsTo.get(TypeConst(tpe))).toSet
+ if (mayBeNull) syms + symForEqualsTo(NullConst) else syms
+ }.filter(_.nonEmpty)
+ }
+
// populate equalitySyms
// don't care about the result, but want only one fresh symbol per distinct constant c
def registerEquality(c: Const): Unit = {ensureCanModify(); symForEqualsTo getOrElseUpdate(c, Sym(this, c))}
diff --git a/src/compiler/scala/tools/nsc/transform/patmat/MatchAnalysis.scala b/src/compiler/scala/tools/nsc/transform/patmat/MatchAnalysis.scala
index cecb5c37be..a11906ace1 100644
--- a/src/compiler/scala/tools/nsc/transform/patmat/MatchAnalysis.scala
+++ b/src/compiler/scala/tools/nsc/transform/patmat/MatchAnalysis.scala
@@ -95,58 +95,84 @@ trait TreeAndTypeAnalysis extends Debugging {
val typer: Typer
// TODO: domain of other feasibly enumerable built-in types (char?)
- def enumerateSubtypes(tp: Type): Option[List[Type]] =
+ def enumerateSubtypes(tp: Type, grouped: Boolean): List[List[Type]] =
tp.typeSymbol match {
// TODO case _ if tp.isTupleType => // recurse into component types?
- case UnitClass =>
- Some(List(UnitTpe))
- case BooleanClass =>
- Some(ConstantTrue :: ConstantFalse :: Nil)
+ case UnitClass if !grouped =>
+ List(List(UnitTpe))
+ case BooleanClass if !grouped =>
+ List(ConstantTrue :: ConstantFalse :: Nil)
// TODO case _ if tp.isTupleType => // recurse into component types
- case modSym: ModuleClassSymbol =>
- Some(List(tp))
+ case modSym: ModuleClassSymbol if !grouped =>
+ List(List(tp))
case sym: RefinementClassSymbol =>
- val parentSubtypes: List[Option[List[Type]]] = tp.parents.map(parent => enumerateSubtypes(parent))
- if (parentSubtypes exists (_.isDefined))
+ val parentSubtypes = tp.parents.flatMap(parent => enumerateSubtypes(parent, grouped))
+ if (parentSubtypes exists (_.nonEmpty)) {
// If any of the parents is enumerable, then the refinement type is enumerable.
- Some(
- // We must only include subtypes of the parents that conform to `tp`.
- // See neg/virtpatmat_exhaust_compound.scala for an example.
- parentSubtypes flatMap (_.getOrElse(Nil)) filter (_ <:< tp)
- )
- else None
+ // We must only include subtypes of the parents that conform to `tp`.
+ // See neg/virtpatmat_exhaust_compound.scala for an example.
+ parentSubtypes map (_.filter(_ <:< tp))
+ }
+ else Nil
// make sure it's not a primitive, else (5: Byte) match { case 5 => ... } sees no Byte
case sym if sym.isSealed =>
- val subclasses = debug.patmatResult(s"enum $sym sealed, subclasses")(
- // symbols which are both sealed and abstract need not be covered themselves, because
- // all of their children must be and they cannot otherwise be created.
- sym.sealedDescendants.toList
- sortBy (_.sealedSortName)
- filterNot (x => x.isSealed && x.isAbstractClass && !isPrimitiveValueClass(x))
- )
val tpApprox = typer.infer.approximateAbstracts(tp)
val pre = tpApprox.prefix
- Some(debug.patmatResult(s"enum sealed tp=$tp, tpApprox=$tpApprox as") {
- // valid subtypes are turned into checkable types, as we are entering the realm of the dynamic
- subclasses flatMap { sym =>
+ def filterChildren(children: List[Symbol]): List[Type] = {
+ children flatMap { sym =>
// have to filter out children which cannot match: see ticket #3683 for an example
// compare to the fully known type `tp` (modulo abstract types),
// so that we can rule out stuff like: sealed trait X[T]; class XInt extends X[Int] --> XInt not valid when enumerating X[String]
// however, must approximate abstract types in
- val memberType = nestedMemberType(sym, pre, tpApprox.typeSymbol.owner)
- val subTp = appliedType(memberType, sym.typeParams.map(_ => WildcardType))
+ val memberType = nestedMemberType(sym, pre, tpApprox.typeSymbol.owner)
+ val subTp = appliedType(memberType, sym.typeParams.map(_ => WildcardType))
val subTpApprox = typer.infer.approximateAbstracts(subTp) // TODO: needed?
// debug.patmat("subtp"+(subTpApprox <:< tpApprox, subTpApprox, tpApprox))
if (subTpApprox <:< tpApprox) Some(checkableType(subTp))
else None
}
- })
+ }
+
+ if(grouped) {
+ def enumerateChildren(sym: Symbol) = {
+ sym.children.toList
+ .sortBy(_.sealedSortName)
+ .filterNot(x => x.isSealed && x.isAbstractClass && !isPrimitiveValueClass(x))
+ }
+
+ // enumerate only direct subclasses,
+ // subclasses of subclasses are enumerated in the next iteration
+ // and added to a new group
+ def groupChildren(wl: List[Symbol],
+ acc: List[List[Type]]): List[List[Type]] = wl match {
+ case hd :: tl =>
+ val children = enumerateChildren(hd)
+ groupChildren(tl ++ children, acc :+ filterChildren(children))
+ case Nil => acc
+ }
+
+ groupChildren(sym :: Nil, Nil)
+ } else {
+ val subclasses = debug.patmatResult(s"enum $sym sealed, subclasses")(
+ // symbols which are both sealed and abstract need not be covered themselves, because
+ // all of their children must be and they cannot otherwise be created.
+ sym.sealedDescendants.toList
+ sortBy (_.sealedSortName)
+ filterNot (x => x.isSealed && x.isAbstractClass && !isPrimitiveValueClass(x))
+ )
+
+ List(debug.patmatResult(s"enum sealed tp=$tp, tpApprox=$tpApprox as") {
+ // valid subtypes are turned into checkable types, as we are entering the realm of the dynamic
+ filterChildren(subclasses)
+ })
+ }
+
case sym =>
debug.patmat("enum unsealed "+ ((tp, sym, sym.isSealed, isPrimitiveValueClass(sym))))
- None
+ Nil
}
// approximate a type to the static type that is fully checkable at run time,
@@ -176,7 +202,7 @@ trait TreeAndTypeAnalysis extends Debugging {
def uncheckableType(tp: Type): Boolean = {
val checkable = (
(isTupleType(tp) && tupleComponents(tp).exists(tp => !uncheckableType(tp)))
- || enumerateSubtypes(tp).nonEmpty)
+ || enumerateSubtypes(tp, grouped = false).nonEmpty)
// if (!checkable) debug.patmat("deemed uncheckable: "+ tp)
!checkable
}
diff --git a/src/compiler/scala/tools/nsc/transform/patmat/Solving.scala b/src/compiler/scala/tools/nsc/transform/patmat/Solving.scala
index c43f1b6209..9710c5c66b 100644
--- a/src/compiler/scala/tools/nsc/transform/patmat/Solving.scala
+++ b/src/compiler/scala/tools/nsc/transform/patmat/Solving.scala
@@ -65,11 +65,22 @@ trait Solving extends Logic {
def size = symbols.size
}
+ def cnfString(f: Array[Clause]): String
+
final case class Solvable(cnf: Cnf, symbolMapping: SymbolMapping) {
def ++(other: Solvable) = {
require(this.symbolMapping eq other.symbolMapping)
Solvable(cnf ++ other.cnf, symbolMapping)
}
+
+ override def toString: String = {
+ "Solvable\nLiterals:\n" +
+ (for {
+ (lit, sym) <- symbolMapping.symForVar.toSeq.sortBy(_._1)
+ } yield {
+ s"$lit -> $sym"
+ }).mkString("\n") + "Cnf:\n" + cnfString(cnf)
+ }
}
trait CnfBuilder {
@@ -140,20 +151,23 @@ trait Solving extends Logic {
def apply(p: Prop): Solvable = {
- def convert(p: Prop): Lit = {
+ def convert(p: Prop): Option[Lit] = {
p match {
case And(fv) =>
- and(fv.map(convert))
+ Some(and(fv.flatMap(convert)))
case Or(fv) =>
- or(fv.map(convert))
+ Some(or(fv.flatMap(convert)))
case Not(a) =>
- not(convert(a))
+ convert(a).map(not)
case sym: Sym =>
- convertSym(sym)
+ Some(convertSym(sym))
case True =>
- constTrue
+ Some(constTrue)
case False =>
- constFalse
+ Some(constFalse)
+ case AtMostOne(ops) =>
+ atMostOne(ops)
+ None
case _: Eq =>
throw new MatchError(p)
}
@@ -199,8 +213,57 @@ trait Solving extends Logic {
// no need for auxiliary variable
def not(a: Lit): Lit = -a
+ /**
+ * This encoding adds 3n-4 variables auxiliary variables
+ * to encode that at most 1 symbol can be set.
+ * See also "Towards an Optimal CNF Encoding of Boolean Cardinality Constraints"
+ * http://www.carstensinz.de/papers/CP-2005.pdf
+ */
+ def atMostOne(ops: List[Sym]) {
+ (ops: @unchecked) match {
+ case hd :: Nil => convertSym(hd)
+ case x1 :: tail =>
+ // sequential counter: 3n-4 clauses
+ // pairwise encoding: n*(n-1)/2 clauses
+ // thus pays off only if n > 5
+ if (ops.lengthCompare(5) > 0) {
+
+ @inline
+ def /\(a: Lit, b: Lit) = addClauseProcessed(clause(a, b))
+
+ val (mid, xn :: Nil) = tail.splitAt(tail.size - 1)
+
+ // 1 <= x1,...,xn <==>
+ //
+ // (!x1 \/ s1) /\ (!xn \/ !sn-1) /\
+ //
+ // /\
+ // / \ (!xi \/ si) /\ (!si-1 \/ si) /\ (!xi \/ !si-1)
+ // 1 < i < n
+ val s1 = newLiteral()
+ /\(-convertSym(x1), s1)
+ val snMinus = mid.foldLeft(s1) {
+ case (siMinus, sym) =>
+ val xi = convertSym(sym)
+ val si = newLiteral()
+ /\(-xi, si)
+ /\(-siMinus, si)
+ /\(-xi, -siMinus)
+ si
+ }
+ /\(-convertSym(xn), -snMinus)
+ } else {
+ ops.map(convertSym).combinations(2).foreach {
+ case a :: b :: Nil =>
+ addClauseProcessed(clause(-a, -b))
+ case _ =>
+ }
+ }
+ }
+ }
+
// add intermediate variable since we want the formula to be SAT!
- addClauseProcessed(clause(convert(p)))
+ addClauseProcessed(convert(p).toSet)
Solvable(buildCnf, symbolMapping)
}