diff options
author | Li Haoyi <haoyi.sg@gmail.com> | 2017-11-03 23:44:39 -0700 |
---|---|---|
committer | Li Haoyi <haoyi.sg@gmail.com> | 2017-11-03 23:44:39 -0700 |
commit | 13270145903b457c906a9fa77bd152afb6448ef5 (patch) | |
tree | e85b7ed530e0c8e3c3041cbf17641857c448b602 /core/src/main/scala/forge/Evaluator.scala | |
parent | 66f1c5c2438aeb8f2496575f52c25b09cf5793a6 (diff) | |
download | mill-13270145903b457c906a9fa77bd152afb6448ef5.tar.gz mill-13270145903b457c906a9fa77bd152afb6448ef5.tar.bz2 mill-13270145903b457c906a9fa77bd152afb6448ef5.zip |
Split up forge into `scalaplugin` an `core` subprojects, to allow us to use the `T#apply` macro in the implementation of `scalaplugin.Subproject`
Also needed to implement inter-`Subproject` dependencies so the `MetacircularTests` can continue to support the new layout
Diffstat (limited to 'core/src/main/scala/forge/Evaluator.scala')
-rw-r--r-- | core/src/main/scala/forge/Evaluator.scala | 191 |
1 files changed, 191 insertions, 0 deletions
diff --git a/core/src/main/scala/forge/Evaluator.scala b/core/src/main/scala/forge/Evaluator.scala new file mode 100644 index 00000000..50dc46d4 --- /dev/null +++ b/core/src/main/scala/forge/Evaluator.scala @@ -0,0 +1,191 @@ +package forge + + +import play.api.libs.json.{Format, JsValue, Json} + +import scala.collection.mutable +import ammonite.ops._ +import forge.util.{Args, Labelled, MultiBiMap, OSet} +class Evaluator(workspacePath: Path, + labeling: Map[Target[_], Labelled[_]]){ + + def evaluate(targets: OSet[Target[_]]): Evaluator.Results = { + mkdir(workspacePath) + + val sortedGroups = Evaluator.groupAroundNamedTargets( + Evaluator.topoSortedTransitiveTargets(targets), + labeling + ) + + val evaluated = new OSet.Mutable[Target[_]] + val results = mutable.LinkedHashMap.empty[Target[_], Any] + + for (groupIndex <- sortedGroups.keys()){ + val group = sortedGroups.lookupKey(groupIndex) + + val (newResults, newEvaluated) = evaluateGroupCached( + group, + results, + sortedGroups + ) + evaluated.appendAll(newEvaluated) + for((k, v) <- newResults) results.put(k, v) + + } + + Evaluator.Results(targets.items.map(results), evaluated) + } + + def evaluateGroupCached(group: OSet[Target[_]], + results: collection.Map[Target[_], Any], + sortedGroups: MultiBiMap[Int, Target[_]]): (collection.Map[Target[_], Any], Seq[Target[_]]) = { + + + val (externalInputs, terminals) = partitionGroupInputOutput(group, results) + + val inputsHash = + externalInputs.toIterator.map(results).toVector.hashCode + + group.toIterator.map(_.sideHash).toVector.hashCode() + + val primeLabel = labeling(terminals.items(0)).segments + + + val targetDestPath = workspacePath / primeLabel + val metadataPath = targetDestPath / up / (targetDestPath.last + ".forge.json") + + val cached = for{ + json <- scala.util.Try(Json.parse(read.getInputStream(metadataPath))).toOption + (cachedHash, terminalResults) <- Json.fromJson[(Int, Seq[JsValue])](json).asOpt + if cachedHash == inputsHash + } yield terminalResults + + cached match{ + case Some(terminalResults) => + val newResults = mutable.LinkedHashMap.empty[Target[_], Any] + for((terminal, res) <- terminals.items.zip(terminalResults)){ + newResults(terminal) = labeling(terminal).format.reads(res).get + } + (newResults, Nil) + + case _ => + val (newResults, newEvaluated, terminalResults) = evaluateGroup(group, results, targetDestPath) + + write.over( + metadataPath, + Json.prettyPrint( + Json.toJson(inputsHash -> terminals.toList.map(terminalResults)) + ), + ) + + (newResults, newEvaluated) + } + } + + def partitionGroupInputOutput(group: OSet[Target[_]], + results: collection.Map[Target[_], Any]) = { + val allInputs = group.items.flatMap(_.inputs) + val (internalInputs, externalInputs) = allInputs.partition(group.contains) + val internalInputSet = internalInputs.toSet + val terminals = group.filter(!internalInputSet(_)) + (OSet.from(externalInputs.distinct), terminals) + } + + def evaluateGroup(group: OSet[Target[_]], + results: collection.Map[Target[_], Any], + targetDestPath: Path) = { + + rm(targetDestPath) + val terminalResults = mutable.LinkedHashMap.empty[Target[_], JsValue] + val newEvaluated = mutable.Buffer.empty[Target[_]] + val newResults = mutable.LinkedHashMap.empty[Target[_], Any] + for (target <- group.items) { + newEvaluated.append(target) + val targetInputValues = target.inputs.toVector.map(x => + newResults.getOrElse(x, results(x)) + ) + + val args = new Args(targetInputValues, targetDestPath) + val res = target.evaluate(args) + for(targetLabel <- labeling.get(target)){ + terminalResults(target) = targetLabel + .format + .asInstanceOf[Format[Any]] + .writes(res.asInstanceOf[Any]) + } + newResults(target) = res + } + + (newResults, newEvaluated, terminalResults) + } + +} + + +object Evaluator{ + class TopoSorted private[Evaluator] (val values: OSet[Target[_]]) + case class Results(values: Seq[Any], evaluated: OSet[Target[_]]) + def groupAroundNamedTargets(topoSortedTargets: TopoSorted, + labeling: Map[Target[_], Labelled[_]]): MultiBiMap[Int, Target[_]] = { + + val grouping = new MultiBiMap.Mutable[Int, Target[_]]() + + var groupCount = 0 + + for(target <- topoSortedTargets.values.items.reverseIterator){ + if (!grouping.containsValue(target)){ + grouping.add(groupCount, target) + groupCount += 1 + } + + val targetGroup = grouping.lookupValue(target) + for(upstream <- target.inputs){ + grouping.lookupValueOpt(upstream) match{ + case None if !labeling.contains(upstream) => + grouping.add(targetGroup, upstream) + case Some(upstreamGroup) if upstreamGroup == targetGroup => + val upstreamTargets = grouping.removeAll(upstreamGroup) + + grouping.addAll(targetGroup, upstreamTargets) + case _ => //donothing + } + } + } + + val targetOrdering = topoSortedTargets.values.items.zipWithIndex.toMap + val output = new MultiBiMap.Mutable[Int, Target[_]] + + // Sort groups amongst themselves, and sort the contents of each group + // before aggregating it into the final output + for(g <- grouping.values().toArray.sortBy(g => targetOrdering(g.items(0)))){ + output.addAll(output.keys.length, g.toArray.sortBy(targetOrdering)) + } + output + } + + /** + * Takes the given targets, finds all the targets they transitively depend + * on, and sort them topologically. Fails if there are dependency cycles + */ + def topoSortedTransitiveTargets(sourceTargets: OSet[Target[_]]): TopoSorted = { + val transitiveTargets = new OSet.Mutable[Target[_]] + def rec(t: Target[_]): Unit = { + if (transitiveTargets.contains(t)) () // do nothing + else { + transitiveTargets.append(t) + t.inputs.foreach(rec) + } + } + + sourceTargets.items.foreach(rec) + val targetIndices = transitiveTargets.items.zipWithIndex.toMap + + val numberedEdges = + for(t <- transitiveTargets.items) + yield t.inputs.map(targetIndices) + + val sortedClusters = Tarjans(numberedEdges) + val nonTrivialClusters = sortedClusters.filter(_.length > 1) + assert(nonTrivialClusters.isEmpty, nonTrivialClusters) + new TopoSorted(OSet.from(sortedClusters.flatten.map(transitiveTargets.items))) + } +}
\ No newline at end of file |