summaryrefslogtreecommitdiff
path: root/src/compiler/scala/tools/nsc/ScriptRunner.scala
blob: 3266ec3799fd966865f66c2906ae2ff6939a2257 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
/* NSC -- new Scala compiler
 * Copyright 2005-2007 LAMP/EPFL
 * @author  Martin Odersky
 */
// $Id$

package scala.tools.nsc

import java.io.{BufferedReader, File, FileInputStream, FileOutputStream,
                FileReader, InputStreamReader, PrintWriter}
import java.lang.reflect.InvocationTargetException
import java.net.URL
import java.util.jar.{JarEntry, JarOutputStream}

import scala.tools.nsc.io.PlainFile
import scala.tools.nsc.reporters.{Reporter,ConsoleReporter}
import scala.tools.nsc.util.{ClassPath, CompoundSourceFile, SourceFile, SourceFileFragment}

/** An object that runs Scala code in script files.
 *
 *  <p>For example, here is a complete Scala script on Unix:</pre>
 *  <pre>
 *    #!/bin/sh
 *    exec scala "$0" "$@"
 *    !#
 *    Console.println("Hello, world!")
 *    argv.toList foreach Console.println
 *  </pre>
 *  <p>And here is a batch file example on Windows XP:</p>
 *  <pre>
 *    ::#!
 *    @echo off
 *    call scala %0 %*
 *    goto :eof
 *    ::!#
 *    Console.println("Hello, world!")
 *    argv.toList foreach Console.println
 *  </pre>
 *
 *  @author  Lex Spoon
 *  @version 1.0, 15/05/2006
 *  @todo    It would be better if error output went to stderr instead
 *           of stdout...
 */
class ScriptRunner {
  protected def compileClient: StandardCompileClient = CompileClient //todo: lazy val
  protected def compileSocket: CompileSocket = CompileSocket //todo: make lazy val


  /** Default name to use for the wrapped script */
  val defaultScriptMain = "Main"

  /** Pick a main object name from the specified settings */
  def scriptMain(settings: Settings) =
    if (settings.script.value == "")
      defaultScriptMain
    else
      settings.script.value

  /** Choose a jar filename to hold the compiled version
   * of a script
   */
  private def jarFileFor(scriptFile: String): File = {
    val filename =
      if (scriptFile.matches(".*\\.[^.\\\\/]*"))
        scriptFile.replaceFirst("\\.[^.\\\\/]*$", ".jar")
      else
        scriptFile + ".jar"

    new File(filename)
  }

  /** Try to create a jar file out of all the contents
   *  of the directory <code>sourcePath</code>.
   */
  private def tryMakeJar(jarFile: File, sourcePath: File) = {
    try {
      val jarFileStream = new FileOutputStream(jarFile)
      val jar = new JarOutputStream(jarFileStream)
      val buf = new Array[Byte](10240)

      def addFromDir(dir: File, prefix: String) {
        for (entry <- dir.listFiles) {
          if (entry.isFile) {
            jar.putNextEntry(new JarEntry(prefix + entry.getName))

            val input = new FileInputStream(entry)
            var n = input.read(buf, 0, buf.length)
            while (n >= 0) {
              jar.write (buf, 0, n)
              n = input.read(buf, 0, buf.length)
            }
            jar.closeEntry
            input.close
          } else {
            addFromDir(entry, prefix + entry.getName + "/")
          }
        }
      }

      addFromDir(sourcePath, "")
      jar.close
    } catch {
      case _:Error => jarFile.delete // XXX what errors to catch?
    }
  }

  /** Read the entire contents of a file as a String. */
  private def contentsOfFile(filename: String): String = {
    val strbuf = new StringBuilder
    val reader = new FileReader(filename)
    val cbuf = new Array[Char](1024)
    while(true) {
      val n = reader.read(cbuf)
      if (n <= 0)
        return strbuf.toString
      strbuf.append(cbuf, 0, n)
    }
    throw new Error("impossible")
  }

