aboutsummaryrefslogtreecommitdiff
path: root/repl/src
diff options
context:
space:
mode:
authorIsmael Juma <ismael@juma.me.uk>2011-05-27 09:37:34 +0100
committerIsmael Juma <ismael@juma.me.uk>2011-05-27 11:22:50 +0100
commit1396678baa0a0b9b47e50bb2da4970aca1351b2e (patch)
tree6bedbe61370dfba294ae7043c3d5615914b79d85 /repl/src
parent3e8114ddbdf598724d0e6cee8507b6afa111f7f3 (diff)
downloadspark-1396678baa0a0b9b47e50bb2da4970aca1351b2e.tar.gz
spark-1396678baa0a0b9b47e50bb2da4970aca1351b2e.tar.bz2
spark-1396678baa0a0b9b47e50bb2da4970aca1351b2e.zip
Move REPL classes to separate module.
Diffstat (limited to 'repl/src')
-rw-r--r--repl/src/main/scala/spark/repl/ExecutorClassLoader.scala108
-rw-r--r--repl/src/main/scala/spark/repl/Main.scala16
-rw-r--r--repl/src/main/scala/spark/repl/SparkCompletion.scala353
-rw-r--r--repl/src/main/scala/spark/repl/SparkCompletionOutput.scala92
-rw-r--r--repl/src/main/scala/spark/repl/SparkInteractiveReader.scala60
-rw-r--r--repl/src/main/scala/spark/repl/SparkInterpreter.scala1395
-rw-r--r--repl/src/main/scala/spark/repl/SparkInterpreterLoop.scala662
-rw-r--r--repl/src/main/scala/spark/repl/SparkInterpreterSettings.scala112
-rw-r--r--repl/src/main/scala/spark/repl/SparkJLineReader.scala38
-rw-r--r--repl/src/main/scala/spark/repl/SparkSimpleReader.scala33
-rw-r--r--repl/src/test/scala/spark/repl/ReplSuite.scala144
11 files changed, 3013 insertions, 0 deletions
diff --git a/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala
new file mode 100644
index 0000000000..13d81ec1cf
--- /dev/null
+++ b/repl/src/main/scala/spark/repl/ExecutorClassLoader.scala
@@ -0,0 +1,108 @@
+package spark.repl
+
+import java.io.{ByteArrayOutputStream, InputStream}
+import java.net.{URI, URL, URLClassLoader, URLEncoder}
+import java.util.concurrent.{Executors, ExecutorService}
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+
+import org.objectweb.asm._
+import org.objectweb.asm.commons.EmptyVisitor
+import org.objectweb.asm.Opcodes._
+
+
+/**
+ * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI,
+ * used to load classes defined by the interpreter when the REPL is used
+ */
+class ExecutorClassLoader(classUri: String, parent: ClassLoader)
+extends ClassLoader(parent) {
+ val uri = new URI(classUri)
+ val directory = uri.getPath
+
+ // Hadoop FileSystem object for our URI, if it isn't using HTTP
+ var fileSystem: FileSystem = {
+ if (uri.getScheme() == "http")
+ null
+ else
+ FileSystem.get(uri, new Configuration())
+ }
+
+ override def findClass(name: String): Class[_] = {
+ try {
+ val pathInDirectory = name.replace('.', '/') + ".class"
+ val inputStream = {
+ if (fileSystem != null)
+ fileSystem.open(new Path(directory, pathInDirectory))
+ else
+ new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream()
+ }
+ val bytes = readAndTransformClass(name, inputStream)
+ inputStream.close()
+ return defineClass(name, bytes, 0, bytes.length)
+ } catch {
+ case e: Exception => throw new ClassNotFoundException(name, e)
+ }
+ }
+
+ def readAndTransformClass(name: String, in: InputStream): Array[Byte] = {
+ if (name.startsWith("line") && name.endsWith("$iw$")) {
+ // Class seems to be an interpreter "wrapper" object storing a val or var.
+ // Replace its constructor with a dummy one that does not run the
+ // initialization code placed there by the REPL. The val or var will
+ // be initialized later through reflection when it is used in a task.
+ val cr = new ClassReader(in)
+ val cw = new ClassWriter(
+ ClassWriter.COMPUTE_FRAMES + ClassWriter.COMPUTE_MAXS)
+ val cleaner = new ConstructorCleaner(name, cw)
+ cr.accept(cleaner, 0)
+ return cw.toByteArray
+ } else {
+ // Pass the class through unmodified
+ val bos = new ByteArrayOutputStream
+ val bytes = new Array[Byte](4096)
+ var done = false
+ while (!done) {
+ val num = in.read(bytes)
+ if (num >= 0)
+ bos.write(bytes, 0, num)
+ else
+ done = true
+ }
+ return bos.toByteArray
+ }
+ }
+
+ /**
+ * URL-encode a string, preserving only slashes
+ */
+ def urlEncode(str: String): String = {
+ str.split('/').map(part => URLEncoder.encode(part, "UTF-8")).mkString("/")
+ }
+}
+
+class ConstructorCleaner(className: String, cv: ClassVisitor)
+extends ClassAdapter(cv) {
+ override def visitMethod(access: Int, name: String, desc: String,
+ sig: String, exceptions: Array[String]): MethodVisitor = {
+ val mv = cv.visitMethod(access, name, desc, sig, exceptions)
+ if (name == "<init>" && (access & ACC_STATIC) == 0) {
+ // This is the constructor, time to clean it; just output some new
+ // instructions to mv that create the object and set the static MODULE$
+ // field in the class to point to it, but do nothing otherwise.
+ mv.visitCode()
+ mv.visitVarInsn(ALOAD, 0) // load this
+ mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V")
+ mv.visitVarInsn(ALOAD, 0) // load this
+ //val classType = className.replace('.', '/')
+ //mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";")
+ mv.visitInsn(RETURN)
+ mv.visitMaxs(-1, -1) // stack size and local vars will be auto-computed
+ mv.visitEnd()
+ return null
+ } else {
+ return mv
+ }
+ }
+}
diff --git a/repl/src/main/scala/spark/repl/Main.scala b/repl/src/main/scala/spark/repl/Main.scala
new file mode 100644
index 0000000000..f00df5aa58
--- /dev/null
+++ b/repl/src/main/scala/spark/repl/Main.scala
@@ -0,0 +1,16 @@
+package spark.repl
+
+import scala.collection.mutable.Set
+
+object Main {
+ private var _interp: SparkInterpreterLoop = null
+
+ def interp = _interp
+
+ private[repl] def interp_=(i: SparkInterpreterLoop) { _interp = i }
+
+ def main(args: Array[String]) {
+ _interp = new SparkInterpreterLoop
+ _interp.main(args)
+ }
+}
diff --git a/repl/src/main/scala/spark/repl/SparkCompletion.scala b/repl/src/main/scala/spark/repl/SparkCompletion.scala
new file mode 100644
index 0000000000..c6ed1860f0
--- /dev/null
+++ b/repl/src/main/scala/spark/repl/SparkCompletion.scala
@@ -0,0 +1,353 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2010 LAMP/EPFL
+ * @author Paul Phillips
+ */
+
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+import scala.tools.nsc.interpreter
+import scala.tools.nsc.interpreter._
+
+import jline._
+import java.util.{ List => JList }
+import util.returning
+
+object SparkCompletion {
+ def looksLikeInvocation(code: String) = (
+ (code != null)
+ && (code startsWith ".")
+ && !(code == ".")
+ && !(code startsWith "./")
+ && !(code startsWith "..")
+ )
+
+ object Forwarder {
+ def apply(forwardTo: () => Option[CompletionAware]): CompletionAware = new CompletionAware {
+ def completions(verbosity: Int) = forwardTo() map (_ completions verbosity) getOrElse Nil
+ override def follow(s: String) = forwardTo() flatMap (_ follow s)
+ }
+ }
+}
+import SparkCompletion._
+
+// REPL completor - queries supplied interpreter for valid
+// completions based on current contents of buffer.
+class SparkCompletion(val repl: SparkInterpreter) extends SparkCompletionOutput {
+ // verbosity goes up with consecutive tabs
+ private var verbosity: Int = 0
+ def resetVerbosity() = verbosity = 0
+
+ def isCompletionDebug = repl.isCompletionDebug
+ def DBG(msg: => Any) = if (isCompletionDebug) println(msg.toString)
+ def debugging[T](msg: String): T => T = (res: T) => returning[T](res)(x => DBG(msg + x))
+
+ lazy val global: repl.compiler.type = repl.compiler
+ import global._
+ import definitions.{ PredefModule, RootClass, AnyClass, AnyRefClass, ScalaPackage, JavaLangPackage }
+
+ // XXX not yet used.
+ lazy val dottedPaths = {
+ def walk(tp: Type): scala.List[Symbol] = {
+ val pkgs = tp.nonPrivateMembers filter (_.isPackage)
+ pkgs ++ (pkgs map (_.tpe) flatMap walk)
+ }
+ walk(RootClass.tpe)
+ }
+
+ def getType(name: String, isModule: Boolean) = {
+ val f = if (isModule) definitions.getModule(_: Name) else definitions.getClass(_: Name)
+ try Some(f(name).tpe)
+ catch { case _: MissingRequirementError => None }
+ }
+
+ def typeOf(name: String) = getType(name, false)
+ def moduleOf(name: String) = getType(name, true)
+
+ trait CompilerCompletion {
+ def tp: Type
+ def effectiveTp = tp match {
+ case MethodType(Nil, resType) => resType
+ case PolyType(Nil, resType) => resType
+ case _ => tp
+ }
+
+ // for some reason any's members don't show up in subclasses, which
+ // we need so 5.<tab> offers asInstanceOf etc.
+ private def anyMembers = AnyClass.tpe.nonPrivateMembers
+ def anyRefMethodsToShow = List("isInstanceOf", "asInstanceOf", "toString")
+
+ def tos(sym: Symbol) = sym.name.decode.toString
+ def memberNamed(s: String) = members find (x => tos(x) == s)
+ def hasMethod(s: String) = methods exists (x => tos(x) == s)
+
+ // XXX we'd like to say "filterNot (_.isDeprecated)" but this causes the
+ // compiler to crash for reasons not yet known.
+ def members = (effectiveTp.nonPrivateMembers ++ anyMembers) filter (_.isPublic)
+ def methods = members filter (_.isMethod)
+ def packages = members filter (_.isPackage)
+ def aliases = members filter (_.isAliasType)
+
+ def memberNames = members map tos
+ def methodNames = methods map tos
+ def packageNames = packages map tos
+ def aliasNames = aliases map tos
+ }
+
+ object TypeMemberCompletion {
+ def apply(tp: Type): TypeMemberCompletion = {
+ if (tp.typeSymbol.isPackageClass) new PackageCompletion(tp)
+ else new TypeMemberCompletion(tp)
+ }
+ def imported(tp: Type) = new ImportCompletion(tp)
+ }
+
+ class TypeMemberCompletion(val tp: Type) extends CompletionAware with CompilerCompletion {
+ def excludeEndsWith: List[String] = Nil
+ def excludeStartsWith: List[String] = List("<") // <byname>, <repeated>, etc.
+ def excludeNames: List[String] = anyref.methodNames.filterNot(anyRefMethodsToShow.contains) ++ List("_root_")
+
+ def methodSignatureString(sym: Symbol) = {
+ def asString = new MethodSymbolOutput(sym).methodString()
+
+ if (isCompletionDebug)
+ repl.power.showAtAllPhases(asString)
+
+ atPhase(currentRun.typerPhase)(asString)
+ }
+
+ def exclude(name: String): Boolean = (
+ (name contains "$") ||
+ (excludeNames contains name) ||
+ (excludeEndsWith exists (name endsWith _)) ||
+ (excludeStartsWith exists (name startsWith _))
+ )
+ def filtered(xs: List[String]) = xs filterNot exclude distinct
+
+ def completions(verbosity: Int) =
+ debugging(tp + " completions ==> ")(filtered(memberNames))
+
+ override def follow(s: String): Option[CompletionAware] =
+ debugging(tp + " -> '" + s + "' ==> ")(memberNamed(s) map (x => TypeMemberCompletion(x.tpe)))
+
+ override def alternativesFor(id: String): List[String] =
+ debugging(id + " alternatives ==> ") {
+ val alts = members filter (x => x.isMethod && tos(x) == id) map methodSignatureString
+
+ if (alts.nonEmpty) "" :: alts else Nil
+ }
+
+ override def toString = "TypeMemberCompletion(%s)".format(tp)
+ }
+
+ class PackageCompletion(tp: Type) extends TypeMemberCompletion(tp) {
+ override def excludeNames = anyref.methodNames
+ }
+
+ class LiteralCompletion(lit: Literal) extends TypeMemberCompletion(lit.value.tpe) {
+ override def completions(verbosity: Int) = verbosity match {
+ case 0 => filtered(memberNames)
+ case _ => memberNames
+ }
+ }
+
+ class ImportCompletion(tp: Type) extends TypeMemberCompletion(tp) {
+ override def completions(verbosity: Int) = verbosity match {
+ case 0 => filtered(members filterNot (_.isSetter) map tos)
+ case _ => super.completions(verbosity)
+ }
+ }
+
+ // not for completion but for excluding
+ object anyref extends TypeMemberCompletion(AnyRefClass.tpe) { }
+
+ // the unqualified vals/defs/etc visible in the repl
+ object ids extends CompletionAware {
+ override def completions(verbosity: Int) = repl.unqualifiedIds ::: List("classOf")
+ // we try to use the compiler and fall back on reflection if necessary
+ // (which at present is for anything defined in the repl session.)
+ override def follow(id: String) =
+ if (completions(0) contains id) {
+ for (clazz <- repl clazzForIdent id) yield {
+ // XXX The isMemberClass check is a workaround for the crasher described
+ // in the comments of #3431. The issue as described by iulian is:
+ //
+ // Inner classes exist as symbols
+ // inside their enclosing class, but also inside their package, with a mangled
+ // name (A$B). The mangled names should never be loaded, and exist only for the
+ // optimizer, which sometimes cannot get the right symbol, but it doesn't care
+ // and loads the bytecode anyway.
+ //
+ // So this solution is incorrect, but in the short term the simple fix is
+ // to skip the compiler any time completion is requested on a nested class.
+ if (clazz.isMemberClass) new InstanceCompletion(clazz)
+ else (typeOf(clazz.getName) map TypeMemberCompletion.apply) getOrElse new InstanceCompletion(clazz)
+ }
+ }
+ else None
+ }
+
+ // wildcard imports in the repl like "import global._" or "import String._"
+ private def imported = repl.wildcardImportedTypes map TypeMemberCompletion.imported
+
+ // literal Ints, Strings, etc.
+ object literals extends CompletionAware {
+ def simpleParse(code: String): Tree = {
+ val unit = new CompilationUnit(new util.BatchSourceFile("<console>", code))
+ val scanner = new syntaxAnalyzer.UnitParser(unit)
+ val tss = scanner.templateStatSeq(false)._2
+
+ if (tss.size == 1) tss.head else EmptyTree
+ }
+
+ def completions(verbosity: Int) = Nil
+
+ override def follow(id: String) = simpleParse(id) match {
+ case x: Literal => Some(new LiteralCompletion(x))
+ case _ => None
+ }
+ }
+
+ // top level packages
+ object rootClass extends TypeMemberCompletion(RootClass.tpe) { }
+ // members of Predef
+ object predef extends TypeMemberCompletion(PredefModule.tpe) {
+ override def excludeEndsWith = super.excludeEndsWith ++ List("Wrapper", "ArrayOps")
+ override def excludeStartsWith = super.excludeStartsWith ++ List("wrap")
+ override def excludeNames = anyref.methodNames
+
+ override def exclude(name: String) = super.exclude(name) || (
+ (name contains "2")
+ )
+
+ override def completions(verbosity: Int) = verbosity match {
+ case 0 => Nil
+ case _ => super.completions(verbosity)
+ }
+ }
+ // members of scala.*
+ object scalalang extends PackageCompletion(ScalaPackage.tpe) {
+ def arityClasses = List("Product", "Tuple", "Function")
+ def skipArity(name: String) = arityClasses exists (x => name != x && (name startsWith x))
+ override def exclude(name: String) = super.exclude(name) || (
+ skipArity(name)
+ )
+
+ override def completions(verbosity: Int) = verbosity match {
+ case 0 => filtered(packageNames ++ aliasNames)
+ case _ => super.completions(verbosity)
+ }
+ }
+ // members of java.lang.*
+ object javalang extends PackageCompletion(JavaLangPackage.tpe) {
+ override lazy val excludeEndsWith = super.excludeEndsWith ++ List("Exception", "Error")
+ override lazy val excludeStartsWith = super.excludeStartsWith ++ List("CharacterData")
+
+ override def completions(verbosity: Int) = verbosity match {
+ case 0 => filtered(packageNames)
+ case _ => super.completions(verbosity)
+ }
+ }
+
+ // the list of completion aware objects which should be consulted
+ lazy val topLevelBase: List[CompletionAware] = List(ids, rootClass, predef, scalalang, javalang, literals)
+ def topLevel = topLevelBase ++ imported
+
+ // the first tier of top level objects (doesn't include file completion)
+ def topLevelFor(parsed: Parsed) = topLevel flatMap (_ completionsFor parsed)
+
+ // the most recent result
+ def lastResult = Forwarder(() => ids follow repl.mostRecentVar)
+
+ def lastResultFor(parsed: Parsed) = {
+ /** The logic is a little tortured right now because normally '.' is
+ * ignored as a delimiter, but on .<tab> it needs to be propagated.
+ */
+ val xs = lastResult completionsFor parsed
+ if (parsed.isEmpty) xs map ("." + _) else xs
+ }
+
+ // chasing down results which won't parse
+ def execute(line: String): Option[Any] = {
+ val parsed = Parsed(line)
+ def noDotOrSlash = line forall (ch => ch != '.' && ch != '/')
+
+ if (noDotOrSlash) None // we defer all unqualified ids to the repl.
+ else {
+ (ids executionFor parsed) orElse
+ (rootClass executionFor parsed) orElse
+ (FileCompletion executionFor line)
+ }
+ }
+
+ // generic interface for querying (e.g. interpreter loop, testing)
+ def completions(buf: String): List[String] =
+ topLevelFor(Parsed.dotted(buf + ".", buf.length + 1))
+
+ // jline's entry point
+ lazy val jline: ArgumentCompletor =
+ returning(new ArgumentCompletor(new JLineCompletion, new JLineDelimiter))(_ setStrict false)
+
+ /** This gets a little bit hairy. It's no small feat delegating everything
+ * and also keeping track of exactly where the cursor is and where it's supposed
+ * to end up. The alternatives mechanism is a little hacky: if there is an empty
+ * string in the list of completions, that means we are expanding a unique
+ * completion, so don't update the "last" buffer because it'll be wrong.
+ */
+ class JLineCompletion extends Completor {
+ // For recording the buffer on the last tab hit
+ private var lastBuf: String = ""
+ private var lastCursor: Int = -1
+
+ // Does this represent two consecutive tabs?
+ def isConsecutiveTabs(buf: String, cursor: Int) = cursor == lastCursor && buf == lastBuf
+
+ // Longest common prefix
+ def commonPrefix(xs: List[String]) =
+ if (xs.isEmpty) ""
+ else xs.reduceLeft(_ zip _ takeWhile (x => x._1 == x._2) map (_._1) mkString)
+
+ // This is jline's entry point for completion.
+ override def complete(_buf: String, cursor: Int, candidates: java.util.List[java.lang.String]): Int = {
+ val buf = onull(_buf)
+ verbosity = if (isConsecutiveTabs(buf, cursor)) verbosity + 1 else 0
+ DBG("complete(%s, %d) last = (%s, %d), verbosity: %s".format(buf, cursor, lastBuf, lastCursor, verbosity))
+
+ // we don't try lower priority completions unless higher ones return no results.
+ def tryCompletion(p: Parsed, completionFunction: Parsed => List[String]): Option[Int] = {
+ completionFunction(p) match {
+ case Nil => None
+ case xs =>
+ // modify in place and return the position
+ xs.foreach(x => candidates.add(x))
+
+ // update the last buffer unless this is an alternatives list
+ if (xs contains "") Some(p.cursor)
+ else {
+ val advance = commonPrefix(xs)
+ lastCursor = p.position + advance.length
+ lastBuf = (buf take p.position) + advance
+
+ DBG("tryCompletion(%s, _) lastBuf = %s, lastCursor = %s, p.position = %s".format(p, lastBuf, lastCursor, p.position))
+ Some(p.position)
+ }
+ }
+ }
+
+ def mkDotted = Parsed.dotted(buf, cursor) withVerbosity verbosity
+ def mkUndelimited = Parsed.undelimited(buf, cursor) withVerbosity verbosity
+
+ // a single dot is special cased to completion on the previous result
+ def lastResultCompletion =
+ if (!looksLikeInvocation(buf)) None
+ else tryCompletion(Parsed.dotted(buf drop 1, cursor), lastResultFor)
+
+ def regularCompletion = tryCompletion(mkDotted, topLevelFor)
+ def fileCompletion = tryCompletion(mkUndelimited, FileCompletion completionsFor _.buffer)
+
+ (lastResultCompletion orElse regularCompletion orElse fileCompletion) getOrElse cursor
+ }
+ }
+}
diff --git a/repl/src/main/scala/spark/repl/SparkCompletionOutput.scala b/repl/src/main/scala/spark/repl/SparkCompletionOutput.scala
new file mode 100644
index 0000000000..5ac46e3412
--- /dev/null
+++ b/repl/src/main/scala/spark/repl/SparkCompletionOutput.scala
@@ -0,0 +1,92 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2010 LAMP/EPFL
+ * @author Paul Phillips
+ */
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+import scala.tools.nsc.interpreter
+import scala.tools.nsc.interpreter._
+
+/** This has a lot of duplication with other methods in Symbols and Types,
+ * but repl completion utility is very sensitive to precise output. Best
+ * thing would be to abstract an interface for how such things are printed,
+ * as is also in progress with error messages.
+ */
+trait SparkCompletionOutput {
+ self: SparkCompletion =>
+
+ import global._
+ import definitions.{ NothingClass, AnyClass, isTupleType, isFunctionType, isRepeatedParamType }
+
+ /** Reducing fully qualified noise for some common packages.
+ */
+ val typeTransforms = List(
+ "java.lang." -> "",
+ "scala.collection.immutable." -> "immutable.",
+ "scala.collection.mutable." -> "mutable.",
+ "scala.collection.generic." -> "generic."
+ )
+
+ def quietString(tp: String): String =
+ typeTransforms.foldLeft(tp) {
+ case (str, (prefix, replacement)) =>
+ if (str startsWith prefix) replacement + (str stripPrefix prefix)
+ else str
+ }
+
+ class MethodSymbolOutput(method: Symbol) {
+ val pkg = method.ownerChain find (_.isPackageClass) map (_.fullName) getOrElse ""
+
+ def relativize(str: String): String = quietString(str stripPrefix (pkg + "."))
+ def relativize(tp: Type): String = relativize(tp.normalize.toString)
+ def relativize(sym: Symbol): String = relativize(sym.info)
+
+ def braceList(tparams: List[String]) = if (tparams.isEmpty) "" else (tparams map relativize).mkString("[", ", ", "]")
+ def parenList(params: List[Any]) = params.mkString("(", ", ", ")")
+
+ def methodTypeToString(mt: MethodType) =
+ (mt.paramss map paramsString mkString "") + ": " + relativize(mt.finalResultType)
+
+ def typeToString(tp: Type): String = relativize(
+ tp match {
+ case x if isFunctionType(x) => functionString(x)
+ case x if isTupleType(x) => tupleString(x)
+ case x if isRepeatedParamType(x) => typeToString(x.typeArgs.head) + "*"
+ case mt @ MethodType(_, _) => methodTypeToString(mt)
+ case x => x.toString
+ }
+ )
+
+ def tupleString(tp: Type) = parenList(tp.normalize.typeArgs map relativize)
+ def functionString(tp: Type) = tp.normalize.typeArgs match {
+ case List(t, r) => t + " => " + r
+ case xs => parenList(xs.init) + " => " + xs.last
+ }
+
+ def tparamsString(tparams: List[Symbol]) = braceList(tparams map (_.defString))
+ def paramsString(params: List[Symbol]) = {
+ def paramNameString(sym: Symbol) = if (sym.isSynthetic) "" else sym.nameString + ": "
+ def paramString(sym: Symbol) = paramNameString(sym) + typeToString(sym.info.normalize)
+
+ val isImplicit = params.nonEmpty && params.head.isImplicit
+ val strs = (params map paramString) match {
+ case x :: xs if isImplicit => ("implicit " + x) :: xs
+ case xs => xs
+ }
+ parenList(strs)
+ }
+
+ def methodString() =
+ method.keyString + " " + method.nameString + (method.info.normalize match {
+ case PolyType(Nil, resType) => ": " + typeToString(resType) // nullary method
+ case PolyType(tparams, resType) => tparamsString(tparams) + typeToString(resType)
+ case mt @ MethodType(_, _) => methodTypeToString(mt)
+ case x =>
+ DBG("methodString(): %s / %s".format(x.getClass, x))
+ x.toString
+ })
+ }
+}
diff --git a/repl/src/main/scala/spark/repl/SparkInteractiveReader.scala b/repl/src/main/scala/spark/repl/SparkInteractiveReader.scala
new file mode 100644
index 0000000000..4f5a0a6fa0
--- /dev/null
+++ b/repl/src/main/scala/spark/repl/SparkInteractiveReader.scala
@@ -0,0 +1,60 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2010 LAMP/EPFL
+ * @author Stepan Koltsov
+ */
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+import scala.tools.nsc.interpreter
+import scala.tools.nsc.interpreter._
+
+import scala.util.control.Exception._
+
+/** Reads lines from an input stream */
+trait SparkInteractiveReader {
+ import SparkInteractiveReader._
+ import java.io.IOException
+
+ protected def readOneLine(prompt: String): String
+ val interactive: Boolean
+
+ def readLine(prompt: String): String = {
+ def handler: Catcher[String] = {
+ case e: IOException if restartSystemCall(e) => readLine(prompt)
+ }
+ catching(handler) { readOneLine(prompt) }
+ }
+
+ // override if history is available
+ def history: Option[History] = None
+ def historyList = history map (_.asList) getOrElse Nil
+
+ // override if completion is available
+ def completion: Option[SparkCompletion] = None
+
+ // hack necessary for OSX jvm suspension because read calls are not restarted after SIGTSTP
+ private def restartSystemCall(e: Exception): Boolean =
+ Properties.isMac && (e.getMessage == msgEINTR)
+}
+
+
+object SparkInteractiveReader {
+ val msgEINTR = "Interrupted system call"
+ private val exes = List(classOf[Exception], classOf[NoClassDefFoundError])
+
+ def createDefault(): SparkInteractiveReader = createDefault(null)
+
+ /** Create an interactive reader. Uses <code>JLineReader</code> if the
+ * library is available, but otherwise uses a <code>SimpleReader</code>.
+ */
+ def createDefault(interpreter: SparkInterpreter): SparkInteractiveReader =
+ try new SparkJLineReader(interpreter)
+ catch {
+ case e @ (_: Exception | _: NoClassDefFoundError) =>
+ // println("Failed to create SparkJLineReader(%s): %s".format(interpreter, e))
+ new SparkSimpleReader
+ }
+}
+
diff --git a/repl/src/main/scala/spark/repl/SparkInterpreter.scala b/repl/src/main/scala/spark/repl/SparkInterpreter.scala
new file mode 100644
index 0000000000..10ea346658
--- /dev/null
+++ b/repl/src/main/scala/spark/repl/SparkInterpreter.scala
@@ -0,0 +1,1395 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2010 LAMP/EPFL
+ * @author Martin Odersky
+ */
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+
+import Predef.{ println => _, _ }
+import java.io.{ File, IOException, PrintWriter, StringWriter, Writer }
+import File.pathSeparator
+import java.lang.{ Class, ClassLoader }
+import java.net.{ MalformedURLException, URL }
+import java.lang.reflect
+import reflect.InvocationTargetException
+import java.util.UUID
+
+import scala.PartialFunction.{ cond, condOpt }
+import scala.tools.util.PathResolver
+import scala.reflect.Manifest
+import scala.collection.mutable
+import scala.collection.mutable.{ ListBuffer, HashSet, HashMap, ArrayBuffer }
+import scala.collection.immutable.Set
+import scala.tools.nsc.util.ScalaClassLoader
+import ScalaClassLoader.URLClassLoader
+import scala.util.control.Exception.{ Catcher, catching, ultimately, unwrapping }
+
+import io.{ PlainFile, VirtualDirectory }
+import reporters.{ ConsoleReporter, Reporter }
+import symtab.{ Flags, Names }
+import util.{ SourceFile, BatchSourceFile, ScriptSourceFile, ClassPath, Chars, stringFromWriter }
+import scala.reflect.NameTransformer
+import scala.tools.nsc.{ InterpreterResults => IR }
+import interpreter._
+import SparkInterpreter._
+
+import spark.HttpServer
+import spark.Utils
+
+/** <p>
+ * An interpreter for Scala code.
+ * </p>
+ * <p>
+ * The main public entry points are <code>compile()</code>,
+ * <code>interpret()</code>, and <code>bind()</code>.
+ * The <code>compile()</code> method loads a
+ * complete Scala file. The <code>interpret()</code> method executes one
+ * line of Scala code at the request of the user. The <code>bind()</code>
+ * method binds an object to a variable that can then be used by later
+ * interpreted code.
+ * </p>
+ * <p>
+ * The overall approach is based on compiling the requested code and then
+ * using a Java classloader and Java reflection to run the code
+ * and access its results.
+ * </p>
+ * <p>
+ * In more detail, a single compiler instance is used
+ * to accumulate all successfully compiled or interpreted Scala code. To
+ * "interpret" a line of code, the compiler generates a fresh object that
+ * includes the line of code and which has public member(s) to export
+ * all variables defined by that code. To extract the result of an
+ * interpreted line to show the user, a second "result object" is created
+ * which imports the variables exported by the above object and then
+ * exports a single member named "scala_repl_result". To accomodate user expressions
+ * that read from variables or methods defined in previous statements, "import"
+ * statements are used.
+ * </p>
+ * <p>
+ * This interpreter shares the strengths and weaknesses of using the
+ * full compiler-to-Java. The main strength is that interpreted code
+ * behaves exactly as does compiled code, including running at full speed.
+ * The main weakness is that redefining classes and methods is not handled
+ * properly, because rebinding at the Java level is technically difficult.
+ * </p>
+ *
+ * @author Moez A. Abdel-Gawad
+ * @author Lex Spoon
+ */
+class SparkInterpreter(val settings: Settings, out: PrintWriter) {
+ repl =>
+
+ def println(x: Any) = {
+ out.println(x)
+ out.flush()
+ }
+
+ /** construct an interpreter that reports to Console */
+ def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true))
+ def this() = this(new Settings())
+
+ val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1")
+
+ /** Local directory to save .class files too */
+ val outputDir = {
+ val tmp = System.getProperty("java.io.tmpdir")
+ val rootDir = System.getProperty("spark.repl.classdir", tmp)
+ Utils.createTempDir(rootDir)
+ }
+ if (SPARK_DEBUG_REPL) {
+ println("Output directory: " + outputDir)
+ }
+
+ /** Scala compiler virtual directory for outputDir */
+ //val virtualDirectory = new VirtualDirectory("(memory)", None)
+ val virtualDirectory = new PlainFile(outputDir)
+
+ /** Jetty server that will serve our classes to worker nodes */
+ val classServer = new HttpServer(outputDir)
+
+ // Start the classServer and store its URI in a spark system property
+ // (which will be passed to executors so that they can connect to it)
+ classServer.start()
+ System.setProperty("spark.repl.class.uri", classServer.uri)
+ if (SPARK_DEBUG_REPL) {
+ println("Class server started, URI = " + classServer.uri)
+ }
+
+ /** reporter */
+ object reporter extends ConsoleReporter(settings, null, out) {
+ override def printMessage(msg: String) {
+ out println clean(msg)
+ out.flush()
+ }
+ }
+
+ /** We're going to go to some trouble to initialize the compiler asynchronously.
+ * It's critical that nothing call into it until it's been initialized or we will
+ * run into unrecoverable issues, but the perceived repl startup time goes
+ * through the roof if we wait for it. So we initialize it with a future and
+ * use a lazy val to ensure that any attempt to use the compiler object waits
+ * on the future.
+ */
+ private val _compiler: Global = newCompiler(settings, reporter)
+ private def _initialize(): Boolean = {
+ val source = """
+ |// this is assembled to force the loading of approximately the
+ |// classes which will be loaded on the first expression anyway.
+ |class $repl_$init {
+ | val x = "abc".reverse.length + (5 max 5)
+ | scala.runtime.ScalaRunTime.stringOf(x)
+ |}
+ |""".stripMargin
+
+ try {
+ new _compiler.Run() compileSources List(new BatchSourceFile("<init>", source))
+ if (isReplDebug || settings.debug.value)
+ println("Repl compiler initialized.")
+ true
+ }
+ catch {
+ case MissingRequirementError(msg) => println("""
+ |Failed to initialize compiler: %s not found.
+ |** Note that as of 2.8 scala does not assume use of the java classpath.
+ |** For the old behavior pass -usejavacp to scala, or if using a Settings
+ |** object programatically, settings.usejavacp.value = true.""".stripMargin.format(msg)
+ )
+ false
+ }
+ }
+
+ // set up initialization future
+ private var _isInitialized: () => Boolean = null
+ def initialize() = synchronized {
+ if (_isInitialized == null)
+ _isInitialized = scala.concurrent.ops future _initialize()
+ }
+
+ /** the public, go through the future compiler */
+ lazy val compiler: Global = {
+ initialize()
+
+ // blocks until it is ; false means catastrophic failure
+ if (_isInitialized()) _compiler
+ else null
+ }
+
+ import compiler.{ Traverser, CompilationUnit, Symbol, Name, Type }
+ import compiler.{
+ Tree, TermTree, ValOrDefDef, ValDef, DefDef, Assign, ClassDef,
+ ModuleDef, Ident, Select, TypeDef, Import, MemberDef, DocDef,
+ ImportSelector, EmptyTree, NoType }
+ import compiler.{ nme, newTermName, newTypeName }
+ import nme.{
+ INTERPRETER_VAR_PREFIX, INTERPRETER_SYNTHVAR_PREFIX, INTERPRETER_LINE_PREFIX,
+ INTERPRETER_IMPORT_WRAPPER, INTERPRETER_WRAPPER_SUFFIX, USCOREkw
+ }
+
+ import compiler.definitions
+ import definitions.{ EmptyPackage, getMember }
+
+ /** whether to print out result lines */
+ private[repl] var printResults: Boolean = true
+
+ /** Temporarily be quiet */
+ def beQuietDuring[T](operation: => T): T = {
+ val wasPrinting = printResults
+ ultimately(printResults = wasPrinting) {
+ printResults = false
+ operation
+ }
+ }
+
+ /** whether to bind the lastException variable */
+ private var bindLastException = true
+
+ /** Temporarily stop binding lastException */
+ def withoutBindingLastException[T](operation: => T): T = {
+ val wasBinding = bindLastException
+ ultimately(bindLastException = wasBinding) {
+ bindLastException = false
+ operation
+ }
+ }
+
+ /** interpreter settings */
+ lazy val isettings = new SparkInterpreterSettings(this)
+
+ /** Instantiate a compiler. Subclasses can override this to
+ * change the compiler class used by this interpreter. */
+ protected def newCompiler(settings: Settings, reporter: Reporter) = {
+ settings.outputDirs setSingleOutput virtualDirectory
+ new Global(settings, reporter)
+ }
+
+ /** the compiler's classpath, as URL's */
+ lazy val compilerClasspath: List[URL] = new PathResolver(settings) asURLs
+
+ /* A single class loader is used for all commands interpreted by this Interpreter.
+ It would also be possible to create a new class loader for each command
+ to interpret. The advantages of the current approach are:
+
+ - Expressions are only evaluated one time. This is especially
+ significant for I/O, e.g. "val x = Console.readLine"
+
+ The main disadvantage is:
+
+ - Objects, classes, and methods cannot be rebound. Instead, definitions
+ shadow the old ones, and old code objects refer to the old
+ definitions.
+ */
+ private var _classLoader: ClassLoader = null
+ def resetClassLoader() = _classLoader = makeClassLoader()
+ def classLoader: ClassLoader = {
+ if (_classLoader == null)
+ resetClassLoader()
+
+ _classLoader
+ }
+ private def makeClassLoader(): ClassLoader = {
+ /*
+ val parent =
+ if (parentClassLoader == null) ScalaClassLoader fromURLs compilerClasspath
+ else new URLClassLoader(compilerClasspath, parentClassLoader)
+
+ new AbstractFileClassLoader(virtualDirectory, parent)
+ */
+ val parent =
+ if (parentClassLoader == null)
+ new java.net.URLClassLoader(compilerClasspath.toArray)
+ else
+ new java.net.URLClassLoader(compilerClasspath.toArray,
+ parentClassLoader)
+ val virtualDirUrl = new URL("file://" + virtualDirectory.path + "/")
+ new java.net.URLClassLoader(Array(virtualDirUrl), parent)
+ }
+
+ private def loadByName(s: String): Class[_] = // (classLoader tryToInitializeClass s).get
+ Class.forName(s, true, classLoader)
+
+ private def methodByName(c: Class[_], name: String): reflect.Method =
+ c.getMethod(name, classOf[Object])
+
+ protected def parentClassLoader: ClassLoader = this.getClass.getClassLoader()
+ def getInterpreterClassLoader() = classLoader
+
+ // Set the current Java "context" class loader to this interpreter's class loader
+ def setContextClassLoader() = Thread.currentThread.setContextClassLoader(classLoader)
+
+ /** the previous requests this interpreter has processed */
+ private val prevRequests = new ArrayBuffer[Request]()
+ private val usedNameMap = new HashMap[Name, Request]()
+ private val boundNameMap = new HashMap[Name, Request]()
+ private def allHandlers = prevRequests.toList flatMap (_.handlers)
+ private def allReqAndHandlers = prevRequests.toList flatMap (req => req.handlers map (req -> _))
+
+ def printAllTypeOf = {
+ prevRequests foreach { req =>
+ req.typeOf foreach { case (k, v) => Console.println(k + " => " + v) }
+ }
+ }
+
+ /** Most recent tree handled which wasn't wholly synthetic. */
+ private def mostRecentlyHandledTree: Option[Tree] = {
+ for {
+ req <- prevRequests.reverse
+ handler <- req.handlers.reverse
+ name <- handler.generatesValue
+ if !isSynthVarName(name)
+ } return Some(handler.member)
+
+ None
+ }
+
+ def recordRequest(req: Request) {
+ def tripart[T](set1: Set[T], set2: Set[T]) = {
+ val intersect = set1 intersect set2
+ List(set1 -- intersect, intersect, set2 -- intersect)
+ }
+
+ prevRequests += req
+ req.usedNames foreach (x => usedNameMap(x) = req)
+ req.boundNames foreach (x => boundNameMap(x) = req)
+
+ // XXX temporarily putting this here because of tricky initialization order issues
+ // so right now it's not bound until after you issue a command.
+ if (prevRequests.size == 1)
+ quietBind("settings", "spark.repl.SparkInterpreterSettings", isettings)
+
+ // println("\n s1 = %s\n s2 = %s\n s3 = %s".format(
+ // tripart(usedNameMap.keysIterator.toSet, boundNameMap.keysIterator.toSet): _*
+ // ))
+ }
+
+ private def keyList[T](x: collection.Map[T, _]): List[T] = x.keys.toList sortBy (_.toString)
+ def allUsedNames = keyList(usedNameMap)
+ def allBoundNames = keyList(boundNameMap)
+ def allSeenTypes = prevRequests.toList flatMap (_.typeOf.values.toList) distinct
+ def allValueGeneratingNames = allHandlers flatMap (_.generatesValue)
+ def allImplicits = partialFlatMap(allHandlers) {
+ case x: MemberHandler if x.definesImplicit => x.boundNames
+ }
+
+ /** Generates names pre0, pre1, etc. via calls to apply method */
+ class NameCreator(pre: String) {
+ private var x = -1
+ var mostRecent: String = null
+
+ def apply(): String = {
+ x += 1
+ val name = pre + x.toString
+ // make sure we don't overwrite their unwisely named res3 etc.
+ mostRecent =
+ if (allBoundNames exists (_.toString == name)) apply()
+ else name
+
+ mostRecent
+ }
+ def reset(): Unit = x = -1
+ def didGenerate(name: String) =
+ (name startsWith pre) && ((name drop pre.length) forall (_.isDigit))
+ }
+
+ /** allocate a fresh line name */
+ private lazy val lineNameCreator = new NameCreator(INTERPRETER_LINE_PREFIX)
+
+ /** allocate a fresh var name */
+ private lazy val varNameCreator = new NameCreator(INTERPRETER_VAR_PREFIX)
+
+ /** allocate a fresh internal variable name */
+ private lazy val synthVarNameCreator = new NameCreator(INTERPRETER_SYNTHVAR_PREFIX)
+
+ /** Check if a name looks like it was generated by varNameCreator */
+ private def isGeneratedVarName(name: String): Boolean = varNameCreator didGenerate name
+ private def isSynthVarName(name: String): Boolean = synthVarNameCreator didGenerate name
+ private def isSynthVarName(name: Name): Boolean = synthVarNameCreator didGenerate name.toString
+
+ def getVarName = varNameCreator()
+ def getSynthVarName = synthVarNameCreator()
+
+ /** Truncate a string if it is longer than isettings.maxPrintString */
+ private def truncPrintString(str: String): String = {
+ val maxpr = isettings.maxPrintString
+ val trailer = "..."
+
+ if (maxpr <= 0 || str.length <= maxpr) str
+ else str.substring(0, maxpr-3) + trailer
+ }
+
+ /** Clean up a string for output */
+ private def clean(str: String) = truncPrintString(
+ if (isettings.unwrapStrings && !SPARK_DEBUG_REPL) stripWrapperGunk(str)
+ else str
+ )
+
+ /** Heuristically strip interpreter wrapper prefixes
+ * from an interpreter output string.
+ * MATEI: Copied from interpreter package object
+ */
+ def stripWrapperGunk(str: String): String = {
+ val wrapregex = """(line[0-9]+\$object[$.])?(\$?VAL.?)*(\$iwC?(.this)?[$.])*"""
+ str.replaceAll(wrapregex, "")
+ }
+
+ /** Indent some code by the width of the scala> prompt.
+ * This way, compiler error messages read better.
+ */
+ private final val spaces = List.fill(7)(" ").mkString
+ def indentCode(code: String) = {
+ /** Heuristic to avoid indenting and thereby corrupting """-strings and XML literals. */
+ val noIndent = (code contains "\n") && (List("\"\"\"", "</", "/>") exists (code contains _))
+ stringFromWriter(str =>
+ for (line <- code.lines) {
+ if (!noIndent)
+ str.print(spaces)
+
+ str.print(line + "\n")
+ str.flush()
+ })
+ }
+ def indentString(s: String) = s split "\n" map (spaces + _ + "\n") mkString
+
+ implicit def name2string(name: Name) = name.toString
+
+ /** Compute imports that allow definitions from previous
+ * requests to be visible in a new request. Returns
+ * three pieces of related code:
+ *
+ * 1. An initial code fragment that should go before
+ * the code of the new request.
+ *
+ * 2. A code fragment that should go after the code
+ * of the new request.
+ *
+ * 3. An access path which can be traverested to access
+ * any bindings inside code wrapped by #1 and #2 .
+ *
+ * The argument is a set of Names that need to be imported.
+ *
+ * Limitations: This method is not as precise as it could be.
+ * (1) It does not process wildcard imports to see what exactly
+ * they import.
+ * (2) If it imports any names from a request, it imports all
+ * of them, which is not really necessary.
+ * (3) It imports multiple same-named implicits, but only the
+ * last one imported is actually usable.
+ */
+ private case class ComputedImports(prepend: String, append: String, access: String)
+ private def importsCode(wanted: Set[Name]): ComputedImports = {
+ /** Narrow down the list of requests from which imports
+ * should be taken. Removes requests which cannot contribute
+ * useful imports for the specified set of wanted names.
+ */
+ case class ReqAndHandler(req: Request, handler: MemberHandler) { }
+
+ def reqsToUse: List[ReqAndHandler] = {
+ /** Loop through a list of MemberHandlers and select which ones to keep.
+ * 'wanted' is the set of names that need to be imported.
+ */
+ def select(reqs: List[ReqAndHandler], wanted: Set[Name]): List[ReqAndHandler] = {
+ val isWanted = wanted contains _
+ // Single symbol imports might be implicits! See bug #1752. Rather than
+ // try to finesse this, we will mimic all imports for now.
+ def keepHandler(handler: MemberHandler) = handler match {
+ case _: ImportHandler => true
+ case x => x.definesImplicit || (x.boundNames exists isWanted)
+ }
+
+ reqs match {
+ case Nil => Nil
+ case rh :: rest if !keepHandler(rh.handler) => select(rest, wanted)
+ case rh :: rest =>
+ val importedNames = rh.handler match { case x: ImportHandler => x.importedNames ; case _ => Nil }
+ import rh.handler._
+ val newWanted = wanted ++ usedNames -- boundNames -- importedNames
+ rh :: select(rest, newWanted)
+ }
+ }
+
+ /** Flatten the handlers out and pair each with the original request */
+ select(allReqAndHandlers reverseMap { case (r, h) => ReqAndHandler(r, h) }, wanted).reverse
+ }
+
+ val code, trailingLines, accessPath = new StringBuffer
+ val currentImps = HashSet[Name]()
+
+ // add code for a new object to hold some imports
+ def addWrapper() {
+ /*
+ val impname = INTERPRETER_IMPORT_WRAPPER
+ code append "object %s {\n".format(impname)
+ trailingLines append "}\n"
+ accessPath append ("." + impname)
+ currentImps.clear
+ */
+ val impname = INTERPRETER_IMPORT_WRAPPER
+ code.append("@serializable class " + impname + "C {\n")
+ trailingLines.append("}\nval " + impname + " = new " + impname + "C;\n")
+ accessPath.append("." + impname)
+ currentImps.clear
+ }
+
+ addWrapper()
+
+ // loop through previous requests, adding imports for each one
+ for (ReqAndHandler(req, handler) <- reqsToUse) {
+ handler match {
+ // If the user entered an import, then just use it; add an import wrapping
+ // level if the import might conflict with some other import
+ case x: ImportHandler =>
+ if (x.importsWildcard || (currentImps exists (x.importedNames contains _)))
+ addWrapper()
+
+ code append (x.member.toString + "\n")
+
+ // give wildcard imports a import wrapper all to their own
+ if (x.importsWildcard) addWrapper()
+ else currentImps ++= x.importedNames
+
+ // For other requests, import each bound variable.
+ // import them explicitly instead of with _, so that
+ // ambiguity errors will not be generated. Also, quote
+ // the name of the variable, so that we don't need to
+ // handle quoting keywords separately.
+ case x =>
+ for (imv <- x.boundNames) {
+ // MATEI: Commented this check out because it was messing up for case classes
+ // (trying to import them twice within the same wrapper), but that is more likely
+ // due to a miscomputation of names that makes the code think they're unique.
+ // Need to evaluate whether having so many wrappers is a bad thing.
+ /*if (currentImps contains imv) */ addWrapper()
+
+ code.append("val " + req.objectName + "$VAL = " + req.objectName + ".INSTANCE;\n")
+ code.append("import " + req.objectName + "$VAL" + req.accessPath + ".`" + imv + "`;\n")
+
+ //code append ("import %s\n" format (req fullPath imv))
+ currentImps += imv
+ }
+ }
+ }
+ // add one extra wrapper, to prevent warnings in the common case of
+ // redefining the value bound in the last interpreter request.
+ addWrapper()
+ ComputedImports(code.toString, trailingLines.toString, accessPath.toString)
+ }
+
+ /** Parse a line into a sequence of trees. Returns None if the input is incomplete. */
+ private def parse(line: String): Option[List[Tree]] = {
+ var justNeedsMore = false
+ reporter.withIncompleteHandler((pos,msg) => {justNeedsMore = true}) {
+ // simple parse: just parse it, nothing else
+ def simpleParse(code: String): List[Tree] = {
+ reporter.reset
+ val unit = new CompilationUnit(new BatchSourceFile("<console>", code))
+ val scanner = new compiler.syntaxAnalyzer.UnitParser(unit)
+
+ scanner.templateStatSeq(false)._2
+ }
+ val trees = simpleParse(line)
+
+ if (reporter.hasErrors) Some(Nil) // the result did not parse, so stop
+ else if (justNeedsMore) None
+ else Some(trees)
+ }
+ }
+
+ /** Compile an nsc SourceFile. Returns true if there are
+ * no compilation errors, or false otherwise.
+ */
+ def compileSources(sources: SourceFile*): Boolean = {
+ reporter.reset
+ new compiler.Run() compileSources sources.toList
+ !reporter.hasErrors
+ }
+
+ /** Compile a string. Returns true if there are no
+ * compilation errors, or false otherwise.
+ */
+ def compileString(code: String): Boolean =
+ compileSources(new BatchSourceFile("<script>", code))
+
+ def compileAndSaveRun(label: String, code: String) = {
+ if (SPARK_DEBUG_REPL)
+ println(code)
+ if (isReplDebug) {
+ parse(code) match {
+ case Some(trees) => trees foreach (t => DBG(compiler.asCompactString(t)))
+ case _ => DBG("Parse error:\n\n" + code)
+ }
+ }
+ val run = new compiler.Run()
+ run.compileSources(List(new BatchSourceFile(label, code)))
+ run
+ }
+
+ /** Build a request from the user. <code>trees</code> is <code>line</code>
+ * after being parsed.
+ */
+ private def buildRequest(line: String, lineName: String, trees: List[Tree]): Request =
+ new Request(line, lineName, trees)
+
+ private def chooseHandler(member: Tree): MemberHandler = member match {
+ case member: DefDef => new DefHandler(member)
+ case member: ValDef => new ValHandler(member)
+ case member@Assign(Ident(_), _) => new AssignHandler(member)
+ case member: ModuleDef => new ModuleHandler(member)
+ case member: ClassDef => new ClassHandler(member)
+ case member: TypeDef => new TypeAliasHandler(member)
+ case member: Import => new ImportHandler(member)
+ case DocDef(_, documented) => chooseHandler(documented)
+ case member => new GenericHandler(member)
+ }
+
+ private def requestFromLine(line: String, synthetic: Boolean): Either[IR.Result, Request] = {
+ val trees = parse(indentCode(line)) match {
+ case None => return Left(IR.Incomplete)
+ case Some(Nil) => return Left(IR.Error) // parse error or empty input
+ case Some(trees) => trees
+ }
+
+ // use synthetic vars to avoid filling up the resXX slots
+ def varName = if (synthetic) getSynthVarName else getVarName
+
+ // Treat a single bare expression specially. This is necessary due to it being hard to
+ // modify code at a textual level, and it being hard to submit an AST to the compiler.
+ if (trees.size == 1) trees.head match {
+ case _:Assign => // we don't want to include assignments
+ case _:TermTree | _:Ident | _:Select => // ... but do want these as valdefs.
+ return requestFromLine("val %s =\n%s".format(varName, line), synthetic)
+ case _ =>
+ }
+
+ // figure out what kind of request
+ Right(buildRequest(line, lineNameCreator(), trees))
+ }
+
+ /** <p>
+ * Interpret one line of input. All feedback, including parse errors
+ * and evaluation results, are printed via the supplied compiler's
+ * reporter. Values defined are available for future interpreted
+ * strings.
+ * </p>
+ * <p>
+ * The return value is whether the line was interpreter successfully,
+ * e.g. that there were no parse errors.
+ * </p>
+ *
+ * @param line ...
+ * @return ...
+ */
+ def interpret(line: String): IR.Result = interpret(line, false)
+ def interpret(line: String, synthetic: Boolean): IR.Result = {
+ def loadAndRunReq(req: Request) = {
+ val (result, succeeded) = req.loadAndRun
+ if (printResults || !succeeded)
+ out print clean(result)
+
+ // book-keeping
+ if (succeeded && !synthetic)
+ recordRequest(req)
+
+ if (succeeded) IR.Success
+ else IR.Error
+ }
+
+ if (compiler == null) IR.Error
+ else requestFromLine(line, synthetic) match {
+ case Left(result) => result
+ case Right(req) =>
+ // null indicates a disallowed statement type; otherwise compile and
+ // fail if false (implying e.g. a type error)
+ if (req == null || !req.compile) IR.Error
+ else loadAndRunReq(req)
+ }
+ }
+
+ /** A name creator used for objects created by <code>bind()</code>. */
+ private lazy val newBinder = new NameCreator("binder")
+
+ /** Bind a specified name to a specified value. The name may
+ * later be used by expressions passed to interpret.
+ *
+ * @param name the variable name to bind
+ * @param boundType the type of the variable, as a string
+ * @param value the object value to bind to it
+ * @return an indication of whether the binding succeeded
+ */
+ def bind(name: String, boundType: String, value: Any): IR.Result = {
+ val binderName = newBinder()
+
+ compileString("""
+ |object %s {
+ | var value: %s = _
+ | def set(x: Any) = value = x.asInstanceOf[%s]
+ |}
+ """.stripMargin.format(binderName, boundType, boundType))
+
+ val binderObject = loadByName(binderName)
+ val setterMethod = methodByName(binderObject, "set")
+
+ setterMethod.invoke(null, value.asInstanceOf[AnyRef])
+ interpret("val %s = %s.value".format(name, binderName))
+ }
+
+ def quietBind(name: String, boundType: String, value: Any): IR.Result =
+ beQuietDuring { bind(name, boundType, value) }
+
+ /** Reset this interpreter, forgetting all user-specified requests. */
+ def reset() {
+ //virtualDirectory.clear
+ virtualDirectory.delete
+ virtualDirectory.create
+ resetClassLoader()
+ lineNameCreator.reset()
+ varNameCreator.reset()
+ prevRequests.clear
+ }
+
+ /** <p>
+ * This instance is no longer needed, so release any resources
+ * it is using. The reporter's output gets flushed.
+ * </p>
+ */
+ def close() {
+ reporter.flush
+ classServer.stop()
+ }
+
+ /** A traverser that finds all mentioned identifiers, i.e. things
+ * that need to be imported. It might return extra names.
+ */
+ private class ImportVarsTraverser extends Traverser {
+ val importVars = new HashSet[Name]()
+
+ override def traverse(ast: Tree) = ast match {
+ // XXX this is obviously inadequate but it's going to require some effort
+ // to get right.
+ case Ident(name) if !(name.toString startsWith "x$") => importVars += name
+ case _ => super.traverse(ast)
+ }
+ }
+
+ /** Class to handle one member among all the members included
+ * in a single interpreter request.
+ */
+ private sealed abstract class MemberHandler(val member: Tree) {
+ lazy val usedNames: List[Name] = {
+ val ivt = new ImportVarsTraverser()
+ ivt traverse member
+ ivt.importVars.toList
+ }
+ def boundNames: List[Name] = Nil
+ val definesImplicit = cond(member) {
+ case tree: MemberDef => tree.mods hasFlag Flags.IMPLICIT
+ }
+ def generatesValue: Option[Name] = None
+
+ def extraCodeToEvaluate(req: Request, code: PrintWriter) { }
+ def resultExtractionCode(req: Request, code: PrintWriter) { }
+
+ override def toString = "%s(used = %s)".format(this.getClass.toString split '.' last, usedNames)
+ }
+
+ private class GenericHandler(member: Tree) extends MemberHandler(member)
+
+ private class ValHandler(member: ValDef) extends MemberHandler(member) {
+ lazy val ValDef(mods, vname, _, _) = member
+ lazy val prettyName = NameTransformer.decode(vname)
+ lazy val isLazy = mods hasFlag Flags.LAZY
+
+ override lazy val boundNames = List(vname)
+ override def generatesValue = Some(vname)
+
+ override def resultExtractionCode(req: Request, code: PrintWriter) {
+ val isInternal = isGeneratedVarName(vname) && req.typeOfEnc(vname) == "Unit"
+ if (!mods.isPublic || isInternal) return
+
+ lazy val extractor = "scala.runtime.ScalaRunTime.stringOf(%s)".format(req fullPath vname)
+
+ // if this is a lazy val we avoid evaluating it here
+ val resultString = if (isLazy) codegenln(false, "<lazy>") else extractor
+ val codeToPrint =
+ """ + "%s: %s = " + %s""".format(prettyName, string2code(req typeOf vname), resultString)
+
+ code print codeToPrint
+ }
+ }
+
+ private class DefHandler(defDef: DefDef) extends MemberHandler(defDef) {
+ lazy val DefDef(mods, name, _, vparamss, _, _) = defDef
+ override lazy val boundNames = List(name)
+ // true if 0-arity
+ override def generatesValue =
+ if (vparamss.isEmpty || vparamss.head.isEmpty) Some(name)
+ else None
+
+ override def resultExtractionCode(req: Request, code: PrintWriter) =
+ if (mods.isPublic) code print codegenln(name, ": ", req.typeOf(name))
+ }
+
+ private class AssignHandler(member: Assign) extends MemberHandler(member) {
+ val lhs = member.lhs.asInstanceOf[Ident] // an unfortunate limitation
+ val helperName = newTermName(synthVarNameCreator())
+ override def generatesValue = Some(helperName)
+
+ override def extraCodeToEvaluate(req: Request, code: PrintWriter) =
+ code println """val %s = %s""".format(helperName, lhs)
+
+ /** Print out lhs instead of the generated varName */
+ override def resultExtractionCode(req: Request, code: PrintWriter) {
+ val lhsType = string2code(req typeOfEnc helperName)
+ val res = string2code(req fullPath helperName)
+ val codeToPrint = """ + "%s: %s = " + %s + "\n" """.format(lhs, lhsType, res)
+
+ code println codeToPrint
+ }
+ }
+
+ private class ModuleHandler(module: ModuleDef) extends MemberHandler(module) {
+ lazy val ModuleDef(mods, name, _) = module
+ override lazy val boundNames = List(name)
+ override def generatesValue = Some(name)
+
+ override def resultExtractionCode(req: Request, code: PrintWriter) =
+ code println codegenln("defined module ", name)
+ }
+
+ private class ClassHandler(classdef: ClassDef) extends MemberHandler(classdef) {
+ lazy val ClassDef(mods, name, _, _) = classdef
+ override lazy val boundNames =
+ name :: (if (mods hasFlag Flags.CASE) List(name.toTermName) else Nil)
+
+ override def resultExtractionCode(req: Request, code: PrintWriter) =
+ code print codegenln("defined %s %s".format(classdef.keyword, name))
+ }
+
+ private class TypeAliasHandler(typeDef: TypeDef) extends MemberHandler(typeDef) {
+ lazy val TypeDef(mods, name, _, _) = typeDef
+ def isAlias() = mods.isPublic && compiler.treeInfo.isAliasTypeDef(typeDef)
+ override lazy val boundNames = if (isAlias) List(name) else Nil
+
+ override def resultExtractionCode(req: Request, code: PrintWriter) =
+ code println codegenln("defined type alias ", name)
+ }
+
+ private class ImportHandler(imp: Import) extends MemberHandler(imp) {
+ lazy val Import(expr, selectors) = imp
+ def targetType = stringToCompilerType(expr.toString) match {
+ case NoType => None
+ case x => Some(x)
+ }
+
+ private def selectorWild = selectors filter (_.name == USCOREkw) // wildcard imports, e.g. import foo._
+ private def selectorMasked = selectors filter (_.rename == USCOREkw) // masking imports, e.g. import foo.{ bar => _ }
+ private def selectorNames = selectors map (_.name)
+ private def selectorRenames = selectors map (_.rename) filterNot (_ == null)
+
+ /** Whether this import includes a wildcard import */
+ val importsWildcard = selectorWild.nonEmpty
+
+ /** Complete list of names imported by a wildcard */
+ def wildcardImportedNames: List[Name] = (
+ for (tpe <- targetType ; if importsWildcard) yield
+ tpe.nonPrivateMembers filter (x => x.isMethod && x.isPublic) map (_.name) distinct
+ ).toList.flatten
+
+ /** The individual names imported by this statement */
+ /** XXX come back to this and see what can be done with wildcards now that
+ * we know how to enumerate the identifiers.
+ */
+ val importedNames: List[Name] =
+ selectorRenames filterNot (_ == USCOREkw) flatMap (x => List(x.toTypeName, x.toTermName))
+
+ override def resultExtractionCode(req: Request, code: PrintWriter) =
+ code println codegenln(imp.toString)
+ }
+
+ /** One line of code submitted by the user for interpretation */
+ private class Request(val line: String, val lineName: String, val trees: List[Tree]) {
+ /** name to use for the object that will compute "line" */
+ def objectName = lineName + INTERPRETER_WRAPPER_SUFFIX
+
+ /** name of the object that retrieves the result from the above object */
+ def resultObjectName = "RequestResult$" + objectName
+
+ /** handlers for each tree in this request */
+ val handlers: List[MemberHandler] = trees map chooseHandler
+
+ /** all (public) names defined by these statements */
+ val boundNames = handlers flatMap (_.boundNames)
+
+ /** list of names used by this expression */
+ val usedNames: List[Name] = handlers flatMap (_.usedNames)
+
+ /** def and val names */
+ def defNames = partialFlatMap(handlers) { case x: DefHandler => x.boundNames }
+ def valueNames = partialFlatMap(handlers) {
+ case x: AssignHandler => List(x.helperName)
+ case x: ValHandler => boundNames
+ case x: ModuleHandler => List(x.name)
+ }
+
+ /** Code to import bound names from previous lines - accessPath is code to
+ * append to objectName to access anything bound by request.
+ */
+ val ComputedImports(importsPreamble, importsTrailer, accessPath) =
+ importsCode(Set.empty ++ usedNames)
+
+ /** Code to access a variable with the specified name */
+ def fullPath(vname: String): String = "%s.`%s`".format(objectName + ".INSTANCE" + accessPath, vname)
+
+ /** Code to access a variable with the specified name */
+ def fullPath(vname: Name): String = fullPath(vname.toString)
+
+ /** the line of code to compute */
+ def toCompute = line
+
+ /** generate the source code for the object that computes this request */
+ def objectSourceCode: String = stringFromWriter { code =>
+ val preamble = """
+ |@serializable class %s {
+ | %s%s
+ """.stripMargin.format(objectName, importsPreamble, indentCode(toCompute))
+ val postamble = importsTrailer + "\n}"
+
+ code println preamble
+ handlers foreach { _.extraCodeToEvaluate(this, code) }
+ code println postamble
+
+ //create an object
+ code.println("object " + objectName + " {")
+ code.println(" val INSTANCE = new " + objectName + "();")
+ code.println("}")
+ }
+
+ /** generate source code for the object that retrieves the result
+ from objectSourceCode */
+ def resultObjectSourceCode: String = stringFromWriter { code =>
+ /** We only want to generate this code when the result
+ * is a value which can be referred to as-is.
+ */
+ val valueExtractor = handlers.last.generatesValue match {
+ case Some(vname) if typeOf contains vname =>
+ """
+ |lazy val scala_repl_value = {
+ | scala_repl_result
+ | %s
+ |}""".stripMargin.format(fullPath(vname))
+ case _ => ""
+ }
+
+ // first line evaluates object to make sure constructor is run
+ // initial "" so later code can uniformly be: + etc
+ val preamble = """
+ |object %s {
+ | %s
+ | val scala_repl_result: String = {
+ | %s
+ | (""
+ """.stripMargin.format(resultObjectName, valueExtractor, objectName + ".INSTANCE" + accessPath)
+
+ val postamble = """
+ | )
+ | }
+ |}
+ """.stripMargin
+
+ code println preamble
+ if (printResults) {
+ handlers foreach { _.resultExtractionCode(this, code) }
+ }
+ code println postamble
+ }
+
+ // compile the object containing the user's code
+ lazy val objRun = compileAndSaveRun("<console>", objectSourceCode)
+
+ // compile the result-extraction object
+ lazy val extractionObjectRun = compileAndSaveRun("<console>", resultObjectSourceCode)
+
+ lazy val loadedResultObject = loadByName(resultObjectName)
+
+ def extractionValue(): Option[AnyRef] = {
+ // ensure it has run
+ extractionObjectRun
+
+ // load it and retrieve the value
+ try Some(loadedResultObject getMethod "scala_repl_value" invoke loadedResultObject)
+ catch { case _: Exception => None }
+ }
+
+ /** Compile the object file. Returns whether the compilation succeeded.
+ * If all goes well, the "types" map is computed. */
+ def compile(): Boolean = {
+ // error counting is wrong, hence interpreter may overlook failure - so we reset
+ reporter.reset
+
+ // compile the main object
+ objRun
+
+ // bail on error
+ if (reporter.hasErrors)
+ return false
+
+ // extract and remember types
+ typeOf
+
+ // compile the result-extraction object
+ extractionObjectRun
+
+ // success
+ !reporter.hasErrors
+ }
+
+ def atNextPhase[T](op: => T): T = compiler.atPhase(objRun.typerPhase.next)(op)
+
+ /** The outermost wrapper object */
+ lazy val outerResObjSym: Symbol = getMember(EmptyPackage, newTermName(objectName).toTypeName)
+
+ /** The innermost object inside the wrapper, found by
+ * following accessPath into the outer one. */
+ lazy val resObjSym =
+ accessPath.split("\\.").foldLeft(outerResObjSym) { (sym, name) =>
+ if (name == "") sym else
+ atNextPhase(sym.info member newTermName(name))
+ }
+
+ /* typeOf lookup with encoding */
+ def typeOfEnc(vname: Name) = typeOf(compiler encode vname)
+
+ /** Types of variables defined by this request. */
+ lazy val typeOf: Map[Name, String] = {
+ def getTypes(names: List[Name], nameMap: Name => Name): Map[Name, String] = {
+ names.foldLeft(Map.empty[Name, String]) { (map, name) =>
+ val rawType = atNextPhase(resObjSym.info.member(name).tpe)
+ // the types are all =>T; remove the =>
+ val cleanedType = rawType match {
+ case compiler.PolyType(Nil, rt) => rt
+ case rawType => rawType
+ }
+
+ map + (name -> atNextPhase(cleanedType.toString))
+ }
+ }
+
+ getTypes(valueNames, nme.getterToLocal(_)) ++ getTypes(defNames, identity)
+ }
+
+ /** load and run the code using reflection */
+ def loadAndRun: (String, Boolean) = {
+ val resultValMethod: reflect.Method = loadedResultObject getMethod "scala_repl_result"
+ // XXX if wrapperExceptions isn't type-annotated we crash scalac
+ val wrapperExceptions: List[Class[_ <: Throwable]] =
+ List(classOf[InvocationTargetException], classOf[ExceptionInInitializerError])
+
+ /** We turn off the binding to accomodate ticket #2817 */
+ def onErr: Catcher[(String, Boolean)] = {
+ case t: Throwable if bindLastException =>
+ withoutBindingLastException {
+ quietBind("lastException", "java.lang.Throwable", t)
+ (stringFromWriter(t.printStackTrace(_)), false)
+ }
+ }
+
+ catching(onErr) {
+ unwrapping(wrapperExceptions: _*) {
+ (resultValMethod.invoke(loadedResultObject).toString, true)
+ }
+ }
+ }
+
+ override def toString = "Request(line=%s, %s trees)".format(line, trees.size)
+ }
+
+ /** A container class for methods to be injected into the repl
+ * in power mode.
+ */
+ object power {
+ lazy val compiler: repl.compiler.type = repl.compiler
+ import compiler.{ phaseNames, atPhase, currentRun }
+
+ def mkContext(code: String = "") = compiler.analyzer.rootContext(mkUnit(code))
+ def mkAlias(name: String, what: String) = interpret("type %s = %s".format(name, what))
+ def mkSourceFile(code: String) = new BatchSourceFile("<console>", code)
+ def mkUnit(code: String) = new CompilationUnit(mkSourceFile(code))
+
+ def mkTree(code: String): Tree = mkTrees(code).headOption getOrElse EmptyTree
+ def mkTrees(code: String): List[Tree] = parse(code) getOrElse Nil
+ def mkTypedTrees(code: String*): List[compiler.Tree] = {
+ class TyperRun extends compiler.Run {
+ override def stopPhase(name: String) = name == "superaccessors"
+ }
+
+ reporter.reset
+ val run = new TyperRun
+ run compileSources (code.toList.zipWithIndex map {
+ case (s, i) => new BatchSourceFile("<console %d>".format(i), s)
+ })
+ run.units.toList map (_.body)
+ }
+ def mkTypedTree(code: String) = mkTypedTrees(code).head
+ def mkType(id: String): compiler.Type = stringToCompilerType(id)
+
+ def dump(): String = (
+ ("Names used: " :: allUsedNames) ++
+ ("\nIdentifiers: " :: unqualifiedIds)
+ ) mkString " "
+
+ lazy val allPhases: List[Phase] = phaseNames map (currentRun phaseNamed _)
+ def atAllPhases[T](op: => T): List[(String, T)] = allPhases map (ph => (ph.name, atPhase(ph)(op)))
+ def showAtAllPhases(op: => Any): Unit =
+ atAllPhases(op.toString) foreach { case (ph, op) => Console.println("%15s -> %s".format(ph, op take 240)) }
+ }
+
+ def unleash(): Unit = beQuietDuring {
+ interpret("import scala.tools.nsc._")
+ repl.bind("repl", "spark.repl.SparkInterpreter", this)
+ interpret("val global: repl.compiler.type = repl.compiler")
+ interpret("val power: repl.power.type = repl.power")
+ // interpret("val replVars = repl.replVars")
+ }
+
+ /** Artificial object demonstrating completion */
+ // lazy val replVars = CompletionAware(
+ // Map[String, CompletionAware](
+ // "ids" -> CompletionAware(() => unqualifiedIds, completionAware _),
+ // "synthVars" -> CompletionAware(() => allBoundNames filter isSynthVarName map (_.toString)),
+ // "types" -> CompletionAware(() => allSeenTypes map (_.toString)),
+ // "implicits" -> CompletionAware(() => allImplicits map (_.toString))
+ // )
+ // )
+
+ /** Returns the name of the most recent interpreter result.
+ * Mostly this exists so you can conveniently invoke methods on
+ * the previous result.
+ */
+ def mostRecentVar: String =
+ if (mostRecentlyHandledTree.isEmpty) ""
+ else mostRecentlyHandledTree.get match {
+ case x: ValOrDefDef => x.name
+ case Assign(Ident(name), _) => name
+ case ModuleDef(_, name, _) => name
+ case _ => onull(varNameCreator.mostRecent)
+ }
+
+ private def requestForName(name: Name): Option[Request] =
+ prevRequests.reverse find (_.boundNames contains name)
+
+ private def requestForIdent(line: String): Option[Request] = requestForName(newTermName(line))
+
+ def stringToCompilerType(id: String): compiler.Type = {
+ // if it's a recognized identifier, the type of that; otherwise treat the
+ // String like a value (e.g. scala.collection.Map) .
+ def findType = typeForIdent(id) match {
+ case Some(x) => definitions.getClass(newTermName(x)).tpe
+ case _ => definitions.getModule(newTermName(id)).tpe
+ }
+
+ try findType catch { case _: MissingRequirementError => NoType }
+ }
+
+ def typeForIdent(id: String): Option[String] =
+ requestForIdent(id) flatMap (x => x.typeOf get newTermName(id))
+
+ def methodsOf(name: String) =
+ evalExpr[List[String]](methodsCode(name)) map (x => NameTransformer.decode(getOriginalName(x)))
+
+ def completionAware(name: String) = {
+ // XXX working around "object is not a value" crash, i.e.
+ // import java.util.ArrayList ; ArrayList.<tab>
+ clazzForIdent(name) flatMap (_ => evalExpr[Option[CompletionAware]](asCompletionAwareCode(name)))
+ }
+
+ def extractionValueForIdent(id: String): Option[AnyRef] =
+ requestForIdent(id) flatMap (_.extractionValue)
+
+ /** Executes code looking for a manifest of type T.
+ */
+ def manifestFor[T: Manifest] =
+ evalExpr[Manifest[T]]("""manifest[%s]""".format(manifest[T]))
+
+ /** Executes code looking for an implicit value of type T.
+ */
+ def implicitFor[T: Manifest] = {
+ val s = manifest[T].toString
+ evalExpr[Option[T]]("{ def f(implicit x: %s = null): %s = x ; Option(f) }".format(s, s))
+ // We don't use implicitly so as to fail without failing.
+ // evalExpr[T]("""implicitly[%s]""".format(manifest[T]))
+ }
+ /** Executes code looking for an implicit conversion from the type
+ * of the given identifier to CompletionAware.
+ */
+ def completionAwareImplicit[T](id: String) = {
+ val f1string = "%s => %s".format(typeForIdent(id).get, classOf[CompletionAware].getName)
+ val code = """{
+ | def f(implicit x: (%s) = null): %s = x
+ | val f1 = f
+ | if (f1 == null) None else Some(f1(%s))
+ |}""".stripMargin.format(f1string, f1string, id)
+
+ evalExpr[Option[CompletionAware]](code)
+ }
+
+ def clazzForIdent(id: String): Option[Class[_]] =
+ extractionValueForIdent(id) flatMap (x => Option(x) map (_.getClass))
+
+ private def methodsCode(name: String) =
+ "%s.%s(%s)".format(classOf[ReflectionCompletion].getName, "methodsOf", name)
+
+ private def asCompletionAwareCode(name: String) =
+ "%s.%s(%s)".format(classOf[CompletionAware].getName, "unapply", name)
+
+ private def getOriginalName(name: String): String =
+ nme.originalName(newTermName(name)).toString
+
+ case class InterpreterEvalException(msg: String) extends Exception(msg)
+ def evalError(msg: String) = throw InterpreterEvalException(msg)
+
+ /** The user-facing eval in :power mode wraps an Option.
+ */
+ def eval[T: Manifest](line: String): Option[T] =
+ try Some(evalExpr[T](line))
+ catch { case InterpreterEvalException(msg) => out println indentString(msg) ; None }
+
+ def evalExpr[T: Manifest](line: String): T = {
+ // Nothing means the type could not be inferred.
+ if (manifest[T] eq Manifest.Nothing)
+ evalError("Could not infer type: try 'eval[SomeType](%s)' instead".format(line))
+
+ val lhs = getSynthVarName
+ beQuietDuring { interpret("val " + lhs + " = { " + line + " } ") }
+
+ // TODO - can we meaningfully compare the inferred type T with
+ // the internal compiler Type assigned to lhs?
+ // def assignedType = prevRequests.last.typeOf(newTermName(lhs))
+
+ val req = requestFromLine(lhs, true) match {
+ case Left(result) => evalError(result.toString)
+ case Right(req) => req
+ }
+ if (req == null || !req.compile || req.handlers.size != 1)
+ evalError("Eval error.")
+
+ try req.extractionValue.get.asInstanceOf[T] catch {
+ case e: Exception => evalError(e.getMessage)
+ }
+ }
+
+ def interpretExpr[T: Manifest](code: String): Option[T] = beQuietDuring {
+ interpret(code) match {
+ case IR.Success =>
+ try prevRequests.last.extractionValue map (_.asInstanceOf[T])
+ catch { case e: Exception => out println e ; None }
+ case _ => None
+ }
+ }
+
+ /** Another entry point for tab-completion, ids in scope */
+ private def unqualifiedIdNames() = partialFlatMap(allHandlers) {
+ case x: AssignHandler => List(x.helperName)
+ case x: ValHandler => List(x.vname)
+ case x: ModuleHandler => List(x.name)
+ case x: DefHandler => List(x.name)
+ case x: ImportHandler => x.importedNames
+ } filterNot isSynthVarName
+
+ /** Types which have been wildcard imported, such as:
+ * val x = "abc" ; import x._ // type java.lang.String
+ * import java.lang.String._ // object java.lang.String
+ *
+ * Used by tab completion.
+ *
+ * XXX right now this gets import x._ and import java.lang.String._,
+ * but doesn't figure out import String._. There's a lot of ad hoc
+ * scope twiddling which should be swept away in favor of digging
+ * into the compiler scopes.
+ */
+ def wildcardImportedTypes(): List[Type] = {
+ val xs = allHandlers collect { case x: ImportHandler if x.importsWildcard => x.targetType }
+ xs.flatten.reverse.distinct
+ }
+
+ /** Another entry point for tab-completion, ids in scope */
+ def unqualifiedIds() = (unqualifiedIdNames() map (_.toString)).distinct.sorted
+
+ /** For static/object method completion */
+ def getClassObject(path: String): Option[Class[_]] = //classLoader tryToLoadClass path
+ try {
+ Some(Class.forName(path, true, classLoader))
+ } catch {
+ case e: Exception => None
+ }
+
+ /** Parse the ScalaSig to find type aliases */
+ def aliasForType(path: String) = ByteCode.aliasForType(path)
+
+ // Coming soon
+ // implicit def string2liftedcode(s: String): LiftedCode = new LiftedCode(s)
+ // case class LiftedCode(code: String) {
+ // val lifted: String = {
+ // beQuietDuring { interpret(code) }
+ // eval2[String]("({ " + code + " }).toString")
+ // }
+ // def >> : String = lifted
+ // }
+
+ // debugging
+ def isReplDebug = settings.Yrepldebug.value
+ def isCompletionDebug = settings.Ycompletion.value
+ def DBG(s: String) = if (isReplDebug) out println s else ()
+}
+
+/** Utility methods for the Interpreter. */
+object SparkInterpreter {
+
+ import scala.collection.generic.CanBuildFrom
+ def partialFlatMap[A, B, CC[X] <: Traversable[X]]
+ (coll: CC[A])
+ (pf: PartialFunction[A, CC[B]])
+ (implicit bf: CanBuildFrom[CC[A], B, CC[B]]) =
+ {
+ val b = bf(coll)
+ for (x <- coll collect pf)
+ b ++= x
+
+ b.result
+ }
+
+ object DebugParam {
+ implicit def tuple2debugparam[T](x: (String, T))(implicit m: Manifest[T]): DebugParam[T] =
+ DebugParam(x._1, x._2)
+
+ implicit def any2debugparam[T](x: T)(implicit m: Manifest[T]): DebugParam[T] =
+ DebugParam("p" + getCount(), x)
+
+ private var counter = 0
+ def getCount() = { counter += 1; counter }
+ }
+ case class DebugParam[T](name: String, param: T)(implicit m: Manifest[T]) {
+ val manifest = m
+ val typeStr = {
+ val str = manifest.toString
+ // I'm sure there are more to be discovered...
+ val regexp1 = """(.*?)\[(.*)\]""".r
+ val regexp2str = """.*\.type#"""
+ val regexp2 = (regexp2str + """(.*)""").r
+
+ (str.replaceAll("""\n""", "")) match {
+ case regexp1(clazz, typeArgs) => "%s[%s]".format(clazz, typeArgs.replaceAll(regexp2str, ""))
+ case regexp2(clazz) => clazz
+ case _ => str
+ }
+ }
+ }
+ def breakIf(assertion: => Boolean, args: DebugParam[_]*): Unit =
+ if (assertion) break(args.toList)
+
+ // start a repl, binding supplied args
+ def break(args: List[DebugParam[_]]): Unit = {
+ val intLoop = new SparkInterpreterLoop
+ intLoop.settings = new Settings(Console.println)
+ // XXX come back to the dot handling
+ intLoop.settings.classpath.value = "."
+ intLoop.createInterpreter
+ intLoop.in = SparkInteractiveReader.createDefault(intLoop.interpreter)
+
+ // rebind exit so people don't accidentally call System.exit by way of predef
+ intLoop.interpreter.beQuietDuring {
+ intLoop.interpreter.interpret("""def exit = println("Type :quit to resume program execution.")""")
+ for (p <- args) {
+ intLoop.interpreter.bind(p.name, p.typeStr, p.param)
+ Console println "%s: %s".format(p.name, p.typeStr)
+ }
+ }
+ intLoop.repl()
+ intLoop.closeInterpreter
+ }
+
+ def codegenln(leadingPlus: Boolean, xs: String*): String = codegen(leadingPlus, (xs ++ Array("\n")): _*)
+ def codegenln(xs: String*): String = codegenln(true, xs: _*)
+
+ def codegen(xs: String*): String = codegen(true, xs: _*)
+ def codegen(leadingPlus: Boolean, xs: String*): String = {
+ val front = if (leadingPlus) "+ " else ""
+ front + (xs map string2codeQuoted mkString " + ")
+ }
+
+ def string2codeQuoted(str: String) = "\"" + string2code(str) + "\""
+
+ /** Convert a string into code that can recreate the string.
+ * This requires replacing all special characters by escape
+ * codes. It does not add the surrounding " marks. */
+ def string2code(str: String): String = {
+ val res = new StringBuilder
+ for (c <- str) c match {
+ case '"' | '\'' | '\\' => res += '\\' ; res += c
+ case _ if c.isControl => res ++= Chars.char2uescape(c)
+ case _ => res += c
+ }
+ res.toString
+ }
+}
+
diff --git a/repl/src/main/scala/spark/repl/SparkInterpreterLoop.scala b/repl/src/main/scala/spark/repl/SparkInterpreterLoop.scala
new file mode 100644
index 0000000000..a118abf3ca
--- /dev/null
+++ b/repl/src/main/scala/spark/repl/SparkInterpreterLoop.scala
@@ -0,0 +1,662 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2010 LAMP/EPFL
+ * @author Alexander Spoon
+ */
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+
+import Predef.{ println => _, _ }
+import java.io.{ BufferedReader, FileReader, PrintWriter }
+import java.io.IOException
+
+import scala.tools.nsc.{ InterpreterResults => IR }
+import scala.annotation.tailrec
+import scala.collection.mutable.ListBuffer
+import scala.concurrent.ops
+import util.{ ClassPath }
+import interpreter._
+import io.{ File, Process }
+
+import spark.SparkContext
+
+// Classes to wrap up interpreter commands and their results
+// You can add new commands by adding entries to val commands
+// inside InterpreterLoop.
+trait InterpreterControl {
+ self: SparkInterpreterLoop =>
+
+ // the default result means "keep running, and don't record that line"
+ val defaultResult = Result(true, None)
+
+ // a single interpreter command
+ sealed abstract class Command extends Function1[List[String], Result] {
+ def name: String
+ def help: String
+ def error(msg: String) = {
+ out.println(":" + name + " " + msg + ".")
+ Result(true, None)
+ }
+ def usage(): String
+ }
+
+ case class NoArgs(name: String, help: String, f: () => Result) extends Command {
+ def usage(): String = ":" + name
+ def apply(args: List[String]) = if (args.isEmpty) f() else error("accepts no arguments")
+ }
+
+ case class LineArg(name: String, help: String, f: (String) => Result) extends Command {
+ def usage(): String = ":" + name + " <line>"
+ def apply(args: List[String]) = f(args mkString " ")
+ }
+
+ case class OneArg(name: String, help: String, f: (String) => Result) extends Command {
+ def usage(): String = ":" + name + " <arg>"
+ def apply(args: List[String]) =
+ if (args.size == 1) f(args.head)
+ else error("requires exactly one argument")
+ }
+
+ case class VarArgs(name: String, help: String, f: (List[String]) => Result) extends Command {
+ def usage(): String = ":" + name + " [arg]"
+ def apply(args: List[String]) = f(args)
+ }
+
+ // the result of a single command
+ case class Result(keepRunning: Boolean, lineToRecord: Option[String])
+}
+
+/** The
+ * <a href="http://scala-lang.org/" target="_top">Scala</a>
+ * interactive shell. It provides a read-eval-print loop around
+ * the Interpreter class.
+ * After instantiation, clients should call the <code>main()</code> method.
+ *
+ * <p>If no in0 is specified, then input will come from the console, and
+ * the class will attempt to provide input editing feature such as
+ * input history.
+ *
+ * @author Moez A. Abdel-Gawad
+ * @author Lex Spoon
+ * @version 1.2
+ */
+class SparkInterpreterLoop(
+ in0: Option[BufferedReader], val out: PrintWriter, master: Option[String])
+extends InterpreterControl {
+ def this(in0: BufferedReader, out: PrintWriter, master: String) =
+ this(Some(in0), out, Some(master))
+
+ def this(in0: BufferedReader, out: PrintWriter) =
+ this(Some(in0), out, None)
+
+ def this() = this(None, new PrintWriter(Console.out), None)
+
+ /** The input stream from which commands come, set by main() */
+ var in: SparkInteractiveReader = _
+
+ /** The context class loader at the time this object was created */
+ protected val originalClassLoader = Thread.currentThread.getContextClassLoader
+
+ var settings: Settings = _ // set by main()
+ var interpreter: SparkInterpreter = _ // set by createInterpreter()
+
+ // classpath entries added via :cp
+ var addedClasspath: String = ""
+
+ /** A reverse list of commands to replay if the user requests a :replay */
+ var replayCommandStack: List[String] = Nil
+
+ /** A list of commands to replay if the user requests a :replay */
+ def replayCommands = replayCommandStack.reverse
+
+ /** Record a command for replay should the user request a :replay */
+ def addReplay(cmd: String) = replayCommandStack ::= cmd
+
+ /** Close the interpreter and set the var to <code>null</code>. */
+ def closeInterpreter() {
+ if (interpreter ne null) {
+ interpreter.close
+ interpreter = null
+ Thread.currentThread.setContextClassLoader(originalClassLoader)
+ }
+ }
+
+ /** Create a new interpreter. */
+ def createInterpreter() {
+ if (addedClasspath != "")
+ settings.classpath append addedClasspath
+
+ interpreter = new SparkInterpreter(settings, out) {
+ override protected def parentClassLoader =
+ classOf[SparkInterpreterLoop].getClassLoader
+ }
+ interpreter.setContextClassLoader()
+ // interpreter.quietBind("settings", "spark.repl.SparkInterpreterSettings", interpreter.isettings)
+ }
+
+ /** print a friendly help message */
+ def printHelp() = {
+ out println "All commands can be abbreviated - for example :he instead of :help.\n"
+ val cmds = commands map (x => (x.usage, x.help))
+ val width: Int = cmds map { case (x, _) => x.length } max
+ val formatStr = "%-" + width + "s %s"
+ cmds foreach { case (usage, help) => out println formatStr.format(usage, help) }
+ }
+
+ /** Print a welcome message */
+ def printWelcome() {
+ plushln("""Welcome to
+ ____ __
+ / __/__ ___ _____/ /__
+ _\ \/ _ \/ _ `/ __/ '_/
+ /___/ .__/\_,_/_/ /_/\_\ version 0.0
+ /_/
+""")
+
+ import Properties._
+ val welcomeMsg = "Using Scala %s (%s, Java %s)".format(
+ versionString, javaVmName, javaVersion)
+ plushln(welcomeMsg)
+ }
+
+ /** Show the history */
+ def printHistory(xs: List[String]) {
+ val defaultLines = 20
+
+ if (in.history.isEmpty)
+ return println("No history available.")
+
+ val current = in.history.get.index
+ val count = try xs.head.toInt catch { case _: Exception => defaultLines }
+ val lines = in.historyList takeRight count
+ val offset = current - lines.size + 1
+
+ for ((line, index) <- lines.zipWithIndex)
+ println("%d %s".format(index + offset, line))
+ }
+
+ /** Some print conveniences */
+ def println(x: Any) = out println x
+ def plush(x: Any) = { out print x ; out.flush() }
+ def plushln(x: Any) = { out println x ; out.flush() }
+
+ /** Search the history */
+ def searchHistory(_cmdline: String) {
+ val cmdline = _cmdline.toLowerCase
+
+ if (in.history.isEmpty)
+ return println("No history available.")
+
+ val current = in.history.get.index
+ val offset = current - in.historyList.size + 1
+
+ for ((line, index) <- in.historyList.zipWithIndex ; if line.toLowerCase contains cmdline)
+ println("%d %s".format(index + offset, line))
+ }
+
+ /** Prompt to print when awaiting input */
+ val prompt = Properties.shellPromptString
+
+ // most commands do not want to micromanage the Result, but they might want
+ // to print something to the console, so we accomodate Unit and String returns.
+ object CommandImplicits {
+ implicit def u2ir(x: Unit): Result = defaultResult
+ implicit def s2ir(s: String): Result = {
+ out println s
+ defaultResult
+ }
+ }
+
+ /** Standard commands **/
+ val standardCommands: List[Command] = {
+ import CommandImplicits._
+ List(
+ OneArg("cp", "add an entry (jar or directory) to the classpath", addClasspath),
+ NoArgs("help", "print this help message", printHelp),
+ VarArgs("history", "show the history (optional arg: lines to show)", printHistory),
+ LineArg("h?", "search the history", searchHistory),
+ OneArg("load", "load and interpret a Scala file", load),
+ NoArgs("power", "enable power user mode", power),
+ NoArgs("quit", "exit the interpreter", () => Result(false, None)),
+ NoArgs("replay", "reset execution and replay all previous commands", replay),
+ LineArg("sh", "fork a shell and run a command", runShellCmd),
+ NoArgs("silent", "disable/enable automatic printing of results", verbosity)
+ )
+ }
+
+ /** Power user commands */
+ var powerUserOn = false
+ val powerCommands: List[Command] = {
+ import CommandImplicits._
+ List(
+ OneArg("completions", "generate list of completions for a given String", completions),
+ NoArgs("dump", "displays a view of the interpreter's internal state", () => interpreter.power.dump())
+
+ // VarArgs("tree", "displays ASTs for specified identifiers",
+ // (xs: List[String]) => interpreter dumpTrees xs)
+ // LineArg("meta", "given code which produces scala code, executes the results",
+ // (xs: List[String]) => )
+ )
+ }
+
+ /** Available commands */
+ def commands: List[Command] = standardCommands ::: (if (powerUserOn) powerCommands else Nil)
+
+ def initializeSpark() {
+ interpreter.beQuietDuring {
+ command("""
+ spark.repl.Main.interp.out.println("Registering with Mesos...");
+ spark.repl.Main.interp.out.flush();
+ @transient val sc = spark.repl.Main.interp.createSparkContext();
+ sc.waitForRegister();
+ spark.repl.Main.interp.out.println("Spark context available as sc.");
+ spark.repl.Main.interp.out.flush();
+ """)
+ command("import spark.SparkContext._");
+ }
+ plushln("Type in expressions to have them evaluated.")
+ plushln("Type :help for more information.")
+ }
+
+ var sparkContext: SparkContext = null
+
+ def createSparkContext(): SparkContext = {
+ val master = this.master match {
+ case Some(m) => m
+ case None => {
+ val prop = System.getenv("MASTER")
+ if (prop != null) prop else "local"
+ }
+ }
+ sparkContext = new SparkContext(master, "Spark shell")
+ sparkContext
+ }
+
+ /** The main read-eval-print loop for the interpreter. It calls
+ * <code>command()</code> for each line of input, and stops when
+ * <code>command()</code> returns <code>false</code>.
+ */
+ def repl() {
+ def readOneLine() = {
+ out.flush
+ in readLine prompt
+ }
+ // return false if repl should exit
+ def processLine(line: String): Boolean =
+ if (line eq null) false // assume null means EOF
+ else command(line) match {
+ case Result(false, _) => false
+ case Result(_, Some(finalLine)) => addReplay(finalLine) ; true
+ case _ => true
+ }
+
+ while (processLine(readOneLine)) { }
+ }
+
+ /** interpret all lines from a specified file */
+ def interpretAllFrom(file: File) {
+ val oldIn = in
+ val oldReplay = replayCommandStack
+
+ try file applyReader { reader =>
+ in = new SparkSimpleReader(reader, out, false)
+ plushln("Loading " + file + "...")
+ repl()
+ }
+ finally {
+ in = oldIn
+ replayCommandStack = oldReplay
+ }
+ }
+
+ /** create a new interpreter and replay all commands so far */
+ def replay() {
+ closeInterpreter()
+ createInterpreter()
+ for (cmd <- replayCommands) {
+ plushln("Replaying: " + cmd) // flush because maybe cmd will have its own output
+ command(cmd)
+ out.println
+ }
+ }
+
+ /** fork a shell and run a command */
+ def runShellCmd(line: String) {
+ // we assume if they're using :sh they'd appreciate being able to pipeline
+ interpreter.beQuietDuring {
+ interpreter.interpret("import _root_.scala.tools.nsc.io.Process.Pipe._")
+ }
+ val p = Process(line)
+ // only bind non-empty streams
+ def add(name: String, it: Iterator[String]) =
+ if (it.hasNext) interpreter.bind(name, "scala.List[String]", it.toList)
+
+ List(("stdout", p.stdout), ("stderr", p.stderr)) foreach (add _).tupled
+ }
+
+ def withFile(filename: String)(action: File => Unit) {
+ val f = File(filename)
+
+ if (f.exists) action(f)
+ else out.println("That file does not exist")
+ }
+
+ def load(arg: String) = {
+ var shouldReplay: Option[String] = None
+ withFile(arg)(f => {
+ interpretAllFrom(f)
+ shouldReplay = Some(":load " + arg)
+ })
+ Result(true, shouldReplay)
+ }
+
+ def addClasspath(arg: String): Unit = {
+ val f = File(arg).normalize
+ if (f.exists) {
+ addedClasspath = ClassPath.join(addedClasspath, f.path)
+ val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath)
+ println("Added '%s'. Your new classpath is:\n%s".format(f.path, totalClasspath))
+ replay()
+ }
+ else out.println("The path '" + f + "' doesn't seem to exist.")
+ }
+
+ def completions(arg: String): Unit = {
+ val comp = in.completion getOrElse { return println("Completion unavailable.") }
+ val xs = comp completions arg
+
+ injectAndName(xs)
+ }
+
+ def power() {
+ val powerUserBanner =
+ """** Power User mode enabled - BEEP BOOP **
+ |** scala.tools.nsc._ has been imported **
+ |** New vals! Try repl, global, power **
+ |** New cmds! :help to discover them **
+ |** New defs! Type power.<tab> to reveal **""".stripMargin
+
+ powerUserOn = true
+ interpreter.unleash()
+ injectOne("history", in.historyList)
+ in.completion foreach (x => injectOne("completion", x))
+ out println powerUserBanner
+ }
+
+ def verbosity() = {
+ val old = interpreter.printResults
+ interpreter.printResults = !old
+ out.println("Switched " + (if (old) "off" else "on") + " result printing.")
+ }
+
+ /** Run one command submitted by the user. Two values are returned:
+ * (1) whether to keep running, (2) the line to record for replay,
+ * if any. */
+ def command(line: String): Result = {
+ def withError(msg: String) = {
+ out println msg
+ Result(true, None)
+ }
+ def ambiguous(cmds: List[Command]) = "Ambiguous: did you mean " + cmds.map(":" + _.name).mkString(" or ") + "?"
+
+ // not a command
+ if (!line.startsWith(":")) {
+ // Notice failure to create compiler
+ if (interpreter.compiler == null) return Result(false, None)
+ else return Result(true, interpretStartingWith(line))
+ }
+
+ val tokens = (line drop 1 split """\s+""").toList
+ if (tokens.isEmpty)
+ return withError(ambiguous(commands))
+
+ val (cmd :: args) = tokens
+
+ // this lets us add commands willy-nilly and only requires enough command to disambiguate
+ commands.filter(_.name startsWith cmd) match {
+ case List(x) => x(args)
+ case Nil => withError("Unknown command. Type :help for help.")
+ case xs => withError(ambiguous(xs))
+ }
+ }
+
+ private val CONTINUATION_STRING = " | "
+ private val PROMPT_STRING = "scala> "
+
+ /** If it looks like they're pasting in a scala interpreter
+ * transcript, remove all the formatting we inserted so we
+ * can make some sense of it.
+ */
+ private var pasteStamp: Long = 0
+
+ /** Returns true if it's long enough to quit. */
+ def updatePasteStamp(): Boolean = {
+ /* Enough milliseconds between readLines to call it a day. */
+ val PASTE_FINISH = 1000
+
+ val prevStamp = pasteStamp
+ pasteStamp = System.currentTimeMillis
+
+ (pasteStamp - prevStamp > PASTE_FINISH)
+
+ }
+ /** TODO - we could look for the usage of resXX variables in the transcript.
+ * Right now backreferences to auto-named variables will break.
+ */
+
+ /** The trailing lines complication was an attempt to work around the introduction
+ * of newlines in e.g. email messages of repl sessions. It doesn't work because
+ * an unlucky newline can always leave you with a syntactically valid first line,
+ * which is executed before the next line is considered. So this doesn't actually
+ * accomplish anything, but I'm leaving it in case I decide to try harder.
+ */
+ case class PasteCommand(cmd: String, trailing: ListBuffer[String] = ListBuffer[String]())
+
+ /** Commands start on lines beginning with "scala>" and each successive
+ * line which begins with the continuation string is appended to that command.
+ * Everything else is discarded. When the end of the transcript is spotted,
+ * all the commands are replayed.
+ */
+ @tailrec private def cleanTranscript(lines: List[String], acc: List[PasteCommand]): List[PasteCommand] = lines match {
+ case Nil => acc.reverse
+ case x :: xs if x startsWith PROMPT_STRING =>
+ val first = x stripPrefix PROMPT_STRING
+ val (xs1, xs2) = xs span (_ startsWith CONTINUATION_STRING)
+ val rest = xs1 map (_ stripPrefix CONTINUATION_STRING)
+ val result = (first :: rest).mkString("", "\n", "\n")
+
+ cleanTranscript(xs2, PasteCommand(result) :: acc)
+
+ case ln :: lns =>
+ val newacc = acc match {
+ case Nil => Nil
+ case PasteCommand(cmd, trailing) :: accrest =>
+ PasteCommand(cmd, trailing :+ ln) :: accrest
+ }
+ cleanTranscript(lns, newacc)
+ }
+
+ /** The timestamp is for safety so it doesn't hang looking for the end
+ * of a transcript. Ad hoc parsing can't be too demanding. You can
+ * also use ctrl-D to start it parsing.
+ */
+ @tailrec private def interpretAsPastedTranscript(lines: List[String]) {
+ val line = in.readLine("")
+ val finished = updatePasteStamp()
+
+ if (line == null || finished || line.trim == PROMPT_STRING.trim) {
+ val xs = cleanTranscript(lines.reverse, Nil)
+ println("Replaying %d commands from interpreter transcript." format xs.size)
+ for (PasteCommand(cmd, trailing) <- xs) {
+ out.flush()
+ def runCode(code: String, extraLines: List[String]) {
+ (interpreter interpret code) match {
+ case IR.Incomplete if extraLines.nonEmpty =>
+ runCode(code + "\n" + extraLines.head, extraLines.tail)
+ case _ => ()
+ }
+ }
+ runCode(cmd, trailing.toList)
+ }
+ }
+ else
+ interpretAsPastedTranscript(line :: lines)
+ }
+
+ /** Interpret expressions starting with the first line.
+ * Read lines until a complete compilation unit is available
+ * or until a syntax error has been seen. If a full unit is
+ * read, go ahead and interpret it. Return the full string
+ * to be recorded for replay, if any.
+ */
+ def interpretStartingWith(code: String): Option[String] = {
+ // signal completion non-completion input has been received
+ in.completion foreach (_.resetVerbosity())
+
+ def reallyInterpret = interpreter.interpret(code) match {
+ case IR.Error => None
+ case IR.Success => Some(code)
+ case IR.Incomplete =>
+ if (in.interactive && code.endsWith("\n\n")) {
+ out.println("You typed two blank lines. Starting a new command.")
+ None
+ }
+ else in.readLine(CONTINUATION_STRING) match {
+ case null =>
+ // we know compilation is going to fail since we're at EOF and the
+ // parser thinks the input is still incomplete, but since this is
+ // a file being read non-interactively we want to fail. So we send
+ // it straight to the compiler for the nice error message.
+ interpreter.compileString(code)
+ None
+
+ case line => interpretStartingWith(code + "\n" + line)
+ }
+ }
+
+ /** Here we place ourselves between the user and the interpreter and examine
+ * the input they are ostensibly submitting. We intervene in several cases:
+ *
+ * 1) If the line starts with "scala> " it is assumed to be an interpreter paste.
+ * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation
+ * on the previous result.
+ * 3) If the Completion object's execute returns Some(_), we inject that value
+ * and avoid the interpreter, as it's likely not valid scala code.
+ */
+ if (code == "") None
+ else if (code startsWith PROMPT_STRING) {
+ updatePasteStamp()
+ interpretAsPastedTranscript(List(code))
+ None
+ }
+ else if (Completion.looksLikeInvocation(code) && interpreter.mostRecentVar != "") {
+ interpretStartingWith(interpreter.mostRecentVar + code)
+ }
+ else {
+ val result = for (comp <- in.completion ; res <- comp execute code) yield res
+ result match {
+ case Some(res) => injectAndName(res) ; None // completion took responsibility, so do not parse
+ case _ => reallyInterpret
+ }
+ }
+ }
+
+ // runs :load <file> on any files passed via -i
+ def loadFiles(settings: Settings) = settings match {
+ case settings: GenericRunnerSettings =>
+ for (filename <- settings.loadfiles.value) {
+ val cmd = ":load " + filename
+ command(cmd)
+ addReplay(cmd)
+ out.println()
+ }
+ case _ =>
+ }
+
+ def main(settings: Settings) {
+ this.settings = settings
+ createInterpreter()
+
+ // sets in to some kind of reader depending on environmental cues
+ in = in0 match {
+ case Some(in0) => new SparkSimpleReader(in0, out, true)
+ case None =>
+ // the interpreter is passed as an argument to expose tab completion info
+ if (settings.Xnojline.value || Properties.isEmacsShell) new SparkSimpleReader
+ else if (settings.noCompletion.value) SparkInteractiveReader.createDefault()
+ else SparkInteractiveReader.createDefault(interpreter)
+ }
+
+ loadFiles(settings)
+ try {
+ // it is broken on startup; go ahead and exit
+ if (interpreter.reporter.hasErrors) return
+
+ printWelcome()
+
+ // this is about the illusion of snappiness. We call initialize()
+ // which spins off a separate thread, then print the prompt and try
+ // our best to look ready. Ideally the user will spend a
+ // couple seconds saying "wow, it starts so fast!" and by the time
+ // they type a command the compiler is ready to roll.
+ interpreter.initialize()
+ initializeSpark()
+ repl()
+ }
+ finally closeInterpreter()
+ }
+
+ private def objClass(x: Any) = x.asInstanceOf[AnyRef].getClass
+ private def objName(x: Any) = {
+ val clazz = objClass(x)
+ val typeParams = clazz.getTypeParameters
+ val basename = clazz.getName
+ val tpString = if (typeParams.isEmpty) "" else "[%s]".format(typeParams map (_ => "_") mkString ", ")
+
+ basename + tpString
+ }
+
+ // injects one value into the repl; returns pair of name and class
+ def injectOne(name: String, obj: Any): Tuple2[String, String] = {
+ val className = objName(obj)
+ interpreter.quietBind(name, className, obj)
+ (name, className)
+ }
+ def injectAndName(obj: Any): Tuple2[String, String] = {
+ val name = interpreter.getVarName
+ val className = objName(obj)
+ interpreter.bind(name, className, obj)
+ (name, className)
+ }
+
+ // injects list of values into the repl; returns summary string
+ def injectDebug(args: List[Any]): String = {
+ val strs =
+ for ((arg, i) <- args.zipWithIndex) yield {
+ val varName = "p" + (i + 1)
+ val (vname, vtype) = injectOne(varName, arg)
+ vname + ": " + vtype
+ }
+
+ if (strs.size == 0) "Set no variables."
+ else "Variables set:\n" + strs.mkString("\n")
+ }
+
+ /** process command-line arguments and do as they request */
+ def main(args: Array[String]) {
+ def error1(msg: String) = out println ("scala: " + msg)
+ val command = new InterpreterCommand(args.toList, error1)
+ def neededHelp(): String =
+ (if (command.settings.help.value) command.usageMsg + "\n" else "") +
+ (if (command.settings.Xhelp.value) command.xusageMsg + "\n" else "")
+
+ // if they asked for no help and command is valid, we call the real main
+ neededHelp() match {
+ case "" => if (command.ok) main(command.settings) // else nothing
+ case help => plush(help)
+ }
+ }
+}
+
diff --git a/repl/src/main/scala/spark/repl/SparkInterpreterSettings.scala b/repl/src/main/scala/spark/repl/SparkInterpreterSettings.scala
new file mode 100644
index 0000000000..ffa477785b
--- /dev/null
+++ b/repl/src/main/scala/spark/repl/SparkInterpreterSettings.scala
@@ -0,0 +1,112 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2010 LAMP/EPFL
+ * @author Alexander Spoon
+ */
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+
+/** Settings for the interpreter
+ *
+ * @version 1.0
+ * @author Lex Spoon, 2007/3/24
+ **/
+class SparkInterpreterSettings(repl: SparkInterpreter) {
+ /** A list of paths where :load should look */
+ var loadPath = List(".")
+
+ /** The maximum length of toString to use when printing the result
+ * of an evaluation. 0 means no maximum. If a printout requires
+ * more than this number of characters, then the printout is
+ * truncated.
+ */
+ var maxPrintString = 800
+
+ /** The maximum number of completion candidates to print for tab
+ * completion without requiring confirmation.
+ */
+ var maxAutoprintCompletion = 250
+
+ /** String unwrapping can be disabled if it is causing issues.
+ * Settings this to false means you will see Strings like "$iw.$iw.".
+ */
+ var unwrapStrings = true
+
+ def deprecation_=(x: Boolean) = {
+ val old = repl.settings.deprecation.value
+ repl.settings.deprecation.value = x
+ if (!old && x) println("Enabled -deprecation output.")
+ else if (old && !x) println("Disabled -deprecation output.")
+ }
+ def deprecation: Boolean = repl.settings.deprecation.value
+
+ def allSettings = Map(
+ "maxPrintString" -> maxPrintString,
+ "maxAutoprintCompletion" -> maxAutoprintCompletion,
+ "unwrapStrings" -> unwrapStrings,
+ "deprecation" -> deprecation
+ )
+
+ private def allSettingsString =
+ allSettings.toList sortBy (_._1) map { case (k, v) => " " + k + " = " + v + "\n" } mkString
+
+ override def toString = """
+ | SparkInterpreterSettings {
+ | %s
+ | }""".stripMargin.format(allSettingsString)
+}
+
+/* Utilities for the InterpreterSettings class
+ *
+ * @version 1.0
+ * @author Lex Spoon, 2007/5/24
+ */
+object SparkInterpreterSettings {
+ /** Source code for the InterpreterSettings class. This is
+ * used so that the interpreter is sure to have the code
+ * available.
+ *
+ * XXX I'm not seeing why this degree of defensiveness is necessary.
+ * If files are missing the repl's not going to work, it's not as if
+ * we have string source backups for anything else.
+ */
+ val sourceCodeForClass =
+"""
+package scala.tools.nsc
+
+/** Settings for the interpreter
+ *
+ * @version 1.0
+ * @author Lex Spoon, 2007/3/24
+ **/
+class SparkInterpreterSettings(repl: Interpreter) {
+ /** A list of paths where :load should look */
+ var loadPath = List(".")
+
+ /** The maximum length of toString to use when printing the result
+ * of an evaluation. 0 means no maximum. If a printout requires
+ * more than this number of characters, then the printout is
+ * truncated.
+ */
+ var maxPrintString = 2400
+
+ def deprecation_=(x: Boolean) = {
+ val old = repl.settings.deprecation.value
+ repl.settings.deprecation.value = x
+ if (!old && x) println("Enabled -deprecation output.")
+ else if (old && !x) println("Disabled -deprecation output.")
+ }
+ def deprecation: Boolean = repl.settings.deprecation.value
+
+ override def toString =
+ "SparkInterpreterSettings {\n" +
+// " loadPath = " + loadPath + "\n" +
+ " maxPrintString = " + maxPrintString + "\n" +
+ "}"
+}
+
+"""
+
+}
diff --git a/repl/src/main/scala/spark/repl/SparkJLineReader.scala b/repl/src/main/scala/spark/repl/SparkJLineReader.scala
new file mode 100644
index 0000000000..9d761c06fc
--- /dev/null
+++ b/repl/src/main/scala/spark/repl/SparkJLineReader.scala
@@ -0,0 +1,38 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2010 LAMP/EPFL
+ * @author Stepan Koltsov
+ */
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+import scala.tools.nsc.interpreter
+import scala.tools.nsc.interpreter._
+
+import java.io.File
+import jline.{ ConsoleReader, ArgumentCompletor, History => JHistory }
+
+/** Reads from the console using JLine */
+class SparkJLineReader(interpreter: SparkInterpreter) extends SparkInteractiveReader {
+ def this() = this(null)
+
+ override lazy val history = Some(History(consoleReader))
+ override lazy val completion = Option(interpreter) map (x => new SparkCompletion(x))
+
+ val consoleReader = {
+ val r = new jline.ConsoleReader()
+ r setHistory (History().jhistory)
+ r setBellEnabled false
+ completion foreach { c =>
+ r addCompletor c.jline
+ r setAutoprintThreshhold 250
+ }
+
+ r
+ }
+
+ def readOneLine(prompt: String) = consoleReader readLine prompt
+ val interactive = true
+}
+
diff --git a/repl/src/main/scala/spark/repl/SparkSimpleReader.scala b/repl/src/main/scala/spark/repl/SparkSimpleReader.scala
new file mode 100644
index 0000000000..2b24c4bf63
--- /dev/null
+++ b/repl/src/main/scala/spark/repl/SparkSimpleReader.scala
@@ -0,0 +1,33 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2010 LAMP/EPFL
+ * @author Stepan Koltsov
+ */
+
+package spark.repl
+
+import scala.tools.nsc
+import scala.tools.nsc._
+import scala.tools.nsc.interpreter
+import scala.tools.nsc.interpreter._
+
+import java.io.{ BufferedReader, PrintWriter }
+import io.{ Path, File, Directory }
+
+/** Reads using standard JDK API */
+class SparkSimpleReader(
+ in: BufferedReader,
+ out: PrintWriter,
+ val interactive: Boolean)
+extends SparkInteractiveReader {
+ def this() = this(Console.in, new PrintWriter(Console.out), true)
+ def this(in: File, out: PrintWriter, interactive: Boolean) = this(in.bufferedReader(), out, interactive)
+
+ def close() = in.close()
+ def readOneLine(prompt: String): String = {
+ if (interactive) {
+ out.print(prompt)
+ out.flush()
+ }
+ in.readLine()
+ }
+}
diff --git a/repl/src/test/scala/spark/repl/ReplSuite.scala b/repl/src/test/scala/spark/repl/ReplSuite.scala
new file mode 100644
index 0000000000..829b1d934e
--- /dev/null
+++ b/repl/src/test/scala/spark/repl/ReplSuite.scala
@@ -0,0 +1,144 @@
+package spark.repl
+
+import java.io._
+import java.net.URLClassLoader
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.JavaConversions._
+
+import org.scalatest.FunSuite
+
+class ReplSuite extends FunSuite {
+ def runInterpreter(master: String, input: String): String = {
+ val in = new BufferedReader(new StringReader(input + "\n"))
+ val out = new StringWriter()
+ val cl = getClass.getClassLoader
+ var paths = new ArrayBuffer[String]
+ if (cl.isInstanceOf[URLClassLoader]) {
+ val urlLoader = cl.asInstanceOf[URLClassLoader]
+ for (url <- urlLoader.getURLs) {
+ if (url.getProtocol == "file") {
+ paths += url.getFile
+ }
+ }
+ }
+ val interp = new SparkInterpreterLoop(in, new PrintWriter(out), master)
+ spark.repl.Main.interp = interp
+ val separator = System.getProperty("path.separator")
+ interp.main(Array("-classpath", paths.mkString(separator)))
+ spark.repl.Main.interp = null
+ if (interp.sparkContext != null)
+ interp.sparkContext.stop()
+ return out.toString
+ }
+
+ def assertContains(message: String, output: String) {
+ assert(output contains message,
+ "Interpreter output did not contain '" + message + "':\n" + output)
+ }
+
+ def assertDoesNotContain(message: String, output: String) {
+ assert(!(output contains message),
+ "Interpreter output contained '" + message + "':\n" + output)
+ }
+
+ test ("simple foreach with accumulator") {
+ val output = runInterpreter("local", """
+ val accum = sc.accumulator(0)
+ sc.parallelize(1 to 10).foreach(x => accum += x)
+ accum.value
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res1: Int = 55", output)
+ }
+
+ test ("external vars") {
+ val output = runInterpreter("local", """
+ var v = 7
+ sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_)
+ v = 10
+ sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_)
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 70", output)
+ assertContains("res1: Int = 100", output)
+ }
+
+ test ("external classes") {
+ val output = runInterpreter("local", """
+ class C {
+ def foo = 5
+ }
+ sc.parallelize(1 to 10).map(x => (new C).foo).collect.reduceLeft(_+_)
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 50", output)
+ }
+
+ test ("external functions") {
+ val output = runInterpreter("local", """
+ def double(x: Int) = x + x
+ sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_)
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 110", output)
+ }
+
+ test ("external functions that access vars") {
+ val output = runInterpreter("local", """
+ var v = 7
+ def getV() = v
+ sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ v = 10
+ sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 70", output)
+ assertContains("res1: Int = 100", output)
+ }
+
+ test ("broadcast vars") {
+ // Test that the value that a broadcast var had when it was created is used,
+ // even if that variable is then modified in the driver program
+ // TODO: This doesn't actually work for arrays when we run in local mode!
+ val output = runInterpreter("local", """
+ var array = new Array[Int](5)
+ val broadcastArray = sc.broadcast(array)
+ sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ array(0) = 5
+ sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output)
+ }
+
+ if (System.getenv("MESOS_HOME") != null) {
+ test ("running on Mesos") {
+ val output = runInterpreter("localquiet", """
+ var v = 7
+ def getV() = v
+ sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ v = 10
+ sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ var array = new Array[Int](5)
+ val broadcastArray = sc.broadcast(array)
+ sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ array(0) = 5
+ sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 70", output)
+ assertContains("res1: Int = 100", output)
+ assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ }
+ }
+}