summaryrefslogtreecommitdiff
path: root/src/compiler/scala/tools/nsc/ScriptRunner.scala
blob: 40ad350fdaf7ec96b828dd18330dfe49d8332c8b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
/* NSC -- new Scala compiler
 * Copyright 2005-2007 LAMP/EPFL
 * @author  Martin Odersky
 */
// $Id$

package scala.tools.nsc

import java.io.{BufferedReader, File, FileInputStream, FileOutputStream,
                FileReader, InputStreamReader, PrintWriter}
import java.lang.reflect.InvocationTargetException
import java.util.jar.{JarEntry, JarOutputStream}

import scala.tools.nsc.io.PlainFile
import scala.tools.nsc.reporters.ConsoleReporter
import scala.tools.nsc.util.{CompoundSourceFile, SourceFile, SourceFileFragment}

/** An object that runs Scala code in script files.
 *
 *  <p>For example, here is a complete Scala script on Unix:</pre>
 *  <pre>
 *    #!/bin/sh
 *    exec scala "$0" "$@"
 *    !#
 *    Console.println("Hello, world!")
 *    argv.toList foreach Console.println
 *  </pre>
 *  <p>And here is a batch file example on Windows XP:</p>
 *  <pre>
 *    ::#!
 *    @echo off
 *    call scala %0 %*
 *    goto :eof
 *    ::!#
 *    Console.println("Hello, world!")
 *    argv.toList foreach Console.println
 *  </pre>
 *
 *  @author  Lex Spoon
 *  @version 1.0, 15/05/2006
 *  @todo    It would be better if error output went to stderr instead
 *           of stdout...
 */
object ScriptRunner {

  /** Choose a jar filename to hold the compiled version
   * of a script
   */
  private def jarFileFor(scriptFile: String): File = {
    val filename =
      if (scriptFile.matches(".*\\.[^.\\\\/]*"))
        scriptFile.replaceFirst("\\.[^.\\\\/]*$", ".jar")
      else
        scriptFile + ".jar"

    new File(filename)
  }

  /** Try to create a jar file out of all the contents
   *  of the directory <code>sourcePath</code>.
   */
  private def tryMakeJar(jarFile: File, sourcePath: File) = {
    try {
      val jarFileStream = new FileOutputStream(jarFile)
      val jar = new JarOutputStream(jarFileStream)
      val buf = new Array[byte](10240)

      def addFromDir(dir: File, prefix: String): Unit = {
        for (entry <- dir.listFiles) {
          if (entry.isFile) {
            jar.putNextEntry(new JarEntry(prefix + entry.getName))

            val input = new FileInputStream(entry)
            var n = input.read(buf, 0, buf.length)
            while (n >= 0) {
              jar.write (buf, 0, n)
              n = input.read(buf, 0, buf.length)
            }
            jar.closeEntry
            input.close
          } else {
            addFromDir(entry, prefix + entry.getName + "/")
          }
        }
      }

      addFromDir(sourcePath, "")
      jar.close
    } catch {
      case _:Error => jarFile.delete // XXX what errors to catch?
    }
  }

  /** Read the entire contents of a file as a String. */
  private def contentsOfFile(filename: String): String = {
    val strbuf = new StringBuilder
    val reader = new FileReader(filename)
    val cbuf = new Array[Char](1024)
    while(true) {
      val n = reader.read(cbuf)
      if (n <= 0)
        return strbuf.toString
      strbuf.append(cbuf, 0, n)
    }
    throw new Error("impossible")
  }

  /** Find the length of the header in the specified file, if
    * there is one.  The header part starts with "#!" or "::#!"
    * and ends with a line that begins with "!#" or "::!#".
    */
  private def headerLength(filename: String): Int = {
    import java.util.regex._

    val fileContents = contentsOfFile(filename)

    if (!(fileContents.startsWith("#!") || fileContents.startsWith("::#!")))
      return 0

    val matcher =
      (Pattern.compile("^(::)?!#.*(\\r|\\n|\\r\\n)", Pattern.MULTILINE)
              .matcher(fileContents))
    if (! matcher.find)
      throw new Error("script file does not close its header with !# or ::!#")

    return matcher.end
  }

  /** Wrap a script file into a runnable object named
   *  <code>scala.scripting.Main</code>.
   *
   *  @param filename      ...
   *  @param getSourceFile ...
   *  @return              ...
   */
  def wrappedScript(filename: String, getSourceFile: PlainFile => SourceFile): SourceFile = {
    val preamble =
      new SourceFile("<script preamble>",
          ("package $scalascript\n" +
          "object Main {\n" +
          "  def main(argv: Array[String]): Unit = {\n" +
          "  val args = argv;\n").toCharArray)

    val middle = {
      val f = new File(filename)
      new SourceFileFragment(
          getSourceFile(new PlainFile(f)),
          headerLength(filename),
          f.length.asInstanceOf[Int])
    }
    val end = new SourceFile("<script trailer>", "\n} }\n".toCharArray)

    new CompoundSourceFile(preamble, middle, end)
  }

