summaryrefslogblamecommitdiff
path: root/core/src/main/scala/mill/eval/Evaluator.scala
blob: da6e2656913f314e23629f9a9a86ced601cc2df5 (plain) (tree)
1
2
3
4
5
6
7
8
9
                 
 
                     


                                         
                                                 

                               
                                    
                                                     
 
                                                             
                        
 

                                                         
                                                                              
 

                                                           
 

                                                    
 
                                                           

                

                    



                                                 
 
                                                                        

   


                                                                                                                   
 
 
                                                                               

                    

                                                                
 


                                                                                
                                                                                     

                                                  
     
 
                     
                                  
                                                                                    

                                                                                     
                           

                 
                                   
                                                                  
                                                                     
                                                                         
         
                         

               




                                                                                                      
                                                                                                       
 






                                                                              

         
                                  

     
 

                                                                          


                                                                              
                                                                                   
                                                   

   

                                                          
                                                     
 
                              


                                                                       




                                                             
 
                                                                   





                                               
       
                              
     
 
                                               
   
 
 


                 
                                                                
                                                                                           
                                                            
                                                                                                
 
                                                         


                      
                                                                  







                                                    
                                                      
                                               

                                                                              





                                                                   
 


                                                                          











                                                                                                


                                                                                  
                                   
     
          

   
                                                                        

                                                     
                                                         
            
                                   



                             
                                    







                                                                             
                                                                  

                       
                                       
                                       



                                                                
                                                                                  

   
package mill.eval

import ammonite.ops._
import mill.define.Task
import mill.discover.Labelled
import mill.util.{Args, MultiBiMap, OSet}
import play.api.libs.json.{Format, JsValue, Json}

import scala.collection.mutable
class Evaluator(workspacePath: Path,
                labeling: Map[Task[_], Labelled[_]]){

  def evaluate(targets: OSet[Task[_]]): Evaluator.Results = {
    mkdir(workspacePath)

    val transitive = Evaluator.transitiveTargets(targets)
    val topoSorted = Evaluator.topoSorted(transitive)
    val sortedGroups = Evaluator.groupAroundNamedTargets(topoSorted, labeling)

    val evaluated = new OSet.Mutable[Task[_]]
    val results = mutable.LinkedHashMap.empty[Task[_], Any]

    for (groupIndex <- sortedGroups.keys()){
      val group = sortedGroups.lookupKey(groupIndex)

      val (newResults, newEvaluated) = evaluateGroupCached(
        group,
        results,
        sortedGroups
      )
      evaluated.appendAll(newEvaluated)
      for((k, v) <- newResults) results.put(k, v)

    }

    Evaluator.Results(targets.items.map(results), evaluated, transitive)
  }

  def evaluateGroupCached(group: OSet[Task[_]],
                          results: collection.Map[Task[_], Any],
                          sortedGroups: MultiBiMap[Int, Task[_]]): (collection.Map[Task[_], Any], Seq[Task[_]]) = {


    val (externalInputs, terminals) = partitionGroupInputOutput(group, results)

    val inputsHash =
      externalInputs.toIterator.map(results).toVector.hashCode +
      group.toIterator.map(_.sideHash).toVector.hashCode()

    val (targetDestPath, metadataPath) = labeling.get(terminals.items(0)) match{
      case Some(labeling) =>
        val targetDestPath = workspacePath / labeling.segments
        val metadataPath = targetDestPath / up / (targetDestPath.last + ".mill.json")
        (Some(targetDestPath), Some(metadataPath))
      case None => (None, None)
    }

    val cached = for{
      metadataPath <- metadataPath
      json <- scala.util.Try(Json.parse(read.getInputStream(metadataPath))).toOption
      (cachedHash, terminalResults) <- Json.fromJson[(Int, Seq[JsValue])](json).asOpt
      if cachedHash == inputsHash
    } yield terminalResults

    cached match{
      case Some(terminalResults) =>
        val newResults = mutable.LinkedHashMap.empty[Task[_], Any]
        for((terminal, res) <- terminals.items.zip(terminalResults)){
          newResults(terminal) = labeling(terminal).format.reads(res).get
        }
        (newResults, Nil)

      case _ =>

        val labeled = group.collect(labeling)
        if (labeled.nonEmpty){
          println(fansi.Color.Blue("Running " + labeled.map(_.segments.mkString(".")).mkString(", ")))
        }
        val (newResults, newEvaluated, terminalResults) = evaluateGroup(group, results, targetDestPath)

        metadataPath.foreach(
          write.over(
            _,
            Json.prettyPrint(
              Json.toJson(inputsHash -> terminals.toList.map(terminalResults))
            ),
          )
        )

        (newResults, newEvaluated)
    }
  }

  def partitionGroupInputOutput(group: OSet[Task[_]],
                                results: collection.Map[Task[_], Any]) = {
    val allInputs = group.items.flatMap(_.inputs)
    val (internalInputs, externalInputs) = allInputs.partition(group.contains)
    val internalInputSet = internalInputs.toSet
    val terminals = group.filter(x => !internalInputSet(x) || labeling.contains(x))
    (OSet.from(externalInputs.distinct), terminals)
  }

  def evaluateGroup(group: OSet[Task[_]],
                    results: collection.Map[Task[_], Any],
                    targetDestPath: Option[Path]) = {

    targetDestPath.foreach(rm)
    val terminalResults = mutable.LinkedHashMap.empty[Task[_], JsValue]
    val newEvaluated = mutable.Buffer.empty[Task[_]]
    val newResults = mutable.LinkedHashMap.empty[Task[_], Any]
    for (target <- group.items) {
      newEvaluated.append(target)
      val targetInputValues = target.inputs.toVector.map(x =>
        newResults.getOrElse(x, results(x))
      )

      val args = new Args(targetInputValues, targetDestPath.orNull)
      val res = target.evaluate(args)
      for(targetLabel <- labeling.get(target)){
        terminalResults(target) = targetLabel
          .format
          .asInstanceOf[Format[Any]]
          .writes(res.asInstanceOf[Any])
      }
      newResults(target) = res
    }

    (newResults, newEvaluated, terminalResults)
  }

}


