From d8c23bbf9063404c334bf2abc9ad102729126ead Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Thu, 26 Oct 2017 20:50:08 -0700 Subject: Fleshed out basic `groupAroundNamedTargets` logic --- src/main/scala/forge/Evaluator.scala | 40 ++++++++++++++++++++++++++++++++--- src/main/scala/forge/Target.scala | 5 +---- src/main/scala/forge/Util.scala | 23 ++++++++++++++++++++ src/test/scala/forge/ForgeTests.scala | 39 +++++++++++++++++++++++++++++++--- 4 files changed, 97 insertions(+), 10 deletions(-) (limited to 'src') diff --git a/src/main/scala/forge/Evaluator.scala b/src/main/scala/forge/Evaluator.scala index 1bff722b..1220f56c 100644 --- a/src/main/scala/forge/Evaluator.scala +++ b/src/main/scala/forge/Evaluator.scala @@ -17,7 +17,7 @@ class Evaluator(workspacePath: jnio.Path, val sortedTargets = Evaluator.topoSortedTransitiveTargets(targets) val evaluated = mutable.Buffer.empty[Target[_]] val results = mutable.Map.empty[Target[_], Any] - for (target <- sortedTargets){ + for (target <- sortedTargets.values){ val inputResults = target.inputs.map(results).toIndexedSeq val enclosingStr = target.defCtx.label @@ -62,11 +62,45 @@ class Evaluator(workspacePath: jnio.Path, object Evaluator{ + class TopoSorted private[Evaluator] (val values: Seq[Target[_]]) case class Results(values: Seq[Any], evaluated: Seq[Target[_]]) + def groupAroundNamedTargets(topoSortedTargets: TopoSorted): Seq[Seq[Target[_]]] = { + val grouping = new MultiBiMap[Int, Target[_]]() + + var groupCount = 0 + + for(target <- topoSortedTargets.values.reverseIterator){ + + if (!grouping.containsValue(target)){ + grouping.add(groupCount, target) + groupCount += 1 + } + + val targetGroup = grouping.lookupValue(target) + for(upstream <- target.inputs){ + grouping.lookupValueOpt(upstream) match{ + case None if upstream.dirty.isEmpty && upstream.defCtx.anonId.nonEmpty => + grouping.add(targetGroup, upstream) + case Some(upstreamGroup) if upstreamGroup == targetGroup => + val upstreamTargets = grouping.removeAll(upstreamGroup) + grouping.addAll(targetGroup, upstreamTargets) + case _ => //donothing + } + } + } + val output = mutable.Buffer.empty[Seq[Target[_]]] + for(target <- topoSortedTargets.values){ + for(targetGroup <- grouping.lookupValueOpt(target)){ + output.append(grouping.removeAll(targetGroup)) + } + } + output + } + /** * Takes the given targets, finds */ - def topoSortedTransitiveTargets(sourceTargets: Seq[Target[_]]) = { + def topoSortedTransitiveTargets(sourceTargets: Seq[Target[_]]): TopoSorted = { val transitiveTargetSet = mutable.Set.empty[Target[_]] val transitiveTargets = mutable.Buffer.empty[Target[_]] def rec(t: Target[_]): Unit = { @@ -88,6 +122,6 @@ object Evaluator{ val sortedClusters = Tarjans(numberedEdges) val nonTrivialClusters = sortedClusters.filter(_.length > 1) assert(nonTrivialClusters.isEmpty, nonTrivialClusters) - sortedClusters.flatten.map(transitiveTargets) + new TopoSorted(sortedClusters.flatten.map(transitiveTargets)) } } \ No newline at end of file diff --git a/src/main/scala/forge/Target.scala b/src/main/scala/forge/Target.scala index 93fa022e..46c4b2a1 100644 --- a/src/main/scala/forge/Target.scala +++ b/src/main/scala/forge/Target.scala @@ -52,10 +52,7 @@ object Target{ override def toString = this.getClass.getName + "@" + defCtx.label } def test(inputs: Target[Int]*)(implicit defCtx: DefCtx) = { - new Test(inputs, defCtx, pure = false) - } - def testPure(inputs: Target[Int]*)(implicit defCtx: DefCtx) = { - new Test(inputs, defCtx, pure = true) + new Test(inputs, defCtx, pure = inputs.nonEmpty) } /** diff --git a/src/main/scala/forge/Util.scala b/src/main/scala/forge/Util.scala index 7beffd4e..0a751da5 100644 --- a/src/main/scala/forge/Util.scala +++ b/src/main/scala/forge/Util.scala @@ -7,7 +7,30 @@ import java.util.jar.JarEntry import sourcecode.Enclosing import scala.collection.JavaConverters._ +import scala.collection.mutable +class MultiBiMap[K, V](){ + private[this] val valueToKey = mutable.Map.empty[V, K] + private[this] val keyToValues = mutable.Map.empty[K, List[V]] + def containsValue(v: V) = valueToKey.contains(v) + def lookupValue(v: V) = valueToKey(v) + def lookupValueOpt(v: V) = valueToKey.get(v) + def add(k: K, v: V): Unit = { + valueToKey(v) = k + keyToValues(k) = v :: keyToValues.getOrElse(k, Nil) + } + def removeAll(k: K): Seq[V] = { + val vs = keyToValues(k) + for(v <- vs){ + valueToKey.remove(v) + } + vs + } + def addAll(k: K, vs: Seq[V]): Unit = { + for(v <- vs) valueToKey(v) = k + keyToValues(k) = vs ++: keyToValues.getOrElse(k, Nil) + } +} object Util{ def compileAll(sources: Target[Seq[jnio.Path]]) (implicit defCtx: DefCtx) = { diff --git a/src/test/scala/forge/ForgeTests.scala b/src/test/scala/forge/ForgeTests.scala index 6c588d4f..5e9d8e37 100644 --- a/src/test/scala/forge/ForgeTests.scala +++ b/src/test/scala/forge/ForgeTests.scala @@ -1,7 +1,7 @@ package forge import utest._ -import Target.{test, testPure} +import Target.test import java.nio.{file => jnio} object ForgeTests extends TestSuite{ @@ -33,7 +33,7 @@ object ForgeTests extends TestSuite{ object AnonImpureDiamond{ val up = T{ test() } - val down = T{ test(testPure(up), test(up)) } + val down = T{ test(test(up), test(up)) } } @@ -101,7 +101,7 @@ object ForgeTests extends TestSuite{ 'topoSortedTransitiveTargets - { def check(targets: Seq[Target[_]], expected: Seq[Target[_]]) = { - val result = Evaluator.topoSortedTransitiveTargets(targets) + val result = Evaluator.topoSortedTransitiveTargets(targets).values assert(result == expected) } @@ -131,6 +131,39 @@ object ForgeTests extends TestSuite{ ) ) } + + 'groupAroundNamedTargets - { + def check(target: Target[_], expected: Seq[Seq[Target[_]]]) = { + val grouped = Evaluator.groupAroundNamedTargets( + Evaluator.topoSortedTransitiveTargets(Seq(target)) + ) + assert(grouped == expected) + } + 'singleton - check( + Singleton.single, + Seq(Seq(Singleton.single)) + ) + 'pair - check( + Pair.down, + Seq(Seq(Pair.up), Seq(Pair.down)) + ) + 'anonTriple - check( + AnonTriple.down, + Seq(Seq(AnonTriple.up), Seq(AnonTriple.down.inputs(0), AnonTriple.down)) + ) + 'diamond - check( + Diamond.down, + Seq(Seq(Diamond.up), Seq(Diamond.left), Seq(Diamond.right), Seq(Diamond.down)) + ) + 'anonDiamond - check( + AnonDiamond.down, + Seq( + Seq(AnonDiamond.up), + Seq(AnonDiamond.down.inputs(1), AnonDiamond.down.inputs(0), AnonDiamond.down) + ) + ) + } + 'labeling - { def check(t: Target[_], relPath: String) = { -- cgit v1.2.3