  /** Find the length of the header in the specified file, if
    * there is one.  The header part starts with "#!" or "::#!"
    * and ends with a line that begins with "!#" or "::!#".
    */
  private def headerLength(filename: String): Int = {
    import java.util.regex._

    val fileContents = contentsOfFile(filename)

    if (!(fileContents.startsWith("#!") || fileContents.startsWith("::#!")))
      return 0

    val matcher =
      (Pattern.compile("^(::)?!#.*(\\r|\\n|\\r\\n)", Pattern.MULTILINE)
              .matcher(fileContents))
    if (! matcher.find)
      throw new Error("script file does not close its header with !# or ::!#")

    return matcher.end
  }

  /** Split a fully qualified object name into a
   *  package and an unqualified object name */
  private def splitObjectName(fullname: String):
  (Option[String],String) =
  {
    val idx = fullname.lastIndexOf('.')
    if (idx < 0)
      (None, fullname)
    else
      (Some(fullname.substring(0,idx)), fullname.substring(idx+1))
  }

  /** Code that is added to the beginning of a script file to make
   *  it a complete Scala compilation unit.
   */
  protected def preambleCode(objectName: String) =  {
    val (maybePack, objName) = splitObjectName(objectName)

    val packageDecl =
      maybePack match {
	case Some(pack) => "package " + pack + "\n"
	case None => ""
      }

    (packageDecl +
     "object " + objName + " {\n" +
     "  def main(argv: Array[String]): Unit = {\n" +
     "  val args = argv;\n")
  }

  /** Code that is added to the end of a script file to make
   *  it a complete Scala compilation unit.
   */
  val endCode = "\n} }\n"


  /** Wrap a script file into a runnable object named
   *  <code>scala.scripting.Main</code>.
   */
  def wrappedScript(
    objectName: String,
    filename: String,
    getSourceFile: PlainFile => SourceFile): SourceFile =
  {
    val preamble =
      new SourceFile("<script preamble>",
		     preambleCode(objectName).toCharArray)

    val middle = {
      val f = new File(filename)
      new SourceFileFragment(
          getSourceFile(new PlainFile(f)),
          headerLength(filename),
          f.length.asInstanceOf[Int])
    }

    val end = new SourceFile("<script trailer>", endCode.toCharArray)

    new CompoundSourceFile(preamble, middle, end)
  }

  /** Compile a script using the fsc compilation deamon.
   *
   *  @param settings     ...
   *  @param scriptFileIn ...
   *  @return             ...
   */
  private def compileWithDaemon(
      settings: GenericRunnerSettings,
      scriptFileIn: String): Boolean =
  {
    val scriptFile = compileClient.absFileName(scriptFileIn)
    for (setting:settings.StringSetting <- List(
            settings.classpath,
            settings.sourcepath,
            settings.bootclasspath,
            settings.extdirs,
            settings.outdir))
      setting.value = compileClient.absFileNames(setting.value)

    val compSettingNames =
      (new Settings(error)).allSettings.map(_.name)

    val compSettings =
      settings.allSettings.filter(stg =>
        compSettingNames.contains(stg.name))

    val coreCompArgs =
      compSettings.foldLeft[List[String]](Nil)((args, stg) =>
        stg.unparse ::: args)

    val compArgs =
      (coreCompArgs :::
        List("-Xscript", scriptMain(settings), scriptFile))

    val socket = compileSocket.getOrCreateSocket("")
    if (socket eq null)
      return false

    val out = new PrintWriter(socket.getOutputStream(), true)
    val in = new BufferedReader(new InputStreamReader(socket.getInputStream()))

    out.println(compileSocket.getPassword(socket.getPort))
    out.println(compArgs.mkString("", "\0", ""))

    var compok = true

    var fromServer = in.readLine()
    while (fromServer ne null) {
      Console.println(fromServer)
      if (compileSocket.errorPattern.matcher(fromServer).matches)
        compok = false

      fromServer = in.readLine()
    }
    in.close()
    out.close()
    socket.close()

    compok
  }

