summaryrefslogtreecommitdiff
path: root/main/core/src/mill/define/Graph.scala
blob: 3119f2fbcd10fccae94c0ff8a29b617fc972e2c3 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
package mill.define

import mill.eval.Tarjans
import mill.util.MultiBiMap
import mill.util.Strict.Agg

object Graph {

  /**
    * The `values` [[Agg]] is guaranteed to be topological sorted and cycle free.
    * That's why the constructor is package private.
    * @see [[Graph.topoSorted]]
    */
  class TopoSorted private[Graph] (val values: Agg[Task[_]])

  def groupAroundImportantTargets[T](topoSortedTargets: TopoSorted)
                                    (important: PartialFunction[Task[_], T]): MultiBiMap[T, Task[_]] = {

    val output = new MultiBiMap.Mutable[T, Task[_]]()
    for ((target, t) <- topoSortedTargets.values.flatMap(t => important.lift(t).map((t, _)))) {

      val transitiveTargets = new Agg.Mutable[Task[_]]
      def rec(t: Task[_]): Unit = {
        if (transitiveTargets.contains(t)) () // do nothing
        else if (important.isDefinedAt(t) && t != target) () // do nothing
        else {
          transitiveTargets.append(t)
          t.inputs.foreach(rec)
        }
      }
      rec(target)
      output.addAll(t, topoSorted(transitiveTargets).values)
    }
    output
  }

  /**
   * Collects all transitive dependencies (targets) of the given targets,
   * including the given targets.
   */
  def transitiveTargets(sourceTargets: Agg[Task[_]]): Agg[Task[_]] = {
    val transitiveTargets = new Agg.Mutable[Task[_]]
    def rec(t: Task[_]): Unit = {
      if (transitiveTargets.contains(t)) () // do nothing
      else {
        transitiveTargets.append(t)
        t.inputs.foreach(rec)
      }
    }

    sourceTargets.items.foreach(rec)
    transitiveTargets
  }
  /**
    * Takes the given targets, finds all the targets they transitively depend
    * on, and sort them topologically. Fails if there are dependency cycles
    */
  def topoSorted(transitiveTargets: Agg[Task[_]]): TopoSorted = {

    val indexed = transitiveTargets.indexed
    val targetIndices = indexed.zipWithIndex.toMap

    val numberedEdges =
      for(t <- transitiveTargets.items)
        yield t.inputs.collect(targetIndices)

    val sortedClusters = Tarjans(numberedEdges)
    val nonTrivialClusters = sortedClusters.filter(_.length > 1)
    assert(nonTrivialClusters.isEmpty, nonTrivialClusters)
    new TopoSorted(Agg.from(sortedClusters.flatten.map(indexed)))
  }
}