From f53db8482c86f30c917d16b6312ad4804b37f2df Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Fri, 27 Oct 2017 08:15:01 -0700 Subject: Migrate everything which shouldn't have duplicates over to a new `OSet` data structure --- src/main/scala/forge/Evaluator.scala | 134 +++++++++++++++++++---------- src/main/scala/forge/Util.scala | 78 +++++++++++++++++ src/test/scala/forge/EvaluationTests.scala | 46 +++++----- src/test/scala/forge/GraphTests.scala | 78 ++++++++--------- src/test/scala/forge/TestUtil.scala | 4 +- 5 files changed, 229 insertions(+), 111 deletions(-) (limited to 'src') diff --git a/src/main/scala/forge/Evaluator.scala b/src/main/scala/forge/Evaluator.scala index 2a3e4470..ecbf4260 100644 --- a/src/main/scala/forge/Evaluator.scala +++ b/src/main/scala/forge/Evaluator.scala @@ -10,44 +10,81 @@ import scala.collection.mutable class Evaluator(workspacePath: jnio.Path, enclosingBase: DefCtx){ - val resultCache = mutable.Map.empty[String, (Int, String)] - def evaluate(targets: Seq[Target[_]]): Evaluator.Results = { + /** + * Cache from the ID of the first terminal target in a group to the has of + * all the group's distinct inputs, and the results of the possibly-multiple + * terminal nodes + */ + val resultCache = mutable.Map.empty[String, (Int, Seq[String])] + def evaluate(targets: OSet[Target[_]]): Evaluator.Results = { jnio.Files.createDirectories(workspacePath) - val sortedTargets = Evaluator.topoSortedTransitiveTargets(targets) - pprint.log(sortedTargets.values) - val evaluated = mutable.Buffer.empty[Target[_]] + val sortedGroups = Evaluator.groupAroundNamedTargets( + Evaluator.topoSortedTransitiveTargets(targets) + ) + + val evaluated = new MutableOSet[Target[_]] val results = mutable.Map.empty[Target[_], Any] - for (target <- sortedTargets.values){ - val inputResults = target.inputs.map(results).toIndexedSeq - - val enclosingStr = target.defCtx.label - val targetDestPath = workspacePath.resolve( - jnio.Paths.get(enclosingStr.stripSuffix(enclosingBase.label)) - ) - deleteRec(targetDestPath) - - val inputsHash = inputResults.hashCode - (target.dirty, resultCache.get(target.defCtx.label)) match{ - case (Some(dirtyCheck), Some((hash, res))) - if hash == inputsHash && !dirtyCheck() => - results(target) = target.formatter.reads(Json.parse(res)).get - - case _ => - evaluated.append(target) + for (group <- sortedGroups){ + val (newResults, newEvaluated) = evaluateGroup(group, results) + evaluated.appendAll(newEvaluated) + for((k, v) <- newResults) results.put(k, v) + + } + Evaluator.Results(targets.items.map(results), evaluated) + } + + def evaluateGroup(group: OSet[Target[_]], + results: collection.Map[Target[_], Any]) = { + val allInputs = group.items.flatMap(_.inputs) + val (internalInputs, externalInputs) = allInputs.partition(group.contains) + val internalInputSet = internalInputs.toSet + val inputResults = externalInputs.distinct.map(results).toIndexedSeq + + val newResults = mutable.Map.empty[Target[_], Any] + val newEvaluated = mutable.Buffer.empty[Target[_]] + + val terminals = group.filter(!internalInputSet(_)) + val primeTerminal = terminals.items(0) + val enclosingStr = primeTerminal.defCtx.label + val targetDestPath = workspacePath.resolve( + jnio.Paths.get(enclosingStr.stripSuffix(enclosingBase.label)) + ) + deleteRec(targetDestPath) + + val inputsHash = inputResults.hashCode + (primeTerminal.dirty, resultCache.get(primeTerminal.defCtx.label)) match{ + case (Some(dirtyCheck), Some((hash, terminalResults))) + if hash == inputsHash && !dirtyCheck() => + for((terminal, res) <- terminals.items.zip(terminalResults)){ + + newResults(terminal) = primeTerminal.formatter.reads(Json.parse(res)).get + } + + case _ => + val terminalResults = mutable.Buffer.empty[String] + for(target <- group.items){ + + newEvaluated.append(target) if (target.defCtx.anonId.isDefined && target.dirty.isEmpty) { val res = target.evaluate(new Args(inputResults, targetDestPath)) - results(target) = res + newResults(target) = res }else{ - val (res, serialized) = target.evaluateAndWrite(new Args(inputResults, targetDestPath)) - resultCache(target.defCtx.label) = (inputsHash, serialized) - results(target) = res + val (res, serialized) = target.evaluateAndWrite( + new Args(inputResults, targetDestPath) + ) + if (!internalInputSet(target)){ + terminalResults.append(serialized) + + } + newResults(target) = res } - - } + } + resultCache(primeTerminal.defCtx.label) = (inputsHash, terminalResults) } - Evaluator.Results(targets.map(results), evaluated) + + (newResults, newEvaluated) } def deleteRec(path: jnio.Path) = { if (jnio.Files.exists(path)){ @@ -63,14 +100,14 @@ class Evaluator(workspacePath: jnio.Path, object Evaluator{ - class TopoSorted private[Evaluator] (val values: Seq[Target[_]]) - case class Results(values: Seq[Any], evaluated: Seq[Target[_]]) - def groupAroundNamedTargets(topoSortedTargets: TopoSorted): Seq[Seq[Target[_]]] = { + class TopoSorted private[Evaluator] (val values: OSet[Target[_]]) + case class Results(values: Seq[Any], evaluated: OSet[Target[_]]) + def groupAroundNamedTargets(topoSortedTargets: TopoSorted): OSet[OSet[Target[_]]] = { val grouping = new MultiBiMap[Int, Target[_]]() var groupCount = 0 - for(target <- topoSortedTargets.values.reverseIterator){ + for(target <- topoSortedTargets.values.items.reverseIterator){ if (!grouping.containsValue(target)){ grouping.add(groupCount, target) @@ -90,40 +127,43 @@ object Evaluator{ } } } - val output = mutable.Buffer.empty[Seq[Target[_]]] - for(target <- topoSortedTargets.values.reverseIterator){ + val output = mutable.Buffer.empty[OSet[Target[_]]] + for(target <- topoSortedTargets.values.items.reverseIterator){ for(targetGroup <- grouping.lookupValueOpt(target)){ - output.append(grouping.removeAll(targetGroup)) + output.append( + OSet.from( + grouping.removeAll(targetGroup) + .sortBy(topoSortedTargets.values.items.indexOf) + ) + ) } } - output.map(_.sortBy(topoSortedTargets.values.indexOf)).reverse + OSet.from(output.reverse) } /** * Takes the given targets, finds */ - def topoSortedTransitiveTargets(sourceTargets: Seq[Target[_]]): TopoSorted = { - val transitiveTargetSet = mutable.Set.empty[Target[_]] - val transitiveTargets = mutable.Buffer.empty[Target[_]] + def topoSortedTransitiveTargets(sourceTargets: OSet[Target[_]]): TopoSorted = { + val transitiveTargets = new MutableOSet[Target[_]] def rec(t: Target[_]): Unit = { - if (transitiveTargetSet.contains(t)) () // do nothing + if (transitiveTargets.contains(t)) () // do nothing else { - transitiveTargetSet.add(t) transitiveTargets.append(t) t.inputs.foreach(rec) } } - sourceTargets.foreach(rec) - val targetIndices = transitiveTargets.zipWithIndex.toMap + sourceTargets.items.foreach(rec) + val targetIndices = transitiveTargets.items.zipWithIndex.toMap val numberedEdges = - for(i <- transitiveTargets.indices) - yield transitiveTargets(i).inputs.map(targetIndices) + for(t <- transitiveTargets.items) + yield t.inputs.map(targetIndices) val sortedClusters = Tarjans(numberedEdges) val nonTrivialClusters = sortedClusters.filter(_.length > 1) assert(nonTrivialClusters.isEmpty, nonTrivialClusters) - new TopoSorted(sortedClusters.flatten.map(transitiveTargets)) + new TopoSorted(OSet.from(sortedClusters.flatten.map(transitiveTargets.items))) } } \ No newline at end of file diff --git a/src/main/scala/forge/Util.scala b/src/main/scala/forge/Util.scala index 15d3e176..e558e975 100644 --- a/src/main/scala/forge/Util.scala +++ b/src/main/scala/forge/Util.scala @@ -32,6 +32,84 @@ class MultiBiMap[K, V](){ keyToValues(k) = vs ++: keyToValues.getOrElse(k, Nil) } } + +/** + * A collection with enforced uniqueness, fast contains and deterministic + * ordering. When a duplicate happens, it can be configured to either remove + * it automatically or to throw an exception and fail loudly + */ +trait OSet[V] extends TraversableOnce[V]{ + def contains(v: V): Boolean + def items: IndexedSeq[V] + def flatMap[T](f: V => TraversableOnce[T]): OSet[T] + def map[T](f: V => T): OSet[T] + def filter(f: V => Boolean): OSet[V] + + +} +object OSet{ + def apply[V](items: V*) = from(items) + def dedup[V](items: V*) = from(items, dedup = true) + + def from[V](items: TraversableOnce[V], dedup: Boolean = false): OSet[V] = { + val set = new MutableOSet[V](dedup) + items.foreach(set.append) + set + } +} +class MutableOSet[V](dedup: Boolean = false) extends OSet[V]{ + private[this] val items0 = mutable.ArrayBuffer.empty[V] + private[this] val set0 = mutable.Set.empty[V] + def contains(v: V) = set0.contains(v) + def append(v: V) = if (!contains(v)){ + set0.add(v) + items0.append(v) + }else if (!dedup) { + throw new Exception("Duplicated item inserted into OrderedSet: " + v) + } + def appendAll(vs: Seq[V]) = vs.foreach(append) + def items: IndexedSeq[V] = items0 + def set: collection.Set[V] = set0 + + def map[T](f: V => T): OSet[T] = { + val output = new MutableOSet[T] + for(i <- items) output.append(f(i)) + output + } + def flatMap[T](f: V => TraversableOnce[T]): OSet[T] = { + val output = new MutableOSet[T] + for(i <- items) for(i0 <- f(i)) output.append(i0) + output + } + def filter(f: V => Boolean): OSet[V] = { + val output = new MutableOSet[V] + for(i <- items) if (f(i)) output.append(i) + output + } + + // Members declared in scala.collection.GenTraversableOnce + def isTraversableAgain: Boolean = items.isTraversableAgain + def toIterator: Iterator[V] = items.toIterator + def toStream: Stream[V] = items.toStream + + // Members declared in scala.collection.TraversableOnce + def copyToArray[B >: V](xs: Array[B],start: Int,len: Int): Unit = items.copyToArray(xs, start, len) + def exists(p: V => Boolean): Boolean = items.exists(p) + def find(p: V => Boolean): Option[V] = items.find(p) + def forall(p: V => Boolean): Boolean = items.forall(p) + def foreach[U](f: V => U): Unit = items.foreach(f) + def hasDefiniteSize: Boolean = items.hasDefiniteSize + def isEmpty: Boolean = items.isEmpty + def seq: scala.collection.TraversableOnce[V] = items + def toTraversable: Traversable[V] = items + + override def hashCode() = items.hashCode() + override def equals(other: Any) = other match{ + case s: OSet[_] => items.equals(s.items) + case _ => super.equals(other) + } + override def toString = items.mkString("OSet(", ", ", ")") +} object Util{ def compileAll(sources: Target[Seq[jnio.Path]]) (implicit defCtx: DefCtx) = { diff --git a/src/test/scala/forge/EvaluationTests.scala b/src/test/scala/forge/EvaluationTests.scala index 34ebd684..2ab6d31e 100644 --- a/src/test/scala/forge/EvaluationTests.scala +++ b/src/test/scala/forge/EvaluationTests.scala @@ -14,89 +14,89 @@ object EvaluationTests extends TestSuite{ 'evaluateSingle - { val evaluator = new Evaluator(jnio.Paths.get("target/workspace"), baseCtx) - def check(target: Target[_], expValue: Any, expEvaled: Seq[Target[_]]) = { - val Evaluator.Results(returnedValues, returnedEvaluated) = evaluator.evaluate(Seq(target)) + def check(target: Target[_], expValue: Any, expEvaled: OSet[Target[_]]) = { + val Evaluator.Results(returnedValues, returnedEvaluated) = evaluator.evaluate(OSet(target)) assert( returnedValues == Seq(expValue), returnedEvaluated == expEvaled ) // Second time the value is already cached, so no evaluation needed - val Evaluator.Results(returnedValues2, returnedEvaluated2) = evaluator.evaluate(Seq(target)) + val Evaluator.Results(returnedValues2, returnedEvaluated2) = evaluator.evaluate(OSet(target)) assert( returnedValues2 == returnedValues, - returnedEvaluated2 == Nil + returnedEvaluated2 == OSet() ) } 'singleton - { import singleton._ // First time the target is evaluated - check(single, expValue = 0, expEvaled = Seq(single)) + check(single, expValue = 0, expEvaled = OSet(single)) single.counter += 1 // After incrementing the counter, it forces re-evaluation - check(single, expValue = 1, expEvaled = Seq(single)) + check(single, expValue = 1, expEvaled = OSet(single)) } 'pair - { import pair._ - check(down, expValue = 0, expEvaled = Seq(up, down)) + check(down, expValue = 0, expEvaled = OSet(up, down)) down.counter += 1 - check(down, expValue = 1, expEvaled = Seq(down)) + check(down, expValue = 1, expEvaled = OSet(down)) up.counter += 1 - check(down, expValue = 2, expEvaled = Seq(up, down)) + check(down, expValue = 2, expEvaled = OSet(up, down)) } 'anonTriple - { import anonTriple._ val middle = down.inputs(0) - check(down, expValue = 0, expEvaled = Seq(up, middle, down)) + check(down, expValue = 0, expEvaled = OSet(up, middle, down)) down.counter += 1 - check(down, expValue = 1, expEvaled = Seq(down)) + check(down, expValue = 1, expEvaled = OSet(down)) up.counter += 1 - check(down, expValue = 2, expEvaled = Seq(up, middle, down)) + check(down, expValue = 2, expEvaled = OSet(up, middle, down)) middle.asInstanceOf[Target.Test].counter += 1 - check(down, expValue = 3, expEvaled = Seq(middle, down)) + check(down, expValue = 3, expEvaled = OSet(middle, down)) } 'diamond - { import diamond._ - check(down, expValue = 0, expEvaled = Seq(up, left, right, down)) + check(down, expValue = 0, expEvaled = OSet(up, left, right, down)) down.counter += 1 - check(down, expValue = 1, expEvaled = Seq(down)) + check(down, expValue = 1, expEvaled = OSet(down)) up.counter += 1 // Increment by 2 because up is referenced twice: once by left once by right - check(down, expValue = 3, expEvaled = Seq(up, left, right, down)) + check(down, expValue = 3, expEvaled = OSet(up, left, right, down)) left.counter += 1 - check(down, expValue = 4, expEvaled = Seq(left, down)) + check(down, expValue = 4, expEvaled = OSet(left, down)) right.counter += 1 - check(down, expValue = 5, expEvaled = Seq(right, down)) + check(down, expValue = 5, expEvaled = OSet(right, down)) } 'anonDiamond - { import anonDiamond._ val left = down.inputs(0).asInstanceOf[Target.Test] val right = down.inputs(1).asInstanceOf[Target.Test] - check(down, expValue = 0, expEvaled = Seq(up, left, right, down)) + check(down, expValue = 0, expEvaled = OSet(up, left, right, down)) down.counter += 1 - check(down, expValue = 1, expEvaled = Seq(down)) + check(down, expValue = 1, expEvaled = OSet(down)) up.counter += 1 // Increment by 2 because up is referenced twice: once by left once by right - check(down, expValue = 3, expEvaled = Seq(up, left, right, down)) + check(down, expValue = 3, expEvaled = OSet(up, left, right, down)) left.counter += 1 - check(down, expValue = 4, expEvaled = Seq(left, down)) + check(down, expValue = 4, expEvaled = OSet(left, down)) right.counter += 1 - check(down, expValue = 5, expEvaled = Seq(right, down)) + check(down, expValue = 5, expEvaled = OSet(right, down)) } // 'anonImpureDiamond - { // import AnonImpureDiamond._ diff --git a/src/test/scala/forge/GraphTests.scala b/src/test/scala/forge/GraphTests.scala index 85b63131..648ad873 100644 --- a/src/test/scala/forge/GraphTests.scala +++ b/src/test/scala/forge/GraphTests.scala @@ -75,31 +75,31 @@ object GraphTests extends TestSuite{ 'topoSortedTransitiveTargets - { - def check(targets: Seq[Target[_]], expected: Set[Target[_]]) = { + def check(targets: OSet[Target[_]], expected: OSet[Target[_]]) = { val result = Evaluator.topoSortedTransitiveTargets(targets).values TestUtil.checkTopological(result) - assert(result.toSet == expected) + assert(result == expected) } 'singleton - check( - targets = Seq(singleton.single), - expected = Set(singleton.single) + targets = OSet(singleton.single), + expected = OSet(singleton.single) ) 'pair - check( - targets = Seq(pair.down), - expected = Set(pair.up, pair.down) + targets = OSet(pair.down), + expected = OSet(pair.up, pair.down) ) 'anonTriple - check( - targets = Seq(anonTriple.down), - expected = Set(anonTriple.up, anonTriple.down.inputs(0), anonTriple.down) + targets = OSet(anonTriple.down), + expected = OSet(anonTriple.up, anonTriple.down.inputs(0), anonTriple.down) ) 'diamond - check( - targets = Seq(diamond.down), - expected = Set(diamond.up, diamond.left, diamond.right, diamond.down) + targets = OSet(diamond.down), + expected = OSet(diamond.up, diamond.left, diamond.right, diamond.down) ) 'anonDiamond - check( - targets = Seq(diamond.down), - expected = Set( + targets = OSet(diamond.down), + expected = OSet( diamond.up, diamond.down.inputs(0), diamond.down.inputs(1), @@ -107,59 +107,59 @@ object GraphTests extends TestSuite{ ) ) 'bigSingleTerminal - { - val result = Evaluator.topoSortedTransitiveTargets(Seq(bigSingleTerminal.j)).values + val result = Evaluator.topoSortedTransitiveTargets(OSet(bigSingleTerminal.j)).values TestUtil.checkTopological(result) - assert(result.distinct.length == 28) + assert(result.size == 28) } } 'groupAroundNamedTargets - { - def check(target: Target[_], expected: Seq[Set[String]]) = { + def check(target: Target[_], expected: OSet[OSet[String]]) = { val grouped = Evaluator.groupAroundNamedTargets( - Evaluator.topoSortedTransitiveTargets(Seq(target)) + Evaluator.topoSortedTransitiveTargets(OSet(target)) ) - TestUtil.checkTopological(grouped.flatten) - val stringified = grouped.map(_.map(_.toString).toSet) - assert(stringified == expected.map(_.toSet)) + TestUtil.checkTopological(grouped.flatMap(_.items)) + val stringified = grouped.map(_.map(_.toString)) + assert(stringified == expected) } 'singleton - check( singleton.single, - Seq(Set("single")) + OSet(OSet("single")) ) 'pair - check( pair.down, - Seq(Set("up"), Set("down")) + OSet(OSet("up"), OSet("down")) ) 'anonTriple - check( anonTriple.down, - Seq(Set("up"), Set("down1", "down")) + OSet(OSet("up"), OSet("down1", "down")) ) 'diamond - check( diamond.down, - Seq(Set("up"), Set("left"), Set("right"), Set("down")) + OSet(OSet("up"), OSet("left"), OSet("right"), OSet("down")) ) 'anonDiamond - check( anonDiamond.down, - Seq( - Set("up"), - Set("down2", "down1", "down") + OSet( + OSet("up"), + OSet("down2", "down1", "down") ) ) 'bigSingleTerminal - check( bigSingleTerminal.j, - Seq( - Set("i1"), - Set("e4"), - Set("a1"), - Set("a2"), - Set("a"), - Set("b1"), - Set("b"), - Set("e5", "e2", "e8", "e1", "e7", "e6", "e3", "e"), - Set("i2", "i5", "i4", "i3", "i"), - Set("f2"), - Set("f3", "f1", "f"), - Set("j3", "j2", "j1", "j") + OSet( + OSet("i1"), + OSet("e4"), + OSet("a1"), + OSet("a2"), + OSet("a"), + OSet("b1"), + OSet("b"), + OSet("e5", "e2", "e8", "e1", "e7", "e6", "e3", "e"), + OSet("i2", "i5", "i4", "i3", "i"), + OSet("f2"), + OSet("f3", "f1", "f"), + OSet("j3", "j2", "j1", "j") ) ) } diff --git a/src/test/scala/forge/TestUtil.scala b/src/test/scala/forge/TestUtil.scala index 5a328d87..d0dcd755 100644 --- a/src/test/scala/forge/TestUtil.scala +++ b/src/test/scala/forge/TestUtil.scala @@ -7,9 +7,9 @@ import scala.collection.mutable object TestUtil { - def checkTopological(targets: Seq[Target[_]]) = { + def checkTopological(targets: OSet[Target[_]]) = { val seen = mutable.Set.empty[Target[_]] - for(t <- targets.reverseIterator){ + for(t <- targets.items.reverseIterator){ seen.add(t) for(upstream <- t.inputs){ assert(!seen(upstream)) -- cgit v1.2.3