From aa5eb186c044e0c00d512e0c009e9d519a753e0c Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Tue, 17 Oct 2017 06:03:36 -0700 Subject: Include Tarjan's algorithm, for doing a topological sort which elegantly handles cycles --- build.sbt | 3 ++ src/main/scala/hbt/Main.scala | 57 +++++++++++++++++------ src/main/scala/hbt/Tarjans.scala | 61 ++++++++++++++++++++++++ src/test/scala/hbt/TarjanTests.scala | 89 ++++++++++++++++++++++++++++++++++++ 4 files changed, 197 insertions(+), 13 deletions(-) create mode 100644 src/main/scala/hbt/Tarjans.scala create mode 100644 src/test/scala/hbt/TarjanTests.scala diff --git a/build.sbt b/build.sbt index 6d412fc5..7478db4f 100644 --- a/build.sbt +++ b/build.sbt @@ -4,6 +4,9 @@ name := "hbt" organization := "com.lihaoyi" +libraryDependencies += "com.lihaoyi" %% "utest" % "0.5.4" % "test" + +testFrameworks += new TestFramework("utest.runner.Framework") libraryDependencies += "com.lihaoyi" %% "sourcecode" % "0.1.4" diff --git a/src/main/scala/hbt/Main.scala b/src/main/scala/hbt/Main.scala index fb86cecf..b4b5757c 100644 --- a/src/main/scala/hbt/Main.scala +++ b/src/main/scala/hbt/Main.scala @@ -6,8 +6,18 @@ import java.nio.{file => jnio} import java.util.jar.JarEntry import sourcecode.Enclosing +class Args(val args: IndexedSeq[_]){ + def length = args.length + def apply[T](index: Int): T = { + if (index >= 0 && index < args.length) args(index).asInstanceOf[T] + else throw new IndexOutOfBoundsException(s"Index $index outside of range 0 - ${args.length}") + } +} sealed trait Target[T]{ - def label: String + val label: String + def evaluate(args: Args): T + val inputs: Seq[Target[_]] + def map[V](f: T => V)(implicit path: Enclosing) = { Target.Mapped(this, f, path.value) } @@ -25,21 +35,36 @@ object Target{ def traverse[T](source: Seq[Target[T]])(implicit path: Enclosing) = { Traverse(source, path.value) } - case class Traverse[T](source: Seq[Target[T]], label: String) extends Target[Seq[T]] + case class Traverse[T](inputs: Seq[Target[T]], label: String) extends Target[Seq[T]]{ + def evaluate(args: Args) = { + for (i <- 0 until args.length) + yield args(i) + } + + } case class Mapped[T, V](source: Target[T], f: T => V, - label: String) extends Target[V] - case class Zipped[T, V](source: Target[T], + label: String) extends Target[V]{ + def evaluate(args: Args) = f(args(0)) + val inputs = List(source) + } + case class Zipped[T, V](source1: Target[T], source2: Target[V], - label: String) extends Target[(T, V)] - case class Path(path: jnio.Path, label: String) extends Target[jnio.Path] - case class Command(inputs: Seq[Target[jnio.Path]], - output: Seq[Target[jnio.Path]], - label: String) extends Target[Command.Result] - object Command{ - case class Result(stdout: String, - stderr: String, - writtenFiles: Seq[jnio.Path]) + label: String) extends Target[(T, V)]{ + def evaluate(args: Args) = (args(0), args(0)) + val inputs = List(source1, source1) + } + case class Path(path: jnio.Path, label: String) extends Target[jnio.Path]{ + def evaluate(args: Args) = path + val inputs = Nil } +// case class Command(inputs: Seq[Target[jnio.Path]], +// output: Seq[Target[jnio.Path]], +// label: String) extends Target[Command.Result] +// object Command{ +// case class Result(stdout: String, +// stderr: String, +// writtenFiles: Seq[jnio.Path]) +// } } object Main{ def compileAll(sources: Target[Seq[jnio.Path]]) @@ -87,4 +112,10 @@ object Main{ val classFiles = compileAll(allSources) val jar = jarUp(resourceRoot, classFiles) } + + def evaluateTargetGraph[T](t: Target[T]): T = { + ??? +// val evaluated = collection.mutable.Map.empty[Target[_], Any] +// val forwardEdges + } } \ No newline at end of file diff --git a/src/main/scala/hbt/Tarjans.scala b/src/main/scala/hbt/Tarjans.scala new file mode 100644 index 00000000..dc95b02f --- /dev/null +++ b/src/main/scala/hbt/Tarjans.scala @@ -0,0 +1,61 @@ +package hbt + +import collection.mutable +// Adapted from +// https://github.com/indy256/codelibrary/blob/c52247216258e84aac442a23273b7d8306ef757b/java/src/SCCTarjan.java +object Tarjans { + def main(args: Array[String]) = { + val components = Tarjans( + Vector( + Vector(1), + Vector(0), + Vector(0, 1) + ) + ) + println(components) + } + + def apply(graph0: Seq[Seq[Int]]): Seq[Seq[Int]] = { + val graph = graph0.map(_.toArray).toArray + val n = graph.length + val visited = new Array[Boolean](n) + val stack = mutable.ArrayBuffer.empty[Integer] + var time = 0 + val lowlink = new Array[Int](n) + val components = mutable.ArrayBuffer.empty[Seq[Int]] + + + for (u <- 0 until n) { + if (!visited(u)) dfs(u) + } + + def dfs(u: Int): Unit = { + lowlink(u) = time + time += 1 + visited(u) = true + stack.append(u) + var isComponentRoot = true + for (v <- graph(u)) { + if (!visited(v)) dfs(v) + if (lowlink(u) > lowlink(v)) { + lowlink(u) = lowlink(v) + isComponentRoot = false + } + } + if (isComponentRoot) { + val component = mutable.Buffer.empty[Int] + + var done = false + while (!done) { + val x = stack.last + stack.remove(stack.length - 1) + component.append(x) + lowlink(x) = Integer.MAX_VALUE + if (x == u) done = true + } + components.append(component) + } + } + components + } +} \ No newline at end of file diff --git a/src/test/scala/hbt/TarjanTests.scala b/src/test/scala/hbt/TarjanTests.scala new file mode 100644 index 00000000..48314cf4 --- /dev/null +++ b/src/test/scala/hbt/TarjanTests.scala @@ -0,0 +1,89 @@ +package hbt +import utest._ +object TarjanTests extends TestSuite{ + def check(input: Seq[Seq[Int]], expected: Seq[Seq[Int]]) = { + val result = Tarjans(input).map(_.sorted) + val sortedExpected = expected.map(_.sorted) + assert(result == sortedExpected) + } + val tests = Tests{ + // + 'empty - check(Seq(), Seq()) + + // (0) + 'singleton - check(Seq(Seq()), Seq(Seq(0))) + + + // (0)-. + // ^._/ + 'selfCycle - check(Seq(Seq(0)), Seq(Seq(0))) + + // (0) <-> (1) + 'simpleCycle- check(Seq(Seq(1), Seq(0)), Seq(Seq(1, 0))) + + // (0) (1) (2) + 'multipleSingletons - check( + Seq(Seq(), Seq(), Seq()), + Seq(Seq(0), Seq(1), Seq(2)) + ) + + // (0) -> (1) -> (2) + 'straightLineNoCycles- check( + Seq(Seq(1), Seq(2), Seq()), + Seq(Seq(2), Seq(1), Seq(0)) + ) + + // (0) <- (1) <- (2) + 'straightLineNoCyclesReversed- check( + Seq(Seq(), Seq(0), Seq(1)), + Seq(Seq(0), Seq(1), Seq(2)) + ) + + // (0) <-> (1) (2) -> (3) -> (4) + // ^.____________/ + 'independentSimpleCycles - check( + Seq(Seq(1), Seq(0), Seq(3), Seq(4), Seq(2)), + Seq(Seq(1, 0), Seq(4, 3, 2)) + ) + + // ___________________ + // v \ + // (0) <-> (1) (2) -> (3) -> (4) + // ^.____________/ + 'independentLinkedCycles - check( + Seq(Seq(1), Seq(0), Seq(3), Seq(4), Seq(2, 1)), + Seq(Seq(1, 0), Seq(4, 3, 2)) + ) + // _____________ + // / v + // (0) <-> (1) (2) -> (3) -> (4) + // ^.____________/ + 'independentLinkedCycles2 - check( + Seq(Seq(1, 2), Seq(0), Seq(3), Seq(4), Seq(2)), + Seq(Seq(4, 3, 2), Seq(1, 0)) + ) + + // _____________ + // / v + // (0) <-> (1) (2) -> (3) -> (4) + // ^. ^.____________/ + // \________________/ + 'combinedCycles - check( + Seq(Seq(1, 2), Seq(0), Seq(3), Seq(4), Seq(2, 1)), + Seq(Seq(4, 3, 2, 1, 0)) + ) + // + // (0) <-> (1) <- (2) <- (3) <-> (4) <- (5) + // ^.____________/ / / + // / / + // (6) <- (7) <-/ (8) <-' + // / / + // v / + // (9) <--------' + 'combinedCycles - check( + Seq(Seq(1), Seq(0), Seq(0, 1), Seq(2, 4, 7, 9), Seq(3), Seq(4, 8), Seq(9), Seq(6), Seq(), Seq()), + Seq(Seq(0, 1), Seq(2), Seq(9), Seq(6), Seq(7), Seq(3, 4), Seq(8), Seq(5)) + ) + + } +} \ No newline at end of file -- cgit v1.2.3