aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorThomas Graves <tgraves@apache.org>2014-03-06 18:27:50 -0600
committerThomas Graves <tgraves@apache.org>2014-03-06 18:27:50 -0600
commit7edbea41b43e0dc11a2de156be220db8b7952d01 (patch)
tree1e7156df70131660d28994795eb3df2c3fd0a819 /core
parent40566e10aae4b21ffc71ea72702b8df118ac5c8e (diff)
downloadspark-7edbea41b43e0dc11a2de156be220db8b7952d01.tar.gz
spark-7edbea41b43e0dc11a2de156be220db8b7952d01.tar.bz2
spark-7edbea41b43e0dc11a2de156be220db8b7952d01.zip
SPARK-1189: Add Security to Spark - Akka, Http, ConnectionManager, UI use servlets
resubmit pull request. was https://github.com/apache/incubator-spark/pull/332. Author: Thomas Graves <tgraves@apache.org> Closes #33 from tgravescs/security-branch-0.9-with-client-rebase and squashes the following commits: dfe3918 [Thomas Graves] Fix merge conflict since startUserClass now using runAsUser 05eebed [Thomas Graves] Fix dependency lost in upmerge d1040ec [Thomas Graves] Fix up various imports 05ff5e0 [Thomas Graves] Fix up imports after upmerging to master ac046b3 [Thomas Graves] Merge remote-tracking branch 'upstream/master' into security-branch-0.9-with-client-rebase 13733e1 [Thomas Graves] Pass securityManager and SparkConf around where we can. Switch to use sparkConf for reading config whereever possible. Added ConnectionManagerSuite unit tests. 4a57acc [Thomas Graves] Change UI createHandler routines to createServlet since they now return servlets 2f77147 [Thomas Graves] Rework from comments 50dd9f2 [Thomas Graves] fix header in SecurityManager ecbfb65 [Thomas Graves] Fix spacing and formatting b514bec [Thomas Graves] Fix reference to config ed3d1c1 [Thomas Graves] Add security.md 6f7ddf3 [Thomas Graves] Convert SaslClient and SaslServer to scala, change spark.authenticate.ui to spark.ui.acls.enable, and fix up various other things from review comments 2d9e23e [Thomas Graves] Merge remote-tracking branch 'upstream/master' into security-branch-0.9-with-client-rebase_rework 5721c5a [Thomas Graves] update AkkaUtilsSuite test for the actorSelection changes, fix typos based on comments, and remove extra lines I missed in rebase from AkkaUtils f351763 [Thomas Graves] Add Security to Spark - Akka, Http, ConnectionManager, UI to use servlets
Diffstat (limited to 'core')
-rw-r--r--core/pom.xml16
-rw-r--r--core/src/main/scala/org/apache/spark/HttpFileServer.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/HttpServer.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/SecurityManager.scala253
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/SparkSaslClient.scala146
-rw-r--r--core/src/main/scala/org/apache/spark/SparkSaslServer.scala174
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/Client.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/Master.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala5
-rwxr-xr-xcore/src/main/scala/org/apache/spark/deploy/worker/Worker.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/network/BufferMessage.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/network/Connection.scala61
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionId.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManager.scala266
-rw-r--r--core/src/main/scala/org/apache/spark/network/Message.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/network/ReceiverTest.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/network/SecurityMessage.scala163
-rw-r--r--core/src/main/scala/org/apache/spark/network/SenderTest.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/ui/JettyUtils.scala138
-rw-r--r--core/src/main/scala/org/apache/spark/ui/SparkUI.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/util/AkkaUtils.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala37
-rw-r--r--core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala215
-rw-r--r--core/src/test/scala/org/apache/spark/BroadcastSuite.scala1
-rw-r--r--core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala230
-rw-r--r--core/src/test/scala/org/apache/spark/DriverSuite.scala1
-rw-r--r--core/src/test/scala/org/apache/spark/FileServerSuite.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala9
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala67
-rw-r--r--core/src/test/scala/org/apache/spark/ui/UISuite.scala10
57 files changed, 2043 insertions, 241 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 99c841472b..4c1c2d4da5 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -66,6 +66,18 @@
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-plus</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-security</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-util</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
</dependency>
<dependency>
@@ -119,6 +131,10 @@
<version>0.3.1</version>
</dependency>
<dependency>
+ <groupId>commons-net</groupId>
+ <artifactId>commons-net</artifactId>
+ </dependency>
+ <dependency>
<groupId>${akka.group}</groupId>
<artifactId>akka-remote_${scala.binary.version}</artifactId>
</dependency>
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
index d3264a4bb3..3d7692ea8a 100644
--- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -23,7 +23,7 @@ import com.google.common.io.Files
import org.apache.spark.util.Utils
-private[spark] class HttpFileServer extends Logging {
+private[spark] class HttpFileServer(securityManager: SecurityManager) extends Logging {
var baseDir : File = null
var fileDir : File = null
@@ -38,9 +38,10 @@ private[spark] class HttpFileServer extends Logging {
fileDir.mkdir()
jarDir.mkdir()
logInfo("HTTP File server directory is " + baseDir)
- httpServer = new HttpServer(baseDir)
+ httpServer = new HttpServer(baseDir, securityManager)
httpServer.start()
serverUri = httpServer.uri
+ logDebug("HTTP file server started at: " + serverUri)
}
def stop() {
diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala
index 759e68ee0c..cb5df25fa4 100644
--- a/core/src/main/scala/org/apache/spark/HttpServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpServer.scala
@@ -19,15 +19,18 @@ package org.apache.spark
import java.io.File
+import org.eclipse.jetty.util.security.{Constraint, Password}
+import org.eclipse.jetty.security.authentication.DigestAuthenticator
+import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService, SecurityHandler}
+
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.bio.SocketConnector
-import org.eclipse.jetty.server.handler.DefaultHandler
-import org.eclipse.jetty.server.handler.HandlerList
-import org.eclipse.jetty.server.handler.ResourceHandler
+import org.eclipse.jetty.server.handler.{DefaultHandler, HandlerList, ResourceHandler}
import org.eclipse.jetty.util.thread.QueuedThreadPool
import org.apache.spark.util.Utils
+
/**
* Exception type thrown by HttpServer when it is in the wrong state for an operation.
*/
@@ -38,7 +41,8 @@ private[spark] class ServerStateException(message: String) extends Exception(mes
* as well as classes created by the interpreter when the user types in code. This is just a wrapper
* around a Jetty server.
*/
-private[spark] class HttpServer(resourceBase: File) extends Logging {
+private[spark] class HttpServer(resourceBase: File, securityManager: SecurityManager)
+ extends Logging {
private var server: Server = null
private var port: Int = -1
@@ -59,14 +63,60 @@ private[spark] class HttpServer(resourceBase: File) extends Logging {
server.setThreadPool(threadPool)
val resHandler = new ResourceHandler
resHandler.setResourceBase(resourceBase.getAbsolutePath)
+
val handlerList = new HandlerList
handlerList.setHandlers(Array(resHandler, new DefaultHandler))
- server.setHandler(handlerList)
+
+ if (securityManager.isAuthenticationEnabled()) {
+ logDebug("HttpServer is using security")
+ val sh = setupSecurityHandler(securityManager)
+ // make sure we go through security handler to get resources
+ sh.setHandler(handlerList)
+ server.setHandler(sh)
+ } else {
+ logDebug("HttpServer is not using security")
+ server.setHandler(handlerList)
+ }
+
server.start()
port = server.getConnectors()(0).getLocalPort()
}
}
+ /**
+ * Setup Jetty to the HashLoginService using a single user with our
+ * shared secret. Configure it to use DIGEST-MD5 authentication so that the password
+ * isn't passed in plaintext.
+ */
+ private def setupSecurityHandler(securityMgr: SecurityManager): ConstraintSecurityHandler = {
+ val constraint = new Constraint()
+ // use DIGEST-MD5 as the authentication mechanism
+ constraint.setName(Constraint.__DIGEST_AUTH)
+ constraint.setRoles(Array("user"))
+ constraint.setAuthenticate(true)
+ constraint.setDataConstraint(Constraint.DC_NONE)
+
+ val cm = new ConstraintMapping()
+ cm.setConstraint(constraint)
+ cm.setPathSpec("/*")
+ val sh = new ConstraintSecurityHandler()
+
+ // the hashLoginService lets us do a single user and
+ // secret right now. This could be changed to use the
+ // JAASLoginService for other options.
+ val hashLogin = new HashLoginService()
+
+ val userCred = new Password(securityMgr.getSecretKey())
+ if (userCred == null) {
+ throw new Exception("Error: secret key is null with authentication on")
+ }
+ hashLogin.putUser(securityMgr.getHttpUser(), userCred, Array("user"))
+ sh.setLoginService(hashLogin)
+ sh.setAuthenticator(new DigestAuthenticator());
+ sh.setConstraintMappings(Array(cm))
+ sh
+ }
+
def stop() {
if (server == null) {
throw new ServerStateException("Server is already stopped")
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
new file mode 100644
index 0000000000..591978c1d3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.net.{Authenticator, PasswordAuthentication}
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.security.Credentials
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.spark.deploy.SparkHadoopUtil
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * Spark class responsible for security.
+ *
+ * In general this class should be instantiated by the SparkEnv and most components
+ * should access it from that. There are some cases where the SparkEnv hasn't been
+ * initialized yet and this class must be instantiated directly.
+ *
+ * Spark currently supports authentication via a shared secret.
+ * Authentication can be configured to be on via the 'spark.authenticate' configuration
+ * parameter. This parameter controls whether the Spark communication protocols do
+ * authentication using the shared secret. This authentication is a basic handshake to
+ * make sure both sides have the same shared secret and are allowed to communicate.
+ * If the shared secret is not identical they will not be allowed to communicate.
+ *
+ * The Spark UI can also be secured by using javax servlet filters. A user may want to
+ * secure the UI if it has data that other users should not be allowed to see. The javax
+ * servlet filter specified by the user can authenticate the user and then once the user
+ * is logged in, Spark can compare that user versus the view acls to make sure they are
+ * authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls'
+ * control the behavior of the acls. Note that the person who started the application
+ * always has view access to the UI.
+ *
+ * Spark does not currently support encryption after authentication.
+ *
+ * At this point spark has multiple communication protocols that need to be secured and
+ * different underlying mechanisms are used depending on the protocol:
+ *
+ * - Akka -> The only option here is to use the Akka Remote secure-cookie functionality.
+ * Akka remoting allows you to specify a secure cookie that will be exchanged
+ * and ensured to be identical in the connection handshake between the client
+ * and the server. If they are not identical then the client will be refused
+ * to connect to the server. There is no control of the underlying
+ * authentication mechanism so its not clear if the password is passed in
+ * plaintext or uses DIGEST-MD5 or some other mechanism.
+ * Akka also has an option to turn on SSL, this option is not currently supported
+ * but we could add a configuration option in the future.
+ *
+ * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty
+ * for the HttpServer. Jetty supports multiple authentication mechanisms -
+ * Basic, Digest, Form, Spengo, etc. It also supports multiple different login
+ * services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService
+ * to authenticate using DIGEST-MD5 via a single user and the shared secret.
+ * Since we are using DIGEST-MD5, the shared secret is not passed on the wire
+ * in plaintext.
+ * We currently do not support SSL (https), but Jetty can be configured to use it
+ * so we could add a configuration option for this in the future.
+ *
+ * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5.
+ * Any clients must specify the user and password. There is a default
+ * Authenticator installed in the SecurityManager to how it does the authentication
+ * and in this case gets the user name and password from the request.
+ *
+ * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously
+ * exchange messages. For this we use the Java SASL
+ * (Simple Authentication and Security Layer) API and again use DIGEST-MD5
+ * as the authentication mechanism. This means the shared secret is not passed
+ * over the wire in plaintext.
+ * Note that SASL is pluggable as to what mechanism it uses. We currently use
+ * DIGEST-MD5 but this could be changed to use Kerberos or other in the future.
+ * Spark currently supports "auth" for the quality of protection, which means
+ * the connection is not supporting integrity or privacy protection (encryption)
+ * after authentication. SASL also supports "auth-int" and "auth-conf" which
+ * SPARK could be support in the future to allow the user to specify the quality
+ * of protection they want. If we support those, the messages will also have to
+ * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's.
+ *
+ * Since the connectionManager does asynchronous messages passing, the SASL
+ * authentication is a bit more complex. A ConnectionManager can be both a client
+ * and a Server, so for a particular connection is has to determine what to do.
+ * A ConnectionId was added to be able to track connections and is used to
+ * match up incoming messages with connections waiting for authentication.
+ * If its acting as a client and trying to send a message to another ConnectionManager,
+ * it blocks the thread calling sendMessage until the SASL negotiation has occurred.
+ * The ConnectionManager tracks all the sendingConnections using the ConnectionId
+ * and waits for the response from the server and does the handshake.
+ *
+ * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
+ * can be used. Yarn requires a specific AmIpFilter be installed for security to work
+ * properly. For non-Yarn deployments, users can write a filter to go through a
+ * companies normal login service. If an authentication filter is in place then the
+ * SparkUI can be configured to check the logged in user against the list of users who
+ * have view acls to see if that user is authorized.
+ * The filters can also be used for many different purposes. For instance filters
+ * could be used for logging, encryption, or compression.
+ *
+ * The exact mechanisms used to generate/distributed the shared secret is deployment specific.
+ *
+ * For Yarn deployments, the secret is automatically generated using the Akka remote
+ * Crypt.generateSecureCookie() API. The secret is placed in the Hadoop UGI which gets passed
+ * around via the Hadoop RPC mechanism. Hadoop RPC can be configured to support different levels
+ * of protection. See the Hadoop documentation for more details. Each Spark application on Yarn
+ * gets a different shared secret. On Yarn, the Spark UI gets configured to use the Hadoop Yarn
+ * AmIpFilter which requires the user to go through the ResourceManager Proxy. That Proxy is there
+ * to reduce the possibility of web based attacks through YARN. Hadoop can be configured to use
+ * filters to do authentication. That authentication then happens via the ResourceManager Proxy
+ * and Spark will use that to do authorization against the view acls.
+ *
+ * For other Spark deployments, the shared secret must be specified via the
+ * spark.authenticate.secret config.
+ * All the nodes (Master and Workers) and the applications need to have the same shared secret.
+ * This again is not ideal as one user could potentially affect another users application.
+ * This should be enhanced in the future to provide better protection.
+ * If the UI needs to be secured the user needs to install a javax servlet filter to do the
+ * authentication. Spark will then use that user to compare against the view acls to do
+ * authorization. If not filter is in place the user is generally null and no authorization
+ * can take place.
+ */
+
+private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
+
+ // key used to store the spark secret in the Hadoop UGI
+ private val sparkSecretLookupKey = "sparkCookie"
+
+ private val authOn = sparkConf.getBoolean("spark.authenticate", false)
+ private val uiAclsOn = sparkConf.getBoolean("spark.ui.acls.enable", false)
+
+ // always add the current user and SPARK_USER to the viewAcls
+ private val aclUsers = ArrayBuffer[String](System.getProperty("user.name", ""),
+ Option(System.getenv("SPARK_USER")).getOrElse(""))
+ aclUsers ++= sparkConf.get("spark.ui.view.acls", "").split(',')
+ private val viewAcls = aclUsers.map(_.trim()).filter(!_.isEmpty).toSet
+
+ private val secretKey = generateSecretKey()
+ logInfo("SecurityManager, is authentication enabled: " + authOn +
+ " are ui acls enabled: " + uiAclsOn + " users with view permissions: " + viewAcls.toString())
+
+ // Set our own authenticator to properly negotiate user/password for HTTP connections.
+ // This is needed by the HTTP client fetching from the HttpServer. Put here so its
+ // only set once.
+ if (authOn) {
+ Authenticator.setDefault(
+ new Authenticator() {
+ override def getPasswordAuthentication(): PasswordAuthentication = {
+ var passAuth: PasswordAuthentication = null
+ val userInfo = getRequestingURL().getUserInfo()
+ if (userInfo != null) {
+ val parts = userInfo.split(":", 2)
+ passAuth = new PasswordAuthentication(parts(0), parts(1).toCharArray())
+ }
+ return passAuth
+ }
+ }
+ )
+ }
+
+ /**
+ * Generates or looks up the secret key.
+ *
+ * The way the key is stored depends on the Spark deployment mode. Yarn
+ * uses the Hadoop UGI.
+ *
+ * For non-Yarn deployments, If the config variable is not set
+ * we throw an exception.
+ */
+ private def generateSecretKey(): String = {
+ if (!isAuthenticationEnabled) return null
+ // first check to see if the secret is already set, else generate a new one if on yarn
+ val sCookie = if (SparkHadoopUtil.get.isYarnMode) {
+ val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(sparkSecretLookupKey)
+ if (secretKey != null) {
+ logDebug("in yarn mode, getting secret from credentials")
+ return new Text(secretKey).toString
+ } else {
+ logDebug("getSecretKey: yarn mode, secret key from credentials is null")
+ }
+ val cookie = akka.util.Crypt.generateSecureCookie
+ // if we generated the secret then we must be the first so lets set it so t
+ // gets used by everyone else
+ SparkHadoopUtil.get.addSecretKeyToUserCredentials(sparkSecretLookupKey, cookie)
+ logInfo("adding secret to credentials in yarn mode")
+ cookie
+ } else {
+ // user must have set spark.authenticate.secret config
+ sparkConf.getOption("spark.authenticate.secret") match {
+ case Some(value) => value
+ case None => throw new Exception("Error: a secret key must be specified via the " +
+ "spark.authenticate.secret config")
+ }
+ }
+ sCookie
+ }
+
+ /**
+ * Check to see if Acls for the UI are enabled
+ * @return true if UI authentication is enabled, otherwise false
+ */
+ def uiAclsEnabled(): Boolean = uiAclsOn
+
+ /**
+ * Checks the given user against the view acl list to see if they have
+ * authorization to view the UI. If the UI acls must are disabled
+ * via spark.ui.acls.enable, all users have view access.
+ *
+ * @param user to see if is authorized
+ * @return true is the user has permission, otherwise false
+ */
+ def checkUIViewPermissions(user: String): Boolean = {
+ if (uiAclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true
+ }
+
+ /**
+ * Check to see if authentication for the Spark communication protocols is enabled
+ * @return true if authentication is enabled, otherwise false
+ */
+ def isAuthenticationEnabled(): Boolean = authOn
+
+ /**
+ * Gets the user used for authenticating HTTP connections.
+ * For now use a single hardcoded user.
+ * @return the HTTP user as a String
+ */
+ def getHttpUser(): String = "sparkHttpUser"
+
+ /**
+ * Gets the user used for authenticating SASL connections.
+ * For now use a single hardcoded user.
+ * @return the SASL user as a String
+ */
+ def getSaslUser(): String = "sparkSaslUser"
+
+ /**
+ * Gets the secret key.
+ * @return the secret key as a String if authentication is enabled, otherwise returns null
+ */
+ def getSecretKey(): String = secretKey
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index da778aa851..24731ad706 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -130,6 +130,8 @@ class SparkContext(
val isLocal = (master == "local" || master.startsWith("local["))
+ if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true")
+
// Create the Spark execution environment (cache, map output tracker, etc)
private[spark] val env = SparkEnv.create(
conf,
@@ -634,7 +636,7 @@ class SparkContext(
addedFiles(key) = System.currentTimeMillis
// Fetch the file locally in case a job is executed using DAGScheduler.runLocally().
- Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf)
+ Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf, env.securityManager)
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 7ac65828f6..5e43b51984 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -53,7 +53,8 @@ class SparkEnv private[spark] (
val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
- val conf: SparkConf) extends Logging {
+ val conf: SparkConf,
+ val securityManager: SecurityManager) extends Logging {
// A mapping of thread ID to amount of memory used for shuffle in bytes
// All accesses should be manually synchronized
@@ -122,8 +123,9 @@ object SparkEnv extends Logging {
isDriver: Boolean,
isLocal: Boolean): SparkEnv = {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port,
- conf = conf)
+ val securityManager = new SecurityManager(conf)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, conf = conf,
+ securityManager = securityManager)
// Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port),
// figure out which port number Akka actually bound to and set spark.driver.port to it.
@@ -139,7 +141,6 @@ object SparkEnv extends Logging {
val name = conf.get(propertyName, defaultClassName)
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
}
-
val serializerManager = new SerializerManager
val serializer = serializerManager.setDefault(
@@ -167,12 +168,12 @@ object SparkEnv extends Logging {
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal, conf)), conf)
- val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
- serializer, conf)
+ val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
+ serializer, conf, securityManager)
val connectionManager = blockManager.connectionManager
- val broadcastManager = new BroadcastManager(isDriver, conf)
+ val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
val cacheManager = new CacheManager(blockManager)
@@ -190,14 +191,14 @@ object SparkEnv extends Logging {
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
- val httpFileServer = new HttpFileServer()
+ val httpFileServer = new HttpFileServer(securityManager)
httpFileServer.initialize()
conf.set("spark.fileserver.uri", httpFileServer.serverUri)
val metricsSystem = if (isDriver) {
- MetricsSystem.createMetricsSystem("driver", conf)
+ MetricsSystem.createMetricsSystem("driver", conf, securityManager)
} else {
- MetricsSystem.createMetricsSystem("executor", conf)
+ MetricsSystem.createMetricsSystem("executor", conf, securityManager)
}
metricsSystem.start()
@@ -231,6 +232,7 @@ object SparkEnv extends Logging {
httpFileServer,
sparkFilesDir,
metricsSystem,
- conf)
+ conf,
+ securityManager)
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
new file mode 100644
index 0000000000..a2a871cbd3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.IOException
+import javax.security.auth.callback.Callback
+import javax.security.auth.callback.CallbackHandler
+import javax.security.auth.callback.NameCallback
+import javax.security.auth.callback.PasswordCallback
+import javax.security.auth.callback.UnsupportedCallbackException
+import javax.security.sasl.RealmCallback
+import javax.security.sasl.RealmChoiceCallback
+import javax.security.sasl.Sasl
+import javax.security.sasl.SaslClient
+import javax.security.sasl.SaslException
+
+import scala.collection.JavaConversions.mapAsJavaMap
+
+/**
+ * Implements SASL Client logic for Spark
+ */
+private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging {
+
+ /**
+ * Used to respond to server's counterpart, SaslServer with SASL tokens
+ * represented as byte arrays.
+ *
+ * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
+ * configurable in the future.
+ */
+ private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST),
+ null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
+ new SparkSaslClientCallbackHandler(securityMgr))
+
+ /**
+ * Used to initiate SASL handshake with server.
+ * @return response to challenge if needed
+ */
+ def firstToken(): Array[Byte] = {
+ synchronized {
+ val saslToken: Array[Byte] =
+ if (saslClient != null && saslClient.hasInitialResponse()) {
+ logDebug("has initial response")
+ saslClient.evaluateChallenge(new Array[Byte](0))
+ } else {
+ new Array[Byte](0)
+ }
+ saslToken
+ }
+ }
+
+ /**
+ * Determines whether the authentication exchange has completed.
+ * @return true is complete, otherwise false
+ */
+ def isComplete(): Boolean = {
+ synchronized {
+ if (saslClient != null) saslClient.isComplete() else false
+ }
+ }
+
+ /**
+ * Respond to server's SASL token.
+ * @param saslTokenMessage contains server's SASL token
+ * @return client's response SASL token
+ */
+ def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = {
+ synchronized {
+ if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0)
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslClient might be using.
+ */
+ def dispose() {
+ synchronized {
+ if (saslClient != null) {
+ try {
+ saslClient.dispose()
+ } catch {
+ case e: SaslException => // ignored
+ } finally {
+ saslClient = null
+ }
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler
+ * that works with share secrets.
+ */
+ private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends
+ CallbackHandler {
+
+ private val userName: String =
+ SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes())
+ private val secretKey = securityMgr.getSecretKey()
+ private val userPassword: Array[Char] =
+ SparkSaslServer.encodePassword(if (secretKey != null) secretKey.getBytes() else "".getBytes())
+
+ /**
+ * Implementation used to respond to SASL request from the server.
+ *
+ * @param callbacks objects that indicate what credential information the
+ * server's SaslServer requires from the client.
+ */
+ override def handle(callbacks: Array[Callback]) {
+ logDebug("in the sasl client callback handler")
+ callbacks foreach {
+ case nc: NameCallback => {
+ logDebug("handle: SASL client callback: setting username: " + userName)
+ nc.setName(userName)
+ }
+ case pc: PasswordCallback => {
+ logDebug("handle: SASL client callback: setting userPassword")
+ pc.setPassword(userPassword)
+ }
+ case rc: RealmCallback => {
+ logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText())
+ rc.setText(rc.getDefaultText())
+ }
+ case cb: RealmChoiceCallback => {}
+ case cb: Callback => throw
+ new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback")
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
new file mode 100644
index 0000000000..11fcb2ae3a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import javax.security.auth.callback.Callback
+import javax.security.auth.callback.CallbackHandler
+import javax.security.auth.callback.NameCallback
+import javax.security.auth.callback.PasswordCallback
+import javax.security.auth.callback.UnsupportedCallbackException
+import javax.security.sasl.AuthorizeCallback
+import javax.security.sasl.RealmCallback
+import javax.security.sasl.Sasl
+import javax.security.sasl.SaslException
+import javax.security.sasl.SaslServer
+import scala.collection.JavaConversions.mapAsJavaMap
+import org.apache.commons.net.util.Base64
+
+/**
+ * Encapsulates SASL server logic
+ */
+private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging {
+
+ /**
+ * Actual SASL work done by this object from javax.security.sasl.
+ */
+ private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null,
+ SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
+ new SparkSaslDigestCallbackHandler(securityMgr))
+
+ /**
+ * Determines whether the authentication exchange has completed.
+ * @return true is complete, otherwise false
+ */
+ def isComplete(): Boolean = {
+ synchronized {
+ if (saslServer != null) saslServer.isComplete() else false
+ }
+ }
+
+ /**
+ * Used to respond to server SASL tokens.
+ * @param token Server's SASL token
+ * @return response to send back to the server.
+ */
+ def response(token: Array[Byte]): Array[Byte] = {
+ synchronized {
+ if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0)
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslServer might be using.
+ */
+ def dispose() {
+ synchronized {
+ if (saslServer != null) {
+ try {
+ saslServer.dispose()
+ } catch {
+ case e: SaslException => // ignore
+ } finally {
+ saslServer = null
+ }
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler
+ * for SASL DIGEST-MD5 mechanism
+ */
+ private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager)
+ extends CallbackHandler {
+
+ private val userName: String =
+ SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes())
+
+ override def handle(callbacks: Array[Callback]) {
+ logDebug("In the sasl server callback handler")
+ callbacks foreach {
+ case nc: NameCallback => {
+ logDebug("handle: SASL server callback: setting username")
+ nc.setName(userName)
+ }
+ case pc: PasswordCallback => {
+ logDebug("handle: SASL server callback: setting userPassword")
+ val password: Array[Char] =
+ SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes())
+ pc.setPassword(password)
+ }
+ case rc: RealmCallback => {
+ logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText())
+ rc.setText(rc.getDefaultText())
+ }
+ case ac: AuthorizeCallback => {
+ val authid = ac.getAuthenticationID()
+ val authzid = ac.getAuthorizationID()
+ if (authid.equals(authzid)) {
+ logDebug("set auth to true")
+ ac.setAuthorized(true)
+ } else {
+ logDebug("set auth to false")
+ ac.setAuthorized(false)
+ }
+ if (ac.isAuthorized()) {
+ logDebug("sasl server is authorized")
+ ac.setAuthorizedID(authzid)
+ }
+ }
+ case cb: Callback => throw
+ new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback")
+ }
+ }
+ }
+}
+
+private[spark] object SparkSaslServer {
+
+ /**
+ * This is passed as the server name when creating the sasl client/server.
+ * This could be changed to be configurable in the future.
+ */
+ val SASL_DEFAULT_REALM = "default"
+
+ /**
+ * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
+ * configurable in the future.
+ */
+ val DIGEST = "DIGEST-MD5"
+
+ /**
+ * The quality of protection is just "auth". This means that we are doing
+ * authentication only, we are not supporting integrity or privacy protection of the
+ * communication channel after authentication. This could be changed to be configurable
+ * in the future.
+ */
+ val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true")
+
+ /**
+ * Encode a byte[] identifier as a Base64-encoded string.
+ *
+ * @param identifier identifier to encode
+ * @return Base64-encoded string
+ */
+ def encodeIdentifier(identifier: Array[Byte]): String = {
+ new String(Base64.encodeBase64(identifier))
+ }
+
+ /**
+ * Encode a password as a base64-encoded char[] array.
+ * @param password as a byte array.
+ * @return password as a char array.
+ */
+ def encodePassword(password: Array[Byte]): Array[Char] = {
+ new String(Base64.encodeBase64(password)).toCharArray()
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
index d113d40405..e3c3a12d16 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -60,7 +60,8 @@ abstract class Broadcast[T](val id: Long) extends Serializable {
}
private[spark]
-class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging with Serializable {
+class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager)
+ extends Logging with Serializable {
private var initialized = false
private var broadcastFactory: BroadcastFactory = null
@@ -78,7 +79,7 @@ class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Initialize appropriate BroadcastFactory and BroadcastObject
- broadcastFactory.initialize(isDriver, conf)
+ broadcastFactory.initialize(isDriver, conf, securityManager)
initialized = true
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
index 940e5ab805..6beecaeced 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.broadcast
+import org.apache.spark.SecurityManager
import org.apache.spark.SparkConf
@@ -26,7 +27,7 @@ import org.apache.spark.SparkConf
* entire Spark job.
*/
trait BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf): Unit
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 20207c2613..e8eb04bb10 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -18,13 +18,13 @@
package org.apache.spark.broadcast
import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream}
-import java.net.URL
+import java.net.{URL, URLConnection, URI}
import java.util.concurrent.TimeUnit
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-import org.apache.spark.{HttpServer, Logging, SparkConf, SparkEnv}
+import org.apache.spark.{SparkConf, HttpServer, Logging, SecurityManager, SparkEnv}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
@@ -67,7 +67,9 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
* A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium.
*/
class HttpBroadcastFactory extends BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf) { HttpBroadcast.initialize(isDriver, conf) }
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
+ HttpBroadcast.initialize(isDriver, conf, securityMgr)
+ }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
@@ -83,6 +85,7 @@ private object HttpBroadcast extends Logging {
private var bufferSize: Int = 65536
private var serverUri: String = null
private var server: HttpServer = null
+ private var securityManager: SecurityManager = null
// TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist
private val files = new TimeStampedHashSet[String]
@@ -92,11 +95,12 @@ private object HttpBroadcast extends Logging {
private var compressionCodec: CompressionCodec = null
- def initialize(isDriver: Boolean, conf: SparkConf) {
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
synchronized {
if (!initialized) {
bufferSize = conf.getInt("spark.buffer.size", 65536)
compress = conf.getBoolean("spark.broadcast.compress", true)
+ securityManager = securityMgr
if (isDriver) {
createServer(conf)
conf.set("spark.httpBroadcast.uri", serverUri)
@@ -126,7 +130,7 @@ private object HttpBroadcast extends Logging {
private def createServer(conf: SparkConf) {
broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf))
- server = new HttpServer(broadcastDir)
+ server = new HttpServer(broadcastDir, securityManager)
server.start()
serverUri = server.uri
logInfo("Broadcast server started at " + serverUri)
@@ -149,11 +153,23 @@ private object HttpBroadcast extends Logging {
}
def read[T](id: Long): T = {
+ logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id)
val url = serverUri + "/" + BroadcastBlockId(id).name
+
+ var uc: URLConnection = null
+ if (securityManager.isAuthenticationEnabled()) {
+ logDebug("broadcast security enabled")
+ val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager)
+ uc = newuri.toURL().openConnection()
+ uc.setAllowUserInteraction(false)
+ } else {
+ logDebug("broadcast not using security")
+ uc = new URL(url).openConnection()
+ }
+
val in = {
- val httpConnection = new URL(url).openConnection()
- httpConnection.setReadTimeout(httpReadTimeout)
- val inputStream = httpConnection.getInputStream
+ uc.setReadTimeout(httpReadTimeout)
+ val inputStream = uc.getInputStream();
if (compress) {
compressionCodec.compressedInputStream(inputStream)
} else {
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 22d783c859..3cd7121376 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -241,7 +241,9 @@ private[spark] case class TorrentInfo(
*/
class TorrentBroadcastFactory extends BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) }
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
+ TorrentBroadcast.initialize(isDriver, conf)
+ }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index eb5676b51d..d9e3035e1a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -26,7 +26,7 @@ import akka.pattern.ask
import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent}
import org.apache.log4j.{Level, Logger}
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.{DriverState, Master}
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -141,7 +141,7 @@ object Client {
// TODO: See if we can initialize akka so return messages are sent back using the same TCP
// flow. Else, this (sadly) requires the DriverClient be routable from the Master.
val (actorSystem, _) = AkkaUtils.createActorSystem(
- "driverClient", Utils.localHostName(), 0, false, conf)
+ "driverClient", Utils.localHostName(), 0, false, conf, new SecurityManager(conf))
actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf))
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index ec15647e1d..d2d8d6d662 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -21,6 +21,7 @@ import java.security.PrivilegedExceptionAction
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.security.Credentials
import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.{SparkContext, SparkException}
@@ -65,6 +66,15 @@ class SparkHadoopUtil {
def addCredentials(conf: JobConf) {}
def isYarnMode(): Boolean = { false }
+
+ def getCurrentUserCredentials(): Credentials = { null }
+
+ def addCurrentUserCredentials(creds: Credentials) {}
+
+ def addSecretKeyToUserCredentials(key: String, secret: String) {}
+
+ def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { null }
+
}
object SparkHadoopUtil {
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
index 1550c3eb42..63f166d401 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy.client
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf, Logging}
import org.apache.spark.deploy.{ApplicationDescription, Command}
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -45,8 +45,9 @@ private[spark] object TestClient {
def main(args: Array[String]) {
val url = args(0)
+ val conf = new SparkConf
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0,
- conf = new SparkConf)
+ conf = conf, securityManager = new SecurityManager(conf))
val desc = new ApplicationDescription(
"TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()),
Some("dummy-spark-home"), "ignored")
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 51794ce40c..2d6d0c33fa 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -30,7 +30,7 @@ import akka.pattern.ask
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
import akka.serialization.SerializationExtension
-import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.DriverState.DriverState
@@ -39,7 +39,8 @@ import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{AkkaUtils, Utils}
-private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
+private[spark] class Master(host: String, port: Int, webUiPort: Int,
+ val securityMgr: SecurityManager) extends Actor with Logging {
import context.dispatcher // to use Akka's scheduler.schedule()
val conf = new SparkConf
@@ -70,8 +71,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
Utils.checkHost(host, "Expected hostname")
- val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf)
- val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf)
+ val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr)
+ val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf,
+ securityMgr)
val masterSource = new MasterSource(this)
val webUi = new MasterWebUI(this, webUiPort)
@@ -711,8 +713,11 @@ private[spark] object Master {
def startSystemAndActor(host: String, port: Int, webUiPort: Int, conf: SparkConf)
: (ActorSystem, Int, Int) =
{
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf)
- val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), actorName)
+ val securityMgr = new SecurityManager(conf)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf,
+ securityManager = securityMgr)
+ val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort,
+ securityMgr), actorName)
val timeout = AkkaUtils.askTimeout(conf)
val respFuture = actor.ask(RequestWebUIPort)(timeout)
val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse]
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index 5ab13e7aa6..a7bd01e284 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -18,8 +18,8 @@
package org.apache.spark.deploy.master.ui
import javax.servlet.http.HttpServletRequest
-
-import org.eclipse.jetty.server.{Handler, Server}
+import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.Logging
import org.apache.spark.deploy.master.Master
@@ -46,7 +46,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging {
def start() {
try {
- val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers)
+ val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, master.conf)
server = Some(srv)
boundPort = Some(bPort)
logInfo("Started Master web UI at http://%s:%d".format(host, boundPort.get))
@@ -60,12 +60,17 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging {
val metricsHandlers = master.masterMetricsSystem.getServletHandlers ++
master.applicationMetricsSystem.getServletHandlers
- val handlers = metricsHandlers ++ Array[(String, Handler)](
- ("/static", createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR)),
- ("/app/json", (request: HttpServletRequest) => applicationPage.renderJson(request)),
- ("/app", (request: HttpServletRequest) => applicationPage.render(request)),
- ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)),
- ("*", (request: HttpServletRequest) => indexPage.render(request))
+ val handlers = metricsHandlers ++ Seq[ServletContextHandler](
+ createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static/*"),
+ createServletHandler("/app/json",
+ createServlet((request: HttpServletRequest) => applicationPage.renderJson(request),
+ master.securityMgr)),
+ createServletHandler("/app", createServlet((request: HttpServletRequest) => applicationPage
+ .render(request), master.securityMgr)),
+ createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage
+ .renderJson(request), master.securityMgr)),
+ createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render
+ (request), master.securityMgr))
)
def stop() {
@@ -74,5 +79,5 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging {
}
private[spark] object MasterWebUI {
- val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static"
+ val STATIC_RESOURCE_DIR = "org/apache/spark/ui"
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
index a26e47950a..be15138f62 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
@@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker
import akka.actor._
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.util.{AkkaUtils, Utils}
/**
@@ -29,8 +29,9 @@ object DriverWrapper {
def main(args: Array[String]) {
args.toList match {
case workerUrl :: mainClass :: extraArgs =>
+ val conf = new SparkConf()
val (actorSystem, _) = AkkaUtils.createActorSystem("Driver",
- Utils.localHostName(), 0, false, new SparkConf())
+ Utils.localHostName(), 0, false, conf, new SecurityManager(conf))
actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher")
// Delegate to supplied main class
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 7b0b7861b7..afaabedffe 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -27,7 +27,7 @@ import scala.concurrent.duration._
import akka.actor._
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
-import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.{DriverState, Master}
@@ -48,7 +48,8 @@ private[spark] class Worker(
actorSystemName: String,
actorName: String,
workDirPath: String = null,
- val conf: SparkConf)
+ val conf: SparkConf,
+ val securityMgr: SecurityManager)
extends Actor with Logging {
import context.dispatcher
@@ -91,7 +92,7 @@ private[spark] class Worker(
var coresUsed = 0
var memoryUsed = 0
- val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf)
+ val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr)
val workerSource = new WorkerSource(this)
def coresFree: Int = cores - coresUsed
@@ -347,10 +348,11 @@ private[spark] object Worker {
val conf = new SparkConf
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
val actorName = "Worker"
+ val securityMgr = new SecurityManager(conf)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port,
- conf = conf)
+ conf = conf, securityManager = securityMgr)
actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
- masterUrls, systemName, actorName, workDir, conf), name = actorName)
+ masterUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName)
(actorSystem, boundPort)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index bdf126f93a..ffc05bd306 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -19,8 +19,8 @@ package org.apache.spark.deploy.worker.ui
import java.io.File
import javax.servlet.http.HttpServletRequest
-
-import org.eclipse.jetty.server.{Handler, Server}
+import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.Logging
import org.apache.spark.deploy.worker.Worker
@@ -33,7 +33,7 @@ import org.apache.spark.util.{AkkaUtils, Utils}
*/
private[spark]
class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[Int] = None)
- extends Logging {
+ extends Logging {
val timeout = AkkaUtils.askTimeout(worker.conf)
val host = Utils.localHostName()
val port = requestedPort.getOrElse(
@@ -46,17 +46,21 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
val metricsHandlers = worker.metricsSystem.getServletHandlers
- val handlers = metricsHandlers ++ Array[(String, Handler)](
- ("/static", createStaticHandler(WorkerWebUI.STATIC_RESOURCE_DIR)),
- ("/log", (request: HttpServletRequest) => log(request)),
- ("/logPage", (request: HttpServletRequest) => logPage(request)),
- ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)),
- ("*", (request: HttpServletRequest) => indexPage.render(request))
+ val handlers = metricsHandlers ++ Seq[ServletContextHandler](
+ createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static/*"),
+ createServletHandler("/log", createServlet((request: HttpServletRequest) => log(request),
+ worker.securityMgr)),
+ createServletHandler("/logPage", createServlet((request: HttpServletRequest) => logPage
+ (request), worker.securityMgr)),
+ createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage
+ .renderJson(request), worker.securityMgr)),
+ createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render
+ (request), worker.securityMgr))
)
def start() {
try {
- val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers)
+ val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, worker.conf)
server = Some(srv)
boundPort = Some(bPort)
logInfo("Started Worker web UI at http://%s:%d".format(host, bPort))
@@ -198,6 +202,6 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
}
private[spark] object WorkerWebUI {
- val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static"
+ val STATIC_RESOURCE_BASE = "org/apache/spark/ui"
val DEFAULT_PORT="8081"
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 0aae569b17..3486092a14 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -22,7 +22,7 @@ import java.nio.ByteBuffer
import akka.actor._
import akka.remote._
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf, Logging}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.worker.WorkerWatcher
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
@@ -97,10 +97,11 @@ private[spark] object CoarseGrainedExecutorBackend {
// Debug code
Utils.checkHost(hostname)
+ val conf = new SparkConf
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0,
- indestructible = true, conf = new SparkConf)
+ indestructible = true, conf = conf, new SecurityManager(conf))
// set it
val sparkHostPort = hostname + ":" + boundPort
actorSystem.actorOf(
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 989d666f15..e69f6f72d3 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -69,11 +69,6 @@ private[spark] class Executor(
conf.set("spark.local.dir", getYarnLocalDirs())
}
- // Create our ClassLoader and set it on this thread
- private val urlClassLoader = createClassLoader()
- private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
- Thread.currentThread.setContextClassLoader(replClassLoader)
-
if (!isLocal) {
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
@@ -117,6 +112,12 @@ private[spark] class Executor(
}
}
+ // Create our ClassLoader and set it on this thread
+ // do this after SparkEnv creation so can access the SecurityManager
+ private val urlClassLoader = createClassLoader()
+ private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
+ Thread.currentThread.setContextClassLoader(replClassLoader)
+
// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
private val akkaFrameSize = {
@@ -338,12 +339,12 @@ private[spark] class Executor(
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager)
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager)
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index 966c092124..c5bda2078f 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable
import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry}
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.metrics.sink.{MetricsServlet, Sink}
import org.apache.spark.metrics.source.Source
@@ -64,7 +64,7 @@ import org.apache.spark.metrics.source.Source
* [options] is the specific property of this source or sink.
*/
private[spark] class MetricsSystem private (val instance: String,
- conf: SparkConf) extends Logging {
+ conf: SparkConf, securityMgr: SecurityManager) extends Logging {
val confFile = conf.get("spark.metrics.conf", null)
val metricsConfig = new MetricsConfig(Option(confFile))
@@ -131,8 +131,8 @@ private[spark] class MetricsSystem private (val instance: String,
val classPath = kv._2.getProperty("class")
try {
val sink = Class.forName(classPath)
- .getConstructor(classOf[Properties], classOf[MetricRegistry])
- .newInstance(kv._2, registry)
+ .getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager])
+ .newInstance(kv._2, registry, securityMgr)
if (kv._1 == "servlet") {
metricsServlet = Some(sink.asInstanceOf[MetricsServlet])
} else {
@@ -160,6 +160,7 @@ private[spark] object MetricsSystem {
}
}
- def createMetricsSystem(instance: String, conf: SparkConf): MetricsSystem =
- new MetricsSystem(instance, conf)
+ def createMetricsSystem(instance: String, conf: SparkConf,
+ securityMgr: SecurityManager): MetricsSystem =
+ new MetricsSystem(instance, conf, securityMgr)
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
index 98fa1dbd7c..4d2ffc54d8 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
@@ -22,9 +22,11 @@ import java.util.concurrent.TimeUnit
import com.codahale.metrics.{ConsoleReporter, MetricRegistry}
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class ConsoleSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+class ConsoleSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val CONSOLE_DEFAULT_PERIOD = 10
val CONSOLE_DEFAULT_UNIT = "SECONDS"
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
index 40f64768e6..319f40815d 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
@@ -23,9 +23,11 @@ import java.util.concurrent.TimeUnit
import com.codahale.metrics.{CsvReporter, MetricRegistry}
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class CsvSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+class CsvSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val CSV_KEY_PERIOD = "period"
val CSV_KEY_UNIT = "unit"
val CSV_KEY_DIR = "directory"
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala
index 410ca0704b..cd37317da7 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala
@@ -24,9 +24,11 @@ import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.ganglia.GangliaReporter
import info.ganglia.gmetric4j.gmetric.GMetric
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class GangliaSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+class GangliaSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val GANGLIA_KEY_PERIOD = "period"
val GANGLIA_DEFAULT_PERIOD = 10
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
index e09be00142..0ffdf3846d 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
@@ -24,9 +24,11 @@ import java.util.concurrent.TimeUnit
import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.graphite.{Graphite, GraphiteReporter}
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+class GraphiteSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val GRAPHITE_DEFAULT_PERIOD = 10
val GRAPHITE_DEFAULT_UNIT = "SECONDS"
val GRAPHITE_DEFAULT_PREFIX = ""
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
index b5cf210af2..3b5edd5c37 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
@@ -20,8 +20,11 @@ package org.apache.spark.metrics.sink
import java.util.Properties
import com.codahale.metrics.{JmxReporter, MetricRegistry}
+import org.apache.spark.SecurityManager
+
+class JmxSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
-class JmxSink(val property: Properties, val registry: MetricRegistry) extends Sink {
val reporter: JmxReporter = JmxReporter.forRegistry(registry).build()
override def start() {
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
index 3cdfe26d40..3110eccdee 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
@@ -19,16 +19,19 @@ package org.apache.spark.metrics.sink
import java.util.Properties
import java.util.concurrent.TimeUnit
+
import javax.servlet.http.HttpServletRequest
import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.json.MetricsModule
import com.fasterxml.jackson.databind.ObjectMapper
-import org.eclipse.jetty.server.Handler
+import org.eclipse.jetty.servlet.ServletContextHandler
+import org.apache.spark.SecurityManager
import org.apache.spark.ui.JettyUtils
-class MetricsServlet(val property: Properties, val registry: MetricRegistry) extends Sink {
+class MetricsServlet(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val SERVLET_KEY_PATH = "path"
val SERVLET_KEY_SAMPLE = "sample"
@@ -42,8 +45,11 @@ class MetricsServlet(val property: Properties, val registry: MetricRegistry) ext
val mapper = new ObjectMapper().registerModule(
new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample))
- def getHandlers = Array[(String, Handler)](
- (servletPath, JettyUtils.createHandler(request => getMetricsSnapshot(request), "text/json"))
+ def getHandlers = Array[ServletContextHandler](
+ JettyUtils.createServletHandler(servletPath,
+ JettyUtils.createServlet(
+ new JettyUtils.ServletParams(request => getMetricsSnapshot(request), "text/json"),
+ securityMgr) )
)
def getMetricsSnapshot(request: HttpServletRequest): String = {
diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
index d3c09b1606..04df2f3b0d 100644
--- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
+++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
@@ -45,9 +45,10 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
throw new Exception("Max chunk size is " + maxChunkSize)
}
+ val security = if (isSecurityNeg) 1 else 0
if (size == 0 && !gotChunkForSendingOnce) {
val newChunk = new MessageChunk(
- new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
+ new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null)
gotChunkForSendingOnce = true
return Some(newChunk)
}
@@ -65,7 +66,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
}
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
gotChunkForSendingOnce = true
return Some(newChunk)
}
@@ -79,6 +80,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
throw new Exception("Attempting to get chunk from message with multiple data buffers")
}
val buffer = buffers(0)
+ val security = if (isSecurityNeg) 1 else 0
if (buffer.remaining > 0) {
if (buffer.remaining < chunkSize) {
throw new Exception("Not enough space in data buffer for receiving chunk")
@@ -86,7 +88,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
return Some(newChunk)
}
None
diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala
index 8219a185ea..8fd9c2b87d 100644
--- a/core/src/main/scala/org/apache/spark/network/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/Connection.scala
@@ -17,6 +17,11 @@
package org.apache.spark.network
+import org.apache.spark._
+import org.apache.spark.SparkSaslServer
+
+import scala.collection.mutable.{HashMap, Queue, ArrayBuffer}
+
import java.net._
import java.nio._
import java.nio.channels._
@@ -27,13 +32,16 @@ import org.apache.spark._
private[spark]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
- val socketRemoteConnectionManagerId: ConnectionManagerId)
+ val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId)
extends Logging {
- def this(channel_ : SocketChannel, selector_ : Selector) = {
+ var sparkSaslServer: SparkSaslServer = null
+ var sparkSaslClient: SparkSaslClient = null
+
+ def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = {
this(channel_, selector_,
ConnectionManagerId.fromSocketAddress(
- channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]))
+ channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]), id_)
}
channel.configureBlocking(false)
@@ -49,6 +57,16 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
val remoteAddress = getRemoteAddress()
+ /**
+ * Used to synchronize client requests: client's work-related requests must
+ * wait until SASL authentication completes.
+ */
+ private val authenticated = new Object()
+
+ def getAuthenticated(): Object = authenticated
+
+ def isSaslComplete(): Boolean
+
def resetForceReregister(): Boolean
// Read channels typically do not register for write and write does not for read
@@ -69,6 +87,16 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
// Will be true for ReceivingConnection, false for SendingConnection.
def changeInterestForRead(): Boolean
+ private def disposeSasl() {
+ if (sparkSaslServer != null) {
+ sparkSaslServer.dispose();
+ }
+
+ if (sparkSaslClient != null) {
+ sparkSaslClient.dispose()
+ }
+ }
+
// On receiving a write event, should we change the interest for this channel or not ?
// Will be false for ReceivingConnection, true for SendingConnection.
// Actually, for now, should not get triggered for ReceivingConnection
@@ -101,6 +129,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
k.cancel()
}
channel.close()
+ disposeSasl()
callOnCloseCallback()
}
@@ -168,8 +197,12 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
private[spark]
class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
- remoteId_ : ConnectionManagerId)
- extends Connection(SocketChannel.open, selector_, remoteId_) {
+ remoteId_ : ConnectionManagerId, id_ : ConnectionId)
+ extends Connection(SocketChannel.open, selector_, remoteId_, id_) {
+
+ def isSaslComplete(): Boolean = {
+ if (sparkSaslClient != null) sparkSaslClient.isComplete() else false
+ }
private class Outbox {
val messages = new Queue[Message]()
@@ -226,6 +259,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
data as detailed in https://github.com/mesos/spark/pull/791
*/
private var needForceReregister = false
+
val currentBuffers = new ArrayBuffer[ByteBuffer]()
/*channel.socket.setSendBufferSize(256 * 1024)*/
@@ -316,6 +350,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
// If we have 'seen' pending messages, then reset flag - since we handle that as
// normal registering of event (below)
if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister()
+
currentBuffers ++= buffers
}
case None => {
@@ -384,8 +419,15 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
// Must be created within selector loop - else deadlock
-private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
- extends Connection(channel_, selector_) {
+private[spark] class ReceivingConnection(
+ channel_ : SocketChannel,
+ selector_ : Selector,
+ id_ : ConnectionId)
+ extends Connection(channel_, selector_, id_) {
+
+ def isSaslComplete(): Boolean = {
+ if (sparkSaslServer != null) sparkSaslServer.isComplete() else false
+ }
class Inbox() {
val messages = new HashMap[Int, BufferMessage]()
@@ -396,6 +438,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
val newMessage = Message.create(header).asInstanceOf[BufferMessage]
newMessage.started = true
newMessage.startTime = System.currentTimeMillis
+ newMessage.isSecurityNeg = header.securityNeg == 1
logDebug(
"Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
messages += ((newMessage.id, newMessage))
@@ -441,7 +484,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
val inbox = new Inbox()
val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
- var onReceiveCallback: (Connection , Message) => Unit = null
+ var onReceiveCallback: (Connection, Message) => Unit = null
var currentChunk: MessageChunk = null
channel.register(selector, SelectionKey.OP_READ)
@@ -516,7 +559,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
}
}
} catch {
- case e: Exception => {
+ case e: Exception => {
logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala
new file mode 100644
index 0000000000..ffaab677d4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) {
+ override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId
+}
+
+private[spark] object ConnectionId {
+
+ def createConnectionIdFromString(connectionIdString: String): ConnectionId = {
+ val res = connectionIdString.split("_").map(_.trim())
+ if (res.size != 3) {
+ throw new Exception("Error converting ConnectionId string: " + connectionIdString +
+ " to a ConnectionId Object")
+ }
+ new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index a7f20f8c51..a75130cba2 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -21,6 +21,9 @@ import java.net._
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
+import java.net._
+import java.util.concurrent.atomic.AtomicInteger
+
import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor}
import scala.collection.mutable.ArrayBuffer
@@ -28,13 +31,15 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.collection.mutable.SynchronizedMap
import scala.collection.mutable.SynchronizedQueue
+
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.concurrent.duration._
import org.apache.spark._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{SystemClock, Utils}
-private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Logging {
+private[spark] class ConnectionManager(port: Int, conf: SparkConf,
+ securityManager: SecurityManager) extends Logging {
class MessageStatus(
val message: Message,
@@ -50,6 +55,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
private val selector = SelectorProvider.provider.openSelector()
+ // default to 30 second timeout waiting for authentication
+ private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30)
+
private val handleMessageExecutor = new ThreadPoolExecutor(
conf.getInt("spark.core.connection.handler.threads.min", 20),
conf.getInt("spark.core.connection.handler.threads.max", 60),
@@ -71,6 +79,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
new LinkedBlockingDeque[Runnable]())
private val serverChannel = ServerSocketChannel.open()
+ // used to track the SendingConnections waiting to do SASL negotiation
+ private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection]
+ with SynchronizedMap[ConnectionId, SendingConnection]
private val connectionsByKey =
new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection]
@@ -84,6 +95,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
+ private val authEnabled = securityManager.isAuthenticationEnabled()
+
serverChannel.configureBlocking(false)
serverChannel.socket.setReuseAddress(true)
serverChannel.socket.setReceiveBufferSize(256 * 1024)
@@ -94,6 +107,10 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
+ // used in combination with the ConnectionManagerId to create unique Connection ids
+ // to be able to track asynchronous messages
+ private val idCount: AtomicInteger = new AtomicInteger(1)
+
private val selectorThread = new Thread("connection-manager-thread") {
override def run() = ConnectionManager.this.run()
}
@@ -125,7 +142,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
} finally {
writeRunnableStarted.synchronized {
writeRunnableStarted -= key
- val needReregister = register || conn.resetForceReregister()
+ val needReregister = register || conn.resetForceReregister()
if (needReregister && conn.changeInterestForWrite()) {
conn.registerInterest()
}
@@ -372,7 +389,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
// accept them all in a tight loop. non blocking accept with no processing, should be fine
while (newChannel != null) {
try {
- val newConnection = new ReceivingConnection(newChannel, selector)
+ val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
+ val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId)
newConnection.onReceive(receiveMessage)
addListeners(newConnection)
addConnection(newConnection)
@@ -406,6 +424,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
connectionsById -= sendingConnectionManagerId
+ connectionsAwaitingSasl -= connection.connectionId
messageStatuses.synchronized {
messageStatuses
@@ -481,7 +500,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
val creationTime = System.currentTimeMillis
def run() {
logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
- handleMessage(connectionManagerId, message)
+ handleMessage(connectionManagerId, message, connection)
logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
}
}
@@ -489,10 +508,133 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
/*handleMessage(connection, message)*/
}
- private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) {
+ private def handleClientAuthentication(
+ waitingConn: SendingConnection,
+ securityMsg: SecurityMessage,
+ connectionId : ConnectionId) {
+ if (waitingConn.isSaslComplete()) {
+ logDebug("Client sasl completed for id: " + waitingConn.connectionId)
+ connectionsAwaitingSasl -= waitingConn.connectionId
+ waitingConn.getAuthenticated().synchronized {
+ waitingConn.getAuthenticated().notifyAll();
+ }
+ return
+ } else {
+ var replyToken : Array[Byte] = null
+ try {
+ replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken);
+ if (waitingConn.isSaslComplete()) {
+ logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
+ connectionsAwaitingSasl -= waitingConn.connectionId
+ waitingConn.getAuthenticated().synchronized {
+ waitingConn.getAuthenticated().notifyAll()
+ }
+ return
+ }
+ var securityMsgResp = SecurityMessage.fromResponse(replyToken,
+ securityMsg.getConnectionId.toString())
+ var message = securityMsgResp.toBufferMessage
+ if (message == null) throw new Exception("Error creating security message")
+ sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
+ } catch {
+ case e: Exception => {
+ logError("Error handling sasl client authentication", e)
+ waitingConn.close()
+ throw new Exception("Error evaluating sasl response: " + e)
+ }
+ }
+ }
+ }
+
+ private def handleServerAuthentication(
+ connection: Connection,
+ securityMsg: SecurityMessage,
+ connectionId: ConnectionId) {
+ if (!connection.isSaslComplete()) {
+ logDebug("saslContext not established")
+ var replyToken : Array[Byte] = null
+ try {
+ connection.synchronized {
+ if (connection.sparkSaslServer == null) {
+ logDebug("Creating sasl Server")
+ connection.sparkSaslServer = new SparkSaslServer(securityManager)
+ }
+ }
+ replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
+ if (connection.isSaslComplete()) {
+ logDebug("Server sasl completed: " + connection.connectionId)
+ } else {
+ logDebug("Server sasl not completed: " + connection.connectionId)
+ }
+ if (replyToken != null) {
+ var securityMsgResp = SecurityMessage.fromResponse(replyToken,
+ securityMsg.getConnectionId)
+ var message = securityMsgResp.toBufferMessage
+ if (message == null) throw new Exception("Error creating security Message")
+ sendSecurityMessage(connection.getRemoteConnectionManagerId(), message)
+ }
+ } catch {
+ case e: Exception => {
+ logError("Error in server auth negotiation: " + e)
+ // It would probably be better to send an error message telling other side auth failed
+ // but for now just close
+ connection.close()
+ }
+ }
+ } else {
+ logDebug("connection already established for this connection id: " + connection.connectionId)
+ }
+ }
+
+
+ private def handleAuthentication(conn: Connection, bufferMessage: BufferMessage): Boolean = {
+ if (bufferMessage.isSecurityNeg) {
+ logDebug("This is security neg message")
+
+ // parse as SecurityMessage
+ val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage)
+ val connectionId = ConnectionId.createConnectionIdFromString(securityMsg.getConnectionId)
+
+ connectionsAwaitingSasl.get(connectionId) match {
+ case Some(waitingConn) => {
+ // Client - this must be in response to us doing Send
+ logDebug("Client handleAuth for id: " + waitingConn.connectionId)
+ handleClientAuthentication(waitingConn, securityMsg, connectionId)
+ }
+ case None => {
+ // Server - someone sent us something and we haven't authenticated yet
+ logDebug("Server handleAuth for id: " + connectionId)
+ handleServerAuthentication(conn, securityMsg, connectionId)
+ }
+ }
+ return true
+ } else {
+ if (!conn.isSaslComplete()) {
+ // We could handle this better and tell the client we need to do authentication
+ // negotiation, but for now just ignore them.
+ logError("message sent that is not security negotiation message on connection " +
+ "not authenticated yet, ignoring it!!")
+ return true
+ }
+ }
+ return false
+ }
+
+ private def handleMessage(
+ connectionManagerId: ConnectionManagerId,
+ message: Message,
+ connection: Connection) {
logDebug("Handling [" + message + "] from [" + connectionManagerId + "]")
message match {
case bufferMessage: BufferMessage => {
+ if (authEnabled) {
+ val res = handleAuthentication(connection, bufferMessage)
+ if (res == true) {
+ // message was security negotiation so skip the rest
+ logDebug("After handleAuth result was true, returning")
+ return
+ }
+ }
if (bufferMessage.hasAckId) {
val sentMessageStatus = messageStatuses.synchronized {
messageStatuses.get(bufferMessage.ackId) match {
@@ -541,17 +683,124 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
}
}
+ private def checkSendAuthFirst(connManagerId: ConnectionManagerId, conn: SendingConnection) {
+ // see if we need to do sasl before writing
+ // this should only be the first negotiation as the Client!!!
+ if (!conn.isSaslComplete()) {
+ conn.synchronized {
+ if (conn.sparkSaslClient == null) {
+ conn.sparkSaslClient = new SparkSaslClient(securityManager)
+ var firstResponse: Array[Byte] = null
+ try {
+ firstResponse = conn.sparkSaslClient.firstToken()
+ var securityMsg = SecurityMessage.fromResponse(firstResponse,
+ conn.connectionId.toString())
+ var message = securityMsg.toBufferMessage
+ if (message == null) throw new Exception("Error creating security message")
+ connectionsAwaitingSasl += ((conn.connectionId, conn))
+ sendSecurityMessage(connManagerId, message)
+ logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId)
+ } catch {
+ case e: Exception => {
+ logError("Error getting first response from the SaslClient.", e)
+ conn.close()
+ throw new Exception("Error getting first response from the SaslClient")
+ }
+ }
+ }
+ }
+ } else {
+ logDebug("Sasl already established ")
+ }
+ }
+
+ // allow us to add messages to the inbox for doing sasl negotiating
+ private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) {
+ def startNewConnection(): SendingConnection = {
+ val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port)
+ val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
+ val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId,
+ newConnectionId)
+ logInfo("creating new sending connection for security! " + newConnectionId )
+ registerRequests.enqueue(newConnection)
+
+ newConnection
+ }
+ // I removed the lookupKey stuff as part of merge ... should I re-add it ?
+ // We did not find it useful in our test-env ...
+ // If we do re-add it, we should consistently use it everywhere I guess ?
+ message.senderAddress = id.toSocketAddress()
+ logTrace("Sending Security [" + message + "] to [" + connManagerId + "]")
+ val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection())
+
+ //send security message until going connection has been authenticated
+ connection.send(message)
+
+ wakeupSelector()
+ }
+
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host,
connectionManagerId.port)
- val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId)
+ val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
+ val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId,
+ newConnectionId)
+ logTrace("creating new sending connection: " + newConnectionId)
registerRequests.enqueue(newConnection)
newConnection
}
val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
+ if (authEnabled) {
+ checkSendAuthFirst(connectionManagerId, connection)
+ }
message.senderAddress = id.toSocketAddress()
+ logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " +
+ "connectionid: " + connection.connectionId)
+
+ if (authEnabled) {
+ // if we aren't authenticated yet lets block the senders until authentication completes
+ try {
+ connection.getAuthenticated().synchronized {
+ val clock = SystemClock
+ val startTime = clock.getTime()
+
+ while (!connection.isSaslComplete()) {
+ logDebug("getAuthenticated wait connectionid: " + connection.connectionId)
+ // have timeout in case remote side never responds
+ connection.getAuthenticated().wait(500)
+ if (((clock.getTime() - startTime) >= (authTimeout * 1000))
+ && (!connection.isSaslComplete())) {
+ // took to long to authenticate the connection, something probably went wrong
+ throw new Exception("Took to long for authentication to " + connectionManagerId +
+ ", waited " + authTimeout + "seconds, failing.")
+ }
+ }
+ }
+ } catch {
+ case e: Exception => logError("Exception while waiting for authentication.", e)
+
+ // need to tell sender it failed
+ messageStatuses.synchronized {
+ val s = messageStatuses.get(message.id)
+ s match {
+ case Some(msgStatus) => {
+ messageStatuses -= message.id
+ logInfo("Notifying " + msgStatus.connectionManagerId)
+ msgStatus.synchronized {
+ msgStatus.attempted = true
+ msgStatus.acked = false
+ msgStatus.markDone()
+ }
+ }
+ case None => {
+ logError("no messageStatus for failed message id: " + message.id)
+ }
+ }
+ }
+ }
+ }
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
connection.send(message)
@@ -603,7 +852,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
private[spark] object ConnectionManager {
def main(args: Array[String]) {
- val manager = new ConnectionManager(9999, new SparkConf)
+ val conf = new SparkConf
+ val manager = new ConnectionManager(9999, conf, new SecurityManager(conf))
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
println("Received [" + msg + "] from [" + id + "]")
None
diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala
index 20fe676618..7caccfdbb4 100644
--- a/core/src/main/scala/org/apache/spark/network/Message.scala
+++ b/core/src/main/scala/org/apache/spark/network/Message.scala
@@ -27,6 +27,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
var started = false
var startTime = -1L
var finishTime = -1L
+ var isSecurityNeg = false
def size: Int
diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
index 9bcbc6141a..ead663ede7 100644
--- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
+++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
@@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader(
val totalSize: Int,
val chunkSize: Int,
val other: Int,
+ val securityNeg: Int,
val address: InetSocketAddress) {
lazy val buffer = {
// No need to change this, at 'use' time, we do a reverse lookup of the hostname.
@@ -40,6 +41,7 @@ private[spark] class MessageChunkHeader(
putInt(totalSize).
putInt(chunkSize).
putInt(other).
+ putInt(securityNeg).
putInt(ip.size).
put(ip).
putInt(port).
@@ -48,12 +50,13 @@ private[spark] class MessageChunkHeader(
}
override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
- " and sizes " + totalSize + " / " + chunkSize + " bytes"
+ " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg
+
}
private[spark] object MessageChunkHeader {
- val HEADER_SIZE = 40
+ val HEADER_SIZE = 44
def create(buffer: ByteBuffer): MessageChunkHeader = {
if (buffer.remaining != HEADER_SIZE) {
@@ -64,11 +67,13 @@ private[spark] object MessageChunkHeader {
val totalSize = buffer.getInt()
val chunkSize = buffer.getInt()
val other = buffer.getInt()
+ val securityNeg = buffer.getInt()
val ipSize = buffer.getInt()
val ipBytes = new Array[Byte](ipSize)
buffer.get(ipBytes)
val ip = InetAddress.getByAddress(ipBytes)
val port = buffer.getInt()
- new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
+ new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg,
+ new InetSocketAddress(ip, port))
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
index 9976255c7e..3c09a713c6 100644
--- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
@@ -18,12 +18,12 @@
package org.apache.spark.network
import java.nio.ByteBuffer
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
private[spark] object ReceiverTest {
def main(args: Array[String]) {
- val manager = new ConnectionManager(9999, new SparkConf)
+ val conf = new SparkConf
+ val manager = new ConnectionManager(9999, conf, new SecurityManager(conf))
println("Started connection manager with id = " + manager.id)
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
new file mode 100644
index 0000000000..0d9f743b36
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.StringBuilder
+
+import org.apache.spark._
+import org.apache.spark.network._
+
+/**
+ * SecurityMessage is class that contains the connectionId and sasl token
+ * used in SASL negotiation. SecurityMessage has routines for converting
+ * it to and from a BufferMessage so that it can be sent by the ConnectionManager
+ * and easily consumed by users when received.
+ * The api was modeled after BlockMessage.
+ *
+ * The connectionId is the connectionId of the client side. Since
+ * message passing is asynchronous and its possible for the server side (receiving)
+ * to get multiple different types of messages on the same connection the connectionId
+ * is used to know which connnection the security message is intended for.
+ *
+ * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side
+ * is acting as a client and connecting to node_1. SASL negotiation has to occur
+ * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message.
+ * node_1 receives the message from node_0 but before it can process it and send a response,
+ * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0
+ * and sends a security message of its own to authenticate as a client. Now node_0 gets
+ * the message and it needs to decide if this message is in response to it being a client
+ * (from the first send) or if its just node_1 trying to connect to it to send data. This
+ * is where the connectionId field is used. node_0 can lookup the connectionId to see if
+ * it is in response to it being a client or if its in response to someone sending other data.
+ *
+ * The format of a SecurityMessage as its sent is:
+ * - Length of the ConnectionId
+ * - ConnectionId
+ * - Length of the token
+ * - Token
+ */
+private[spark] class SecurityMessage() extends Logging {
+
+ private var connectionId: String = null
+ private var token: Array[Byte] = null
+
+ def set(byteArr: Array[Byte], newconnectionId: String) {
+ if (byteArr == null) {
+ token = new Array[Byte](0)
+ } else {
+ token = byteArr
+ }
+ connectionId = newconnectionId
+ }
+
+ /**
+ * Read the given buffer and set the members of this class.
+ */
+ def set(buffer: ByteBuffer) {
+ val idLength = buffer.getInt()
+ val idBuilder = new StringBuilder(idLength)
+ for (i <- 1 to idLength) {
+ idBuilder += buffer.getChar()
+ }
+ connectionId = idBuilder.toString()
+
+ val tokenLength = buffer.getInt()
+ token = new Array[Byte](tokenLength)
+ if (tokenLength > 0) {
+ buffer.get(token, 0, tokenLength)
+ }
+ }
+
+ def set(bufferMsg: BufferMessage) {
+ val buffer = bufferMsg.buffers.apply(0)
+ buffer.clear()
+ set(buffer)
+ }
+
+ def getConnectionId: String = {
+ return connectionId
+ }
+
+ def getToken: Array[Byte] = {
+ return token
+ }
+
+ /**
+ * Create a BufferMessage that can be sent by the ConnectionManager containing
+ * the security information from this class.
+ * @return BufferMessage
+ */
+ def toBufferMessage: BufferMessage = {
+ val startTime = System.currentTimeMillis
+ val buffers = new ArrayBuffer[ByteBuffer]()
+
+ // 4 bytes for the length of the connectionId
+ // connectionId is of type char so multiple the length by 2 to get number of bytes
+ // 4 bytes for the length of token
+ // token is a byte buffer so just take the length
+ var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length)
+ buffer.putInt(connectionId.length())
+ connectionId.foreach((x: Char) => buffer.putChar(x))
+ buffer.putInt(token.length)
+
+ if (token.length > 0) {
+ buffer.put(token)
+ }
+ buffer.flip()
+ buffers += buffer
+
+ var message = Message.createBufferMessage(buffers)
+ logDebug("message total size is : " + message.size)
+ message.isSecurityNeg = true
+ return message
+ }
+
+ override def toString: String = {
+ "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]"
+ }
+}
+
+private[spark] object SecurityMessage {
+
+ /**
+ * Convert the given BufferMessage to a SecurityMessage by parsing the contents
+ * of the BufferMessage and populating the SecurityMessage fields.
+ * @param bufferMessage is a BufferMessage that was received
+ * @return new SecurityMessage
+ */
+ def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = {
+ val newSecurityMessage = new SecurityMessage()
+ newSecurityMessage.set(bufferMessage)
+ newSecurityMessage
+ }
+
+ /**
+ * Create a SecurityMessage to send from a given saslResponse.
+ * @param response is the response to a challenge from the SaslClient or Saslserver
+ * @param connectionId the client connectionId we are negotiation authentication for
+ * @return a new SecurityMessage
+ */
+ def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = {
+ val newSecurityMessage = new SecurityMessage()
+ newSecurityMessage.set(response, connectionId)
+ newSecurityMessage
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
index 646f8425d9..aac2c24a46 100644
--- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
@@ -18,8 +18,7 @@
package org.apache.spark.network
import java.nio.ByteBuffer
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
private[spark] object SenderTest {
def main(args: Array[String]) {
@@ -32,8 +31,8 @@ private[spark] object SenderTest {
val targetHost = args(0)
val targetPort = args(1).toInt
val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort)
-
- val manager = new ConnectionManager(0, new SparkConf)
+ val conf = new SparkConf
+ val manager = new ConnectionManager(0, conf, new SecurityManager(conf))
println("Started connection manager with id = " + manager.id)
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 977c24687c..1bf3f4db32 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props}
import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
import sun.nio.ch.DirectBuffer
-import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
+import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException, SecurityManager}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.serializer.Serializer
@@ -47,7 +47,8 @@ private[spark] class BlockManager(
val master: BlockManagerMaster,
val defaultSerializer: Serializer,
maxMemory: Long,
- val conf: SparkConf)
+ val conf: SparkConf,
+ securityManager: SecurityManager)
extends Logging {
val shuffleBlockManager = new ShuffleBlockManager(this)
@@ -66,7 +67,7 @@ private[spark] class BlockManager(
if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
}
- val connectionManager = new ConnectionManager(0, conf)
+ val connectionManager = new ConnectionManager(0, conf, securityManager)
implicit val futureExecContext = connectionManager.futureExecContext
val blockManagerId = BlockManagerId(
@@ -122,8 +123,9 @@ private[spark] class BlockManager(
* Construct a BlockManager with a memory limit set based on system properties.
*/
def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster,
- serializer: Serializer, conf: SparkConf) = {
- this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf)
+ serializer: Serializer, conf: SparkConf, securityManager: SecurityManager) = {
+ this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf,
+ securityManager)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
index 1d81d006c0..36f2a0fd02 100644
--- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
@@ -24,6 +24,7 @@ import util.Random
import org.apache.spark.SparkConf
import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.{SecurityManager, SparkConf}
/**
* This class tests the BlockManager and MemoryStore for thread safety and
@@ -98,7 +99,8 @@ private[spark] object ThreadingTest {
val blockManagerMaster = new BlockManagerMaster(
actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf))), conf)
val blockManager = new BlockManager(
- "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf)
+ "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf,
+ new SecurityManager(conf))
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
producers.foreach(_.start)
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 1b78c52ff6..7c35cd165a 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -18,7 +18,8 @@
package org.apache.spark.ui
import java.net.InetSocketAddress
-import javax.servlet.http.{HttpServletResponse, HttpServletRequest}
+import java.net.URL
+import javax.servlet.http.{HttpServlet, HttpServletResponse, HttpServletRequest}
import scala.annotation.tailrec
import scala.util.{Failure, Success, Try}
@@ -26,11 +27,14 @@ import scala.xml.Node
import org.json4s.JValue
import org.json4s.jackson.JsonMethods.{pretty, render}
-import org.eclipse.jetty.server.{Handler, Request, Server}
-import org.eclipse.jetty.server.handler.{AbstractHandler, ContextHandler, HandlerList, ResourceHandler}
+
+import org.eclipse.jetty.server.{DispatcherType, Server}
+import org.eclipse.jetty.server.handler.HandlerList
+import org.eclipse.jetty.servlet.{DefaultServlet, FilterHolder, ServletContextHandler, ServletHolder}
import org.eclipse.jetty.util.thread.QueuedThreadPool
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
+
/** Utilities for launching a web server using Jetty's HTTP Server class */
private[spark] object JettyUtils extends Logging {
@@ -39,57 +43,104 @@ private[spark] object JettyUtils extends Logging {
type Responder[T] = HttpServletRequest => T
- // Conversions from various types of Responder's to jetty Handlers
- implicit def jsonResponderToHandler(responder: Responder[JValue]): Handler =
- createHandler(responder, "text/json", (in: JValue) => pretty(render(in)))
+ class ServletParams[T <% AnyRef](val responder: Responder[T],
+ val contentType: String,
+ val extractFn: T => String = (in: Any) => in.toString) {}
+
+ // Conversions from various types of Responder's to appropriate servlet parameters
+ implicit def jsonResponderToServlet(responder: Responder[JValue]): ServletParams[JValue] =
+ new ServletParams(responder, "text/json", (in: JValue) => pretty(render(in)))
- implicit def htmlResponderToHandler(responder: Responder[Seq[Node]]): Handler =
- createHandler(responder, "text/html", (in: Seq[Node]) => "<!DOCTYPE html>" + in.toString)
+ implicit def htmlResponderToServlet(responder: Responder[Seq[Node]]): ServletParams[Seq[Node]] =
+ new ServletParams(responder, "text/html", (in: Seq[Node]) => "<!DOCTYPE html>" + in.toString)
- implicit def textResponderToHandler(responder: Responder[String]): Handler =
- createHandler(responder, "text/plain")
+ implicit def textResponderToServlet(responder: Responder[String]): ServletParams[String] =
+ new ServletParams(responder, "text/plain")
- def createHandler[T <% AnyRef](responder: Responder[T], contentType: String,
- extractFn: T => String = (in: Any) => in.toString): Handler = {
- new AbstractHandler {
- def handle(target: String,
- baseRequest: Request,
- request: HttpServletRequest,
+ def createServlet[T <% AnyRef](servletParams: ServletParams[T],
+ securityMgr: SecurityManager): HttpServlet = {
+ new HttpServlet {
+ override def doGet(request: HttpServletRequest,
response: HttpServletResponse) {
- response.setContentType("%s;charset=utf-8".format(contentType))
- response.setStatus(HttpServletResponse.SC_OK)
- baseRequest.setHandled(true)
- val result = responder(request)
- response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
- response.getWriter().println(extractFn(result))
+ if (securityMgr.checkUIViewPermissions(request.getRemoteUser())) {
+ response.setContentType("%s;charset=utf-8".format(servletParams.contentType))
+ response.setStatus(HttpServletResponse.SC_OK)
+ val result = servletParams.responder(request)
+ response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
+ response.getWriter().println(servletParams.extractFn(result))
+ } else {
+ response.setStatus(HttpServletResponse.SC_UNAUTHORIZED)
+ response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
+ response.sendError(HttpServletResponse.SC_UNAUTHORIZED,
+ "User is not authorized to access this page.");
+ }
}
}
}
+ def createServletHandler(path: String, servlet: HttpServlet): ServletContextHandler = {
+ val contextHandler = new ServletContextHandler()
+ val holder = new ServletHolder(servlet)
+ contextHandler.setContextPath(path)
+ contextHandler.addServlet(holder, "/")
+ contextHandler
+ }
+
/** Creates a handler that always redirects the user to a given path */
- def createRedirectHandler(newPath: String): Handler = {
- new AbstractHandler {
- def handle(target: String,
- baseRequest: Request,
- request: HttpServletRequest,
+ def createRedirectHandler(newPath: String, path: String): ServletContextHandler = {
+ val servlet = new HttpServlet {
+ override def doGet(request: HttpServletRequest,
response: HttpServletResponse) {
- response.setStatus(302)
- response.setHeader("Location", baseRequest.getRootURL + newPath)
- baseRequest.setHandled(true)
+ // make sure we don't end up with // in the middle
+ val newUri = new URL(new URL(request.getRequestURL.toString), newPath).toURI
+ response.sendRedirect(newUri.toString)
}
}
+ val contextHandler = new ServletContextHandler()
+ val holder = new ServletHolder(servlet)
+ contextHandler.setContextPath(path)
+ contextHandler.addServlet(holder, "/")
+ contextHandler
}
/** Creates a handler for serving files from a static directory */
- def createStaticHandler(resourceBase: String): ResourceHandler = {
- val staticHandler = new ResourceHandler
+ def createStaticHandler(resourceBase: String, path: String): ServletContextHandler = {
+ val contextHandler = new ServletContextHandler()
+ val staticHandler = new DefaultServlet
+ val holder = new ServletHolder(staticHandler)
Option(getClass.getClassLoader.getResource(resourceBase)) match {
case Some(res) =>
- staticHandler.setResourceBase(res.toString)
+ holder.setInitParameter("resourceBase", res.toString)
case None =>
throw new Exception("Could not find resource path for Web UI: " + resourceBase)
}
- staticHandler
+ contextHandler.addServlet(holder, path)
+ contextHandler
+ }
+
+ private def addFilters(handlers: Seq[ServletContextHandler], conf: SparkConf) {
+ val filters: Array[String] = conf.get("spark.ui.filters", "").split(',').map(_.trim())
+ filters.foreach {
+ case filter : String =>
+ if (!filter.isEmpty) {
+ logInfo("Adding filter: " + filter)
+ val holder : FilterHolder = new FilterHolder()
+ holder.setClassName(filter)
+ // get any parameters for each filter
+ val paramName = "spark." + filter + ".params"
+ val params = conf.get(paramName, "").split(',').map(_.trim()).toSet
+ params.foreach {
+ case param : String =>
+ if (!param.isEmpty) {
+ val parts = param.split("=")
+ if (parts.length == 2) holder.setInitParameter(parts(0), parts(1))
+ }
+ }
+ val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR,
+ DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST)
+ handlers.foreach { case(handler) => handler.addFilter(holder, "/*", enumDispatcher) }
+ }
+ }
}
/**
@@ -99,17 +150,12 @@ private[spark] object JettyUtils extends Logging {
* If the desired port number is contented, continues incrementing ports until a free port is
* found. Returns the chosen port and the jetty Server object.
*/
- def startJettyServer(hostName: String, port: Int, handlers: Seq[(String, Handler)]): (Server, Int)
- = {
-
- val handlersToRegister = handlers.map { case(path, handler) =>
- val contextHandler = new ContextHandler(path)
- contextHandler.setHandler(handler)
- contextHandler.asInstanceOf[org.eclipse.jetty.server.Handler]
- }
+ def startJettyServer(hostName: String, port: Int, handlers: Seq[ServletContextHandler],
+ conf: SparkConf): (Server, Int) = {
+ addFilters(handlers, conf)
val handlerList = new HandlerList
- handlerList.setHandlers(handlersToRegister.toArray)
+ handlerList.setHandlers(handlers.toArray)
@tailrec
def connect(currentPort: Int): (Server, Int) = {
@@ -119,7 +165,9 @@ private[spark] object JettyUtils extends Logging {
server.setThreadPool(pool)
server.setHandler(handlerList)
- Try { server.start() } match {
+ Try {
+ server.start()
+ } match {
case s: Success[_] =>
(server, server.getConnectors.head.getLocalPort)
case f: Failure[_] =>
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index af6b65860e..ca82c3da2f 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -17,7 +17,10 @@
package org.apache.spark.ui
-import org.eclipse.jetty.server.{Handler, Server}
+import javax.servlet.http.HttpServletRequest
+
+import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.{Logging, SparkContext, SparkEnv}
import org.apache.spark.ui.JettyUtils._
@@ -34,9 +37,9 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging {
var boundPort: Option[Int] = None
var server: Option[Server] = None
- val handlers = Seq[(String, Handler)](
- ("/static", createStaticHandler(SparkUI.STATIC_RESOURCE_DIR)),
- ("/", createRedirectHandler("/stages"))
+ val handlers = Seq[ServletContextHandler] (
+ createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static/*"),
+ createRedirectHandler("/stages", "/")
)
val storage = new BlockManagerUI(sc)
val jobs = new JobProgressUI(sc)
@@ -52,7 +55,7 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging {
/** Bind the HTTP server which backs this web interface */
def bind() {
try {
- val (srv, usedPort) = JettyUtils.startJettyServer(host, port, allHandlers)
+ val (srv, usedPort) = JettyUtils.startJettyServer(host, port, allHandlers, sc.conf)
logInfo("Started Spark Web UI at http://%s:%d".format(host, usedPort))
server = Some(srv)
boundPort = Some(usedPort)
@@ -83,5 +86,5 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging {
private[spark] object SparkUI {
val DEFAULT_PORT = "4040"
- val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static"
+ val STATIC_RESOURCE_DIR = "org/apache/spark/ui"
}
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala
index 9e7cdc8816..14333476c0 100644
--- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConversions._
import scala.util.Properties
import scala.xml.Node
-import org.eclipse.jetty.server.Handler
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.SparkContext
import org.apache.spark.ui.JettyUtils._
@@ -32,8 +32,9 @@ import org.apache.spark.ui.UIUtils
private[spark] class EnvironmentUI(sc: SparkContext) {
- def getHandlers = Seq[(String, Handler)](
- ("/environment", (request: HttpServletRequest) => envDetails(request))
+ def getHandlers = Seq[ServletContextHandler](
+ createServletHandler("/environment",
+ createServlet((request: HttpServletRequest) => envDetails(request), sc.env.securityManager))
)
def envDetails(request: HttpServletRequest): Seq[Node] = {
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
index 1f3b7a4c23..4235cfeff9 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
@@ -22,7 +22,7 @@ import javax.servlet.http.HttpServletRequest
import scala.collection.mutable.{HashMap, HashSet}
import scala.xml.Node
-import org.eclipse.jetty.server.Handler
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.{ExceptionFailure, Logging, SparkContext}
import org.apache.spark.executor.TaskMetrics
@@ -43,8 +43,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
sc.addSparkListener(listener)
}
- def getHandlers = Seq[(String, Handler)](
- ("/executors", (request: HttpServletRequest) => render(request))
+ def getHandlers = Seq[ServletContextHandler](
+ createServletHandler("/executors", createServlet((request: HttpServletRequest) => render
+ (request), sc.env.securityManager))
)
def render(request: HttpServletRequest): Seq[Node] = {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
index 557bce6b66..2d95d47e15 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
@@ -23,6 +23,7 @@ import javax.servlet.http.HttpServletRequest
import scala.Seq
import org.eclipse.jetty.server.Handler
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.SparkContext
import org.apache.spark.ui.JettyUtils._
@@ -45,9 +46,15 @@ private[spark] class JobProgressUI(val sc: SparkContext) {
def formatDuration(ms: Long) = Utils.msDurationToString(ms)
- def getHandlers = Seq[(String, Handler)](
- ("/stages/stage", (request: HttpServletRequest) => stagePage.render(request)),
- ("/stages/pool", (request: HttpServletRequest) => poolPage.render(request)),
- ("/stages", (request: HttpServletRequest) => indexPage.render(request))
+ def getHandlers = Seq[ServletContextHandler](
+ createServletHandler("/stages/stage",
+ createServlet((request: HttpServletRequest) => stagePage.render(request),
+ sc.env.securityManager)),
+ createServletHandler("/stages/pool",
+ createServlet((request: HttpServletRequest) => poolPage.render(request),
+ sc.env.securityManager)),
+ createServletHandler("/stages",
+ createServlet((request: HttpServletRequest) => indexPage.render(request),
+ sc.env.securityManager))
)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
index dc18eab74e..cb2083eb01 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ui.storage
import javax.servlet.http.HttpServletRequest
-import org.eclipse.jetty.server.Handler
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.ui.JettyUtils._
@@ -29,8 +29,12 @@ private[spark] class BlockManagerUI(val sc: SparkContext) extends Logging {
val indexPage = new IndexPage(this)
val rddPage = new RDDPage(this)
- def getHandlers = Seq[(String, Handler)](
- ("/storage/rdd", (request: HttpServletRequest) => rddPage.render(request)),
- ("/storage", (request: HttpServletRequest) => indexPage.render(request))
+ def getHandlers = Seq[ServletContextHandler](
+ createServletHandler("/storage/rdd",
+ createServlet((request: HttpServletRequest) => rddPage.render(request),
+ sc.env.securityManager)),
+ createServletHandler("/storage",
+ createServlet((request: HttpServletRequest) => indexPage.render(request),
+ sc.env.securityManager))
)
}
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index f26ed47e58..a6c9a9aaba 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -24,12 +24,12 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, IndestructibleActorSystem}
import com.typesafe.config.ConfigFactory
import org.apache.log4j.{Level, Logger}
-import org.apache.spark.SparkConf
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
/**
* Various utility classes for working with Akka.
*/
-private[spark] object AkkaUtils {
+private[spark] object AkkaUtils extends Logging {
/**
* Creates an ActorSystem ready for remoting, with various Spark features. Returns both the
@@ -42,7 +42,7 @@ private[spark] object AkkaUtils {
* of a fatal exception. This is used by [[org.apache.spark.executor.Executor]].
*/
def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false,
- conf: SparkConf): (ActorSystem, Int) = {
+ conf: SparkConf, securityManager: SecurityManager): (ActorSystem, Int) = {
val akkaThreads = conf.getInt("spark.akka.threads", 4)
val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15)
@@ -65,6 +65,15 @@ private[spark] object AkkaUtils {
conf.getDouble("spark.akka.failure-detector.threshold", 300.0)
val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000)
+ val secretKey = securityManager.getSecretKey()
+ val isAuthOn = securityManager.isAuthenticationEnabled()
+ if (isAuthOn && secretKey == null) {
+ throw new Exception("Secret key is null with authentication on")
+ }
+ val requireCookie = if (isAuthOn) "on" else "off"
+ val secureCookie = if (isAuthOn) secretKey else ""
+ logDebug("In createActorSystem, requireCookie is: " + requireCookie)
+
val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]).withFallback(
ConfigFactory.parseString(
s"""
@@ -72,6 +81,8 @@ private[spark] object AkkaUtils {
|akka.loggers = [""akka.event.slf4j.Slf4jLogger""]
|akka.stdout-loglevel = "ERROR"
|akka.jvm-exit-on-fatal-error = off
+ |akka.remote.require-cookie = "$requireCookie"
+ |akka.remote.secure-cookie = "$secureCookie"
|akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatInterval s
|akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPauses s
|akka.remote.transport-failure-detector.threshold = $akkaFailureDetector
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 8e69f1d335..0eb2f78b73 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,7 +18,7 @@
package org.apache.spark.util
import java.io._
-import java.net.{InetAddress, Inet4Address, NetworkInterface, URI, URL}
+import java.net.{InetAddress, Inet4Address, NetworkInterface, URI, URL, URLConnection}
import java.nio.ByteBuffer
import java.util.{Locale, Random, UUID}
import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor}
@@ -33,10 +33,11 @@ import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
-import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
import org.apache.spark.deploy.SparkHadoopUtil
+
/**
* Various utility methods used by Spark.
*/
@@ -233,13 +234,29 @@ private[spark] object Utils extends Logging {
}
/**
+ * Construct a URI container information used for authentication.
+ * This also sets the default authenticator to properly negotiation the
+ * user/password based on the URI.
+ *
+ * Note this relies on the Authenticator.setDefault being set properly to decode
+ * the user name and password. This is currently set in the SecurityManager.
+ */
+ def constructURIForAuthentication(uri: URI, securityMgr: SecurityManager): URI = {
+ val userCred = securityMgr.getSecretKey()
+ if (userCred == null) throw new Exception("Secret key is null with authentication on")
+ val userInfo = securityMgr.getHttpUser() + ":" + userCred
+ new URI(uri.getScheme(), userInfo, uri.getHost(), uri.getPort(), uri.getPath(),
+ uri.getQuery(), uri.getFragment())
+ }
+
+ /**
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
*
* Throws SparkException if the target file already exists and has different contents than
* the requested file.
*/
- def fetchFile(url: String, targetDir: File, conf: SparkConf) {
+ def fetchFile(url: String, targetDir: File, conf: SparkConf, securityMgr: SecurityManager) {
val filename = url.split("/").last
val tempDir = getLocalDir(conf)
val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
@@ -249,7 +266,19 @@ private[spark] object Utils extends Logging {
uri.getScheme match {
case "http" | "https" | "ftp" =>
logInfo("Fetching " + url + " to " + tempFile)
- val in = new URL(url).openStream()
+
+ var uc: URLConnection = null
+ if (securityMgr.isAuthenticationEnabled()) {
+ logDebug("fetchFile with security enabled")
+ val newuri = constructURIForAuthentication(uri, securityMgr)
+ uc = newuri.toURL().openConnection()
+ uc.setAllowUserInteraction(false)
+ } else {
+ logDebug("fetchFile not using security")
+ uc = new URL(url).openConnection()
+ }
+
+ val in = uc.getInputStream();
val out = new FileOutputStream(tempFile)
Utils.copyStream(in, out, true)
if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
new file mode 100644
index 0000000000..cd054c1f68
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+
+import akka.actor._
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.AkkaUtils
+import scala.concurrent.Await
+
+/**
+ * Test the AkkaUtils with various security settings.
+ */
+class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
+
+ test("remote fetch security bad password") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+
+ val securityManager = new SecurityManager(conf);
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
+ conf = conf, securityManager = securityManager)
+ System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+ assert(securityManager.isAuthenticationEnabled() === true)
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "true")
+ badconf.set("spark.authenticate.secret", "bad")
+ val securityManagerBad = new SecurityManager(badconf);
+
+ assert(securityManagerBad.isAuthenticationEnabled() === true)
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
+ conf = conf, securityManager = securityManagerBad)
+ val slaveTracker = new MapOutputTracker(conf)
+ val selection = slaveSystem.actorSelection(
+ s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ intercept[akka.actor.ActorNotFound] {
+ slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+ }
+
+ actorSystem.shutdown()
+ slaveSystem.shutdown()
+ }
+
+ test("remote fetch security off") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "false")
+ conf.set("spark.authenticate.secret", "bad")
+ val securityManager = new SecurityManager(conf);
+
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
+ conf = conf, securityManager = securityManager)
+ System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+
+ assert(securityManager.isAuthenticationEnabled() === false)
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "false")
+ badconf.set("spark.authenticate.secret", "good")
+ val securityManagerBad = new SecurityManager(badconf);
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
+ conf = badconf, securityManager = securityManagerBad)
+ val slaveTracker = new MapOutputTracker(conf)
+ val selection = slaveSystem.actorSelection(
+ s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+
+ assert(securityManagerBad.isAuthenticationEnabled() === false)
+
+ masterTracker.registerShuffle(10, 1)
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ masterTracker.registerMapOutput(10, 0, new MapStatus(
+ BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+ // this should succeed since security off
+ assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
+ Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
+
+ actorSystem.shutdown()
+ slaveSystem.shutdown()
+ }
+
+ test("remote fetch security pass") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf);
+
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
+ conf = conf, securityManager = securityManager)
+ System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+
+ assert(securityManager.isAuthenticationEnabled() === true)
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
+
+ val goodconf = new SparkConf
+ goodconf.set("spark.authenticate", "true")
+ goodconf.set("spark.authenticate.secret", "good")
+ val securityManagerGood = new SecurityManager(goodconf);
+
+ assert(securityManagerGood.isAuthenticationEnabled() === true)
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
+ conf = goodconf, securityManager = securityManagerGood)
+ val slaveTracker = new MapOutputTracker(conf)
+ val selection = slaveSystem.actorSelection(
+ s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+
+ masterTracker.registerShuffle(10, 1)
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ masterTracker.registerMapOutput(10, 0, new MapStatus(
+ BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+ // this should succeed since security on and passwords match
+ assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
+ Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
+
+ actorSystem.shutdown()
+ slaveSystem.shutdown()
+ }
+
+ test("remote fetch security off client") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+
+ val securityManager = new SecurityManager(conf);
+
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
+ conf = conf, securityManager = securityManager)
+ System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+
+ assert(securityManager.isAuthenticationEnabled() === true)
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "false")
+ badconf.set("spark.authenticate.secret", "bad")
+ val securityManagerBad = new SecurityManager(badconf);
+
+ assert(securityManagerBad.isAuthenticationEnabled() === false)
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
+ conf = badconf, securityManager = securityManagerBad)
+ val slaveTracker = new MapOutputTracker(conf)
+ val selection = slaveSystem.actorSelection(
+ s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ intercept[akka.actor.ActorNotFound] {
+ slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+ }
+
+ actorSystem.shutdown()
+ slaveSystem.shutdown()
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
index e022accee6..96ba3929c1 100644
--- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
@@ -21,6 +21,7 @@ import org.scalatest.FunSuite
class BroadcastSuite extends FunSuite with LocalSparkContext {
+
override def afterEach() {
super.afterEach()
System.clearProperty("spark.broadcast.factory")
diff --git a/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala
new file mode 100644
index 0000000000..80f7ec00c7
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+
+import java.nio._
+
+import org.apache.spark.network.{ConnectionManager, Message, ConnectionManagerId}
+import scala.concurrent.Await
+import scala.concurrent.TimeoutException
+import scala.concurrent.duration._
+
+
+/**
+ * Test the ConnectionManager with various security settings.
+ */
+class ConnectionManagerSuite extends FunSuite {
+
+ test("security default off") {
+ val conf = new SparkConf
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var receivedMessage = false
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ receivedMessage = true
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliablySync(manager.id, bufferMessage)
+
+ assert(receivedMessage == true)
+
+ manager.stop()
+ }
+
+ test("security on same password") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+ val managerServer = new ConnectionManager(0, conf, securityManager)
+ var numReceivedServerMessages = 0
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val count = 10
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliablySync(managerServer.id, bufferMessage)
+ })
+
+ assert(numReceivedServerMessages == 10)
+ assert(numReceivedMessages == 0)
+
+ manager.stop()
+ managerServer.stop()
+ }
+
+ test("security mismatch password") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "true")
+ badconf.set("spark.authenticate.secret", "bad")
+ val badsecurityManager = new SecurityManager(badconf)
+ val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
+ var numReceivedServerMessages = 0
+
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliablySync(managerServer.id, bufferMessage)
+
+ assert(numReceivedServerMessages == 0)
+ assert(numReceivedMessages == 0)
+
+ manager.stop()
+ managerServer.stop()
+ }
+
+ test("security mismatch auth off") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "false")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "true")
+ badconf.set("spark.authenticate.secret", "good")
+ val badsecurityManager = new SecurityManager(badconf)
+ val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
+ var numReceivedServerMessages = 0
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ (0 until 1).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(managerServer.id, bufferMessage)
+ }).foreach(f => {
+ try {
+ val g = Await.result(f, 1 second)
+ assert(false)
+ } catch {
+ case e: TimeoutException => {
+ // we should timeout here since the client can't do the negotiation
+ assert(true)
+ }
+ }
+ })
+
+ assert(numReceivedServerMessages == 0)
+ assert(numReceivedMessages == 0)
+ manager.stop()
+ managerServer.stop()
+ }
+
+ test("security auth off") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "false")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "false")
+ val badsecurityManager = new SecurityManager(badconf)
+ val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
+ var numReceivedServerMessages = 0
+
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ (0 until 10).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(managerServer.id, bufferMessage)
+ }).foreach(f => {
+ try {
+ val g = Await.result(f, 1 second)
+ if (!g.isDefined) assert(false) else assert(true)
+ } catch {
+ case e: Exception => {
+ assert(false)
+ }
+ }
+ })
+ assert(numReceivedServerMessages == 10)
+ assert(numReceivedMessages == 0)
+
+ manager.stop()
+ managerServer.stop()
+ }
+
+
+
+}
+
diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala
index e0e8011278..9cbdfc54a3 100644
--- a/core/src/test/scala/org/apache/spark/DriverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala
@@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.util.Utils
class DriverSuite extends FunSuite with Timeouts {
+
test("driver should exit after finishing") {
val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.props.get("spark.home")).get
// Regression test for SPARK-530: "Spark driver process doesn't exit after finishing"
diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
index 9be67b3c95..aee9ab9091 100644
--- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
@@ -30,6 +30,12 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
@transient var tmpFile: File = _
@transient var tmpJarUrl: String = _
+ override def beforeEach() {
+ super.beforeEach()
+ resetSparkContext()
+ System.setProperty("spark.authenticate", "false")
+ }
+
override def beforeAll() {
super.beforeAll()
val tmpDir = new File(Files.createTempDir(), "test")
@@ -43,6 +49,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
val jarFile = new File(tmpDir, "test.jar")
val jarStream = new FileOutputStream(jarFile)
val jar = new JarOutputStream(jarStream, new java.util.jar.Manifest())
+ System.setProperty("spark.authenticate", "false")
val jarEntry = new JarEntry(textFile.getName)
jar.putNextEntry(jarEntry)
@@ -77,6 +84,25 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
assert(result.toSet === Set((1,200), (2,300), (3,500)))
}
+ test("Distributing files locally security On") {
+ val sparkConf = new SparkConf(false)
+ sparkConf.set("spark.authenticate", "true")
+ sparkConf.set("spark.authenticate.secret", "good")
+ sc = new SparkContext("local[4]", "test", sparkConf)
+
+ sc.addFile(tmpFile.toString)
+ assert(sc.env.securityManager.isAuthenticationEnabled() === true)
+ val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
+ val result = sc.parallelize(testData).reduceByKey {
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
+ val fileVal = in.readLine().toInt
+ in.close()
+ _ * fileVal + _ * fileVal
+ }.collect()
+ assert(result.toSet === Set((1,200), (2,300), (3,500)))
+ }
+
test("Distributing files locally using URL as input") {
// addFile("file:///....")
sc = new SparkContext("local[4]", "test")
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 6c1e325f6f..8efa072a97 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -98,14 +98,16 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("remote fetch") {
val hostname = "localhost"
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf,
+ securityManager = new SecurityManager(conf))
System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
val masterTracker = new MapOutputTrackerMaster(conf)
masterTracker.trackerActor = actorSystem.actorOf(
Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf)
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf,
+ securityManager = new SecurityManager(conf))
val slaveTracker = new MapOutputTracker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
index c1e8b295df..96a5a12318 100644
--- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
@@ -18,21 +18,22 @@
package org.apache.spark.metrics
import org.scalatest.{BeforeAndAfter, FunSuite}
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.master.MasterSource
class MetricsSystemSuite extends FunSuite with BeforeAndAfter {
var filePath: String = _
var conf: SparkConf = null
+ var securityMgr: SecurityManager = null
before {
filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile()
conf = new SparkConf(false).set("spark.metrics.conf", filePath)
+ securityMgr = new SecurityManager(conf)
}
test("MetricsSystem with default config") {
- val metricsSystem = MetricsSystem.createMetricsSystem("default", conf)
+ val metricsSystem = MetricsSystem.createMetricsSystem("default", conf, securityMgr)
val sources = metricsSystem.sources
val sinks = metricsSystem.sinks
@@ -42,7 +43,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter {
}
test("MetricsSystem with sources add") {
- val metricsSystem = MetricsSystem.createMetricsSystem("test", conf)
+ val metricsSystem = MetricsSystem.createMetricsSystem("test", conf, securityMgr)
val sources = metricsSystem.sources
val sinks = metricsSystem.sinks
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 9f011d9c8d..121e47c7b1 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -28,7 +28,7 @@ import org.scalatest.concurrent.Timeouts._
import org.scalatest.matchers.ShouldMatchers._
import org.scalatest.time.SpanSugar._
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.{SecurityManager, SparkConf, SparkContext}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils}
@@ -39,6 +39,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
var actorSystem: ActorSystem = null
var master: BlockManagerMaster = null
var oldArch: String = null
+ conf.set("spark.authenticate", "false")
+ val securityMgr = new SecurityManager(conf)
// Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
conf.set("spark.kryoserializer.buffer.mb", "1")
@@ -49,7 +51,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId)
before {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf,
+ securityManager = securityMgr)
this.actorSystem = actorSystem
conf.set("spark.driver.port", boundPort.toString)
@@ -125,7 +128,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("master + 1 manager interaction") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -155,8 +158,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("master + 2 managers interaction") {
- store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf)
- store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf)
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr)
+ store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf,
+ securityMgr)
val peers = master.getPeers(store.blockManagerId, 1)
assert(peers.size === 1, "master did not return the other manager as a peer")
@@ -171,7 +175,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("removing block") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -219,7 +223,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("removing rdd") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -253,7 +257,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("reregistration on heart beat") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
@@ -269,7 +273,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("reregistration on block update") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
@@ -288,7 +292,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("reregistration doesn't dead lock") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = List(new Array[Byte](400))
@@ -325,7 +329,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU storage") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -344,7 +348,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU storage with serialization") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -363,7 +367,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of same RDD") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -382,7 +386,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of multiple RDDs") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
@@ -405,7 +409,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("on-disk storage") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -418,7 +422,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -433,7 +437,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with getLocalBytes") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -448,7 +452,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -463,7 +467,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization and getLocalBytes") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -478,7 +482,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -503,7 +507,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU with streams") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@@ -527,7 +531,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels and streams") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@@ -573,7 +577,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("overly large block") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 500, conf)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 500, conf, securityMgr)
store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.getSingle("a1") === None, "a1 was in store")
store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK)
@@ -584,7 +588,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("block compression") {
try {
conf.set("spark.shuffle.compress", "true")
- store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100,
"shuffle_0_0_0 was not compressed")
@@ -592,7 +596,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
store = null
conf.set("spark.shuffle.compress", "false")
- store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000,
"shuffle_0_0_0 was compressed")
@@ -600,7 +604,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
store = null
conf.set("spark.broadcast.compress", "true")
- store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100,
"broadcast_0 was not compressed")
@@ -608,28 +612,28 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
store = null
conf.set("spark.broadcast.compress", "false")
- store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed")
store.stop()
store = null
conf.set("spark.rdd.compress", "true")
- store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed")
store.stop()
store = null
conf.set("spark.rdd.compress", "false")
- store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed")
store.stop()
store = null
// Check that any other block types are also kept uncompressed
- store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed")
store.stop()
@@ -643,7 +647,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("block store put failure") {
// Use Java serializer so we can create an unserializable error.
- store = new BlockManager("<driver>", actorSystem, master, new JavaSerializer(conf), 1200, conf)
+ store = new BlockManager("<driver>", actorSystem, master, new JavaSerializer(conf), 1200, conf,
+ securityMgr)
// The put should fail since a1 is not serializable.
class UnserializableClass
diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
index 20ebb1897e..30415814ad 100644
--- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
@@ -24,6 +24,8 @@ import scala.util.{Failure, Success, Try}
import org.eclipse.jetty.server.Server
import org.scalatest.FunSuite
+import org.apache.spark.SparkConf
+
class UISuite extends FunSuite {
test("jetty port increases under contention") {
val startPort = 4040
@@ -34,15 +36,17 @@ class UISuite extends FunSuite {
case Failure(e) =>
// Either case server port is busy hence setup for test complete
}
- val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq())
- val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq())
+ val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq(),
+ new SparkConf)
+ val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq(),
+ new SparkConf)
// Allow some wiggle room in case ports on the machine are under contention
assert(boundPort1 > startPort && boundPort1 < startPort + 10)
assert(boundPort2 > boundPort1 && boundPort2 < boundPort1 + 10)
}
test("jetty binds to port 0 correctly") {
- val (jettyServer, boundPort) = JettyUtils.startJettyServer("0.0.0.0", 0, Seq())
+ val (jettyServer, boundPort) = JettyUtils.startJettyServer("0.0.0.0", 0, Seq(), new SparkConf)
assert(jettyServer.getState === "STARTED")
assert(boundPort != 0)
Try {new ServerSocket(boundPort)} match {