diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2010-06-11 01:10:03 -0700 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2010-06-11 01:10:03 -0700 |
commit | 396f48e5a45fd17b156b396c835acb1cccf5a021 (patch) | |
tree | 35872e5e581924dd001709388bd3039b9d1f8e87 /src | |
parent | 4eb39e0c8a4d4b636b6539a5ad52d8923393471d (diff) | |
download | spark-396f48e5a45fd17b156b396c835acb1cccf5a021.tar.gz spark-396f48e5a45fd17b156b396c835acb1cccf5a021.tar.bz2 spark-396f48e5a45fd17b156b396c835acb1cccf5a021.zip |
New interpreter port for Scala 2.8 interpreter
Diffstat (limited to 'src')
-rw-r--r-- | src/scala/spark/ClosureCleaner.scala | 10 | ||||
-rw-r--r-- | src/scala/spark/Executor.scala | 2 | ||||
-rw-r--r-- | src/scala/spark/repl/SparkCompletion.scala | 353 | ||||
-rw-r--r-- | src/scala/spark/repl/SparkCompletionOutput.scala | 92 | ||||
-rw-r--r-- | src/scala/spark/repl/SparkInteractiveReader.scala | 60 | ||||
-rw-r--r-- | src/scala/spark/repl/SparkInterpreter.scala | 1604 | ||||
-rw-r--r-- | src/scala/spark/repl/SparkInterpreterLoop.scala | 730 | ||||
-rw-r--r-- | src/scala/spark/repl/SparkInterpreterSettings.scala | 112 | ||||
-rw-r--r-- | src/scala/spark/repl/SparkJLineReader.scala | 38 | ||||
-rw-r--r-- | src/scala/spark/repl/SparkSimpleReader.scala | 33 |
10 files changed, 2198 insertions, 836 deletions
diff --git a/src/scala/spark/ClosureCleaner.scala b/src/scala/spark/ClosureCleaner.scala index 3426e56f60..8037434c38 100644 --- a/src/scala/spark/ClosureCleaner.scala +++ b/src/scala/spark/ClosureCleaner.scala @@ -9,8 +9,10 @@ import org.objectweb.asm.Opcodes._ object ClosureCleaner { - private def getClassReader(cls: Class[_]): ClassReader = new ClassReader( - cls.getResourceAsStream(cls.getName.replaceFirst("^.*\\.", "") + ".class")) + private def getClassReader(cls: Class[_]): ClassReader = { + new ClassReader(cls.getResourceAsStream( + cls.getName.replaceFirst("^.*\\.", "") + ".class")) + } private def getOuterClasses(obj: AnyRef): List[Class[_]] = { for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { @@ -84,7 +86,6 @@ object ClosureCleaner { } private def instantiateClass(cls: Class[_], outer: AnyRef): AnyRef = { - /* // TODO: Fix for Scala 2.8 if (spark.repl.Main.interp == null) { // This is a bona fide closure class, whose constructor has no effects // other than to set its fields, so use its constructor @@ -94,7 +95,6 @@ object ClosureCleaner { params(0) = outer // First param is always outer object return cons.newInstance(params: _*).asInstanceOf[AnyRef] } else { - */ // Use reflection to instantiate object without calling constructor val rf = sun.reflect.ReflectionFactory.getReflectionFactory(); val parentCtor = classOf[java.lang.Object].getDeclaredConstructor(); @@ -107,9 +107,7 @@ object ClosureCleaner { field.set(obj, outer) } return obj - /* } - */ } } diff --git a/src/scala/spark/Executor.scala b/src/scala/spark/Executor.scala index c7ee4e594d..395790893b 100644 --- a/src/scala/spark/Executor.scala +++ b/src/scala/spark/Executor.scala @@ -25,13 +25,11 @@ object Executor { // If the REPL is in use, create a ClassLoader that will be able to // read new classes defined by the REPL as the user types code classLoader = this.getClass.getClassLoader - /* // TODO: Fix for Scala 2.8 val classDir = System.getProperty("spark.repl.classdir") if (classDir != null) { println("Using REPL classdir: " + classDir) classLoader = new repl.ExecutorClassLoader(classDir, classLoader) } - */ Thread.currentThread.setContextClassLoader(classLoader) // Start worker thread pool (they will inherit our context ClassLoader) diff --git a/src/scala/spark/repl/SparkCompletion.scala b/src/scala/spark/repl/SparkCompletion.scala new file mode 100644 index 0000000000..d67438445b --- /dev/null +++ b/src/scala/spark/repl/SparkCompletion.scala @@ -0,0 +1,353 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2010 LAMP/EPFL + * @author Paul Phillips + */ + + +package spark.repl + +import scala.tools.nsc +import scala.tools.nsc._ +import scala.tools.nsc.interpreter +import scala.tools.nsc.interpreter._ + +import jline._ +import java.util.{ List => JList } +import util.returning + +object SparkCompletion { + def looksLikeInvocation(code: String) = ( + (code != null) + && (code startsWith ".") + && !(code == ".") + && !(code startsWith "./") + && !(code startsWith "..") + ) + + object Forwarder { + def apply(forwardTo: () => Option[CompletionAware]): CompletionAware = new CompletionAware { + def completions(verbosity: Int) = forwardTo() map (_ completions verbosity) getOrElse Nil + override def follow(s: String) = forwardTo() flatMap (_ follow s) + } + } +} +import SparkCompletion._ + +// REPL completor - queries supplied interpreter for valid +// completions based on current contents of buffer. +class SparkCompletion(val repl: SparkInterpreter) extends SparkCompletionOutput { + // verbosity goes up with consecutive tabs + private var verbosity: Int = 0 + def resetVerbosity() = verbosity = 0 + + def isCompletionDebug = repl.isCompletionDebug + def DBG(msg: => Any) = if (isCompletionDebug) println(msg.toString) + def debugging[T](msg: String): T => T = (res: T) => returning[T](res)(x => DBG(msg + x)) + + lazy val global: repl.compiler.type = repl.compiler + import global._ + import definitions.{ PredefModule, RootClass, AnyClass, AnyRefClass, ScalaPackage, JavaLangPackage } + + // XXX not yet used. + lazy val dottedPaths = { + def walk(tp: Type): scala.List[Symbol] = { + val pkgs = tp.nonPrivateMembers filter (_.isPackage) + pkgs ++ (pkgs map (_.tpe) flatMap walk) + } + walk(RootClass.tpe) + } + + def getType(name: String, isModule: Boolean) = { + val f = if (isModule) definitions.getModule(_: Name) else definitions.getClass(_: Name) + try Some(f(name).tpe) + catch { case _: MissingRequirementError => None } + } + + def typeOf(name: String) = getType(name, false) + def moduleOf(name: String) = getType(name, true) + + trait CompilerCompletion { + def tp: Type + def effectiveTp = tp match { + case MethodType(Nil, resType) => resType + case PolyType(Nil, resType) => resType + case _ => tp + } + + // for some reason any's members don't show up in subclasses, which + // we need so 5.<tab> offers asInstanceOf etc. + private def anyMembers = AnyClass.tpe.nonPrivateMembers + def anyRefMethodsToShow = List("isInstanceOf", "asInstanceOf", "toString") + + def tos(sym: Symbol) = sym.name.decode.toString + def memberNamed(s: String) = members find (x => tos(x) == s) + def hasMethod(s: String) = methods exists (x => tos(x) == s) + + // XXX we'd like to say "filterNot (_.isDeprecated)" but this causes the + // compiler to crash for reasons not yet known. + def members = (effectiveTp.nonPrivateMembers ++ anyMembers) filter (_.isPublic) + def methods = members filter (_.isMethod) + def packages = members filter (_.isPackage) + def aliases = members filter (_.isAliasType) + + def memberNames = members map tos + def methodNames = methods map tos + def packageNames = packages map tos + def aliasNames = aliases map tos + } + + object TypeMemberCompletion { + def apply(tp: Type): TypeMemberCompletion = { + if (tp.typeSymbol.isPackageClass) new PackageCompletion(tp) + else new TypeMemberCompletion(tp) + } + def imported(tp: Type) = new ImportCompletion(tp) + } + + class TypeMemberCompletion(val tp: Type) extends CompletionAware with CompilerCompletion { + def excludeEndsWith: List[String] = Nil + def excludeStartsWith: List[String] = List("<") // <byname>, <repeated>, etc. + def excludeNames: List[String] = anyref.methodNames -- anyRefMethodsToShow ++ List("_root_") + + def methodSignatureString(sym: Symbol) = { + def asString = new MethodSymbolOutput(sym).methodString() + + if (isCompletionDebug) + repl.power.showAtAllPhases(asString) + + atPhase(currentRun.typerPhase)(asString) + } + + def exclude(name: String): Boolean = ( + (name contains "$") || + (excludeNames contains name) || + (excludeEndsWith exists (name endsWith _)) || + (excludeStartsWith exists (name startsWith _)) + ) + def filtered(xs: List[String]) = xs filterNot exclude distinct + + def completions(verbosity: Int) = + debugging(tp + " completions ==> ")(filtered(memberNames)) + + override def follow(s: String): Option[CompletionAware] = + debugging(tp + " -> '" + s + "' ==> ")(memberNamed(s) map (x => TypeMemberCompletion(x.tpe))) + + override def alternativesFor(id: String): List[String] = + debugging(id + " alternatives ==> ") { + val alts = members filter (x => x.isMethod && tos(x) == id) map methodSignatureString + + if (alts.nonEmpty) "" :: alts else Nil + } + + override def toString = "TypeMemberCompletion(%s)".format(tp) + } + + class PackageCompletion(tp: Type) extends TypeMemberCompletion(tp) { + override def excludeNames = anyref.methodNames + } + + class LiteralCompletion(lit: Literal) extends TypeMemberCompletion(lit.value.tpe) { + override def completions(verbosity: Int) = verbosity match { + case 0 => filtered(memberNames) + case _ => memberNames + } + } + + class ImportCompletion(tp: Type) extends TypeMemberCompletion(tp) { + override def completions(verbosity: Int) = verbosity match { + case 0 => filtered(members filterNot (_.isSetter) map tos) + case _ => super.completions(verbosity) + } + } + + // not for completion but for excluding + object anyref extends TypeMemberCompletion(AnyRefClass.tpe) { } + + // the unqualified vals/defs/etc visible in the repl + object ids extends CompletionAware { + override def completions(verbosity: Int) = repl.unqualifiedIds ::: List("classOf") + // we try to use the compiler and fall back on reflection if necessary + // (which at present is for anything defined in the repl session.) + override def follow(id: String) = + if (completions(0) contains id) { + for (clazz <- repl clazzForIdent id) yield { + // XXX The isMemberClass check is a workaround for the crasher described + // in the comments of #3431. The issue as described by iulian is: + // + // Inner classes exist as symbols + // inside their enclosing class, but also inside their package, with a mangled + // name (A$B). The mangled names should never be loaded, and exist only for the + // optimizer, which sometimes cannot get the right symbol, but it doesn't care + // and loads the bytecode anyway. + // + // So this solution is incorrect, but in the short term the simple fix is + // to skip the compiler any time completion is requested on a nested class. + if (clazz.isMemberClass) new InstanceCompletion(clazz) + else (typeOf(clazz.getName) map TypeMemberCompletion.apply) getOrElse new InstanceCompletion(clazz) + } + } + else None + } + + // wildcard imports in the repl like "import global._" or "import String._" + private def imported = repl.wildcardImportedTypes map TypeMemberCompletion.imported + + // literal Ints, Strings, etc. + object literals extends CompletionAware { + def simpleParse(code: String): Tree = { + val unit = new CompilationUnit(new util.BatchSourceFile("<console>", code)) + val scanner = new syntaxAnalyzer.UnitParser(unit) + val tss = scanner.templateStatSeq(false)._2 + + if (tss.size == 1) tss.head else EmptyTree + } + + def completions(verbosity: Int) = Nil + + override def follow(id: String) = simpleParse(id) match { + case x: Literal => Some(new LiteralCompletion(x)) + case _ => None + } + } + + // top level packages + object rootClass extends TypeMemberCompletion(RootClass.tpe) { } + // members of Predef + object predef extends TypeMemberCompletion(PredefModule.tpe) { + override def excludeEndsWith = super.excludeEndsWith ++ List("Wrapper", "ArrayOps") + override def excludeStartsWith = super.excludeStartsWith ++ List("wrap") + override def excludeNames = anyref.methodNames + + override def exclude(name: String) = super.exclude(name) || ( + (name contains "2") + ) + + override def completions(verbosity: Int) = verbosity match { + case 0 => Nil + case _ => super.completions(verbosity) + } + } + // members of scala.* + object scalalang extends PackageCompletion(ScalaPackage.tpe) { + def arityClasses = List("Product", "Tuple", "Function") + def skipArity(name: String) = arityClasses exists (x => name != x && (name startsWith x)) + override def exclude(name: String) = super.exclude(name) || ( + skipArity(name) + ) + + override def completions(verbosity: Int) = verbosity match { + case 0 => filtered(packageNames ++ aliasNames) + case _ => super.completions(verbosity) + } + } + // members of java.lang.* + object javalang extends PackageCompletion(JavaLangPackage.tpe) { + override lazy val excludeEndsWith = super.excludeEndsWith ++ List("Exception", "Error") + override lazy val excludeStartsWith = super.excludeStartsWith ++ List("CharacterData") + + override def completions(verbosity: Int) = verbosity match { + case 0 => filtered(packageNames) + case _ => super.completions(verbosity) + } + } + + // the list of completion aware objects which should be consulted + lazy val topLevelBase: List[CompletionAware] = List(ids, rootClass, predef, scalalang, javalang, literals) + def topLevel = topLevelBase ++ imported + + // the first tier of top level objects (doesn't include file completion) + def topLevelFor(parsed: Parsed) = topLevel flatMap (_ completionsFor parsed) + + // the most recent result + def lastResult = Forwarder(() => ids follow repl.mostRecentVar) + + def lastResultFor(parsed: Parsed) = { + /** The logic is a little tortured right now because normally '.' is + * ignored as a delimiter, but on .<tab> it needs to be propagated. + */ + val xs = lastResult completionsFor parsed + if (parsed.isEmpty) xs map ("." + _) else xs + } + + // chasing down results which won't parse + def execute(line: String): Option[Any] = { + val parsed = Parsed(line) + def noDotOrSlash = line forall (ch => ch != '.' && ch != '/') + + if (noDotOrSlash) None // we defer all unqualified ids to the repl. + else { + (ids executionFor parsed) orElse + (rootClass executionFor parsed) orElse + (FileCompletion executionFor line) + } + } + + // generic interface for querying (e.g. interpreter loop, testing) + def completions(buf: String): List[String] = + topLevelFor(Parsed.dotted(buf + ".", buf.length + 1)) + + // jline's entry point + lazy val jline: ArgumentCompletor = + returning(new ArgumentCompletor(new JLineCompletion, new JLineDelimiter))(_ setStrict false) + + /** This gets a little bit hairy. It's no small feat delegating everything + * and also keeping track of exactly where the cursor is and where it's supposed + * to end up. The alternatives mechanism is a little hacky: if there is an empty + * string in the list of completions, that means we are expanding a unique + * completion, so don't update the "last" buffer because it'll be wrong. + */ + class JLineCompletion extends Completor { + // For recording the buffer on the last tab hit + private var lastBuf: String = "" + private var lastCursor: Int = -1 + + // Does this represent two consecutive tabs? + def isConsecutiveTabs(buf: String, cursor: Int) = cursor == lastCursor && buf == lastBuf + + // Longest common prefix + def commonPrefix(xs: List[String]) = + if (xs.isEmpty) "" + else xs.reduceLeft(_ zip _ takeWhile (x => x._1 == x._2) map (_._1) mkString) + + // This is jline's entry point for completion. + override def complete(_buf: String, cursor: Int, candidates: JList[String]): Int = { + val buf = onull(_buf) + verbosity = if (isConsecutiveTabs(buf, cursor)) verbosity + 1 else 0 + DBG("complete(%s, %d) last = (%s, %d), verbosity: %s".format(buf, cursor, lastBuf, lastCursor, verbosity)) + + // we don't try lower priority completions unless higher ones return no results. + def tryCompletion(p: Parsed, completionFunction: Parsed => List[String]): Option[Int] = { + completionFunction(p) match { + case Nil => None + case xs => + // modify in place and return the position + xs foreach (candidates add _) + + // update the last buffer unless this is an alternatives list + if (xs contains "") Some(p.cursor) + else { + val advance = commonPrefix(xs) + lastCursor = p.position + advance.length + lastBuf = (buf take p.position) + advance + + DBG("tryCompletion(%s, _) lastBuf = %s, lastCursor = %s, p.position = %s".format(p, lastBuf, lastCursor, p.position)) + Some(p.position) + } + } + } + + def mkDotted = Parsed.dotted(buf, cursor) withVerbosity verbosity + def mkUndelimited = Parsed.undelimited(buf, cursor) withVerbosity verbosity + + // a single dot is special cased to completion on the previous result + def lastResultCompletion = + if (!looksLikeInvocation(buf)) None + else tryCompletion(Parsed.dotted(buf drop 1, cursor), lastResultFor) + + def regularCompletion = tryCompletion(mkDotted, topLevelFor) + def fileCompletion = tryCompletion(mkUndelimited, FileCompletion completionsFor _.buffer) + + (lastResultCompletion orElse regularCompletion orElse fileCompletion) getOrElse cursor + } + } +} diff --git a/src/scala/spark/repl/SparkCompletionOutput.scala b/src/scala/spark/repl/SparkCompletionOutput.scala new file mode 100644 index 0000000000..5ac46e3412 --- /dev/null +++ b/src/scala/spark/repl/SparkCompletionOutput.scala @@ -0,0 +1,92 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2010 LAMP/EPFL + * @author Paul Phillips + */ + +package spark.repl + +import scala.tools.nsc +import scala.tools.nsc._ +import scala.tools.nsc.interpreter +import scala.tools.nsc.interpreter._ + +/** This has a lot of duplication with other methods in Symbols and Types, + * but repl completion utility is very sensitive to precise output. Best + * thing would be to abstract an interface for how such things are printed, + * as is also in progress with error messages. + */ +trait SparkCompletionOutput { + self: SparkCompletion => + + import global._ + import definitions.{ NothingClass, AnyClass, isTupleType, isFunctionType, isRepeatedParamType } + + /** Reducing fully qualified noise for some common packages. + */ + val typeTransforms = List( + "java.lang." -> "", + "scala.collection.immutable." -> "immutable.", + "scala.collection.mutable." -> "mutable.", + "scala.collection.generic." -> "generic." + ) + + def quietString(tp: String): String = + typeTransforms.foldLeft(tp) { + case (str, (prefix, replacement)) => + if (str startsWith prefix) replacement + (str stripPrefix prefix) + else str + } + + class MethodSymbolOutput(method: Symbol) { + val pkg = method.ownerChain find (_.isPackageClass) map (_.fullName) getOrElse "" + + def relativize(str: String): String = quietString(str stripPrefix (pkg + ".")) + def relativize(tp: Type): String = relativize(tp.normalize.toString) + def relativize(sym: Symbol): String = relativize(sym.info) + + def braceList(tparams: List[String]) = if (tparams.isEmpty) "" else (tparams map relativize).mkString("[", ", ", "]") + def parenList(params: List[Any]) = params.mkString("(", ", ", ")") + + def methodTypeToString(mt: MethodType) = + (mt.paramss map paramsString mkString "") + ": " + relativize(mt.finalResultType) + + def typeToString(tp: Type): String = relativize( + tp match { + case x if isFunctionType(x) => functionString(x) + case x if isTupleType(x) => tupleString(x) + case x if isRepeatedParamType(x) => typeToString(x.typeArgs.head) + "*" + case mt @ MethodType(_, _) => methodTypeToString(mt) + case x => x.toString + } + ) + + def tupleString(tp: Type) = parenList(tp.normalize.typeArgs map relativize) + def functionString(tp: Type) = tp.normalize.typeArgs match { + case List(t, r) => t + " => " + r + case xs => parenList(xs.init) + " => " + xs.last + } + + def tparamsString(tparams: List[Symbol]) = braceList(tparams map (_.defString)) + def paramsString(params: List[Symbol]) = { + def paramNameString(sym: Symbol) = if (sym.isSynthetic) "" else sym.nameString + ": " + def paramString(sym: Symbol) = paramNameString(sym) + typeToString(sym.info.normalize) + + val isImplicit = params.nonEmpty && params.head.isImplicit + val strs = (params map paramString) match { + case x :: xs if isImplicit => ("implicit " + x) :: xs + case xs => xs + } + parenList(strs) + } + + def methodString() = + method.keyString + " " + method.nameString + (method.info.normalize match { + case PolyType(Nil, resType) => ": " + typeToString(resType) // nullary method + case PolyType(tparams, resType) => tparamsString(tparams) + typeToString(resType) + case mt @ MethodType(_, _) => methodTypeToString(mt) + case x => + DBG("methodString(): %s / %s".format(x.getClass, x)) + x.toString + }) + } +} diff --git a/src/scala/spark/repl/SparkInteractiveReader.scala b/src/scala/spark/repl/SparkInteractiveReader.scala new file mode 100644 index 0000000000..4f5a0a6fa0 --- /dev/null +++ b/src/scala/spark/repl/SparkInteractiveReader.scala @@ -0,0 +1,60 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2010 LAMP/EPFL + * @author Stepan Koltsov + */ + +package spark.repl + +import scala.tools.nsc +import scala.tools.nsc._ +import scala.tools.nsc.interpreter +import scala.tools.nsc.interpreter._ + +import scala.util.control.Exception._ + +/** Reads lines from an input stream */ +trait SparkInteractiveReader { + import SparkInteractiveReader._ + import java.io.IOException + + protected def readOneLine(prompt: String): String + val interactive: Boolean + + def readLine(prompt: String): String = { + def handler: Catcher[String] = { + case e: IOException if restartSystemCall(e) => readLine(prompt) + } + catching(handler) { readOneLine(prompt) } + } + + // override if history is available + def history: Option[History] = None + def historyList = history map (_.asList) getOrElse Nil + + // override if completion is available + def completion: Option[SparkCompletion] = None + + // hack necessary for OSX jvm suspension because read calls are not restarted after SIGTSTP + private def restartSystemCall(e: Exception): Boolean = + Properties.isMac && (e.getMessage == msgEINTR) +} + + +object SparkInteractiveReader { + val msgEINTR = "Interrupted system call" + private val exes = List(classOf[Exception], classOf[NoClassDefFoundError]) + + def createDefault(): SparkInteractiveReader = createDefault(null) + + /** Create an interactive reader. Uses <code>JLineReader</code> if the + * library is available, but otherwise uses a <code>SimpleReader</code>. + */ + def createDefault(interpreter: SparkInterpreter): SparkInteractiveReader = + try new SparkJLineReader(interpreter) + catch { + case e @ (_: Exception | _: NoClassDefFoundError) => + // println("Failed to create SparkJLineReader(%s): %s".format(interpreter, e)) + new SparkSimpleReader + } +} + diff --git a/src/scala/spark/repl/SparkInterpreter.scala b/src/scala/spark/repl/SparkInterpreter.scala index 2377f0c7d6..85313b55b5 100644 --- a/src/scala/spark/repl/SparkInterpreter.scala +++ b/src/scala/spark/repl/SparkInterpreter.scala @@ -1,30 +1,40 @@ /* NSC -- new Scala compiler - * Copyright 2005-2009 LAMP/EPFL + * Copyright 2005-2010 LAMP/EPFL * @author Martin Odersky */ -// $Id: Interpreter.scala 17013 2009-02-02 11:59:53Z washburn $ package spark.repl import scala.tools.nsc import scala.tools.nsc._ -import java.io.{File, IOException, PrintWriter, StringWriter, Writer} -import java.lang.{Class, ClassLoader} -import java.net.{MalformedURLException, URL, URLClassLoader} +import Predef.{ println => _, _ } +import java.io.{ File, IOException, PrintWriter, StringWriter, Writer } +import File.pathSeparator +import java.lang.{ Class, ClassLoader } +import java.net.{ MalformedURLException, URL } +import java.lang.reflect +import reflect.InvocationTargetException import java.util.UUID -import scala.collection.immutable.ListSet +import scala.PartialFunction.{ cond, condOpt } +import scala.tools.util.PathResolver +import scala.reflect.Manifest import scala.collection.mutable -import scala.collection.mutable.{ListBuffer, HashSet, ArrayBuffer} - -//import ast.parser.SyntaxAnalyzer -import io.{PlainFile, VirtualDirectory} -import reporters.{ConsoleReporter, Reporter} -import symtab.Flags -import util.{SourceFile,BatchSourceFile,ClassPath,NameTransformer} -import nsc.{InterpreterResults=>IR} -import scala.tools.nsc.interpreter._ +import scala.collection.mutable.{ ListBuffer, HashSet, HashMap, ArrayBuffer } +import scala.collection.immutable.Set +import scala.tools.nsc.util.ScalaClassLoader +import ScalaClassLoader.URLClassLoader +import scala.util.control.Exception.{ Catcher, catching, ultimately, unwrapping } + +import io.{ PlainFile, VirtualDirectory } +import reporters.{ ConsoleReporter, Reporter } +import symtab.{ Flags, Names } +import util.{ SourceFile, BatchSourceFile, ScriptSourceFile, ClassPath, Chars, stringFromWriter } +import scala.reflect.NameTransformer +import scala.tools.nsc.{ InterpreterResults => IR } +import interpreter._ +import SparkInterpreter._ /** <p> * An interpreter for Scala code. @@ -51,7 +61,7 @@ import scala.tools.nsc.interpreter._ * all variables defined by that code. To extract the result of an * interpreted line to show the user, a second "result object" is created * which imports the variables exported by the above object and then - * exports a single member named "result". To accomodate user expressions + * exports a single member named "scala_repl_result". To accomodate user expressions * that read from variables or methods defined in previous statements, "import" * statements are used. * </p> @@ -67,14 +77,18 @@ import scala.tools.nsc.interpreter._ * @author Lex Spoon */ class SparkInterpreter(val settings: Settings, out: PrintWriter) { - import symtab.Names - - /* If the interpreter is running on pre-jvm-1.5 JVM, - it is necessary to force the target setting to jvm-1.4 */ - private val major = System.getProperty("java.class.version").split("\\.")(0) - if (major.toInt < 49) { - this.settings.target.value = "jvm-1.4" + repl => + + def println(x: Any) = { + out.println(x) + out.flush() } + + /** construct an interpreter that reports to Console */ + def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) + def this() = this(new Settings()) + + val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") /** directory to save .class files to */ //val virtualDirectory = new VirtualDirectory("(memory)", None) @@ -97,77 +111,120 @@ class SparkInterpreter(val settings: Settings, out: PrintWriter) { } System.setProperty("spark.repl.classdir", "file://" + outputDir.getAbsolutePath + "/") - //println("Output dir: " + outputDir) + if (SPARK_DEBUG_REPL) + println("Output directory: " + outputDir) new PlainFile(outputDir) } + + /** reporter */ + object reporter extends ConsoleReporter(settings, null, out) { + override def printMessage(msg: String) { + out println clean(msg) + out.flush() + } + } - /** the compiler to compile expressions with */ - val compiler: scala.tools.nsc.Global = newCompiler(settings, reporter) + /** We're going to go to some trouble to initialize the compiler asynchronously. + * It's critical that nothing call into it until it's been initialized or we will + * run into unrecoverable issues, but the perceived repl startup time goes + * through the roof if we wait for it. So we initialize it with a future and + * use a lazy val to ensure that any attempt to use the compiler object waits + * on the future. + */ + private val _compiler: Global = newCompiler(settings, reporter) + private def _initialize(): Boolean = { + val source = """ + |// this is assembled to force the loading of approximately the + |// classes which will be loaded on the first expression anyway. + |class $repl_$init { + | val x = "abc".reverse.length + (5 max 5) + | scala.runtime.ScalaRunTime.stringOf(x) + |} + |""".stripMargin + + try { + new _compiler.Run() compileSources List(new BatchSourceFile("<init>", source)) + if (isReplDebug || settings.debug.value) + println("Repl compiler initialized.") + true + } + catch { + case MissingRequirementError(msg) => println(""" + |Failed to initialize compiler: %s not found. + |** Note that as of 2.8 scala does not assume use of the java classpath. + |** For the old behavior pass -usejavacp to scala, or if using a Settings + |** object programatically, settings.usejavacp.value = true.""".stripMargin.format(msg) + ) + false + } + } + + // set up initialization future + private var _isInitialized: () => Boolean = null + def initialize() = synchronized { + if (_isInitialized == null) + _isInitialized = scala.concurrent.ops future _initialize() + } - import compiler.Traverser - import compiler.{Tree, TermTree, - ValOrDefDef, ValDef, DefDef, Assign, - ClassDef, ModuleDef, Ident, Select, TypeDef, - Import, MemberDef, DocDef} - import compiler.CompilationUnit - import compiler.{Symbol,Name,Type} - import compiler.nme - import compiler.newTermName - import compiler.newTypeName - import compiler.nme.{INTERPRETER_VAR_PREFIX, INTERPRETER_SYNTHVAR_PREFIX} - import Interpreter.string2code + /** the public, go through the future compiler */ + lazy val compiler: Global = { + initialize() - /** construct an interpreter that reports to Console */ - def this(settings: Settings) = - this(settings, - new NewLinePrintWriter(new ConsoleWriter, true)) + // blocks until it is ; false means catastrophic failure + if (_isInitialized()) _compiler + else null + } + + import compiler.{ Traverser, CompilationUnit, Symbol, Name, Type } + import compiler.{ + Tree, TermTree, ValOrDefDef, ValDef, DefDef, Assign, ClassDef, + ModuleDef, Ident, Select, TypeDef, Import, MemberDef, DocDef, + ImportSelector, EmptyTree, NoType } + import compiler.{ nme, newTermName, newTypeName } + import nme.{ + INTERPRETER_VAR_PREFIX, INTERPRETER_SYNTHVAR_PREFIX, INTERPRETER_LINE_PREFIX, + INTERPRETER_IMPORT_WRAPPER, INTERPRETER_WRAPPER_SUFFIX, USCOREkw + } + + import compiler.definitions + import definitions.{ EmptyPackage, getMember } /** whether to print out result lines */ - private var printResults: Boolean = true + private[repl] var printResults: Boolean = true - /** Be quiet. Do not print out the results of each - * submitted command unless an exception is thrown. */ - def beQuiet = { printResults = false } - /** Temporarily be quiet */ - def beQuietDuring[T](operation: => T): T = { - val wasPrinting = printResults - try { + def beQuietDuring[T](operation: => T): T = { + val wasPrinting = printResults + ultimately(printResults = wasPrinting) { printResults = false operation - } finally { - printResults = wasPrinting + } + } + + /** whether to bind the lastException variable */ + private var bindLastException = true + + /** Temporarily stop binding lastException */ + def withoutBindingLastException[T](operation: => T): T = { + val wasBinding = bindLastException + ultimately(bindLastException = wasBinding) { + bindLastException = false + operation } } /** interpreter settings */ - lazy val isettings = new InterpreterSettings - - object reporter extends ConsoleReporter(settings, null, out) { - //override def printMessage(msg: String) { out.println(clean(msg)) } - override def printMessage(msg: String) { out.print(clean(msg) + "\n"); out.flush() } - } + lazy val isettings = new SparkInterpreterSettings(this) /** Instantiate a compiler. Subclasses can override this to * change the compiler class used by this interpreter. */ protected def newCompiler(settings: Settings, reporter: Reporter) = { - val comp = new scala.tools.nsc.Global(settings, reporter) - comp.genJVM.outputDir = virtualDirectory - comp + settings.outputDirs setSingleOutput virtualDirectory + new Global(settings, reporter) } - - + /** the compiler's classpath, as URL's */ - val compilerClasspath: List[URL] = { - val classpathPart = - (ClassPath.expandPath(compiler.settings.classpath.value). - map(s => new File(s).toURL)) - def parseURL(s: String): Option[URL] = - try { Some(new URL(s)) } - catch { case _:MalformedURLException => None } - val codebasePart = (compiler.settings.Xcodebase.value.split(" ")).toList.flatMap(parseURL) - classpathPart ::: codebasePart - } + lazy val compilerClasspath: List[URL] = new PathResolver(settings) asURLs /* A single class loader is used for all commands interpreted by this Interpreter. It would also be possible to create a new class loader for each command @@ -182,113 +239,176 @@ class SparkInterpreter(val settings: Settings, out: PrintWriter) { shadow the old ones, and old code objects refer to the old definitions. */ - /** class loader used to load compiled code */ - private val classLoader = { + private var _classLoader: ClassLoader = null + def resetClassLoader() = _classLoader = makeClassLoader() + def classLoader: ClassLoader = { + if (_classLoader == null) + resetClassLoader() + + _classLoader + } + private def makeClassLoader(): ClassLoader = { + /* + val parent = + if (parentClassLoader == null) ScalaClassLoader fromURLs compilerClasspath + else new URLClassLoader(compilerClasspath, parentClassLoader) + + new AbstractFileClassLoader(virtualDirectory, parent) + */ val parent = if (parentClassLoader == null) - new URLClassLoader(compilerClasspath.toArray) + new java.net.URLClassLoader(compilerClasspath.toArray) else - new URLClassLoader(compilerClasspath.toArray, + new java.net.URLClassLoader(compilerClasspath.toArray, parentClassLoader) val virtualDirUrl = new URL("file://" + virtualDirectory.path + "/") - new URLClassLoader(Array(virtualDirUrl), parent) - //new InterpreterClassLoader(Array(virtualDirUrl), parent) - //new AbstractFileClassLoader(virtualDirectory, parent) + new java.net.URLClassLoader(Array(virtualDirUrl), parent) } - /** Set the current Java "context" class loader to this - * interpreter's class loader */ - def setContextClassLoader() { - Thread.currentThread.setContextClassLoader(classLoader) - } + private def loadByName(s: String): Class[_] = // (classLoader tryToInitializeClass s).get + Class.forName(s, true, classLoader) + private def methodByName(c: Class[_], name: String): reflect.Method = + c.getMethod(name, classOf[Object]) + protected def parentClassLoader: ClassLoader = this.getClass.getClassLoader() + def getInterpreterClassLoader() = classLoader + + // Set the current Java "context" class loader to this interpreter's class loader + def setContextClassLoader() = Thread.currentThread.setContextClassLoader(classLoader) /** the previous requests this interpreter has processed */ private val prevRequests = new ArrayBuffer[Request]() + private val usedNameMap = new HashMap[Name, Request]() + private val boundNameMap = new HashMap[Name, Request]() + private def allHandlers = prevRequests.toList flatMap (_.handlers) + private def allReqAndHandlers = prevRequests.toList flatMap (req => req.handlers map (req -> _)) + + def printAllTypeOf = { + prevRequests foreach { req => + req.typeOf foreach { case (k, v) => Console.println(k + " => " + v) } + } + } - /** next line number to use */ - private var nextLineNo = 0 + /** Most recent tree handled which wasn't wholly synthetic. */ + private def mostRecentlyHandledTree: Option[Tree] = { + for { + req <- prevRequests.reverse + handler <- req.handlers.reverse + name <- handler.generatesValue + if !isSynthVarName(name) + } return Some(handler.member) - /** allocate a fresh line name */ - private def newLineName = { - val num = nextLineNo - nextLineNo += 1 - compiler.nme.INTERPRETER_LINE_PREFIX + num + None } - - /** next result variable number to use */ - private var nextVarNameNo = 0 - /** allocate a fresh variable name */ - private def newVarName() = { - val num = nextVarNameNo - nextVarNameNo += 1 - INTERPRETER_VAR_PREFIX + num - } + def recordRequest(req: Request) { + def tripart[T](set1: Set[T], set2: Set[T]) = { + val intersect = set1 intersect set2 + List(set1 -- intersect, intersect, set2 -- intersect) + } - /** next internal variable number to use */ - private var nextInternalVarNo = 0 - - /** allocate a fresh internal variable name */ - private def newInternalVarName() = { - val num = nextVarNameNo - nextVarNameNo += 1 - INTERPRETER_SYNTHVAR_PREFIX + num + prevRequests += req + req.usedNames foreach (x => usedNameMap(x) = req) + req.boundNames foreach (x => boundNameMap(x) = req) + + // XXX temporarily putting this here because of tricky initialization order issues + // so right now it's not bound until after you issue a command. + if (prevRequests.size == 1) + quietBind("settings", "spark.repl.SparkInterpreterSettings", isettings) + + // println("\n s1 = %s\n s2 = %s\n s3 = %s".format( + // tripart(usedNameMap.keysIterator.toSet, boundNameMap.keysIterator.toSet): _* + // )) } + private def keyList[T](x: collection.Map[T, _]): List[T] = x.keys.toList sortBy (_.toString) + def allUsedNames = keyList(usedNameMap) + def allBoundNames = keyList(boundNameMap) + def allSeenTypes = prevRequests.toList flatMap (_.typeOf.values.toList) distinct + def allValueGeneratingNames = allHandlers flatMap (_.generatesValue) + def allImplicits = partialFlatMap(allHandlers) { + case x: MemberHandler if x.definesImplicit => x.boundNames + } - /** Check if a name looks like it was generated by newVarName */ - private def isGeneratedVarName(name: String): Boolean = - name.startsWith(INTERPRETER_VAR_PREFIX) && { - val suffix = name.drop(INTERPRETER_VAR_PREFIX.length) - suffix.forall(_.isDigit) + /** Generates names pre0, pre1, etc. via calls to apply method */ + class NameCreator(pre: String) { + private var x = -1 + var mostRecent: String = null + + def apply(): String = { + x += 1 + val name = pre + x.toString + // make sure we don't overwrite their unwisely named res3 etc. + mostRecent = + if (allBoundNames exists (_.toString == name)) apply() + else name + + mostRecent } + def reset(): Unit = x = -1 + def didGenerate(name: String) = + (name startsWith pre) && ((name drop pre.length) forall (_.isDigit)) + } + /** allocate a fresh line name */ + private lazy val lineNameCreator = new NameCreator(INTERPRETER_LINE_PREFIX) + + /** allocate a fresh var name */ + private lazy val varNameCreator = new NameCreator(INTERPRETER_VAR_PREFIX) + + /** allocate a fresh internal variable name */ + private lazy val synthVarNameCreator = new NameCreator(INTERPRETER_SYNTHVAR_PREFIX) - /** generate a string using a routine that wants to write on a stream */ - private def stringFrom(writer: PrintWriter => Unit): String = { - val stringWriter = new StringWriter() - val stream = new NewLinePrintWriter(stringWriter) - writer(stream) - stream.close - stringWriter.toString - } + /** Check if a name looks like it was generated by varNameCreator */ + private def isGeneratedVarName(name: String): Boolean = varNameCreator didGenerate name + private def isSynthVarName(name: String): Boolean = synthVarNameCreator didGenerate name + private def isSynthVarName(name: Name): Boolean = synthVarNameCreator didGenerate name.toString + + def getVarName = varNameCreator() + def getSynthVarName = synthVarNameCreator() - /** Truncate a string if it is longer than settings.maxPrintString */ + /** Truncate a string if it is longer than isettings.maxPrintString */ private def truncPrintString(str: String): String = { val maxpr = isettings.maxPrintString - - if (maxpr <= 0) - return str - - if (str.length <= maxpr) - return str - val trailer = "..." - if (maxpr >= trailer.length+1) - return str.substring(0, maxpr-3) + trailer - - str.substring(0, maxpr) + + if (maxpr <= 0 || str.length <= maxpr) str + else str.substring(0, maxpr-3) + trailer } - /** Clean up a string for output */ - private def clean(str: String) = - truncPrintString(Interpreter.stripWrapperGunk(str)) + /** Clean up a string for output */ + private def clean(str: String) = truncPrintString( + if (isettings.unwrapStrings && !SPARK_DEBUG_REPL) stripWrapperGunk(str) + else str + ) + + /** Heuristically strip interpreter wrapper prefixes + * from an interpreter output string. + * MATEI: Copied from interpreter package object + */ + def stripWrapperGunk(str: String): String = { + val wrapregex = """(line[0-9]+\$object[$.])?(\$VAL.?)*(\$iwC?(.this?)[$.])*""" + str.replaceAll(wrapregex, "") + } /** Indent some code by the width of the scala> prompt. - * This way, compiler error messages read beettr. + * This way, compiler error messages read better. */ + private final val spaces = List.fill(7)(" ").mkString def indentCode(code: String) = { - val spaces = " " - - stringFrom(str => + /** Heuristic to avoid indenting and thereby corrupting """-strings and XML literals. */ + val noIndent = (code contains "\n") && (List("\"\"\"", "</", "/>") exists (code contains _)) + stringFromWriter(str => for (line <- code.lines) { - str.print(spaces) + if (!noIndent) + str.print(spaces) + str.print(line + "\n") str.flush() }) } + def indentString(s: String) = s split "\n" map (spaces + _ + "\n") mkString implicit def name2string(name: Name) = name.toString @@ -315,71 +435,55 @@ class SparkInterpreter(val settings: Settings, out: PrintWriter) { * (3) It imports multiple same-named implicits, but only the * last one imported is actually usable. */ - private def importsCode(wanted: Set[Name]): (String, String, String) = { + private case class ComputedImports(prepend: String, append: String, access: String) + private def importsCode(wanted: Set[Name]): ComputedImports = { /** Narrow down the list of requests from which imports * should be taken. Removes requests which cannot contribute * useful imports for the specified set of wanted names. */ - def reqsToUse: List[(Request,MemberHandler)] = { - /** Loop through a list of MemberHandlers and select - * which ones to keep. 'wanted' is the set of - * names that need to be imported, and - * 'shadowed' is the list of names useless to import - * because a later request will re-import it anyway. + case class ReqAndHandler(req: Request, handler: MemberHandler) { } + + def reqsToUse: List[ReqAndHandler] = { + /** Loop through a list of MemberHandlers and select which ones to keep. + * 'wanted' is the set of names that need to be imported. */ - def select(reqs: List[(Request,MemberHandler)], wanted: Set[Name]): - List[(Request,MemberHandler)] = { + def select(reqs: List[ReqAndHandler], wanted: Set[Name]): List[ReqAndHandler] = { + val isWanted = wanted contains _ + // Single symbol imports might be implicits! See bug #1752. Rather than + // try to finesse this, we will mimic all imports for now. + def keepHandler(handler: MemberHandler) = handler match { + case _: ImportHandler => true + case x => x.definesImplicit || (x.boundNames exists isWanted) + } + reqs match { - case Nil => Nil - - case (req,handler)::rest => - val keepit = - (handler.definesImplicit || - handler.importsWildcard || - handler.importedNames.exists(wanted.contains(_)) || - handler.boundNames.exists(wanted.contains(_))) - - val newWanted = - if (keepit) { - (wanted - ++ handler.usedNames - -- handler.boundNames - -- handler.importedNames) - } else { - wanted - } - - val restToKeep = select(rest, newWanted) - - if(keepit) - (req,handler) :: restToKeep - else - restToKeep + case Nil => Nil + case rh :: rest if !keepHandler(rh.handler) => select(rest, wanted) + case rh :: rest => + val importedNames = rh.handler match { case x: ImportHandler => x.importedNames ; case _ => Nil } + import rh.handler._ + val newWanted = wanted ++ usedNames -- boundNames -- importedNames + rh :: select(rest, newWanted) } } - - val rhpairs = for { - req <- prevRequests.toList.reverse - handler <- req.handlers - } yield (req, handler) - - select(rhpairs, wanted).reverse + + /** Flatten the handlers out and pair each with the original request */ + select(allReqAndHandlers reverseMap { case (r, h) => ReqAndHandler(r, h) }, wanted).reverse } - val code = new StringBuffer - val trailingLines = new ArrayBuffer[String] - val accessPath = new StringBuffer - val impname = compiler.nme.INTERPRETER_IMPORT_WRAPPER - val currentImps = mutable.Set.empty[Name] + val code, trailingLines, accessPath = new StringBuffer + val currentImps = HashSet[Name]() // add code for a new object to hold some imports - /*def addWrapper() { - code.append("object " + impname + "{\n") - trailingLines.append("}\n") - accessPath.append("." + impname) - currentImps.clear - }*/ def addWrapper() { + /* + val impname = INTERPRETER_IMPORT_WRAPPER + code append "object %s {\n".format(impname) + trailingLines append "}\n" + accessPath append ("." + impname) + currentImps.clear + */ + val impname = INTERPRETER_IMPORT_WRAPPER code.append("@serializable class " + impname + "C {\n") trailingLines.append("}\nval " + impname + " = new " + impname + "C;\n") accessPath.append("." + impname) @@ -388,85 +492,74 @@ class SparkInterpreter(val settings: Settings, out: PrintWriter) { addWrapper() - // loop through previous requests, adding imports - // for each one - for ((req,handler) <- reqsToUse) { - // If the user entered an import, then just use it - - // add an import wrapping level if the import might - // conflict with some other import - if(handler.importsWildcard || - currentImps.exists(handler.importedNames.contains)) - if(!currentImps.isEmpty) - addWrapper() - - if (handler.member.isInstanceOf[Import]) - code.append(handler.member.toString + ";\n") - + // loop through previous requests, adding imports for each one + for (ReqAndHandler(req, handler) <- reqsToUse) { + handler match { + // If the user entered an import, then just use it; add an import wrapping + // level if the import might conflict with some other import + case x: ImportHandler => + if (x.importsWildcard || (currentImps exists (x.importedNames contains _))) + addWrapper() + + code append (x.member.toString + "\n") + // give wildcard imports a import wrapper all to their own - if(handler.importsWildcard) - addWrapper() - else - currentImps ++= handler.importedNames - - // For other requests, import each bound variable. - // import them explicitly instead of with _, so that - // ambiguity errors will not be generated. Also, quote - // the name of the variable, so that we don't need to - // handle quoting keywords separately. - for (imv <- handler.boundNames) { - if (currentImps.contains(imv)) - addWrapper() + if (x.importsWildcard) addWrapper() + else currentImps ++= x.importedNames + + // For other requests, import each bound variable. + // import them explicitly instead of with _, so that + // ambiguity errors will not be generated. Also, quote + // the name of the variable, so that we don't need to + // handle quoting keywords separately. + case x => + for (imv <- x.boundNames) { + // MATEI: Commented this check out because it was messing up for case classes + // (trying to import them twice within the same wrapper), but that is more likely + // due to a miscomputation of names that makes the code think they're unique. + // Need to evaluate whether having so many wrappers is a bad thing. + /*if (currentImps contains imv) */ addWrapper() + code.append("val " + req.objectName + "$VAL = " + req.objectName + ".INSTANCE;\n") - code.append("import ") - code.append(req.objectName + "$VAL" + req.accessPath + ".`" + imv + "`;\n") - // The code below is less likely to pull in bad variables, but prevents use of vars & classes - //code.append("val `" + imv + "` = " + req.objectName + ".INSTANCE" + req.accessPath + ".`" + imv + "`;\n") + code.append("import " + req.objectName + "$VAL" + req.accessPath + ".`" + imv + "`;\n") + + //code append ("import %s\n" format (req fullPath imv)) currentImps += imv } + } } - - addWrapper() // Add one extra wrapper, to prevent warnings - // in the frequent case of redefining - // the value bound in the last interpreter - // request. - - (code.toString, trailingLines.reverse.mkString, accessPath.toString) + // add one extra wrapper, to prevent warnings in the common case of + // redefining the value bound in the last interpreter request. + addWrapper() + ComputedImports(code.toString, trailingLines.toString, accessPath.toString) } - /** Parse a line into a sequence of trees. Returns None if the input - * is incomplete. */ - private def parse(line: String): Option[List[Tree]] = { + /** Parse a line into a sequence of trees. Returns None if the input is incomplete. */ + private def parse(line: String): Option[List[Tree]] = { var justNeedsMore = false reporter.withIncompleteHandler((pos,msg) => {justNeedsMore = true}) { // simple parse: just parse it, nothing else def simpleParse(code: String): List[Tree] = { reporter.reset - val unit = - new CompilationUnit( - new BatchSourceFile("<console>", code.toCharArray())) - val scanner = new compiler.syntaxAnalyzer.UnitParser(unit); - val xxx = scanner.templateStatSeq(false); - (xxx._2) - } - val (trees) = simpleParse(line) - if (reporter.hasErrors) { - Some(Nil) // the result did not parse, so stop - } else if (justNeedsMore) { - None - } else { - Some(trees) + val unit = new CompilationUnit(new BatchSourceFile("<console>", code)) + val scanner = new compiler.syntaxAnalyzer.UnitParser(unit) + + scanner.templateStatSeq(false)._2 } + val trees = simpleParse(line) + + if (reporter.hasErrors) Some(Nil) // the result did not parse, so stop + else if (justNeedsMore) None + else Some(trees) } } /** Compile an nsc SourceFile. Returns true if there are - * no compilation errors, or false othrewise. + * no compilation errors, or false otherwise. */ - def compileSources(sources: List[SourceFile]): Boolean = { - val cr = new compiler.Run + def compileSources(sources: SourceFile*): Boolean = { reporter.reset - cr.compileSources(sources) + new compiler.Run() compileSources sources.toList !reporter.hasErrors } @@ -474,28 +567,62 @@ class SparkInterpreter(val settings: Settings, out: PrintWriter) { * compilation errors, or false otherwise. */ def compileString(code: String): Boolean = - compileSources(List(new BatchSourceFile("<script>", code.toCharArray))) + compileSources(new BatchSourceFile("<script>", code)) + + def compileAndSaveRun(label: String, code: String) = { + if (SPARK_DEBUG_REPL) + println(code) + if (isReplDebug) { + parse(code) match { + case Some(trees) => trees foreach (t => DBG(compiler.asCompactString(t))) + case _ => DBG("Parse error:\n\n" + code) + } + } + val run = new compiler.Run() + run.compileSources(List(new BatchSourceFile(label, code))) + run + } /** Build a request from the user. <code>trees</code> is <code>line</code> * after being parsed. */ - private def buildRequest(trees: List[Tree], line: String, lineName: String): Request = - new Request(line, lineName) - - private def chooseHandler(member: Tree): Option[MemberHandler] = - member match { - case member: DefDef => - Some(new DefHandler(member)) - case member: ValDef => - Some(new ValHandler(member)) - case member@Assign(Ident(_), _) => Some(new AssignHandler(member)) - case member: ModuleDef => Some(new ModuleHandler(member)) - case member: ClassDef => Some(new ClassHandler(member)) - case member: TypeDef => Some(new TypeAliasHandler(member)) - case member: Import => Some(new ImportHandler(member)) - case DocDef(_, documented) => chooseHandler(documented) - case member => Some(new GenericHandler(member)) + private def buildRequest(line: String, lineName: String, trees: List[Tree]): Request = + new Request(line, lineName, trees) + + private def chooseHandler(member: Tree): MemberHandler = member match { + case member: DefDef => new DefHandler(member) + case member: ValDef => new ValHandler(member) + case member@Assign(Ident(_), _) => new AssignHandler(member) + case member: ModuleDef => new ModuleHandler(member) + case member: ClassDef => new ClassHandler(member) + case member: TypeDef => new TypeAliasHandler(member) + case member: Import => new ImportHandler(member) + case DocDef(_, documented) => chooseHandler(documented) + case member => new GenericHandler(member) + } + + private def requestFromLine(line: String, synthetic: Boolean): Either[IR.Result, Request] = { + val trees = parse(indentCode(line)) match { + case None => return Left(IR.Incomplete) + case Some(Nil) => return Left(IR.Error) // parse error or empty input + case Some(trees) => trees + } + + // use synthetic vars to avoid filling up the resXX slots + def varName = if (synthetic) getSynthVarName else getVarName + + // Treat a single bare expression specially. This is necessary due to it being hard to + // modify code at a textual level, and it being hard to submit an AST to the compiler. + if (trees.size == 1) trees.head match { + case _:Assign => // we don't want to include assignments + case _:TermTree | _:Ident | _:Select => // ... but do want these as valdefs. + return requestFromLine("val %s =\n%s".format(varName, line), synthetic) + case _ => } + + // figure out what kind of request + Right(buildRequest(line, lineNameCreator(), trees)) + } /** <p> * Interpret one line of input. All feedback, including parse errors @@ -511,55 +638,34 @@ class SparkInterpreter(val settings: Settings, out: PrintWriter) { * @param line ... * @return ... */ - def interpret(line: String): IR.Result = { - if (prevRequests.isEmpty) - new compiler.Run // initialize the compiler - - // parse - val trees = parse(indentCode(line)) match { - case None => return IR.Incomplete - case (Some(Nil)) => return IR.Error // parse error or empty input - case Some(trees) => trees - } - - trees match { - case List(_:Assign) => () - - case List(_:TermTree) | List(_:Ident) | List(_:Select) => - // Treat a single bare expression specially. - // This is necessary due to it being hard to modify - // code at a textual level, and it being hard to - // submit an AST to the compiler. - return interpret("val "+newVarName()+" = \n"+line) - - case _ => () + def interpret(line: String): IR.Result = interpret(line, false) + def interpret(line: String, synthetic: Boolean): IR.Result = { + def loadAndRunReq(req: Request) = { + val (result, succeeded) = req.loadAndRun + if (printResults || !succeeded) + out print clean(result) + + // book-keeping + if (succeeded && !synthetic) + recordRequest(req) + + if (succeeded) IR.Success + else IR.Error } - - val lineName = newLineName - - // figure out what kind of request - val req = buildRequest(trees, line, lineName) - if (req eq null) return IR.Error // a disallowed statement type - - if (!req.compile) - return IR.Error // an error happened during compilation, e.g. a type error - - val (interpreterResultString, succeeded) = req.loadAndRun - - if (printResults || !succeeded) { - // print the result - out.print(clean(interpreterResultString)) + + if (compiler == null) IR.Error + else requestFromLine(line, synthetic) match { + case Left(result) => result + case Right(req) => + // null indicates a disallowed statement type; otherwise compile and + // fail if false (implying e.g. a type error) + if (req == null || !req.compile) IR.Error + else loadAndRunReq(req) } - - // book-keeping - if (succeeded) - prevRequests += req - - if (succeeded) IR.Success else IR.Error } - /** A counter used for numbering objects created by <code>bind()</code>. */ - private var binderNum = 0 + /** A name creator used for objects created by <code>bind()</code>. */ + private lazy val newBinder = new NameCreator("binder") /** Bind a specified name to a specified value. The name may * later be used by expressions passed to interpret. @@ -570,30 +676,35 @@ class SparkInterpreter(val settings: Settings, out: PrintWriter) { * @return an indication of whether the binding succeeded */ def bind(name: String, boundType: String, value: Any): IR.Result = { - val binderName = "binder" + binderNum - binderNum += 1 - - compileString( - "object " + binderName + - "{ var value: " + boundType + " = _; " + - " def set(x: Any) = value=x.asInstanceOf[" + boundType + "]; }") + val binderName = newBinder() - val binderObject = - Class.forName(binderName, true, classLoader) - val setterMethod = - (binderObject - .getDeclaredMethods - .toList - .find(meth => meth.getName == "set") - .get) - var argsHolder: Array[Any] = null // this roundabout approach is to try and - // make sure the value is boxed - argsHolder = List(value).toArray - setterMethod.invoke(null, argsHolder.asInstanceOf[Array[AnyRef]]: _*) + compileString(""" + |object %s { + | var value: %s = _ + | def set(x: Any) = value = x.asInstanceOf[%s] + |} + """.stripMargin.format(binderName, boundType, boundType)) - interpret("val " + name + " = " + binderName + ".value") + val binderObject = loadByName(binderName) + val setterMethod = methodByName(binderObject, "set") + + setterMethod.invoke(null, value.asInstanceOf[AnyRef]) + interpret("val %s = %s.value".format(name, binderName)) + } + + def quietBind(name: String, boundType: String, value: Any): IR.Result = + beQuietDuring { bind(name, boundType, value) } + + /** Reset this interpreter, forgetting all user-specified requests. */ + def reset() { + //virtualDirectory.clear + virtualDirectory.delete + virtualDirectory.create + resetClassLoader() + lineNameCreator.reset() + varNameCreator.reset() + prevRequests.clear } - /** <p> * This instance is no longer needed, so release any resources @@ -601,206 +712,191 @@ class SparkInterpreter(val settings: Settings, out: PrintWriter) { * </p> */ def close() { - reporter.flush() + reporter.flush } /** A traverser that finds all mentioned identifiers, i.e. things - * that need to be imported. - * It might return extra names. + * that need to be imported. It might return extra names. */ - private class ImportVarsTraverser(definedVars: List[Name]) extends Traverser { + private class ImportVarsTraverser extends Traverser { val importVars = new HashSet[Name]() - override def traverse(ast: Tree) { - ast match { - case Ident(name) => importVars += name - case _ => super.traverse(ast) - } + override def traverse(ast: Tree) = ast match { + // XXX this is obviously inadequate but it's going to require some effort + // to get right. + case Ident(name) if !(name.toString startsWith "x$") => importVars += name + case _ => super.traverse(ast) } } - /** Class to handle one member among all the members included * in a single interpreter request. */ private sealed abstract class MemberHandler(val member: Tree) { - val usedNames: List[Name] = { - val ivt = new ImportVarsTraverser(boundNames) - ivt.traverseTrees(List(member)) + lazy val usedNames: List[Name] = { + val ivt = new ImportVarsTraverser() + ivt traverse member ivt.importVars.toList } def boundNames: List[Name] = Nil - def valAndVarNames: List[Name] = Nil - def defNames: List[Name] = Nil - val importsWildcard = false - val importedNames: Seq[Name] = Nil - val definesImplicit = - member match { - case tree:MemberDef => - tree.mods.hasFlag(symtab.Flags.IMPLICIT) - case _ => false - } + val definesImplicit = cond(member) { + case tree: MemberDef => tree.mods hasFlag Flags.IMPLICIT + } + def generatesValue: Option[Name] = None def extraCodeToEvaluate(req: Request, code: PrintWriter) { } def resultExtractionCode(req: Request, code: PrintWriter) { } + + override def toString = "%s(used = %s)".format(this.getClass.toString split '.' last, usedNames) } private class GenericHandler(member: Tree) extends MemberHandler(member) - + private class ValHandler(member: ValDef) extends MemberHandler(member) { - override lazy val boundNames = List(member.name) - override def valAndVarNames = boundNames + lazy val ValDef(mods, vname, _, _) = member + lazy val prettyName = NameTransformer.decode(vname) + lazy val isLazy = mods hasFlag Flags.LAZY + + override lazy val boundNames = List(vname) + override def generatesValue = Some(vname) override def resultExtractionCode(req: Request, code: PrintWriter) { - val vname = member.name - if (member.mods.isPublic && - !(isGeneratedVarName(vname) && - req.typeOf(compiler.encode(vname)) == "Unit")) - { - val prettyName = NameTransformer.decode(vname) - code.print(" + \"" + prettyName + ": " + - string2code(req.typeOf(vname)) + - " = \" + " + - " { val tmp = scala.runtime.ScalaRunTime.stringOf(" + - req.fullPath(vname) + - "); " + - " (if(tmp.toSeq.contains('\\n')) \"\\n\" else \"\") + tmp + \"\\n\"} ") - } + val isInternal = isGeneratedVarName(vname) && req.typeOfEnc(vname) == "Unit" + if (!mods.isPublic || isInternal) return + + lazy val extractor = "scala.runtime.ScalaRunTime.stringOf(%s)".format(req fullPath vname) + + // if this is a lazy val we avoid evaluating it here + val resultString = if (isLazy) codegenln(false, "<lazy>") else extractor + val codeToPrint = + """ + "%s: %s = " + %s""".format(prettyName, string2code(req typeOf vname), resultString) + + code print codeToPrint } } private class DefHandler(defDef: DefDef) extends MemberHandler(defDef) { - override lazy val boundNames = List(defDef.name) - override def defNames = boundNames - - override def resultExtractionCode(req: Request, code: PrintWriter) { - if (defDef.mods.isPublic) - code.print("+\""+string2code(defDef.name)+": "+ - string2code(req.typeOf(defDef.name))+"\\n\"") - } + lazy val DefDef(mods, name, _, vparamss, _, _) = defDef + override lazy val boundNames = List(name) + // true if 0-arity + override def generatesValue = + if (vparamss.isEmpty || vparamss.head.isEmpty) Some(name) + else None + + override def resultExtractionCode(req: Request, code: PrintWriter) = + if (mods.isPublic) code print codegenln(name, ": ", req.typeOf(name)) } private class AssignHandler(member: Assign) extends MemberHandler(member) { - val lhs = member. lhs.asInstanceOf[Ident] // an unfortunate limitation + val lhs = member.lhs.asInstanceOf[Ident] // an unfortunate limitation + val helperName = newTermName(synthVarNameCreator()) + override def generatesValue = Some(helperName) - val helperName = newTermName(newInternalVarName()) - override val valAndVarNames = List(helperName) - - override def extraCodeToEvaluate(req: Request, code: PrintWriter) { - code.println("val "+helperName+" = "+member.lhs+";") - } + override def extraCodeToEvaluate(req: Request, code: PrintWriter) = + code println """val %s = %s""".format(helperName, lhs) /** Print out lhs instead of the generated varName */ override def resultExtractionCode(req: Request, code: PrintWriter) { - code.print(" + \"" + lhs + ": " + - string2code(req.typeOf(compiler.encode(helperName))) + - " = \" + " + - string2code(req.fullPath(helperName)) - + " + \"\\n\"") + val lhsType = string2code(req typeOfEnc helperName) + val res = string2code(req fullPath helperName) + val codeToPrint = """ + "%s: %s = " + %s + "\n" """.format(lhs, lhsType, res) + + code println codeToPrint } } private class ModuleHandler(module: ModuleDef) extends MemberHandler(module) { - override lazy val boundNames = List(module.name) + lazy val ModuleDef(mods, name, _) = module + override lazy val boundNames = List(name) + override def generatesValue = Some(name) - override def resultExtractionCode(req: Request, code: PrintWriter) { - code.println(" + \"defined module " + - string2code(module.name) - + "\\n\"") - } + override def resultExtractionCode(req: Request, code: PrintWriter) = + code println codegenln("defined module ", name) } - private class ClassHandler(classdef: ClassDef) - extends MemberHandler(classdef) - { - override lazy val boundNames = - List(classdef.name) ::: - (if (classdef.mods.hasFlag(Flags.CASE)) - List(classdef.name.toTermName) - else - Nil) - - // TODO: MemberDef.keyword does not include "trait"; - // otherwise it could be used here - def keyword: String = - if (classdef.mods.isTrait) "trait" else "class" + private class ClassHandler(classdef: ClassDef) extends MemberHandler(classdef) { + lazy val ClassDef(mods, name, _, _) = classdef + override lazy val boundNames = + name :: (if (mods hasFlag Flags.CASE) List(name.toTermName) else Nil) - override def resultExtractionCode(req: Request, code: PrintWriter) { - code.print( - " + \"defined " + - keyword + - " " + - string2code(classdef.name) + - "\\n\"") - } + override def resultExtractionCode(req: Request, code: PrintWriter) = + code print codegenln("defined %s %s".format(classdef.keyword, name)) } - private class TypeAliasHandler(typeDef: TypeDef) - extends MemberHandler(typeDef) - { - override lazy val boundNames = - if (typeDef.mods.isPublic && compiler.treeInfo.isAliasTypeDef(typeDef)) - List(typeDef.name) - else - Nil + private class TypeAliasHandler(typeDef: TypeDef) extends MemberHandler(typeDef) { + lazy val TypeDef(mods, name, _, _) = typeDef + def isAlias() = mods.isPublic && compiler.treeInfo.isAliasTypeDef(typeDef) + override lazy val boundNames = if (isAlias) List(name) else Nil - override def resultExtractionCode(req: Request, code: PrintWriter) { - code.println(" + \"defined type alias " + - string2code(typeDef.name) + "\\n\"") - } + override def resultExtractionCode(req: Request, code: PrintWriter) = + code println codegenln("defined type alias ", name) } private class ImportHandler(imp: Import) extends MemberHandler(imp) { - override def resultExtractionCode(req: Request, code: PrintWriter) { - code.println("+ \"" + imp.toString + "\\n\"") + lazy val Import(expr, selectors) = imp + def targetType = stringToCompilerType(expr.toString) match { + case NoType => None + case x => Some(x) } - + + private def selectorWild = selectors filter (_.name == USCOREkw) // wildcard imports, e.g. import foo._ + private def selectorMasked = selectors filter (_.rename == USCOREkw) // masking imports, e.g. import foo.{ bar => _ } + private def selectorNames = selectors map (_.name) + private def selectorRenames = selectors map (_.rename) filterNot (_ == null) + /** Whether this import includes a wildcard import */ - override val importsWildcard = - imp.selectors.map(_._1).contains(nme.USCOREkw) + val importsWildcard = selectorWild.nonEmpty + + /** Complete list of names imported by a wildcard */ + def wildcardImportedNames: List[Name] = ( + for (tpe <- targetType ; if importsWildcard) yield + tpe.nonPrivateMembers filter (x => x.isMethod && x.isPublic) map (_.name) distinct + ).toList.flatten /** The individual names imported by this statement */ - override val importedNames: Seq[Name] = - for { - val (_,sel) <- imp.selectors - sel != null - sel != nme.USCOREkw - val name <- List(sel.toTypeName, sel.toTermName) - } - yield name + /** XXX come back to this and see what can be done with wildcards now that + * we know how to enumerate the identifiers. + */ + val importedNames: List[Name] = + selectorRenames filterNot (_ == USCOREkw) flatMap (x => List(x.toTypeName, x.toTermName)) + + override def resultExtractionCode(req: Request, code: PrintWriter) = + code println codegenln(imp.toString) } /** One line of code submitted by the user for interpretation */ - private class Request(val line: String, val lineName: String) { - val trees = parse(line) match { - case Some(ts) => ts - case None => Nil - } - + private class Request(val line: String, val lineName: String, val trees: List[Tree]) { /** name to use for the object that will compute "line" */ - def objectName = lineName + compiler.nme.INTERPRETER_WRAPPER_SUFFIX + def objectName = lineName + INTERPRETER_WRAPPER_SUFFIX /** name of the object that retrieves the result from the above object */ def resultObjectName = "RequestResult$" + objectName - val handlers: List[MemberHandler] = trees.flatMap(chooseHandler(_)) + /** handlers for each tree in this request */ + val handlers: List[MemberHandler] = trees map chooseHandler /** all (public) names defined by these statements */ - val boundNames = (ListSet() ++ handlers.flatMap(_.boundNames)).toList + val boundNames = handlers flatMap (_.boundNames) /** list of names used by this expression */ - val usedNames: List[Name] = handlers.flatMap(_.usedNames) - - def myImportsCode = importsCode(Set.empty ++ usedNames) - - /** Code to append to objectName to access anything that - * the request binds. */ - val accessPath = myImportsCode._3 + val usedNames: List[Name] = handlers flatMap (_.usedNames) + + /** def and val names */ + def defNames = partialFlatMap(handlers) { case x: DefHandler => x.boundNames } + def valueNames = partialFlatMap(handlers) { + case x: AssignHandler => List(x.helperName) + case x: ValHandler => boundNames + case x: ModuleHandler => List(x.name) + } + /** Code to import bound names from previous lines - accessPath is code to + * append to objectName to access anything bound by request. + */ + val ComputedImports(importsPreamble, importsTrailer, accessPath) = + importsCode(Set.empty ++ usedNames) /** Code to access a variable with the specified name */ - def fullPath(vname: String): String = - objectName + ".INSTANCE" + accessPath + ".`" + vname + "`" + def fullPath(vname: String): String = "%s.`%s`".format(objectName + ".INSTANCE" + accessPath, vname) /** Code to access a variable with the specified name */ def fullPath(vname: Name): String = fullPath(vname.toString) @@ -809,196 +905,486 @@ class SparkInterpreter(val settings: Settings, out: PrintWriter) { def toCompute = line /** generate the source code for the object that computes this request */ - def objectSourceCode: String = { - val src = stringFrom { code => - // header for the wrapper object - code.println("@serializable class " + objectName + " {") - - val (importsPreamble, importsTrailer, _) = myImportsCode - - code.print(importsPreamble) - - code.println(indentCode(toCompute)) - - handlers.foreach(_.extraCodeToEvaluate(this,code)) - - code.println(importsTrailer) - - //end the wrapper object - code.println(";}") - - //create an object - code.println("object " + objectName + " {") - code.println(" val INSTANCE = new " + objectName + "();") - code.println("}") - } - //println(src) - src + def objectSourceCode: String = stringFromWriter { code => + val preamble = """ + |@serializable class %s { + | %s%s + """.stripMargin.format(objectName, importsPreamble, indentCode(toCompute)) + val postamble = importsTrailer + "\n}" + + code println preamble + handlers foreach { _.extraCodeToEvaluate(this, code) } + code println postamble + + //create an object + code.println("object " + objectName + " {") + code.println(" val INSTANCE = new " + objectName + "();") + code.println("}") } - - /** Types of variables defined by this request. They are computed - after compilation of the main object */ - var typeOf: Map[Name, String] = _ /** generate source code for the object that retrieves the result from objectSourceCode */ - def resultObjectSourceCode: String = - stringFrom(code => { - code.println("object " + resultObjectName) - code.println("{ val result: String = {") - code.println(objectName + ".INSTANCE" + accessPath + ";") // evaluate the object, to make sure its constructor is run - code.print("(\"\"") // print an initial empty string, so later code can - // uniformly be: + morestuff - handlers.foreach(_.resultExtractionCode(this, code)) - code.println("\n)}") - code.println(";}") - }) + def resultObjectSourceCode: String = stringFromWriter { code => + /** We only want to generate this code when the result + * is a value which can be referred to as-is. + */ + val valueExtractor = handlers.last.generatesValue match { + case Some(vname) if typeOf contains vname => + """ + |lazy val scala_repl_value = { + | scala_repl_result + | %s + |}""".stripMargin.format(fullPath(vname)) + case _ => "" + } + + // first line evaluates object to make sure constructor is run + // initial "" so later code can uniformly be: + etc + val preamble = """ + |object %s { + | %s + | val scala_repl_result: String = { + | %s + | ("" + """.stripMargin.format(resultObjectName, valueExtractor, objectName + ".INSTANCE" + accessPath) + + val postamble = """ + | ) + | } + |} + """.stripMargin + + code println preamble + handlers foreach { _.resultExtractionCode(this, code) } + code println postamble + } + + // compile the object containing the user's code + lazy val objRun = compileAndSaveRun("<console>", objectSourceCode) + + // compile the result-extraction object + lazy val extractionObjectRun = compileAndSaveRun("<console>", resultObjectSourceCode) + lazy val loadedResultObject = loadByName(resultObjectName) + + def extractionValue(): Option[AnyRef] = { + // ensure it has run + extractionObjectRun + + // load it and retrieve the value + try Some(loadedResultObject getMethod "scala_repl_value" invoke loadedResultObject) + catch { case _: Exception => None } + } /** Compile the object file. Returns whether the compilation succeeded. * If all goes well, the "types" map is computed. */ def compile(): Boolean = { - reporter.reset // without this, error counting is not correct, - // and the interpreter sometimes overlooks compile failures! + // error counting is wrong, hence interpreter may overlook failure - so we reset + reporter.reset // compile the main object - val objRun = new compiler.Run() - //println("source: "+objectSourceCode) //DEBUG - objRun.compileSources( - List(new BatchSourceFile("<console>", objectSourceCode.toCharArray)) - ) - if (reporter.hasErrors) return false - + objRun + + // bail on error + if (reporter.hasErrors) + return false // extract and remember types - typeOf = findTypes(objRun) + typeOf // compile the result-extraction object - new compiler.Run().compileSources( - List(new BatchSourceFile("<console>", resultObjectSourceCode.toCharArray)) - ) + extractionObjectRun // success !reporter.hasErrors } - /** Dig the types of all bound variables out of the compiler run. - * - * @param objRun ... - * @return ... - */ - def findTypes(objRun: compiler.Run): Map[Name, String] = { - def valAndVarNames = handlers.flatMap(_.valAndVarNames) - def defNames = handlers.flatMap(_.defNames) - - def getTypes(names: List[Name], nameMap: Name=>Name): Map[Name, String] = { - /** the outermost wrapper object */ - val outerResObjSym: Symbol = - compiler.definitions.getMember(compiler.definitions.EmptyPackage, - newTermName(objectName).toTypeName) // MATEI: added toTypeName - - /** the innermost object inside the wrapper, found by - * following accessPath into the outer one. */ - val resObjSym = - (accessPath.split("\\.")).foldLeft(outerResObjSym)((sym,name) => - if(name == "") sym else - compiler.atPhase(objRun.typerPhase.next) { - sym.info.member(newTermName(name)) }) - - names.foldLeft(Map.empty[Name,String])((map, name) => { - val rawType = - compiler.atPhase(objRun.typerPhase.next) { - resObjSym.info.member(name).tpe - } + def atNextPhase[T](op: => T): T = compiler.atPhase(objRun.typerPhase.next)(op) + + /** The outermost wrapper object */ + lazy val outerResObjSym: Symbol = getMember(EmptyPackage, newTermName(objectName).toTypeName) + /** The innermost object inside the wrapper, found by + * following accessPath into the outer one. */ + lazy val resObjSym = + accessPath.split("\\.").foldLeft(outerResObjSym) { (sym, name) => + if (name == "") sym else + atNextPhase(sym.info member newTermName(name)) + } + + /* typeOf lookup with encoding */ + def typeOfEnc(vname: Name) = typeOf(compiler encode vname) + + /** Types of variables defined by this request. */ + lazy val typeOf: Map[Name, String] = { + def getTypes(names: List[Name], nameMap: Name => Name): Map[Name, String] = { + names.foldLeft(Map.empty[Name, String]) { (map, name) => + val rawType = atNextPhase(resObjSym.info.member(name).tpe) // the types are all =>T; remove the => - val cleanedType= rawType match { + val cleanedType = rawType match { case compiler.PolyType(Nil, rt) => rt case rawType => rawType } - map + (name -> compiler.atPhase(objRun.typerPhase.next) { cleanedType.toString }) - }) + map + (name -> atNextPhase(cleanedType.toString)) + } } - val names1 = getTypes(valAndVarNames, n => compiler.nme.getterToLocal(n)) - val names2 = getTypes(defNames, identity) - names1 ++ names2 + getTypes(valueNames, nme.getterToLocal(_)) ++ getTypes(defNames, identity) } /** load and run the code using reflection */ def loadAndRun: (String, Boolean) = { - val interpreterResultObject: Class[_] = - Class.forName(resultObjectName, true, classLoader) - val resultValMethod: java.lang.reflect.Method = - interpreterResultObject.getMethod("result") - try { - (resultValMethod.invoke(interpreterResultObject).toString(), - true) - } catch { - case e => - def caus(e: Throwable): Throwable = - if (e.getCause eq null) e else caus(e.getCause) - val orig = caus(e) - (stringFrom(str => orig.printStackTrace(str)), false) + val resultValMethod: reflect.Method = loadedResultObject getMethod "scala_repl_result" + // XXX if wrapperExceptions isn't type-annotated we crash scalac + val wrapperExceptions: List[Class[_ <: Throwable]] = + List(classOf[InvocationTargetException], classOf[ExceptionInInitializerError]) + + /** We turn off the binding to accomodate ticket #2817 */ + def onErr: Catcher[(String, Boolean)] = { + case t: Throwable if bindLastException => + withoutBindingLastException { + quietBind("lastException", "java.lang.Throwable", t) + (stringFromWriter(t.printStackTrace(_)), false) + } + } + + catching(onErr) { + unwrapping(wrapperExceptions: _*) { + (resultValMethod.invoke(loadedResultObject).toString, true) + } } } + + override def toString = "Request(line=%s, %s trees)".format(line, trees.size) } -} + + /** A container class for methods to be injected into the repl + * in power mode. + */ + object power { + lazy val compiler: repl.compiler.type = repl.compiler + import compiler.{ phaseNames, atPhase, currentRun } + + def mkContext(code: String = "") = compiler.analyzer.rootContext(mkUnit(code)) + def mkAlias(name: String, what: String) = interpret("type %s = %s".format(name, what)) + def mkSourceFile(code: String) = new BatchSourceFile("<console>", code) + def mkUnit(code: String) = new CompilationUnit(mkSourceFile(code)) + + def mkTree(code: String): Tree = mkTrees(code).headOption getOrElse EmptyTree + def mkTrees(code: String): List[Tree] = parse(code) getOrElse Nil + def mkTypedTrees(code: String*): List[compiler.Tree] = { + class TyperRun extends compiler.Run { + override def stopPhase(name: String) = name == "superaccessors" + } -/** Utility methods for the Interpreter. */ -object Interpreter { - /** Delete a directory tree recursively. Use with care! + reporter.reset + val run = new TyperRun + run compileSources (code.toList.zipWithIndex map { + case (s, i) => new BatchSourceFile("<console %d>".format(i), s) + }) + run.units.toList map (_.body) + } + def mkTypedTree(code: String) = mkTypedTrees(code).head + def mkType(id: String): compiler.Type = stringToCompilerType(id) + + def dump(): String = ( + ("Names used: " :: allUsedNames) ++ + ("\nIdentifiers: " :: unqualifiedIds) + ) mkString " " + + lazy val allPhases: List[Phase] = phaseNames map (currentRun phaseNamed _) + def atAllPhases[T](op: => T): List[(String, T)] = allPhases map (ph => (ph.name, atPhase(ph)(op))) + def showAtAllPhases(op: => Any): Unit = + atAllPhases(op.toString) foreach { case (ph, op) => Console.println("%15s -> %s".format(ph, op take 240)) } + } + + def unleash(): Unit = beQuietDuring { + interpret("import scala.tools.nsc._") + repl.bind("repl", "spark.repl.SparkInterpreter", this) + interpret("val global: repl.compiler.type = repl.compiler") + interpret("val power: repl.power.type = repl.power") + // interpret("val replVars = repl.replVars") + } + + /** Artificial object demonstrating completion */ + // lazy val replVars = CompletionAware( + // Map[String, CompletionAware]( + // "ids" -> CompletionAware(() => unqualifiedIds, completionAware _), + // "synthVars" -> CompletionAware(() => allBoundNames filter isSynthVarName map (_.toString)), + // "types" -> CompletionAware(() => allSeenTypes map (_.toString)), + // "implicits" -> CompletionAware(() => allImplicits map (_.toString)) + // ) + // ) + + /** Returns the name of the most recent interpreter result. + * Mostly this exists so you can conveniently invoke methods on + * the previous result. + */ + def mostRecentVar: String = + if (mostRecentlyHandledTree.isEmpty) "" + else mostRecentlyHandledTree.get match { + case x: ValOrDefDef => x.name + case Assign(Ident(name), _) => name + case ModuleDef(_, name, _) => name + case _ => onull(varNameCreator.mostRecent) + } + + private def requestForName(name: Name): Option[Request] = + prevRequests.reverse find (_.boundNames contains name) + + private def requestForIdent(line: String): Option[Request] = requestForName(newTermName(line)) + + def stringToCompilerType(id: String): compiler.Type = { + // if it's a recognized identifier, the type of that; otherwise treat the + // String like a value (e.g. scala.collection.Map) . + def findType = typeForIdent(id) match { + case Some(x) => definitions.getClass(newTermName(x)).tpe + case _ => definitions.getModule(newTermName(id)).tpe + } + + try findType catch { case _: MissingRequirementError => NoType } + } + + def typeForIdent(id: String): Option[String] = + requestForIdent(id) flatMap (x => x.typeOf get newTermName(id)) + + def methodsOf(name: String) = + evalExpr[List[String]](methodsCode(name)) map (x => NameTransformer.decode(getOriginalName(x))) + + def completionAware(name: String) = { + // XXX working around "object is not a value" crash, i.e. + // import java.util.ArrayList ; ArrayList.<tab> + clazzForIdent(name) flatMap (_ => evalExpr[Option[CompletionAware]](asCompletionAwareCode(name))) + } + + def extractionValueForIdent(id: String): Option[AnyRef] = + requestForIdent(id) flatMap (_.extractionValue) + + /** Executes code looking for a manifest of type T. + */ + def manifestFor[T: Manifest] = + evalExpr[Manifest[T]]("""manifest[%s]""".format(manifest[T])) + + /** Executes code looking for an implicit value of type T. + */ + def implicitFor[T: Manifest] = { + val s = manifest[T].toString + evalExpr[Option[T]]("{ def f(implicit x: %s = null): %s = x ; Option(f) }".format(s, s)) + // We don't use implicitly so as to fail without failing. + // evalExpr[T]("""implicitly[%s]""".format(manifest[T])) + } + /** Executes code looking for an implicit conversion from the type + * of the given identifier to CompletionAware. + */ + def completionAwareImplicit[T](id: String) = { + val f1string = "%s => %s".format(typeForIdent(id).get, classOf[CompletionAware].getName) + val code = """{ + | def f(implicit x: (%s) = null): %s = x + | val f1 = f + | if (f1 == null) None else Some(f1(%s)) + |}""".stripMargin.format(f1string, f1string, id) + + evalExpr[Option[CompletionAware]](code) + } + + def clazzForIdent(id: String): Option[Class[_]] = + extractionValueForIdent(id) flatMap (x => Option(x) map (_.getClass)) + + private def methodsCode(name: String) = + "%s.%s(%s)".format(classOf[ReflectionCompletion].getName, "methodsOf", name) + + private def asCompletionAwareCode(name: String) = + "%s.%s(%s)".format(classOf[CompletionAware].getName, "unapply", name) + + private def getOriginalName(name: String): String = + nme.originalName(newTermName(name)).toString + + case class InterpreterEvalException(msg: String) extends Exception(msg) + def evalError(msg: String) = throw InterpreterEvalException(msg) + + /** The user-facing eval in :power mode wraps an Option. */ - def deleteRecursively(path: File) { - path match { - case _ if !path.exists => - () - case _ if path.isDirectory => - for (p <- path.listFiles) - deleteRecursively(p) - path.delete - case _ => - path.delete + def eval[T: Manifest](line: String): Option[T] = + try Some(evalExpr[T](line)) + catch { case InterpreterEvalException(msg) => out println indentString(msg) ; None } + + def evalExpr[T: Manifest](line: String): T = { + // Nothing means the type could not be inferred. + if (manifest[T] eq Manifest.Nothing) + evalError("Could not infer type: try 'eval[SomeType](%s)' instead".format(line)) + + val lhs = getSynthVarName + beQuietDuring { interpret("val " + lhs + " = { " + line + " } ") } + + // TODO - can we meaningfully compare the inferred type T with + // the internal compiler Type assigned to lhs? + // def assignedType = prevRequests.last.typeOf(newTermName(lhs)) + + val req = requestFromLine(lhs, true) match { + case Left(result) => evalError(result.toString) + case Right(req) => req + } + if (req == null || !req.compile || req.handlers.size != 1) + evalError("Eval error.") + + try req.extractionValue.get.asInstanceOf[T] catch { + case e: Exception => evalError(e.getMessage) + } + } + + def interpretExpr[T: Manifest](code: String): Option[T] = beQuietDuring { + interpret(code) match { + case IR.Success => + try prevRequests.last.extractionValue map (_.asInstanceOf[T]) + catch { case e: Exception => out println e ; None } + case _ => None } } - /** Heuristically strip interpreter wrapper prefixes - * from an interpreter output string. + /** Another entry point for tab-completion, ids in scope */ + private def unqualifiedIdNames() = partialFlatMap(allHandlers) { + case x: AssignHandler => List(x.helperName) + case x: ValHandler => List(x.vname) + case x: ModuleHandler => List(x.name) + case x: DefHandler => List(x.name) + case x: ImportHandler => x.importedNames + } filterNot isSynthVarName + + /** Types which have been wildcard imported, such as: + * val x = "abc" ; import x._ // type java.lang.String + * import java.lang.String._ // object java.lang.String + * + * Used by tab completion. + * + * XXX right now this gets import x._ and import java.lang.String._, + * but doesn't figure out import String._. There's a lot of ad hoc + * scope twiddling which should be swept away in favor of digging + * into the compiler scopes. */ - def stripWrapperGunk(str: String): String = { - //val wrapregex = "(line[0-9]+\\$object[$.])?(\\$iw[$.])*" - //str.replaceAll(wrapregex, "") - str + def wildcardImportedTypes(): List[Type] = { + val xs = allHandlers collect { case x: ImportHandler if x.importsWildcard => x.targetType } + xs.flatten.reverse.distinct } + + /** Another entry point for tab-completion, ids in scope */ + def unqualifiedIds() = (unqualifiedIdNames() map (_.toString)).distinct.sorted - /** Convert a string into code that can recreate the string. - * This requires replacing all special characters by escape - * codes. It does not add the surrounding " marks. */ - def string2code(str: String): String = { - /** Convert a character to a backslash-u escape */ - def char2uescape(c: Char): String = { - var rest = c.toInt - val buf = new StringBuilder - for (i <- 1 to 4) { - buf ++= (rest % 16).toHexString - rest = rest / 16 + /** For static/object method completion */ + def getClassObject(path: String): Option[Class[_]] = //classLoader tryToLoadClass path + try { + Some(Class.forName(path, true, classLoader)) + } catch { + case e: Exception => None + } + + /** Parse the ScalaSig to find type aliases */ + def aliasForType(path: String) = ByteCode.aliasForType(path) + + // Coming soon + // implicit def string2liftedcode(s: String): LiftedCode = new LiftedCode(s) + // case class LiftedCode(code: String) { + // val lifted: String = { + // beQuietDuring { interpret(code) } + // eval2[String]("({ " + code + " }).toString") + // } + // def >> : String = lifted + // } + + // debugging + def isReplDebug = settings.Yrepldebug.value + def isCompletionDebug = settings.Ycompletion.value + def DBG(s: String) = if (isReplDebug) out println s else () +} + +/** Utility methods for the Interpreter. */ +object SparkInterpreter { + + import scala.collection.generic.CanBuildFrom + def partialFlatMap[A, B, CC[X] <: Traversable[X]] + (coll: CC[A]) + (pf: PartialFunction[A, CC[B]]) + (implicit bf: CanBuildFrom[CC[A], B, CC[B]]) = + { + val b = bf(coll) + for (x <- coll collect pf) + b ++= x + + b.result + } + + object DebugParam { + implicit def tuple2debugparam[T](x: (String, T))(implicit m: Manifest[T]): DebugParam[T] = + DebugParam(x._1, x._2) + + implicit def any2debugparam[T](x: T)(implicit m: Manifest[T]): DebugParam[T] = + DebugParam("p" + getCount(), x) + + private var counter = 0 + def getCount() = { counter += 1; counter } + } + case class DebugParam[T](name: String, param: T)(implicit m: Manifest[T]) { + val manifest = m + val typeStr = { + val str = manifest.toString + // I'm sure there are more to be discovered... + val regexp1 = """(.*?)\[(.*)\]""".r + val regexp2str = """.*\.type#""" + val regexp2 = (regexp2str + """(.*)""").r + + (str.replaceAll("""\n""", "")) match { + case regexp1(clazz, typeArgs) => "%s[%s]".format(clazz, typeArgs.replaceAll(regexp2str, "")) + case regexp2(clazz) => clazz + case _ => str } - "\\" + "u" + buf.toString.reverse } + } + def breakIf(assertion: => Boolean, args: DebugParam[_]*): Unit = + if (assertion) break(args.toList) + + // start a repl, binding supplied args + def break(args: List[DebugParam[_]]): Unit = { + val intLoop = new SparkInterpreterLoop + intLoop.settings = new Settings(Console.println) + // XXX come back to the dot handling + intLoop.settings.classpath.value = "." + intLoop.createInterpreter + intLoop.in = SparkInteractiveReader.createDefault(intLoop.interpreter) + // rebind exit so people don't accidentally call System.exit by way of predef + intLoop.interpreter.beQuietDuring { + intLoop.interpreter.interpret("""def exit = println("Type :quit to resume program execution.")""") + for (p <- args) { + intLoop.interpreter.bind(p.name, p.typeStr, p.param) + Console println "%s: %s".format(p.name, p.typeStr) + } + } + intLoop.repl() + intLoop.closeInterpreter + } + + def codegenln(leadingPlus: Boolean, xs: String*): String = codegen(leadingPlus, (xs ++ Array("\n")): _*) + def codegenln(xs: String*): String = codegenln(true, xs: _*) + def codegen(xs: String*): String = codegen(true, xs: _*) + def codegen(leadingPlus: Boolean, xs: String*): String = { + val front = if (leadingPlus) "+ " else "" + front + (xs map string2codeQuoted mkString " + ") + } + + def string2codeQuoted(str: String) = "\"" + string2code(str) + "\"" + + /** Convert a string into code that can recreate the string. + * This requires replacing all special characters by escape + * codes. It does not add the surrounding " marks. */ + def string2code(str: String): String = { val res = new StringBuilder - for (c <- str) { - if ("'\"\\" contains c) { - res += '\\' - res += c - } else if (!c.isControl) { - res += c - } else { - res ++= char2uescape(c) - } + for (c <- str) c match { + case '"' | '\'' | '\\' => res += '\\' ; res += c + case _ if c.isControl => res ++= Chars.char2uescape(c) + case _ => res += c } res.toString } } + diff --git a/src/scala/spark/repl/SparkInterpreterLoop.scala b/src/scala/spark/repl/SparkInterpreterLoop.scala index 4aab60fd11..26361fdc25 100644 --- a/src/scala/spark/repl/SparkInterpreterLoop.scala +++ b/src/scala/spark/repl/SparkInterpreterLoop.scala @@ -1,23 +1,73 @@ /* NSC -- new Scala compiler - * Copyright 2005-2009 LAMP/EPFL + * Copyright 2005-2010 LAMP/EPFL * @author Alexander Spoon */ -// $Id: InterpreterLoop.scala 16881 2009-01-09 16:28:11Z cunei $ package spark.repl import scala.tools.nsc import scala.tools.nsc._ -import java.io.{BufferedReader, File, FileReader, PrintWriter} +import Predef.{ println => _, _ } +import java.io.{ BufferedReader, FileReader, PrintWriter } import java.io.IOException -import java.lang.{ClassLoader, System} -import scala.tools.nsc.{InterpreterResults => IR} -import scala.tools.nsc.interpreter._ +import scala.tools.nsc.{ InterpreterResults => IR } +import scala.annotation.tailrec +import scala.collection.mutable.ListBuffer +import scala.concurrent.ops +import util.{ ClassPath } +import interpreter._ +import io.{ File, Process } import spark.SparkContext +// Classes to wrap up interpreter commands and their results +// You can add new commands by adding entries to val commands +// inside InterpreterLoop. +trait InterpreterControl { + self: SparkInterpreterLoop => + + // the default result means "keep running, and don't record that line" + val defaultResult = Result(true, None) + + // a single interpreter command + sealed abstract class Command extends Function1[List[String], Result] { + def name: String + def help: String + def error(msg: String) = { + out.println(":" + name + " " + msg + ".") + Result(true, None) + } + def usage(): String + } + + case class NoArgs(name: String, help: String, f: () => Result) extends Command { + def usage(): String = ":" + name + def apply(args: List[String]) = if (args.isEmpty) f() else error("accepts no arguments") + } + + case class LineArg(name: String, help: String, f: (String) => Result) extends Command { + def usage(): String = ":" + name + " <line>" + def apply(args: List[String]) = f(args mkString " ") + } + + case class OneArg(name: String, help: String, f: (String) => Result) extends Command { + def usage(): String = ":" + name + " <arg>" + def apply(args: List[String]) = + if (args.size == 1) f(args.head) + else error("requires exactly one argument") + } + + case class VarArgs(name: String, help: String, f: (List[String]) => Result) extends Command { + def usage(): String = ":" + name + " [arg]" + def apply(args: List[String]) = f(args) + } + + // the result of a single command + case class Result(keepRunning: Boolean, lineToRecord: Option[String]) +} + /** The * <a href="http://scala-lang.org/" target="_top">Scala</a> * interactive shell. It provides a read-eval-print loop around @@ -32,9 +82,9 @@ import spark.SparkContext * @author Lex Spoon * @version 1.2 */ -class SparkInterpreterLoop(in0: Option[BufferedReader], val out: PrintWriter, - master: Option[String]) -{ +class SparkInterpreterLoop( + in0: Option[BufferedReader], val out: PrintWriter, master: Option[String]) +extends InterpreterControl { def this(in0: BufferedReader, out: PrintWriter, master: String) = this(Some(in0), out, Some(master)) @@ -43,30 +93,28 @@ class SparkInterpreterLoop(in0: Option[BufferedReader], val out: PrintWriter, def this() = this(None, new PrintWriter(Console.out), None) - /** The input stream from which interpreter commands come */ - var in: InteractiveReader = _ //set by main() + /** The input stream from which commands come, set by main() */ + var in: SparkInteractiveReader = _ /** The context class loader at the time this object was created */ - protected val originalClassLoader = - Thread.currentThread.getContextClassLoader + protected val originalClassLoader = Thread.currentThread.getContextClassLoader - var settings: Settings = _ // set by main() - var interpreter: SparkInterpreter = null // set by createInterpreter() - def isettings = interpreter.isettings + var settings: Settings = _ // set by main() + var interpreter: SparkInterpreter = _ // set by createInterpreter() + + // classpath entries added via :cp + var addedClasspath: String = "" - /** A reverse list of commands to replay if the user - * requests a :replay */ - var replayCommandsRev: List[String] = Nil + /** A reverse list of commands to replay if the user requests a :replay */ + var replayCommandStack: List[String] = Nil /** A list of commands to replay if the user requests a :replay */ - def replayCommands = replayCommandsRev.reverse + def replayCommands = replayCommandStack.reverse - /** Record a command for replay should the user requset a :replay */ - def addReplay(cmd: String) = - replayCommandsRev = cmd :: replayCommandsRev + /** Record a command for replay should the user request a :replay */ + def addReplay(cmd: String) = replayCommandStack ::= cmd - /** Close the interpreter, if there is one, and set - * interpreter to <code>null</code>. */ + /** Close the interpreter and set the var to <code>null</code>. */ def closeInterpreter() { if (interpreter ne null) { interpreter.close @@ -75,45 +123,30 @@ class SparkInterpreterLoop(in0: Option[BufferedReader], val out: PrintWriter, } } - /** Create a new interpreter. Close the old one, if there - * is one. */ + /** Create a new interpreter. */ def createInterpreter() { - //closeInterpreter() - + if (addedClasspath != "") + settings.classpath append addedClasspath + interpreter = new SparkInterpreter(settings, out) { - override protected def parentClassLoader = - classOf[SparkInterpreterLoop].getClassLoader + override protected def parentClassLoader = classOf[SparkInterpreterLoop].getClassLoader } interpreter.setContextClassLoader() + // interpreter.quietBind("settings", "spark.repl.SparkInterpreterSettings", interpreter.isettings) } - /** Bind the settings so that evaluated code can modiy them */ - def bindSettings() { - interpreter.beQuietDuring { - interpreter.compileString(InterpreterSettings.sourceCodeForClass) - - interpreter.bind( - "settings", - "scala.tools.nsc.InterpreterSettings", - isettings) - } - } - - /** print a friendly help message */ - def printHelp { - //printWelcome - out.println("This is Scala " + Properties.versionString + " (" + - System.getProperty("java.vm.name") + ", Java " + System.getProperty("java.version") + ")." ) - out.println("Type in expressions to have them evaluated.") - out.println("Type :load followed by a filename to load a Scala file.") - //out.println("Type :replay to reset execution and replay all previous commands.") - out.println("Type :quit to exit the interpreter.") - } + def printHelp() = { + out println "All commands can be abbreviated - for example :he instead of :help.\n" + val cmds = commands map (x => (x.usage, x.help)) + val width: Int = cmds map { case (x, _) => x.length } max + val formatStr = "%-" + width + "s %s" + cmds foreach { case (usage, help) => out println formatStr.format(usage, help) } + } /** Print a welcome message */ def printWelcome() { - out.println("""Welcome to + plushln("""Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ @@ -121,12 +154,111 @@ class SparkInterpreterLoop(in0: Option[BufferedReader], val out: PrintWriter, /_/ """) - out.println("Using Scala " + Properties.versionString + " (" + - System.getProperty("java.vm.name") + ", Java " + - System.getProperty("java.version") + ")." ) - out.flush() + import Properties._ + val welcomeMsg = "Using Scala %s (%s, Java %s)".format( + versionString, javaVmName, javaVersion) + plushln(welcomeMsg) + } + + /** Show the history */ + def printHistory(xs: List[String]) { + val defaultLines = 20 + + if (in.history.isEmpty) + return println("No history available.") + + val current = in.history.get.index + val count = try xs.head.toInt catch { case _: Exception => defaultLines } + val lines = in.historyList takeRight count + val offset = current - lines.size + 1 + + for ((line, index) <- lines.zipWithIndex) + println("%d %s".format(index + offset, line)) + } + + /** Some print conveniences */ + def println(x: Any) = out println x + def plush(x: Any) = { out print x ; out.flush() } + def plushln(x: Any) = { out println x ; out.flush() } + + /** Search the history */ + def searchHistory(_cmdline: String) { + val cmdline = _cmdline.toLowerCase + + if (in.history.isEmpty) + return println("No history available.") + + val current = in.history.get.index + val offset = current - in.historyList.size + 1 + + for ((line, index) <- in.historyList.zipWithIndex ; if line.toLowerCase contains cmdline) + println("%d %s".format(index + offset, line)) + } + + /** Prompt to print when awaiting input */ + val prompt = Properties.shellPromptString + + // most commands do not want to micromanage the Result, but they might want + // to print something to the console, so we accomodate Unit and String returns. + object CommandImplicits { + implicit def u2ir(x: Unit): Result = defaultResult + implicit def s2ir(s: String): Result = { + out println s + defaultResult + } + } + + /** Standard commands **/ + val standardCommands: List[Command] = { + import CommandImplicits._ + List( + OneArg("cp", "add an entry (jar or directory) to the classpath", addClasspath), + NoArgs("help", "print this help message", printHelp), + VarArgs("history", "show the history (optional arg: lines to show)", printHistory), + LineArg("h?", "search the history", searchHistory), + OneArg("load", "load and interpret a Scala file", load), + NoArgs("power", "enable power user mode", power), + NoArgs("quit", "exit the interpreter", () => Result(false, None)), + NoArgs("replay", "reset execution and replay all previous commands", replay), + LineArg("sh", "fork a shell and run a command", runShellCmd), + NoArgs("silent", "disable/enable automatic printing of results", verbosity) + ) + } + + /** Power user commands */ + var powerUserOn = false + val powerCommands: List[Command] = { + import CommandImplicits._ + List( + OneArg("completions", "generate list of completions for a given String", completions), + NoArgs("dump", "displays a view of the interpreter's internal state", () => interpreter.power.dump()) + + // VarArgs("tree", "displays ASTs for specified identifiers", + // (xs: List[String]) => interpreter dumpTrees xs) + // LineArg("meta", "given code which produces scala code, executes the results", + // (xs: List[String]) => ) + ) } + /** Available commands */ + def commands: List[Command] = standardCommands ::: (if (powerUserOn) powerCommands else Nil) + + def initializeSpark() { + interpreter.beQuietDuring { + command(""" + spark.repl.Main.interp.out.println("Registering with Mesos..."); + spark.repl.Main.interp.out.flush(); + @transient val sc = spark.repl.Main.interp.createSparkContext(); + sc.waitForRegister(); + spark.repl.Main.interp.out.println("Spark context available as sc."); + spark.repl.Main.interp.out.flush(); + """) + command("import spark.SparkContext._"); + } + plushln("Type in expressions to have them evaluated.") + plushln("Type :help for more information.") + } + def createSparkContext(): SparkContext = { val master = this.master match { case Some(m) => m @@ -137,88 +269,41 @@ class SparkInterpreterLoop(in0: Option[BufferedReader], val out: PrintWriter, } new SparkContext(master, "Spark REPL") } - - /** Prompt to print when awaiting input */ - val prompt = Properties.shellPromptString /** The main read-eval-print loop for the interpreter. It calls * <code>command()</code> for each line of input, and stops when * <code>command()</code> returns <code>false</code>. */ - def repl() { - out.println("Intializing...") - out.flush() - interpreter.beQuietDuring { - command(""" - spark.repl.Main.interp.out.println("Registering with Nexus..."); - @transient val sc = spark.repl.Main.interp.createSparkContext(); - sc.waitForRegister(); - spark.repl.Main.interp.out.println("Spark context available as sc.") - """) - command("import spark.SparkContext._"); + def repl() { + def readOneLine() = { + out.flush + in readLine prompt } - out.println("Type in expressions to have them evaluated.") - out.println("Type :help for more information.") - out.flush() - - var first = true - while (true) { - out.flush() - - val line = - if (first) { - /* For some reason, the first interpreted command always takes - * a second or two. So, wait until the welcome message - * has been printed before calling bindSettings. That way, - * the user can read the welcome message while this - * command executes. - */ - val futLine = scala.concurrent.ops.future(in.readLine(prompt)) - - bindSettings() - first = false - - futLine() - } else { - in.readLine(prompt) - } - - if (line eq null) - return () // assumes null means EOF - - val (keepGoing, finalLineMaybe) = command(line) - - if (!keepGoing) - return - - finalLineMaybe match { - case Some(finalLine) => addReplay(finalLine) - case None => () + // return false if repl should exit + def processLine(line: String): Boolean = + if (line eq null) false // assume null means EOF + else command(line) match { + case Result(false, _) => false + case Result(_, Some(finalLine)) => addReplay(finalLine) ; true + case _ => true } - } + + while (processLine(readOneLine)) { } } /** interpret all lines from a specified file */ - def interpretAllFrom(filename: String) { - val fileIn = try { - new FileReader(filename) - } catch { - case _:IOException => - out.println("Error opening file: " + filename) - return - } + def interpretAllFrom(file: File) { val oldIn = in - val oldReplay = replayCommandsRev - try { - val inFile = new BufferedReader(fileIn) - in = new SimpleReader(inFile, out, false) - out.println("Loading " + filename + "...") - out.flush - repl - } finally { + val oldReplay = replayCommandStack + + try file applyReader { reader => + in = new SparkSimpleReader(reader, out, false) + plushln("Loading " + file + "...") + repl() + } + finally { in = oldIn - replayCommandsRev = oldReplay - fileIn.close + replayCommandStack = oldReplay } } @@ -227,58 +312,195 @@ class SparkInterpreterLoop(in0: Option[BufferedReader], val out: PrintWriter, closeInterpreter() createInterpreter() for (cmd <- replayCommands) { - out.println("Replaying: " + cmd) - out.flush() // because maybe cmd will have its own output + plushln("Replaying: " + cmd) // flush because maybe cmd will have its own output command(cmd) out.println } } + + /** fork a shell and run a command */ + def runShellCmd(line: String) { + // we assume if they're using :sh they'd appreciate being able to pipeline + interpreter.beQuietDuring { + interpreter.interpret("import _root_.scala.tools.nsc.io.Process.Pipe._") + } + val p = Process(line) + // only bind non-empty streams + def add(name: String, it: Iterator[String]) = + if (it.hasNext) interpreter.bind(name, "scala.List[String]", it.toList) + + List(("stdout", p.stdout), ("stderr", p.stderr)) foreach (add _).tupled + } + + def withFile(filename: String)(action: File => Unit) { + val f = File(filename) + + if (f.exists) action(f) + else out.println("That file does not exist") + } + + def load(arg: String) = { + var shouldReplay: Option[String] = None + withFile(arg)(f => { + interpretAllFrom(f) + shouldReplay = Some(":load " + arg) + }) + Result(true, shouldReplay) + } - /** Run one command submitted by the user. Three values are returned: + def addClasspath(arg: String): Unit = { + val f = File(arg).normalize + if (f.exists) { + addedClasspath = ClassPath.join(addedClasspath, f.path) + val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath) + println("Added '%s'. Your new classpath is:\n%s".format(f.path, totalClasspath)) + replay() + } + else out.println("The path '" + f + "' doesn't seem to exist.") + } + + def completions(arg: String): Unit = { + val comp = in.completion getOrElse { return println("Completion unavailable.") } + val xs = comp completions arg + + injectAndName(xs) + } + + def power() { + val powerUserBanner = + """** Power User mode enabled - BEEP BOOP ** + |** scala.tools.nsc._ has been imported ** + |** New vals! Try repl, global, power ** + |** New cmds! :help to discover them ** + |** New defs! Type power.<tab> to reveal **""".stripMargin + + powerUserOn = true + interpreter.unleash() + injectOne("history", in.historyList) + in.completion foreach (x => injectOne("completion", x)) + out println powerUserBanner + } + + def verbosity() = { + val old = interpreter.printResults + interpreter.printResults = !old + out.println("Switched " + (if (old) "off" else "on") + " result printing.") + } + + /** Run one command submitted by the user. Two values are returned: * (1) whether to keep running, (2) the line to record for replay, * if any. */ - def command(line: String): (Boolean, Option[String]) = { - def withFile(command: String)(action: String => Unit) { - val spaceIdx = command.indexOf(' ') - if (spaceIdx <= 0) { - out.println("That command requires a filename to be specified.") - return () - } - val filename = command.substring(spaceIdx).trim - if (! new File(filename).exists) { - out.println("That file does not exist") - return () - } - action(filename) + def command(line: String): Result = { + def withError(msg: String) = { + out println msg + Result(true, None) } + def ambiguous(cmds: List[Command]) = "Ambiguous: did you mean " + cmds.map(":" + _.name).mkString(" or ") + "?" - val helpRegexp = ":h(e(l(p)?)?)?" - val quitRegexp = ":q(u(i(t)?)?)?" - val loadRegexp = ":l(o(a(d)?)?)?.*" - //val replayRegexp = ":r(e(p(l(a(y)?)?)?)?)?.*" - - var shouldReplay: Option[String] = None + // not a command + if (!line.startsWith(":")) { + // Notice failure to create compiler + if (interpreter.compiler == null) return Result(false, None) + else return Result(true, interpretStartingWith(line)) + } - if (line.matches(helpRegexp)) - printHelp - else if (line.matches(quitRegexp)) - return (false, None) - else if (line.matches(loadRegexp)) { - withFile(line)(f => { - interpretAllFrom(f) - shouldReplay = Some(line) - }) + val tokens = (line drop 1 split """\s+""").toList + if (tokens.isEmpty) + return withError(ambiguous(commands)) + + val (cmd :: args) = tokens + + // this lets us add commands willy-nilly and only requires enough command to disambiguate + commands.filter(_.name startsWith cmd) match { + case List(x) => x(args) + case Nil => withError("Unknown command. Type :help for help.") + case xs => withError(ambiguous(xs)) } - //else if (line matches replayRegexp) - // replay - else if (line startsWith ":") - out.println("Unknown command. Type :help for help.") - else - shouldReplay = interpretStartingWith(line) + } - (true, shouldReplay) + private val CONTINUATION_STRING = " | " + private val PROMPT_STRING = "scala> " + + /** If it looks like they're pasting in a scala interpreter + * transcript, remove all the formatting we inserted so we + * can make some sense of it. + */ + private var pasteStamp: Long = 0 + + /** Returns true if it's long enough to quit. */ + def updatePasteStamp(): Boolean = { + /* Enough milliseconds between readLines to call it a day. */ + val PASTE_FINISH = 1000 + + val prevStamp = pasteStamp + pasteStamp = System.currentTimeMillis + + (pasteStamp - prevStamp > PASTE_FINISH) + } + /** TODO - we could look for the usage of resXX variables in the transcript. + * Right now backreferences to auto-named variables will break. + */ + /** The trailing lines complication was an attempt to work around the introduction + * of newlines in e.g. email messages of repl sessions. It doesn't work because + * an unlucky newline can always leave you with a syntactically valid first line, + * which is executed before the next line is considered. So this doesn't actually + * accomplish anything, but I'm leaving it in case I decide to try harder. + */ + case class PasteCommand(cmd: String, trailing: ListBuffer[String] = ListBuffer[String]()) + + /** Commands start on lines beginning with "scala>" and each successive + * line which begins with the continuation string is appended to that command. + * Everything else is discarded. When the end of the transcript is spotted, + * all the commands are replayed. + */ + @tailrec private def cleanTranscript(lines: List[String], acc: List[PasteCommand]): List[PasteCommand] = lines match { + case Nil => acc.reverse + case x :: xs if x startsWith PROMPT_STRING => + val first = x stripPrefix PROMPT_STRING + val (xs1, xs2) = xs span (_ startsWith CONTINUATION_STRING) + val rest = xs1 map (_ stripPrefix CONTINUATION_STRING) + val result = (first :: rest).mkString("", "\n", "\n") + + cleanTranscript(xs2, PasteCommand(result) :: acc) + + case ln :: lns => + val newacc = acc match { + case Nil => Nil + case PasteCommand(cmd, trailing) :: accrest => + PasteCommand(cmd, trailing :+ ln) :: accrest + } + cleanTranscript(lns, newacc) + } + + /** The timestamp is for safety so it doesn't hang looking for the end + * of a transcript. Ad hoc parsing can't be too demanding. You can + * also use ctrl-D to start it parsing. + */ + @tailrec private def interpretAsPastedTranscript(lines: List[String]) { + val line = in.readLine("") + val finished = updatePasteStamp() + + if (line == null || finished || line.trim == PROMPT_STRING.trim) { + val xs = cleanTranscript(lines.reverse, Nil) + println("Replaying %d commands from interpreter transcript." format xs.size) + for (PasteCommand(cmd, trailing) <- xs) { + out.flush() + def runCode(code: String, extraLines: List[String]) { + (interpreter interpret code) match { + case IR.Incomplete if extraLines.nonEmpty => + runCode(code + "\n" + extraLines.head, extraLines.tail) + case _ => () + } + } + runCode(cmd, trailing.toList) + } + } + else + interpretAsPastedTranscript(line :: lines) + } + /** Interpret expressions starting with the first line. * Read lines until a complete compilation unit is available * or until a syntax error has been seen. If a full unit is @@ -286,81 +508,151 @@ class SparkInterpreterLoop(in0: Option[BufferedReader], val out: PrintWriter, * to be recorded for replay, if any. */ def interpretStartingWith(code: String): Option[String] = { - interpreter.interpret(code) match { - case IR.Success => Some(code) - case IR.Error => None - case IR.Incomplete => + // signal completion non-completion input has been received + in.completion foreach (_.resetVerbosity()) + + def reallyInterpret = interpreter.interpret(code) match { + case IR.Error => None + case IR.Success => Some(code) + case IR.Incomplete => if (in.interactive && code.endsWith("\n\n")) { out.println("You typed two blank lines. Starting a new command.") None - } else { - val nextLine = in.readLine(" | ") - if (nextLine == null) - None // end of file - else - interpretStartingWith(code + "\n" + nextLine) + } + else in.readLine(CONTINUATION_STRING) match { + case null => + // we know compilation is going to fail since we're at EOF and the + // parser thinks the input is still incomplete, but since this is + // a file being read non-interactively we want to fail. So we send + // it straight to the compiler for the nice error message. + interpreter.compileString(code) + None + + case line => interpretStartingWith(code + "\n" + line) } } + + /** Here we place ourselves between the user and the interpreter and examine + * the input they are ostensibly submitting. We intervene in several cases: + * + * 1) If the line starts with "scala> " it is assumed to be an interpreter paste. + * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation + * on the previous result. + * 3) If the Completion object's execute returns Some(_), we inject that value + * and avoid the interpreter, as it's likely not valid scala code. + */ + if (code == "") None + else if (code startsWith PROMPT_STRING) { + updatePasteStamp() + interpretAsPastedTranscript(List(code)) + None + } + else if (Completion.looksLikeInvocation(code) && interpreter.mostRecentVar != "") { + interpretStartingWith(interpreter.mostRecentVar + code) + } + else { + val result = for (comp <- in.completion ; res <- comp execute code) yield res + result match { + case Some(res) => injectAndName(res) ; None // completion took responsibility, so do not parse + case _ => reallyInterpret + } + } } - def loadFiles(settings: Settings) { - settings match { - case settings: GenericRunnerSettings => - for (filename <- settings.loadfiles.value) { - val cmd = ":load " + filename - command(cmd) - replayCommandsRev = cmd :: replayCommandsRev - out.println() - } - case _ => - } + // runs :load <file> on any files passed via -i + def loadFiles(settings: Settings) = settings match { + case settings: GenericRunnerSettings => + for (filename <- settings.loadfiles.value) { + val cmd = ":load " + filename + command(cmd) + addReplay(cmd) + out.println() + } + case _ => } def main(settings: Settings) { this.settings = settings - - in = - in0 match { - case Some(in0) => - new SimpleReader(in0, out, true) - - case None => - val emacsShell = System.getProperty("env.emacs", "") != "" - //println("emacsShell="+emacsShell) //debug - if (settings.Xnojline.value || emacsShell) - new SimpleReader() - else - InteractiveReader.createDefault() - } - createInterpreter() + + // sets in to some kind of reader depending on environmental cues + in = in0 match { + case Some(in0) => new SparkSimpleReader(in0, out, true) + case None => + // the interpreter is passed as an argument to expose tab completion info + if (settings.Xnojline.value || Properties.isEmacsShell) new SparkSimpleReader + else if (settings.noCompletion.value) SparkInteractiveReader.createDefault() + else SparkInteractiveReader.createDefault(interpreter) + } loadFiles(settings) - try { - if (interpreter.reporter.hasErrors) { - return // it is broken on startup; go ahead and exit - } + // it is broken on startup; go ahead and exit + if (interpreter.reporter.hasErrors) return + printWelcome() + + // this is about the illusion of snappiness. We call initialize() + // which spins off a separate thread, then print the prompt and try + // our best to look ready. Ideally the user will spend a + // couple seconds saying "wow, it starts so fast!" and by the time + // they type a command the compiler is ready to roll. + interpreter.initialize() + initializeSpark() repl() - } finally { - closeInterpreter() } + finally closeInterpreter() + } + + private def objClass(x: Any) = x.asInstanceOf[AnyRef].getClass + private def objName(x: Any) = { + val clazz = objClass(x) + val typeParams = clazz.getTypeParameters + val basename = clazz.getName + val tpString = if (typeParams.isEmpty) "" else "[%s]".format(typeParams map (_ => "_") mkString ", ") + + basename + tpString + } + + // injects one value into the repl; returns pair of name and class + def injectOne(name: String, obj: Any): Tuple2[String, String] = { + val className = objName(obj) + interpreter.quietBind(name, className, obj) + (name, className) + } + def injectAndName(obj: Any): Tuple2[String, String] = { + val name = interpreter.getVarName + val className = objName(obj) + interpreter.bind(name, className, obj) + (name, className) + } + + // injects list of values into the repl; returns summary string + def injectDebug(args: List[Any]): String = { + val strs = + for ((arg, i) <- args.zipWithIndex) yield { + val varName = "p" + (i + 1) + val (vname, vtype) = injectOne(varName, arg) + vname + ": " + vtype + } + + if (strs.size == 0) "Set no variables." + else "Variables set:\n" + strs.mkString("\n") } /** process command-line arguments and do as they request */ def main(args: Array[String]) { - def error1(msg: String) { out.println("scala: " + msg) } - val command = new InterpreterCommand(List.fromArray(args), error1) - - if (!command.ok || command.settings.help.value || command.settings.Xhelp.value) { - // either the command line is wrong, or the user - // explicitly requested a help listing - if (command.settings.help.value) out.println(command.usageMsg) - if (command.settings.Xhelp.value) out.println(command.xusageMsg) - out.flush + def error1(msg: String) = out println ("scala: " + msg) + val command = new InterpreterCommand(args.toList, error1) + def neededHelp(): String = + (if (command.settings.help.value) command.usageMsg + "\n" else "") + + (if (command.settings.Xhelp.value) command.xusageMsg + "\n" else "") + + // if they asked for no help and command is valid, we call the real main + neededHelp() match { + case "" => if (command.ok) main(command.settings) // else nothing + case help => plush(help) } - else - main(command.settings) } } + diff --git a/src/scala/spark/repl/SparkInterpreterSettings.scala b/src/scala/spark/repl/SparkInterpreterSettings.scala new file mode 100644 index 0000000000..ffa477785b --- /dev/null +++ b/src/scala/spark/repl/SparkInterpreterSettings.scala @@ -0,0 +1,112 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2010 LAMP/EPFL + * @author Alexander Spoon + */ + +package spark.repl + +import scala.tools.nsc +import scala.tools.nsc._ + +/** Settings for the interpreter + * + * @version 1.0 + * @author Lex Spoon, 2007/3/24 + **/ +class SparkInterpreterSettings(repl: SparkInterpreter) { + /** A list of paths where :load should look */ + var loadPath = List(".") + + /** The maximum length of toString to use when printing the result + * of an evaluation. 0 means no maximum. If a printout requires + * more than this number of characters, then the printout is + * truncated. + */ + var maxPrintString = 800 + + /** The maximum number of completion candidates to print for tab + * completion without requiring confirmation. + */ + var maxAutoprintCompletion = 250 + + /** String unwrapping can be disabled if it is causing issues. + * Settings this to false means you will see Strings like "$iw.$iw.". + */ + var unwrapStrings = true + + def deprecation_=(x: Boolean) = { + val old = repl.settings.deprecation.value + repl.settings.deprecation.value = x + if (!old && x) println("Enabled -deprecation output.") + else if (old && !x) println("Disabled -deprecation output.") + } + def deprecation: Boolean = repl.settings.deprecation.value + + def allSettings = Map( + "maxPrintString" -> maxPrintString, + "maxAutoprintCompletion" -> maxAutoprintCompletion, + "unwrapStrings" -> unwrapStrings, + "deprecation" -> deprecation + ) + + private def allSettingsString = + allSettings.toList sortBy (_._1) map { case (k, v) => " " + k + " = " + v + "\n" } mkString + + override def toString = """ + | SparkInterpreterSettings { + | %s + | }""".stripMargin.format(allSettingsString) +} + +/* Utilities for the InterpreterSettings class + * + * @version 1.0 + * @author Lex Spoon, 2007/5/24 + */ +object SparkInterpreterSettings { + /** Source code for the InterpreterSettings class. This is + * used so that the interpreter is sure to have the code + * available. + * + * XXX I'm not seeing why this degree of defensiveness is necessary. + * If files are missing the repl's not going to work, it's not as if + * we have string source backups for anything else. + */ + val sourceCodeForClass = +""" +package scala.tools.nsc + +/** Settings for the interpreter + * + * @version 1.0 + * @author Lex Spoon, 2007/3/24 + **/ +class SparkInterpreterSettings(repl: Interpreter) { + /** A list of paths where :load should look */ + var loadPath = List(".") + + /** The maximum length of toString to use when printing the result + * of an evaluation. 0 means no maximum. If a printout requires + * more than this number of characters, then the printout is + * truncated. + */ + var maxPrintString = 2400 + + def deprecation_=(x: Boolean) = { + val old = repl.settings.deprecation.value + repl.settings.deprecation.value = x + if (!old && x) println("Enabled -deprecation output.") + else if (old && !x) println("Disabled -deprecation output.") + } + def deprecation: Boolean = repl.settings.deprecation.value + + override def toString = + "SparkInterpreterSettings {\n" + +// " loadPath = " + loadPath + "\n" + + " maxPrintString = " + maxPrintString + "\n" + + "}" +} + +""" + +} diff --git a/src/scala/spark/repl/SparkJLineReader.scala b/src/scala/spark/repl/SparkJLineReader.scala new file mode 100644 index 0000000000..9d761c06fc --- /dev/null +++ b/src/scala/spark/repl/SparkJLineReader.scala @@ -0,0 +1,38 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2010 LAMP/EPFL + * @author Stepan Koltsov + */ + +package spark.repl + +import scala.tools.nsc +import scala.tools.nsc._ +import scala.tools.nsc.interpreter +import scala.tools.nsc.interpreter._ + +import java.io.File +import jline.{ ConsoleReader, ArgumentCompletor, History => JHistory } + +/** Reads from the console using JLine */ +class SparkJLineReader(interpreter: SparkInterpreter) extends SparkInteractiveReader { + def this() = this(null) + + override lazy val history = Some(History(consoleReader)) + override lazy val completion = Option(interpreter) map (x => new SparkCompletion(x)) + + val consoleReader = { + val r = new jline.ConsoleReader() + r setHistory (History().jhistory) + r setBellEnabled false + completion foreach { c => + r addCompletor c.jline + r setAutoprintThreshhold 250 + } + + r + } + + def readOneLine(prompt: String) = consoleReader readLine prompt + val interactive = true +} + diff --git a/src/scala/spark/repl/SparkSimpleReader.scala b/src/scala/spark/repl/SparkSimpleReader.scala new file mode 100644 index 0000000000..2b24c4bf63 --- /dev/null +++ b/src/scala/spark/repl/SparkSimpleReader.scala @@ -0,0 +1,33 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2010 LAMP/EPFL + * @author Stepan Koltsov + */ + +package spark.repl + +import scala.tools.nsc +import scala.tools.nsc._ +import scala.tools.nsc.interpreter +import scala.tools.nsc.interpreter._ + +import java.io.{ BufferedReader, PrintWriter } +import io.{ Path, File, Directory } + +/** Reads using standard JDK API */ +class SparkSimpleReader( + in: BufferedReader, + out: PrintWriter, + val interactive: Boolean) +extends SparkInteractiveReader { + def this() = this(Console.in, new PrintWriter(Console.out), true) + def this(in: File, out: PrintWriter, interactive: Boolean) = this(in.bufferedReader(), out, interactive) + + def close() = in.close() + def readOneLine(prompt: String): String = { + if (interactive) { + out.print(prompt) + out.flush() + } + in.readLine() + } +} |