  protected def newGlobal(settings: Settings, reporter: Reporter) =
    new Global(settings, reporter)

  /** Compile a script and then run the specified closure with
    * a classpath for the compiled script.
    */
  private def withCompiledScript
        (settings: GenericRunnerSettings, scriptFile: String)
        (handler: String => Unit)
        : Unit =
  {
    import Interpreter.deleteRecursively

    /** Compiles the script file, and returns two things:
      * the directory with the compiled class files,
      * and a flag for whether the compilation succeeded.
      */
    def compile: (File, Boolean) = {
      val compiledPath = File.createTempFile("scalascript", "")
      compiledPath.delete  // the file is created as a file; make it a directory
      compiledPath.mkdirs

      // delete the directory after the user code has finished
      Runtime.getRuntime.addShutdownHook(new Thread {
	override def run { deleteRecursively(compiledPath) }})

      settings.outdir.value = compiledPath.getPath

      if (settings.nocompdaemon.value) {
        val reporter = new ConsoleReporter(settings)
        val compiler = newGlobal(settings, reporter)
        val cr = new compiler.Run
	val wrapped =
	  wrappedScript(
	    scriptMain(settings),
	    scriptFile,
	    compiler.getSourceFile _)
        cr.compileSources(List(wrapped))
        (compiledPath, !reporter.hasErrors)
      } else {
        val compok = compileWithDaemon(settings, scriptFile)
        (compiledPath, compok)
      }
    }

    if (settings.savecompiled.value) {
      val jarFile = jarFileFor(scriptFile)

      def jarOK = (jarFile.canRead &&
        (jarFile.lastModified > new File(scriptFile).lastModified))

      if (jarOK) {
        // pre-compiled jar is current
        handler(jarFile.getAbsolutePath)
      } else {
        // The pre-compiled jar is old.  Recompile the script.
        jarFile.delete
        val (compiledPath, compok) = compile

        if (compok) {
          tryMakeJar(jarFile, compiledPath)
          if (jarOK) {
            deleteRecursively(compiledPath)  // may as well do it now
            handler(jarFile.getAbsolutePath)
          } else {
            // jar failed; run directly from the class files
            handler(compiledPath.getPath)
          }
        }
      }
    } else {
      // don't use a cache jar at all--just use the class files
      val (compiledPath, compok) = compile

      if (compok)
        handler(compiledPath.getPath)
    }
  }

  /** Run a script file with the specified arguments and compilation
   *  settings.
   *
   *  @param settings   ...
   *  @param scriptFile ...
   *  @param scriptArgs ...
   */
  def runScript(
      settings: GenericRunnerSettings,
      scriptFile: String,
      scriptArgs: List[String])
  {
    val f = new File(scriptFile)
    if (!f.isFile) {
      Console.println("no such file: " + scriptFile)
      return
    }

    withCompiledScript(settings, scriptFile)(compiledLocation => {
      def fileToURL(f: File): Option[URL] =
        try { Some(f.toURL) }
        catch { case e => Console.println(e); None }

      def paths(str: String, expandStar: Boolean): List[URL] =
        for (
         file <- ClassPath.expandPath(str, expandStar) map (new File(_)) if file.exists;
          val url = fileToURL(file); if !url.isEmpty
        ) yield url.get

      val classpath: List[URL] =
        paths(settings.bootclasspath.value, true) :::
        paths(compiledLocation, false) :::
        paths(settings.classpath.value, true)

      try {
        ObjectRunner.run(
          classpath,
          scriptMain(settings),
          scriptArgs.toArray)
      } catch {
        case e:InvocationTargetException =>
          e.getCause.printStackTrace
      }
    })
  }
}


object ScriptRunner extends ScriptRunner