object Evaluator{
  class TopoSorted private[Evaluator](val values: OSet[Task[_]])
  case class Results(values: Seq[Any], evaluated: OSet[Task[_]], transitive: OSet[Task[_]])
  def groupAroundNamedTargets(topoSortedTargets: TopoSorted,
                              labeling: Map[Task[_], Labelled[_]]): MultiBiMap[Int, Task[_]] = {

    val grouping = new MultiBiMap.Mutable[Int, Task[_]]()

    var groupCount = 0

    for(target <- topoSortedTargets.values.items.reverseIterator){
      if (!grouping.containsValue(target)){
        grouping.add(groupCount, target)
        groupCount += 1
      }

      val targetGroup = grouping.lookupValue(target)
      for(upstream <- target.inputs){
        grouping.lookupValueOpt(upstream) match{
          case None if !labeling.contains(upstream) =>
            grouping.add(targetGroup, upstream)
          case Some(upstreamGroup)
            if !labeling.contains(upstream) && upstreamGroup != targetGroup =>
            val upstreamTargets = grouping.removeAll(upstreamGroup)
            grouping.addAll(targetGroup, upstreamTargets)
          case _ => //donothing
        }
      }
    }


    // Sort groups amongst themselves, and sort the contents of each group
    // before aggregating it into the final output
    val groupGraph = mutable.Buffer.fill[Seq[Int]](groupCount)(Nil)
    for((groupId, groupTasks) <- grouping.items()){
      groupGraph(groupId) =
        groupTasks.toIterator.flatMap(_.inputs).map(grouping.lookupValue).toArray.distinct.toSeq
    }
    // Given input topoSortedTargets has no cycles, group graph should not have cycles
    val groupOrdering = Tarjans.apply(groupGraph)


    val targetOrdering = topoSortedTargets.values.items.zipWithIndex.toMap
    val output = new MultiBiMap.Mutable[Int, Task[_]]
    for((groupIndices, i) <- groupOrdering.zipWithIndex){
      val sortedGroup = OSet.from(
        groupIndices.flatMap(grouping.lookupKeyOpt).flatten.sortBy(targetOrdering)
      )
      output.addAll(i, sortedGroup)
    }
    output
  }

  def transitiveTargets(sourceTargets: OSet[Task[_]]): OSet[Task[_]] = {
    val transitiveTargets = new OSet.Mutable[Task[_]]
    def rec(t: Task[_]): Unit = {
      if (transitiveTargets.contains(t)) () // do nothing
      else {
        transitiveTargets.append(t)
        t.inputs.foreach(rec)
      }
    }

    sourceTargets.items.foreach(rec)
    transitiveTargets
  }
  /**
    * Takes the given targets, finds all the targets they transitively depend
    * on, and sort them topologically. Fails if there are dependency cycles
    */
  def topoSorted(transitiveTargets: OSet[Task[_]]): TopoSorted = {

    val targetIndices = transitiveTargets.items.zipWithIndex.toMap

    val numberedEdges =
      for(t <- transitiveTargets.items)
      yield t.inputs.map(targetIndices)

    val sortedClusters = Tarjans(numberedEdges)
    val nonTrivialClusters = sortedClusters.filter(_.length > 1)
    assert(nonTrivialClusters.isEmpty, nonTrivialClusters)
    new TopoSorted(OSet.from(sortedClusters.flatten.map(transitiveTargets.items)))
  }
}