diff options
Diffstat (limited to 'main/core/src/define/Graph.scala')
-rw-r--r-- | main/core/src/define/Graph.scala | 72 |
1 files changed, 72 insertions, 0 deletions
diff --git a/main/core/src/define/Graph.scala b/main/core/src/define/Graph.scala new file mode 100644 index 00000000..5b29bd7b --- /dev/null +++ b/main/core/src/define/Graph.scala @@ -0,0 +1,72 @@ +package mill.define + +import mill.eval.Tarjans +import mill.util.MultiBiMap +import mill.api.Strict.Agg + +object Graph { + + /** + * The `values` [[Agg]] is guaranteed to be topological sorted and cycle free. + * That's why the constructor is package private. + * @see [[Graph.topoSorted]] + */ + class TopoSorted private[Graph] (val values: Agg[Task[_]]) + + def groupAroundImportantTargets[T](topoSortedTargets: TopoSorted) + (important: PartialFunction[Task[_], T]): MultiBiMap[T, Task[_]] = { + + val output = new MultiBiMap.Mutable[T, Task[_]]() + for ((target, t) <- topoSortedTargets.values.flatMap(t => important.lift(t).map((t, _)))) { + + val transitiveTargets = new Agg.Mutable[Task[_]] + def rec(t: Task[_]): Unit = { + if (transitiveTargets.contains(t)) () // do nothing + else if (important.isDefinedAt(t) && t != target) () // do nothing + else { + transitiveTargets.append(t) + t.inputs.foreach(rec) + } + } + rec(target) + output.addAll(t, topoSorted(transitiveTargets).values) + } + output + } + + /** + * Collects all transitive dependencies (targets) of the given targets, + * including the given targets. + */ + def transitiveTargets(sourceTargets: Agg[Task[_]]): Agg[Task[_]] = { + val transitiveTargets = new Agg.Mutable[Task[_]] + def rec(t: Task[_]): Unit = { + if (transitiveTargets.contains(t)) () // do nothing + else { + transitiveTargets.append(t) + t.inputs.foreach(rec) + } + } + + sourceTargets.items.foreach(rec) + transitiveTargets + } + /** + * Takes the given targets, finds all the targets they transitively depend + * on, and sort them topologically. Fails if there are dependency cycles + */ + def topoSorted(transitiveTargets: Agg[Task[_]]): TopoSorted = { + + val indexed = transitiveTargets.indexed + val targetIndices = indexed.zipWithIndex.toMap + + val numberedEdges = + for(t <- transitiveTargets.items) + yield t.inputs.collect(targetIndices) + + val sortedClusters = Tarjans(numberedEdges) + val nonTrivialClusters = sortedClusters.filter(_.length > 1) + assert(nonTrivialClusters.isEmpty, nonTrivialClusters) + new TopoSorted(Agg.from(sortedClusters.flatten.map(indexed))) + } +} |