summaryrefslogtreecommitdiff
path: root/main/src/main/MainRunner.scala
blob: 354b6173ecae50ce09c1d2fdf780e87e9caedbf4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
package mill.main
import java.io.{InputStream, PrintStream}

import ammonite.Main
import ammonite.interp.{Interpreter, Preprocessor}
import ammonite.util.Util.CodeSource
import ammonite.util._
import mill.eval.{Evaluator, PathRef}
import mill.util.PrintLogger

import scala.annotation.tailrec
import ammonite.runtime.ImportHook


/**
  * Customized version of [[ammonite.MainRunner]], allowing us to run Mill
  * `build.sc` scripts with mill-specific tweaks such as a custom
  * `scriptCodeWrapper` or with a persistent evaluator between runs.
  */
class MainRunner(val config: ammonite.main.Cli.Config,
                 disableTicker: Boolean,
                 outprintStream: PrintStream,
                 errPrintStream: PrintStream,
                 stdIn: InputStream,
                 stateCache0: Option[Evaluator.State] = None,
                 env : Map[String, String],
                 setIdle: Boolean => Unit,
                 debugLog: Boolean,
                 keepGoing: Boolean,
                 systemProperties: Map[String, String])
  extends ammonite.MainRunner(
    config, outprintStream, errPrintStream,
    stdIn, outprintStream, errPrintStream
  ){

  var stateCache  = stateCache0

  override def watchAndWait(watched: Seq[(os.Path, Long)]) = {
    printInfo(s"Watching for changes to ${watched.length} files... (Ctrl-C to exit)")
    def statAll() = watched.forall{ case (file, lastMTime) =>
      Interpreter.pathSignature(file) == lastMTime
    }
    setIdle(true)
    while(statAll()) Thread.sleep(100)
    setIdle(false)
  }

  /**
    * Custom version of [[watchLoop]] that lets us generate the watched-file
    * signature only on demand, so if we don't have config.watch enabled we do
    * not pay the cost of generating it
    */
  @tailrec final def watchLoop2[T](isRepl: Boolean,
                                   printing: Boolean,
                                   run: Main => (Res[T], () => Seq[(os.Path, Long)])): Boolean = {
    val (result, watched) = run(initMain(isRepl))

    val success = handleWatchRes(result, printing)
    if (!config.watch) success
    else{
      watchAndWait(watched())
      watchLoop2(isRepl, printing, run)
    }
  }


  override def runScript(scriptPath: os.Path, scriptArgs: List[String]) =
    watchLoop2(
      isRepl = false,
      printing = true,
      mainCfg => {
        val (result, interpWatched) = RunScript.runScript(
          config.home,
          mainCfg.wd,
          scriptPath,
          mainCfg.instantiateInterpreter(),
          scriptArgs,
          stateCache,
          new PrintLogger(
            colors != ammonite.util.Colors.BlackWhite,
            disableTicker,
            colors,
            outprintStream,
            errPrintStream,
            errPrintStream,
            stdIn,
            debugEnabled = debugLog
          ),
          env,
          keepGoing = keepGoing,
          systemProperties
        )

        result match{
          case Res.Success(data) =>
            val (eval, evalWatches, res) = data

            stateCache = Some(Evaluator.State(eval.rootModule, eval.classLoaderSig, eval.workerCache, interpWatched))
            val watched = () => {
              val alreadyStale = evalWatches.exists(p => p.sig != PathRef(p.path, p.quick).sig)
              // If the file changed between the creation of the original
              // `PathRef` and the current moment, use random junk .sig values
              // to force an immediate re-run. Otherwise calculate the
              // pathSignatures the same way Ammonite would and hand over the
              // values, so Ammonite can watch them and only re-run if they
              // subsequently change
              if (alreadyStale) evalWatches.map(_.path -> util.Random.nextLong())
              else evalWatches.map(p => p.path -> Interpreter.pathSignature(p.path))
            }
            (Res(res), () => interpWatched ++ watched())
          case _ => (result, () => interpWatched)
        }
      }
    )

  override def handleWatchRes[T](res: Res[T], printing: Boolean) = {
    res match{
      case Res.Success(value) => true
      case _ => super.handleWatchRes(res, printing)
    }
  }

  override def initMain(isRepl: Boolean) = {
    val hooks = ImportHook.defaults + (Seq("ivy") -> MillIvyHook)
    super.initMain(isRepl).copy(
      scriptCodeWrapper = CustomCodeWrapper,
      // Ammonite does not properly forward the wd from CliConfig to Main, so
      // force forward it outselves
      wd = config.wd,
      importHooks = hooks
    )
  }

  object CustomCodeWrapper extends ammonite.interp.CodeWrapper {
    def apply(code: String,
              source: CodeSource,
              imports: ammonite.util.Imports,
              printCode: String,
              indexedWrapperName: ammonite.util.Name,
              extraCode: String): (String, String, Int) = {
      import source.pkgName
      val wrapName = indexedWrapperName.backticked
      val path = source
        .path
        .map(path => path.toNIO.getParent)
        .getOrElse(config.wd.toNIO)
      val literalPath = pprint.Util.literalize(path.toString)
      val external = !(path.compareTo(config.wd.toNIO) == 0)
      val top = s"""
        |package ${pkgName.head.encoded}
        |package ${Util.encodeScalaSourcePath(pkgName.tail)}
        |$imports
        |import mill._
        |object $wrapName
        |extends mill.define.BaseModule(os.Path($literalPath), foreign0 = $external)(
        |  implicitly, implicitly, implicitly, implicitly, mill.define.Caller(())
        |)
        |with $wrapName{
        |  // Stub to make sure Ammonite has something to call after it evaluates a script,
        |  // even if it does nothing...
        |  def $$main() = Iterator[String]()
        |
        |  // Need to wrap the returned Module in Some(...) to make sure it
        |  // doesn't get picked up during reflective child-module discovery
        |  def millSelf = Some(this)
        |
        |  implicit lazy val millDiscover: mill.define.Discover[this.type] = mill.define.Discover[this.type]
        |}
        |
        |sealed trait $wrapName extends mill.main.MainModule{
        |""".stripMargin
      val bottom = "}"

      (top, bottom, 1)
    }
  }
}