From d44a86f432a7f9ca250b014acdeab02ac9f2c304 Mon Sep 17 00:00:00 2001 From: Gerard Basler Date: Thu, 12 Mar 2015 01:47:47 +0100 Subject: Patmat: efficient reasoning about mutual exclusion Faster analysis of wide (but relatively flat) class hierarchies by using a more efficient encoding of mutual exclusion. The old CNF encoding for mutually exclusive symbols of a domain added a quadratic number of clauses to the formula to satisfy. E.g. if a domain has the symbols `a`, `b` and `c` then the clauses ``` !a \/ !b /\ !a \/ !c /\ !b \/ !c ``` were added. The first line prevents that `a` and `b` are both true at the same time, etc. There's a simple, more efficient encoding that can be used instead: consider a comparator circuit in hardware, that checks that out of `n` signals, at most 1 is true. Such a circuit can be built in the form of a sequential counter and thus requires only 3n-4 additional clauses [1]. A comprehensible comparison of different encodings can be found in [2]. [1]: http://www.carstensinz.de/papers/CP-2005.pdf [2]: http://www.wv.inf.tu-dresden.de/Publications/2013/report-13-04.pdf --- .../scala/tools/nsc/transform/patmat/Logic.scala | 41 ++++++++++- .../tools/nsc/transform/patmat/MatchAnalysis.scala | 86 ++++++++++++++-------- .../scala/tools/nsc/transform/patmat/Solving.scala | 79 ++++++++++++++++++-- 3 files changed, 165 insertions(+), 41 deletions(-) (limited to 'src') diff --git a/src/compiler/scala/tools/nsc/transform/patmat/Logic.scala b/src/compiler/scala/tools/nsc/transform/patmat/Logic.scala index 4ea569c8e6..cef22d7d6b 100644 --- a/src/compiler/scala/tools/nsc/transform/patmat/Logic.scala +++ b/src/compiler/scala/tools/nsc/transform/patmat/Logic.scala @@ -10,6 +10,7 @@ package tools.nsc.transform.patmat import scala.language.postfixOps import scala.collection.mutable import scala.reflect.internal.util.{NoPosition, Position, Statistics, HashSet} +import scala.tools.nsc.Global trait Logic extends Debugging { import PatternMatchingStats._ @@ -90,6 +91,8 @@ trait Logic extends Debugging { // compute the domain and return it (call registerNull first!) def domainSyms: Option[Set[Sym]] + def groupedDomains: List[Set[Sym]] + // the symbol for this variable being equal to its statically known type // (only available if registerEquality has been called for that type before) def symForStaticTp: Option[Sym] @@ -118,6 +121,9 @@ trait Logic extends Debugging { final case class Not(a: Prop) extends Prop + // mutually exclusive (i.e., not more than one symbol is set) + final case class AtMostOne(ops: List[Sym]) extends Prop + case object True extends Prop case object False extends Prop @@ -192,7 +198,8 @@ trait Logic extends Debugging { case Not(negated) => negationNormalFormNot(negated) case True | False - | (_: Sym) => p + | (_: Sym) + | (_: AtMostOne) => p } def simplifyProp(p: Prop): Prop = p match { @@ -252,6 +259,7 @@ trait Logic extends Debugging { case Not(a) => apply(a) case Eq(a, b) => applyVar(a); applyConst(b) case s: Sym => applySymbol(s) + case AtMostOne(ops) => ops.foreach(applySymbol) case _ => } def applyVar(x: Var): Unit = {} @@ -374,7 +382,23 @@ trait Logic extends Debugging { // when sym is true, what must hold... implied foreach (impliedSym => addAxiom(Or(Not(sym), impliedSym))) // ... and what must not? - excluded foreach (excludedSym => addAxiom(Or(Not(sym), Not(excludedSym)))) + excluded foreach { + excludedSym => + val related = Set(sym, excludedSym) + val exclusive = v.groupedDomains.exists { + domain => related subsetOf domain.toSet + } + + // TODO: populate `v.exclusiveDomains` with `Set`s from the start, and optimize to: + // val exclusive = v.exclusiveDomains.exists { inDomain => inDomain(sym) && inDomain(excludedSym) } + if (!exclusive) + addAxiom(Or(Not(sym), Not(excludedSym))) + } + } + + // all symbols in a domain are mutually exclusive + v.groupedDomains.foreach { + syms => if (syms.size > 1) addAxiom(AtMostOne(syms.toList)) } } @@ -449,7 +473,9 @@ trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis { // once we go to run-time checks (on Const's), convert them to checkable types // TODO: there seems to be bug for singleton domains (variable does not show up in model) lazy val domain: Option[Set[Const]] = { - val subConsts = enumerateSubtypes(staticTp).map{ tps => + val subConsts = + enumerateSubtypes(staticTp, grouped = false) + .headOption.map { tps => tps.toSet[Type].map{ tp => val domainC = TypeConst(tp) registerEquality(domainC) @@ -467,6 +493,15 @@ trait ScalaLogic extends Interface with Logic with TreeAndTypeAnalysis { observed(); allConsts } + lazy val groupedDomains: List[Set[Sym]] = { + val subtypes = enumerateSubtypes(staticTp, grouped = true) + subtypes.map { + subTypes => + val syms = subTypes.flatMap(tpe => symForEqualsTo.get(TypeConst(tpe))).toSet + if (mayBeNull) syms + symForEqualsTo(NullConst) else syms + }.filter(_.nonEmpty) + } + // populate equalitySyms // don't care about the result, but want only one fresh symbol per distinct constant c def registerEquality(c: Const): Unit = {ensureCanModify(); symForEqualsTo getOrElseUpdate(c, Sym(this, c))} diff --git a/src/compiler/scala/tools/nsc/transform/patmat/MatchAnalysis.scala b/src/compiler/scala/tools/nsc/transform/patmat/MatchAnalysis.scala index cecb5c37be..a11906ace1 100644 --- a/src/compiler/scala/tools/nsc/transform/patmat/MatchAnalysis.scala +++ b/src/compiler/scala/tools/nsc/transform/patmat/MatchAnalysis.scala @@ -95,58 +95,84 @@ trait TreeAndTypeAnalysis extends Debugging { val typer: Typer // TODO: domain of other feasibly enumerable built-in types (char?) - def enumerateSubtypes(tp: Type): Option[List[Type]] = + def enumerateSubtypes(tp: Type, grouped: Boolean): List[List[Type]] = tp.typeSymbol match { // TODO case _ if tp.isTupleType => // recurse into component types? - case UnitClass => - Some(List(UnitTpe)) - case BooleanClass => - Some(ConstantTrue :: ConstantFalse :: Nil) + case UnitClass if !grouped => + List(List(UnitTpe)) + case BooleanClass if !grouped => + List(ConstantTrue :: ConstantFalse :: Nil) // TODO case _ if tp.isTupleType => // recurse into component types - case modSym: ModuleClassSymbol => - Some(List(tp)) + case modSym: ModuleClassSymbol if !grouped => + List(List(tp)) case sym: RefinementClassSymbol => - val parentSubtypes: List[Option[List[Type]]] = tp.parents.map(parent => enumerateSubtypes(parent)) - if (parentSubtypes exists (_.isDefined)) + val parentSubtypes = tp.parents.flatMap(parent => enumerateSubtypes(parent, grouped)) + if (parentSubtypes exists (_.nonEmpty)) { // If any of the parents is enumerable, then the refinement type is enumerable. - Some( - // We must only include subtypes of the parents that conform to `tp`. - // See neg/virtpatmat_exhaust_compound.scala for an example. - parentSubtypes flatMap (_.getOrElse(Nil)) filter (_ <:< tp) - ) - else None + // We must only include subtypes of the parents that conform to `tp`. + // See neg/virtpatmat_exhaust_compound.scala for an example. + parentSubtypes map (_.filter(_ <:< tp)) + } + else Nil // make sure it's not a primitive, else (5: Byte) match { case 5 => ... } sees no Byte case sym if sym.isSealed => - val subclasses = debug.patmatResult(s"enum $sym sealed, subclasses")( - // symbols which are both sealed and abstract need not be covered themselves, because - // all of their children must be and they cannot otherwise be created. - sym.sealedDescendants.toList - sortBy (_.sealedSortName) - filterNot (x => x.isSealed && x.isAbstractClass && !isPrimitiveValueClass(x)) - ) val tpApprox = typer.infer.approximateAbstracts(tp) val pre = tpApprox.prefix - Some(debug.patmatResult(s"enum sealed tp=$tp, tpApprox=$tpApprox as") { - // valid subtypes are turned into checkable types, as we are entering the realm of the dynamic - subclasses flatMap { sym => + def filterChildren(children: List[Symbol]): List[Type] = { + children flatMap { sym => // have to filter out children which cannot match: see ticket #3683 for an example // compare to the fully known type `tp` (modulo abstract types), // so that we can rule out stuff like: sealed trait X[T]; class XInt extends X[Int] --> XInt not valid when enumerating X[String] // however, must approximate abstract types in - val memberType = nestedMemberType(sym, pre, tpApprox.typeSymbol.owner) - val subTp = appliedType(memberType, sym.typeParams.map(_ => WildcardType)) + val memberType = nestedMemberType(sym, pre, tpApprox.typeSymbol.owner) + val subTp = appliedType(memberType, sym.typeParams.map(_ => WildcardType)) val subTpApprox = typer.infer.approximateAbstracts(subTp) // TODO: needed? // debug.patmat("subtp"+(subTpApprox <:< tpApprox, subTpApprox, tpApprox)) if (subTpApprox <:< tpApprox) Some(checkableType(subTp)) else None } - }) + } + + if(grouped) { + def enumerateChildren(sym: Symbol) = { + sym.children.toList + .sortBy(_.sealedSortName) + .filterNot(x => x.isSealed && x.isAbstractClass && !isPrimitiveValueClass(x)) + } + + // enumerate only direct subclasses, + // subclasses of subclasses are enumerated in the next iteration + // and added to a new group + def groupChildren(wl: List[Symbol], + acc: List[List[Type]]): List[List[Type]] = wl match { + case hd :: tl => + val children = enumerateChildren(hd) + groupChildren(tl ++ children, acc :+ filterChildren(children)) + case Nil => acc + } + + groupChildren(sym :: Nil, Nil) + } else { + val subclasses = debug.patmatResult(s"enum $sym sealed, subclasses")( + // symbols which are both sealed and abstract need not be covered themselves, because + // all of their children must be and they cannot otherwise be created. + sym.sealedDescendants.toList + sortBy (_.sealedSortName) + filterNot (x => x.isSealed && x.isAbstractClass && !isPrimitiveValueClass(x)) + ) + + List(debug.patmatResult(s"enum sealed tp=$tp, tpApprox=$tpApprox as") { + // valid subtypes are turned into checkable types, as we are entering the realm of the dynamic + filterChildren(subclasses) + }) + } + case sym => debug.patmat("enum unsealed "+ ((tp, sym, sym.isSealed, isPrimitiveValueClass(sym)))) - None + Nil } // approximate a type to the static type that is fully checkable at run time, @@ -176,7 +202,7 @@ trait TreeAndTypeAnalysis extends Debugging { def uncheckableType(tp: Type): Boolean = { val checkable = ( (isTupleType(tp) && tupleComponents(tp).exists(tp => !uncheckableType(tp))) - || enumerateSubtypes(tp).nonEmpty) + || enumerateSubtypes(tp, grouped = false).nonEmpty) // if (!checkable) debug.patmat("deemed uncheckable: "+ tp) !checkable } diff --git a/src/compiler/scala/tools/nsc/transform/patmat/Solving.scala b/src/compiler/scala/tools/nsc/transform/patmat/Solving.scala index c43f1b6209..9710c5c66b 100644 --- a/src/compiler/scala/tools/nsc/transform/patmat/Solving.scala +++ b/src/compiler/scala/tools/nsc/transform/patmat/Solving.scala @@ -65,11 +65,22 @@ trait Solving extends Logic { def size = symbols.size } + def cnfString(f: Array[Clause]): String + final case class Solvable(cnf: Cnf, symbolMapping: SymbolMapping) { def ++(other: Solvable) = { require(this.symbolMapping eq other.symbolMapping) Solvable(cnf ++ other.cnf, symbolMapping) } + + override def toString: String = { + "Solvable\nLiterals:\n" + + (for { + (lit, sym) <- symbolMapping.symForVar.toSeq.sortBy(_._1) + } yield { + s"$lit -> $sym" + }).mkString("\n") + "Cnf:\n" + cnfString(cnf) + } } trait CnfBuilder { @@ -140,20 +151,23 @@ trait Solving extends Logic { def apply(p: Prop): Solvable = { - def convert(p: Prop): Lit = { + def convert(p: Prop): Option[Lit] = { p match { case And(fv) => - and(fv.map(convert)) + Some(and(fv.flatMap(convert))) case Or(fv) => - or(fv.map(convert)) + Some(or(fv.flatMap(convert))) case Not(a) => - not(convert(a)) + convert(a).map(not) case sym: Sym => - convertSym(sym) + Some(convertSym(sym)) case True => - constTrue + Some(constTrue) case False => - constFalse + Some(constFalse) + case AtMostOne(ops) => + atMostOne(ops) + None case _: Eq => throw new MatchError(p) } @@ -199,8 +213,57 @@ trait Solving extends Logic { // no need for auxiliary variable def not(a: Lit): Lit = -a + /** + * This encoding adds 3n-4 variables auxiliary variables + * to encode that at most 1 symbol can be set. + * See also "Towards an Optimal CNF Encoding of Boolean Cardinality Constraints" + * http://www.carstensinz.de/papers/CP-2005.pdf + */ + def atMostOne(ops: List[Sym]) { + (ops: @unchecked) match { + case hd :: Nil => convertSym(hd) + case x1 :: tail => + // sequential counter: 3n-4 clauses + // pairwise encoding: n*(n-1)/2 clauses + // thus pays off only if n > 5 + if (ops.lengthCompare(5) > 0) { + + @inline + def /\(a: Lit, b: Lit) = addClauseProcessed(clause(a, b)) + + val (mid, xn :: Nil) = tail.splitAt(tail.size - 1) + + // 1 <= x1,...,xn <==> + // + // (!x1 \/ s1) /\ (!xn \/ !sn-1) /\ + // + // /\ + // / \ (!xi \/ si) /\ (!si-1 \/ si) /\ (!xi \/ !si-1) + // 1 < i < n + val s1 = newLiteral() + /\(-convertSym(x1), s1) + val snMinus = mid.foldLeft(s1) { + case (siMinus, sym) => + val xi = convertSym(sym) + val si = newLiteral() + /\(-xi, si) + /\(-siMinus, si) + /\(-xi, -siMinus) + si + } + /\(-convertSym(xn), -snMinus) + } else { + ops.map(convertSym).combinations(2).foreach { + case a :: b :: Nil => + addClauseProcessed(clause(-a, -b)) + case _ => + } + } + } + } + // add intermediate variable since we want the formula to be SAT! - addClauseProcessed(clause(convert(p))) + addClauseProcessed(convert(p).toSet) Solvable(buildCnf, symbolMapping) } -- cgit v1.2.3