diff options
Diffstat (limited to 'repl/src')
3 files changed, 33 insertions, 15 deletions
diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index e3bcf7f30a..1aa94079fd 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -18,12 +18,15 @@ package org.apache.spark.repl import java.io.{ByteArrayOutputStream, InputStream} -import java.net.{URI, URL, URLClassLoader, URLEncoder} +import java.net.{URI, URL, URLEncoder} import java.util.concurrent.{Executors, ExecutorService} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.SparkEnv +import org.apache.spark.util.Utils + import org.objectweb.asm._ import org.objectweb.asm.Opcodes._ @@ -53,7 +56,13 @@ extends ClassLoader(parent) { if (fileSystem != null) { fileSystem.open(new Path(directory, pathInDirectory)) } else { - new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream() + if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { + val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) + val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) + newuri.toURL().openStream() + } else { + new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream() + } } } val bytes = readAndTransformClass(name, inputStream) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index f52ebe4a15..9b1da19500 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -881,6 +881,8 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, }) def process(settings: Settings): Boolean = savingContextLoader { + if (getMaster() == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + this.settings = settings createInterpreter() @@ -939,16 +941,9 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, def createSparkContext(): SparkContext = { val execUri = System.getenv("SPARK_EXECUTOR_URI") - val master = this.master match { - case Some(m) => m - case None => { - val prop = System.getenv("MASTER") - if (prop != null) prop else "local" - } - } val jars = SparkILoop.getAddedJars.map(new java.io.File(_).getAbsolutePath) val conf = new SparkConf() - .setMaster(master) + .setMaster(getMaster()) .setAppName("Spark shell") .setJars(jars) .set("spark.repl.class.uri", intp.classServer.uri) @@ -963,6 +958,17 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, sparkContext } + private def getMaster(): String = { + val master = this.master match { + case Some(m) => m + case None => { + val prop = System.getenv("MASTER") + if (prop != null) prop else "local" + } + } + master + } + /** process command-line arguments and do as they request */ def process(args: Array[String]): Boolean = { val command = new SparkCommandLine(args.toList, msg => echo(msg)) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 1d73d0b699..90a96ad383 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -36,7 +36,7 @@ import scala.tools.reflect.StdRuntimeTags._ import scala.util.control.ControlThrowable import util.stackTraceString -import org.apache.spark.{HttpServer, SparkConf, Logging} +import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf} import org.apache.spark.util.Utils // /** directory to save .class files to */ @@ -83,15 +83,17 @@ import org.apache.spark.util.Utils * @author Moez A. Abdel-Gawad * @author Lex Spoon */ - class SparkIMain(initialSettings: Settings, val out: JPrintWriter) extends SparkImports with Logging { + class SparkIMain(initialSettings: Settings, val out: JPrintWriter) + extends SparkImports with Logging { imain => - val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") + val conf = new SparkConf() + val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") /** Local directory to save .class files too */ val outputDir = { val tmp = System.getProperty("java.io.tmpdir") - val rootDir = new SparkConf().get("spark.repl.classdir", tmp) + val rootDir = conf.get("spark.repl.classdir", tmp) Utils.createTempDir(rootDir) } if (SPARK_DEBUG_REPL) { @@ -99,7 +101,8 @@ import org.apache.spark.util.Utils } val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles - val classServer = new HttpServer(outputDir) /** Jetty server that will serve our classes to worker nodes */ + val classServer = new HttpServer(outputDir, + new SecurityManager(conf)) /** Jetty server that will serve our classes to worker nodes */ private var currentSettings: Settings = initialSettings var printResults = true // whether to print result lines var totalSilence = false // whether to print anything |