aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/pom.xml16
-rw-r--r--core/src/main/scala/org/apache/spark/HttpFileServer.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/HttpServer.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/SecurityManager.scala253
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/SparkSaslClient.scala146
-rw-r--r--core/src/main/scala/org/apache/spark/SparkSaslServer.scala174
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/Client.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/Master.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala5
-rwxr-xr-xcore/src/main/scala/org/apache/spark/deploy/worker/Worker.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/network/BufferMessage.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/network/Connection.scala61
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionId.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManager.scala266
-rw-r--r--core/src/main/scala/org/apache/spark/network/Message.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/network/ReceiverTest.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/network/SecurityMessage.scala163
-rw-r--r--core/src/main/scala/org/apache/spark/network/SenderTest.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/ui/JettyUtils.scala138
-rw-r--r--core/src/main/scala/org/apache/spark/ui/SparkUI.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/util/AkkaUtils.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala37
-rw-r--r--core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala215
-rw-r--r--core/src/test/scala/org/apache/spark/BroadcastSuite.scala1
-rw-r--r--core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala230
-rw-r--r--core/src/test/scala/org/apache/spark/DriverSuite.scala1
-rw-r--r--core/src/test/scala/org/apache/spark/FileServerSuite.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala9
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala67
-rw-r--r--core/src/test/scala/org/apache/spark/ui/UISuite.scala10
-rw-r--r--docs/configuration.md51
-rw-r--r--docs/index.md1
-rw-r--r--docs/security.md18
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala7
-rw-r--r--pom.xml20
-rw-r--r--project/SparkBuild.scala4
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala13
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala22
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala13
-rw-r--r--yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala44
-rw-r--r--yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala6
-rw-r--r--yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala2
-rw-r--r--yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala24
-rw-r--r--yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala28
-rw-r--r--yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala6
72 files changed, 2251 insertions, 292 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 99c841472b..4c1c2d4da5 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -66,6 +66,18 @@
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-plus</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-security</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-util</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
</dependency>
<dependency>
@@ -119,6 +131,10 @@
<version>0.3.1</version>
</dependency>
<dependency>
+ <groupId>commons-net</groupId>
+ <artifactId>commons-net</artifactId>
+ </dependency>
+ <dependency>
<groupId>${akka.group}</groupId>
<artifactId>akka-remote_${scala.binary.version}</artifactId>
</dependency>
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
index d3264a4bb3..3d7692ea8a 100644
--- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -23,7 +23,7 @@ import com.google.common.io.Files
import org.apache.spark.util.Utils
-private[spark] class HttpFileServer extends Logging {
+private[spark] class HttpFileServer(securityManager: SecurityManager) extends Logging {
var baseDir : File = null
var fileDir : File = null
@@ -38,9 +38,10 @@ private[spark] class HttpFileServer extends Logging {
fileDir.mkdir()
jarDir.mkdir()
logInfo("HTTP File server directory is " + baseDir)
- httpServer = new HttpServer(baseDir)
+ httpServer = new HttpServer(baseDir, securityManager)
httpServer.start()
serverUri = httpServer.uri
+ logDebug("HTTP file server started at: " + serverUri)
}
def stop() {
diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala
index 759e68ee0c..cb5df25fa4 100644
--- a/core/src/main/scala/org/apache/spark/HttpServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpServer.scala
@@ -19,15 +19,18 @@ package org.apache.spark
import java.io.File
+import org.eclipse.jetty.util.security.{Constraint, Password}
+import org.eclipse.jetty.security.authentication.DigestAuthenticator
+import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService, SecurityHandler}
+
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.bio.SocketConnector
-import org.eclipse.jetty.server.handler.DefaultHandler
-import org.eclipse.jetty.server.handler.HandlerList
-import org.eclipse.jetty.server.handler.ResourceHandler
+import org.eclipse.jetty.server.handler.{DefaultHandler, HandlerList, ResourceHandler}
import org.eclipse.jetty.util.thread.QueuedThreadPool
import org.apache.spark.util.Utils
+
/**
* Exception type thrown by HttpServer when it is in the wrong state for an operation.
*/
@@ -38,7 +41,8 @@ private[spark] class ServerStateException(message: String) extends Exception(mes
* as well as classes created by the interpreter when the user types in code. This is just a wrapper
* around a Jetty server.
*/
-private[spark] class HttpServer(resourceBase: File) extends Logging {
+private[spark] class HttpServer(resourceBase: File, securityManager: SecurityManager)
+ extends Logging {
private var server: Server = null
private var port: Int = -1
@@ -59,14 +63,60 @@ private[spark] class HttpServer(resourceBase: File) extends Logging {
server.setThreadPool(threadPool)
val resHandler = new ResourceHandler
resHandler.setResourceBase(resourceBase.getAbsolutePath)
+
val handlerList = new HandlerList
handlerList.setHandlers(Array(resHandler, new DefaultHandler))
- server.setHandler(handlerList)
+
+ if (securityManager.isAuthenticationEnabled()) {
+ logDebug("HttpServer is using security")
+ val sh = setupSecurityHandler(securityManager)
+ // make sure we go through security handler to get resources
+ sh.setHandler(handlerList)
+ server.setHandler(sh)
+ } else {
+ logDebug("HttpServer is not using security")
+ server.setHandler(handlerList)
+ }
+
server.start()
port = server.getConnectors()(0).getLocalPort()
}
}
+ /**
+ * Setup Jetty to the HashLoginService using a single user with our
+ * shared secret. Configure it to use DIGEST-MD5 authentication so that the password
+ * isn't passed in plaintext.
+ */
+ private def setupSecurityHandler(securityMgr: SecurityManager): ConstraintSecurityHandler = {
+ val constraint = new Constraint()
+ // use DIGEST-MD5 as the authentication mechanism
+ constraint.setName(Constraint.__DIGEST_AUTH)
+ constraint.setRoles(Array("user"))
+ constraint.setAuthenticate(true)
+ constraint.setDataConstraint(Constraint.DC_NONE)
+
+ val cm = new ConstraintMapping()
+ cm.setConstraint(constraint)
+ cm.setPathSpec("/*")
+ val sh = new ConstraintSecurityHandler()
+
+ // the hashLoginService lets us do a single user and
+ // secret right now. This could be changed to use the
+ // JAASLoginService for other options.
+ val hashLogin = new HashLoginService()
+
+ val userCred = new Password(securityMgr.getSecretKey())
+ if (userCred == null) {
+ throw new Exception("Error: secret key is null with authentication on")
+ }
+ hashLogin.putUser(securityMgr.getHttpUser(), userCred, Array("user"))
+ sh.setLoginService(hashLogin)
+ sh.setAuthenticator(new DigestAuthenticator());
+ sh.setConstraintMappings(Array(cm))
+ sh
+ }
+
def stop() {
if (server == null) {
throw new ServerStateException("Server is already stopped")
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
new file mode 100644
index 0000000000..591978c1d3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.net.{Authenticator, PasswordAuthentication}
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.security.Credentials
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.spark.deploy.SparkHadoopUtil
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * Spark class responsible for security.
+ *
+ * In general this class should be instantiated by the SparkEnv and most components
+ * should access it from that. There are some cases where the SparkEnv hasn't been
+ * initialized yet and this class must be instantiated directly.
+ *
+ * Spark currently supports authentication via a shared secret.
+ * Authentication can be configured to be on via the 'spark.authenticate' configuration
+ * parameter. This parameter controls whether the Spark communication protocols do
+ * authentication using the shared secret. This authentication is a basic handshake to
+ * make sure both sides have the same shared secret and are allowed to communicate.
+ * If the shared secret is not identical they will not be allowed to communicate.
+ *
+ * The Spark UI can also be secured by using javax servlet filters. A user may want to
+ * secure the UI if it has data that other users should not be allowed to see. The javax
+ * servlet filter specified by the user can authenticate the user and then once the user
+ * is logged in, Spark can compare that user versus the view acls to make sure they are
+ * authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls'
+ * control the behavior of the acls. Note that the person who started the application
+ * always has view access to the UI.
+ *
+ * Spark does not currently support encryption after authentication.
+ *
+ * At this point spark has multiple communication protocols that need to be secured and
+ * different underlying mechanisms are used depending on the protocol:
+ *
+ * - Akka -> The only option here is to use the Akka Remote secure-cookie functionality.
+ * Akka remoting allows you to specify a secure cookie that will be exchanged
+ * and ensured to be identical in the connection handshake between the client
+ * and the server. If they are not identical then the client will be refused
+ * to connect to the server. There is no control of the underlying
+ * authentication mechanism so its not clear if the password is passed in
+ * plaintext or uses DIGEST-MD5 or some other mechanism.
+ * Akka also has an option to turn on SSL, this option is not currently supported
+ * but we could add a configuration option in the future.
+ *
+ * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty
+ * for the HttpServer. Jetty supports multiple authentication mechanisms -
+ * Basic, Digest, Form, Spengo, etc. It also supports multiple different login
+ * services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService
+ * to authenticate using DIGEST-MD5 via a single user and the shared secret.
+ * Since we are using DIGEST-MD5, the shared secret is not passed on the wire
+ * in plaintext.
+ * We currently do not support SSL (https), but Jetty can be configured to use it
+ * so we could add a configuration option for this in the future.
+ *
+ * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5.
+ * Any clients must specify the user and password. There is a default
+ * Authenticator installed in the SecurityManager to how it does the authentication
+ * and in this case gets the user name and password from the request.
+ *
+ * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously
+ * exchange messages. For this we use the Java SASL
+ * (Simple Authentication and Security Layer) API and again use DIGEST-MD5
+ * as the authentication mechanism. This means the shared secret is not passed
+ * over the wire in plaintext.
+ * Note that SASL is pluggable as to what mechanism it uses. We currently use
+ * DIGEST-MD5 but this could be changed to use Kerberos or other in the future.
+ * Spark currently supports "auth" for the quality of protection, which means
+ * the connection is not supporting integrity or privacy protection (encryption)
+ * after authentication. SASL also supports "auth-int" and "auth-conf" which
+ * SPARK could be support in the future to allow the user to specify the quality
+ * of protection they want. If we support those, the messages will also have to
+ * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's.
+ *
+ * Since the connectionManager does asynchronous messages passing, the SASL
+ * authentication is a bit more complex. A ConnectionManager can be both a client
+ * and a Server, so for a particular connection is has to determine what to do.
+ * A ConnectionId was added to be able to track connections and is used to
+ * match up incoming messages with connections waiting for authentication.
+ * If its acting as a client and trying to send a message to another ConnectionManager,
+ * it blocks the thread calling sendMessage until the SASL negotiation has occurred.
+ * The ConnectionManager tracks all the sendingConnections using the ConnectionId
+ * and waits for the response from the server and does the handshake.
+ *
+ * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
+ * can be used. Yarn requires a specific AmIpFilter be installed for security to work
+ * properly. For non-Yarn deployments, users can write a filter to go through a
+ * companies normal login service. If an authentication filter is in place then the
+ * SparkUI can be configured to check the logged in user against the list of users who
+ * have view acls to see if that user is authorized.
+ * The filters can also be used for many different purposes. For instance filters
+ * could be used for logging, encryption, or compression.
+ *
+ * The exact mechanisms used to generate/distributed the shared secret is deployment specific.
+ *
+ * For Yarn deployments, the secret is automatically generated using the Akka remote
+ * Crypt.generateSecureCookie() API. The secret is placed in the Hadoop UGI which gets passed
+ * around via the Hadoop RPC mechanism. Hadoop RPC can be configured to support different levels
+ * of protection. See the Hadoop documentation for more details. Each Spark application on Yarn
+ * gets a different shared secret. On Yarn, the Spark UI gets configured to use the Hadoop Yarn
+ * AmIpFilter which requires the user to go through the ResourceManager Proxy. That Proxy is there
+ * to reduce the possibility of web based attacks through YARN. Hadoop can be configured to use
+ * filters to do authentication. That authentication then happens via the ResourceManager Proxy
+ * and Spark will use that to do authorization against the view acls.
+ *
+ * For other Spark deployments, the shared secret must be specified via the
+ * spark.authenticate.secret config.
+ * All the nodes (Master and Workers) and the applications need to have the same shared secret.
+ * This again is not ideal as one user could potentially affect another users application.
+ * This should be enhanced in the future to provide better protection.
+ * If the UI needs to be secured the user needs to install a javax servlet filter to do the
+ * authentication. Spark will then use that user to compare against the view acls to do
+ * authorization. If not filter is in place the user is generally null and no authorization
+ * can take place.
+ */
+
+private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
+
+ // key used to store the spark secret in the Hadoop UGI
+ private val sparkSecretLookupKey = "sparkCookie"
+
+ private val authOn = sparkConf.getBoolean("spark.authenticate", false)
+ private val uiAclsOn = sparkConf.getBoolean("spark.ui.acls.enable", false)
+
+ // always add the current user and SPARK_USER to the viewAcls
+ private val aclUsers = ArrayBuffer[String](System.getProperty("user.name", ""),
+ Option(System.getenv("SPARK_USER")).getOrElse(""))
+ aclUsers ++= sparkConf.get("spark.ui.view.acls", "").split(',')
+ private val viewAcls = aclUsers.map(_.trim()).filter(!_.isEmpty).toSet
+
+ private val secretKey = generateSecretKey()
+ logInfo("SecurityManager, is authentication enabled: " + authOn +
+ " are ui acls enabled: " + uiAclsOn + " users with view permissions: " + viewAcls.toString())
+
+ // Set our own authenticator to properly negotiate user/password for HTTP connections.
+ // This is needed by the HTTP client fetching from the HttpServer. Put here so its
+ // only set once.
+ if (authOn) {
+ Authenticator.setDefault(
+ new Authenticator() {
+ override def getPasswordAuthentication(): PasswordAuthentication = {
+ var passAuth: PasswordAuthentication = null
+ val userInfo = getRequestingURL().getUserInfo()
+ if (userInfo != null) {
+ val parts = userInfo.split(":", 2)
+ passAuth = new PasswordAuthentication(parts(0), parts(1).toCharArray())
+ }
+ return passAuth
+ }
+ }
+ )
+ }
+
+ /**
+ * Generates or looks up the secret key.
+ *
+ * The way the key is stored depends on the Spark deployment mode. Yarn
+ * uses the Hadoop UGI.
+ *
+ * For non-Yarn deployments, If the config variable is not set
+ * we throw an exception.
+ */
+ private def generateSecretKey(): String = {
+ if (!isAuthenticationEnabled) return null
+ // first check to see if the secret is already set, else generate a new one if on yarn
+ val sCookie = if (SparkHadoopUtil.get.isYarnMode) {
+ val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(sparkSecretLookupKey)
+ if (secretKey != null) {
+ logDebug("in yarn mode, getting secret from credentials")
+ return new Text(secretKey).toString
+ } else {
+ logDebug("getSecretKey: yarn mode, secret key from credentials is null")
+ }
+ val cookie = akka.util.Crypt.generateSecureCookie
+ // if we generated the secret then we must be the first so lets set it so t
+ // gets used by everyone else
+ SparkHadoopUtil.get.addSecretKeyToUserCredentials(sparkSecretLookupKey, cookie)
+ logInfo("adding secret to credentials in yarn mode")
+ cookie
+ } else {
+ // user must have set spark.authenticate.secret config
+ sparkConf.getOption("spark.authenticate.secret") match {
+ case Some(value) => value
+ case None => throw new Exception("Error: a secret key must be specified via the " +
+ "spark.authenticate.secret config")
+ }
+ }
+ sCookie
+ }
+
+ /**
+ * Check to see if Acls for the UI are enabled
+ * @return true if UI authentication is enabled, otherwise false
+ */
+ def uiAclsEnabled(): Boolean = uiAclsOn
+
+ /**
+ * Checks the given user against the view acl list to see if they have
+ * authorization to view the UI. If the UI acls must are disabled
+ * via spark.ui.acls.enable, all users have view access.
+ *
+ * @param user to see if is authorized
+ * @return true is the user has permission, otherwise false
+ */
+ def checkUIViewPermissions(user: String): Boolean = {
+ if (uiAclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true
+ }
+
+ /**
+ * Check to see if authentication for the Spark communication protocols is enabled
+ * @return true if authentication is enabled, otherwise false
+ */
+ def isAuthenticationEnabled(): Boolean = authOn
+
+ /**
+ * Gets the user used for authenticating HTTP connections.
+ * For now use a single hardcoded user.
+ * @return the HTTP user as a String
+ */
+ def getHttpUser(): String = "sparkHttpUser"
+
+ /**
+ * Gets the user used for authenticating SASL connections.
+ * For now use a single hardcoded user.
+ * @return the SASL user as a String
+ */
+ def getSaslUser(): String = "sparkSaslUser"
+
+ /**
+ * Gets the secret key.
+ * @return the secret key as a String if authentication is enabled, otherwise returns null
+ */
+ def getSecretKey(): String = secretKey
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index da778aa851..24731ad706 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -130,6 +130,8 @@ class SparkContext(
val isLocal = (master == "local" || master.startsWith("local["))
+ if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true")
+
// Create the Spark execution environment (cache, map output tracker, etc)
private[spark] val env = SparkEnv.create(
conf,
@@ -634,7 +636,7 @@ class SparkContext(
addedFiles(key) = System.currentTimeMillis
// Fetch the file locally in case a job is executed using DAGScheduler.runLocally().
- Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf)
+ Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf, env.securityManager)
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 7ac65828f6..5e43b51984 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -53,7 +53,8 @@ class SparkEnv private[spark] (
val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
- val conf: SparkConf) extends Logging {
+ val conf: SparkConf,
+ val securityManager: SecurityManager) extends Logging {
// A mapping of thread ID to amount of memory used for shuffle in bytes
// All accesses should be manually synchronized
@@ -122,8 +123,9 @@ object SparkEnv extends Logging {
isDriver: Boolean,
isLocal: Boolean): SparkEnv = {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port,
- conf = conf)
+ val securityManager = new SecurityManager(conf)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, conf = conf,
+ securityManager = securityManager)
// Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port),
// figure out which port number Akka actually bound to and set spark.driver.port to it.
@@ -139,7 +141,6 @@ object SparkEnv extends Logging {
val name = conf.get(propertyName, defaultClassName)
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
}
-
val serializerManager = new SerializerManager
val serializer = serializerManager.setDefault(
@@ -167,12 +168,12 @@ object SparkEnv extends Logging {
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal, conf)), conf)
- val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
- serializer, conf)
+ val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
+ serializer, conf, securityManager)
val connectionManager = blockManager.connectionManager
- val broadcastManager = new BroadcastManager(isDriver, conf)
+ val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
val cacheManager = new CacheManager(blockManager)
@@ -190,14 +191,14 @@ object SparkEnv extends Logging {
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
- val httpFileServer = new HttpFileServer()
+ val httpFileServer = new HttpFileServer(securityManager)
httpFileServer.initialize()
conf.set("spark.fileserver.uri", httpFileServer.serverUri)
val metricsSystem = if (isDriver) {
- MetricsSystem.createMetricsSystem("driver", conf)
+ MetricsSystem.createMetricsSystem("driver", conf, securityManager)
} else {
- MetricsSystem.createMetricsSystem("executor", conf)
+ MetricsSystem.createMetricsSystem("executor", conf, securityManager)
}
metricsSystem.start()
@@ -231,6 +232,7 @@ object SparkEnv extends Logging {
httpFileServer,
sparkFilesDir,
metricsSystem,
- conf)
+ conf,
+ securityManager)
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
new file mode 100644
index 0000000000..a2a871cbd3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.IOException
+import javax.security.auth.callback.Callback
+import javax.security.auth.callback.CallbackHandler
+import javax.security.auth.callback.NameCallback
+import javax.security.auth.callback.PasswordCallback
+import javax.security.auth.callback.UnsupportedCallbackException
+import javax.security.sasl.RealmCallback
+import javax.security.sasl.RealmChoiceCallback
+import javax.security.sasl.Sasl
+import javax.security.sasl.SaslClient
+import javax.security.sasl.SaslException
+
+import scala.collection.JavaConversions.mapAsJavaMap
+
+/**
+ * Implements SASL Client logic for Spark
+ */
+private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging {
+
+ /**
+ * Used to respond to server's counterpart, SaslServer with SASL tokens
+ * represented as byte arrays.
+ *
+ * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
+ * configurable in the future.
+ */
+ private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST),
+ null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
+ new SparkSaslClientCallbackHandler(securityMgr))
+
+ /**
+ * Used to initiate SASL handshake with server.
+ * @return response to challenge if needed
+ */
+ def firstToken(): Array[Byte] = {
+ synchronized {
+ val saslToken: Array[Byte] =
+ if (saslClient != null && saslClient.hasInitialResponse()) {
+ logDebug("has initial response")
+ saslClient.evaluateChallenge(new Array[Byte](0))
+ } else {
+ new Array[Byte](0)
+ }
+ saslToken
+ }
+ }
+
+ /**
+ * Determines whether the authentication exchange has completed.
+ * @return true is complete, otherwise false
+ */
+ def isComplete(): Boolean = {
+ synchronized {
+ if (saslClient != null) saslClient.isComplete() else false
+ }
+ }
+
+ /**
+ * Respond to server's SASL token.
+ * @param saslTokenMessage contains server's SASL token
+ * @return client's response SASL token
+ */
+ def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = {
+ synchronized {
+ if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0)
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslClient might be using.
+ */
+ def dispose() {
+ synchronized {
+ if (saslClient != null) {
+ try {
+ saslClient.dispose()
+ } catch {
+ case e: SaslException => // ignored
+ } finally {
+ saslClient = null
+ }
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler
+ * that works with share secrets.
+ */
+ private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends
+ CallbackHandler {
+
+ private val userName: String =
+ SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes())
+ private val secretKey = securityMgr.getSecretKey()
+ private val userPassword: Array[Char] =
+ SparkSaslServer.encodePassword(if (secretKey != null) secretKey.getBytes() else "".getBytes())
+
+ /**
+ * Implementation used to respond to SASL request from the server.
+ *
+ * @param callbacks objects that indicate what credential information the
+ * server's SaslServer requires from the client.
+ */
+ override def handle(callbacks: Array[Callback]) {
+ logDebug("in the sasl client callback handler")
+ callbacks foreach {
+ case nc: NameCallback => {
+ logDebug("handle: SASL client callback: setting username: " + userName)
+ nc.setName(userName)
+ }
+ case pc: PasswordCallback => {
+ logDebug("handle: SASL client callback: setting userPassword")
+ pc.setPassword(userPassword)
+ }
+ case rc: RealmCallback => {
+ logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText())
+ rc.setText(rc.getDefaultText())
+ }
+ case cb: RealmChoiceCallback => {}
+ case cb: Callback => throw
+ new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback")
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
new file mode 100644
index 0000000000..11fcb2ae3a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import javax.security.auth.callback.Callback
+import javax.security.auth.callback.CallbackHandler
+import javax.security.auth.callback.NameCallback
+import javax.security.auth.callback.PasswordCallback
+import javax.security.auth.callback.UnsupportedCallbackException
+import javax.security.sasl.AuthorizeCallback
+import javax.security.sasl.RealmCallback
+import javax.security.sasl.Sasl
+import javax.security.sasl.SaslException
+import javax.security.sasl.SaslServer
+import scala.collection.JavaConversions.mapAsJavaMap
+import org.apache.commons.net.util.Base64
+
+/**
+ * Encapsulates SASL server logic
+ */
+private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging {
+
+ /**
+ * Actual SASL work done by this object from javax.security.sasl.
+ */
+ private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null,
+ SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
+ new SparkSaslDigestCallbackHandler(securityMgr))
+
+ /**
+ * Determines whether the authentication exchange has completed.
+ * @return true is complete, otherwise false
+ */
+ def isComplete(): Boolean = {
+ synchronized {
+ if (saslServer != null) saslServer.isComplete() else false
+ }
+ }
+
+ /**
+ * Used to respond to server SASL tokens.
+ * @param token Server's SASL token
+ * @return response to send back to the server.
+ */
+ def response(token: Array[Byte]): Array[Byte] = {
+ synchronized {
+ if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0)
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslServer might be using.
+ */
+ def dispose() {
+ synchronized {
+ if (saslServer != null) {
+ try {
+ saslServer.dispose()
+ } catch {
+ case e: SaslException => // ignore
+ } finally {
+ saslServer = null
+ }
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler
+ * for SASL DIGEST-MD5 mechanism
+ */
+ private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager)
+ extends CallbackHandler {
+
+ private val userName: String =
+ SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes())
+
+ override def handle(callbacks: Array[Callback]) {
+ logDebug("In the sasl server callback handler")
+ callbacks foreach {
+ case nc: NameCallback => {
+ logDebug("handle: SASL server callback: setting username")
+ nc.setName(userName)
+ }
+ case pc: PasswordCallback => {
+ logDebug("handle: SASL server callback: setting userPassword")
+ val password: Array[Char] =
+ SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes())
+ pc.setPassword(password)
+ }
+ case rc: RealmCallback => {
+ logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText())
+ rc.setText(rc.getDefaultText())
+ }
+ case ac: AuthorizeCallback => {
+ val authid = ac.getAuthenticationID()
+ val authzid = ac.getAuthorizationID()
+ if (authid.equals(authzid)) {
+ logDebug("set auth to true")
+ ac.setAuthorized(true)
+ } else {
+ logDebug("set auth to false")
+ ac.setAuthorized(false)
+ }
+ if (ac.isAuthorized()) {
+ logDebug("sasl server is authorized")
+ ac.setAuthorizedID(authzid)
+ }
+ }
+ case cb: Callback => throw
+ new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback")
+ }
+ }
+ }
+}
+
+private[spark] object SparkSaslServer {
+
+ /**
+ * This is passed as the server name when creating the sasl client/server.
+ * This could be changed to be configurable in the future.
+ */
+ val SASL_DEFAULT_REALM = "default"
+
+ /**
+ * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
+ * configurable in the future.
+ */
+ val DIGEST = "DIGEST-MD5"
+
+ /**
+ * The quality of protection is just "auth". This means that we are doing
+ * authentication only, we are not supporting integrity or privacy protection of the
+ * communication channel after authentication. This could be changed to be configurable
+ * in the future.
+ */
+ val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true")
+
+ /**
+ * Encode a byte[] identifier as a Base64-encoded string.
+ *
+ * @param identifier identifier to encode
+ * @return Base64-encoded string
+ */
+ def encodeIdentifier(identifier: Array[Byte]): String = {
+ new String(Base64.encodeBase64(identifier))
+ }
+
+ /**
+ * Encode a password as a base64-encoded char[] array.
+ * @param password as a byte array.
+ * @return password as a char array.
+ */
+ def encodePassword(password: Array[Byte]): Array[Char] = {
+ new String(Base64.encodeBase64(password)).toCharArray()
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
index d113d40405..e3c3a12d16 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -60,7 +60,8 @@ abstract class Broadcast[T](val id: Long) extends Serializable {
}
private[spark]
-class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging with Serializable {
+class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager)
+ extends Logging with Serializable {
private var initialized = false
private var broadcastFactory: BroadcastFactory = null
@@ -78,7 +79,7 @@ class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Initialize appropriate BroadcastFactory and BroadcastObject
- broadcastFactory.initialize(isDriver, conf)
+ broadcastFactory.initialize(isDriver, conf, securityManager)
initialized = true
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
index 940e5ab805..6beecaeced 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.broadcast
+import org.apache.spark.SecurityManager
import org.apache.spark.SparkConf
@@ -26,7 +27,7 @@ import org.apache.spark.SparkConf
* entire Spark job.
*/
trait BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf): Unit
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 20207c2613..e8eb04bb10 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -18,13 +18,13 @@
package org.apache.spark.broadcast
import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream}
-import java.net.URL
+import java.net.{URL, URLConnection, URI}
import java.util.concurrent.TimeUnit
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-import org.apache.spark.{HttpServer, Logging, SparkConf, SparkEnv}
+import org.apache.spark.{SparkConf, HttpServer, Logging, SecurityManager, SparkEnv}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
@@ -67,7 +67,9 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
* A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium.
*/
class HttpBroadcastFactory extends BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf) { HttpBroadcast.initialize(isDriver, conf) }
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
+ HttpBroadcast.initialize(isDriver, conf, securityMgr)
+ }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
@@ -83,6 +85,7 @@ private object HttpBroadcast extends Logging {
private var bufferSize: Int = 65536
private var serverUri: String = null
private var server: HttpServer = null
+ private var securityManager: SecurityManager = null
// TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist
private val files = new TimeStampedHashSet[String]
@@ -92,11 +95,12 @@ private object HttpBroadcast extends Logging {
private var compressionCodec: CompressionCodec = null
- def initialize(isDriver: Boolean, conf: SparkConf) {
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
synchronized {
if (!initialized) {
bufferSize = conf.getInt("spark.buffer.size", 65536)
compress = conf.getBoolean("spark.broadcast.compress", true)
+ securityManager = securityMgr
if (isDriver) {
createServer(conf)
conf.set("spark.httpBroadcast.uri", serverUri)
@@ -126,7 +130,7 @@ private object HttpBroadcast extends Logging {
private def createServer(conf: SparkConf) {
broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf))
- server = new HttpServer(broadcastDir)
+ server = new HttpServer(broadcastDir, securityManager)
server.start()
serverUri = server.uri
logInfo("Broadcast server started at " + serverUri)
@@ -149,11 +153,23 @@ private object HttpBroadcast extends Logging {
}
def read[T](id: Long): T = {
+ logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id)
val url = serverUri + "/" + BroadcastBlockId(id).name
+
+ var uc: URLConnection = null
+ if (securityManager.isAuthenticationEnabled()) {
+ logDebug("broadcast security enabled")
+ val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager)
+ uc = newuri.toURL().openConnection()
+ uc.setAllowUserInteraction(false)
+ } else {
+ logDebug("broadcast not using security")
+ uc = new URL(url).openConnection()
+ }
+
val in = {
- val httpConnection = new URL(url).openConnection()
- httpConnection.setReadTimeout(httpReadTimeout)
- val inputStream = httpConnection.getInputStream
+ uc.setReadTimeout(httpReadTimeout)
+ val inputStream = uc.getInputStream();
if (compress) {
compressionCodec.compressedInputStream(inputStream)
} else {
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 22d783c859..3cd7121376 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -241,7 +241,9 @@ private[spark] case class TorrentInfo(
*/
class TorrentBroadcastFactory extends BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) }
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
+ TorrentBroadcast.initialize(isDriver, conf)
+ }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index eb5676b51d..d9e3035e1a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -26,7 +26,7 @@ import akka.pattern.ask
import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent}
import org.apache.log4j.{Level, Logger}
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.{DriverState, Master}
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -141,7 +141,7 @@ object Client {
// TODO: See if we can initialize akka so return messages are sent back using the same TCP
// flow. Else, this (sadly) requires the DriverClient be routable from the Master.
val (actorSystem, _) = AkkaUtils.createActorSystem(
- "driverClient", Utils.localHostName(), 0, false, conf)
+ "driverClient", Utils.localHostName(), 0, false, conf, new SecurityManager(conf))
actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf))
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index ec15647e1d..d2d8d6d662 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -21,6 +21,7 @@ import java.security.PrivilegedExceptionAction
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.security.Credentials
import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.{SparkContext, SparkException}
@@ -65,6 +66,15 @@ class SparkHadoopUtil {
def addCredentials(conf: JobConf) {}
def isYarnMode(): Boolean = { false }
+
+ def getCurrentUserCredentials(): Credentials = { null }
+
+ def addCurrentUserCredentials(creds: Credentials) {}
+
+ def addSecretKeyToUserCredentials(key: String, secret: String) {}
+
+ def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { null }
+
}
object SparkHadoopUtil {
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
index 1550c3eb42..63f166d401 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy.client
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf, Logging}
import org.apache.spark.deploy.{ApplicationDescription, Command}
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -45,8 +45,9 @@ private[spark] object TestClient {
def main(args: Array[String]) {
val url = args(0)
+ val conf = new SparkConf
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0,
- conf = new SparkConf)
+ conf = conf, securityManager = new SecurityManager(conf))
val desc = new ApplicationDescription(
"TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()),
Some("dummy-spark-home"), "ignored")
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 51794ce40c..2d6d0c33fa 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -30,7 +30,7 @@ import akka.pattern.ask
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
import akka.serialization.SerializationExtension
-import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.DriverState.DriverState
@@ -39,7 +39,8 @@ import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{AkkaUtils, Utils}
-private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
+private[spark] class Master(host: String, port: Int, webUiPort: Int,
+ val securityMgr: SecurityManager) extends Actor with Logging {
import context.dispatcher // to use Akka's scheduler.schedule()
val conf = new SparkConf
@@ -70,8 +71,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
Utils.checkHost(host, "Expected hostname")
- val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf)
- val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf)
+ val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr)
+ val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf,
+ securityMgr)
val masterSource = new MasterSource(this)
val webUi = new MasterWebUI(this, webUiPort)
@@ -711,8 +713,11 @@ private[spark] object Master {
def startSystemAndActor(host: String, port: Int, webUiPort: Int, conf: SparkConf)
: (ActorSystem, Int, Int) =
{
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf)
- val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), actorName)
+ val securityMgr = new SecurityManager(conf)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf,
+ securityManager = securityMgr)
+ val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort,
+ securityMgr), actorName)
val timeout = AkkaUtils.askTimeout(conf)
val respFuture = actor.ask(RequestWebUIPort)(timeout)
val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse]
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index 5ab13e7aa6..a7bd01e284 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -18,8 +18,8 @@
package org.apache.spark.deploy.master.ui
import javax.servlet.http.HttpServletRequest
-
-import org.eclipse.jetty.server.{Handler, Server}
+import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.Logging
import org.apache.spark.deploy.master.Master
@@ -46,7 +46,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging {
def start() {
try {
- val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers)
+ val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, master.conf)
server = Some(srv)
boundPort = Some(bPort)
logInfo("Started Master web UI at http://%s:%d".format(host, boundPort.get))
@@ -60,12 +60,17 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging {
val metricsHandlers = master.masterMetricsSystem.getServletHandlers ++
master.applicationMetricsSystem.getServletHandlers
- val handlers = metricsHandlers ++ Array[(String, Handler)](
- ("/static", createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR)),
- ("/app/json", (request: HttpServletRequest) => applicationPage.renderJson(request)),
- ("/app", (request: HttpServletRequest) => applicationPage.render(request)),
- ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)),
- ("*", (request: HttpServletRequest) => indexPage.render(request))
+ val handlers = metricsHandlers ++ Seq[ServletContextHandler](
+ createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static/*"),
+ createServletHandler("/app/json",
+ createServlet((request: HttpServletRequest) => applicationPage.renderJson(request),
+ master.securityMgr)),
+ createServletHandler("/app", createServlet((request: HttpServletRequest) => applicationPage
+ .render(request), master.securityMgr)),
+ createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage
+ .renderJson(request), master.securityMgr)),
+ createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render
+ (request), master.securityMgr))
)
def stop() {
@@ -74,5 +79,5 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging {
}
private[spark] object MasterWebUI {
- val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static"
+ val STATIC_RESOURCE_DIR = "org/apache/spark/ui"
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
index a26e47950a..be15138f62 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
@@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker
import akka.actor._
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.util.{AkkaUtils, Utils}
/**
@@ -29,8 +29,9 @@ object DriverWrapper {
def main(args: Array[String]) {
args.toList match {
case workerUrl :: mainClass :: extraArgs =>
+ val conf = new SparkConf()
val (actorSystem, _) = AkkaUtils.createActorSystem("Driver",
- Utils.localHostName(), 0, false, new SparkConf())
+ Utils.localHostName(), 0, false, conf, new SecurityManager(conf))
actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher")
// Delegate to supplied main class
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 7b0b7861b7..afaabedffe 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -27,7 +27,7 @@ import scala.concurrent.duration._
import akka.actor._
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
-import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.{DriverState, Master}
@@ -48,7 +48,8 @@ private[spark] class Worker(
actorSystemName: String,
actorName: String,
workDirPath: String = null,
- val conf: SparkConf)
+ val conf: SparkConf,
+ val securityMgr: SecurityManager)
extends Actor with Logging {
import context.dispatcher
@@ -91,7 +92,7 @@ private[spark] class Worker(
var coresUsed = 0
var memoryUsed = 0
- val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf)
+ val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr)
val workerSource = new WorkerSource(this)
def coresFree: Int = cores - coresUsed
@@ -347,10 +348,11 @@ private[spark] object Worker {
val conf = new SparkConf
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
val actorName = "Worker"
+ val securityMgr = new SecurityManager(conf)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port,
- conf = conf)
+ conf = conf, securityManager = securityMgr)
actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
- masterUrls, systemName, actorName, workDir, conf), name = actorName)
+ masterUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName)
(actorSystem, boundPort)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index bdf126f93a..ffc05bd306 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -19,8 +19,8 @@ package org.apache.spark.deploy.worker.ui
import java.io.File
import javax.servlet.http.HttpServletRequest
-
-import org.eclipse.jetty.server.{Handler, Server}
+import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.Logging
import org.apache.spark.deploy.worker.Worker
@@ -33,7 +33,7 @@ import org.apache.spark.util.{AkkaUtils, Utils}
*/
private[spark]
class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[Int] = None)
- extends Logging {
+ extends Logging {
val timeout = AkkaUtils.askTimeout(worker.conf)
val host = Utils.localHostName()
val port = requestedPort.getOrElse(
@@ -46,17 +46,21 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
val metricsHandlers = worker.metricsSystem.getServletHandlers
- val handlers = metricsHandlers ++ Array[(String, Handler)](
- ("/static", createStaticHandler(WorkerWebUI.STATIC_RESOURCE_DIR)),
- ("/log", (request: HttpServletRequest) => log(request)),
- ("/logPage", (request: HttpServletRequest) => logPage(request)),
- ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)),
- ("*", (request: HttpServletRequest) => indexPage.render(request))
+ val handlers = metricsHandlers ++ Seq[ServletContextHandler](
+ createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static/*"),
+ createServletHandler("/log", createServlet((request: HttpServletRequest) => log(request),
+ worker.securityMgr)),
+ createServletHandler("/logPage", createServlet((request: HttpServletRequest) => logPage
+ (request), worker.securityMgr)),
+ createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage
+ .renderJson(request), worker.securityMgr)),
+ createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render
+ (request), worker.securityMgr))
)
def start() {
try {
- val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers)
+ val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, worker.conf)
server = Some(srv)
boundPort = Some(bPort)
logInfo("Started Worker web UI at http://%s:%d".format(host, bPort))
@@ -198,6 +202,6 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
}
private[spark] object WorkerWebUI {
- val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static"
+ val STATIC_RESOURCE_BASE = "org/apache/spark/ui"
val DEFAULT_PORT="8081"
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 0aae569b17..3486092a14 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -22,7 +22,7 @@ import java.nio.ByteBuffer
import akka.actor._
import akka.remote._
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf, Logging}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.worker.WorkerWatcher
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
@@ -97,10 +97,11 @@ private[spark] object CoarseGrainedExecutorBackend {
// Debug code
Utils.checkHost(hostname)
+ val conf = new SparkConf
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0,
- indestructible = true, conf = new SparkConf)
+ indestructible = true, conf = conf, new SecurityManager(conf))
// set it
val sparkHostPort = hostname + ":" + boundPort
actorSystem.actorOf(
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 989d666f15..e69f6f72d3 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -69,11 +69,6 @@ private[spark] class Executor(
conf.set("spark.local.dir", getYarnLocalDirs())
}
- // Create our ClassLoader and set it on this thread
- private val urlClassLoader = createClassLoader()
- private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
- Thread.currentThread.setContextClassLoader(replClassLoader)
-
if (!isLocal) {
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
@@ -117,6 +112,12 @@ private[spark] class Executor(
}
}
+ // Create our ClassLoader and set it on this thread
+ // do this after SparkEnv creation so can access the SecurityManager
+ private val urlClassLoader = createClassLoader()
+ private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
+ Thread.currentThread.setContextClassLoader(replClassLoader)
+
// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
private val akkaFrameSize = {
@@ -338,12 +339,12 @@ private[spark] class Executor(
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager)
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager)
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index 966c092124..c5bda2078f 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable
import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry}
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.metrics.sink.{MetricsServlet, Sink}
import org.apache.spark.metrics.source.Source
@@ -64,7 +64,7 @@ import org.apache.spark.metrics.source.Source
* [options] is the specific property of this source or sink.
*/
private[spark] class MetricsSystem private (val instance: String,
- conf: SparkConf) extends Logging {
+ conf: SparkConf, securityMgr: SecurityManager) extends Logging {
val confFile = conf.get("spark.metrics.conf", null)
val metricsConfig = new MetricsConfig(Option(confFile))
@@ -131,8 +131,8 @@ private[spark] class MetricsSystem private (val instance: String,
val classPath = kv._2.getProperty("class")
try {
val sink = Class.forName(classPath)
- .getConstructor(classOf[Properties], classOf[MetricRegistry])
- .newInstance(kv._2, registry)
+ .getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager])
+ .newInstance(kv._2, registry, securityMgr)
if (kv._1 == "servlet") {
metricsServlet = Some(sink.asInstanceOf[MetricsServlet])
} else {
@@ -160,6 +160,7 @@ private[spark] object MetricsSystem {
}
}
- def createMetricsSystem(instance: String, conf: SparkConf): MetricsSystem =
- new MetricsSystem(instance, conf)
+ def createMetricsSystem(instance: String, conf: SparkConf,
+ securityMgr: SecurityManager): MetricsSystem =
+ new MetricsSystem(instance, conf, securityMgr)
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
index 98fa1dbd7c..4d2ffc54d8 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
@@ -22,9 +22,11 @@ import java.util.concurrent.TimeUnit
import com.codahale.metrics.{ConsoleReporter, MetricRegistry}
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class ConsoleSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+class ConsoleSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val CONSOLE_DEFAULT_PERIOD = 10
val CONSOLE_DEFAULT_UNIT = "SECONDS"
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
index 40f64768e6..319f40815d 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
@@ -23,9 +23,11 @@ import java.util.concurrent.TimeUnit
import com.codahale.metrics.{CsvReporter, MetricRegistry}
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class CsvSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+class CsvSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val CSV_KEY_PERIOD = "period"
val CSV_KEY_UNIT = "unit"
val CSV_KEY_DIR = "directory"
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala
index 410ca0704b..cd37317da7 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala
@@ -24,9 +24,11 @@ import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.ganglia.GangliaReporter
import info.ganglia.gmetric4j.gmetric.GMetric
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class GangliaSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+class GangliaSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val GANGLIA_KEY_PERIOD = "period"
val GANGLIA_DEFAULT_PERIOD = 10
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
index e09be00142..0ffdf3846d 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
@@ -24,9 +24,11 @@ import java.util.concurrent.TimeUnit
import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.graphite.{Graphite, GraphiteReporter}
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+class GraphiteSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val GRAPHITE_DEFAULT_PERIOD = 10
val GRAPHITE_DEFAULT_UNIT = "SECONDS"
val GRAPHITE_DEFAULT_PREFIX = ""
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
index b5cf210af2..3b5edd5c37 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
@@ -20,8 +20,11 @@ package org.apache.spark.metrics.sink
import java.util.Properties
import com.codahale.metrics.{JmxReporter, MetricRegistry}
+import org.apache.spark.SecurityManager
+
+class JmxSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
-class JmxSink(val property: Properties, val registry: MetricRegistry) extends Sink {
val reporter: JmxReporter = JmxReporter.forRegistry(registry).build()
override def start() {
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
index 3cdfe26d40..3110eccdee 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
@@ -19,16 +19,19 @@ package org.apache.spark.metrics.sink
import java.util.Properties
import java.util.concurrent.TimeUnit
+
import javax.servlet.http.HttpServletRequest
import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.json.MetricsModule
import com.fasterxml.jackson.databind.ObjectMapper
-import org.eclipse.jetty.server.Handler
+import org.eclipse.jetty.servlet.ServletContextHandler
+import org.apache.spark.SecurityManager
import org.apache.spark.ui.JettyUtils
-class MetricsServlet(val property: Properties, val registry: MetricRegistry) extends Sink {
+class MetricsServlet(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val SERVLET_KEY_PATH = "path"
val SERVLET_KEY_SAMPLE = "sample"
@@ -42,8 +45,11 @@ class MetricsServlet(val property: Properties, val registry: MetricRegistry) ext
val mapper = new ObjectMapper().registerModule(
new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample))
- def getHandlers = Array[(String, Handler)](
- (servletPath, JettyUtils.createHandler(request => getMetricsSnapshot(request), "text/json"))
+ def getHandlers = Array[ServletContextHandler](
+ JettyUtils.createServletHandler(servletPath,
+ JettyUtils.createServlet(
+ new JettyUtils.ServletParams(request => getMetricsSnapshot(request), "text/json"),
+ securityMgr) )
)
def getMetricsSnapshot(request: HttpServletRequest): String = {
diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
index d3c09b1606..04df2f3b0d 100644
--- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
+++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
@@ -45,9 +45,10 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
throw new Exception("Max chunk size is " + maxChunkSize)
}
+ val security = if (isSecurityNeg) 1 else 0
if (size == 0 && !gotChunkForSendingOnce) {
val newChunk = new MessageChunk(
- new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
+ new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null)
gotChunkForSendingOnce = true
return Some(newChunk)
}
@@ -65,7 +66,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
}
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
gotChunkForSendingOnce = true
return Some(newChunk)
}
@@ -79,6 +80,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
throw new Exception("Attempting to get chunk from message with multiple data buffers")
}
val buffer = buffers(0)
+ val security = if (isSecurityNeg) 1 else 0
if (buffer.remaining > 0) {
if (buffer.remaining < chunkSize) {
throw new Exception("Not enough space in data buffer for receiving chunk")
@@ -86,7 +88,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
return Some(newChunk)
}
None
diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala
index 8219a185ea..8fd9c2b87d 100644
--- a/core/src/main/scala/org/apache/spark/network/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/Connection.scala
@@ -17,6 +17,11 @@
package org.apache.spark.network
+import org.apache.spark._
+import org.apache.spark.SparkSaslServer
+
+import scala.collection.mutable.{HashMap, Queue, ArrayBuffer}
+
import java.net._
import java.nio._
import java.nio.channels._
@@ -27,13 +32,16 @@ import org.apache.spark._
private[spark]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
- val socketRemoteConnectionManagerId: ConnectionManagerId)
+ val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId)
extends Logging {
- def this(channel_ : SocketChannel, selector_ : Selector) = {
+ var sparkSaslServer: SparkSaslServer = null
+ var sparkSaslClient: SparkSaslClient = null
+
+ def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = {
this(channel_, selector_,
ConnectionManagerId.fromSocketAddress(
- channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]))
+ channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]), id_)
}
channel.configureBlocking(false)
@@ -49,6 +57,16 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
val remoteAddress = getRemoteAddress()
+ /**
+ * Used to synchronize client requests: client's work-related requests must
+ * wait until SASL authentication completes.
+ */
+ private val authenticated = new Object()
+
+ def getAuthenticated(): Object = authenticated
+
+ def isSaslComplete(): Boolean
+
def resetForceReregister(): Boolean
// Read channels typically do not register for write and write does not for read
@@ -69,6 +87,16 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
// Will be true for ReceivingConnection, false for SendingConnection.
def changeInterestForRead(): Boolean
+ private def disposeSasl() {
+ if (sparkSaslServer != null) {
+ sparkSaslServer.dispose();
+ }
+
+ if (sparkSaslClient != null) {
+ sparkSaslClient.dispose()
+ }
+ }
+
// On receiving a write event, should we change the interest for this channel or not ?
// Will be false for ReceivingConnection, true for SendingConnection.
// Actually, for now, should not get triggered for ReceivingConnection
@@ -101,6 +129,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
k.cancel()
}
channel.close()
+ disposeSasl()
callOnCloseCallback()
}
@@ -168,8 +197,12 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
private[spark]
class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
- remoteId_ : ConnectionManagerId)
- extends Connection(SocketChannel.open, selector_, remoteId_) {
+ remoteId_ : ConnectionManagerId, id_ : ConnectionId)
+ extends Connection(SocketChannel.open, selector_, remoteId_, id_) {
+
+ def isSaslComplete(): Boolean = {
+ if (sparkSaslClient != null) sparkSaslClient.isComplete() else false
+ }
private class Outbox {
val messages = new Queue[Message]()
@@ -226,6 +259,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
data as detailed in https://github.com/mesos/spark/pull/791
*/
private var needForceReregister = false
+
val currentBuffers = new ArrayBuffer[ByteBuffer]()
/*channel.socket.setSendBufferSize(256 * 1024)*/
@@ -316,6 +350,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
// If we have 'seen' pending messages, then reset flag - since we handle that as
// normal registering of event (below)
if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister()
+
currentBuffers ++= buffers
}
case None => {
@@ -384,8 +419,15 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
// Must be created within selector loop - else deadlock
-private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
- extends Connection(channel_, selector_) {
+private[spark] class ReceivingConnection(
+ channel_ : SocketChannel,
+ selector_ : Selector,
+ id_ : ConnectionId)
+ extends Connection(channel_, selector_, id_) {
+
+ def isSaslComplete(): Boolean = {
+ if (sparkSaslServer != null) sparkSaslServer.isComplete() else false
+ }
class Inbox() {
val messages = new HashMap[Int, BufferMessage]()
@@ -396,6 +438,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
val newMessage = Message.create(header).asInstanceOf[BufferMessage]
newMessage.started = true
newMessage.startTime = System.currentTimeMillis
+ newMessage.isSecurityNeg = header.securityNeg == 1
logDebug(
"Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
messages += ((newMessage.id, newMessage))
@@ -441,7 +484,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
val inbox = new Inbox()
val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
- var onReceiveCallback: (Connection , Message) => Unit = null
+ var onReceiveCallback: (Connection, Message) => Unit = null
var currentChunk: MessageChunk = null
channel.register(selector, SelectionKey.OP_READ)
@@ -516,7 +559,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
}
}
} catch {
- case e: Exception => {
+ case e: Exception => {
logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala
new file mode 100644
index 0000000000..ffaab677d4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) {
+ override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId
+}
+
+private[spark] object ConnectionId {
+
+ def createConnectionIdFromString(connectionIdString: String): ConnectionId = {
+ val res = connectionIdString.split("_").map(_.trim())
+ if (res.size != 3) {
+ throw new Exception("Error converting ConnectionId string: " + connectionIdString +
+ " to a ConnectionId Object")
+ }
+ new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index a7f20f8c51..a75130cba2 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -21,6 +21,9 @@ import java.net._
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
+import java.net._
+import java.util.concurrent.atomic.AtomicInteger
+
import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor}
import scala.collection.mutable.ArrayBuffer
@@ -28,13 +31,15 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.collection.mutable.SynchronizedMap
import scala.collection.mutable.SynchronizedQueue
+
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.concurrent.duration._
import org.apache.spark._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{SystemClock, Utils}
-private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Logging {
+private[spark] class ConnectionManager(port: Int, conf: SparkConf,
+ securityManager: SecurityManager) extends Logging {
class MessageStatus(
val message: Message,
@@ -50,6 +55,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
private val selector = SelectorProvider.provider.openSelector()
+ // default to 30 second timeout waiting for authentication
+ private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30)
+
private val handleMessageExecutor = new ThreadPoolExecutor(
conf.getInt("spark.core.connection.handler.threads.min", 20),
conf.getInt("spark.core.connection.handler.threads.max", 60),
@@ -71,6 +79,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
new LinkedBlockingDeque[Runnable]())
private val serverChannel = ServerSocketChannel.open()
+ // used to track the SendingConnections waiting to do SASL negotiation
+ private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection]
+ with SynchronizedMap[ConnectionId, SendingConnection]
private val connectionsByKey =
new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection]
@@ -84,6 +95,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
+ private val authEnabled = securityManager.isAuthenticationEnabled()
+
serverChannel.configureBlocking(false)
serverChannel.socket.setReuseAddress(true)
serverChannel.socket.setReceiveBufferSize(256 * 1024)
@@ -94,6 +107,10 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
+ // used in combination with the ConnectionManagerId to create unique Connection ids
+ // to be able to track asynchronous messages
+ private val idCount: AtomicInteger = new AtomicInteger(1)
+
private val selectorThread = new Thread("connection-manager-thread") {
override def run() = ConnectionManager.this.run()
}
@@ -125,7 +142,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
} finally {
writeRunnableStarted.synchronized {
writeRunnableStarted -= key
- val needReregister = register || conn.resetForceReregister()
+ val needReregister = register || conn.resetForceReregister()
if (needReregister && conn.changeInterestForWrite()) {
conn.registerInterest()
}
@@ -372,7 +389,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
// accept them all in a tight loop. non blocking accept with no processing, should be fine
while (newChannel != null) {
try {
- val newConnection = new ReceivingConnection(newChannel, selector)
+ val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
+ val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId)
newConnection.onReceive(receiveMessage)
addListeners(newConnection)
addConnection(newConnection)
@@ -406,6 +424,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
connectionsById -= sendingConnectionManagerId
+ connectionsAwaitingSasl -= connection.connectionId
messageStatuses.synchronized {
messageStatuses
@@ -481,7 +500,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
val creationTime = System.currentTimeMillis
def run() {
logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
- handleMessage(connectionManagerId, message)
+ handleMessage(connectionManagerId, message, connection)
logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
}
}
@@ -489,10 +508,133 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
/*handleMessage(connection, message)*/
}
- private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) {
+ private def handleClientAuthentication(
+ waitingConn: SendingConnection,
+ securityMsg: SecurityMessage,
+ connectionId : ConnectionId) {
+ if (waitingConn.isSaslComplete()) {
+ logDebug("Client sasl completed for id: " + waitingConn.connectionId)
+ connectionsAwaitingSasl -= waitingConn.connectionId
+ waitingConn.getAuthenticated().synchronized {
+ waitingConn.getAuthenticated().notifyAll();
+ }
+ return
+ } else {
+ var replyToken : Array[Byte] = null
+ try {
+ replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken);
+ if (waitingConn.isSaslComplete()) {
+ logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
+ connectionsAwaitingSasl -= waitingConn.connectionId
+ waitingConn.getAuthenticated().synchronized {
+ waitingConn.getAuthenticated().notifyAll()
+ }
+ return
+ }
+ var securityMsgResp = SecurityMessage.fromResponse(replyToken,
+ securityMsg.getConnectionId.toString())
+ var message = securityMsgResp.toBufferMessage
+ if (message == null) throw new Exception("Error creating security message")
+ sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
+ } catch {
+ case e: Exception => {
+ logError("Error handling sasl client authentication", e)
+ waitingConn.close()
+ throw new Exception("Error evaluating sasl response: " + e)
+ }
+ }
+ }
+ }
+
+ private def handleServerAuthentication(
+ connection: Connection,
+ securityMsg: SecurityMessage,
+ connectionId: ConnectionId) {
+ if (!connection.isSaslComplete()) {
+ logDebug("saslContext not established")
+ var replyToken : Array[Byte] = null
+ try {
+ connection.synchronized {
+ if (connection.sparkSaslServer == null) {
+ logDebug("Creating sasl Server")
+ connection.sparkSaslServer = new SparkSaslServer(securityManager)
+ }
+ }
+ replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
+ if (connection.isSaslComplete()) {
+ logDebug("Server sasl completed: " + connection.connectionId)
+ } else {
+ logDebug("Server sasl not completed: " + connection.connectionId)
+ }
+ if (replyToken != null) {
+ var securityMsgResp = SecurityMessage.fromResponse(replyToken,
+ securityMsg.getConnectionId)
+ var message = securityMsgResp.toBufferMessage
+ if (message == null) throw new Exception("Error creating security Message")
+ sendSecurityMessage(connection.getRemoteConnectionManagerId(), message)
+ }
+ } catch {
+ case e: Exception => {
+ logError("Error in server auth negotiation: " + e)
+ // It would probably be better to send an error message telling other side auth failed
+ // but for now just close
+ connection.close()
+ }
+ }
+ } else {
+ logDebug("connection already established for this connection id: " + connection.connectionId)
+ }
+ }
+
+
+ private def handleAuthentication(conn: Connection, bufferMessage: BufferMessage): Boolean = {
+ if (bufferMessage.isSecurityNeg) {
+ logDebug("This is security neg message")
+
+ // parse as SecurityMessage
+ val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage)
+ val connectionId = ConnectionId.createConnectionIdFromString(securityMsg.getConnectionId)
+
+ connectionsAwaitingSasl.get(connectionId) match {
+ case Some(waitingConn) => {
+ // Client - this must be in response to us doing Send
+ logDebug("Client handleAuth for id: " + waitingConn.connectionId)
+ handleClientAuthentication(waitingConn, securityMsg, connectionId)
+ }
+ case None => {
+ // Server - someone sent us something and we haven't authenticated yet
+ logDebug("Server handleAuth for id: " + connectionId)
+ handleServerAuthentication(conn, securityMsg, connectionId)
+ }
+ }
+ return true
+ } else {
+ if (!conn.isSaslComplete()) {
+ // We could handle this better and tell the client we need to do authentication
+ // negotiation, but for now just ignore them.
+ logError("message sent that is not security negotiation message on connection " +
+ "not authenticated yet, ignoring it!!")
+ return true
+ }
+ }
+ return false
+ }
+
+ private def handleMessage(
+ connectionManagerId: ConnectionManagerId,
+ message: Message,
+ connection: Connection) {
logDebug("Handling [" + message + "] from [" + connectionManagerId + "]")
message match {
case bufferMessage: BufferMessage => {
+ if (authEnabled) {
+ val res = handleAuthentication(connection, bufferMessage)
+ if (res == true) {
+ // message was security negotiation so skip the rest
+ logDebug("After handleAuth result was true, returning")
+ return
+ }
+ }
if (bufferMessage.hasAckId) {
val sentMessageStatus = messageStatuses.synchronized {
messageStatuses.get(bufferMessage.ackId) match {
@@ -541,17 +683,124 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
}
}
+ private def checkSendAuthFirst(connManagerId: ConnectionManagerId, conn: SendingConnection) {
+ // see if we need to do sasl before writing
+ // this should only be the first negotiation as the Client!!!
+ if (!conn.isSaslComplete()) {
+ conn.synchronized {
+ if (conn.sparkSaslClient == null) {
+ conn.sparkSaslClient = new SparkSaslClient(securityManager)
+ var firstResponse: Array[Byte] = null
+ try {
+ firstResponse = conn.sparkSaslClient.firstToken()
+ var securityMsg = SecurityMessage.fromResponse(firstResponse,
+ conn.connectionId.toString())
+ var message = securityMsg.toBufferMessage
+ if (message == null) throw new Exception("Error creating security message")
+ connectionsAwaitingSasl += ((conn.connectionId, conn))
+ sendSecurityMessage(connManagerId, message)
+ logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId)
+ } catch {
+ case e: Exception => {
+ logError("Error getting first response from the SaslClient.", e)
+ conn.close()
+ throw new Exception("Error getting first response from the SaslClient")
+ }
+ }
+ }
+ }
+ } else {
+ logDebug("Sasl already established ")
+ }
+ }
+
+ // allow us to add messages to the inbox for doing sasl negotiating
+ private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) {
+ def startNewConnection(): SendingConnection = {
+ val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port)
+ val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
+ val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId,
+ newConnectionId)
+ logInfo("creating new sending connection for security! " + newConnectionId )
+ registerRequests.enqueue(newConnection)
+
+ newConnection
+ }
+ // I removed the lookupKey stuff as part of merge ... should I re-add it ?
+ // We did not find it useful in our test-env ...
+ // If we do re-add it, we should consistently use it everywhere I guess ?
+ message.senderAddress = id.toSocketAddress()
+ logTrace("Sending Security [" + message + "] to [" + connManagerId + "]")
+ val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection())
+
+ //send security message until going connection has been authenticated
+ connection.send(message)
+
+ wakeupSelector()
+ }
+
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host,
connectionManagerId.port)
- val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId)
+ val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
+ val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId,
+ newConnectionId)
+ logTrace("creating new sending connection: " + newConnectionId)
registerRequests.enqueue(newConnection)
newConnection
}
val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
+ if (authEnabled) {
+ checkSendAuthFirst(connectionManagerId, connection)
+ }
message.senderAddress = id.toSocketAddress()
+ logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " +
+ "connectionid: " + connection.connectionId)
+
+ if (authEnabled) {
+ // if we aren't authenticated yet lets block the senders until authentication completes
+ try {
+ connection.getAuthenticated().synchronized {
+ val clock = SystemClock
+ val startTime = clock.getTime()
+
+ while (!connection.isSaslComplete()) {
+ logDebug("getAuthenticated wait connectionid: " + connection.connectionId)
+ // have timeout in case remote side never responds
+ connection.getAuthenticated().wait(500)
+ if (((clock.getTime() - startTime) >= (authTimeout * 1000))
+ && (!connection.isSaslComplete())) {
+ // took to long to authenticate the connection, something probably went wrong
+ throw new Exception("Took to long for authentication to " + connectionManagerId +
+ ", waited " + authTimeout + "seconds, failing.")
+ }
+ }
+ }
+ } catch {
+ case e: Exception => logError("Exception while waiting for authentication.", e)
+
+ // need to tell sender it failed
+ messageStatuses.synchronized {
+ val s = messageStatuses.get(message.id)
+ s match {
+ case Some(msgStatus) => {
+ messageStatuses -= message.id
+ logInfo("Notifying " + msgStatus.connectionManagerId)
+ msgStatus.synchronized {
+ msgStatus.attempted = true
+ msgStatus.acked = false
+ msgStatus.markDone()
+ }
+ }
+ case None => {
+ logError("no messageStatus for failed message id: " + message.id)
+ }
+ }
+ }
+ }
+ }
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
connection.send(message)
@@ -603,7 +852,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
private[spark] object ConnectionManager {
def main(args: Array[String]) {
- val manager = new ConnectionManager(9999, new SparkConf)
+ val conf = new SparkConf
+ val manager = new ConnectionManager(9999, conf, new SecurityManager(conf))
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
println("Received [" + msg + "] from [" + id + "]")
None
diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala
index 20fe676618..7caccfdbb4 100644
--- a/core/src/main/scala/org/apache/spark/network/Message.scala
+++ b/core/src/main/scala/org/apache/spark/network/Message.scala
@@ -27,6 +27,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
var started = false
var startTime = -1L
var finishTime = -1L
+ var isSecurityNeg = false
def size: Int
diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
index 9bcbc6141a..ead663ede7 100644
--- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
+++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
@@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader(
val totalSize: Int,
val chunkSize: Int,
val other: Int,
+ val securityNeg: Int,
val address: InetSocketAddress) {
lazy val buffer = {
// No need to change this, at 'use' time, we do a reverse lookup of the hostname.
@@ -40,6 +41,7 @@ private[spark] class MessageChunkHeader(
putInt(totalSize).
putInt(chunkSize).
putInt(other).
+ putInt(securityNeg).
putInt(ip.size).
put(ip).
putInt(port).
@@ -48,12 +50,13 @@ private[spark] class MessageChunkHeader(
}
override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
- " and sizes " + totalSize + " / " + chunkSize + " bytes"
+ " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg
+
}
private[spark] object MessageChunkHeader {
- val HEADER_SIZE = 40
+ val HEADER_SIZE = 44
def create(buffer: ByteBuffer): MessageChunkHeader = {
if (buffer.remaining != HEADER_SIZE) {
@@ -64,11 +67,13 @@ private[spark] object MessageChunkHeader {
val totalSize = buffer.getInt()
val chunkSize = buffer.getInt()
val other = buffer.getInt()
+ val securityNeg = buffer.getInt()
val ipSize = buffer.getInt()
val ipBytes = new Array[Byte](ipSize)
buffer.get(ipBytes)
val ip = InetAddress.getByAddress(ipBytes)
val port = buffer.getInt()
- new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
+ new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg,
+ new InetSocketAddress(ip, port))
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
index 9976255c7e..3c09a713c6 100644
--- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
@@ -18,12 +18,12 @@
package org.apache.spark.network
import java.nio.ByteBuffer
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
private[spark] object ReceiverTest {
def main(args: Array[String]) {
- val manager = new ConnectionManager(9999, new SparkConf)
+ val conf = new SparkConf
+ val manager = new ConnectionManager(9999, conf, new SecurityManager(conf))
println("Started connection manager with id = " + manager.id)
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
new file mode 100644
index 0000000000..0d9f743b36
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.StringBuilder
+
+import org.apache.spark._
+import org.apache.spark.network._
+
+/**
+ * SecurityMessage is class that contains the connectionId and sasl token
+ * used in SASL negotiation. SecurityMessage has routines for converting
+ * it to and from a BufferMessage so that it can be sent by the ConnectionManager
+ * and easily consumed by users when received.
+ * The api was modeled after BlockMessage.
+ *
+ * The connectionId is the connectionId of the client side. Since
+ * message passing is asynchronous and its possible for the server side (receiving)
+ * to get multiple different types of messages on the same connection the connectionId
+ * is used to know which connnection the security message is intended for.
+ *
+ * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side
+ * is acting as a client and connecting to node_1. SASL negotiation has to occur
+ * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message.
+ * node_1 receives the message from node_0 but before it can process it and send a response,
+ * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0
+ * and sends a security message of its own to authenticate as a client. Now node_0 gets
+ * the message and it needs to decide if this message is in response to it being a client
+ * (from the first send) or if its just node_1 trying to connect to it to send data. This
+ * is where the connectionId field is used. node_0 can lookup the connectionId to see if
+ * it is in response to it being a client or if its in response to someone sending other data.
+ *
+ * The format of a SecurityMessage as its sent is:
+ * - Length of the ConnectionId
+ * - ConnectionId
+ * - Length of the token
+ * - Token
+ */
+private[spark] class SecurityMessage() extends Logging {
+
+ private var connectionId: String = null
+ private var token: Array[Byte] = null
+
+ def set(byteArr: Array[Byte], newconnectionId: String) {
+ if (byteArr == null) {
+ token = new Array[Byte](0)
+ } else {
+ token = byteArr
+ }
+ connectionId = newconnectionId
+ }
+
+ /**
+ * Read the given buffer and set the members of this class.
+ */
+ def set(buffer: ByteBuffer) {
+ val idLength = buffer.getInt()
+ val idBuilder = new StringBuilder(idLength)
+ for (i <- 1 to idLength) {
+ idBuilder += buffer.getChar()
+ }
+ connectionId = idBuilder.toString()
+
+ val tokenLength = buffer.getInt()
+ token = new Array[Byte](tokenLength)
+ if (tokenLength > 0) {
+ buffer.get(token, 0, tokenLength)
+ }
+ }
+
+ def set(bufferMsg: BufferMessage) {
+ val buffer = bufferMsg.buffers.apply(0)
+ buffer.clear()
+ set(buffer)
+ }
+
+ def getConnectionId: String = {
+ return connectionId
+ }
+
+ def getToken: Array[Byte] = {
+ return token
+ }
+
+ /**
+ * Create a BufferMessage that can be sent by the ConnectionManager containing
+ * the security information from this class.
+ * @return BufferMessage
+ */
+ def toBufferMessage: BufferMessage = {
+ val startTime = System.currentTimeMillis
+ val buffers = new ArrayBuffer[ByteBuffer]()
+
+ // 4 bytes for the length of the connectionId
+ // connectionId is of type char so multiple the length by 2 to get number of bytes
+ // 4 bytes for the length of token
+ // token is a byte buffer so just take the length
+ var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length)
+ buffer.putInt(connectionId.length())
+ connectionId.foreach((x: Char) => buffer.putChar(x))
+ buffer.putInt(token.length)
+
+ if (token.length > 0) {
+ buffer.put(token)
+ }
+ buffer.flip()
+ buffers += buffer
+
+ var message = Message.createBufferMessage(buffers)
+ logDebug("message total size is : " + message.size)
+ message.isSecurityNeg = true
+ return message
+ }
+
+ override def toString: String = {
+ "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]"
+ }
+}
+
+private[spark] object SecurityMessage {
+
+ /**
+ * Convert the given BufferMessage to a SecurityMessage by parsing the contents
+ * of the BufferMessage and populating the SecurityMessage fields.
+ * @param bufferMessage is a BufferMessage that was received
+ * @return new SecurityMessage
+ */
+ def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = {
+ val newSecurityMessage = new SecurityMessage()
+ newSecurityMessage.set(bufferMessage)
+ newSecurityMessage
+ }
+
+ /**
+ * Create a SecurityMessage to send from a given saslResponse.
+ * @param response is the response to a challenge from the SaslClient or Saslserver
+ * @param connectionId the client connectionId we are negotiation authentication for
+ * @return a new SecurityMessage
+ */
+ def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = {
+ val newSecurityMessage = new SecurityMessage()
+ newSecurityMessage.set(response, connectionId)
+ newSecurityMessage
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
index 646f8425d9..aac2c24a46 100644
--- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
@@ -18,8 +18,7 @@
package org.apache.spark.network
import java.nio.ByteBuffer
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
private[spark] object SenderTest {
def main(args: Array[String]) {
@@ -32,8 +31,8 @@ private[spark] object SenderTest {
val targetHost = args(0)
val targetPort = args(1).toInt
val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort)
-
- val manager = new ConnectionManager(0, new SparkConf)
+ val conf = new SparkConf
+ val manager = new ConnectionManager(0, conf, new SecurityManager(conf))
println("Started connection manager with id = " + manager.id)
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 977c24687c..1bf3f4db32 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props}
import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
import sun.nio.ch.DirectBuffer
-import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
+import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException, SecurityManager}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.serializer.Serializer
@@ -47,7 +47,8 @@ private[spark] class BlockManager(
val master: BlockManagerMaster,
val defaultSerializer: Serializer,
maxMemory: Long,
- val conf: SparkConf)
+ val conf: SparkConf,
+ securityManager: SecurityManager)
extends Logging {
val shuffleBlockManager = new ShuffleBlockManager(this)
@@ -66,7 +67,7 @@ private[spark] class BlockManager(
if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
}
- val connectionManager = new ConnectionManager(0, conf)
+ val connectionManager = new ConnectionManager(0, conf, securityManager)
implicit val futureExecContext = connectionManager.futureExecContext
val blockManagerId = BlockManagerId(
@@ -122,8 +123,9 @@ private[spark] class BlockManager(
* Construct a BlockManager with a memory limit set based on system properties.
*/
def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster,
- serializer: Serializer, conf: SparkConf) = {
- this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf)
+ serializer: Serializer, conf: SparkConf, securityManager: SecurityManager) = {
+ this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf,
+ securityManager)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
index 1d81d006c0..36f2a0fd02 100644
--- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
@@ -24,6 +24,7 @@ import util.Random
import org.apache.spark.SparkConf
import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.{SecurityManager, SparkConf}
/**
* This class tests the BlockManager and MemoryStore for thread safety and
@@ -98,7 +99,8 @@ private[spark] object ThreadingTest {
val blockManagerMaster = new BlockManagerMaster(
actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf))), conf)
val blockManager = new BlockManager(
- "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf)
+ "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf,
+ new SecurityManager(conf))
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
producers.foreach(_.start)
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 1b78c52ff6..7c35cd165a 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -18,7 +18,8 @@
package org.apache.spark.ui
import java.net.InetSocketAddress
-import javax.servlet.http.{HttpServletResponse, HttpServletRequest}
+import java.net.URL
+import javax.servlet.http.{HttpServlet, HttpServletResponse, HttpServletRequest}
import scala.annotation.tailrec
import scala.util.{Failure, Success, Try}
@@ -26,11 +27,14 @@ import scala.xml.Node
import org.json4s.JValue
import org.json4s.jackson.JsonMethods.{pretty, render}
-import org.eclipse.jetty.server.{Handler, Request, Server}
-import org.eclipse.jetty.server.handler.{AbstractHandler, ContextHandler, HandlerList, ResourceHandler}
+
+import org.eclipse.jetty.server.{DispatcherType, Server}
+import org.eclipse.jetty.server.handler.HandlerList
+import org.eclipse.jetty.servlet.{DefaultServlet, FilterHolder, ServletContextHandler, ServletHolder}
import org.eclipse.jetty.util.thread.QueuedThreadPool
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
+
/** Utilities for launching a web server using Jetty's HTTP Server class */
private[spark] object JettyUtils extends Logging {
@@ -39,57 +43,104 @@ private[spark] object JettyUtils extends Logging {
type Responder[T] = HttpServletRequest => T
- // Conversions from various types of Responder's to jetty Handlers
- implicit def jsonResponderToHandler(responder: Responder[JValue]): Handler =
- createHandler(responder, "text/json", (in: JValue) => pretty(render(in)))
+ class ServletParams[T <% AnyRef](val responder: Responder[T],
+ val contentType: String,
+ val extractFn: T => String = (in: Any) => in.toString) {}
+
+ // Conversions from various types of Responder's to appropriate servlet parameters
+ implicit def jsonResponderToServlet(responder: Responder[JValue]): ServletParams[JValue] =
+ new ServletParams(responder, "text/json", (in: JValue) => pretty(render(in)))
- implicit def htmlResponderToHandler(responder: Responder[Seq[Node]]): Handler =
- createHandler(responder, "text/html", (in: Seq[Node]) => "<!DOCTYPE html>" + in.toString)
+ implicit def htmlResponderToServlet(responder: Responder[Seq[Node]]): ServletParams[Seq[Node]] =
+ new ServletParams(responder, "text/html", (in: Seq[Node]) => "<!DOCTYPE html>" + in.toString)
- implicit def textResponderToHandler(responder: Responder[String]): Handler =
- createHandler(responder, "text/plain")
+ implicit def textResponderToServlet(responder: Responder[String]): ServletParams[String] =
+ new ServletParams(responder, "text/plain")
- def createHandler[T <% AnyRef](responder: Responder[T], contentType: String,
- extractFn: T => String = (in: Any) => in.toString): Handler = {
- new AbstractHandler {
- def handle(target: String,
- baseRequest: Request,
- request: HttpServletRequest,
+ def createServlet[T <% AnyRef](servletParams: ServletParams[T],
+ securityMgr: SecurityManager): HttpServlet = {
+ new HttpServlet {
+ override def doGet(request: HttpServletRequest,
response: HttpServletResponse) {
- response.setContentType("%s;charset=utf-8".format(contentType))
- response.setStatus(HttpServletResponse.SC_OK)
- baseRequest.setHandled(true)
- val result = responder(request)
- response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
- response.getWriter().println(extractFn(result))
+ if (securityMgr.checkUIViewPermissions(request.getRemoteUser())) {
+ response.setContentType("%s;charset=utf-8".format(servletParams.contentType))
+ response.setStatus(HttpServletResponse.SC_OK)
+ val result = servletParams.responder(request)
+ response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
+ response.getWriter().println(servletParams.extractFn(result))
+ } else {
+ response.setStatus(HttpServletResponse.SC_UNAUTHORIZED)
+ response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
+ response.sendError(HttpServletResponse.SC_UNAUTHORIZED,
+ "User is not authorized to access this page.");
+ }
}
}
}
+ def createServletHandler(path: String, servlet: HttpServlet): ServletContextHandler = {
+ val contextHandler = new ServletContextHandler()
+ val holder = new ServletHolder(servlet)
+ contextHandler.setContextPath(path)
+ contextHandler.addServlet(holder, "/")
+ contextHandler
+ }
+
/** Creates a handler that always redirects the user to a given path */
- def createRedirectHandler(newPath: String): Handler = {
- new AbstractHandler {
- def handle(target: String,
- baseRequest: Request,
- request: HttpServletRequest,
+ def createRedirectHandler(newPath: String, path: String): ServletContextHandler = {
+ val servlet = new HttpServlet {
+ override def doGet(request: HttpServletRequest,
response: HttpServletResponse) {
- response.setStatus(302)
- response.setHeader("Location", baseRequest.getRootURL + newPath)
- baseRequest.setHandled(true)
+ // make sure we don't end up with // in the middle
+ val newUri = new URL(new URL(request.getRequestURL.toString), newPath).toURI
+ response.sendRedirect(newUri.toString)
}
}
+ val contextHandler = new ServletContextHandler()
+ val holder = new ServletHolder(servlet)
+ contextHandler.setContextPath(path)
+ contextHandler.addServlet(holder, "/")
+ contextHandler
}
/** Creates a handler for serving files from a static directory */
- def createStaticHandler(resourceBase: String): ResourceHandler = {
- val staticHandler = new ResourceHandler
+ def createStaticHandler(resourceBase: String, path: String): ServletContextHandler = {
+ val contextHandler = new ServletContextHandler()
+ val staticHandler = new DefaultServlet
+ val holder = new ServletHolder(staticHandler)
Option(getClass.getClassLoader.getResource(resourceBase)) match {
case Some(res) =>
- staticHandler.setResourceBase(res.toString)
+ holder.setInitParameter("resourceBase", res.toString)
case None =>
throw new Exception("Could not find resource path for Web UI: " + resourceBase)
}
- staticHandler
+ contextHandler.addServlet(holder, path)
+ contextHandler
+ }
+
+ private def addFilters(handlers: Seq[ServletContextHandler], conf: SparkConf) {
+ val filters: Array[String] = conf.get("spark.ui.filters", "").split(',').map(_.trim())
+ filters.foreach {
+ case filter : String =>
+ if (!filter.isEmpty) {
+ logInfo("Adding filter: " + filter)
+ val holder : FilterHolder = new FilterHolder()
+ holder.setClassName(filter)
+ // get any parameters for each filter
+ val paramName = "spark." + filter + ".params"
+ val params = conf.get(paramName, "").split(',').map(_.trim()).toSet
+ params.foreach {
+ case param : String =>
+ if (!param.isEmpty) {
+ val parts = param.split("=")
+ if (parts.length == 2) holder.setInitParameter(parts(0), parts(1))
+ }
+ }
+ val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR,
+ DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST)
+ handlers.foreach { case(handler) => handler.addFilter(holder, "/*", enumDispatcher) }
+ }
+ }
}
/**
@@ -99,17 +150,12 @@ private[spark] object JettyUtils extends Logging {
* If the desired port number is contented, continues incrementing ports until a free port is
* found. Returns the chosen port and the jetty Server object.
*/
- def startJettyServer(hostName: String, port: Int, handlers: Seq[(String, Handler)]): (Server, Int)
- = {
-
- val handlersToRegister = handlers.map { case(path, handler) =>
- val contextHandler = new ContextHandler(path)
- contextHandler.setHandler(handler)
- contextHandler.asInstanceOf[org.eclipse.jetty.server.Handler]
- }
+ def startJettyServer(hostName: String, port: Int, handlers: Seq[ServletContextHandler],
+ conf: SparkConf): (Server, Int) = {
+ addFilters(handlers, conf)
val handlerList = new HandlerList
- handlerList.setHandlers(handlersToRegister.toArray)
+ handlerList.setHandlers(handlers.toArray)
@tailrec
def connect(currentPort: Int): (Server, Int) = {
@@ -119,7 +165,9 @@ private[spark] object JettyUtils extends Logging {
server.setThreadPool(pool)
server.setHandler(handlerList)
- Try { server.start() } match {
+ Try {
+ server.start()
+ } match {
case s: Success[_] =>
(server, server.getConnectors.head.getLocalPort)
case f: Failure[_] =>
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index af6b65860e..ca82c3da2f 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -17,7 +17,10 @@
package org.apache.spark.ui
-import org.eclipse.jetty.server.{Handler, Server}
+import javax.servlet.http.HttpServletRequest
+
+import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.{Logging, SparkContext, SparkEnv}
import org.apache.spark.ui.JettyUtils._
@@ -34,9 +37,9 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging {
var boundPort: Option[Int] = None
var server: Option[Server] = None
- val handlers = Seq[(String, Handler)](
- ("/static", createStaticHandler(SparkUI.STATIC_RESOURCE_DIR)),
- ("/", createRedirectHandler("/stages"))
+ val handlers = Seq[ServletContextHandler] (
+ createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static/*"),
+ createRedirectHandler("/stages", "/")
)
val storage = new BlockManagerUI(sc)
val jobs = new JobProgressUI(sc)
@@ -52,7 +55,7 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging {
/** Bind the HTTP server which backs this web interface */
def bind() {
try {
- val (srv, usedPort) = JettyUtils.startJettyServer(host, port, allHandlers)
+ val (srv, usedPort) = JettyUtils.startJettyServer(host, port, allHandlers, sc.conf)
logInfo("Started Spark Web UI at http://%s:%d".format(host, usedPort))
server = Some(srv)
boundPort = Some(usedPort)
@@ -83,5 +86,5 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging {
private[spark] object SparkUI {
val DEFAULT_PORT = "4040"
- val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static"
+ val STATIC_RESOURCE_DIR = "org/apache/spark/ui"
}
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala
index 9e7cdc8816..14333476c0 100644
--- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConversions._
import scala.util.Properties
import scala.xml.Node
-import org.eclipse.jetty.server.Handler
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.SparkContext
import org.apache.spark.ui.JettyUtils._
@@ -32,8 +32,9 @@ import org.apache.spark.ui.UIUtils
private[spark] class EnvironmentUI(sc: SparkContext) {
- def getHandlers = Seq[(String, Handler)](
- ("/environment", (request: HttpServletRequest) => envDetails(request))
+ def getHandlers = Seq[ServletContextHandler](
+ createServletHandler("/environment",
+ createServlet((request: HttpServletRequest) => envDetails(request), sc.env.securityManager))
)
def envDetails(request: HttpServletRequest): Seq[Node] = {
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
index 1f3b7a4c23..4235cfeff9 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
@@ -22,7 +22,7 @@ import javax.servlet.http.HttpServletRequest
import scala.collection.mutable.{HashMap, HashSet}
import scala.xml.Node
-import org.eclipse.jetty.server.Handler
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.{ExceptionFailure, Logging, SparkContext}
import org.apache.spark.executor.TaskMetrics
@@ -43,8 +43,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
sc.addSparkListener(listener)
}
- def getHandlers = Seq[(String, Handler)](
- ("/executors", (request: HttpServletRequest) => render(request))
+ def getHandlers = Seq[ServletContextHandler](
+ createServletHandler("/executors", createServlet((request: HttpServletRequest) => render
+ (request), sc.env.securityManager))
)
def render(request: HttpServletRequest): Seq[Node] = {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
index 557bce6b66..2d95d47e15 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
@@ -23,6 +23,7 @@ import javax.servlet.http.HttpServletRequest
import scala.Seq
import org.eclipse.jetty.server.Handler
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.SparkContext
import org.apache.spark.ui.JettyUtils._
@@ -45,9 +46,15 @@ private[spark] class JobProgressUI(val sc: SparkContext) {
def formatDuration(ms: Long) = Utils.msDurationToString(ms)
- def getHandlers = Seq[(String, Handler)](
- ("/stages/stage", (request: HttpServletRequest) => stagePage.render(request)),
- ("/stages/pool", (request: HttpServletRequest) => poolPage.render(request)),
- ("/stages", (request: HttpServletRequest) => indexPage.render(request))
+ def getHandlers = Seq[ServletContextHandler](
+ createServletHandler("/stages/stage",
+ createServlet((request: HttpServletRequest) => stagePage.render(request),
+ sc.env.securityManager)),
+ createServletHandler("/stages/pool",
+ createServlet((request: HttpServletRequest) => poolPage.render(request),
+ sc.env.securityManager)),
+ createServletHandler("/stages",
+ createServlet((request: HttpServletRequest) => indexPage.render(request),
+ sc.env.securityManager))
)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
index dc18eab74e..cb2083eb01 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ui.storage
import javax.servlet.http.HttpServletRequest
-import org.eclipse.jetty.server.Handler
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.ui.JettyUtils._
@@ -29,8 +29,12 @@ private[spark] class BlockManagerUI(val sc: SparkContext) extends Logging {
val indexPage = new IndexPage(this)
val rddPage = new RDDPage(this)
- def getHandlers = Seq[(String, Handler)](
- ("/storage/rdd", (request: HttpServletRequest) => rddPage.render(request)),
- ("/storage", (request: HttpServletRequest) => indexPage.render(request))
+ def getHandlers = Seq[ServletContextHandler](
+ createServletHandler("/storage/rdd",
+ createServlet((request: HttpServletRequest) => rddPage.render(request),
+ sc.env.securityManager)),
+ createServletHandler("/storage",
+ createServlet((request: HttpServletRequest) => indexPage.render(request),
+ sc.env.securityManager))
)
}
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index f26ed47e58..a6c9a9aaba 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -24,12 +24,12 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, IndestructibleActorSystem}
import com.typesafe.config.ConfigFactory
import org.apache.log4j.{Level, Logger}
-import org.apache.spark.SparkConf
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
/**
* Various utility classes for working with Akka.
*/
-private[spark] object AkkaUtils {
+private[spark] object AkkaUtils extends Logging {
/**
* Creates an ActorSystem ready for remoting, with various Spark features. Returns both the
@@ -42,7 +42,7 @@ private[spark] object AkkaUtils {
* of a fatal exception. This is used by [[org.apache.spark.executor.Executor]].
*/
def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false,
- conf: SparkConf): (ActorSystem, Int) = {
+ conf: SparkConf, securityManager: SecurityManager): (ActorSystem, Int) = {
val akkaThreads = conf.getInt("spark.akka.threads", 4)
val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15)
@@ -65,6 +65,15 @@ private[spark] object AkkaUtils {
conf.getDouble("spark.akka.failure-detector.threshold", 300.0)
val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000)
+ val secretKey = securityManager.getSecretKey()
+ val isAuthOn = securityManager.isAuthenticationEnabled()
+ if (isAuthOn && secretKey == null) {
+ throw new Exception("Secret key is null with authentication on")
+ }
+ val requireCookie = if (isAuthOn) "on" else "off"
+ val secureCookie = if (isAuthOn) secretKey else ""
+ logDebug("In createActorSystem, requireCookie is: " + requireCookie)
+
val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]).withFallback(
ConfigFactory.parseString(
s"""
@@ -72,6 +81,8 @@ private[spark] object AkkaUtils {
|akka.loggers = [""akka.event.slf4j.Slf4jLogger""]
|akka.stdout-loglevel = "ERROR"
|akka.jvm-exit-on-fatal-error = off
+ |akka.remote.require-cookie = "$requireCookie"
+ |akka.remote.secure-cookie = "$secureCookie"
|akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatInterval s
|akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPauses s
|akka.remote.transport-failure-detector.threshold = $akkaFailureDetector
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 8e69f1d335..0eb2f78b73 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,7 +18,7 @@
package org.apache.spark.util
import java.io._
-import java.net.{InetAddress, Inet4Address, NetworkInterface, URI, URL}
+import java.net.{InetAddress, Inet4Address, NetworkInterface, URI, URL, URLConnection}
import java.nio.ByteBuffer
import java.util.{Locale, Random, UUID}
import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor}
@@ -33,10 +33,11 @@ import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
-import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
import org.apache.spark.deploy.SparkHadoopUtil
+
/**
* Various utility methods used by Spark.
*/
@@ -233,13 +234,29 @@ private[spark] object Utils extends Logging {
}
/**
+ * Construct a URI container information used for authentication.
+ * This also sets the default authenticator to properly negotiation the
+ * user/password based on the URI.
+ *
+ * Note this relies on the Authenticator.setDefault being set properly to decode
+ * the user name and password. This is currently set in the SecurityManager.
+ */
+ def constructURIForAuthentication(uri: URI, securityMgr: SecurityManager): URI = {
+ val userCred = securityMgr.getSecretKey()
+ if (userCred == null) throw new Exception("Secret key is null with authentication on")
+ val userInfo = securityMgr.getHttpUser() + ":" + userCred
+ new URI(uri.getScheme(), userInfo, uri.getHost(), uri.getPort(), uri.getPath(),
+ uri.getQuery(), uri.getFragment())
+ }
+
+ /**
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
*
* Throws SparkException if the target file already exists and has different contents than
* the requested file.
*/
- def fetchFile(url: String, targetDir: File, conf: SparkConf) {
+ def fetchFile(url: String, targetDir: File, conf: SparkConf, securityMgr: SecurityManager) {
val filename = url.split("/").last
val tempDir = getLocalDir(conf)
val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
@@ -249,7 +266,19 @@ private[spark] object Utils extends Logging {
uri.getScheme match {
case "http" | "https" | "ftp" =>
logInfo("Fetching " + url + " to " + tempFile)
- val in = new URL(url).openStream()
+
+ var uc: URLConnection = null
+ if (securityMgr.isAuthenticationEnabled()) {
+ logDebug("fetchFile with security enabled")
+ val newuri = constructURIForAuthentication(uri, securityMgr)
+ uc = newuri.toURL().openConnection()
+ uc.setAllowUserInteraction(false)
+ } else {
+ logDebug("fetchFile not using security")
+ uc = new URL(url).openConnection()
+ }
+
+ val in = uc.getInputStream();
val out = new FileOutputStream(tempFile)
Utils.copyStream(in, out, true)
if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
new file mode 100644
index 0000000000..cd054c1f68
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+
+import akka.actor._
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.AkkaUtils
+import scala.concurrent.Await
+
+/**
+ * Test the AkkaUtils with various security settings.
+ */
+class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
+
+ test("remote fetch security bad password") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+
+ val securityManager = new SecurityManager(conf);
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
+ conf = conf, securityManager = securityManager)
+ System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+ assert(securityManager.isAuthenticationEnabled() === true)
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "true")
+ badconf.set("spark.authenticate.secret", "bad")
+ val securityManagerBad = new SecurityManager(badconf);
+
+ assert(securityManagerBad.isAuthenticationEnabled() === true)
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
+ conf = conf, securityManager = securityManagerBad)
+ val slaveTracker = new MapOutputTracker(conf)
+ val selection = slaveSystem.actorSelection(
+ s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ intercept[akka.actor.ActorNotFound] {
+ slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+ }
+
+ actorSystem.shutdown()
+ slaveSystem.shutdown()
+ }
+
+ test("remote fetch security off") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "false")
+ conf.set("spark.authenticate.secret", "bad")
+ val securityManager = new SecurityManager(conf);
+
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
+ conf = conf, securityManager = securityManager)
+ System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+
+ assert(securityManager.isAuthenticationEnabled() === false)
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "false")
+ badconf.set("spark.authenticate.secret", "good")
+ val securityManagerBad = new SecurityManager(badconf);
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
+ conf = badconf, securityManager = securityManagerBad)
+ val slaveTracker = new MapOutputTracker(conf)
+ val selection = slaveSystem.actorSelection(
+ s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+
+ assert(securityManagerBad.isAuthenticationEnabled() === false)
+
+ masterTracker.registerShuffle(10, 1)
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ masterTracker.registerMapOutput(10, 0, new MapStatus(
+ BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+ // this should succeed since security off
+ assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
+ Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
+
+ actorSystem.shutdown()
+ slaveSystem.shutdown()
+ }
+
+ test("remote fetch security pass") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf);
+
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
+ conf = conf, securityManager = securityManager)
+ System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+
+ assert(securityManager.isAuthenticationEnabled() === true)
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
+
+ val goodconf = new SparkConf
+ goodconf.set("spark.authenticate", "true")
+ goodconf.set("spark.authenticate.secret", "good")
+ val securityManagerGood = new SecurityManager(goodconf);
+
+ assert(securityManagerGood.isAuthenticationEnabled() === true)
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
+ conf = goodconf, securityManager = securityManagerGood)
+ val slaveTracker = new MapOutputTracker(conf)
+ val selection = slaveSystem.actorSelection(
+ s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+
+ masterTracker.registerShuffle(10, 1)
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ masterTracker.registerMapOutput(10, 0, new MapStatus(
+ BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+ // this should succeed since security on and passwords match
+ assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
+ Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
+
+ actorSystem.shutdown()
+ slaveSystem.shutdown()
+ }
+
+ test("remote fetch security off client") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+
+ val securityManager = new SecurityManager(conf);
+
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
+ conf = conf, securityManager = securityManager)
+ System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+
+ assert(securityManager.isAuthenticationEnabled() === true)
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "false")
+ badconf.set("spark.authenticate.secret", "bad")
+ val securityManagerBad = new SecurityManager(badconf);
+
+ assert(securityManagerBad.isAuthenticationEnabled() === false)
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
+ conf = badconf, securityManager = securityManagerBad)
+ val slaveTracker = new MapOutputTracker(conf)
+ val selection = slaveSystem.actorSelection(
+ s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ intercept[akka.actor.ActorNotFound] {
+ slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+ }
+
+ actorSystem.shutdown()
+ slaveSystem.shutdown()
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
index e022accee6..96ba3929c1 100644
--- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
@@ -21,6 +21,7 @@ import org.scalatest.FunSuite
class BroadcastSuite extends FunSuite with LocalSparkContext {
+
override def afterEach() {
super.afterEach()
System.clearProperty("spark.broadcast.factory")
diff --git a/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala
new file mode 100644
index 0000000000..80f7ec00c7
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+
+import java.nio._
+
+import org.apache.spark.network.{ConnectionManager, Message, ConnectionManagerId}
+import scala.concurrent.Await
+import scala.concurrent.TimeoutException
+import scala.concurrent.duration._
+
+
+/**
+ * Test the ConnectionManager with various security settings.
+ */
+class ConnectionManagerSuite extends FunSuite {
+
+ test("security default off") {
+ val conf = new SparkConf
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var receivedMessage = false
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ receivedMessage = true
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliablySync(manager.id, bufferMessage)
+
+ assert(receivedMessage == true)
+
+ manager.stop()
+ }
+
+ test("security on same password") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+ val managerServer = new ConnectionManager(0, conf, securityManager)
+ var numReceivedServerMessages = 0
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val count = 10
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliablySync(managerServer.id, bufferMessage)
+ })
+
+ assert(numReceivedServerMessages == 10)
+ assert(numReceivedMessages == 0)
+
+ manager.stop()
+ managerServer.stop()
+ }
+
+ test("security mismatch password") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "true")
+ badconf.set("spark.authenticate.secret", "bad")
+ val badsecurityManager = new SecurityManager(badconf)
+ val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
+ var numReceivedServerMessages = 0
+
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliablySync(managerServer.id, bufferMessage)
+
+ assert(numReceivedServerMessages == 0)
+ assert(numReceivedMessages == 0)
+
+ manager.stop()
+ managerServer.stop()
+ }
+
+ test("security mismatch auth off") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "false")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "true")
+ badconf.set("spark.authenticate.secret", "good")
+ val badsecurityManager = new SecurityManager(badconf)
+ val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
+ var numReceivedServerMessages = 0
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ (0 until 1).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(managerServer.id, bufferMessage)
+ }).foreach(f => {
+ try {
+ val g = Await.result(f, 1 second)
+ assert(false)
+ } catch {
+ case e: TimeoutException => {
+ // we should timeout here since the client can't do the negotiation
+ assert(true)
+ }
+ }
+ })
+
+ assert(numReceivedServerMessages == 0)
+ assert(numReceivedMessages == 0)
+ manager.stop()
+ managerServer.stop()
+ }
+
+ test("security auth off") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "false")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "false")
+ val badsecurityManager = new SecurityManager(badconf)
+ val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
+ var numReceivedServerMessages = 0
+
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ (0 until 10).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(managerServer.id, bufferMessage)
+ }).foreach(f => {
+ try {
+ val g = Await.result(f, 1 second)
+ if (!g.isDefined) assert(false) else assert(true)
+ } catch {
+ case e: Exception => {
+ assert(false)
+ }
+ }
+ })
+ assert(numReceivedServerMessages == 10)
+ assert(numReceivedMessages == 0)
+
+ manager.stop()
+ managerServer.stop()
+ }
+
+
+
+}
+
diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala
index e0e8011278..9cbdfc54a3 100644
--- a/core/src/test/scala/org/apache/spark/DriverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala
@@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.util.Utils
class DriverSuite extends FunSuite with Timeouts {
+
test("driver should exit after finishing") {
val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.props.get("spark.home")).get
// Regression test for SPARK-530: "Spark driver process doesn't exit after finishing"
diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
index 9be67b3c95..aee9ab9091 100644
--- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
@@ -30,6 +30,12 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
@transient var tmpFile: File = _
@transient var tmpJarUrl: String = _
+ override def beforeEach() {
+ super.beforeEach()
+ resetSparkContext()
+ System.setProperty("spark.authenticate", "false")
+ }
+
override def beforeAll() {
super.beforeAll()
val tmpDir = new File(Files.createTempDir(), "test")
@@ -43,6 +49,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
val jarFile = new File(tmpDir, "test.jar")
val jarStream = new FileOutputStream(jarFile)
val jar = new JarOutputStream(jarStream, new java.util.jar.Manifest())
+ System.setProperty("spark.authenticate", "false")
val jarEntry = new JarEntry(textFile.getName)
jar.putNextEntry(jarEntry)
@@ -77,6 +84,25 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
assert(result.toSet === Set((1,200), (2,300), (3,500)))
}
+ test("Distributing files locally security On") {
+ val sparkConf = new SparkConf(false)
+ sparkConf.set("spark.authenticate", "true")
+ sparkConf.set("spark.authenticate.secret", "good")
+ sc = new SparkContext("local[4]", "test", sparkConf)
+
+ sc.addFile(tmpFile.toString)
+ assert(sc.env.securityManager.isAuthenticationEnabled() === true)
+ val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
+ val result = sc.parallelize(testData).reduceByKey {
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
+ val fileVal = in.readLine().toInt
+ in.close()
+ _ * fileVal + _ * fileVal
+ }.collect()
+ assert(result.toSet === Set((1,200), (2,300), (3,500)))
+ }
+
test("Distributing files locally using URL as input") {
// addFile("file:///....")
sc = new SparkContext("local[4]", "test")
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 6c1e325f6f..8efa072a97 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -98,14 +98,16 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("remote fetch") {
val hostname = "localhost"
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf,
+ securityManager = new SecurityManager(conf))
System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
val masterTracker = new MapOutputTrackerMaster(conf)
masterTracker.trackerActor = actorSystem.actorOf(
Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf)
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf,
+ securityManager = new SecurityManager(conf))
val slaveTracker = new MapOutputTracker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
index c1e8b295df..96a5a12318 100644
--- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
@@ -18,21 +18,22 @@
package org.apache.spark.metrics
import org.scalatest.{BeforeAndAfter, FunSuite}
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.master.MasterSource
class MetricsSystemSuite extends FunSuite with BeforeAndAfter {
var filePath: String = _
var conf: SparkConf = null
+ var securityMgr: SecurityManager = null
before {
filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile()
conf = new SparkConf(false).set("spark.metrics.conf", filePath)
+ securityMgr = new SecurityManager(conf)
}
test("MetricsSystem with default config") {
- val metricsSystem = MetricsSystem.createMetricsSystem("default", conf)
+ val metricsSystem = MetricsSystem.createMetricsSystem("default", conf, securityMgr)
val sources = metricsSystem.sources
val sinks = metricsSystem.sinks
@@ -42,7 +43,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter {
}
test("MetricsSystem with sources add") {
- val metricsSystem = MetricsSystem.createMetricsSystem("test", conf)
+ val metricsSystem = MetricsSystem.createMetricsSystem("test", conf, securityMgr)
val sources = metricsSystem.sources
val sinks = metricsSystem.sinks
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 9f011d9c8d..121e47c7b1 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -28,7 +28,7 @@ import org.scalatest.concurrent.Timeouts._
import org.scalatest.matchers.ShouldMatchers._
import org.scalatest.time.SpanSugar._
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.{SecurityManager, SparkConf, SparkContext}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils}
@@ -39,6 +39,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
var actorSystem: ActorSystem = null
var master: BlockManagerMaster = null
var oldArch: String = null
+ conf.set("spark.authenticate", "false")
+ val securityMgr = new SecurityManager(conf)
// Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
conf.set("spark.kryoserializer.buffer.mb", "1")
@@ -49,7 +51,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId)
before {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf,
+ securityManager = securityMgr)
this.actorSystem = actorSystem
conf.set("spark.driver.port", boundPort.toString)
@@ -125,7 +128,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("master + 1 manager interaction") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -155,8 +158,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("master + 2 managers interaction") {
- store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf)
- store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf)
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr)
+ store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf,
+ securityMgr)
val peers = master.getPeers(store.blockManagerId, 1)
assert(peers.size === 1, "master did not return the other manager as a peer")
@@ -171,7 +175,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("removing block") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -219,7 +223,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("removing rdd") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -253,7 +257,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("reregistration on heart beat") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
@@ -269,7 +273,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("reregistration on block update") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
@@ -288,7 +292,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("reregistration doesn't dead lock") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = List(new Array[Byte](400))
@@ -325,7 +329,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU storage") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -344,7 +348,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU storage with serialization") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -363,7 +367,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of same RDD") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -382,7 +386,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of multiple RDDs") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
@@ -405,7 +409,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("on-disk storage") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -418,7 +422,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -433,7 +437,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with getLocalBytes") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -448,7 +452,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -463,7 +467,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization and getLocalBytes") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -478,7 +482,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -503,7 +507,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU with streams") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@@ -527,7 +531,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels and streams") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@@ -573,7 +577,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("overly large block") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 500, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 500, conf, securityMgr)
store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.getSingle("a1") === None, "a1 was in store")
store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK)
@@ -584,7 +588,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("block compression") {
try {
conf.set("spark.shuffle.compress", "true")
- store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100,
"shuffle_0_0_0 was not compressed")
@@ -592,7 +596,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
store = null
conf.set("spark.shuffle.compress", "false")
- store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000,
"shuffle_0_0_0 was compressed")
@@ -600,7 +604,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
store = null
conf.set("spark.broadcast.compress", "true")
- store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100,
"broadcast_0 was not compressed")
@@ -608,28 +612,28 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
store = null
conf.set("spark.broadcast.compress", "false")
- store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed")
store.stop()
store = null
conf.set("spark.rdd.compress", "true")
- store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed")
store.stop()
store = null
conf.set("spark.rdd.compress", "false")
- store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed")
store.stop()
store = null
// Check that any other block types are also kept uncompressed
- store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed")
store.stop()
@@ -643,7 +647,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("block store put failure") {
// Use Java serializer so we can create an unserializable error.
- store = new BlockManager("<driver>", actorSystem, master, new JavaSerializer(conf), 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, new JavaSerializer(conf), 1200, conf,
+ securityMgr)
// The put should fail since a1 is not serializable.
class UnserializableClass
diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
index 20ebb1897e..30415814ad 100644
--- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
@@ -24,6 +24,8 @@ import scala.util.{Failure, Success, Try}
import org.eclipse.jetty.server.Server
import org.scalatest.FunSuite
+import org.apache.spark.SparkConf
+
class UISuite extends FunSuite {
test("jetty port increases under contention") {
val startPort = 4040
@@ -34,15 +36,17 @@ class UISuite extends FunSuite {
case Failure(e) =>
// Either case server port is busy hence setup for test complete
}
- val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq())
- val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq())
+ val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq(),
+ new SparkConf)
+ val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq(),
+ new SparkConf)
// Allow some wiggle room in case ports on the machine are under contention
assert(boundPort1 > startPort && boundPort1 < startPort + 10)
assert(boundPort2 > boundPort1 && boundPort2 < boundPort1 + 10)
}
test("jetty binds to port 0 correctly") {
- val (jettyServer, boundPort) = JettyUtils.startJettyServer("0.0.0.0", 0, Seq())
+ val (jettyServer, boundPort) = JettyUtils.startJettyServer("0.0.0.0", 0, Seq(), new SparkConf)
assert(jettyServer.getState === "STARTED")
assert(boundPort != 0)
Try {new ServerSocket(boundPort)} match {
diff --git a/docs/configuration.md b/docs/configuration.md
index 017d509854..913c653b0d 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -147,6 +147,34 @@ Apart from these, the following properties are also available, and may be useful
How many stages the Spark UI remembers before garbage collecting.
</td>
</tr>
+</tr>
+ <td>spark.ui.filters</td>
+ <td>None</td>
+ <td>
+ Comma separated list of filter class names to apply to the Spark web ui. The filter should be a
+ standard javax servlet Filter. Parameters to each filter can also be specified by setting a
+ java system property of spark.<class name of filter>.params='param1=value1,param2=value2'
+ (e.g.-Dspark.ui.filters=com.test.filter1 -Dspark.com.test.filter1.params='param1=foo,param2=testing')
+ </td>
+</tr>
+<tr>
+ <td>spark.ui.acls.enable</td>
+ <td>false</td>
+ <td>
+ Whether spark web ui acls should are enabled. If enabled, this checks to see if the user has
+ access permissions to view the web ui. See <code>spark.ui.view.acls</code> for more details.
+ Also note this requires the user to be known, if the user comes across as null no checks
+ are done. Filters can be used to authenticate and set the user.
+ </td>
+</tr>
+<tr>
+ <td>spark.ui.view.acls</td>
+ <td>Empty</td>
+ <td>
+ Comma separated list of users that have view access to the spark web ui. By default only the
+ user that started the Spark job has view access.
+ </td>
+</tr>
<tr>
<td>spark.shuffle.compress</td>
<td>true</td>
@@ -495,6 +523,29 @@ Apart from these, the following properties are also available, and may be useful
<td>
Whether to overwrite files added through SparkContext.addFile() when the target file exists and its contents do not match those of the source.
</td>
+<tr>
+ <td>spark.authenticate</td>
+ <td>false</td>
+ <td>
+ Whether spark authenticates its internal connections. See <code>spark.authenticate.secret</code> if not
+ running on Yarn.
+ </td>
+</tr>
+<tr>
+ <td>spark.authenticate.secret</td>
+ <td>None</td>
+ <td>
+ Set the secret key used for Spark to authenticate between components. This needs to be set if
+ not running on Yarn and authentication is enabled.
+ </td>
+</tr>
+<tr>
+ <td>spark.core.connection.auth.wait.timeout</td>
+ <td>30</td>
+ <td>
+ Number of seconds for the connection to wait for authentication to occur before timing
+ out and giving up.
+ </td>
</tr>
</table>
diff --git a/docs/index.md b/docs/index.md
index 4eb297df39..c4f4d79edb 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -103,6 +103,7 @@ For this version of Spark (0.8.1) Hadoop 2.2.x (or newer) users will have to bui
* [Configuration](configuration.html): customize Spark via its configuration system
* [Tuning Guide](tuning.html): best practices to optimize performance and memory use
+* [Security](security.html): Spark security support
* [Hardware Provisioning](hardware-provisioning.html): recommendations for cluster hardware
* [Job Scheduling](job-scheduling.html): scheduling resources across and within Spark applications
* [Building Spark with Maven](building-with-maven.html): build Spark using the Maven system
diff --git a/docs/security.md b/docs/security.md
new file mode 100644
index 0000000000..9e4218fbcf
--- /dev/null
+++ b/docs/security.md
@@ -0,0 +1,18 @@
+---
+layout: global
+title: Spark Security
+---
+
+Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate.
+
+The Spark UI can also be secured by using javax servlet filters. A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view acls to make sure they are authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' control the behavior of the acls. Note that the person who started the application always has view access to the UI.
+
+For Spark on Yarn deployments, configuring `spark.authenticate` to true will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. The Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. If an authentication filter is enabled, the acls controls can be used by control which users can via the Spark UI.
+
+For other types of Spark deployments, the spark config `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. The UI can be secured using a javax servlet filter installed via `spark.ui.filters`. If an authentication filter is enabled, the acls controls can be used by control which users can via the Spark UI.
+
+IMPORTANT NOTE: The NettyBlockFetcherIterator is not secured so do not use netty for the shuffle is running with authentication on.
+
+See [Spark Configuration](configuration.html) for more details on the security configs.
+
+See <a href="api/core/index.html#org.apache.spark.SecurityManager"><code>org.apache.spark.SecurityManager</code></a> for implementation details about security.
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala
index 3d7b390724..62d3a52615 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala
@@ -23,7 +23,7 @@ import scala.util.Random
import akka.actor.{Actor, ActorRef, Props, actorRef2Scala}
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SecurityManager}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions
import org.apache.spark.streaming.receivers.Receiver
@@ -112,8 +112,9 @@ object FeederActor {
}
val Seq(host, port) = args.toSeq
-
- val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt, conf = new SparkConf)._1
+ val conf = new SparkConf
+ val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt, conf = conf,
+ securityManager = new SecurityManager(conf))._1
val feeder = actorSystem.actorOf(Props[FeederActor], "FeederActor")
println("Feeder started as:" + feeder)
diff --git a/pom.xml b/pom.xml
index c59fada5cd..3b863856e4 100644
--- a/pom.xml
+++ b/pom.xml
@@ -157,6 +157,21 @@
<dependencies>
<dependency>
<groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-util</artifactId>
+ <version>7.6.8.v20121106</version>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-security</artifactId>
+ <version>7.6.8.v20121106</version>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-plus</artifactId>
+ <version>7.6.8.v20121106</version>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
<version>7.6.8.v20121106</version>
</dependency>
@@ -296,6 +311,11 @@
<version>${mesos.version}</version>
</dependency>
<dependency>
+ <groupId>commons-net</groupId>
+ <artifactId>commons-net</artifactId>
+ <version>2.2</version>
+ </dependency>
+ <dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>4.0.17.Final</version>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index aa17848975..138aad7561 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -226,6 +226,9 @@ object SparkBuild extends Build {
libraryDependencies ++= Seq(
"io.netty" % "netty-all" % "4.0.17.Final",
"org.eclipse.jetty" % "jetty-server" % "7.6.8.v20121106",
+ "org.eclipse.jetty" % "jetty-util" % "7.6.8.v20121106",
+ "org.eclipse.jetty" % "jetty-plus" % "7.6.8.v20121106",
+ "org.eclipse.jetty" % "jetty-security" % "7.6.8.v20121106",
/** Workaround for SPARK-959. Dependency used by org.eclipse.jetty. Fixed in ivy 2.3.0. */
"org.eclipse.jetty.orbit" % "javax.servlet" % "2.5.0.v201103041518" artifacts Artifact("javax.servlet", "jar", "jar"),
"org.scalatest" %% "scalatest" % "1.9.1" % "test",
@@ -285,6 +288,7 @@ object SparkBuild extends Build {
"it.unimi.dsi" % "fastutil" % "6.4.4",
"colt" % "colt" % "1.2.0",
"org.apache.mesos" % "mesos" % "0.13.0",
+ "commons-net" % "commons-net" % "2.2",
"net.java.dev.jets3t" % "jets3t" % "0.7.1" excludeAll(excludeCommonsLogging),
"org.apache.derby" % "derby" % "10.4.2.0" % "test",
"org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeCommonsLogging, excludeSLF4J),
diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
index e3bcf7f30a..1aa94079fd 100644
--- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
@@ -18,12 +18,15 @@
package org.apache.spark.repl
import java.io.{ByteArrayOutputStream, InputStream}
-import java.net.{URI, URL, URLClassLoader, URLEncoder}
+import java.net.{URI, URL, URLEncoder}
import java.util.concurrent.{Executors, ExecutorService}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.spark.SparkEnv
+import org.apache.spark.util.Utils
+
import org.objectweb.asm._
import org.objectweb.asm.Opcodes._
@@ -53,7 +56,13 @@ extends ClassLoader(parent) {
if (fileSystem != null) {
fileSystem.open(new Path(directory, pathInDirectory))
} else {
- new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream()
+ if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
+ val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
+ val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager)
+ newuri.toURL().openStream()
+ } else {
+ new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream()
+ }
}
}
val bytes = readAndTransformClass(name, inputStream)
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
index f52ebe4a15..9b1da19500 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -881,6 +881,8 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
})
def process(settings: Settings): Boolean = savingContextLoader {
+ if (getMaster() == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true")
+
this.settings = settings
createInterpreter()
@@ -939,16 +941,9 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
def createSparkContext(): SparkContext = {
val execUri = System.getenv("SPARK_EXECUTOR_URI")
- val master = this.master match {
- case Some(m) => m
- case None => {
- val prop = System.getenv("MASTER")
- if (prop != null) prop else "local"
- }
- }
val jars = SparkILoop.getAddedJars.map(new java.io.File(_).getAbsolutePath)
val conf = new SparkConf()
- .setMaster(master)
+ .setMaster(getMaster())
.setAppName("Spark shell")
.setJars(jars)
.set("spark.repl.class.uri", intp.classServer.uri)
@@ -963,6 +958,17 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
sparkContext
}
+ private def getMaster(): String = {
+ val master = this.master match {
+ case Some(m) => m
+ case None => {
+ val prop = System.getenv("MASTER")
+ if (prop != null) prop else "local"
+ }
+ }
+ master
+ }
+
/** process command-line arguments and do as they request */
def process(args: Array[String]): Boolean = {
val command = new SparkCommandLine(args.toList, msg => echo(msg))
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
index 1d73d0b699..90a96ad383 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
@@ -36,7 +36,7 @@ import scala.tools.reflect.StdRuntimeTags._
import scala.util.control.ControlThrowable
import util.stackTraceString
-import org.apache.spark.{HttpServer, SparkConf, Logging}
+import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf}
import org.apache.spark.util.Utils
// /** directory to save .class files to */
@@ -83,15 +83,17 @@ import org.apache.spark.util.Utils
* @author Moez A. Abdel-Gawad
* @author Lex Spoon
*/
- class SparkIMain(initialSettings: Settings, val out: JPrintWriter) extends SparkImports with Logging {
+ class SparkIMain(initialSettings: Settings, val out: JPrintWriter)
+ extends SparkImports with Logging {
imain =>
- val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1")
+ val conf = new SparkConf()
+ val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1")
/** Local directory to save .class files too */
val outputDir = {
val tmp = System.getProperty("java.io.tmpdir")
- val rootDir = new SparkConf().get("spark.repl.classdir", tmp)
+ val rootDir = conf.get("spark.repl.classdir", tmp)
Utils.createTempDir(rootDir)
}
if (SPARK_DEBUG_REPL) {
@@ -99,7 +101,8 @@ import org.apache.spark.util.Utils
}
val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles
- val classServer = new HttpServer(outputDir) /** Jetty server that will serve our classes to worker nodes */
+ val classServer = new HttpServer(outputDir,
+ new SecurityManager(conf)) /** Jetty server that will serve our classes to worker nodes */
private var currentSettings: Settings = initialSettings
var printResults = true // whether to print result lines
var totalSilence = false // whether to print anything
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index e045b9f024..bb574f4152 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -27,7 +27,6 @@ import scala.collection.JavaConversions._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.net.NetUtils
-import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.util.ShutdownHookManager
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.records._
@@ -36,7 +35,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
-import org.apache.spark.{SparkConf, SparkContext, Logging}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.Utils
@@ -87,27 +86,16 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts
resourceManager = registerWithResourceManager()
- // Workaround until hadoop moves to something which has
- // https://issues.apache.org/jira/browse/HADOOP-8406 - fixed in (2.0.2-alpha but no 0.23 line)
- // ignore result.
- // This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times
- // Hence args.workerCores = numCore disabled above. Any better option?
-
- // Compute number of threads for akka
- //val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()
- //if (minimumMemory > 0) {
- // val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD
- // val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0)
-
- // if (numCore > 0) {
- // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406
- // TODO: Uncomment when hadoop is on a version which has this fixed.
- // args.workerCores = numCore
- // }
- //}
- // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf)
+ // setup AmIpFilter for the SparkUI - do this before we start the UI
+ addAmIpFilter()
ApplicationMaster.register(this)
+
+ // Call this to force generation of secret so it gets populated into the
+ // hadoop UGI. This has to happen before the startUserClass which does a
+ // doAs in order for the credentials to be passed on to the worker containers.
+ val securityMgr = new SecurityManager(sparkConf)
+
// Start the user's JAR
userThread = startUserClass()
@@ -132,6 +120,20 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
System.exit(0)
}
+ // add the yarn amIpFilter that Yarn requires for properly securing the UI
+ private def addAmIpFilter() {
+ val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter"
+ System.setProperty("spark.ui.filters", amFilter)
+ val proxy = YarnConfiguration.getProxyHostAndPort(conf)
+ val parts : Array[String] = proxy.split(":")
+ val uriBase = "http://" + proxy +
+ System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV)
+
+ val params = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase
+ System.setProperty("spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params",
+ params)
+ }
+
/** Get the Yarn approved local directories. */
private def getLocalDirs(): String = {
// Hadoop 0.23 and 2.x have different Environment variable names for the
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
index 138c27910b..b735d01df8 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
@@ -29,7 +29,7 @@ import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
import akka.actor._
import akka.remote._
import akka.actor.Terminated
-import org.apache.spark.{SparkConf, SparkContext, Logging}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext}
import org.apache.spark.util.{Utils, AkkaUtils}
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
import org.apache.spark.scheduler.SplitInfo
@@ -50,8 +50,9 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
private var yarnAllocator: YarnAllocationHandler = _
private var driverClosed:Boolean = false
+ val securityManager = new SecurityManager(sparkConf)
val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0,
- conf = sparkConf)._1
+ conf = sparkConf, securityManager = securityManager)._1
var actor: ActorRef = _
// This actor just working as a monitor to watch on Driver Actor.
@@ -110,6 +111,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
// we want to be reasonably responsive without causing too many requests to RM.
val schedulerInterval =
System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong
+
// must be <= timeoutInterval / 2.
val interval = math.min(timeoutInterval / 2, schedulerInterval)
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index fe37168e5a..11322b1202 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -134,7 +134,7 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) {
" --args ARGS Arguments to be passed to your application's main class.\n" +
" Mutliple invocations are possible, each will be passed in order.\n" +
" --num-workers NUM Number of workers to start (Default: 2)\n" +
- " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
+ " --worker-cores NUM Number of cores for the workers (Default: 1).\n" +
" --master-class CLASS_NAME Class Name for Master (Default: spark.deploy.yarn.ApplicationMaster)\n" +
" --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
" --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
index d6c12a9f59..4c6e1dcd6d 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -17,11 +17,13 @@
package org.apache.spark.deploy.yarn
-import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.hadoop.io.Text
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.security.Credentials
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.deploy.SparkHadoopUtil
/**
* Contains util methods to interact with Hadoop from spark.
@@ -44,4 +46,24 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
val jobCreds = conf.getCredentials()
jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials())
}
+
+ override def getCurrentUserCredentials(): Credentials = {
+ UserGroupInformation.getCurrentUser().getCredentials()
+ }
+
+ override def addCurrentUserCredentials(creds: Credentials) {
+ UserGroupInformation.getCurrentUser().addCredentials(creds)
+ }
+
+ override def addSecretKeyToUserCredentials(key: String, secret: String) {
+ val creds = new Credentials()
+ creds.addSecretKey(new Text(key), secret.getBytes())
+ addCurrentUserCredentials(creds)
+ }
+
+ override def getSecretKeyFromUserCredentials(key: String): Array[Byte] = {
+ val credentials = getCurrentUserCredentials()
+ if (credentials != null) credentials.getSecretKey(new Text(key)) else null
+ }
+
}
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index dd117d5810..b48a2d50db 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -27,7 +27,6 @@ import scala.collection.JavaConversions._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.net.NetUtils
-import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.util.ShutdownHookManager
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.protocolrecords._
@@ -37,8 +36,9 @@ import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+import org.apache.hadoop.yarn.webapp.util.WebAppUtils;
-import org.apache.spark.{SparkConf, SparkContext, Logging}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.Utils
@@ -91,12 +91,16 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
amClient.init(yarnConf)
amClient.start()
- // Workaround until hadoop moves to something which has
- // https://issues.apache.org/jira/browse/HADOOP-8406 - fixed in (2.0.2-alpha but no 0.23 line)
- // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf)
+ // setup AmIpFilter for the SparkUI - do this before we start the UI
+ addAmIpFilter()
ApplicationMaster.register(this)
+ // Call this to force generation of secret so it gets populated into the
+ // hadoop UGI. This has to happen before the startUserClass which does a
+ // doAs in order for the credentials to be passed on to the worker containers.
+ val securityMgr = new SecurityManager(sparkConf)
+
// Start the user's JAR
userThread = startUserClass()
@@ -121,6 +125,19 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
System.exit(0)
}
+ // add the yarn amIpFilter that Yarn requires for properly securing the UI
+ private def addAmIpFilter() {
+ val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter"
+ System.setProperty("spark.ui.filters", amFilter)
+ val proxy = WebAppUtils.getProxyHostAndPort(conf)
+ val parts : Array[String] = proxy.split(":")
+ val uriBase = "http://" + proxy +
+ System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV)
+
+ val params = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase
+ System.setProperty("spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", params)
+ }
+
/** Get the Yarn approved local directories. */
private def getLocalDirs(): String = {
// Hadoop 0.23 and 2.x have different Environment variable names for the
@@ -261,7 +278,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
val schedulerInterval =
sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000)
-
// must be <= timeoutInterval / 2.
val interval = math.min(timeoutInterval / 2, schedulerInterval)
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
index 40600f38e5..f1c1fea0b5 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
@@ -28,7 +28,7 @@ import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
import akka.actor._
import akka.remote._
import akka.actor.Terminated
-import org.apache.spark.{SparkConf, SparkContext, Logging}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext}
import org.apache.spark.util.{Utils, AkkaUtils}
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
import org.apache.spark.scheduler.SplitInfo
@@ -52,8 +52,9 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
private var amClient: AMRMClient[ContainerRequest] = _
+ val securityManager = new SecurityManager(sparkConf)
val actorSystem: ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0,
- conf = sparkConf)._1
+ conf = sparkConf, securityManager = securityManager)._1
var actor: ActorRef = _
// This actor just working as a monitor to watch on Driver Actor.
@@ -105,6 +106,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
val interval = math.min(timeoutInterval / 2, schedulerInterval)
reporterThread = launchReporterThread(interval)
+
// Wait for the reporter thread to Finish.
reporterThread.join()