diff options
Diffstat (limited to 'repl')
23 files changed, 3692 insertions, 5 deletions
diff --git a/repl/pom.xml b/repl/pom.xml index af528c8914..bd688c8c1e 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -39,6 +39,11 @@ <dependencies> <dependency> + <groupId>${jline.groupid}</groupId> + <artifactId>jline</artifactId> + <version>${jline.version}</version> + </dependency> + <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_${scala.binary.version}</artifactId> <version>${project.version}</version> @@ -76,11 +81,6 @@ <version>${scala.version}</version> </dependency> <dependency> - <groupId>org.scala-lang</groupId> - <artifactId>jline</artifactId> - <version>${scala.version}</version> - </dependency> - <dependency> <groupId>org.slf4j</groupId> <artifactId>jul-to-slf4j</artifactId> </dependency> @@ -124,4 +124,84 @@ </plugin> </plugins> </build> + <profiles> + <profile> + <id>scala-2.10</id> + <build> + <plugins> + <plugin> + <groupId>org.codehaus.mojo</groupId> + <artifactId>build-helper-maven-plugin</artifactId> + <executions> + <execution> + <id>add-scala-sources</id> + <phase>generate-sources</phase> + <goals> + <goal>add-source</goal> + </goals> + <configuration> + <sources> + <source>src/main/scala</source> + <source>scala-2.10/src/main/scala</source> + </sources> + </configuration> + </execution> + <execution> + <id>add-scala-test-sources</id> + <phase>generate-test-sources</phase> + <goals> + <goal>add-test-source</goal> + </goals> + <configuration> + <sources> + <source>src/test/scala</source> + <source>scala-2.10/src/test/scala</source> + </sources> + </configuration> + </execution> + </executions> + </plugin> + </plugins> + </build> + </profile> + <profile> + <id>scala-2.11</id> + <build> + <plugins> + <plugin> + <groupId>org.codehaus.mojo</groupId> + <artifactId>build-helper-maven-plugin</artifactId> + <executions> + <execution> + <id>add-scala-sources</id> + <phase>generate-sources</phase> + <goals> + <goal>add-source</goal> + </goals> + <configuration> + <sources> + <source>src/main/scala</source> + <source>scala-2.11/src/main/scala</source> + </sources> + </configuration> + </execution> + <execution> + <id>add-scala-test-sources</id> + <phase>generate-test-sources</phase> + <goals> + <goal>add-test-source</goal> + </goals> + <configuration> + <sources> + <source>src/test/scala</source> + <source>scala-2.11/src/test/scala</source> + </sources> + </configuration> + </execution> + </executions> + </plugin> + </plugins> + </build> + </profile> + </profiles> </project> diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala index 14b448d076..14b448d076 100644 --- a/repl/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala index 05816941b5..05816941b5 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala index f8432c8af6..f8432c8af6 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala index 5340951d91..5340951d91 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index e56b74edba..e56b74edba 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index 7667a9c119..7667a9c119 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 646c68e60c..646c68e60c 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala index 193a42dcde..193a42dcde 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala index 3159b70008..3159b70008 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala index 0db26c3407..0db26c3407 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala index 13cd2b7fa5..13cd2b7fa5 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala index 7fd5fbb424..7fd5fbb424 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 91c9c52c3c..91c9c52c3c 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala new file mode 100644 index 0000000000..5e93a71995 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import org.apache.spark.util.Utils +import org.apache.spark._ + +import scala.tools.nsc.Settings +import scala.tools.nsc.interpreter.SparkILoop + +object Main extends Logging { + + val conf = new SparkConf() + val tmp = System.getProperty("java.io.tmpdir") + val rootDir = conf.get("spark.repl.classdir", tmp) + val outputDir = Utils.createTempDir(rootDir) + val s = new Settings() + s.processArguments(List("-Yrepl-class-based", + "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-sync"), true) + val classServer = new HttpServer(outputDir, new SecurityManager(conf)) + var sparkContext: SparkContext = _ + var interp = new SparkILoop // this is a public var because tests reset it. + + def main(args: Array[String]) { + if (getMaster == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + // Start the classServer and store its URI in a spark system property + // (which will be passed to executors so that they can connect to it) + classServer.start() + interp.process(s) // Repl starts and goes in loop of R.E.P.L + classServer.stop() + Option(sparkContext).map(_.stop) + } + + + def getAddedJars: Array[String] = { + val envJars = sys.env.get("ADD_JARS") + val propJars = sys.props.get("spark.jars").flatMap { p => if (p == "") None else Some(p) } + val jars = propJars.orElse(envJars).getOrElse("") + Utils.resolveURIs(jars).split(",").filter(_.nonEmpty) + } + + def createSparkContext(): SparkContext = { + val execUri = System.getenv("SPARK_EXECUTOR_URI") + val jars = getAddedJars + val conf = new SparkConf() + .setMaster(getMaster) + .setAppName("Spark shell") + .setJars(jars) + .set("spark.repl.class.uri", classServer.uri) + logInfo("Spark class server started at " + classServer.uri) + if (execUri != null) { + conf.set("spark.executor.uri", execUri) + } + if (System.getenv("SPARK_HOME") != null) { + conf.setSparkHome(System.getenv("SPARK_HOME")) + } + sparkContext = new SparkContext(conf) + logInfo("Created spark context..") + sparkContext + } + + private def getMaster: String = { + val master = { + val envMaster = sys.env.get("MASTER") + val propMaster = sys.props.get("spark.master") + propMaster.orElse(envMaster).getOrElse("local[*]") + } + master + } +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala new file mode 100644 index 0000000000..8e519fa67f --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala @@ -0,0 +1,86 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Paul Phillips + */ + +package scala.tools.nsc +package interpreter + +import scala.tools.nsc.ast.parser.Tokens.EOF + +trait SparkExprTyper { + val repl: SparkIMain + + import repl._ + import global.{ reporter => _, Import => _, _ } + import naming.freshInternalVarName + + def symbolOfLine(code: String): Symbol = { + def asExpr(): Symbol = { + val name = freshInternalVarName() + // Typing it with a lazy val would give us the right type, but runs + // into compiler bugs with things like existentials, so we compile it + // behind a def and strip the NullaryMethodType which wraps the expr. + val line = "def " + name + " = " + code + + interpretSynthetic(line) match { + case IR.Success => + val sym0 = symbolOfTerm(name) + // drop NullaryMethodType + sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType) + case _ => NoSymbol + } + } + def asDefn(): Symbol = { + val old = repl.definedSymbolList.toSet + + interpretSynthetic(code) match { + case IR.Success => + repl.definedSymbolList filterNot old match { + case Nil => NoSymbol + case sym :: Nil => sym + case syms => NoSymbol.newOverloaded(NoPrefix, syms) + } + case _ => NoSymbol + } + } + def asError(): Symbol = { + interpretSynthetic(code) + NoSymbol + } + beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError() + } + + private var typeOfExpressionDepth = 0 + def typeOfExpression(expr: String, silent: Boolean = true): Type = { + if (typeOfExpressionDepth > 2) { + repldbg("Terminating typeOfExpression recursion for expression: " + expr) + return NoType + } + typeOfExpressionDepth += 1 + // Don't presently have a good way to suppress undesirable success output + // while letting errors through, so it is first trying it silently: if there + // is an error, and errors are desired, then it re-evaluates non-silently + // to induce the error message. + try beSilentDuring(symbolOfLine(expr).tpe) match { + case NoType if !silent => symbolOfLine(expr).tpe // generate error + case tpe => tpe + } + finally typeOfExpressionDepth -= 1 + } + + // This only works for proper types. + def typeOfTypeString(typeString: String): Type = { + def asProperType(): Option[Type] = { + val name = freshInternalVarName() + val line = "def %s: %s = ???" format (name, typeString) + interpretSynthetic(line) match { + case IR.Success => + val sym0 = symbolOfTerm(name) + Some(sym0.asMethod.returnType) + case _ => None + } + } + beSilentDuring(asProperType()) getOrElse NoType + } +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala new file mode 100644 index 0000000000..a591e9fc46 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -0,0 +1,966 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Alexander Spoon + */ + +package scala +package tools.nsc +package interpreter + +import scala.language.{ implicitConversions, existentials } +import scala.annotation.tailrec +import Predef.{ println => _, _ } +import interpreter.session._ +import StdReplTags._ +import scala.reflect.api.{Mirror, Universe, TypeCreator} +import scala.util.Properties.{ jdkHome, javaVersion, versionString, javaVmName } +import scala.tools.nsc.util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream } +import scala.reflect.{ClassTag, classTag} +import scala.reflect.internal.util.{ BatchSourceFile, ScalaClassLoader } +import ScalaClassLoader._ +import scala.reflect.io.{ File, Directory } +import scala.tools.util._ +import scala.collection.generic.Clearable +import scala.concurrent.{ ExecutionContext, Await, Future, future } +import ExecutionContext.Implicits._ +import java.io.{ BufferedReader, FileReader } + +/** The Scala interactive shell. It provides a read-eval-print loop + * around the Interpreter class. + * After instantiation, clients should call the main() method. + * + * If no in0 is specified, then input will come from the console, and + * the class will attempt to provide input editing feature such as + * input history. + * + * @author Moez A. Abdel-Gawad + * @author Lex Spoon + * @version 1.2 + */ +class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter) + extends AnyRef + with LoopCommands +{ + def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) + def this() = this(None, new JPrintWriter(Console.out, true)) +// +// @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp +// @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: Interpreter): Unit = intp = i + + var in: InteractiveReader = _ // the input stream from which commands come + var settings: Settings = _ + var intp: SparkIMain = _ + + var globalFuture: Future[Boolean] = _ + + protected def asyncMessage(msg: String) { + if (isReplInfo || isReplPower) + echoAndRefresh(msg) + } + + def initializeSpark() { + intp.beQuietDuring { + command( """ + @transient val sc = org.apache.spark.repl.Main.createSparkContext(); + """) + command("import org.apache.spark.SparkContext._") + } + echo("Spark context available as sc.") + } + + /** Print a welcome message */ + def printWelcome() { + import org.apache.spark.SPARK_VERSION + echo("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + """.format(SPARK_VERSION)) + val welcomeMsg = "Using Scala %s (%s, Java %s)".format( + versionString, javaVmName, javaVersion) + echo(welcomeMsg) + echo("Type in expressions to have them evaluated.") + echo("Type :help for more information.") + } + + override def echoCommandMessage(msg: String) { + intp.reporter printUntruncatedMessage msg + } + + // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals]) + def history = in.history + + // classpath entries added via :cp + var addedClasspath: String = "" + + /** A reverse list of commands to replay if the user requests a :replay */ + var replayCommandStack: List[String] = Nil + + /** A list of commands to replay if the user requests a :replay */ + def replayCommands = replayCommandStack.reverse + + /** Record a command for replay should the user request a :replay */ + def addReplay(cmd: String) = replayCommandStack ::= cmd + + def savingReplayStack[T](body: => T): T = { + val saved = replayCommandStack + try body + finally replayCommandStack = saved + } + def savingReader[T](body: => T): T = { + val saved = in + try body + finally in = saved + } + + /** Close the interpreter and set the var to null. */ + def closeInterpreter() { + if (intp ne null) { + intp.close() + intp = null + } + } + + class SparkILoopInterpreter extends SparkIMain(settings, out) { + outer => + + override lazy val formatting = new Formatting { + def prompt = SparkILoop.this.prompt + } + override protected def parentClassLoader = + settings.explicitParentLoader.getOrElse( classOf[SparkILoop].getClassLoader ) + } + + /** Create a new interpreter. */ + def createInterpreter() { + if (addedClasspath != "") + settings.classpath append addedClasspath + + intp = new SparkILoopInterpreter + } + + /** print a friendly help message */ + def helpCommand(line: String): Result = { + if (line == "") helpSummary() + else uniqueCommand(line) match { + case Some(lc) => echo("\n" + lc.help) + case _ => ambiguousError(line) + } + } + private def helpSummary() = { + val usageWidth = commands map (_.usageMsg.length) max + val formatStr = "%-" + usageWidth + "s %s" + + echo("All commands can be abbreviated, e.g. :he instead of :help.") + + commands foreach { cmd => + echo(formatStr.format(cmd.usageMsg, cmd.help)) + } + } + private def ambiguousError(cmd: String): Result = { + matchingCommands(cmd) match { + case Nil => echo(cmd + ": no such command. Type :help for help.") + case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?") + } + Result(keepRunning = true, None) + } + private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd) + private def uniqueCommand(cmd: String): Option[LoopCommand] = { + // this lets us add commands willy-nilly and only requires enough command to disambiguate + matchingCommands(cmd) match { + case List(x) => Some(x) + // exact match OK even if otherwise appears ambiguous + case xs => xs find (_.name == cmd) + } + } + + /** Show the history */ + lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { + override def usage = "[num]" + def defaultLines = 20 + + def apply(line: String): Result = { + if (history eq NoHistory) + return "No history available." + + val xs = words(line) + val current = history.index + val count = try xs.head.toInt catch { case _: Exception => defaultLines } + val lines = history.asStrings takeRight count + val offset = current - lines.size + 1 + + for ((line, index) <- lines.zipWithIndex) + echo("%3d %s".format(index + offset, line)) + } + } + + // When you know you are most likely breaking into the middle + // of a line being typed. This softens the blow. + protected def echoAndRefresh(msg: String) = { + echo("\n" + msg) + in.redrawLine() + } + protected def echo(msg: String) = { + out println msg + out.flush() + } + + /** Search the history */ + def searchHistory(_cmdline: String) { + val cmdline = _cmdline.toLowerCase + val offset = history.index - history.size + 1 + + for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline) + echo("%d %s".format(index + offset, line)) + } + + private val currentPrompt = Properties.shellPromptString + + /** Prompt to print when awaiting input */ + def prompt = currentPrompt + + import LoopCommand.{ cmd, nullary } + + /** Standard commands **/ + lazy val standardCommands = List( + cmd("cp", "<path>", "add a jar or directory to the classpath", addClasspath), + cmd("edit", "<id>|<line>", "edit history", editCommand), + cmd("help", "[command]", "print this summary or command-specific help", helpCommand), + historyCommand, + cmd("h?", "<string>", "search the history", searchHistory), + cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand), + //cmd("implicits", "[-v]", "show the implicits in scope", intp.implicitsCommand), + cmd("javap", "<path|class>", "disassemble a file or class name", javapCommand), + cmd("line", "<id>|<line>", "place line(s) at the end of history", lineCommand), + cmd("load", "<path>", "interpret lines in a file", loadCommand), + cmd("paste", "[-raw] [path]", "enter paste mode or paste a file", pasteCommand), + // nullary("power", "enable power user mode", powerCmd), + nullary("quit", "exit the interpreter", () => Result(keepRunning = false, None)), + nullary("replay", "reset execution and replay all previous commands", replay), + nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand), + cmd("save", "<path>", "save replayable session to a file", saveCommand), + shCommand, + cmd("settings", "[+|-]<options>", "+enable/-disable flags, set compiler options", changeSettings), + nullary("silent", "disable/enable automatic printing of results", verbosity), +// cmd("type", "[-v] <expr>", "display the type of an expression without evaluating it", typeCommand), +// cmd("kind", "[-v] <expr>", "display the kind of expression's type", kindCommand), + nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) + ) + + /** Power user commands */ +// lazy val powerCommands: List[LoopCommand] = List( +// cmd("phase", "<phase>", "set the implicit phase for power commands", phaseCommand) +// ) + + private def importsCommand(line: String): Result = { + val tokens = words(line) + val handlers = intp.languageWildcardHandlers ++ intp.importHandlers + + handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach { + case (handler, idx) => + val (types, terms) = handler.importedSymbols partition (_.name.isTypeName) + val imps = handler.implicitSymbols + val found = tokens filter (handler importsSymbolNamed _) + val typeMsg = if (types.isEmpty) "" else types.size + " types" + val termMsg = if (terms.isEmpty) "" else terms.size + " terms" + val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit" + val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "") + val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")") + + intp.reporter.printMessage("%2d) %-30s %s%s".format( + idx + 1, + handler.importString, + statsMsg, + foundMsg + )) + } + } + + private def findToolsJar() = PathResolver.SupplementalLocations.platformTools + + private def addToolsJarToLoader() = { + val cl = findToolsJar() match { + case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader) + case _ => intp.classLoader + } + if (Javap.isAvailable(cl)) { + repldbg(":javap available.") + cl + } + else { + repldbg(":javap unavailable: no tools.jar at " + jdkHome) + intp.classLoader + } + } +// +// protected def newJavap() = +// JavapClass(addToolsJarToLoader(), new IMain.ReplStrippingWriter(intp), Some(intp)) +// +// private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap()) + + // Still todo: modules. +// private def typeCommand(line0: String): Result = { +// line0.trim match { +// case "" => ":type [-v] <expression>" +// case s => intp.typeCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ") +// } +// } + +// private def kindCommand(expr: String): Result = { +// expr.trim match { +// case "" => ":kind [-v] <expression>" +// case s => intp.kindCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ") +// } +// } + + private def warningsCommand(): Result = { + if (intp.lastWarnings.isEmpty) + "Can't find any cached warnings." + else + intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) } + } + + private def changeSettings(args: String): Result = { + def showSettings() = { + for (s <- settings.userSetSettings.toSeq.sorted) echo(s.toString) + } + def updateSettings() = { + // put aside +flag options + val (pluses, rest) = (args split "\\s+").toList partition (_.startsWith("+")) + val tmps = new Settings + val (ok, leftover) = tmps.processArguments(rest, processAll = true) + if (!ok) echo("Bad settings request.") + else if (leftover.nonEmpty) echo("Unprocessed settings.") + else { + // boolean flags set-by-user on tmp copy should be off, not on + val offs = tmps.userSetSettings filter (_.isInstanceOf[Settings#BooleanSetting]) + val (minuses, nonbools) = rest partition (arg => offs exists (_ respondsTo arg)) + // update non-flags + settings.processArguments(nonbools, processAll = true) + // also snag multi-value options for clearing, e.g. -Ylog: and -language: + for { + s <- settings.userSetSettings + if s.isInstanceOf[Settings#MultiStringSetting] || s.isInstanceOf[Settings#PhasesSetting] + if nonbools exists (arg => arg.head == '-' && arg.last == ':' && (s respondsTo arg.init)) + } s match { + case c: Clearable => c.clear() + case _ => + } + def update(bs: Seq[String], name: String=>String, setter: Settings#Setting=>Unit) = { + for (b <- bs) + settings.lookupSetting(name(b)) match { + case Some(s) => + if (s.isInstanceOf[Settings#BooleanSetting]) setter(s) + else echo(s"Not a boolean flag: $b") + case _ => + echo(s"Not an option: $b") + } + } + update(minuses, identity, _.tryToSetFromPropertyValue("false")) // turn off + update(pluses, "-" + _.drop(1), _.tryToSet(Nil)) // turn on + } + } + if (args.isEmpty) showSettings() else updateSettings() + } + + private def javapCommand(line: String): Result = { +// if (javap == null) +// ":javap unavailable, no tools.jar at %s. Set JDK_HOME.".format(jdkHome) +// else if (line == "") +// ":javap [-lcsvp] [path1 path2 ...]" +// else +// javap(words(line)) foreach { res => +// if (res.isError) return "Failed: " + res.value +// else res.show() +// } + } + + private def pathToPhaseWrapper = intp.originalPath("$r") + ".phased.atCurrent" + + private def phaseCommand(name: String): Result = { +// val phased: Phased = power.phased +// import phased.NoPhaseName +// +// if (name == "clear") { +// phased.set(NoPhaseName) +// intp.clearExecutionWrapper() +// "Cleared active phase." +// } +// else if (name == "") phased.get match { +// case NoPhaseName => "Usage: :phase <expr> (e.g. typer, erasure.next, erasure+3)" +// case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get) +// } +// else { +// val what = phased.parse(name) +// if (what.isEmpty || !phased.set(what)) +// "'" + name + "' does not appear to represent a valid phase." +// else { +// intp.setExecutionWrapper(pathToPhaseWrapper) +// val activeMessage = +// if (what.toString.length == name.length) "" + what +// else "%s (%s)".format(what, name) +// +// "Active phase is now: " + activeMessage +// } +// } + } + + /** Available commands */ + def commands: List[LoopCommand] = standardCommands ++ ( + // if (isReplPower) + // powerCommands + // else + Nil + ) + + val replayQuestionMessage = + """|That entry seems to have slain the compiler. Shall I replay + |your session? I can re-run each line except the last one. + |[y/n] + """.trim.stripMargin + + private val crashRecovery: PartialFunction[Throwable, Boolean] = { + case ex: Throwable => + val (err, explain) = ( + if (intp.isInitializeComplete) + (intp.global.throwableAsString(ex), "") + else + (ex.getMessage, "The compiler did not initialize.\n") + ) + echo(err) + + ex match { + case _: NoSuchMethodError | _: NoClassDefFoundError => + echo("\nUnrecoverable error.") + throw ex + case _ => + def fn(): Boolean = + try in.readYesOrNo(explain + replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) + catch { case _: RuntimeException => false } + + if (fn()) replay() + else echo("\nAbandoning crashed session.") + } + true + } + + // return false if repl should exit + def processLine(line: String): Boolean = { + import scala.concurrent.duration._ + Await.ready(globalFuture, 60.seconds) + + (line ne null) && (command(line) match { + case Result(false, _) => false + case Result(_, Some(line)) => addReplay(line) ; true + case _ => true + }) + } + + private def readOneLine() = { + out.flush() + in readLine prompt + } + + /** The main read-eval-print loop for the repl. It calls + * command() for each line of input, and stops when + * command() returns false. + */ + @tailrec final def loop() { + if ( try processLine(readOneLine()) catch crashRecovery ) + loop() + } + + /** interpret all lines from a specified file */ + def interpretAllFrom(file: File) { + savingReader { + savingReplayStack { + file applyReader { reader => + in = SimpleReader(reader, out, interactive = false) + echo("Loading " + file + "...") + loop() + } + } + } + } + + /** create a new interpreter and replay the given commands */ + def replay() { + reset() + if (replayCommandStack.isEmpty) + echo("Nothing to replay.") + else for (cmd <- replayCommands) { + echo("Replaying: " + cmd) // flush because maybe cmd will have its own output + command(cmd) + echo("") + } + } + def resetCommand() { + echo("Resetting interpreter state.") + if (replayCommandStack.nonEmpty) { + echo("Forgetting this session history:\n") + replayCommands foreach echo + echo("") + replayCommandStack = Nil + } + if (intp.namedDefinedTerms.nonEmpty) + echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", ")) + if (intp.definedTypes.nonEmpty) + echo("Forgetting defined types: " + intp.definedTypes.mkString(", ")) + + reset() + } + def reset() { + intp.reset() + unleashAndSetPhase() + } + + def lineCommand(what: String): Result = editCommand(what, None) + + // :edit id or :edit line + def editCommand(what: String): Result = editCommand(what, Properties.envOrNone("EDITOR")) + + def editCommand(what: String, editor: Option[String]): Result = { + def diagnose(code: String) = { + echo("The edited code is incomplete!\n") + val errless = intp compileSources new BatchSourceFile("<pastie>", s"object pastel {\n$code\n}") + if (errless) echo("The compiler reports no errors.") + } + def historicize(text: String) = history match { + case jlh: JLineHistory => text.lines foreach jlh.add ; jlh.moveToEnd() ; true + case _ => false + } + def edit(text: String): Result = editor match { + case Some(ed) => + val tmp = File.makeTemp() + tmp.writeAll(text) + try { + val pr = new ProcessResult(s"$ed ${tmp.path}") + pr.exitCode match { + case 0 => + tmp.safeSlurp() match { + case Some(edited) if edited.trim.isEmpty => echo("Edited text is empty.") + case Some(edited) => + echo(edited.lines map ("+" + _) mkString "\n") + val res = intp interpret edited + if (res == IR.Incomplete) diagnose(edited) + else { + historicize(edited) + Result(lineToRecord = Some(edited), keepRunning = true) + } + case None => echo("Can't read edited text. Did you delete it?") + } + case x => echo(s"Error exit from $ed ($x), ignoring") + } + } finally { + tmp.delete() + } + case None => + if (historicize(text)) echo("Placing text in recent history.") + else echo(f"No EDITOR defined and you can't change history, echoing your text:%n$text") + } + + // if what is a number, use it as a line number or range in history + def isNum = what forall (c => c.isDigit || c == '-' || c == '+') + // except that "-" means last value + def isLast = (what == "-") + if (isLast || !isNum) { + val name = if (isLast) intp.mostRecentVar else what + val sym = intp.symbolOfIdent(name) + intp.prevRequestList collectFirst { case r if r.defines contains sym => r } match { + case Some(req) => edit(req.line) + case None => echo(s"No symbol in scope: $what") + } + } else try { + val s = what + // line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur) + val (start, len) = + if ((s indexOf '+') > 0) { + val (a,b) = s splitAt (s indexOf '+') + (a.toInt, b.drop(1).toInt) + } else { + (s indexOf '-') match { + case -1 => (s.toInt, 1) + case 0 => val n = s.drop(1).toInt ; (history.index - n, n) + case _ if s.last == '-' => val n = s.init.toInt ; (n, history.index - n) + case i => val n = s.take(i).toInt ; (n, s.drop(i+1).toInt - n) + } + } + import scala.collection.JavaConverters._ + val index = (start - 1) max 0 + val text = history match { + case jlh: JLineHistory => jlh.entries(index).asScala.take(len) map (_.value) mkString "\n" + case _ => history.asStrings.slice(index, index + len) mkString "\n" + } + edit(text) + } catch { + case _: NumberFormatException => echo(s"Bad range '$what'") + echo("Use line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur)") + } + } + + /** fork a shell and run a command */ + lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { + override def usage = "<command line>" + def apply(line: String): Result = line match { + case "" => showUsage() + case _ => + val toRun = s"new ${classOf[ProcessResult].getName}(${string2codeQuoted(line)})" + intp interpret toRun + () + } + } + + def withFile[A](filename: String)(action: File => A): Option[A] = { + val res = Some(File(filename)) filter (_.exists) map action + if (res.isEmpty) echo("That file does not exist") // courtesy side-effect + res + } + + def loadCommand(arg: String) = { + var shouldReplay: Option[String] = None + withFile(arg)(f => { + interpretAllFrom(f) + shouldReplay = Some(":load " + arg) + }) + Result(keepRunning = true, shouldReplay) + } + + def saveCommand(filename: String): Result = ( + if (filename.isEmpty) echo("File name is required.") + else if (replayCommandStack.isEmpty) echo("No replay commands in session") + else File(filename).printlnAll(replayCommands: _*) + ) + + def addClasspath(arg: String): Unit = { + val f = File(arg).normalize + if (f.exists) { + addedClasspath = ClassPath.join(addedClasspath, f.path) + val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath) + echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, totalClasspath)) + replay() + } + else echo("The path '" + f + "' doesn't seem to exist.") + } + + def powerCmd(): Result = { + if (isReplPower) "Already in power mode." + else enablePowerMode(isDuringInit = false) + } + def enablePowerMode(isDuringInit: Boolean) = { + replProps.power setValue true + unleashAndSetPhase() + // asyncEcho(isDuringInit, power.banner) + } + private def unleashAndSetPhase() { + if (isReplPower) { + // power.unleash() + // Set the phase to "typer" + // intp beSilentDuring phaseCommand("typer") + } + } + + def asyncEcho(async: Boolean, msg: => String) { + if (async) asyncMessage(msg) + else echo(msg) + } + + def verbosity() = { + val old = intp.printResults + intp.printResults = !old + echo("Switched " + (if (old) "off" else "on") + " result printing.") + } + + /** Run one command submitted by the user. Two values are returned: + * (1) whether to keep running, (2) the line to record for replay, + * if any. */ + def command(line: String): Result = { + if (line startsWith ":") { + val cmd = line.tail takeWhile (x => !x.isWhitespace) + uniqueCommand(cmd) match { + case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace)) + case _ => ambiguousError(cmd) + } + } + else if (intp.global == null) Result(keepRunning = false, None) // Notice failure to create compiler + else Result(keepRunning = true, interpretStartingWith(line)) + } + + private def readWhile(cond: String => Boolean) = { + Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) + } + + def pasteCommand(arg: String): Result = { + var shouldReplay: Option[String] = None + def result = Result(keepRunning = true, shouldReplay) + val (raw, file) = + if (arg.isEmpty) (false, None) + else { + val r = """(-raw)?(\s+)?([^\-]\S*)?""".r + arg match { + case r(flag, sep, name) => + if (flag != null && name != null && sep == null) + echo(s"""I assume you mean "$flag $name"?""") + (flag != null, Option(name)) + case _ => + echo("usage: :paste -raw file") + return result + } + } + val code = file match { + case Some(name) => + withFile(name)(f => { + shouldReplay = Some(s":paste $arg") + val s = f.slurp.trim + if (s.isEmpty) echo(s"File contains no code: $f") + else echo(s"Pasting file $f...") + s + }) getOrElse "" + case None => + echo("// Entering paste mode (ctrl-D to finish)\n") + val text = (readWhile(_ => true) mkString "\n").trim + if (text.isEmpty) echo("\n// Nothing pasted, nothing gained.\n") + else echo("\n// Exiting paste mode, now interpreting.\n") + text + } + def interpretCode() = { + val res = intp interpret code + // if input is incomplete, let the compiler try to say why + if (res == IR.Incomplete) { + echo("The pasted code is incomplete!\n") + // Remembrance of Things Pasted in an object + val errless = intp compileSources new BatchSourceFile("<pastie>", s"object pastel {\n$code\n}") + if (errless) echo("...but compilation found no error? Good luck with that.") + } + } + def compileCode() = { + val errless = intp compileSources new BatchSourceFile("<pastie>", code) + if (!errless) echo("There were compilation errors!") + } + if (code.nonEmpty) { + if (raw) compileCode() else interpretCode() + } + result + } + + private object paste extends Pasted { + val ContinueString = " | " + val PromptString = "scala> " + + def interpret(line: String): Unit = { + echo(line.trim) + intp interpret line + echo("") + } + + def transcript(start: String) = { + echo("\n// Detected repl transcript paste: ctrl-D to finish.\n") + apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim)) + } + } + import paste.{ ContinueString, PromptString } + + /** Interpret expressions starting with the first line. + * Read lines until a complete compilation unit is available + * or until a syntax error has been seen. If a full unit is + * read, go ahead and interpret it. Return the full string + * to be recorded for replay, if any. + */ + def interpretStartingWith(code: String): Option[String] = { + // signal completion non-completion input has been received + in.completion.resetVerbosity() + + def reallyInterpret = { + val reallyResult = intp.interpret(code) + (reallyResult, reallyResult match { + case IR.Error => None + case IR.Success => Some(code) + case IR.Incomplete => + if (in.interactive && code.endsWith("\n\n")) { + echo("You typed two blank lines. Starting a new command.") + None + } + else in.readLine(ContinueString) match { + case null => + // we know compilation is going to fail since we're at EOF and the + // parser thinks the input is still incomplete, but since this is + // a file being read non-interactively we want to fail. So we send + // it straight to the compiler for the nice error message. + intp.compileString(code) + None + + case line => interpretStartingWith(code + "\n" + line) + } + }) + } + + /** Here we place ourselves between the user and the interpreter and examine + * the input they are ostensibly submitting. We intervene in several cases: + * + * 1) If the line starts with "scala> " it is assumed to be an interpreter paste. + * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation + * on the previous result. + * 3) If the Completion object's execute returns Some(_), we inject that value + * and avoid the interpreter, as it's likely not valid scala code. + */ + if (code == "") None + else if (!paste.running && code.trim.startsWith(PromptString)) { + paste.transcript(code) + None + } + else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") { + interpretStartingWith(intp.mostRecentVar + code) + } + else if (code.trim startsWith "//") { + // line comment, do nothing + None + } + else + reallyInterpret._2 + } + + // runs :load `file` on any files passed via -i + def loadFiles(settings: Settings) = settings match { + case settings: GenericRunnerSettings => + for (filename <- settings.loadfiles.value) { + val cmd = ":load " + filename + command(cmd) + addReplay(cmd) + echo("") + } + case _ => + } + + /** Tries to create a JLineReader, falling back to SimpleReader: + * unless settings or properties are such that it should start + * with SimpleReader. + */ + def chooseReader(settings: Settings): InteractiveReader = { + if (settings.Xnojline || Properties.isEmacsShell) + SimpleReader() + else try new JLineReader( + if (settings.noCompletion) NoCompletion + else new SparkJLineCompletion(intp) + ) + catch { + case ex @ (_: Exception | _: NoClassDefFoundError) => + echo("Failed to created JLineReader: " + ex + "\nFalling back to SimpleReader.") + SimpleReader() + } + } + protected def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] = + u.TypeTag[T]( + m, + new TypeCreator { + def apply[U <: Universe with Singleton](m: Mirror[U]): U # Type = + m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type] + }) + + private def loopPostInit() { + // Bind intp somewhere out of the regular namespace where + // we can get at it in generated code. + intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfStaticClass[SparkIMain], classTag[SparkIMain])) + // Auto-run code via some setting. + ( replProps.replAutorunCode.option + flatMap (f => io.File(f).safeSlurp()) + foreach (intp quietRun _) + ) + // classloader and power mode setup + intp.setContextClassLoader() + if (isReplPower) { + // replProps.power setValue true + // unleashAndSetPhase() + // asyncMessage(power.banner) + } + // SI-7418 Now, and only now, can we enable TAB completion. + in match { + case x: JLineReader => x.consoleReader.postInit + case _ => + } + } + def process(settings: Settings): Boolean = savingContextLoader { + this.settings = settings + createInterpreter() + + // sets in to some kind of reader depending on environmental cues + in = in0.fold(chooseReader(settings))(r => SimpleReader(r, out, interactive = true)) + globalFuture = future { + intp.initializeSynchronous() + loopPostInit() + !intp.reporter.hasErrors + } + import scala.concurrent.duration._ + Await.ready(globalFuture, 10 seconds) + printWelcome() + initializeSpark() + loadFiles(settings) + + try loop() + catch AbstractOrMissingHandler() + finally closeInterpreter() + + true + } + + @deprecated("Use `process` instead", "2.9.0") + def main(settings: Settings): Unit = process(settings) //used by sbt +} + +object SparkILoop { + implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp + + // Designed primarily for use by test code: take a String with a + // bunch of code, and prints out a transcript of what it would look + // like if you'd just typed it into the repl. + def runForTranscript(code: String, settings: Settings): String = { + import java.io.{ BufferedReader, StringReader, OutputStreamWriter } + + stringFromStream { ostream => + Console.withOut(ostream) { + val output = new JPrintWriter(new OutputStreamWriter(ostream), true) { + override def write(str: String) = { + // completely skip continuation lines + if (str forall (ch => ch.isWhitespace || ch == '|')) () + else super.write(str) + } + } + val input = new BufferedReader(new StringReader(code.trim + "\n")) { + override def readLine(): String = { + val s = super.readLine() + // helping out by printing the line being interpreted. + if (s != null) + output.println(s) + s + } + } + val repl = new SparkILoop(input, output) + if (settings.classpath.isDefault) + settings.classpath.value = sys.props("java.class.path") + + repl process settings + } + } + } + + /** Creates an interpreter loop with default settings and feeds + * the given code to it as input. + */ + def run(code: String, sets: Settings = new Settings): String = { + import java.io.{ BufferedReader, StringReader, OutputStreamWriter } + + stringFromStream { ostream => + Console.withOut(ostream) { + val input = new BufferedReader(new StringReader(code)) + val output = new JPrintWriter(new OutputStreamWriter(ostream), true) + val repl = new SparkILoop(input, output) + + if (sets.classpath.isDefault) + sets.classpath.value = sys.props("java.class.path") + + repl process sets + } + } + } + def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala new file mode 100644 index 0000000000..1bb62c84ab --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -0,0 +1,1319 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Martin Odersky + */ + +package scala +package tools.nsc +package interpreter + +import PartialFunction.cond +import scala.language.implicitConversions +import scala.beans.BeanProperty +import scala.collection.mutable +import scala.concurrent.{ Future, ExecutionContext } +import scala.reflect.runtime.{ universe => ru } +import scala.reflect.{ ClassTag, classTag } +import scala.reflect.internal.util.{ BatchSourceFile, SourceFile } +import scala.tools.util.PathResolver +import scala.tools.nsc.io.AbstractFile +import scala.tools.nsc.typechecker.{ TypeStrings, StructuredTypeStrings } +import scala.tools.nsc.util.{ ScalaClassLoader, stringFromReader, stringFromWriter, StackTraceOps } +import scala.tools.nsc.util.Exceptional.unwrap +import javax.script.{AbstractScriptEngine, Bindings, ScriptContext, ScriptEngine, ScriptEngineFactory, ScriptException, CompiledScript, Compilable} + +/** An interpreter for Scala code. + * + * The main public entry points are compile(), interpret(), and bind(). + * The compile() method loads a complete Scala file. The interpret() method + * executes one line of Scala code at the request of the user. The bind() + * method binds an object to a variable that can then be used by later + * interpreted code. + * + * The overall approach is based on compiling the requested code and then + * using a Java classloader and Java reflection to run the code + * and access its results. + * + * In more detail, a single compiler instance is used + * to accumulate all successfully compiled or interpreted Scala code. To + * "interpret" a line of code, the compiler generates a fresh object that + * includes the line of code and which has public member(s) to export + * all variables defined by that code. To extract the result of an + * interpreted line to show the user, a second "result object" is created + * which imports the variables exported by the above object and then + * exports members called "$eval" and "$print". To accomodate user expressions + * that read from variables or methods defined in previous statements, "import" + * statements are used. + * + * This interpreter shares the strengths and weaknesses of using the + * full compiler-to-Java. The main strength is that interpreted code + * behaves exactly as does compiled code, including running at full speed. + * The main weakness is that redefining classes and methods is not handled + * properly, because rebinding at the Java level is technically difficult. + * + * @author Moez A. Abdel-Gawad + * @author Lex Spoon + */ +class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Settings, + protected val out: JPrintWriter) extends AbstractScriptEngine with Compilable with SparkImports { + imain => + + setBindings(createBindings, ScriptContext.ENGINE_SCOPE) + object replOutput extends ReplOutput(settings.Yreploutdir) { } + + @deprecated("Use replOutput.dir instead", "2.11.0") + def virtualDirectory = replOutput.dir + // Used in a test case. + def showDirectory() = replOutput.show(out) + + private[nsc] var printResults = true // whether to print result lines + private[nsc] var totalSilence = false // whether to print anything + private var _initializeComplete = false // compiler is initialized + private var _isInitialized: Future[Boolean] = null // set up initialization future + private var bindExceptions = true // whether to bind the lastException variable + private var _executionWrapper = "" // code to be wrapped around all lines + + /** We're going to go to some trouble to initialize the compiler asynchronously. + * It's critical that nothing call into it until it's been initialized or we will + * run into unrecoverable issues, but the perceived repl startup time goes + * through the roof if we wait for it. So we initialize it with a future and + * use a lazy val to ensure that any attempt to use the compiler object waits + * on the future. + */ + private var _classLoader: util.AbstractFileClassLoader = null // active classloader + private val _compiler: ReplGlobal = newCompiler(settings, reporter) // our private compiler + + def compilerClasspath: Seq[java.net.URL] = ( + if (isInitializeComplete) global.classPath.asURLs + else new PathResolver(settings).result.asURLs // the compiler's classpath + ) + def settings = initialSettings + // Run the code body with the given boolean settings flipped to true. + def withoutWarnings[T](body: => T): T = beQuietDuring { + val saved = settings.nowarn.value + if (!saved) + settings.nowarn.value = true + + try body + finally if (!saved) settings.nowarn.value = false + } + + /** construct an interpreter that reports to Console */ + def this(settings: Settings, out: JPrintWriter) = this(null, settings, out) + def this(factory: ScriptEngineFactory, settings: Settings) = this(factory, settings, new NewLinePrintWriter(new ConsoleWriter, true)) + def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) + def this(factory: ScriptEngineFactory) = this(factory, new Settings()) + def this() = this(new Settings()) + + lazy val formatting: Formatting = new Formatting { + val prompt = Properties.shellPromptString + } + lazy val reporter: SparkReplReporter = new SparkReplReporter(this) + + import formatting._ + import reporter.{ printMessage, printUntruncatedMessage } + + // This exists mostly because using the reporter too early leads to deadlock. + private def echo(msg: String) { Console println msg } + private def _initSources = List(new BatchSourceFile("<init>", "class $repl_$init { }")) + private def _initialize() = { + try { + // if this crashes, REPL will hang its head in shame + val run = new _compiler.Run() + assert(run.typerPhase != NoPhase, "REPL requires a typer phase.") + run compileSources _initSources + _initializeComplete = true + true + } + catch AbstractOrMissingHandler() + } + private def tquoted(s: String) = "\"\"\"" + s + "\"\"\"" + private val logScope = scala.sys.props contains "scala.repl.scope" + private def scopelog(msg: String) = if (logScope) Console.err.println(msg) + + // argument is a thunk to execute after init is done + def initialize(postInitSignal: => Unit) { + synchronized { + if (_isInitialized == null) { + _isInitialized = + Future(try _initialize() finally postInitSignal)(ExecutionContext.global) + } + } + } + def initializeSynchronous(): Unit = { + if (!isInitializeComplete) { + _initialize() + assert(global != null, global) + } + } + def isInitializeComplete = _initializeComplete + + lazy val global: Global = { + if (!isInitializeComplete) _initialize() + _compiler + } + + import global._ + import definitions.{ ObjectClass, termMember, dropNullaryMethod} + + lazy val runtimeMirror = ru.runtimeMirror(classLoader) + + private def noFatal(body: => Symbol): Symbol = try body catch { case _: FatalError => NoSymbol } + + def getClassIfDefined(path: String) = ( + noFatal(runtimeMirror staticClass path) + orElse noFatal(rootMirror staticClass path) + ) + def getModuleIfDefined(path: String) = ( + noFatal(runtimeMirror staticModule path) + orElse noFatal(rootMirror staticModule path) + ) + + implicit class ReplTypeOps(tp: Type) { + def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp) + } + + // TODO: If we try to make naming a lazy val, we run into big time + // scalac unhappiness with what look like cycles. It has not been easy to + // reduce, but name resolution clearly takes different paths. + object naming extends { + val global: imain.global.type = imain.global + } with Naming { + // make sure we don't overwrite their unwisely named res3 etc. + def freshUserTermName(): TermName = { + val name = newTermName(freshUserVarName()) + if (replScope containsName name) freshUserTermName() + else name + } + def isInternalTermName(name: Name) = isInternalVarName("" + name) + } + import naming._ + + object deconstruct extends { + val global: imain.global.type = imain.global + } with StructuredTypeStrings + + lazy val memberHandlers = new { + val intp: imain.type = imain + } with SparkMemberHandlers + import memberHandlers._ + + /** Temporarily be quiet */ + def beQuietDuring[T](body: => T): T = { + val saved = printResults + printResults = false + try body + finally printResults = saved + } + def beSilentDuring[T](operation: => T): T = { + val saved = totalSilence + totalSilence = true + try operation + finally totalSilence = saved + } + + def quietRun[T](code: String) = beQuietDuring(interpret(code)) + + /** takes AnyRef because it may be binding a Throwable or an Exceptional */ + private def withLastExceptionLock[T](body: => T, alt: => T): T = { + assert(bindExceptions, "withLastExceptionLock called incorrectly.") + bindExceptions = false + + try beQuietDuring(body) + catch logAndDiscard("withLastExceptionLock", alt) + finally bindExceptions = true + } + + def executionWrapper = _executionWrapper + def setExecutionWrapper(code: String) = _executionWrapper = code + def clearExecutionWrapper() = _executionWrapper = "" + + /** interpreter settings */ + lazy val isettings = new SparkISettings(this) + + /** Instantiate a compiler. Overridable. */ + protected def newCompiler(settings: Settings, reporter: reporters.Reporter): ReplGlobal = { + settings.outputDirs setSingleOutput replOutput.dir + settings.exposeEmptyPackage.value = true + new Global(settings, reporter) with ReplGlobal { override def toString: String = "<global>" } + } + + /** Parent classloader. Overridable. */ + protected def parentClassLoader: ClassLoader = + settings.explicitParentLoader.getOrElse( this.getClass.getClassLoader() ) + + /* A single class loader is used for all commands interpreted by this Interpreter. + It would also be possible to create a new class loader for each command + to interpret. The advantages of the current approach are: + + - Expressions are only evaluated one time. This is especially + significant for I/O, e.g. "val x = Console.readLine" + + The main disadvantage is: + + - Objects, classes, and methods cannot be rebound. Instead, definitions + shadow the old ones, and old code objects refer to the old + definitions. + */ + def resetClassLoader() = { + repldbg("Setting new classloader: was " + _classLoader) + _classLoader = null + ensureClassLoader() + } + final def ensureClassLoader() { + if (_classLoader == null) + _classLoader = makeClassLoader() + } + def classLoader: util.AbstractFileClassLoader = { + ensureClassLoader() + _classLoader + } + + def backticked(s: String): String = ( + (s split '.').toList map { + case "_" => "_" + case s if nme.keywords(newTermName(s)) => s"`$s`" + case s => s + } mkString "." + ) + def readRootPath(readPath: String) = getModuleIfDefined(readPath) + + abstract class PhaseDependentOps { + def shift[T](op: => T): T + + def path(name: => Name): String = shift(path(symbolOfName(name))) + def path(sym: Symbol): String = backticked(shift(sym.fullName)) + def sig(sym: Symbol): String = shift(sym.defString) + } + object typerOp extends PhaseDependentOps { + def shift[T](op: => T): T = exitingTyper(op) + } + object flatOp extends PhaseDependentOps { + def shift[T](op: => T): T = exitingFlatten(op) + } + + def originalPath(name: String): String = originalPath(name: TermName) + def originalPath(name: Name): String = typerOp path name + def originalPath(sym: Symbol): String = typerOp path sym + def flatPath(sym: Symbol): String = flatOp shift sym.javaClassName + def translatePath(path: String) = { + val sym = if (path endsWith "$") symbolOfTerm(path.init) else symbolOfIdent(path) + sym.toOption map flatPath + } + def translateEnclosingClass(n: String) = symbolOfTerm(n).enclClass.toOption map flatPath + + private class TranslatingClassLoader(parent: ClassLoader) extends util.AbstractFileClassLoader(replOutput.dir, parent) { + /** Overridden here to try translating a simple name to the generated + * class name if the original attempt fails. This method is used by + * getResourceAsStream as well as findClass. + */ + override protected def findAbstractFile(name: String): AbstractFile = + super.findAbstractFile(name) match { + case null if _initializeComplete => translatePath(name) map (super.findAbstractFile(_)) orNull + case file => file + } + } + private def makeClassLoader(): util.AbstractFileClassLoader = + new TranslatingClassLoader(parentClassLoader match { + case null => ScalaClassLoader fromURLs compilerClasspath + case p => new ScalaClassLoader.URLClassLoader(compilerClasspath, p) + }) + + // Set the current Java "context" class loader to this interpreter's class loader + def setContextClassLoader() = classLoader.setAsContext() + + def allDefinedNames: List[Name] = exitingTyper(replScope.toList.map(_.name).sorted) + def unqualifiedIds: List[String] = allDefinedNames map (_.decode) sorted + + /** Most recent tree handled which wasn't wholly synthetic. */ + private def mostRecentlyHandledTree: Option[Tree] = { + prevRequests.reverse foreach { req => + req.handlers.reverse foreach { + case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member) + case _ => () + } + } + None + } + + private def updateReplScope(sym: Symbol, isDefined: Boolean) { + def log(what: String) { + val mark = if (sym.isType) "t " else "v " + val name = exitingTyper(sym.nameString) + val info = cleanTypeAfterTyper(sym) + val defn = sym defStringSeenAs info + + scopelog(f"[$mark$what%6s] $name%-25s $defn%s") + } + if (ObjectClass isSubClass sym.owner) return + // unlink previous + replScope lookupAll sym.name foreach { sym => + log("unlink") + replScope unlink sym + } + val what = if (isDefined) "define" else "import" + log(what) + replScope enter sym + } + + def recordRequest(req: Request) { + if (req == null) + return + + prevRequests += req + + // warning about serially defining companions. It'd be easy + // enough to just redefine them together but that may not always + // be what people want so I'm waiting until I can do it better. + exitingTyper { + req.defines filterNot (s => req.defines contains s.companionSymbol) foreach { newSym => + val oldSym = replScope lookup newSym.name.companionName + if (Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule }) { + replwarn(s"warning: previously defined $oldSym is not a companion to $newSym.") + replwarn("Companions must be defined together; you may wish to use :paste mode for this.") + } + } + } + exitingTyper { + req.imports foreach (sym => updateReplScope(sym, isDefined = false)) + req.defines foreach (sym => updateReplScope(sym, isDefined = true)) + } + } + + private[nsc] def replwarn(msg: => String) { + if (!settings.nowarnings) + printMessage(msg) + } + + def compileSourcesKeepingRun(sources: SourceFile*) = { + val run = new Run() + assert(run.typerPhase != NoPhase, "REPL requires a typer phase.") + reporter.reset() + run compileSources sources.toList + (!reporter.hasErrors, run) + } + + /** Compile an nsc SourceFile. Returns true if there are + * no compilation errors, or false otherwise. + */ + def compileSources(sources: SourceFile*): Boolean = + compileSourcesKeepingRun(sources: _*)._1 + + /** Compile a string. Returns true if there are no + * compilation errors, or false otherwise. + */ + def compileString(code: String): Boolean = + compileSources(new BatchSourceFile("<script>", code)) + + /** Build a request from the user. `trees` is `line` after being parsed. + */ + private def buildRequest(line: String, trees: List[Tree]): Request = { + executingRequest = new Request(line, trees) + executingRequest + } + + private def safePos(t: Tree, alt: Int): Int = + try t.pos.start + catch { case _: UnsupportedOperationException => alt } + + // Given an expression like 10 * 10 * 10 we receive the parent tree positioned + // at a '*'. So look at each subtree and find the earliest of all positions. + private def earliestPosition(tree: Tree): Int = { + var pos = Int.MaxValue + tree foreach { t => + pos = math.min(pos, safePos(t, Int.MaxValue)) + } + pos + } + + private def requestFromLine(line: String, synthetic: Boolean): Either[IR.Result, Request] = { + val content = indentCode(line) + val trees = parse(content) match { + case parse.Incomplete => return Left(IR.Incomplete) + case parse.Error => return Left(IR.Error) + case parse.Success(trees) => trees + } + repltrace( + trees map (t => { + // [Eugene to Paul] previously it just said `t map ...` + // because there was an implicit conversion from Tree to a list of Trees + // however Martin and I have removed the conversion + // (it was conflicting with the new reflection API), + // so I had to rewrite this a bit + val subs = t collect { case sub => sub } + subs map (t0 => + " " + safePos(t0, -1) + ": " + t0.shortClass + "\n" + ) mkString "" + }) mkString "\n" + ) + // If the last tree is a bare expression, pinpoint where it begins using the + // AST node position and snap the line off there. Rewrite the code embodied + // by the last tree as a ValDef instead, so we can access the value. + val last = trees.lastOption.getOrElse(EmptyTree) + last match { + case _:Assign => // we don't want to include assignments + case _:TermTree | _:Ident | _:Select => // ... but do want other unnamed terms. + val varName = if (synthetic) freshInternalVarName() else freshUserVarName() + val rewrittenLine = ( + // In theory this would come out the same without the 1-specific test, but + // it's a cushion against any more sneaky parse-tree position vs. code mismatches: + // this way such issues will only arise on multiple-statement repl input lines, + // which most people don't use. + if (trees.size == 1) "val " + varName + " =\n" + content + else { + // The position of the last tree + val lastpos0 = earliestPosition(last) + // Oh boy, the parser throws away parens so "(2+2)" is mispositioned, + // with increasingly hard to decipher positions as we move on to "() => 5", + // (x: Int) => x + 1, and more. So I abandon attempts to finesse and just + // look for semicolons and newlines, which I'm sure is also buggy. + val (raw1, raw2) = content splitAt lastpos0 + repldbg("[raw] " + raw1 + " <---> " + raw2) + + val adjustment = (raw1.reverse takeWhile (ch => (ch != ';') && (ch != '\n'))).size + val lastpos = lastpos0 - adjustment + + // the source code split at the laboriously determined position. + val (l1, l2) = content splitAt lastpos + repldbg("[adj] " + l1 + " <---> " + l2) + + val prefix = if (l1.trim == "") "" else l1 + ";\n" + // Note to self: val source needs to have this precise structure so that + // error messages print the user-submitted part without the "val res0 = " part. + val combined = prefix + "val " + varName + " =\n" + l2 + + repldbg(List( + " line" -> line, + " content" -> content, + " was" -> l2, + "combined" -> combined) map { + case (label, s) => label + ": '" + s + "'" + } mkString "\n" + ) + combined + } + ) + // Rewriting "foo ; bar ; 123" + // to "foo ; bar ; val resXX = 123" + requestFromLine(rewrittenLine, synthetic) match { + case Right(req) => return Right(req withOriginalLine line) + case x => return x + } + case _ => + } + Right(buildRequest(line, trees)) + } + + // dealias non-public types so we don't see protected aliases like Self + def dealiasNonPublic(tp: Type) = tp match { + case TypeRef(_, sym, _) if sym.isAliasType && !sym.isPublic => tp.dealias + case _ => tp + } + + /** + * Interpret one line of input. All feedback, including parse errors + * and evaluation results, are printed via the supplied compiler's + * reporter. Values defined are available for future interpreted strings. + * + * The return value is whether the line was interpreter successfully, + * e.g. that there were no parse errors. + */ + def interpret(line: String): IR.Result = interpret(line, synthetic = false) + def interpretSynthetic(line: String): IR.Result = interpret(line, synthetic = true) + def interpret(line: String, synthetic: Boolean): IR.Result = compile(line, synthetic) match { + case Left(result) => result + case Right(req) => new WrappedRequest(req).loadAndRunReq + } + + private def compile(line: String, synthetic: Boolean): Either[IR.Result, Request] = { + if (global == null) Left(IR.Error) + else requestFromLine(line, synthetic) match { + case Left(result) => Left(result) + case Right(req) => + // null indicates a disallowed statement type; otherwise compile and + // fail if false (implying e.g. a type error) + if (req == null || !req.compile) Left(IR.Error) else Right(req) + } + } + + var code = "" + var bound = false + def compiled(script: String): CompiledScript = { + if (!bound) { + quietBind("engine" -> this.asInstanceOf[ScriptEngine]) + bound = true + } + val cat = code + script + compile(cat, false) match { + case Left(result) => result match { + case IR.Incomplete => { + code = cat + "\n" + new CompiledScript { + def eval(context: ScriptContext): Object = null + def getEngine: ScriptEngine = SparkIMain.this + } + } + case _ => { + code = "" + throw new ScriptException("compile-time error") + } + } + case Right(req) => { + code = "" + new WrappedRequest(req) + } + } + } + + private class WrappedRequest(val req: Request) extends CompiledScript { + var recorded = false + + /** In Java we would have to wrap any checked exception in the declared + * ScriptException. Runtime exceptions and errors would be ok and would + * not need to be caught. So let us do the same in Scala : catch and + * wrap any checked exception, and let runtime exceptions and errors + * escape. We could have wrapped runtime exceptions just like other + * exceptions in ScriptException, this is a choice. + */ + @throws[ScriptException] + def eval(context: ScriptContext): Object = { + val result = req.lineRep.evalEither match { + case Left(e: RuntimeException) => throw e + case Left(e: Exception) => throw new ScriptException(e) + case Left(e) => throw e + case Right(result) => result.asInstanceOf[Object] + } + if (!recorded) { + recordRequest(req) + recorded = true + } + result + } + + def loadAndRunReq = classLoader.asContext { + val (result, succeeded) = req.loadAndRun + + /** To our displeasure, ConsoleReporter offers only printMessage, + * which tacks a newline on the end. Since that breaks all the + * output checking, we have to take one off to balance. + */ + if (succeeded) { + if (printResults && result != "") + printMessage(result stripSuffix "\n") + else if (isReplDebug) // show quiet-mode activity + printMessage(result.trim.lines map ("[quiet] " + _) mkString "\n") + + // Book-keeping. Have to record synthetic requests too, + // as they may have been issued for information, e.g. :type + recordRequest(req) + IR.Success + } + else { + // don't truncate stack traces + printUntruncatedMessage(result) + IR.Error + } + } + + def getEngine: ScriptEngine = SparkIMain.this + } + + /** Bind a specified name to a specified value. The name may + * later be used by expressions passed to interpret. + * + * @param name the variable name to bind + * @param boundType the type of the variable, as a string + * @param value the object value to bind to it + * @return an indication of whether the binding succeeded + */ + def bind(name: String, boundType: String, value: Any, modifiers: List[String] = Nil): IR.Result = { + val bindRep = new ReadEvalPrint() + bindRep.compile(""" + |object %s { + | var value: %s = _ + | def set(x: Any) = value = x.asInstanceOf[%s] + |} + """.stripMargin.format(bindRep.evalName, boundType, boundType) + ) + bindRep.callEither("set", value) match { + case Left(ex) => + repldbg("Set failed in bind(%s, %s, %s)".format(name, boundType, value)) + repldbg(util.stackTraceString(ex)) + IR.Error + + case Right(_) => + val line = "%sval %s = %s.value".format(modifiers map (_ + " ") mkString, name, bindRep.evalPath) + repldbg("Interpreting: " + line) + interpret(line) + } + } + def directBind(name: String, boundType: String, value: Any): IR.Result = { + val result = bind(name, boundType, value) + if (result == IR.Success) + directlyBoundNames += newTermName(name) + result + } + def directBind(p: NamedParam): IR.Result = directBind(p.name, p.tpe, p.value) + def directBind[T: ru.TypeTag : ClassTag](name: String, value: T): IR.Result = directBind((name, value)) + + def rebind(p: NamedParam): IR.Result = { + val name = p.name + val newType = p.tpe + val tempName = freshInternalVarName() + + quietRun("val %s = %s".format(tempName, name)) + quietRun("val %s = %s.asInstanceOf[%s]".format(name, tempName, newType)) + } + def quietBind(p: NamedParam): IR.Result = beQuietDuring(bind(p)) + def bind(p: NamedParam): IR.Result = bind(p.name, p.tpe, p.value) + def bind[T: ru.TypeTag : ClassTag](name: String, value: T): IR.Result = bind((name, value)) + + /** Reset this interpreter, forgetting all user-specified requests. */ + def reset() { + clearExecutionWrapper() + resetClassLoader() + resetAllCreators() + prevRequests.clear() + resetReplScope() + replOutput.dir.clear() + } + + /** This instance is no longer needed, so release any resources + * it is using. The reporter's output gets flushed. + */ + def close() { + reporter.flush() + } + + /** Here is where we: + * + * 1) Read some source code, and put it in the "read" object. + * 2) Evaluate the read object, and put the result in the "eval" object. + * 3) Create a String for human consumption, and put it in the "print" object. + * + * Read! Eval! Print! Some of that not yet centralized here. + */ + class ReadEvalPrint(val lineId: Int) { + def this() = this(freshLineId()) + + val packageName = sessionNames.line + lineId + val readName = sessionNames.read + val evalName = sessionNames.eval + val printName = sessionNames.print + val resultName = sessionNames.result + + def bindError(t: Throwable) = { + if (!bindExceptions) // avoid looping if already binding + throw t + + val unwrapped = unwrap(t) + + // Example input: $line3.$read$$iw$$iw$ + val classNameRegex = (naming.lineRegex + ".*").r + def isWrapperInit(x: StackTraceElement) = cond(x.getClassName) { + case classNameRegex() if x.getMethodName == nme.CONSTRUCTOR.decoded => true + } + val stackTrace = unwrapped stackTracePrefixString (!isWrapperInit(_)) + + withLastExceptionLock[String]({ + directBind[Throwable]("lastException", unwrapped)(StdReplTags.tagOfThrowable, classTag[Throwable]) + stackTrace + }, stackTrace) + } + + // TODO: split it out into a package object and a regular + // object and we can do that much less wrapping. + def packageDecl = "package " + packageName + + def pathTo(name: String) = packageName + "." + name + def packaged(code: String) = packageDecl + "\n\n" + code + + def readPath = pathTo(readName) + def evalPath = pathTo(evalName) + + def call(name: String, args: Any*): AnyRef = { + val m = evalMethod(name) + repldbg("Invoking: " + m) + if (args.nonEmpty) + repldbg(" with args: " + args.mkString(", ")) + + m.invoke(evalClass, args.map(_.asInstanceOf[AnyRef]): _*) + } + + def callEither(name: String, args: Any*): Either[Throwable, AnyRef] = + try Right(call(name, args: _*)) + catch { case ex: Throwable => Left(ex) } + + class EvalException(msg: String, cause: Throwable) extends RuntimeException(msg, cause) { } + + private def evalError(path: String, ex: Throwable) = + throw new EvalException("Failed to load '" + path + "': " + ex.getMessage, ex) + + private def load(path: String): Class[_] = { + try Class.forName(path, true, classLoader) + catch { case ex: Throwable => evalError(path, unwrap(ex)) } + } + + lazy val evalClass = load(evalPath) + + def evalEither = callEither(resultName) match { + case Left(ex) => ex match { + case ex: NullPointerException => Right(null) + case ex => Left(unwrap(ex)) + } + case Right(result) => Right(result) + } + + def compile(source: String): Boolean = compileAndSaveRun("<console>", source) + + /** The innermost object inside the wrapper, found by + * following accessPath into the outer one. + */ + def resolvePathToSymbol(accessPath: String): Symbol = { + val readRoot: global.Symbol = readRootPath(readPath) // the outermost wrapper + ((".INSTANCE" + accessPath) split '.').foldLeft(readRoot: Symbol) { + case (sym, "") => sym + case (sym, name) => exitingTyper(termMember(sym, name)) + } + } + /** We get a bunch of repeated warnings for reasons I haven't + * entirely figured out yet. For now, squash. + */ + private def updateRecentWarnings(run: Run) { + def loop(xs: List[(Position, String)]): List[(Position, String)] = xs match { + case Nil => Nil + case ((pos, msg)) :: rest => + val filtered = rest filter { case (pos0, msg0) => + (msg != msg0) || (pos.lineContent.trim != pos0.lineContent.trim) || { + // same messages and same line content after whitespace removal + // but we want to let through multiple warnings on the same line + // from the same run. The untrimmed line will be the same since + // there's no whitespace indenting blowing it. + (pos.lineContent == pos0.lineContent) + } + } + ((pos, msg)) :: loop(filtered) + } + val warnings = loop(run.reporting.allConditionalWarnings) + if (warnings.nonEmpty) + mostRecentWarnings = warnings + } + private def evalMethod(name: String) = evalClass.getMethods filter (_.getName == name) match { + case Array() => null + case Array(method) => method + case xs => sys.error("Internal error: eval object " + evalClass + ", " + xs.mkString("\n", "\n", "")) + } + private def compileAndSaveRun(label: String, code: String) = { + showCodeIfDebugging(code) + val (success, run) = compileSourcesKeepingRun(new BatchSourceFile(label, packaged(code))) + updateRecentWarnings(run) + success + } + } + + /** One line of code submitted by the user for interpretation */ + class Request(val line: String, val trees: List[Tree]) { + def defines = defHandlers flatMap (_.definedSymbols) + def imports = importedSymbols + def value = Some(handlers.last) filter (h => h.definesValue) map (h => definedSymbols(h.definesTerm.get)) getOrElse NoSymbol + + val lineRep = new ReadEvalPrint() + + private var _originalLine: String = null + def withOriginalLine(s: String): this.type = { _originalLine = s ; this } + def originalLine = if (_originalLine == null) line else _originalLine + + /** handlers for each tree in this request */ + val handlers: List[MemberHandler] = trees map (memberHandlers chooseHandler _) + def defHandlers = handlers collect { case x: MemberDefHandler => x } + + /** list of names used by this expression */ + val referencedNames: List[Name] = handlers flatMap (_.referencedNames) + + /** def and val names */ + def termNames = handlers flatMap (_.definesTerm) + def typeNames = handlers flatMap (_.definesType) + def importedSymbols = handlers flatMap { + case x: ImportHandler => x.importedSymbols + case _ => Nil + } + + val definedClasses = handlers.exists { + case _: ClassHandler => true + case _ => false + } + /** Code to import bound names from previous lines - accessPath is code to + * append to objectName to access anything bound by request. + */ + lazy val ComputedImports(importsPreamble, importsTrailer, accessPath) = + exitingTyper(importsCode(referencedNames.toSet, ObjectSourceCode, definedClasses)) + + /** the line of code to compute */ + def toCompute = line + + /** The path of the value that contains the user code. */ + def fullAccessPath = s"${lineRep.readPath}.INSTANCE$accessPath" + + /** The path of the given member of the wrapping instance. */ + def fullPath(vname: String) = s"$fullAccessPath.`$vname`" + + /** generate the source code for the object that computes this request */ + abstract class Wrapper extends SparkIMain.CodeAssembler[MemberHandler] { + def path = originalPath("$intp") + def envLines = { + if (!isReplPower) Nil // power mode only for now + else List("def %s = %s".format("$line", tquoted(originalLine)), "def %s = Nil".format("$trees")) + } + def preamble = s""" + |$preambleHeader + |%s%s%s + """.stripMargin.format(lineRep.readName, envLines.map(" " + _ + ";\n").mkString, + importsPreamble, indentCode(toCompute)) + + val generate = (m: MemberHandler) => m extraCodeToEvaluate Request.this + + /** A format string with %s for $read, specifying the wrapper definition. */ + def preambleHeader: String + + /** Like preambleHeader for an import wrapper. */ + def prewrap: String = preambleHeader + "\n" + + /** Like postamble for an import wrapper. */ + def postwrap: String + } + + private class ObjectBasedWrapper extends Wrapper { + def preambleHeader = "object %s {" + + def postamble = importsTrailer + "\n}" + + def postwrap = "}\n" + } + + private class ClassBasedWrapper extends Wrapper { + def preambleHeader = "class %s extends Serializable {" + + /** Adds an object that instantiates the outer wrapping class. */ + def postamble = s""" + |$importsTrailer + |} + |object ${lineRep.readName} { + | val INSTANCE = new ${lineRep.readName}(); + |} + |""".stripMargin + + import nme.{ INTERPRETER_IMPORT_WRAPPER => iw } + + /** Adds a val that instantiates the wrapping class. */ + def postwrap = s"}\nval $iw = new $iw\n" + } + + private lazy val ObjectSourceCode: Wrapper = new ClassBasedWrapper + private object ResultObjectSourceCode extends SparkIMain.CodeAssembler[MemberHandler] { + /** We only want to generate this code when the result + * is a value which can be referred to as-is. + */ + val evalResult = Request.this.value match { + case NoSymbol => "" + case sym => + "lazy val %s = %s".format(lineRep.resultName, fullPath(sym.decodedName)) + } + // first line evaluates object to make sure constructor is run + // initial "" so later code can uniformly be: + etc + val preamble = """ + |object %s { + | %s + | lazy val %s: String = %s { + | %s + | ("" + """.stripMargin.format( + lineRep.evalName, evalResult, lineRep.printName, + executionWrapper, fullAccessPath + ) + + val postamble = """ + | ) + | } + |} + """.stripMargin + val generate = (m: MemberHandler) => m resultExtractionCode Request.this + } + + /** Compile the object file. Returns whether the compilation succeeded. + * If all goes well, the "types" map is computed. */ + lazy val compile: Boolean = { + // error counting is wrong, hence interpreter may overlook failure - so we reset + reporter.reset() + + // compile the object containing the user's code + lineRep.compile(ObjectSourceCode(handlers)) && { + // extract and remember types + typeOf + typesOfDefinedTerms + + // Assign symbols to the original trees + // TODO - just use the new trees. + defHandlers foreach { dh => + val name = dh.member.name + definedSymbols get name foreach { sym => + dh.member setSymbol sym + repldbg("Set symbol of " + name + " to " + symbolDefString(sym)) + } + } + + // compile the result-extraction object + val handls = if (printResults) handlers else Nil + withoutWarnings(lineRep compile ResultObjectSourceCode(handls)) + } + } + + lazy val resultSymbol = lineRep.resolvePathToSymbol(accessPath) + + def applyToResultMember[T](name: Name, f: Symbol => T) = exitingTyper(f(resultSymbol.info.nonPrivateDecl(name))) + + /* typeOf lookup with encoding */ + def lookupTypeOf(name: Name) = typeOf.getOrElse(name, typeOf(global.encode(name.toString))) + + private def typeMap[T](f: Type => T) = + mapFrom[Name, Name, T](termNames ++ typeNames)(x => f(cleanMemberDecl(resultSymbol, x))) + + /** Types of variables defined by this request. */ + lazy val compilerTypeOf = typeMap[Type](x => x) withDefaultValue NoType + /** String representations of same. */ + lazy val typeOf = typeMap[String](tp => exitingTyper(tp.toString)) + + lazy val definedSymbols = ( + termNames.map(x => x -> applyToResultMember(x, x => x)) ++ + typeNames.map(x => x -> compilerTypeOf(x).typeSymbolDirect) + ).toMap[Name, Symbol] withDefaultValue NoSymbol + + lazy val typesOfDefinedTerms = mapFrom[Name, Name, Type](termNames)(x => applyToResultMember(x, _.tpe)) + + /** load and run the code using reflection */ + def loadAndRun: (String, Boolean) = { + try { ("" + (lineRep call sessionNames.print), true) } + catch { case ex: Throwable => (lineRep.bindError(ex), false) } + } + + override def toString = "Request(line=%s, %s trees)".format(line, trees.size) + } + + def createBindings: Bindings = new IBindings { + override def put(name: String, value: Object): Object = { + val n = name.indexOf(":") + val p: NamedParam = if (n < 0) (name, value) else { + val nme = name.substring(0, n).trim + val tpe = name.substring(n + 1).trim + NamedParamClass(nme, tpe, value) + } + if (!p.name.startsWith("javax.script")) bind(p) + null + } + } + + @throws[ScriptException] + def compile(script: String): CompiledScript = eval("new javax.script.CompiledScript { def eval(context: javax.script.ScriptContext): Object = { " + script + " }.asInstanceOf[Object]; def getEngine: javax.script.ScriptEngine = engine }").asInstanceOf[CompiledScript] + + @throws[ScriptException] + def compile(reader: java.io.Reader): CompiledScript = compile(stringFromReader(reader)) + + @throws[ScriptException] + def eval(script: String, context: ScriptContext): Object = compiled(script).eval(context) + + @throws[ScriptException] + def eval(reader: java.io.Reader, context: ScriptContext): Object = eval(stringFromReader(reader), context) + + override def finalize = close + + /** Returns the name of the most recent interpreter result. + * Mostly this exists so you can conveniently invoke methods on + * the previous result. + */ + def mostRecentVar: String = + if (mostRecentlyHandledTree.isEmpty) "" + else "" + (mostRecentlyHandledTree.get match { + case x: ValOrDefDef => x.name + case Assign(Ident(name), _) => name + case ModuleDef(_, name, _) => name + case _ => naming.mostRecentVar + }) + + private var mostRecentWarnings: List[(global.Position, String)] = Nil + def lastWarnings = mostRecentWarnings + + private lazy val importToGlobal = global mkImporter ru + private lazy val importToRuntime = ru.internal createImporter global + private lazy val javaMirror = ru.rootMirror match { + case x: ru.JavaMirror => x + case _ => null + } + private implicit def importFromRu(sym: ru.Symbol): Symbol = importToGlobal importSymbol sym + private implicit def importToRu(sym: Symbol): ru.Symbol = importToRuntime importSymbol sym + + def classOfTerm(id: String): Option[JClass] = symbolOfTerm(id) match { + case NoSymbol => None + case sym => Some(javaMirror runtimeClass importToRu(sym).asClass) + } + + def typeOfTerm(id: String): Type = symbolOfTerm(id).tpe + + def valueOfTerm(id: String): Option[Any] = exitingTyper { + def value() = { + val sym0 = symbolOfTerm(id) + val sym = (importToRuntime importSymbol sym0).asTerm + val module = runtimeMirror.reflectModule(sym.owner.companionSymbol.asModule).instance + val module1 = runtimeMirror.reflect(module) + val invoker = module1.reflectField(sym) + + invoker.get + } + + try Some(value()) catch { case _: Exception => None } + } + + /** It's a bit of a shotgun approach, but for now we will gain in + * robustness. Try a symbol-producing operation at phase typer, and + * if that is NoSymbol, try again at phase flatten. I'll be able to + * lose this and run only from exitingTyper as soon as I figure out + * exactly where a flat name is sneaking in when calculating imports. + */ + def tryTwice(op: => Symbol): Symbol = exitingTyper(op) orElse exitingFlatten(op) + + def symbolOfIdent(id: String): Symbol = symbolOfType(id) orElse symbolOfTerm(id) + def symbolOfType(id: String): Symbol = tryTwice(replScope lookup (id: TypeName)) + def symbolOfTerm(id: String): Symbol = tryTwice(replScope lookup (id: TermName)) + def symbolOfName(id: Name): Symbol = replScope lookup id + + def runtimeClassAndTypeOfTerm(id: String): Option[(JClass, Type)] = { + classOfTerm(id) flatMap { clazz => + clazz.supers find (!_.isScalaAnonymous) map { nonAnon => + (nonAnon, runtimeTypeOfTerm(id)) + } + } + } + + def runtimeTypeOfTerm(id: String): Type = { + typeOfTerm(id) andAlso { tpe => + val clazz = classOfTerm(id) getOrElse { return NoType } + val staticSym = tpe.typeSymbol + val runtimeSym = getClassIfDefined(clazz.getName) + + if ((runtimeSym != NoSymbol) && (runtimeSym != staticSym) && (runtimeSym isSubClass staticSym)) + runtimeSym.info + else NoType + } + } + + def cleanTypeAfterTyper(sym: => Symbol): Type = { + exitingTyper( + dealiasNonPublic( + dropNullaryMethod( + sym.tpe_* + ) + ) + ) + } + def cleanMemberDecl(owner: Symbol, member: Name): Type = + cleanTypeAfterTyper(owner.info nonPrivateDecl member) + + object exprTyper extends { + val repl: SparkIMain.this.type = imain + } with SparkExprTyper { } + + /** Parse a line into and return parsing result (error, incomplete or success with list of trees) */ + object parse { + abstract sealed class Result + case object Error extends Result + case object Incomplete extends Result + case class Success(trees: List[Tree]) extends Result + + def apply(line: String): Result = debugging(s"""parse("$line")""") { + var isIncomplete = false + currentRun.reporting.withIncompleteHandler((_, _) => isIncomplete = true) { + reporter.reset() + val trees = newUnitParser(line).parseStats() + if (reporter.hasErrors) Error + else if (isIncomplete) Incomplete + else Success(trees) + } + } + } + + def symbolOfLine(code: String): Symbol = + exprTyper.symbolOfLine(code) + + def typeOfExpression(expr: String, silent: Boolean = true): Type = + exprTyper.typeOfExpression(expr, silent) + + protected def onlyTerms(xs: List[Name]): List[TermName] = xs collect { case x: TermName => x } + protected def onlyTypes(xs: List[Name]): List[TypeName] = xs collect { case x: TypeName => x } + + def definedTerms = onlyTerms(allDefinedNames) filterNot isInternalTermName + def definedTypes = onlyTypes(allDefinedNames) + def definedSymbolList = prevRequestList flatMap (_.defines) filterNot (s => isInternalTermName(s.name)) + + // Terms with user-given names (i.e. not res0 and not synthetic) + def namedDefinedTerms = definedTerms filterNot (x => isUserVarName("" + x) || directlyBoundNames(x)) + + private var _replScope: Scope = _ + private def resetReplScope() { + _replScope = newScope + } + def replScope = { + if (_replScope eq null) + _replScope = newScope + + _replScope + } + + private var executingRequest: Request = _ + private val prevRequests = mutable.ListBuffer[Request]() + private val directlyBoundNames = mutable.Set[Name]() + + def allHandlers = prevRequestList flatMap (_.handlers) + def lastRequest = if (prevRequests.isEmpty) null else prevRequests.last + def prevRequestList = prevRequests.toList + def importHandlers = allHandlers collect { case x: ImportHandler => x } + + def withoutUnwrapping(op: => Unit): Unit = { + val saved = isettings.unwrapStrings + isettings.unwrapStrings = false + try op + finally isettings.unwrapStrings = saved + } + + def symbolDefString(sym: Symbol) = { + TypeStrings.quieter( + exitingTyper(sym.defString), + sym.owner.name + ".this.", + sym.owner.fullName + "." + ) + } + + def showCodeIfDebugging(code: String) { + /** Secret bookcase entrance for repl debuggers: end the line + * with "// show" and see what's going on. + */ + def isShow = code.lines exists (_.trim endsWith "// show") + if (isReplDebug || isShow) { + beSilentDuring(parse(code)) match { + case parse.Success(ts) => + ts foreach { t => + withoutUnwrapping(echo(asCompactString(t))) + } + case _ => + } + } + } + + // debugging + def debugging[T](msg: String)(res: T) = { + repldbg(msg + " " + res) + res + } +} + +/** Utility methods for the Interpreter. */ +object SparkIMain { + import java.util.Arrays.{ asList => asJavaList } + + // The two name forms this is catching are the two sides of this assignment: + // + // $line3.$read.$iw.$iw.Bippy = + // $line3.$read$$iw$$iw$Bippy@4a6a00ca + private def removeLineWrapper(s: String) = s.replaceAll("""\$line\d+[./]\$(read|eval|print)[$.]""", "") + private def removeIWPackages(s: String) = s.replaceAll("""\$(iw|read|eval|print)[$.]""", "") + private def removeSparkVals(s: String) = s.replaceAll("""\$VAL[0-9]+[$.]""", "") + def stripString(s: String) = removeSparkVals(removeIWPackages(removeLineWrapper(s))) + + trait CodeAssembler[T] { + def preamble: String + def generate: T => String + def postamble: String + + def apply(contributors: List[T]): String = stringFromWriter { code => + code println preamble + contributors map generate foreach (code println _) + code println postamble + } + } + + trait StrippingWriter { + def isStripping: Boolean + def stripImpl(str: String): String + def strip(str: String): String = if (isStripping) stripImpl(str) else str + } + trait TruncatingWriter { + def maxStringLength: Int + def isTruncating: Boolean + def truncate(str: String): String = { + if (isTruncating && (maxStringLength != 0 && str.length > maxStringLength)) + (str take maxStringLength - 3) + "..." + else str + } + } + abstract class StrippingTruncatingWriter(out: JPrintWriter) + extends JPrintWriter(out) + with StrippingWriter + with TruncatingWriter { + self => + + def clean(str: String): String = truncate(strip(str)) + override def write(str: String) = super.write(clean(str)) + } + class SparkReplStrippingWriter(intp: SparkIMain) extends StrippingTruncatingWriter(intp.out) { + import intp._ + def maxStringLength = isettings.maxPrintString + def isStripping = isettings.unwrapStrings + def isTruncating = reporter.truncationOK + + def stripImpl(str: String): String = naming.unmangle(str) + } +} + +/** Settings for the interpreter + * + * @version 1.0 + * @author Lex Spoon, 2007/3/24 + **/ +class SparkISettings(intp: SparkIMain) { + /** The maximum length of toString to use when printing the result + * of an evaluation. 0 means no maximum. If a printout requires + * more than this number of characters, then the printout is + * truncated. + */ + var maxPrintString = replProps.maxPrintString.option.getOrElse(800) + + /** The maximum number of completion candidates to print for tab + * completion without requiring confirmation. + */ + var maxAutoprintCompletion = 250 + + /** String unwrapping can be disabled if it is causing issues. + * Setting this to false means you will see Strings like "$iw.$iw.". + */ + var unwrapStrings = true + + def deprecation_=(x: Boolean) = { + val old = intp.settings.deprecation.value + intp.settings.deprecation.value = x + if (!old && x) println("Enabled -deprecation output.") + else if (old && !x) println("Disabled -deprecation output.") + } + def deprecation: Boolean = intp.settings.deprecation.value + + def allSettings = Map[String, Any]( + "maxPrintString" -> maxPrintString, + "maxAutoprintCompletion" -> maxAutoprintCompletion, + "unwrapStrings" -> unwrapStrings, + "deprecation" -> deprecation + ) + + private def allSettingsString = + allSettings.toList sortBy (_._1) map { case (k, v) => " " + k + " = " + v + "\n" } mkString + + override def toString = """ + | SparkISettings { + | %s + | }""".stripMargin.format(allSettingsString) +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkImports.scala new file mode 100644 index 0000000000..e60406d1e5 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkImports.scala @@ -0,0 +1,201 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Paul Phillips + */ + +package scala.tools.nsc +package interpreter + +import scala.collection.{ mutable, immutable } + +trait SparkImports { + self: SparkIMain => + + import global._ + import definitions.{ ObjectClass, ScalaPackage, JavaLangPackage, PredefModule } + import memberHandlers._ + + /** Synthetic import handlers for the language defined imports. */ + private def makeWildcardImportHandler(sym: Symbol): ImportHandler = { + val hd :: tl = sym.fullName.split('.').toList map newTermName + val tree = Import( + tl.foldLeft(Ident(hd): Tree)((x, y) => Select(x, y)), + ImportSelector.wildList + ) + tree setSymbol sym + new ImportHandler(tree) + } + + /** Symbols whose contents are language-defined to be imported. */ + def languageWildcardSyms: List[Symbol] = List(JavaLangPackage, ScalaPackage, PredefModule) + def languageWildcardHandlers = languageWildcardSyms map makeWildcardImportHandler + + def allImportedNames = importHandlers flatMap (_.importedNames) + + /** Types which have been wildcard imported, such as: + * val x = "abc" ; import x._ // type java.lang.String + * import java.lang.String._ // object java.lang.String + * + * Used by tab completion. + * + * XXX right now this gets import x._ and import java.lang.String._, + * but doesn't figure out import String._. There's a lot of ad hoc + * scope twiddling which should be swept away in favor of digging + * into the compiler scopes. + */ + def sessionWildcards: List[Type] = { + importHandlers filter (_.importsWildcard) map (_.targetType) distinct + } + + def languageSymbols = languageWildcardSyms flatMap membersAtPickler + def sessionImportedSymbols = importHandlers flatMap (_.importedSymbols) + def importedSymbols = languageSymbols ++ sessionImportedSymbols + def importedTermSymbols = importedSymbols collect { case x: TermSymbol => x } + + /** Tuples of (source, imported symbols) in the order they were imported. + */ + def importedSymbolsBySource: List[(Symbol, List[Symbol])] = { + val lang = languageWildcardSyms map (sym => (sym, membersAtPickler(sym))) + val session = importHandlers filter (_.targetType != NoType) map { mh => + (mh.targetType.typeSymbol, mh.importedSymbols) + } + + lang ++ session + } + def implicitSymbolsBySource: List[(Symbol, List[Symbol])] = { + importedSymbolsBySource map { + case (k, vs) => (k, vs filter (_.isImplicit)) + } filterNot (_._2.isEmpty) + } + + /** Compute imports that allow definitions from previous + * requests to be visible in a new request. Returns + * three pieces of related code: + * + * 1. An initial code fragment that should go before + * the code of the new request. + * + * 2. A code fragment that should go after the code + * of the new request. + * + * 3. An access path which can be traversed to access + * any bindings inside code wrapped by #1 and #2 . + * + * The argument is a set of Names that need to be imported. + * + * Limitations: This method is not as precise as it could be. + * (1) It does not process wildcard imports to see what exactly + * they import. + * (2) If it imports any names from a request, it imports all + * of them, which is not really necessary. + * (3) It imports multiple same-named implicits, but only the + * last one imported is actually usable. + */ + case class ComputedImports(prepend: String, append: String, access: String) + protected def importsCode(wanted: Set[Name], wrapper: Request#Wrapper, definedClass: Boolean): ComputedImports = { + /** Narrow down the list of requests from which imports + * should be taken. Removes requests which cannot contribute + * useful imports for the specified set of wanted names. + */ + case class ReqAndHandler(req: Request, handler: MemberHandler) { } + + def reqsToUse: List[ReqAndHandler] = { + /** Loop through a list of MemberHandlers and select which ones to keep. + * 'wanted' is the set of names that need to be imported. + */ + def select(reqs: List[ReqAndHandler], wanted: Set[Name]): List[ReqAndHandler] = { + // Single symbol imports might be implicits! See bug #1752. Rather than + // try to finesse this, we will mimic all imports for now. + def keepHandler(handler: MemberHandler) = handler match { + case h: ImportHandler if definedClass => h.importedNames.exists(x => wanted.contains(x)) + case _: ImportHandler => true + case x => x.definesImplicit || (x.definedNames exists wanted) + } + + reqs match { + case Nil => Nil + case rh :: rest if !keepHandler(rh.handler) => select(rest, wanted) + case rh :: rest => + import rh.handler._ + val newWanted = wanted ++ referencedNames -- definedNames -- importedNames + rh :: select(rest, newWanted) + } + } + + /** Flatten the handlers out and pair each with the original request */ + select(allReqAndHandlers reverseMap { case (r, h) => ReqAndHandler(r, h) }, wanted).reverse + } + + val code, trailingBraces, accessPath = new StringBuilder + val currentImps = mutable.HashSet[Name]() + + // add code for a new object to hold some imports + def addWrapper() { + import nme.{ INTERPRETER_IMPORT_WRAPPER => iw } + code append (wrapper.prewrap format iw) + trailingBraces append wrapper.postwrap + accessPath append s".$iw" + currentImps.clear() + } + + def maybeWrap(names: Name*) = if (names exists currentImps) addWrapper() + + def wrapBeforeAndAfter[T](op: => T): T = { + addWrapper() + try op finally addWrapper() + } + + // loop through previous requests, adding imports for each one + wrapBeforeAndAfter { + for (ReqAndHandler(req, handler) <- reqsToUse) { + handler match { + // If the user entered an import, then just use it; add an import wrapping + // level if the import might conflict with some other import + case x: ImportHandler if x.importsWildcard => + wrapBeforeAndAfter(code append (x.member + "\n")) + case x: ImportHandler => + maybeWrap(x.importedNames: _*) + code append (x.member + "\n") + currentImps ++= x.importedNames + + case x: ClassHandler => + for (imv <- x.definedNames) { + val objName = req.lineRep.readPath + code.append("import " + objName + ".INSTANCE" + req.accessPath + ".`" + imv + "`\n") + } + + // For other requests, import each defined name. + // import them explicitly instead of with _, so that + // ambiguity errors will not be generated. Also, quote + // the name of the variable, so that we don't need to + // handle quoting keywords separately. + case x => + for (imv <- x.definedNames) { + if (currentImps contains imv) addWrapper() + val objName = req.lineRep.readPath + val valName = "$VAL" + newValId() + if(!code.toString.endsWith(".`" + imv + "`;\n")) { // Which means already imported + code.append("val " + valName + " = " + objName + ".INSTANCE\n") + code.append("import " + valName + req.accessPath + ".`" + imv + "`;\n") + } + currentImps += imv + } + } + } + } + addWrapper() + ComputedImports(code.toString, trailingBraces.toString, accessPath.toString) + } + private var curValId = 0 + + private def newValId(): Int = { + curValId += 1 + curValId + } + + private def allReqAndHandlers = + prevRequestList flatMap (req => req.handlers map (req -> _)) + + private def membersAtPickler(sym: Symbol): List[Symbol] = + enteringPickler(sym.info.nonPrivateMembers.toList) +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala new file mode 100644 index 0000000000..7fe6dcb328 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala @@ -0,0 +1,350 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Paul Phillips + */ + +package scala.tools.nsc +package interpreter + +import Completion._ +import scala.collection.mutable.ListBuffer +import scala.reflect.internal.util.StringOps.longestCommonPrefix + +// REPL completor - queries supplied interpreter for valid +// completions based on current contents of buffer. +class SparkJLineCompletion(val intp: SparkIMain) extends Completion with CompletionOutput { + val global: intp.global.type = intp.global + import global._ + import definitions._ + import rootMirror.{ RootClass, getModuleIfDefined } + import intp.{ debugging } + + // verbosity goes up with consecutive tabs + private var verbosity: Int = 0 + def resetVerbosity() = verbosity = 0 + + def getSymbol(name: String, isModule: Boolean) = ( + if (isModule) getModuleIfDefined(name) + else getModuleIfDefined(name) + ) + + trait CompilerCompletion { + def tp: Type + def effectiveTp = tp match { + case MethodType(Nil, resType) => resType + case NullaryMethodType(resType) => resType + case _ => tp + } + + // for some reason any's members don't show up in subclasses, which + // we need so 5.<tab> offers asInstanceOf etc. + private def anyMembers = AnyTpe.nonPrivateMembers + def anyRefMethodsToShow = Set("isInstanceOf", "asInstanceOf", "toString") + + def tos(sym: Symbol): String = sym.decodedName + def memberNamed(s: String) = exitingTyper(effectiveTp member newTermName(s)) + + // XXX we'd like to say "filterNot (_.isDeprecated)" but this causes the + // compiler to crash for reasons not yet known. + def members = exitingTyper((effectiveTp.nonPrivateMembers.toList ++ anyMembers) filter (_.isPublic)) + def methods = members.toList filter (_.isMethod) + def packages = members.toList filter (_.hasPackageFlag) + def aliases = members.toList filter (_.isAliasType) + + def memberNames = members map tos + def methodNames = methods map tos + def packageNames = packages map tos + def aliasNames = aliases map tos + } + + object NoTypeCompletion extends TypeMemberCompletion(NoType) { + override def memberNamed(s: String) = NoSymbol + override def members = Nil + override def follow(s: String) = None + override def alternativesFor(id: String) = Nil + } + + object TypeMemberCompletion { + def apply(tp: Type, runtimeType: Type, param: NamedParam): TypeMemberCompletion = { + new TypeMemberCompletion(tp) { + var upgraded = false + lazy val upgrade = { + intp rebind param + intp.reporter.printMessage("\nRebinding stable value %s from %s to %s".format(param.name, tp, param.tpe)) + upgraded = true + new TypeMemberCompletion(runtimeType) + } + override def completions(verbosity: Int) = { + super.completions(verbosity) ++ ( + if (verbosity == 0) Nil + else upgrade.completions(verbosity) + ) + } + override def follow(s: String) = super.follow(s) orElse { + if (upgraded) upgrade.follow(s) + else None + } + override def alternativesFor(id: String) = super.alternativesFor(id) ++ ( + if (upgraded) upgrade.alternativesFor(id) + else Nil + ) distinct + } + } + def apply(tp: Type): TypeMemberCompletion = { + if (tp eq NoType) NoTypeCompletion + else if (tp.typeSymbol.isPackageClass) new PackageCompletion(tp) + else new TypeMemberCompletion(tp) + } + def imported(tp: Type) = new ImportCompletion(tp) + } + + class TypeMemberCompletion(val tp: Type) extends CompletionAware + with CompilerCompletion { + def excludeEndsWith: List[String] = Nil + def excludeStartsWith: List[String] = List("<") // <byname>, <repeated>, etc. + def excludeNames: List[String] = (anyref.methodNames filterNot anyRefMethodsToShow) :+ "_root_" + + def methodSignatureString(sym: Symbol) = { + IMain stripString exitingTyper(new MethodSymbolOutput(sym).methodString()) + } + + def exclude(name: String): Boolean = ( + (name contains "$") || + (excludeNames contains name) || + (excludeEndsWith exists (name endsWith _)) || + (excludeStartsWith exists (name startsWith _)) + ) + def filtered(xs: List[String]) = xs filterNot exclude distinct + + def completions(verbosity: Int) = + debugging(tp + " completions ==> ")(filtered(memberNames)) + + override def follow(s: String): Option[CompletionAware] = + debugging(tp + " -> '" + s + "' ==> ")(Some(TypeMemberCompletion(memberNamed(s).tpe)) filterNot (_ eq NoTypeCompletion)) + + override def alternativesFor(id: String): List[String] = + debugging(id + " alternatives ==> ") { + val alts = members filter (x => x.isMethod && tos(x) == id) map methodSignatureString + + if (alts.nonEmpty) "" :: alts else Nil + } + + override def toString = "%s (%d members)".format(tp, members.size) + } + + class PackageCompletion(tp: Type) extends TypeMemberCompletion(tp) { + override def excludeNames = anyref.methodNames + } + + class LiteralCompletion(lit: Literal) extends TypeMemberCompletion(lit.value.tpe) { + override def completions(verbosity: Int) = verbosity match { + case 0 => filtered(memberNames) + case _ => memberNames + } + } + + class ImportCompletion(tp: Type) extends TypeMemberCompletion(tp) { + override def completions(verbosity: Int) = verbosity match { + case 0 => filtered(members filterNot (_.isSetter) map tos) + case _ => super.completions(verbosity) + } + } + + // not for completion but for excluding + object anyref extends TypeMemberCompletion(AnyRefTpe) { } + + // the unqualified vals/defs/etc visible in the repl + object ids extends CompletionAware { + override def completions(verbosity: Int) = intp.unqualifiedIds ++ List("classOf") //, "_root_") + // now we use the compiler for everything. + override def follow(id: String): Option[CompletionAware] = { + if (!completions(0).contains(id)) + return None + + val tpe = intp typeOfExpression id + if (tpe == NoType) + return None + + def default = Some(TypeMemberCompletion(tpe)) + + // only rebinding vals in power mode for now. + if (!isReplPower) default + else intp runtimeClassAndTypeOfTerm id match { + case Some((clazz, runtimeType)) => + val sym = intp.symbolOfTerm(id) + if (sym.isStable) { + val param = new NamedParam.Untyped(id, intp valueOfTerm id getOrElse null) + Some(TypeMemberCompletion(tpe, runtimeType, param)) + } + else default + case _ => + default + } + } + override def toString = "<repl ids> (%s)".format(completions(0).size) + } + + // user-issued wildcard imports like "import global._" or "import String._" + private def imported = intp.sessionWildcards map TypeMemberCompletion.imported + + // literal Ints, Strings, etc. + object literals extends CompletionAware { + def simpleParse(code: String): Option[Tree] = newUnitParser(code).parseStats().lastOption + def completions(verbosity: Int) = Nil + + override def follow(id: String) = simpleParse(id).flatMap { + case x: Literal => Some(new LiteralCompletion(x)) + case _ => None + } + } + + // top level packages + object rootClass extends TypeMemberCompletion(RootClass.tpe) { + override def completions(verbosity: Int) = super.completions(verbosity) :+ "_root_" + override def follow(id: String) = id match { + case "_root_" => Some(this) + case _ => super.follow(id) + } + } + // members of Predef + object predef extends TypeMemberCompletion(PredefModule.tpe) { + override def excludeEndsWith = super.excludeEndsWith ++ List("Wrapper", "ArrayOps") + override def excludeStartsWith = super.excludeStartsWith ++ List("wrap") + override def excludeNames = anyref.methodNames + + override def exclude(name: String) = super.exclude(name) || ( + (name contains "2") + ) + + override def completions(verbosity: Int) = verbosity match { + case 0 => Nil + case _ => super.completions(verbosity) + } + } + // members of scala.* + object scalalang extends PackageCompletion(ScalaPackage.tpe) { + def arityClasses = List("Product", "Tuple", "Function") + def skipArity(name: String) = arityClasses exists (x => name != x && (name startsWith x)) + override def exclude(name: String) = super.exclude(name) || ( + skipArity(name) + ) + + override def completions(verbosity: Int) = verbosity match { + case 0 => filtered(packageNames ++ aliasNames) + case _ => super.completions(verbosity) + } + } + // members of java.lang.* + object javalang extends PackageCompletion(JavaLangPackage.tpe) { + override lazy val excludeEndsWith = super.excludeEndsWith ++ List("Exception", "Error") + override lazy val excludeStartsWith = super.excludeStartsWith ++ List("CharacterData") + + override def completions(verbosity: Int) = verbosity match { + case 0 => filtered(packageNames) + case _ => super.completions(verbosity) + } + } + + // the list of completion aware objects which should be consulted + // for top level unqualified, it's too noisy to let much in. + lazy val topLevelBase: List[CompletionAware] = List(ids, rootClass, predef, scalalang, javalang, literals) + def topLevel = topLevelBase ++ imported + def topLevelThreshold = 50 + + // the first tier of top level objects (doesn't include file completion) + def topLevelFor(parsed: Parsed): List[String] = { + val buf = new ListBuffer[String] + topLevel foreach { ca => + buf ++= (ca completionsFor parsed) + + if (buf.size > topLevelThreshold) + return buf.toList.sorted + } + buf.toList + } + + // the most recent result + def lastResult = Forwarder(() => ids follow intp.mostRecentVar) + + def lastResultFor(parsed: Parsed) = { + /** The logic is a little tortured right now because normally '.' is + * ignored as a delimiter, but on .<tab> it needs to be propagated. + */ + val xs = lastResult completionsFor parsed + if (parsed.isEmpty) xs map ("." + _) else xs + } + + def completer(): ScalaCompleter = new JLineTabCompletion + + /** This gets a little bit hairy. It's no small feat delegating everything + * and also keeping track of exactly where the cursor is and where it's supposed + * to end up. The alternatives mechanism is a little hacky: if there is an empty + * string in the list of completions, that means we are expanding a unique + * completion, so don't update the "last" buffer because it'll be wrong. + */ + class JLineTabCompletion extends ScalaCompleter { + // For recording the buffer on the last tab hit + private var lastBuf: String = "" + private var lastCursor: Int = -1 + + // Does this represent two consecutive tabs? + def isConsecutiveTabs(buf: String, cursor: Int) = + cursor == lastCursor && buf == lastBuf + + // This is jline's entry point for completion. + override def complete(buf: String, cursor: Int): Candidates = { + verbosity = if (isConsecutiveTabs(buf, cursor)) verbosity + 1 else 0 + repldbg(f"%ncomplete($buf, $cursor%d) last = ($lastBuf, $lastCursor%d), verbosity: $verbosity") + + // we don't try lower priority completions unless higher ones return no results. + def tryCompletion(p: Parsed, completionFunction: Parsed => List[String]): Option[Candidates] = { + val winners = completionFunction(p) + if (winners.isEmpty) + return None + val newCursor = + if (winners contains "") p.cursor + else { + val advance = longestCommonPrefix(winners) + lastCursor = p.position + advance.length + lastBuf = (buf take p.position) + advance + repldbg(s"tryCompletion($p, _) lastBuf = $lastBuf, lastCursor = $lastCursor, p.position = ${p.position}") + p.position + } + + Some(Candidates(newCursor, winners)) + } + + def mkDotted = Parsed.dotted(buf, cursor) withVerbosity verbosity + + // a single dot is special cased to completion on the previous result + def lastResultCompletion = + if (!looksLikeInvocation(buf)) None + else tryCompletion(Parsed.dotted(buf drop 1, cursor), lastResultFor) + + def tryAll = ( + lastResultCompletion + orElse tryCompletion(mkDotted, topLevelFor) + getOrElse Candidates(cursor, Nil) + ) + + /** + * This is the kickoff point for all manner of theoretically + * possible compiler unhappiness. The fault may be here or + * elsewhere, but we don't want to crash the repl regardless. + * The compiler makes it impossible to avoid catching Throwable + * with its unfortunate tendency to throw java.lang.Errors and + * AssertionErrors as the hats drop. We take two swings at it + * because there are some spots which like to throw an assertion + * once, then work after that. Yeah, what can I say. + */ + try tryAll + catch { case ex: Throwable => + repldbg("Error: complete(%s, %s) provoked".format(buf, cursor) + ex) + Candidates(cursor, + if (isReplDebug) List("<error:" + ex + ">") + else Nil + ) + } + } + } +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala new file mode 100644 index 0000000000..0e22bc806d --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala @@ -0,0 +1,221 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Martin Odersky + */ + +package scala.tools.nsc +package interpreter + +import scala.collection.{ mutable, immutable } +import scala.language.implicitConversions + +trait SparkMemberHandlers { + val intp: SparkIMain + + import intp.{ Request, global, naming } + import global._ + import naming._ + + private def codegenln(leadingPlus: Boolean, xs: String*): String = codegen(leadingPlus, (xs ++ Array("\n")): _*) + private def codegenln(xs: String*): String = codegenln(true, xs: _*) + private def codegen(leadingPlus: Boolean, xs: String*): String = { + val front = if (leadingPlus) "+ " else "" + front + (xs map string2codeQuoted mkString " + ") + } + private implicit def name2string(name: Name) = name.toString + + /** A traverser that finds all mentioned identifiers, i.e. things + * that need to be imported. It might return extra names. + */ + private class ImportVarsTraverser extends Traverser { + val importVars = new mutable.HashSet[Name]() + + override def traverse(ast: Tree) = ast match { + case Ident(name) => + // XXX this is obviously inadequate but it's going to require some effort + // to get right. + if (name.toString startsWith "x$") () + else importVars += name + case _ => super.traverse(ast) + } + } + private object ImportVarsTraverser { + def apply(member: Tree) = { + val ivt = new ImportVarsTraverser() + ivt traverse member + ivt.importVars.toList + } + } + + private def isTermMacro(ddef: DefDef): Boolean = ddef.mods.isMacro + + def chooseHandler(member: Tree): MemberHandler = member match { + case member: DefDef if isTermMacro(member) => new TermMacroHandler(member) + case member: DefDef => new DefHandler(member) + case member: ValDef => new ValHandler(member) + case member: ModuleDef => new ModuleHandler(member) + case member: ClassDef => new ClassHandler(member) + case member: TypeDef => new TypeAliasHandler(member) + case member: Assign => new AssignHandler(member) + case member: Import => new ImportHandler(member) + case DocDef(_, documented) => chooseHandler(documented) + case member => new GenericHandler(member) + } + + sealed abstract class MemberDefHandler(override val member: MemberDef) extends MemberHandler(member) { + override def name: Name = member.name + def mods: Modifiers = member.mods + def keyword = member.keyword + def prettyName = name.decode + + override def definesImplicit = member.mods.isImplicit + override def definesTerm: Option[TermName] = Some(name.toTermName) filter (_ => name.isTermName) + override def definesType: Option[TypeName] = Some(name.toTypeName) filter (_ => name.isTypeName) + override def definedSymbols = if (symbol.exists) symbol :: Nil else Nil + } + + /** Class to handle one member among all the members included + * in a single interpreter request. + */ + sealed abstract class MemberHandler(val member: Tree) { + def name: Name = nme.NO_NAME + def path = intp.originalPath(symbol).replaceFirst("read", "read.INSTANCE") + def symbol = if (member.symbol eq null) NoSymbol else member.symbol + def definesImplicit = false + def definesValue = false + + def definesTerm = Option.empty[TermName] + def definesType = Option.empty[TypeName] + + private lazy val _referencedNames = ImportVarsTraverser(member) + def referencedNames = _referencedNames + def importedNames = List[Name]() + def definedNames = definesTerm.toList ++ definesType.toList + def definedSymbols = List[Symbol]() + + def extraCodeToEvaluate(req: Request): String = "" + def resultExtractionCode(req: Request): String = "" + + private def shortName = this.getClass.toString split '.' last + override def toString = shortName + referencedNames.mkString(" (refs: ", ", ", ")") + } + + class GenericHandler(member: Tree) extends MemberHandler(member) + + class ValHandler(member: ValDef) extends MemberDefHandler(member) { + val maxStringElements = 1000 // no need to mkString billions of elements + override def definesValue = true + + override def resultExtractionCode(req: Request): String = { + + val isInternal = isUserVarName(name) && req.lookupTypeOf(name) == "Unit" + if (!mods.isPublic || isInternal) "" + else { + // if this is a lazy val we avoid evaluating it here + val resultString = + if (mods.isLazy) codegenln(false, "<lazy>") + else any2stringOf(path, maxStringElements) + + val vidString = + if (replProps.vids) s"""" + " @ " + "%%8x".format(System.identityHashCode($path)) + " """.trim + else "" + + """ + "%s%s: %s = " + %s""".format(string2code(prettyName), vidString, string2code(req typeOf name), resultString) + } + } + } + + class DefHandler(member: DefDef) extends MemberDefHandler(member) { + override def definesValue = flattensToEmpty(member.vparamss) // true if 0-arity + override def resultExtractionCode(req: Request) = + if (mods.isPublic) codegenln(name, ": ", req.typeOf(name)) else "" + } + + abstract class MacroHandler(member: DefDef) extends MemberDefHandler(member) { + override def referencedNames = super.referencedNames.flatMap(name => List(name.toTermName, name.toTypeName)) + override def definesValue = false + override def definesTerm: Option[TermName] = Some(name.toTermName) + override def definesType: Option[TypeName] = None + override def resultExtractionCode(req: Request) = if (mods.isPublic) codegenln(notification(req)) else "" + def notification(req: Request): String + } + + class TermMacroHandler(member: DefDef) extends MacroHandler(member) { + def notification(req: Request) = s"defined term macro $name: ${req.typeOf(name)}" + } + + class AssignHandler(member: Assign) extends MemberHandler(member) { + val Assign(lhs, rhs) = member + override lazy val name = newTermName(freshInternalVarName()) + + override def definesTerm = Some(name) + override def definesValue = true + override def extraCodeToEvaluate(req: Request) = + """val %s = %s""".format(name, lhs) + + /** Print out lhs instead of the generated varName */ + override def resultExtractionCode(req: Request) = { + val lhsType = string2code(req lookupTypeOf name) + val res = string2code(req fullPath name) + """ + "%s: %s = " + %s + "\n" """.format(string2code(lhs.toString), lhsType, res) + "\n" + } + } + + class ModuleHandler(module: ModuleDef) extends MemberDefHandler(module) { + override def definesTerm = Some(name.toTermName) + override def definesValue = true + + override def resultExtractionCode(req: Request) = codegenln("defined object ", name) + } + + class ClassHandler(member: ClassDef) extends MemberDefHandler(member) { + override def definedSymbols = List(symbol, symbol.companionSymbol) filterNot (_ == NoSymbol) + override def definesType = Some(name.toTypeName) + override def definesTerm = Some(name.toTermName) filter (_ => mods.isCase) + + override def resultExtractionCode(req: Request) = + codegenln("defined %s %s".format(keyword, name)) + } + + class TypeAliasHandler(member: TypeDef) extends MemberDefHandler(member) { + private def isAlias = mods.isPublic && treeInfo.isAliasTypeDef(member) + override def definesType = Some(name.toTypeName) filter (_ => isAlias) + + override def resultExtractionCode(req: Request) = + codegenln("defined type alias ", name) + "\n" + } + + class ImportHandler(imp: Import) extends MemberHandler(imp) { + val Import(expr, selectors) = imp + def targetType = intp.global.rootMirror.getModuleIfDefined("" + expr) match { + case NoSymbol => intp.typeOfExpression("" + expr) + case sym => sym.thisType + } + private def importableTargetMembers = importableMembers(targetType).toList + // wildcard imports, e.g. import foo._ + private def selectorWild = selectors filter (_.name == nme.USCOREkw) + // renamed imports, e.g. import foo.{ bar => baz } + private def selectorRenames = selectors map (_.rename) filterNot (_ == null) + + /** Whether this import includes a wildcard import */ + val importsWildcard = selectorWild.nonEmpty + + def implicitSymbols = importedSymbols filter (_.isImplicit) + def importedSymbols = individualSymbols ++ wildcardSymbols + + private val selectorNames = selectorRenames filterNot (_ == nme.USCOREkw) flatMap (_.bothNames) toSet + lazy val individualSymbols: List[Symbol] = exitingTyper(importableTargetMembers filter (m => selectorNames(m.name))) + lazy val wildcardSymbols: List[Symbol] = exitingTyper(if (importsWildcard) importableTargetMembers else Nil) + + /** Complete list of names imported by a wildcard */ + lazy val wildcardNames: List[Name] = wildcardSymbols map (_.name) + lazy val individualNames: List[Name] = individualSymbols map (_.name) + + /** The names imported by this statement */ + override lazy val importedNames: List[Name] = wildcardNames ++ individualNames + lazy val importsSymbolNamed: Set[String] = importedNames map (_.toString) toSet + + def importString = imp.toString + override def resultExtractionCode(req: Request) = codegenln(importString) + "\n" + } +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkReplReporter.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkReplReporter.scala new file mode 100644 index 0000000000..0711ed4871 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkReplReporter.scala @@ -0,0 +1,53 @@ +/* NSC -- new Scala compiler + * Copyright 2002-2013 LAMP/EPFL + * @author Paul Phillips + */ + +package scala.tools.nsc +package interpreter + +import reporters._ +import SparkIMain._ + +import scala.reflect.internal.util.Position + +/** Like ReplGlobal, a layer for ensuring extra functionality. + */ +class SparkReplReporter(intp: SparkIMain) extends ConsoleReporter(intp.settings, Console.in, new SparkReplStrippingWriter(intp)) { + def printUntruncatedMessage(msg: String) = withoutTruncating(printMessage(msg)) + + /** Whether very long lines can be truncated. This exists so important + * debugging information (like printing the classpath) is not rendered + * invisible due to the max message length. + */ + private var _truncationOK: Boolean = !intp.settings.verbose + def truncationOK = _truncationOK + def withoutTruncating[T](body: => T): T = { + val saved = _truncationOK + _truncationOK = false + try body + finally _truncationOK = saved + } + + override def warning(pos: Position, msg: String): Unit = withoutTruncating(super.warning(pos, msg)) + override def error(pos: Position, msg: String): Unit = withoutTruncating(super.error(pos, msg)) + + override def printMessage(msg: String) { + // Avoiding deadlock if the compiler starts logging before + // the lazy val is complete. + if (intp.isInitializeComplete) { + if (intp.totalSilence) { + if (isReplTrace) + super.printMessage("[silent] " + msg) + } + else super.printMessage(msg) + } + else Console.println("[init] " + msg) + } + + override def displayPrompt() { + if (intp.totalSilence) () + else super.displayPrompt() + } + +} diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala new file mode 100644 index 0000000000..f966f25c5a --- /dev/null +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import java.io._ +import java.net.URLClassLoader + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.tools.nsc.interpreter.SparkILoop + +import com.google.common.io.Files +import org.scalatest.FunSuite +import org.apache.commons.lang3.StringEscapeUtils +import org.apache.spark.SparkContext +import org.apache.spark.util.Utils + + + +class ReplSuite extends FunSuite { + + def runInterpreter(master: String, input: String): String = { + val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath" + + val in = new BufferedReader(new StringReader(input + "\n")) + val out = new StringWriter() + val cl = getClass.getClassLoader + var paths = new ArrayBuffer[String] + if (cl.isInstanceOf[URLClassLoader]) { + val urlLoader = cl.asInstanceOf[URLClassLoader] + for (url <- urlLoader.getURLs) { + if (url.getProtocol == "file") { + paths += url.getFile + } + } + } + val classpath = paths.mkString(File.pathSeparator) + + val oldExecutorClasspath = System.getProperty(CONF_EXECUTOR_CLASSPATH) + System.setProperty(CONF_EXECUTOR_CLASSPATH, classpath) + + System.setProperty("spark.master", master) + val interp = { + new SparkILoop(in, new PrintWriter(out)) + } + org.apache.spark.repl.Main.interp = interp + Main.s.processArguments(List("-classpath", classpath), true) + Main.main(Array()) // call main + org.apache.spark.repl.Main.interp = null + + if (oldExecutorClasspath != null) { + System.setProperty(CONF_EXECUTOR_CLASSPATH, oldExecutorClasspath) + } else { + System.clearProperty(CONF_EXECUTOR_CLASSPATH) + } + return out.toString + } + + def assertContains(message: String, output: String) { + val isContain = output.contains(message) + assert(isContain, + "Interpreter output did not contain '" + message + "':\n" + output) + } + + def assertDoesNotContain(message: String, output: String) { + val isContain = output.contains(message) + assert(!isContain, + "Interpreter output contained '" + message + "':\n" + output) + } + + test("propagation of local properties") { + // A mock ILoop that doesn't install the SIGINT handler. + class ILoop(out: PrintWriter) extends SparkILoop(None, out) { + settings = new scala.tools.nsc.Settings + settings.usejavacp.value = true + org.apache.spark.repl.Main.interp = this + override def createInterpreter() { + intp = new SparkILoopInterpreter + intp.setContextClassLoader() + } + } + + val out = new StringWriter() + Main.interp = new ILoop(new PrintWriter(out)) + Main.sparkContext = new SparkContext("local", "repl-test") + Main.interp.createInterpreter() + + Main.sparkContext.setLocalProperty("someKey", "someValue") + + // Make sure the value we set in the caller to interpret is propagated in the thread that + // interprets the command. + Main.interp.interpret("org.apache.spark.repl.Main.sparkContext.getLocalProperty(\"someKey\")") + assert(out.toString.contains("someValue")) + + Main.sparkContext.stop() + System.clearProperty("spark.driver.port") + } + + test("simple foreach with accumulator") { + val output = runInterpreter("local", + """ + |val accum = sc.accumulator(0) + |sc.parallelize(1 to 10).foreach(x => accum += x) + |accum.value + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res1: Int = 55", output) + } + + test("external vars") { + val output = runInterpreter("local", + """ + |var v = 7 + |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) + |v = 10 + |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res1: Int = 100", output) + } + + test("external classes") { + val output = runInterpreter("local", + """ + |class C { + |def foo = 5 + |} + |sc.parallelize(1 to 10).map(x => (new C).foo).collect.reduceLeft(_+_) + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 50", output) + } + + test("external functions") { + val output = runInterpreter("local", + """ + |def double(x: Int) = x + x + |sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_) + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 110", output) + } + + test("external functions that access vars") { + val output = runInterpreter("local", + """ + |var v = 7 + |def getV() = v + |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |v = 10 + |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res1: Int = 100", output) + } + + test("broadcast vars") { + // Test that the value that a broadcast var had when it was created is used, + // even if that variable is then modified in the driver program + // TODO: This doesn't actually work for arrays when we run in local mode! + val output = runInterpreter("local", + """ + |var array = new Array[Int](5) + |val broadcastArray = sc.broadcast(array) + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |array(0) = 5 + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output) + } + + test("interacting with files") { + val tempDir = Files.createTempDir() + tempDir.deleteOnExit() + val out = new FileWriter(tempDir + "/input") + out.write("Hello world!\n") + out.write("What's up?\n") + out.write("Goodbye\n") + out.close() + val output = runInterpreter("local", + """ + |var file = sc.textFile("%s").cache() + |file.count() + |file.count() + |file.count() + """.stripMargin.format(StringEscapeUtils.escapeJava( + tempDir.getAbsolutePath + File.separator + "input"))) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Long = 3", output) + assertContains("res1: Long = 3", output) + assertContains("res2: Long = 3", output) + Utils.deleteRecursively(tempDir) + } + + test("local-cluster mode") { + val output = runInterpreter("local-cluster[1,1,512]", + """ + |var v = 7 + |def getV() = v + |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |v = 10 + |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |var array = new Array[Int](5) + |val broadcastArray = sc.broadcast(array) + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |array(0) = 5 + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res1: Int = 100", output) + assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output) + } + + test("SPARK-1199 two instances of same class don't type check.") { + val output = runInterpreter("local-cluster[1,1,512]", + """ + |case class Sum(exp: String, exp2: String) + |val a = Sum("A", "B") + |def b(a: Sum): String = a match { case Sum(_, _) => "Found Sum" } + |b(a) + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + test("SPARK-2452 compound statements.") { + val output = runInterpreter("local", + """ + |val x = 4 ; def f() = x + |f() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + test("SPARK-2576 importing SQLContext.createSchemaRDD.") { + // We need to use local-cluster to test this case. + val output = runInterpreter("local-cluster[1,1,512]", + """ + |val sqlContext = new org.apache.spark.sql.SQLContext(sc) + |import sqlContext.createSchemaRDD + |case class TestCaseClass(value: Int) + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toSchemaRDD.collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + test("SPARK-2632 importing a method from non serializable class and not using it.") { + val output = runInterpreter("local", + """ + |class TestClass() { def testMethod = 3 } + |val t = new TestClass + |import t.testMethod + |case class TestCaseClass(value: Int) + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + + if (System.getenv("MESOS_NATIVE_LIBRARY") != null) { + test("running on Mesos") { + val output = runInterpreter("localquiet", + """ + |var v = 7 + |def getV() = v + |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |v = 10 + |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |var array = new Array[Int](5) + |val broadcastArray = sc.broadcast(array) + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |array(0) = 5 + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res1: Int = 100", output) + assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output) + } + } + + test("collecting objects of class defined in repl") { + val output = runInterpreter("local[2]", + """ + |case class Foo(i: Int) + |val ret = sc.parallelize((1 to 100).map(Foo), 10).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("ret: Array[Foo] = Array(Foo(1),", output) + } +} |