summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorLi Haoyi <haoyi.sg@gmail.com>2017-10-27 08:15:01 -0700
committerLi Haoyi <haoyi.sg@gmail.com>2017-10-27 08:15:01 -0700
commitf53db8482c86f30c917d16b6312ad4804b37f2df (patch)
tree8758d12808473a6f77e3465243ee848d55272e75 /src
parentafd5de935cb11394848c5040909f2f02fc26335f (diff)
downloadmill-f53db8482c86f30c917d16b6312ad4804b37f2df.tar.gz
mill-f53db8482c86f30c917d16b6312ad4804b37f2df.tar.bz2
mill-f53db8482c86f30c917d16b6312ad4804b37f2df.zip
Migrate everything which shouldn't have duplicates over to a new `OSet` data structure
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/forge/Evaluator.scala134
-rw-r--r--src/main/scala/forge/Util.scala78
-rw-r--r--src/test/scala/forge/EvaluationTests.scala46
-rw-r--r--src/test/scala/forge/GraphTests.scala78
-rw-r--r--src/test/scala/forge/TestUtil.scala4
5 files changed, 229 insertions, 111 deletions
diff --git a/src/main/scala/forge/Evaluator.scala b/src/main/scala/forge/Evaluator.scala
index 2a3e4470..ecbf4260 100644
--- a/src/main/scala/forge/Evaluator.scala
+++ b/src/main/scala/forge/Evaluator.scala
@@ -10,44 +10,81 @@ import scala.collection.mutable
class Evaluator(workspacePath: jnio.Path,
enclosingBase: DefCtx){
- val resultCache = mutable.Map.empty[String, (Int, String)]
- def evaluate(targets: Seq[Target[_]]): Evaluator.Results = {
+ /**
+ * Cache from the ID of the first terminal target in a group to the has of
+ * all the group's distinct inputs, and the results of the possibly-multiple
+ * terminal nodes
+ */
+ val resultCache = mutable.Map.empty[String, (Int, Seq[String])]
+ def evaluate(targets: OSet[Target[_]]): Evaluator.Results = {
jnio.Files.createDirectories(workspacePath)
- val sortedTargets = Evaluator.topoSortedTransitiveTargets(targets)
- pprint.log(sortedTargets.values)
- val evaluated = mutable.Buffer.empty[Target[_]]
+ val sortedGroups = Evaluator.groupAroundNamedTargets(
+ Evaluator.topoSortedTransitiveTargets(targets)
+ )
+
+ val evaluated = new MutableOSet[Target[_]]
val results = mutable.Map.empty[Target[_], Any]
- for (target <- sortedTargets.values){
- val inputResults = target.inputs.map(results).toIndexedSeq
-
- val enclosingStr = target.defCtx.label
- val targetDestPath = workspacePath.resolve(
- jnio.Paths.get(enclosingStr.stripSuffix(enclosingBase.label))
- )
- deleteRec(targetDestPath)
-
- val inputsHash = inputResults.hashCode
- (target.dirty, resultCache.get(target.defCtx.label)) match{
- case (Some(dirtyCheck), Some((hash, res)))
- if hash == inputsHash && !dirtyCheck() =>
- results(target) = target.formatter.reads(Json.parse(res)).get
-
- case _ =>
- evaluated.append(target)
+ for (group <- sortedGroups){
+ val (newResults, newEvaluated) = evaluateGroup(group, results)
+ evaluated.appendAll(newEvaluated)
+ for((k, v) <- newResults) results.put(k, v)
+
+ }
+ Evaluator.Results(targets.items.map(results), evaluated)
+ }
+
+ def evaluateGroup(group: OSet[Target[_]],
+ results: collection.Map[Target[_], Any]) = {
+ val allInputs = group.items.flatMap(_.inputs)
+ val (internalInputs, externalInputs) = allInputs.partition(group.contains)
+ val internalInputSet = internalInputs.toSet
+ val inputResults = externalInputs.distinct.map(results).toIndexedSeq
+
+ val newResults = mutable.Map.empty[Target[_], Any]
+ val newEvaluated = mutable.Buffer.empty[Target[_]]
+
+ val terminals = group.filter(!internalInputSet(_))
+ val primeTerminal = terminals.items(0)
+ val enclosingStr = primeTerminal.defCtx.label
+ val targetDestPath = workspacePath.resolve(
+ jnio.Paths.get(enclosingStr.stripSuffix(enclosingBase.label))
+ )
+ deleteRec(targetDestPath)
+
+ val inputsHash = inputResults.hashCode
+ (primeTerminal.dirty, resultCache.get(primeTerminal.defCtx.label)) match{
+ case (Some(dirtyCheck), Some((hash, terminalResults)))
+ if hash == inputsHash && !dirtyCheck() =>
+ for((terminal, res) <- terminals.items.zip(terminalResults)){
+
+ newResults(terminal) = primeTerminal.formatter.reads(Json.parse(res)).get
+ }
+
+ case _ =>
+ val terminalResults = mutable.Buffer.empty[String]
+ for(target <- group.items){
+
+ newEvaluated.append(target)
if (target.defCtx.anonId.isDefined && target.dirty.isEmpty) {
val res = target.evaluate(new Args(inputResults, targetDestPath))
- results(target) = res
+ newResults(target) = res
}else{
- val (res, serialized) = target.evaluateAndWrite(new Args(inputResults, targetDestPath))
- resultCache(target.defCtx.label) = (inputsHash, serialized)
- results(target) = res
+ val (res, serialized) = target.evaluateAndWrite(
+ new Args(inputResults, targetDestPath)
+ )
+ if (!internalInputSet(target)){
+ terminalResults.append(serialized)
+
+ }
+ newResults(target) = res
}
-
- }
+ }
+ resultCache(primeTerminal.defCtx.label) = (inputsHash, terminalResults)
}
- Evaluator.Results(targets.map(results), evaluated)
+
+ (newResults, newEvaluated)
}
def deleteRec(path: jnio.Path) = {
if (jnio.Files.exists(path)){
@@ -63,14 +100,14 @@ class Evaluator(workspacePath: jnio.Path,
object Evaluator{
- class TopoSorted private[Evaluator] (val values: Seq[Target[_]])
- case class Results(values: Seq[Any], evaluated: Seq[Target[_]])
- def groupAroundNamedTargets(topoSortedTargets: TopoSorted): Seq[Seq[Target[_]]] = {
+ class TopoSorted private[Evaluator] (val values: OSet[Target[_]])
+ case class Results(values: Seq[Any], evaluated: OSet[Target[_]])
+ def groupAroundNamedTargets(topoSortedTargets: TopoSorted): OSet[OSet[Target[_]]] = {
val grouping = new MultiBiMap[Int, Target[_]]()
var groupCount = 0
- for(target <- topoSortedTargets.values.reverseIterator){
+ for(target <- topoSortedTargets.values.items.reverseIterator){
if (!grouping.containsValue(target)){
grouping.add(groupCount, target)
@@ -90,40 +127,43 @@ object Evaluator{
}
}
}
- val output = mutable.Buffer.empty[Seq[Target[_]]]
- for(target <- topoSortedTargets.values.reverseIterator){
+ val output = mutable.Buffer.empty[OSet[Target[_]]]
+ for(target <- topoSortedTargets.values.items.reverseIterator){
for(targetGroup <- grouping.lookupValueOpt(target)){
- output.append(grouping.removeAll(targetGroup))
+ output.append(
+ OSet.from(
+ grouping.removeAll(targetGroup)
+ .sortBy(topoSortedTargets.values.items.indexOf)
+ )
+ )
}
}
- output.map(_.sortBy(topoSortedTargets.values.indexOf)).reverse
+ OSet.from(output.reverse)
}
/**
* Takes the given targets, finds
*/
- def topoSortedTransitiveTargets(sourceTargets: Seq[Target[_]]): TopoSorted = {
- val transitiveTargetSet = mutable.Set.empty[Target[_]]
- val transitiveTargets = mutable.Buffer.empty[Target[_]]
+ def topoSortedTransitiveTargets(sourceTargets: OSet[Target[_]]): TopoSorted = {
+ val transitiveTargets = new MutableOSet[Target[_]]
def rec(t: Target[_]): Unit = {
- if (transitiveTargetSet.contains(t)) () // do nothing
+ if (transitiveTargets.contains(t)) () // do nothing
else {
- transitiveTargetSet.add(t)
transitiveTargets.append(t)
t.inputs.foreach(rec)
}
}
- sourceTargets.foreach(rec)
- val targetIndices = transitiveTargets.zipWithIndex.toMap
+ sourceTargets.items.foreach(rec)
+ val targetIndices = transitiveTargets.items.zipWithIndex.toMap
val numberedEdges =
- for(i <- transitiveTargets.indices)
- yield transitiveTargets(i).inputs.map(targetIndices)
+ for(t <- transitiveTargets.items)
+ yield t.inputs.map(targetIndices)
val sortedClusters = Tarjans(numberedEdges)
val nonTrivialClusters = sortedClusters.filter(_.length > 1)
assert(nonTrivialClusters.isEmpty, nonTrivialClusters)
- new TopoSorted(sortedClusters.flatten.map(transitiveTargets))
+ new TopoSorted(OSet.from(sortedClusters.flatten.map(transitiveTargets.items)))
}
} \ No newline at end of file
diff --git a/src/main/scala/forge/Util.scala b/src/main/scala/forge/Util.scala
index 15d3e176..e558e975 100644
--- a/src/main/scala/forge/Util.scala
+++ b/src/main/scala/forge/Util.scala
@@ -32,6 +32,84 @@ class MultiBiMap[K, V](){
keyToValues(k) = vs ++: keyToValues.getOrElse(k, Nil)
}
}
+
+/**
+ * A collection with enforced uniqueness, fast contains and deterministic
+ * ordering. When a duplicate happens, it can be configured to either remove
+ * it automatically or to throw an exception and fail loudly
+ */
+trait OSet[V] extends TraversableOnce[V]{
+ def contains(v: V): Boolean
+ def items: IndexedSeq[V]
+ def flatMap[T](f: V => TraversableOnce[T]): OSet[T]
+ def map[T](f: V => T): OSet[T]
+ def filter(f: V => Boolean): OSet[V]
+
+
+}
+object OSet{
+ def apply[V](items: V*) = from(items)
+ def dedup[V](items: V*) = from(items, dedup = true)
+
+ def from[V](items: TraversableOnce[V], dedup: Boolean = false): OSet[V] = {
+ val set = new MutableOSet[V](dedup)
+ items.foreach(set.append)
+ set
+ }
+}
+class MutableOSet[V](dedup: Boolean = false) extends OSet[V]{
+ private[this] val items0 = mutable.ArrayBuffer.empty[V]
+ private[this] val set0 = mutable.Set.empty[V]
+ def contains(v: V) = set0.contains(v)
+ def append(v: V) = if (!contains(v)){
+ set0.add(v)
+ items0.append(v)
+ }else if (!dedup) {
+ throw new Exception("Duplicated item inserted into OrderedSet: " + v)
+ }
+ def appendAll(vs: Seq[V]) = vs.foreach(append)
+ def items: IndexedSeq[V] = items0
+ def set: collection.Set[V] = set0
+
+ def map[T](f: V => T): OSet[T] = {
+ val output = new MutableOSet[T]
+ for(i <- items) output.append(f(i))
+ output
+ }
+ def flatMap[T](f: V => TraversableOnce[T]): OSet[T] = {
+ val output = new MutableOSet[T]
+ for(i <- items) for(i0 <- f(i)) output.append(i0)
+ output
+ }
+ def filter(f: V => Boolean): OSet[V] = {
+ val output = new MutableOSet[V]
+ for(i <- items) if (f(i)) output.append(i)
+ output
+ }
+
+ // Members declared in scala.collection.GenTraversableOnce
+ def isTraversableAgain: Boolean = items.isTraversableAgain
+ def toIterator: Iterator[V] = items.toIterator
+ def toStream: Stream[V] = items.toStream
+
+ // Members declared in scala.collection.TraversableOnce
+ def copyToArray[B >: V](xs: Array[B],start: Int,len: Int): Unit = items.copyToArray(xs, start, len)
+ def exists(p: V => Boolean): Boolean = items.exists(p)
+ def find(p: V => Boolean): Option[V] = items.find(p)
+ def forall(p: V => Boolean): Boolean = items.forall(p)
+ def foreach[U](f: V => U): Unit = items.foreach(f)
+ def hasDefiniteSize: Boolean = items.hasDefiniteSize
+ def isEmpty: Boolean = items.isEmpty
+ def seq: scala.collection.TraversableOnce[V] = items
+ def toTraversable: Traversable[V] = items
+
+ override def hashCode() = items.hashCode()
+ override def equals(other: Any) = other match{
+ case s: OSet[_] => items.equals(s.items)
+ case _ => super.equals(other)
+ }
+ override def toString = items.mkString("OSet(", ", ", ")")
+}
object Util{
def compileAll(sources: Target[Seq[jnio.Path]])
(implicit defCtx: DefCtx) = {
diff --git a/src/test/scala/forge/EvaluationTests.scala b/src/test/scala/forge/EvaluationTests.scala
index 34ebd684..2ab6d31e 100644
--- a/src/test/scala/forge/EvaluationTests.scala
+++ b/src/test/scala/forge/EvaluationTests.scala
@@ -14,89 +14,89 @@ object EvaluationTests extends TestSuite{
'evaluateSingle - {
val evaluator = new Evaluator(jnio.Paths.get("target/workspace"), baseCtx)
- def check(target: Target[_], expValue: Any, expEvaled: Seq[Target[_]]) = {
- val Evaluator.Results(returnedValues, returnedEvaluated) = evaluator.evaluate(Seq(target))
+ def check(target: Target[_], expValue: Any, expEvaled: OSet[Target[_]]) = {
+ val Evaluator.Results(returnedValues, returnedEvaluated) = evaluator.evaluate(OSet(target))
assert(
returnedValues == Seq(expValue),
returnedEvaluated == expEvaled
)
// Second time the value is already cached, so no evaluation needed
- val Evaluator.Results(returnedValues2, returnedEvaluated2) = evaluator.evaluate(Seq(target))
+ val Evaluator.Results(returnedValues2, returnedEvaluated2) = evaluator.evaluate(OSet(target))
assert(
returnedValues2 == returnedValues,
- returnedEvaluated2 == Nil
+ returnedEvaluated2 == OSet()
)
}
'singleton - {
import singleton._
// First time the target is evaluated
- check(single, expValue = 0, expEvaled = Seq(single))
+ check(single, expValue = 0, expEvaled = OSet(single))
single.counter += 1
// After incrementing the counter, it forces re-evaluation
- check(single, expValue = 1, expEvaled = Seq(single))
+ check(single, expValue = 1, expEvaled = OSet(single))
}
'pair - {
import pair._
- check(down, expValue = 0, expEvaled = Seq(up, down))
+ check(down, expValue = 0, expEvaled = OSet(up, down))
down.counter += 1
- check(down, expValue = 1, expEvaled = Seq(down))
+ check(down, expValue = 1, expEvaled = OSet(down))
up.counter += 1
- check(down, expValue = 2, expEvaled = Seq(up, down))
+ check(down, expValue = 2, expEvaled = OSet(up, down))
}
'anonTriple - {
import anonTriple._
val middle = down.inputs(0)
- check(down, expValue = 0, expEvaled = Seq(up, middle, down))
+ check(down, expValue = 0, expEvaled = OSet(up, middle, down))
down.counter += 1
- check(down, expValue = 1, expEvaled = Seq(down))
+ check(down, expValue = 1, expEvaled = OSet(down))
up.counter += 1
- check(down, expValue = 2, expEvaled = Seq(up, middle, down))
+ check(down, expValue = 2, expEvaled = OSet(up, middle, down))
middle.asInstanceOf[Target.Test].counter += 1
- check(down, expValue = 3, expEvaled = Seq(middle, down))
+ check(down, expValue = 3, expEvaled = OSet(middle, down))
}
'diamond - {
import diamond._
- check(down, expValue = 0, expEvaled = Seq(up, left, right, down))
+ check(down, expValue = 0, expEvaled = OSet(up, left, right, down))
down.counter += 1
- check(down, expValue = 1, expEvaled = Seq(down))
+ check(down, expValue = 1, expEvaled = OSet(down))
up.counter += 1
// Increment by 2 because up is referenced twice: once by left once by right
- check(down, expValue = 3, expEvaled = Seq(up, left, right, down))
+ check(down, expValue = 3, expEvaled = OSet(up, left, right, down))
left.counter += 1
- check(down, expValue = 4, expEvaled = Seq(left, down))
+ check(down, expValue = 4, expEvaled = OSet(left, down))
right.counter += 1
- check(down, expValue = 5, expEvaled = Seq(right, down))
+ check(down, expValue = 5, expEvaled = OSet(right, down))
}
'anonDiamond - {
import anonDiamond._
val left = down.inputs(0).asInstanceOf[Target.Test]
val right = down.inputs(1).asInstanceOf[Target.Test]
- check(down, expValue = 0, expEvaled = Seq(up, left, right, down))
+ check(down, expValue = 0, expEvaled = OSet(up, left, right, down))
down.counter += 1
- check(down, expValue = 1, expEvaled = Seq(down))
+ check(down, expValue = 1, expEvaled = OSet(down))
up.counter += 1
// Increment by 2 because up is referenced twice: once by left once by right
- check(down, expValue = 3, expEvaled = Seq(up, left, right, down))
+ check(down, expValue = 3, expEvaled = OSet(up, left, right, down))
left.counter += 1
- check(down, expValue = 4, expEvaled = Seq(left, down))
+ check(down, expValue = 4, expEvaled = OSet(left, down))
right.counter += 1
- check(down, expValue = 5, expEvaled = Seq(right, down))
+ check(down, expValue = 5, expEvaled = OSet(right, down))
}
// 'anonImpureDiamond - {
// import AnonImpureDiamond._
diff --git a/src/test/scala/forge/GraphTests.scala b/src/test/scala/forge/GraphTests.scala
index 85b63131..648ad873 100644
--- a/src/test/scala/forge/GraphTests.scala
+++ b/src/test/scala/forge/GraphTests.scala
@@ -75,31 +75,31 @@ object GraphTests extends TestSuite{
'topoSortedTransitiveTargets - {
- def check(targets: Seq[Target[_]], expected: Set[Target[_]]) = {
+ def check(targets: OSet[Target[_]], expected: OSet[Target[_]]) = {
val result = Evaluator.topoSortedTransitiveTargets(targets).values
TestUtil.checkTopological(result)
- assert(result.toSet == expected)
+ assert(result == expected)
}
'singleton - check(
- targets = Seq(singleton.single),
- expected = Set(singleton.single)
+ targets = OSet(singleton.single),
+ expected = OSet(singleton.single)
)
'pair - check(
- targets = Seq(pair.down),
- expected = Set(pair.up, pair.down)
+ targets = OSet(pair.down),
+ expected = OSet(pair.up, pair.down)
)
'anonTriple - check(
- targets = Seq(anonTriple.down),
- expected = Set(anonTriple.up, anonTriple.down.inputs(0), anonTriple.down)
+ targets = OSet(anonTriple.down),
+ expected = OSet(anonTriple.up, anonTriple.down.inputs(0), anonTriple.down)
)
'diamond - check(
- targets = Seq(diamond.down),
- expected = Set(diamond.up, diamond.left, diamond.right, diamond.down)
+ targets = OSet(diamond.down),
+ expected = OSet(diamond.up, diamond.left, diamond.right, diamond.down)
)
'anonDiamond - check(
- targets = Seq(diamond.down),
- expected = Set(
+ targets = OSet(diamond.down),
+ expected = OSet(
diamond.up,
diamond.down.inputs(0),
diamond.down.inputs(1),
@@ -107,59 +107,59 @@ object GraphTests extends TestSuite{
)
)
'bigSingleTerminal - {
- val result = Evaluator.topoSortedTransitiveTargets(Seq(bigSingleTerminal.j)).values
+ val result = Evaluator.topoSortedTransitiveTargets(OSet(bigSingleTerminal.j)).values
TestUtil.checkTopological(result)
- assert(result.distinct.length == 28)
+ assert(result.size == 28)
}
}
'groupAroundNamedTargets - {
- def check(target: Target[_], expected: Seq[Set[String]]) = {
+ def check(target: Target[_], expected: OSet[OSet[String]]) = {
val grouped = Evaluator.groupAroundNamedTargets(
- Evaluator.topoSortedTransitiveTargets(Seq(target))
+ Evaluator.topoSortedTransitiveTargets(OSet(target))
)
- TestUtil.checkTopological(grouped.flatten)
- val stringified = grouped.map(_.map(_.toString).toSet)
- assert(stringified == expected.map(_.toSet))
+ TestUtil.checkTopological(grouped.flatMap(_.items))
+ val stringified = grouped.map(_.map(_.toString))
+ assert(stringified == expected)
}
'singleton - check(
singleton.single,
- Seq(Set("single"))
+ OSet(OSet("single"))
)
'pair - check(
pair.down,
- Seq(Set("up"), Set("down"))
+ OSet(OSet("up"), OSet("down"))
)
'anonTriple - check(
anonTriple.down,
- Seq(Set("up"), Set("down1", "down"))
+ OSet(OSet("up"), OSet("down1", "down"))
)
'diamond - check(
diamond.down,
- Seq(Set("up"), Set("left"), Set("right"), Set("down"))
+ OSet(OSet("up"), OSet("left"), OSet("right"), OSet("down"))
)
'anonDiamond - check(
anonDiamond.down,
- Seq(
- Set("up"),
- Set("down2", "down1", "down")
+ OSet(
+ OSet("up"),
+ OSet("down2", "down1", "down")
)
)
'bigSingleTerminal - check(
bigSingleTerminal.j,
- Seq(
- Set("i1"),
- Set("e4"),
- Set("a1"),
- Set("a2"),
- Set("a"),
- Set("b1"),
- Set("b"),
- Set("e5", "e2", "e8", "e1", "e7", "e6", "e3", "e"),
- Set("i2", "i5", "i4", "i3", "i"),
- Set("f2"),
- Set("f3", "f1", "f"),
- Set("j3", "j2", "j1", "j")
+ OSet(
+ OSet("i1"),
+ OSet("e4"),
+ OSet("a1"),
+ OSet("a2"),
+ OSet("a"),
+ OSet("b1"),
+ OSet("b"),
+ OSet("e5", "e2", "e8", "e1", "e7", "e6", "e3", "e"),
+ OSet("i2", "i5", "i4", "i3", "i"),
+ OSet("f2"),
+ OSet("f3", "f1", "f"),
+ OSet("j3", "j2", "j1", "j")
)
)
}
diff --git a/src/test/scala/forge/TestUtil.scala b/src/test/scala/forge/TestUtil.scala
index 5a328d87..d0dcd755 100644
--- a/src/test/scala/forge/TestUtil.scala
+++ b/src/test/scala/forge/TestUtil.scala
@@ -7,9 +7,9 @@ import scala.collection.mutable
object TestUtil {
- def checkTopological(targets: Seq[Target[_]]) = {
+ def checkTopological(targets: OSet[Target[_]]) = {
val seen = mutable.Set.empty[Target[_]]
- for(t <- targets.reverseIterator){
+ for(t <- targets.items.reverseIterator){
seen.add(t)
for(upstream <- t.inputs){
assert(!seen(upstream))