summaryrefslogtreecommitdiff
path: root/main/core/src/util/Scripts.scala
blob: f61d5cb5fcb6da7bf34b253d63448d26e62988ca (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
package mill.util

import java.nio.file.NoSuchFileException


import ammonite.interp.api.AmmoniteExit
import ammonite.util.Name.backtickWrap
import ammonite.util.Util.CodeSource
import ammonite.util.{Name, Res, Util}
import fastparse.internal.Util.literalize
import mill.util.Router.{ArgSig, EntryPoint}

/**
  * Logic around using Ammonite as a script-runner; invoking scripts via the
  * macro-generated [[Router]], and pretty-printing any output or error messages
  */
object Scripts {
  def groupArgs(flatArgs: List[String]): Seq[(String, Option[String])] = {
    var keywordTokens = flatArgs
    var scriptArgs = Vector.empty[(String, Option[String])]

    while(keywordTokens.nonEmpty) keywordTokens match{
      case List(head, next, rest@_*) if head.startsWith("-") =>
        scriptArgs = scriptArgs :+ (head, Some(next))
        keywordTokens = rest.toList
      case List(head, rest@_*) =>
        scriptArgs = scriptArgs :+ (head, None)
        keywordTokens = rest.toList

    }
    scriptArgs
  }

  def runScript(wd: os.Path,
                path: os.Path,
                interp: ammonite.interp.Interpreter,
                scriptArgs: Seq[(String, Option[String])] = Nil) = {
    interp.watch(path)
    val (pkg, wrapper) = Util.pathToPackageWrapper(Seq(), path relativeTo wd)

    for{
      scriptTxt <- try Res.Success(Util.normalizeNewlines(os.read(path))) catch{
        case e: NoSuchFileException => Res.Failure("Script file not found: " + path)
      }

      processed <- interp.processModule(
        scriptTxt,
        CodeSource(wrapper, pkg, Seq(Name("ammonite"), Name("$file")), Some(path)),
        autoImport = true,
        // Not sure why we need to wrap this in a separate `$routes` object,
        // but if we don't do it for some reason the `generateRoutes` macro
        // does not see the annotations on the methods of the outer-wrapper.
        // It can inspect the type and its methods fine, it's just the
        // `methodsymbol.annotations` ends up being empty.
        extraCode = Util.normalizeNewlines(
          s"""
             |val $$routesOuter = this
             |object $$routes
             |extends scala.Function0[scala.Seq[ammonite.main.Router.EntryPoint[$$routesOuter.type]]]{
             |  def apply() = ammonite.main.Router.generateRoutes[$$routesOuter.type]
             |}
          """.stripMargin
        ),
        hardcoded = true
      )

      routeClsName <- processed.blockInfo.lastOption match{
        case Some(meta) => Res.Success(meta.id.wrapperPath)
        case None => Res.Skip
      }

      mainCls =
      interp
        .evalClassloader
        .loadClass(processed.blockInfo.last.id.wrapperPath + "$")

      routesCls =
      interp
        .evalClassloader
        .loadClass(routeClsName + "$$routes$")

      scriptMains =
      routesCls
        .getField("MODULE$")
        .get(null)
        .asInstanceOf[() => Seq[Router.EntryPoint[Any]]]
        .apply()


      mainObj = mainCls.getField("MODULE$").get(null)

      res <- Util.withContextClassloader(interp.evalClassloader){
        scriptMains match {
          // If there are no @main methods, there's nothing to do
          case Seq() =>
            if (scriptArgs.isEmpty) Res.Success(())
            else {
              val scriptArgString =
                scriptArgs.flatMap{case (a, b) => Seq(a) ++ b}.map(literalize(_))
                  .mkString(" ")

              Res.Failure("Script " + path.last + " does not take arguments: " + scriptArgString)
            }

          // If there's one @main method, we run it with all args
          case Seq(main) => runMainMethod(mainObj, main, scriptArgs)

          // If there are multiple @main methods, we use the first arg to decide
          // which method to run, and pass the rest to that main method
          case mainMethods =>
            val suffix = formatMainMethods(mainObj, mainMethods)
            scriptArgs match{
              case Seq() =>
                Res.Failure(
                  s"Need to specify a subcommand to call when running " + path.last + suffix
                )
              case Seq((head, Some(_)), tail @ _*) =>
                Res.Failure(
                  "To select a subcommand to run, you don't need --s." + Util.newLine +
                    s"Did you mean `${head.drop(2)}` instead of `$head`?"
                )
              case Seq((head, None), tail @ _*) =>
                mainMethods.find(_.name == head) match{
                  case None =>
                    Res.Failure(
                      s"Unable to find subcommand: " + backtickWrap(head) + suffix
                    )
                  case Some(main) =>
                    runMainMethod(mainObj, main, tail)
                }
            }
        }
      }
    } yield res
  }
  def formatMainMethods[T](base: T, mainMethods: Seq[Router.EntryPoint[T]]) = {
    if (mainMethods.isEmpty) ""
    else{
      val leftColWidth = getLeftColWidth(mainMethods.flatMap(_.argSignatures))

      val methods =
        for(main <- mainMethods)
          yield formatMainMethodSignature(base, main, 2, leftColWidth)

      Util.normalizeNewlines(
        s"""
           |
           |Available subcommands:
           |
           |${methods.mkString(Util.newLine)}""".stripMargin
      )
    }
  }
  def getLeftColWidth[T](items: Seq[ArgSig[T, _]]) = {
    items.map(_.name.length + 2) match{
      case Nil => 0
      case x => x.max
    }
  }
  def formatMainMethodSignature[T](base: T,
                                   main: Router.EntryPoint[T],
                                   leftIndent: Int,
                                   leftColWidth: Int) = {
    // +2 for space on right of left col
    val args = main.argSignatures.map(renderArg(base, _, leftColWidth + leftIndent + 2 + 2, 80))

    val leftIndentStr = " " * leftIndent
    val argStrings =
      for((lhs, rhs) <- args)
        yield {
          val lhsPadded = lhs.padTo(leftColWidth, ' ')
          val rhsPadded = rhs.linesIterator.mkString(Util.newLine)
          s"$leftIndentStr  $lhsPadded  $rhsPadded"
        }
    val mainDocSuffix = main.doc match{
      case Some(d) => Util.newLine + leftIndentStr + softWrap(d, leftIndent, 80)
      case None => ""
    }

    s"""$leftIndentStr${main.name}$mainDocSuffix
       |${argStrings.map(_ + Util.newLine).mkString}""".stripMargin
  }
  def runMainMethod[T](base: T,
                       mainMethod: Router.EntryPoint[T],
                       scriptArgs: Seq[(String, Option[String])]): Res[Any] = {
    val leftColWidth = getLeftColWidth(mainMethod.argSignatures)

    def expectedMsg = formatMainMethodSignature(base: T, mainMethod, 0, leftColWidth)

    def pluralize(s: String, n: Int) = {
      if (n == 1) s else s + "s"
    }

    mainMethod.invoke(base, scriptArgs) match{
      case Router.Result.Success(x) => Res.Success(x)
      case Router.Result.Error.Exception(x: AmmoniteExit) => Res.Success(x.value)
      case Router.Result.Error.Exception(x) => Res.Exception(x, "")
      case Router.Result.Error.MismatchedArguments(missing, unknown, duplicate, incomplete) =>
        val missingStr =
          if (missing.isEmpty) ""
          else {
            val chunks =
              for (x <- missing)
                yield "--" + x.name + ": " + x.typeString

            val argumentsStr = pluralize("argument", chunks.length)
            s"Missing $argumentsStr: (${chunks.mkString(", ")})" + Util.newLine
          }


        val unknownStr =
          if (unknown.isEmpty) ""
          else {
            val argumentsStr = pluralize("argument", unknown.length)
            s"Unknown $argumentsStr: " + unknown.map(literalize(_)).mkString(" ") + Util.newLine
          }

        val duplicateStr =
          if (duplicate.isEmpty) ""
          else {
            val lines =
              for ((sig, options) <- duplicate)
                yield {
                  s"Duplicate arguments for (--${sig.name}: ${sig.typeString}): " +
                    options.map(literalize(_)).mkString(" ") + Util.newLine
                }

            lines.mkString

          }
        val incompleteStr = incomplete match{
          case None => ""
          case Some(sig) =>
            s"Option (--${sig.name}: ${sig.typeString}) is missing a corresponding value" +
              Util.newLine

        }

        Res.Failure(
          Util.normalizeNewlines(
            s"""$missingStr$unknownStr$duplicateStr$incompleteStr
               |Arguments provided did not match expected signature:
               |
               |$expectedMsg
               |""".stripMargin
          )
        )

      case Router.Result.Error.InvalidArguments(x) =>
        val argumentsStr = pluralize("argument", x.length)
        val thingies = x.map{
          case Router.Result.ParamError.Invalid(p, v, ex) =>
            val literalV = literalize(v)
            val rendered = {renderArgShort(p)}
            s"$rendered: ${p.typeString} = $literalV failed to parse with $ex"
          case Router.Result.ParamError.DefaultFailed(p, ex) =>
            s"${renderArgShort(p)}'s default value failed to evaluate with $ex"
        }

        Res.Failure(
          Util.normalizeNewlines(
            s"""The following $argumentsStr failed to parse:
               |
              |${thingies.mkString(Util.newLine)}
               |
              |expected signature:
               |
              |$expectedMsg
            """.stripMargin
          )
        )
    }
  }

  def softWrap(s: String, leftOffset: Int, maxWidth: Int) = {
    val oneLine = s.linesIterator.mkString(" ").split(' ')

    lazy val indent = " " * leftOffset

    val output = new StringBuilder(oneLine.head)
    var currentLineWidth = oneLine.head.length
    for(chunk <- oneLine.tail){
      val addedWidth = currentLineWidth + chunk.length + 1
      if (addedWidth > maxWidth){
        output.append(Util.newLine + indent)
        output.append(chunk)
        currentLineWidth = chunk.length
      } else{
        currentLineWidth = addedWidth
        output.append(' ')
        output.append(chunk)
      }
    }
    output.mkString
  }
  def renderArgShort[T](arg: ArgSig[T, _]) = "--" + backtickWrap(arg.name)
  def renderArg[T](base: T,
                   arg: ArgSig[T, _],
                   leftOffset: Int,
                   wrappedWidth: Int): (String, String) = {
    val suffix = arg.default match{
      case Some(f) => " (default " + f(base) + ")"
      case None => ""
    }
    val docSuffix = arg.doc match{
      case Some(d) => ": " + d
      case None => ""
    }
    val wrapped = softWrap(
      arg.typeString + suffix + docSuffix,
      leftOffset,
      wrappedWidth - leftOffset
    )
    (renderArgShort(arg), wrapped)
  }


  def mainMethodDetails[T](ep: EntryPoint[T]) = {
    ep.argSignatures.collect{
      case ArgSig(name, tpe, Some(doc), default) =>
        Util.newLine + name + " // " + doc
    }.mkString
  }

  /**
    * Additional [[scopt.Read]] instance to teach it how to read Ammonite paths
    */
  implicit def pathScoptRead: scopt.Read[os.Path] = scopt.Read.stringRead.map(os.Path(_, os.pwd))

}