From 4d8b3694b34d3a3f328e7eb7d55d75bc66d842c1 Mon Sep 17 00:00:00 2001 From: Lex Spoon Date: Sun, 1 Apr 2007 17:23:12 +0000 Subject: 1. the ones that appear relevant. 2. Such imports are nested, each with its own wrapper object. 3. Interpreter output is cleaned up with a regular expression, so that all of these wrapper objects do not apper. --- src/compiler/scala/tools/nsc/Interpreter.scala | 196 ++++++++++++--------- src/compiler/scala/tools/nsc/symtab/StdNames.scala | 3 + 2 files changed, 118 insertions(+), 81 deletions(-) (limited to 'src') diff --git a/src/compiler/scala/tools/nsc/Interpreter.scala b/src/compiler/scala/tools/nsc/Interpreter.scala index 331e63108c..447da447a6 100644 --- a/src/compiler/scala/tools/nsc/Interpreter.scala +++ b/src/compiler/scala/tools/nsc/Interpreter.scala @@ -10,8 +10,8 @@ import java.lang.{Class, ClassLoader} import java.io.{File, PrintWriter, StringWriter} import java.net.{URL, URLClassLoader} +import scala.collection.mutable import scala.collection.mutable.{ListBuffer, HashSet, ArrayBuffer} -import scala.collection.immutable.{Map, ListMap} import ast.parser.SyntaxAnalyzer import io.PlainFile @@ -140,10 +140,6 @@ class Interpreter(val settings: Settings, reporter: Reporter, out: PrintWriter) /** the previous requests this interpreter has processed */ private val prevRequests = new ArrayBuffer[Request]() - /** look up the request that bound a specified term or type */ - private def reqBinding(vname: Name): Option[Request] = - prevRequests.toList.reverse.find(lin => lin.boundNames.contains(vname)) - /** next line number to use */ private var nextLineNo = 0 @@ -161,23 +157,10 @@ class Interpreter(val settings: Settings, reporter: Reporter, out: PrintWriter) private def newVarName() = { val num = nextVarNameNo nextVarNameNo = nextVarNameNo + 1 - "unnamed" + num + compiler.nme.INTERPRETER_VAR_PREFIX + num } - /** import statements that should be used for submitted code */ - private def importLines: List[String] = - for { - val req <- prevRequests.toList - req.isInstanceOf[ImportReq] - } - yield req.line - - //private var importLinesRev: List[String] = List("import scala.collection.immutable._") - - /** a string of import code corresponding to all of the current importLines */ - private def codeForImports: String = importLines.mkString("", ";\n", ";\n") - /** generate a string using a routine that wants to write on a stream */ private def stringFrom(writer: PrintWriter => Unit): String = { val stringWriter = new StringWriter() @@ -204,6 +187,57 @@ class Interpreter(val settings: Settings, reporter: Reporter, out: PrintWriter) return str.substring(0, maxpr) } + /** Compute imports that allow definitions from previous + * requests to be visible in a new request. Returns + * three pieces of related code: + * + * 1. An initial code fragment that should go before + * the code of the new request. + * + * 2. A code fragment that should go after the code + * of the new request. + * + * 3. An access path which can be traverested to access + * any bindings inside code wrapped by #1 and #2 . + * + */ + private def importsCode: (String, String, String) = { + val code = new StringBuffer + val trailingBraces = new StringBuffer + val accessPath = new StringBuffer + val impname = compiler.nme.INTERPRETER_IMPORT_WRAPPER + + def addWrapper() { + code.append("object " + impname + "{\n") + trailingBraces.append("}\n") + accessPath.append("." + impname) + } + for(val req <- prevRequests) { + req match { + case req:ImportReq => + // If the user entered an import, then just use it + addWrapper() + code.append(req.line + ";\n") + case req => + // For other requests, import each bound variable. + // import them explicitly instead of with _, so that + // ambiguity errors will not be generated. + for(val imv <- req.boundNames) { + addWrapper() + code.append("import ") + code.append(req.objectName + req.accessPath + "." + imv + ";\n") + } + } + } + + addWrapper() // Add one extra wrapper, to prevent warnings + // in the (frequent!) case of redefining + // the value bound in the last interpreter + // request. + + (code.toString, trailingBraces.toString, accessPath.toString) + } + /** Parse a line into a sequence of trees. Returns None if the input * is incomplete. */ private def parse(line: String): Option[List[Tree]] = { @@ -219,18 +253,15 @@ class Interpreter(val settings: Settings, reporter: Reporter, out: PrintWriter) // parse the main code along with the imports reporter.reset - val trees = simpleParse(codeForImports + line) + + val trees= simpleParse(line) + if (justNeedsMore) None else if (reporter.hasErrors) Some(Nil) // the result did not parse, so stop - else { - // parse the imports alone - val importTrees = simpleParse(codeForImports) - - // return just the new trees, not the import trees - Some(trees.drop(importTrees.length)) - } + else + Some(trees) } } @@ -296,6 +327,9 @@ class Interpreter(val settings: Settings, reporter: Reporter, out: PrintWriter) * @return ... */ def interpret(line: String): IR.Result = { + def clean(str: String) = + truncPrintString(Interpreter.stripWrapperGunk(str)) + // parse val trees = parse(line) match { case None => return IR.Incomplete @@ -316,11 +350,11 @@ class Interpreter(val settings: Settings, reporter: Reporter, out: PrintWriter) if (printResults || !succeeded) { // print the result - out.print(truncPrintString(interpreterResultString)) + out.print(clean(interpreterResultString)) // print out types of functions; they are not printed in the // request printout - out.print(req.defTypesSummary) + out.print(clean(req.defTypesSummary)) } // book-keeping @@ -341,7 +375,7 @@ class Interpreter(val settings: Settings, reporter: Reporter, out: PrintWriter) * @param value ... * @return ... */ - def bind(name: String, boundType: String, value: Any) = { + def bind(name: String, boundType: String, value: Any): IR.Result = { val binderName = "binder" + binderNum binderNum = binderNum + 1 @@ -380,28 +414,21 @@ class Interpreter(val settings: Settings, reporter: Reporter, out: PrintWriter) def close: Unit = Interpreter.deleteRecursively(classfilePath) - /** A traverser that finds all mentioned identifiers, i.e. things - * that need to be imported. It might return extra names. - */ - private class ImportVarsTraverser(definedVars: List[Name]) extends Traverser { - val importVars = new HashSet[Name]() - - override def traverse(ast: Tree): unit = ast match { - case Ident(name) => importVars += name - case _ => super.traverse(ast) - } - } /** One line of code submitted by the user for interpretation */ private abstract class Request(val line: String, val lineName: String) { val Some(trees) = parse(line) /** name to use for the object that will compute "line" */ - def objectName = lineName + compiler.nme.INTERPRETER_WRAPPER_SUFFIX // make it unlikely to clash with user variables + def objectName = lineName + compiler.nme.INTERPRETER_WRAPPER_SUFFIX /** name of the object that retrieves the result from the above object */ def resultObjectName = "RequestResult$" + objectName + /** Code to append to objectName to access anything that + * the request binds. */ + val accessPath = importsCode._3 + /** whether the trees need a variable name, as opposed to standing alone */ val needsVarName: Boolean = false @@ -466,42 +493,29 @@ class Interpreter(val settings: Settings, reporter: Reporter, out: PrintWriter) val boundNames = defNames ::: valAndVarNames ::: moduleNames ::: classNames ::: typeNames - /** list of names used by this expression */ - val usedNames: List[Name] = { - val ivt = new ImportVarsTraverser(boundNames) - ivt.traverseTrees(trees) - ivt.importVars.toList - } - /** names to print out to the user after evaluation */ def namesToPrintForUser = valAndVarNames /** generate the source code for the object that computes this request */ def objectSourceCode: String = stringFrom(code => { - // add the user-specified imports first - code.println(codeForImports) - - // object header + // header for the wrapper object code.println("object " + objectName + " {") - // Write an import for each imported variable. - // Note that the imports are inside the object wrapper; otherwise, - // the names defined at the package level will override these - // imported values. - for {val imv <- usedNames - val lastDefiner <- reqBinding(imv).toList } { - code.println("import " + lastDefiner.objectName + "." + imv) - } + val (importsPreamble, importsTrailer, _) = importsCode + code.print(importsPreamble) - // the line of code to compute + // the variable to compute, if any if (needsVarName) - code.println(" val " + varName + " = " + line) - else - code.println(" " + line) + code.print(" val " + varName + " = ") + + // the line of code to compute + code.println(line) - //end + code.println(importsTrailer) + + //end the wrapper object code.println(";}") }) @@ -526,18 +540,19 @@ class Interpreter(val settings: Settings, reporter: Reporter, out: PrintWriter) def resultExtractionCode(code: PrintWriter): Unit = for (val vname <- namesToPrintForUser) { code.print(" + \"" + vname + ": " + typeOf(vname) + - " = \" + " + objectName + "." + vname + " + \"\\n\"") + " = \" + " + objectName + accessPath + + "." + vname + " + \"\\n\"") } /** Compile the object file. Returns whether the compilation succeeded. - If all goes well, types is computed and set */ - def compile: Boolean = { + * If all goes well, the "types" map is computed. */ + def compile(): Boolean = { reporter.reset // without this, error counting is not correct, // and the interpreter sometimes overlooks compile failures! // compile the main object val objRun = new compiler.Run() - //Console.println("source: "+objectSourceCode) //DEBUG + //println("source: "+objectSourceCode) //DEBUG objRun.compileSources( List(new SourceFile("", objectSourceCode.toCharArray)) ) @@ -562,18 +577,31 @@ class Interpreter(val settings: Settings, reporter: Reporter, out: PrintWriter) */ def findTypes(objRun: compiler.Run): Map[Name, String] = { def getTypes(names: List[Name], nameMap: Name=>Name): Map[Name, String] = { - names.foldLeft[Map[Name,String]](new ListMap[Name, String]())((map, name) => { - val resObjSym: Symbol = - compiler.definitions.getMember(compiler.definitions.EmptyPackage, - compiler.newTermName(objectName)) + /** the outermost wrapper object */ + val outerResObjSym: Symbol = + compiler.definitions.getMember(compiler.definitions.EmptyPackage, + compiler.newTermName(objectName)) + + /** the innermost object inside the wrapper, found by + * following accessPath into the outer one. */ + val resObjSym = + (accessPath.split("\\.")).foldLeft(outerResObjSym)((sym,name) => + if(name == "") sym else + compiler.atPhase(objRun.typerPhase.next) { + sym.info.member(compiler.newTermName(name)) }) - val typeString = + names.foldLeft(Map.empty[Name, String])((map, name) => { + val rawType = compiler.atPhase(objRun.typerPhase.next) { - resObjSym.info.decls.toList.find(s => - s.name == nameMap(name)).get.tpe.toString() + resObjSym.info.member(name).tpe } + // the types are all =>T; remove the => + val cleanedType= rawType match { + case compiler.PolyType(Nil, rt) => rt + case rawType => rawType + } - map + name -> typeString + map + name -> cleanedType.toString }) } @@ -619,8 +647,7 @@ class Interpreter(val settings: Settings, reporter: Reporter, out: PrintWriter) extends Request(line, lineName) { override def resultExtractionCode(code: PrintWriter): Unit = { super.resultExtractionCode(code) - val bindReq = reqBinding(lhs).get - code.println(" + \"" + lhs + " = \" + " + bindReq.objectName + "." + lhs) + code.println(" + \"" + lhs + " = \" + " + lhs) } override def namesToPrintForUser = Nil } @@ -680,7 +707,6 @@ class Interpreter(val settings: Settings, reporter: Reporter, out: PrintWriter) private class ImportReq(line: String, lineName: String) extends Request(line, lineName) { override val boundNames = Nil - override val usedNames = Nil override def resultExtractionCode(code: PrintWriter): Unit = { code.println("+ \"" + trees.head.toString + "\"") } @@ -706,4 +732,12 @@ object Interpreter { } } + /** Heuristically strip interpreter wrapper prefixes + * from an interpreter output string. + */ + def stripWrapperGunk(str: String): String = { + val wrapregex = "line[0-9]+\\$object(\\$\\$import)*" + str.replaceAll(wrapregex+"\\.", "") + .replaceAll(wrapregex+"\\$", "") + } } diff --git a/src/compiler/scala/tools/nsc/symtab/StdNames.scala b/src/compiler/scala/tools/nsc/symtab/StdNames.scala index 91b9cd163a..0466242b7c 100644 --- a/src/compiler/scala/tools/nsc/symtab/StdNames.scala +++ b/src/compiler/scala/tools/nsc/symtab/StdNames.scala @@ -70,8 +70,11 @@ trait StdNames requires SymbolTable { val EXPAND_SEPARATOR_STRING = "$$" val TUPLE_FIELD_PREFIX_STRING = "_" val CHECK_IF_REFUTABLE_STRING = "check$ifrefutable$" + val INTERPRETER_WRAPPER_SUFFIX = "$object" val INTERPRETER_LINE_PREFIX = "line" + val INTERPRETER_VAR_PREFIX = "unnamed" + val INTERPRETER_IMPORT_WRAPPER = "$import" def LOCAL(clazz: Symbol) = newTermName(LOCALDUMMY_PREFIX_STRING + clazz.name+">") def TUPLE_FIELD(index: int) = newTermName(TUPLE_FIELD_PREFIX_STRING + index) -- cgit v1.2.3