From 01443e42ed009c7125fe5b5c07ec20ff5eadbd17 Mon Sep 17 00:00:00 2001 From: Lex Spoon Date: Thu, 16 Feb 2006 23:08:08 +0000 Subject: Many, many cleanups and small improvements. --- src/compiler/scala/tools/nsc/Interpreter.scala | 657 ++++++++++++++------- src/compiler/scala/tools/nsc/MainInterpreter.scala | 166 ++++-- 2 files changed, 554 insertions(+), 269 deletions(-) diff --git a/src/compiler/scala/tools/nsc/Interpreter.scala b/src/compiler/scala/tools/nsc/Interpreter.scala index 847cdade6a..c67e9a191e 100644 --- a/src/compiler/scala/tools/nsc/Interpreter.scala +++ b/src/compiler/scala/tools/nsc/Interpreter.scala @@ -3,232 +3,459 @@ * @author Martin Odersky */ // $Id$ -package scala.tools.nsc; - -import scala.tools.nsc.reporters.Reporter; - -abstract class Interpreter { - import scala.collection.mutable.ListBuffer; - import symtab.Names; - - // list of names defined, for each line number - val prevDefines : ListBuffer[Pair[Int,ListBuffer[Names#Name]]] = new ListBuffer(); - - val compiler: Global; - - import scala.tools.nsc.ast.parser.SyntaxAnalyzer; - object syntaxAnalyzer extends SyntaxAnalyzer { - val global: compiler.type = compiler - } - - def interpret(line: String, reporter: Reporter): unit = { - import scala.tools.nsc.util.SourceFile; - - // convert input to a compilation unit, using SourceFile; - // and parse it, using syntaxAnalyzer, to get input ASTs - val inASTs = syntaxAnalyzer.interpreterParse( - new compiler.CompilationUnit( - //if SourceFile is not modified, then fix a bug - //here by adding an EOF character to the end of - //the 'line' - new SourceFile("",line.toCharArray()))); - - //todo: if (errors in parsing) after reporting them, exit method - - val dvt = new DefinedVarsTraverser; - dvt.traverseTrees(inASTs); - val definedVars = dvt.definedVars; - - val ivt = new ImportVarsTraverser(definedVars); - ivt.traverseTrees(inASTs); - val importVars = ivt.importVars; - - val lineno = prevDefines.length; - //todo: it is probably nice to include date & time, as well as a process id, in the filename - val filename = getTempPath().getPath()+java.io.File.separator+"InterpreterTempLine"+lineno+".scala"; - writeTempScalaFile(filename, line, lineno, definedVars, importVars); - - // first phase: compile auto-generated file - compiler.settings.outdir.value = getTempPath().getPath(); - val cr = new compiler.Run; - cr compile List(filename); - - //todo: if no errors in compilation then - // second phase: execute JVM, and print outcome - // else consider definition as if has not happened and exit method - //todo: use Scala's reflection API, which I designed, instead, for the following code - val cl = new java.net.URLClassLoader(Predef.Array(getTempPath().toURL())); - val interpreterResultObject: Class = Class.forName("InterpreterLine"+lineno+"Result",true,cl); - val resultValMethod: java.lang.reflect.Method = interpreterResultObject.getMethod("result",null); - var interpreterResultString: String = resultValMethod.invoke(interpreterResultObject,null).toString(); - - //var interpreterResultJavaTypeString: String = resultValMethod.getReturnType().getName(); - //Console.println(compiler.definitions.EmptyPackage.info.members); - val interpreterResultSym: compiler.Symbol = - compiler.definitions.getMember(compiler.definitions.EmptyPackage, - compiler.newTermName("InterpreterLine"+lineno+"Result")); - - def findSymbolWithName(ls: List[compiler.Symbol], name: compiler.Name): compiler.Symbol = - ls.find(s=>s.name == name) match { - case None => throw new IllegalStateException("Cannot find field '"+name+"' in InterpreterResult"); - case Some(s) => s; - } +package scala.tools.nsc + +import reporters.Reporter +import nsc.util.SourceFile +import scala.tools.util.PlainFile +import java.io.{File, Writer, PrintWriter, StringWriter} +import nsc.ast.parser.SyntaxAnalyzer +import scala.collection.mutable.{ListBuffer, HashSet, ArrayBuffer} +import scala.collection.immutable.{Map, ListMap} +import symtab.Flags + +/** An interpreter for Scala code. + + The main public entry points are compile() and interpret(). The compile() + method loads a complete Scala file. The interpret() method executes one + line of Scala code at the request of the user. - //var lastname: String = compiler.atPhase(cr.typerPhase.next){interpreterResultSym.info.decls.toList.last.name.toString()}; - //reporter.info(null,lastname,true); - //todo: similar to what I should be doing for Scala's reflection?? - var interpreterResultScalaTypeString: String = - compiler.atPhase(cr.typerPhase.next){ - findSymbolWithName(interpreterResultSym.info.decls.toList, - compiler.nme.getterToLocal(compiler.newTermName("result"))) - .tpe.toString() - }; - reporter.info(null,interpreterResultString+": "+interpreterResultScalaTypeString/*+" ("+interpreterResultJavaTypeString+")"*/,true); - -/* - val scalaInterpFile: File = ScalaInterpFile(filename); - scalaInterpFile.deleteOnExit(); - if(scalaInterpFile.exists()) - scalaInterpFile.delete(); - - getvalue of line#.last_var_defined_in_line (from defined_vars) - (works for 'it' as it was added as last val to definedvars) - and send it to reporter + The overall approach is based on compiling the requested code and then + using a Java classloader and using Java reflection to run the code + and access its results. + + In more detail, a single compiler instance is used + to accumulate all successfully compiled or interpreted Scala code. To + "interpret" a line of code, the compiler generates a fresh object that + includes the line of code and which has public member(s) to export + all variables defined by that code. To extract the result of an + interpreted line to show the user, a second "result object" is created + which exports a single member named "result". To accomodate user expressions + that read from variables or methods defined in previous statements, "import" + statements are used. + + This interpreter shares the strengths and weaknesses of using the + full compiler-to-Java. The main strength is that interpreted code + behaves exactly as does compiled code, including running at full speed. + The main weakness is that redefining classes and methods is not handled + properly, because rebinding at the Java level is technically difficult. */ +class Interpreter(val compiler: Global, output: (String => Unit)) { + import symtab.Names + import compiler.Traverser + import compiler.{Tree, TermTree, + ValOrDefDef, ValDef, DefDef, Assign, + ClassDef, ModuleDef, Ident, Select, AliasTypeDef} + import compiler.CompilationUnit + import compiler.Symbol + import compiler.Name - // book-keeping - //todo: if no errors in evaluation then - prevDefines += Pair(lineno,definedVars); - // else consider definition as if has not happened. - - // report some debug info - //reporter.info(null,"inASTs="+inASTs,true); - //reporter.info(null,"definedVars="+definedVars,true); - //reporter.info(null,"importVars="+importVars,true); - //reporter.info(null,"prevDefines="+prevDefines,true); - } - - import java.io.File; - def getTempPath(): File = { - val tempdir = { - val tempdir1 = System.getProperty("java.io.tmpdir"); - if (tempdir1 == null){ - val tempdir2 = System.getProperty("TEMP"); - if (tempdir2 == null){ - val tempdir3 = System.getProperty("TMP"); - if (tempdir3 == null) - throw new IllegalStateException("No temporary folder defined") - else tempdir3 } - else tempdir2 } - else tempdir1 - }; - val path = new File(tempdir); - if (!path.exists() || !path.isDirectory()) - throw new IllegalStateException("Invalid temporary directory") - else if (!path.canWrite()) - throw new IllegalStateException("Temporary directory not writable") - else path - }; - - def writeTempScalaFile(filename: String, line: String, lineno: Int, definedVars: ListBuffer[Names#Name], importVars: ListBuffer[Pair[Names#Name,Int]]) = { - import java.io.{File, PrintWriter, FileOutputStream}; - val scalaFile = new File(filename); - scalaFile.deleteOnExit(); - if(scalaFile.exists()) // to prevent old lingering files from having results from them reported! - scalaFile.delete(); - - val module = new PrintWriter(new FileOutputStream(scalaFile)); - //todo:"import "+LoadedModules?.getName - //module.println("\n"); - - for(val Pair(ivname,ivlineno) <- importVars.toList) yield - module.println("import line"+ivlineno+"."+ivname+";\n"); - - module.println("object line"+lineno+" {"); - var fullLine = line; - if(definedVars.length == 0) { // input is just an expression - fullLine = " var it = " + line; - definedVars += compiler.encode("it"); } - else fullLine = " " + line; - module.println(fullLine); - module.println("}"); - module.println(); - module.println("object InterpreterLine"+lineno+"Result "); - module.println("{ val result = (line"+lineno+"."+definedVars.toList.reverse.head+"); }"); - // reflection is used later (see above) to get the result value above - - module.flush(); - module.close(); - } - - import compiler.Traverser; - import compiler.Tree; - class DefinedVarsTraverser extends Traverser { - val definedVars = new ListBuffer[Names#Name]; - override def traverse(ast: Tree): unit = - if (!ast.isDef) () - else { - import compiler._; - ast match { - // only the outer level needed, so do not recurse to go deeper - // todo: combine similar cases in one case - case ClassDef(_,name,_,_,_) => definedVars += name - case ModuleDef(_, name,_) => definedVars += name - case ValDef(_, name,_,_) => definedVars += name - case DefDef(_,name,_,_,_,_) => definedVars += name - //todo:case Bind(name,_) => ((name != nme.WILDCARD) && (definedVars.elements forall (name !=))) definedVars += name - - //case Ident(name) => if (name...is defined) definedVars += name; - - //todo: - //case PackageDef(name, _) => throw new InterpIllegalDefException(name.toString()+": package definitions not allowed") - //case AbsTypeDef(_,name,_,_) => throw new InterpIllegalDefException(name.toString()+": absract type definitions not allowed") - //case AliasTypeDef(_,name,_,_) => throw new InterpIllegalDefException(name.toString()+": alias type definitions not allowed") - //case LabelDef(name,_,_) => throw new InterpIllegalDefException(name.toString()+": label definitions not allowed") - case _ => throw new InterpIllegalDefException("Unsupported interpreter definition. Contact Scala developers for adding interpreter support for it.")// () - } + /** construct an interpreter that prints to the compiler's reporter */ + def this(compiler: Global) = { + this(compiler, str => compiler.reporter.info(null, str, true)) + } + + private def reporter = compiler.reporter + + /** directory to save .class files to */ + private val classfilePath = File.createTempFile("scalaint", "") + classfilePath.delete // the file is created as a file; make it a directory + classfilePath.mkdirs + + + + /* set up the compiler's output directory */ + compiler.settings.outdir.value = classfilePath.getPath + + /** class loader used to load compiled code */ + /* A single class loader is used for all commands interpreted by this Interpreter. + It would also be possible to create a new class loader for each command + to interpret. The advantages of the current approach are: + + - Expressions are only evaluated one time. This is especially + significant for I/O, e.g. "val x = Console.readLine" + + The main disadvantage is: + + - Objects, classes, and methods cannot be rebound. Instead, definitions + shadow the old ones, and old code objects to refer to the old + definitions. + */ + private val classLoader = new java.net.URLClassLoader(Predef.Array(classfilePath.toURL)) + + + /** the previous requests this interpreter has processed */ + private val prevRequests = new ArrayBuffer[Request]() + + /** look up the request that bound a specified term or type */ + private def reqBinding(vname: Name): Option[Request] = { + prevRequests.toList.reverse.find(lin => lin.boundNames.contains(vname)) + } + + /** generate a string using a routine that wants to write on a stream */ + private def stringFrom(writer: PrintWriter=>Unit): String = { + val stringWriter = new StringWriter() + val stream = new PrintWriter(stringWriter) + writer(stream) + stream.close + stringWriter.toString + } + + /** parse a line into a sequence of trees */ + private def parse(line: String): List[Tree] = { + reporter.reset + + val unit = + new CompilationUnit( + new SourceFile("",line.toCharArray())) + + val trees = new compiler.syntaxAnalyzer.Parser(unit).templateStatSeq + + if(reporter.errors > 0) + return Nil // the result did not parse, so stop + + trees + } + + + /** Compile one source file */ + def compile(filename: String): Unit = { + val jfile = new File(filename) + if(!jfile.exists) { + reporter.error(null, "no such file: " + filename) + return () + } + val cr = new compiler.Run + cr.compileSources(List(new SourceFile(PlainFile.fromFile(jfile)))) + } + + /** build a request from the user. "tree" is "line" after being parsed */ + private def buildRequest(trees: List[Tree], line: String, lineName: String): Request = { + trees match { + /* This case for assignments is more specialized than desirable: it only + handles assignments to an identifier. It would be better to support + arbitrary paths being assigned, but that is technically difficult + because of the way objectSourceCode and resultObjectSourceCode are + implemented in class Request. */ + case List(Assign(Ident(lhs), _)) => new AssignReq(lhs, line, lineName) + + + case _ if trees.forall(t => t.isInstanceOf[ValOrDefDef]) => new DefReq(line, lineName) + case List(_:TermTree) | List(_:Ident) | List(_:Select) => new ExprReq(line, lineName) + case List(_:ModuleDef) => new ModuleReq(line, lineName) + case List(_:ClassDef) => new ClassReq(line, lineName) + case List(_:AliasTypeDef) => new TypeAliasReq(line, lineName) + case _ => { + reporter.error(null, "That kind of statement combination is not supported by the interpreter.") + null } + } + } - case class InterpIllegalDefException(msg: String) extends RuntimeException(msg); -// class ListTraverser extends Traverser { -// def traverse(trees: List[Tree]): Unit = -// trees foreach traverse; -// } -// -// class ListDefinedVarsTraverser extends DefinedVarsTraverser with ListTraverser; + /** interpret one line of input. All feedback, including parse errors + and evaluation results, are printed via the supplied compiler's + reporter. Values defined are available for future interpreted + strings. */ + def interpret(line: String): Unit = { + // parse + val trees = parse(line) + if(trees.isEmpty) return () // parse error or empty input + + // figure out what kind of request + val lineno = prevRequests.length + val lineName = "line" + lineno + + val req = buildRequest(trees, line, lineName) + if(req == null) return () // a disallowed statement type + + + if(!req.compile) + return () // an error happened during compilation, e.g. a type error + + val interpreterResultString = req.loadAndRun + + // print the result + output(interpreterResultString) + + // print out types of functions; they are not printed in the + // request printout + output(req.defTypesSummary) + + // book-keeping + prevRequests += req + } + + + /** Delete a directory tree recursively. Use with care! */ + private def deleteRecursively(path: File): Unit = { + path match { + case _ if(!path.exists) => () + case _ if(path.isDirectory) => + for(val p <- path.listFiles) + deleteRecursively(p) + path.delete + case _ => path.delete + } + } + + /** This instance is no longer needed, so release any resources + it is using. + + Specifically, this deletes the temporary directory used for holding + class files for this instance. This cannot safely be done as commands + are executed becaus of Java's demand loading. + */ + def close: Unit = { + deleteRecursively(classfilePath) + } + + + /** A traverser that finds all mentioned identifiers, i.e. things that need to be imported. + It might return extra names. */ + private class ImportVarsTraverser(definedVars: List[Name]) extends Traverser { + val importVars = new HashSet[Name]() - class ImportVarsTraverser(definedVars: ListBuffer[Names#Name]) extends Traverser { - val importVars = new ListBuffer[Pair[Names#Name,Int]]; - var curAST = 0; - import compiler.Ident; override def traverse(ast: Tree): unit = ast match { - case Ident(name) => { - var lastPrevDefsIdx = -1; - //reporter.info(null,"name="+name,true); - for(val Pair(lineno,defs) <- prevDefines.toList) yield { - //reporter.info(null,"line#="+lineno+", defs="+defs,true); - if (defs.indexOf(name) != -1) lastPrevDefsIdx = lineno - } - val foundInPrevDefines = (lastPrevDefsIdx != -1); - //reporter.info(null,"lastPrevDefsIdx="+lastPrevDefsIdx,true); - if(foundInPrevDefines) { - val firstCurDefIdx = definedVars.indexOf(name); - val foundInDefinedVars = (firstCurDefIdx != -1); - if((!foundInDefinedVars || - (foundInDefinedVars && (firstCurDefIdx > curAST))) - && (importVars.indexOf(Pair(name,lastPrevDefsIdx)) == -1)) - // to prevent duplicate imports (todo: use find instead of indexOf?) - importVars += Pair(name,lastPrevDefsIdx); + case Ident(name) => importVars += name + case _ => super.traverse(ast) + } + } + + + /** One line of code submitted by the user for interpretation */ + private abstract class Request(line: String, val lineName: String) { + val trees = parse(line) + + /** name to use for the object that will compute "line" */ + def objectName = lineName + "$object" // make it unlikely to clash with user variables + + /** name of the object that retrieves the result from the above object */ + def resultObjectName = "RequestResult$" + objectName + + /** whether the trees need a variable name, as opposed to standing + alone */ + def needsVarName: Boolean = false + + /** list of methods defined */ + val defNames = + for { + val DefDef(mods, name, _, _, _, _) <- trees + mods.isPublic + } yield name + + /** list of val's and var's defined */ + val valAndVarNames = { + val baseNames = + for { + val ValDef(mods, name, _, _) <- trees + mods.isPublic + } yield name + + if(needsVarName) + compiler.encode(lineName) :: baseNames // add a var name + else + baseNames + } + //XXXshorten all these for loops + /** list of modules defined */ + val moduleNames = { + val explicit = + for(val ModuleDef(mods, name, _) <- trees; mods.isPublic) + yield name + val caseClasses = + for {val ClassDef(mods, name, _, _, _) <- trees + mods.isPublic + mods.hasFlag(Flags.CASE)} + yield name.toTermName + explicit ::: caseClasses + } + + /** list of classes defined */ + val classNames = + for(val ClassDef(mods, name, _, _, _) <- trees; mods.isPublic) + yield name + + /** list of type aliases defined */ + val typeNames = + for(val AliasTypeDef(mods, name, _, _) <- trees; mods.isPublic) + yield name + + + /** all (public) names defined by these statements */ + val boundNames = defNames ::: valAndVarNames ::: moduleNames ::: classNames ::: typeNames + + + /** list of names used by this expression */ + val usedNames: List[Name] = { + val ivt = new ImportVarsTraverser(boundNames) + ivt.traverseTrees(trees) + ivt.importVars.toList + } + + /** names to print out to the user after evaluation */ + def namesToPrintForUser = valAndVarNames + + /** generate the source code for the object that computes this request */ + def objectSourceCode: String = + stringFrom(code => { + // write an import for each imported variable + for{val imv <- usedNames + val lastDefiner <- reqBinding(imv).toList } { + code.println("import " + lastDefiner.objectName + "." + imv) } + + // object header + code.println("object "+objectName+" {") + + // the line of code to compute + if(needsVarName) + code.println(" val " + lineName + " = " + line) + else + code.println(" " + line) + + //end + code.println(";}") + }) + + /** Types of variables defined by this request. They are computed + after compilation of the main object */ + var typeOf: Map[Name, String] = _ + + + /** generate source code for the object that retrieves the result + from objectSourceCode */ + def resultObjectSourceCode: String = + stringFrom(code => { + code.println("object " + resultObjectName) + code.println("{ val result:String = {") + code.println(objectName + ";") // evaluate the object, to make sure its constructor is run + code.print("\"\"") // print an initial empty string, so later code can + // uniformly be: + morestuff + resultExtractionCode(code) + code.println("}") + code.println(";}") + }) + + def resultExtractionCode(code: PrintWriter): Unit = { + for(val vname <- namesToPrintForUser) { + code.println(" + \"" + vname + ": " + typeOf(vname) + + " = \" + " + objectName + "." + vname + " + \"\\n\"") } - case _ => { - // using case x, instead of case _, we can have: reporter.info(null,"x="+x,true); - super.traverse(ast) + } + + + /** Compile the object file. Returns whether the compilation succeeded. + If all goes well, types is computed and set */ + def compile: Boolean = { + reporter.reset // without this, error counting is not correct, + // and the interpreter sometimes overlooks compile failures! + + // compile the main object + val objRun = new compiler.Run() + objRun.compileSources(List(new SourceFile("", objectSourceCode.toCharArray))) + if(reporter.errors > 0) return false + + // extract and remember types + typeOf = findTypes(objRun) + + // compile the result-extraction object + new compiler.Run().compileSources(List(new SourceFile("", resultObjectSourceCode.toCharArray))) + if(reporter.errors > 0) return false + + // success + true + } + + /** dig the types of all bound variables out of the compiler run */ + def findTypes(objRun: compiler.Run): Map[Name, String] = { + def getTypes(names: List[Name], nameMap: Name=>Name): Map[Name, String] = { + names.foldLeft[Map[Name,String]](new ListMap[Name, String]())((map, name) => { + val resObjSym: Symbol = + compiler.definitions.getMember(compiler.definitions.EmptyPackage, + compiler.newTermName(objectName)) + + val typeString = + compiler.atPhase(objRun.typerPhase.next) { + resObjSym.info.decls.toList.find(s=>s.name == nameMap(name)).get.tpe.toString() + } + + map + name -> typeString + }) } + + val names1 = getTypes(valAndVarNames, n=>compiler.nme.getterToLocal(n)) + val names2 = getTypes(defNames, id) + names1.incl(names2) + } + + /** load and run the code using reflection */ + def loadAndRun: String = { + val interpreterResultObject: Class = Class.forName(resultObjectName,true,classLoader) + val resultValMethod: java.lang.reflect.Method = interpreterResultObject.getMethod("result",null) + resultValMethod.invoke(interpreterResultObject,null).toString() + } + + /** return a summary of the defined methods */ + def defTypesSummary: String = + stringFrom(summ => { + for(val methname <- defNames) { + summ.println("" + methname + ": " + typeOf(methname)) + } + }) + } + + /** A sequence of definition's. val's, var's, def's. */ + private class DefReq(line: String, lineName: String) extends Request(line, lineName) { + } + + /** Assignment of a single variable: lhs = exp */ + private class AssignReq(val lhs: Name, line: String, lineName: String) extends Request(line, lineName) { + override def resultExtractionCode(code: PrintWriter): Unit = { + super.resultExtractionCode(code) + val bindReq = reqBinding(lhs).get + code.println(" + \"" + lhs + " = \" + " + bindReq.objectName + "." + lhs) + } + override def namesToPrintForUser = Nil + } + + /** A single expression */ + private class ExprReq(line: String, lineName: String) extends Request(line, lineName) { + override val needsVarName = true + } + + /** A module definition */ + private class ModuleReq(line: String, lineName: String) extends Request(line, lineName) { + def moduleName = trees match { + case List(ModuleDef(_, name, _)) => name + } + override def resultExtractionCode(code: PrintWriter): Unit = { + super.resultExtractionCode(code) + code.println(" + \"defined module " + moduleName + "\\n\"") + } + } + + /** A class definition */ + private class ClassReq(line: String, lineName: String) extends Request(line, lineName) { + def newClassName = trees match { + case List(ClassDef(_, name, _, _, _)) => name + } + + override def resultExtractionCode(code: PrintWriter): Unit = { + super.resultExtractionCode(code) + code.println(" + \"defined class " + newClassName + "\\n\"") + } + } + + /** a type alias */ + private class TypeAliasReq(line: String, lineName: String) extends Request(line, lineName) { + def newTypeName = trees match { + case List(AliasTypeDef(_, name, _, _)) => name + } + + override def resultExtractionCode(code: PrintWriter): Unit = { + super.resultExtractionCode(code) + code.println(" + \"defined type alias " + newTypeName + "\\n\"") } - override def traverseTrees(asts: List[Tree]): unit = - asts foreach { curAST = curAST+1; traverse; } } - //todo: unit-test cases } diff --git a/src/compiler/scala/tools/nsc/MainInterpreter.scala b/src/compiler/scala/tools/nsc/MainInterpreter.scala index b1ea38e065..4727cd8cc9 100644 --- a/src/compiler/scala/tools/nsc/MainInterpreter.scala +++ b/src/compiler/scala/tools/nsc/MainInterpreter.scala @@ -3,72 +3,130 @@ * @author emir */ // $Id$ -package scala.tools.nsc; +package scala.tools.nsc -import java.io._; -import scala.tools.nsc.util.{Position}; -import scala.tools.nsc.reporters.{Reporter, ConsoleReporter}; +import java.io._ +import scala.tools.nsc.util.{Position} +import scala.tools.nsc.reporters.{Reporter, ConsoleReporter} /** The main class for the new scala interpreter. */ -object MainInterpreter extends Object with EvalLoop { - // lots of stuff duplicated from Main - val PRODUCT: String = - System.getProperty("scala.product", "scalaint"); - val VERSION: String = - System.getProperty("scala.version", "unknown version"); - val versionMsg = PRODUCT + " " + VERSION + " -- (c) 2002-05 LAMP/EPFL"; - val prompt = "\nnsc> "; - - private var reporter: ConsoleReporter = _; - - def error(msg: String): unit = - reporter.error(new Position(PRODUCT), - msg + "\n " + PRODUCT + " -help gives more information"); - - def errors() = reporter.errors; - - def interpret(gCompiler: Global): unit = { - val interpreter = new Interpreter { - val compiler: gCompiler.type = gCompiler - }; - loop(line => try { - interpreter.interpret(line.trim(), reporter) - } catch { - case e: Exception => { - reporter.info(null,e.getMessage(),true); - //e.printStackTrace(); - } +object MainInterpreter { + val reporter = new ConsoleReporter() + + var interpreter: Interpreter = _ + + /** print a friendly help message */ + def printHelp = { + Console.println("This is an interpreter for Scala.") + Console.println("Type in expressions to have them evaluated.") + Console.println("Type :quit to exit the interpreter.") + Console.println("Type :compile followed by a filename to compile a complete Scala file.") + Console.println("Type :load followed by a filename to load a sequence of interpreter commands.") + Console.println("Type :help to repeat this message later.") + } + + /** A simple, generic read-eval-print loop with a pluggable eval-print function. + Loop reading lines of input and invoking the eval-print function. + Stop looping when eval-print returns false. */ + def repl(evpr: String => Boolean): Unit = { + val in = new BufferedReader(new InputStreamReader(System.in)) + + while(true) { + Console.print("\nscala> ") + var line = in.readLine() + if(line == null) + return () // assumes null means EOF + + val keepGoing = evpr(line) + + if(!keepGoing) + return () // the evpr function said to stop + } + } + + /** interpret one line of code submitted by the user */ + def interpretOne(line: String): Unit = { + try { + interpreter.interpret(line) + } catch { + case e: Exception => { + reporter.info(null,"Exception occurred: " + e.getMessage(),true) + //e.printStackTrace() + } + } + } + + /** interpret all lines from a specified file */ + def interpretAllFrom(filename: String): Unit = { + val fileIn = try { + new FileReader(filename) + } catch { + case _:IOException => + Console.println("Error opening file: " + filename) + null + } + if(fileIn == null) return () + val in = new BufferedReader(fileIn) + while(true) { + val line = in.readLine + if(line == null) { + fileIn.close + return () } - ) + command(line) + } } - def process(args: Array[String]): unit = { - reporter = new ConsoleReporter(); - val command = new CompilerCommand(List.fromArray(args), error, false); - reporter.prompt = (command.settings.prompt.value); - if (command.settings.version.value) - reporter.info(null, versionMsg, true) - else if (command.settings.help.value) // 2do replace with InterpCommand - reporter.info(null, command.usageMsg, true) + /** run one command submitted by the user */ + def command(line: String): Boolean = { + def withFile(command: String)(action: String=>Unit): Unit = { + val spaceIdx = command.indexOf(' ') + if(spaceIdx <= 0) { + Console.println("That command requires a filename to be specified.") + return () + } + val filename = command.substring(spaceIdx).trim + action(filename) + } - else { - try { - val compiler = new Global(command.settings, reporter); - interpret(compiler); - } catch { - case ex @ FatalError(msg) => - if (command.settings.debug.value) - ex.printStackTrace(); - reporter.error(null, "fatal error: " + msg); + if(line.startsWith(":")) + line match { + case ":help" => printHelp + case ":quit" => return false + case _ if line.startsWith(":compile") => withFile(line)(f => interpreter.compile(f)) + case _ if line.startsWith(":load") => withFile(line)(f => interpretAllFrom(f)) + case _ => Console.println("Unknown command. Type :help for help.") } - reporter.printSummary() + else + interpretOne(line) + true + } + + + + /** the main interpreter loop */ + def interpretLoop(compiler: Global): unit = { + interpreter = new Interpreter(compiler, str=>Console.print(str)) + repl(command) + interpreter.close + } + + + + /** process the command-line arguments and do as they request */ + def process(args: Array[String]): unit = { + val command = new CompilerCommand(List.fromArray(args), error, false) + reporter.prompt = command.settings.prompt.value + if (command.settings.help.value) { + reporter.info(null, command.usageMsg, true) + } else { + printHelp + interpretLoop(new Global(command.settings, reporter)) } } def main(args: Array[String]): unit = { - process(args); - System.exit(if (reporter.errors > 0) 1 else 0); + process(args) } - } -- cgit v1.2.3