  /** Compile a script using the fsc compilation deamon.
   *
   *  @param settings     ...
   *  @param scriptFileIn ...
   *  @return             ...
   */
  private def compileWithDaemon(
      settings: GenericRunnerSettings,
      scriptFileIn: String): Boolean =
  {
    val scriptFile = CompileClient.absFileName(scriptFileIn)
    for (setting:settings.StringSetting <- List(
            settings.classpath,
            settings.sourcepath,
            settings.bootclasspath,
            settings.extdirs,
            settings.outdir))
      setting.value = CompileClient.absFileNames(setting.value)

    val compSettingNames =
      (new Settings(error)).allSettings.map(.name)

    val compSettings =
      settings.allSettings.filter(stg =>
        compSettingNames.contains(stg.name))

    val coreCompArgs =
      compSettings.foldLeft[List[String]](Nil)((args, stg) =>
        stg.unparse ::: args)

    val compArgs = coreCompArgs ::: List("-Xscript", scriptFile)

    val socket = CompileSocket.getOrCreateSocket("")
    val out = new PrintWriter(socket.getOutputStream(), true)
    val in = new BufferedReader(new InputStreamReader(socket.getInputStream()))

    out.println(CompileSocket.getPassword(socket.getPort))
    out.println(compArgs.mkString("", "\0", ""))

    var compok = true

    var fromServer = in.readLine()
    while (fromServer ne null) {
      Console.println(fromServer)
      if (CompileSocket.errorPattern.matcher(fromServer).matches)
        compok = false

      fromServer = in.readLine()
    }
    in.close()
    out.close()
    socket.close()

    compok
  }

  /** Compile a script and then run the specified closure with
    * a classpath for the compiled script.
    */
  private def withCompiledScript
        (settings: GenericRunnerSettings, scriptFile: String)
        (handler: String => Unit)
        : Unit =
  {
    import Interpreter.deleteRecursively

    /** Compiles the script file, and returns two things:
      * the directory with the compiled class files,
      * and a flag for whether the compilation succeeded.
      */
    def compile: (File, Boolean) = {
      val compiledPath = File.createTempFile("scalascript", "")
      compiledPath.delete  // the file is created as a file; make it a directory
      compiledPath.mkdirs

      settings.outdir.value = compiledPath.getPath

      if (settings.nocompdaemon.value) {
        val reporter = new ConsoleReporter(settings)
        val compiler = new Global(settings, reporter)
        val cr = new compiler.Run
        cr.compileSources(List(wrappedScript(scriptFile, &compiler.getSourceFile)))
        (compiledPath, !reporter.hasErrors)
      } else {
        val compok = compileWithDaemon(settings, scriptFile)
        (compiledPath, compok)
      }
    }

    if (settings.savecompiled.value) {
      val jarFile = jarFileFor(scriptFile)

      def jarOK = (jarFile.canRead &&
        (jarFile.lastModified > new File(scriptFile).lastModified))

      if (jarOK) {
        // pre-compiled jar is current
        handler(jarFile.getAbsolutePath)
      } else {
        // The pre-compiled jar is old.  Recompile the script.
        jarFile.delete
        val (compiledPath, compok) = compile
        try {
          if (compok) {
            tryMakeJar(jarFile, compiledPath)
            if (jarOK) {
              deleteRecursively(compiledPath)
              handler(jarFile.getAbsolutePath)
            } else {
              // run from the interpreter's temporary
              // directory
              handler(compiledPath.getPath)
            }
          }
        } finally {
          deleteRecursively(compiledPath)
        }
      }
    } else {
      // don't use the cache; just run from the interpreter's temporary directory
      val (compiledPath, compok) = compile
      try {
        if (compok)
          handler(compiledPath.getPath)
      } finally {
        deleteRecursively(compiledPath)
      }
    }
  }

  /** Run a script file with the specified arguments and compilation
   *  settings.
   *
   *  @param settings   ...
   *  @param scriptFile ...
   *  @param scriptArgs ...
   */
  def runScript(
      settings: GenericRunnerSettings,
      scriptFile: String,
      scriptArgs: List[String]): Unit =
  {
    val f = new File(scriptFile)
    if (!f.exists || f.isDirectory) {
      scala.Console.println("no such file: " + scriptFile)
      return
    }

    withCompiledScript(settings, scriptFile)(compiledLocation => {
      def pparts(path: String) = path.split(File.pathSeparator).toList

      val classpath =
        pparts(settings.bootclasspath.value) :::
        List(compiledLocation) :::
        pparts(settings.classpath.value)

      try {
        ObjectRunner.run(
          classpath,
          "$scalascript.Main",
          scriptArgs.toArray)
      } catch {
        case e:InvocationTargetException =>
          e.getCause.printStackTrace
      }
    })
  }
}