diff options
Diffstat (limited to 'src/compiler/scala/tools/nsc/ScriptRunner.scala')
-rw-r--r-- | src/compiler/scala/tools/nsc/ScriptRunner.scala | 57 |
1 files changed, 47 insertions, 10 deletions
diff --git a/src/compiler/scala/tools/nsc/ScriptRunner.scala b/src/compiler/scala/tools/nsc/ScriptRunner.scala index 2f0a0fdd4c..7afc0f5b5c 100644 --- a/src/compiler/scala/tools/nsc/ScriptRunner.scala +++ b/src/compiler/scala/tools/nsc/ScriptRunner.scala @@ -43,6 +43,15 @@ import scala.tools.nsc.util.{CompoundSourceFile, SourceFile, SourceFileFragment} * of stdout... */ object ScriptRunner { + /** Default name to use for the wrapped script */ + val defaultScriptMain = "Main" + + /** Pick a main object name from the specified settings */ + def scriptMain(settings: Settings) = + if (settings.script.value == "") + defaultScriptMain + else + settings.script.value /** Choose a jar filename to hold the compiled version * of a script @@ -127,18 +136,39 @@ object ScriptRunner { return matcher.end } + /** Split a fully qualified object name into a + * package and an unqualified object name */ + private def splitObjectName(fullname: String): + (Option[String],String) = + { + val idx = fullname.lastIndexOf('.') + if (idx < 0) + (None, fullname) + else + (Some(fullname.substring(0,idx)), fullname.substring(idx+1)) + } + + /** Wrap a script file into a runnable object named * <code>scala.scripting.Main</code>. - * - * @param filename ... - * @param getSourceFile ... - * @return ... */ - def wrappedScript(filename: String, getSourceFile: PlainFile => SourceFile): SourceFile = { + def wrappedScript( + objectName: String, + filename: String, + getSourceFile: PlainFile => SourceFile): SourceFile = + { + val (maybePack, objName) = splitObjectName(objectName) + + val packageDecl = + maybePack match { + case Some(pack) => "package " + pack + "\n" + case None => "" + } + val preamble = new SourceFile("<script preamble>", - ("package $scalascript\n" + - "object Main {\n" + + (packageDecl + + "object " + objName + " {\n" + " def main(argv: Array[String]): Unit = {\n" + " val args = argv;\n").toCharArray) @@ -184,7 +214,9 @@ object ScriptRunner { compSettings.foldLeft[List[String]](Nil)((args, stg) => stg.unparse ::: args) - val compArgs = coreCompArgs ::: List("-Xscript", scriptFile) + val compArgs = + (coreCompArgs ::: + List("-script", scriptMain(settings), scriptFile)) val socket = CompileSocket.getOrCreateSocket("") val out = new PrintWriter(socket.getOutputStream(), true) @@ -235,7 +267,12 @@ object ScriptRunner { val reporter = new ConsoleReporter(settings) val compiler = new Global(settings, reporter) val cr = new compiler.Run - cr.compileSources(List(wrappedScript(scriptFile, compiler.getSourceFile _))) + val wrapped = + wrappedScript( + scriptMain(settings), + scriptFile, + compiler.getSourceFile _) + cr.compileSources(List(wrapped)) (compiledPath, !reporter.hasErrors) } else { val compok = compileWithDaemon(settings, scriptFile) @@ -324,7 +361,7 @@ object ScriptRunner { try { ObjectRunner.run( classpath, - "$scalascript.Main", + scriptMain(settings), scriptArgs.toArray) } catch { case e:InvocationTargetException => |