diff options
author | Thomas Graves <tgraves@apache.org> | 2014-03-06 18:27:50 -0600 |
---|---|---|
committer | Thomas Graves <tgraves@apache.org> | 2014-03-06 18:27:50 -0600 |
commit | 7edbea41b43e0dc11a2de156be220db8b7952d01 (patch) | |
tree | 1e7156df70131660d28994795eb3df2c3fd0a819 /core/src | |
parent | 40566e10aae4b21ffc71ea72702b8df118ac5c8e (diff) | |
download | spark-7edbea41b43e0dc11a2de156be220db8b7952d01.tar.gz spark-7edbea41b43e0dc11a2de156be220db8b7952d01.tar.bz2 spark-7edbea41b43e0dc11a2de156be220db8b7952d01.zip |
SPARK-1189: Add Security to Spark - Akka, Http, ConnectionManager, UI use servlets
resubmit pull request. was https://github.com/apache/incubator-spark/pull/332.
Author: Thomas Graves <tgraves@apache.org>
Closes #33 from tgravescs/security-branch-0.9-with-client-rebase and squashes the following commits:
dfe3918 [Thomas Graves] Fix merge conflict since startUserClass now using runAsUser
05eebed [Thomas Graves] Fix dependency lost in upmerge
d1040ec [Thomas Graves] Fix up various imports
05ff5e0 [Thomas Graves] Fix up imports after upmerging to master
ac046b3 [Thomas Graves] Merge remote-tracking branch 'upstream/master' into security-branch-0.9-with-client-rebase
13733e1 [Thomas Graves] Pass securityManager and SparkConf around where we can. Switch to use sparkConf for reading config whereever possible. Added ConnectionManagerSuite unit tests.
4a57acc [Thomas Graves] Change UI createHandler routines to createServlet since they now return servlets
2f77147 [Thomas Graves] Rework from comments
50dd9f2 [Thomas Graves] fix header in SecurityManager
ecbfb65 [Thomas Graves] Fix spacing and formatting
b514bec [Thomas Graves] Fix reference to config
ed3d1c1 [Thomas Graves] Add security.md
6f7ddf3 [Thomas Graves] Convert SaslClient and SaslServer to scala, change spark.authenticate.ui to spark.ui.acls.enable, and fix up various other things from review comments
2d9e23e [Thomas Graves] Merge remote-tracking branch 'upstream/master' into security-branch-0.9-with-client-rebase_rework
5721c5a [Thomas Graves] update AkkaUtilsSuite test for the actorSelection changes, fix typos based on comments, and remove extra lines I missed in rebase from AkkaUtils
f351763 [Thomas Graves] Add Security to Spark - Akka, Http, ConnectionManager, UI to use servlets
Diffstat (limited to 'core/src')
56 files changed, 2027 insertions, 241 deletions
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index d3264a4bb3..3d7692ea8a 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -23,7 +23,7 @@ import com.google.common.io.Files import org.apache.spark.util.Utils -private[spark] class HttpFileServer extends Logging { +private[spark] class HttpFileServer(securityManager: SecurityManager) extends Logging { var baseDir : File = null var fileDir : File = null @@ -38,9 +38,10 @@ private[spark] class HttpFileServer extends Logging { fileDir.mkdir() jarDir.mkdir() logInfo("HTTP File server directory is " + baseDir) - httpServer = new HttpServer(baseDir) + httpServer = new HttpServer(baseDir, securityManager) httpServer.start() serverUri = httpServer.uri + logDebug("HTTP file server started at: " + serverUri) } def stop() { diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 759e68ee0c..cb5df25fa4 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -19,15 +19,18 @@ package org.apache.spark import java.io.File +import org.eclipse.jetty.util.security.{Constraint, Password} +import org.eclipse.jetty.security.authentication.DigestAuthenticator +import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService, SecurityHandler} + import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.bio.SocketConnector -import org.eclipse.jetty.server.handler.DefaultHandler -import org.eclipse.jetty.server.handler.HandlerList -import org.eclipse.jetty.server.handler.ResourceHandler +import org.eclipse.jetty.server.handler.{DefaultHandler, HandlerList, ResourceHandler} import org.eclipse.jetty.util.thread.QueuedThreadPool import org.apache.spark.util.Utils + /** * Exception type thrown by HttpServer when it is in the wrong state for an operation. */ @@ -38,7 +41,8 @@ private[spark] class ServerStateException(message: String) extends Exception(mes * as well as classes created by the interpreter when the user types in code. This is just a wrapper * around a Jetty server. */ -private[spark] class HttpServer(resourceBase: File) extends Logging { +private[spark] class HttpServer(resourceBase: File, securityManager: SecurityManager) + extends Logging { private var server: Server = null private var port: Int = -1 @@ -59,14 +63,60 @@ private[spark] class HttpServer(resourceBase: File) extends Logging { server.setThreadPool(threadPool) val resHandler = new ResourceHandler resHandler.setResourceBase(resourceBase.getAbsolutePath) + val handlerList = new HandlerList handlerList.setHandlers(Array(resHandler, new DefaultHandler)) - server.setHandler(handlerList) + + if (securityManager.isAuthenticationEnabled()) { + logDebug("HttpServer is using security") + val sh = setupSecurityHandler(securityManager) + // make sure we go through security handler to get resources + sh.setHandler(handlerList) + server.setHandler(sh) + } else { + logDebug("HttpServer is not using security") + server.setHandler(handlerList) + } + server.start() port = server.getConnectors()(0).getLocalPort() } } + /** + * Setup Jetty to the HashLoginService using a single user with our + * shared secret. Configure it to use DIGEST-MD5 authentication so that the password + * isn't passed in plaintext. + */ + private def setupSecurityHandler(securityMgr: SecurityManager): ConstraintSecurityHandler = { + val constraint = new Constraint() + // use DIGEST-MD5 as the authentication mechanism + constraint.setName(Constraint.__DIGEST_AUTH) + constraint.setRoles(Array("user")) + constraint.setAuthenticate(true) + constraint.setDataConstraint(Constraint.DC_NONE) + + val cm = new ConstraintMapping() + cm.setConstraint(constraint) + cm.setPathSpec("/*") + val sh = new ConstraintSecurityHandler() + + // the hashLoginService lets us do a single user and + // secret right now. This could be changed to use the + // JAASLoginService for other options. + val hashLogin = new HashLoginService() + + val userCred = new Password(securityMgr.getSecretKey()) + if (userCred == null) { + throw new Exception("Error: secret key is null with authentication on") + } + hashLogin.putUser(securityMgr.getHttpUser(), userCred, Array("user")) + sh.setLoginService(hashLogin) + sh.setAuthenticator(new DigestAuthenticator()); + sh.setConstraintMappings(Array(cm)) + sh + } + def stop() { if (server == null) { throw new ServerStateException("Server is already stopped") diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala new file mode 100644 index 0000000000..591978c1d3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.net.{Authenticator, PasswordAuthentication} +import org.apache.hadoop.io.Text +import org.apache.hadoop.security.Credentials +import org.apache.hadoop.security.UserGroupInformation +import org.apache.spark.deploy.SparkHadoopUtil + +import scala.collection.mutable.ArrayBuffer + +/** + * Spark class responsible for security. + * + * In general this class should be instantiated by the SparkEnv and most components + * should access it from that. There are some cases where the SparkEnv hasn't been + * initialized yet and this class must be instantiated directly. + * + * Spark currently supports authentication via a shared secret. + * Authentication can be configured to be on via the 'spark.authenticate' configuration + * parameter. This parameter controls whether the Spark communication protocols do + * authentication using the shared secret. This authentication is a basic handshake to + * make sure both sides have the same shared secret and are allowed to communicate. + * If the shared secret is not identical they will not be allowed to communicate. + * + * The Spark UI can also be secured by using javax servlet filters. A user may want to + * secure the UI if it has data that other users should not be allowed to see. The javax + * servlet filter specified by the user can authenticate the user and then once the user + * is logged in, Spark can compare that user versus the view acls to make sure they are + * authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' + * control the behavior of the acls. Note that the person who started the application + * always has view access to the UI. + * + * Spark does not currently support encryption after authentication. + * + * At this point spark has multiple communication protocols that need to be secured and + * different underlying mechanisms are used depending on the protocol: + * + * - Akka -> The only option here is to use the Akka Remote secure-cookie functionality. + * Akka remoting allows you to specify a secure cookie that will be exchanged + * and ensured to be identical in the connection handshake between the client + * and the server. If they are not identical then the client will be refused + * to connect to the server. There is no control of the underlying + * authentication mechanism so its not clear if the password is passed in + * plaintext or uses DIGEST-MD5 or some other mechanism. + * Akka also has an option to turn on SSL, this option is not currently supported + * but we could add a configuration option in the future. + * + * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty + * for the HttpServer. Jetty supports multiple authentication mechanisms - + * Basic, Digest, Form, Spengo, etc. It also supports multiple different login + * services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService + * to authenticate using DIGEST-MD5 via a single user and the shared secret. + * Since we are using DIGEST-MD5, the shared secret is not passed on the wire + * in plaintext. + * We currently do not support SSL (https), but Jetty can be configured to use it + * so we could add a configuration option for this in the future. + * + * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5. + * Any clients must specify the user and password. There is a default + * Authenticator installed in the SecurityManager to how it does the authentication + * and in this case gets the user name and password from the request. + * + * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously + * exchange messages. For this we use the Java SASL + * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 + * as the authentication mechanism. This means the shared secret is not passed + * over the wire in plaintext. + * Note that SASL is pluggable as to what mechanism it uses. We currently use + * DIGEST-MD5 but this could be changed to use Kerberos or other in the future. + * Spark currently supports "auth" for the quality of protection, which means + * the connection is not supporting integrity or privacy protection (encryption) + * after authentication. SASL also supports "auth-int" and "auth-conf" which + * SPARK could be support in the future to allow the user to specify the quality + * of protection they want. If we support those, the messages will also have to + * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. + * + * Since the connectionManager does asynchronous messages passing, the SASL + * authentication is a bit more complex. A ConnectionManager can be both a client + * and a Server, so for a particular connection is has to determine what to do. + * A ConnectionId was added to be able to track connections and is used to + * match up incoming messages with connections waiting for authentication. + * If its acting as a client and trying to send a message to another ConnectionManager, + * it blocks the thread calling sendMessage until the SASL negotiation has occurred. + * The ConnectionManager tracks all the sendingConnections using the ConnectionId + * and waits for the response from the server and does the handshake. + * + * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters + * can be used. Yarn requires a specific AmIpFilter be installed for security to work + * properly. For non-Yarn deployments, users can write a filter to go through a + * companies normal login service. If an authentication filter is in place then the + * SparkUI can be configured to check the logged in user against the list of users who + * have view acls to see if that user is authorized. + * The filters can also be used for many different purposes. For instance filters + * could be used for logging, encryption, or compression. + * + * The exact mechanisms used to generate/distributed the shared secret is deployment specific. + * + * For Yarn deployments, the secret is automatically generated using the Akka remote + * Crypt.generateSecureCookie() API. The secret is placed in the Hadoop UGI which gets passed + * around via the Hadoop RPC mechanism. Hadoop RPC can be configured to support different levels + * of protection. See the Hadoop documentation for more details. Each Spark application on Yarn + * gets a different shared secret. On Yarn, the Spark UI gets configured to use the Hadoop Yarn + * AmIpFilter which requires the user to go through the ResourceManager Proxy. That Proxy is there + * to reduce the possibility of web based attacks through YARN. Hadoop can be configured to use + * filters to do authentication. That authentication then happens via the ResourceManager Proxy + * and Spark will use that to do authorization against the view acls. + * + * For other Spark deployments, the shared secret must be specified via the + * spark.authenticate.secret config. + * All the nodes (Master and Workers) and the applications need to have the same shared secret. + * This again is not ideal as one user could potentially affect another users application. + * This should be enhanced in the future to provide better protection. + * If the UI needs to be secured the user needs to install a javax servlet filter to do the + * authentication. Spark will then use that user to compare against the view acls to do + * authorization. If not filter is in place the user is generally null and no authorization + * can take place. + */ + +private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { + + // key used to store the spark secret in the Hadoop UGI + private val sparkSecretLookupKey = "sparkCookie" + + private val authOn = sparkConf.getBoolean("spark.authenticate", false) + private val uiAclsOn = sparkConf.getBoolean("spark.ui.acls.enable", false) + + // always add the current user and SPARK_USER to the viewAcls + private val aclUsers = ArrayBuffer[String](System.getProperty("user.name", ""), + Option(System.getenv("SPARK_USER")).getOrElse("")) + aclUsers ++= sparkConf.get("spark.ui.view.acls", "").split(',') + private val viewAcls = aclUsers.map(_.trim()).filter(!_.isEmpty).toSet + + private val secretKey = generateSecretKey() + logInfo("SecurityManager, is authentication enabled: " + authOn + + " are ui acls enabled: " + uiAclsOn + " users with view permissions: " + viewAcls.toString()) + + // Set our own authenticator to properly negotiate user/password for HTTP connections. + // This is needed by the HTTP client fetching from the HttpServer. Put here so its + // only set once. + if (authOn) { + Authenticator.setDefault( + new Authenticator() { + override def getPasswordAuthentication(): PasswordAuthentication = { + var passAuth: PasswordAuthentication = null + val userInfo = getRequestingURL().getUserInfo() + if (userInfo != null) { + val parts = userInfo.split(":", 2) + passAuth = new PasswordAuthentication(parts(0), parts(1).toCharArray()) + } + return passAuth + } + } + ) + } + + /** + * Generates or looks up the secret key. + * + * The way the key is stored depends on the Spark deployment mode. Yarn + * uses the Hadoop UGI. + * + * For non-Yarn deployments, If the config variable is not set + * we throw an exception. + */ + private def generateSecretKey(): String = { + if (!isAuthenticationEnabled) return null + // first check to see if the secret is already set, else generate a new one if on yarn + val sCookie = if (SparkHadoopUtil.get.isYarnMode) { + val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(sparkSecretLookupKey) + if (secretKey != null) { + logDebug("in yarn mode, getting secret from credentials") + return new Text(secretKey).toString + } else { + logDebug("getSecretKey: yarn mode, secret key from credentials is null") + } + val cookie = akka.util.Crypt.generateSecureCookie + // if we generated the secret then we must be the first so lets set it so t + // gets used by everyone else + SparkHadoopUtil.get.addSecretKeyToUserCredentials(sparkSecretLookupKey, cookie) + logInfo("adding secret to credentials in yarn mode") + cookie + } else { + // user must have set spark.authenticate.secret config + sparkConf.getOption("spark.authenticate.secret") match { + case Some(value) => value + case None => throw new Exception("Error: a secret key must be specified via the " + + "spark.authenticate.secret config") + } + } + sCookie + } + + /** + * Check to see if Acls for the UI are enabled + * @return true if UI authentication is enabled, otherwise false + */ + def uiAclsEnabled(): Boolean = uiAclsOn + + /** + * Checks the given user against the view acl list to see if they have + * authorization to view the UI. If the UI acls must are disabled + * via spark.ui.acls.enable, all users have view access. + * + * @param user to see if is authorized + * @return true is the user has permission, otherwise false + */ + def checkUIViewPermissions(user: String): Boolean = { + if (uiAclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true + } + + /** + * Check to see if authentication for the Spark communication protocols is enabled + * @return true if authentication is enabled, otherwise false + */ + def isAuthenticationEnabled(): Boolean = authOn + + /** + * Gets the user used for authenticating HTTP connections. + * For now use a single hardcoded user. + * @return the HTTP user as a String + */ + def getHttpUser(): String = "sparkHttpUser" + + /** + * Gets the user used for authenticating SASL connections. + * For now use a single hardcoded user. + * @return the SASL user as a String + */ + def getSaslUser(): String = "sparkSaslUser" + + /** + * Gets the secret key. + * @return the secret key as a String if authentication is enabled, otherwise returns null + */ + def getSecretKey(): String = secretKey +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index da778aa851..24731ad706 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -130,6 +130,8 @@ class SparkContext( val isLocal = (master == "local" || master.startsWith("local[")) + if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + // Create the Spark execution environment (cache, map output tracker, etc) private[spark] val env = SparkEnv.create( conf, @@ -634,7 +636,7 @@ class SparkContext( addedFiles(key) = System.currentTimeMillis // Fetch the file locally in case a job is executed using DAGScheduler.runLocally(). - Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf) + Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf, env.securityManager) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 7ac65828f6..5e43b51984 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -53,7 +53,8 @@ class SparkEnv private[spark] ( val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, - val conf: SparkConf) extends Logging { + val conf: SparkConf, + val securityManager: SecurityManager) extends Logging { // A mapping of thread ID to amount of memory used for shuffle in bytes // All accesses should be manually synchronized @@ -122,8 +123,9 @@ object SparkEnv extends Logging { isDriver: Boolean, isLocal: Boolean): SparkEnv = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, - conf = conf) + val securityManager = new SecurityManager(conf) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, conf = conf, + securityManager = securityManager) // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port), // figure out which port number Akka actually bound to and set spark.driver.port to it. @@ -139,7 +141,6 @@ object SparkEnv extends Logging { val name = conf.get(propertyName, defaultClassName) Class.forName(name, true, classLoader).newInstance().asInstanceOf[T] } - val serializerManager = new SerializerManager val serializer = serializerManager.setDefault( @@ -167,12 +168,12 @@ object SparkEnv extends Logging { val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf)), conf) - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf) + val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, + serializer, conf, securityManager) val connectionManager = blockManager.connectionManager - val broadcastManager = new BroadcastManager(isDriver, conf) + val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) val cacheManager = new CacheManager(blockManager) @@ -190,14 +191,14 @@ object SparkEnv extends Logging { val shuffleFetcher = instantiateClass[ShuffleFetcher]( "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher") - val httpFileServer = new HttpFileServer() + val httpFileServer = new HttpFileServer(securityManager) httpFileServer.initialize() conf.set("spark.fileserver.uri", httpFileServer.serverUri) val metricsSystem = if (isDriver) { - MetricsSystem.createMetricsSystem("driver", conf) + MetricsSystem.createMetricsSystem("driver", conf, securityManager) } else { - MetricsSystem.createMetricsSystem("executor", conf) + MetricsSystem.createMetricsSystem("executor", conf, securityManager) } metricsSystem.start() @@ -231,6 +232,7 @@ object SparkEnv extends Logging { httpFileServer, sparkFilesDir, metricsSystem, - conf) + conf, + securityManager) } } diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala new file mode 100644 index 0000000000..a2a871cbd3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.IOException +import javax.security.auth.callback.Callback +import javax.security.auth.callback.CallbackHandler +import javax.security.auth.callback.NameCallback +import javax.security.auth.callback.PasswordCallback +import javax.security.auth.callback.UnsupportedCallbackException +import javax.security.sasl.RealmCallback +import javax.security.sasl.RealmChoiceCallback +import javax.security.sasl.Sasl +import javax.security.sasl.SaslClient +import javax.security.sasl.SaslException + +import scala.collection.JavaConversions.mapAsJavaMap + +/** + * Implements SASL Client logic for Spark + */ +private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging { + + /** + * Used to respond to server's counterpart, SaslServer with SASL tokens + * represented as byte arrays. + * + * The authentication mechanism used here is DIGEST-MD5. This could be changed to be + * configurable in the future. + */ + private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST), + null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, + new SparkSaslClientCallbackHandler(securityMgr)) + + /** + * Used to initiate SASL handshake with server. + * @return response to challenge if needed + */ + def firstToken(): Array[Byte] = { + synchronized { + val saslToken: Array[Byte] = + if (saslClient != null && saslClient.hasInitialResponse()) { + logDebug("has initial response") + saslClient.evaluateChallenge(new Array[Byte](0)) + } else { + new Array[Byte](0) + } + saslToken + } + } + + /** + * Determines whether the authentication exchange has completed. + * @return true is complete, otherwise false + */ + def isComplete(): Boolean = { + synchronized { + if (saslClient != null) saslClient.isComplete() else false + } + } + + /** + * Respond to server's SASL token. + * @param saslTokenMessage contains server's SASL token + * @return client's response SASL token + */ + def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = { + synchronized { + if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0) + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslClient might be using. + */ + def dispose() { + synchronized { + if (saslClient != null) { + try { + saslClient.dispose() + } catch { + case e: SaslException => // ignored + } finally { + saslClient = null + } + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler + * that works with share secrets. + */ + private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends + CallbackHandler { + + private val userName: String = + SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes()) + private val secretKey = securityMgr.getSecretKey() + private val userPassword: Array[Char] = + SparkSaslServer.encodePassword(if (secretKey != null) secretKey.getBytes() else "".getBytes()) + + /** + * Implementation used to respond to SASL request from the server. + * + * @param callbacks objects that indicate what credential information the + * server's SaslServer requires from the client. + */ + override def handle(callbacks: Array[Callback]) { + logDebug("in the sasl client callback handler") + callbacks foreach { + case nc: NameCallback => { + logDebug("handle: SASL client callback: setting username: " + userName) + nc.setName(userName) + } + case pc: PasswordCallback => { + logDebug("handle: SASL client callback: setting userPassword") + pc.setPassword(userPassword) + } + case rc: RealmCallback => { + logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText()) + rc.setText(rc.getDefaultText()) + } + case cb: RealmChoiceCallback => {} + case cb: Callback => throw + new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback") + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala new file mode 100644 index 0000000000..11fcb2ae3a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import javax.security.auth.callback.Callback +import javax.security.auth.callback.CallbackHandler +import javax.security.auth.callback.NameCallback +import javax.security.auth.callback.PasswordCallback +import javax.security.auth.callback.UnsupportedCallbackException +import javax.security.sasl.AuthorizeCallback +import javax.security.sasl.RealmCallback +import javax.security.sasl.Sasl +import javax.security.sasl.SaslException +import javax.security.sasl.SaslServer +import scala.collection.JavaConversions.mapAsJavaMap +import org.apache.commons.net.util.Base64 + +/** + * Encapsulates SASL server logic + */ +private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging { + + /** + * Actual SASL work done by this object from javax.security.sasl. + */ + private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null, + SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, + new SparkSaslDigestCallbackHandler(securityMgr)) + + /** + * Determines whether the authentication exchange has completed. + * @return true is complete, otherwise false + */ + def isComplete(): Boolean = { + synchronized { + if (saslServer != null) saslServer.isComplete() else false + } + } + + /** + * Used to respond to server SASL tokens. + * @param token Server's SASL token + * @return response to send back to the server. + */ + def response(token: Array[Byte]): Array[Byte] = { + synchronized { + if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0) + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslServer might be using. + */ + def dispose() { + synchronized { + if (saslServer != null) { + try { + saslServer.dispose() + } catch { + case e: SaslException => // ignore + } finally { + saslServer = null + } + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler + * for SASL DIGEST-MD5 mechanism + */ + private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager) + extends CallbackHandler { + + private val userName: String = + SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes()) + + override def handle(callbacks: Array[Callback]) { + logDebug("In the sasl server callback handler") + callbacks foreach { + case nc: NameCallback => { + logDebug("handle: SASL server callback: setting username") + nc.setName(userName) + } + case pc: PasswordCallback => { + logDebug("handle: SASL server callback: setting userPassword") + val password: Array[Char] = + SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes()) + pc.setPassword(password) + } + case rc: RealmCallback => { + logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText()) + rc.setText(rc.getDefaultText()) + } + case ac: AuthorizeCallback => { + val authid = ac.getAuthenticationID() + val authzid = ac.getAuthorizationID() + if (authid.equals(authzid)) { + logDebug("set auth to true") + ac.setAuthorized(true) + } else { + logDebug("set auth to false") + ac.setAuthorized(false) + } + if (ac.isAuthorized()) { + logDebug("sasl server is authorized") + ac.setAuthorizedID(authzid) + } + } + case cb: Callback => throw + new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback") + } + } + } +} + +private[spark] object SparkSaslServer { + + /** + * This is passed as the server name when creating the sasl client/server. + * This could be changed to be configurable in the future. + */ + val SASL_DEFAULT_REALM = "default" + + /** + * The authentication mechanism used here is DIGEST-MD5. This could be changed to be + * configurable in the future. + */ + val DIGEST = "DIGEST-MD5" + + /** + * The quality of protection is just "auth". This means that we are doing + * authentication only, we are not supporting integrity or privacy protection of the + * communication channel after authentication. This could be changed to be configurable + * in the future. + */ + val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true") + + /** + * Encode a byte[] identifier as a Base64-encoded string. + * + * @param identifier identifier to encode + * @return Base64-encoded string + */ + def encodeIdentifier(identifier: Array[Byte]): String = { + new String(Base64.encodeBase64(identifier)) + } + + /** + * Encode a password as a base64-encoded char[] array. + * @param password as a byte array. + * @return password as a char array. + */ + def encodePassword(password: Array[Byte]): Array[Char] = { + new String(Base64.encodeBase64(password)).toCharArray() + } +} + diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index d113d40405..e3c3a12d16 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -60,7 +60,8 @@ abstract class Broadcast[T](val id: Long) extends Serializable { } private[spark] -class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging with Serializable { +class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager) + extends Logging with Serializable { private var initialized = false private var broadcastFactory: BroadcastFactory = null @@ -78,7 +79,7 @@ class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] // Initialize appropriate BroadcastFactory and BroadcastObject - broadcastFactory.initialize(isDriver, conf) + broadcastFactory.initialize(isDriver, conf, securityManager) initialized = true } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 940e5ab805..6beecaeced 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.broadcast +import org.apache.spark.SecurityManager import org.apache.spark.SparkConf @@ -26,7 +27,7 @@ import org.apache.spark.SparkConf * entire Spark job. */ trait BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf): Unit + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 20207c2613..e8eb04bb10 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -18,13 +18,13 @@ package org.apache.spark.broadcast import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} -import java.net.URL +import java.net.{URL, URLConnection, URI} import java.util.concurrent.TimeUnit import it.unimi.dsi.fastutil.io.FastBufferedInputStream import it.unimi.dsi.fastutil.io.FastBufferedOutputStream -import org.apache.spark.{HttpServer, Logging, SparkConf, SparkEnv} +import org.apache.spark.{SparkConf, HttpServer, Logging, SecurityManager, SparkEnv} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} @@ -67,7 +67,9 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium. */ class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf) { HttpBroadcast.initialize(isDriver, conf) } + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + HttpBroadcast.initialize(isDriver, conf, securityMgr) + } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new HttpBroadcast[T](value_, isLocal, id) @@ -83,6 +85,7 @@ private object HttpBroadcast extends Logging { private var bufferSize: Int = 65536 private var serverUri: String = null private var server: HttpServer = null + private var securityManager: SecurityManager = null // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist private val files = new TimeStampedHashSet[String] @@ -92,11 +95,12 @@ private object HttpBroadcast extends Logging { private var compressionCodec: CompressionCodec = null - def initialize(isDriver: Boolean, conf: SparkConf) { + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { synchronized { if (!initialized) { bufferSize = conf.getInt("spark.buffer.size", 65536) compress = conf.getBoolean("spark.broadcast.compress", true) + securityManager = securityMgr if (isDriver) { createServer(conf) conf.set("spark.httpBroadcast.uri", serverUri) @@ -126,7 +130,7 @@ private object HttpBroadcast extends Logging { private def createServer(conf: SparkConf) { broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf)) - server = new HttpServer(broadcastDir) + server = new HttpServer(broadcastDir, securityManager) server.start() serverUri = server.uri logInfo("Broadcast server started at " + serverUri) @@ -149,11 +153,23 @@ private object HttpBroadcast extends Logging { } def read[T](id: Long): T = { + logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) val url = serverUri + "/" + BroadcastBlockId(id).name + + var uc: URLConnection = null + if (securityManager.isAuthenticationEnabled()) { + logDebug("broadcast security enabled") + val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) + uc = newuri.toURL().openConnection() + uc.setAllowUserInteraction(false) + } else { + logDebug("broadcast not using security") + uc = new URL(url).openConnection() + } + val in = { - val httpConnection = new URL(url).openConnection() - httpConnection.setReadTimeout(httpReadTimeout) - val inputStream = httpConnection.getInputStream + uc.setReadTimeout(httpReadTimeout) + val inputStream = uc.getInputStream(); if (compress) { compressionCodec.compressedInputStream(inputStream) } else { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 22d783c859..3cd7121376 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -241,7 +241,9 @@ private[spark] case class TorrentInfo( */ class TorrentBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) } + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + TorrentBroadcast.initialize(isDriver, conf) + } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new TorrentBroadcast[T](value_, isLocal, id) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index eb5676b51d..d9e3035e1a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -26,7 +26,7 @@ import akka.pattern.ask import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} import org.apache.log4j.{Level, Logger} -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.util.{AkkaUtils, Utils} @@ -141,7 +141,7 @@ object Client { // TODO: See if we can initialize akka so return messages are sent back using the same TCP // flow. Else, this (sadly) requires the DriverClient be routable from the Master. val (actorSystem, _) = AkkaUtils.createActorSystem( - "driverClient", Utils.localHostName(), 0, false, conf) + "driverClient", Utils.localHostName(), 0, false, conf, new SecurityManager(conf)) actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index ec15647e1d..d2d8d6d662 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -21,6 +21,7 @@ import java.security.PrivilegedExceptionAction import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{SparkContext, SparkException} @@ -65,6 +66,15 @@ class SparkHadoopUtil { def addCredentials(conf: JobConf) {} def isYarnMode(): Boolean = { false } + + def getCurrentUserCredentials(): Credentials = { null } + + def addCurrentUserCredentials(creds: Credentials) {} + + def addSecretKeyToUserCredentials(key: String, secret: String) {} + + def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { null } + } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 1550c3eb42..63f166d401 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.client -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.util.{AkkaUtils, Utils} @@ -45,8 +45,9 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) + val conf = new SparkConf val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, - conf = new SparkConf) + conf = conf, securityManager = new SecurityManager(conf)) val desc = new ApplicationDescription( "TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), Some("dummy-spark-home"), "ignored") diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 51794ce40c..2d6d0c33fa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -30,7 +30,7 @@ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.SerializationExtension -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.DriverState.DriverState @@ -39,7 +39,8 @@ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{AkkaUtils, Utils} -private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { +private[spark] class Master(host: String, port: Int, webUiPort: Int, + val securityMgr: SecurityManager) extends Actor with Logging { import context.dispatcher // to use Akka's scheduler.schedule() val conf = new SparkConf @@ -70,8 +71,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act Utils.checkHost(host, "Expected hostname") - val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf) - val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf) + val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) + val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, + securityMgr) val masterSource = new MasterSource(this) val webUi = new MasterWebUI(this, webUiPort) @@ -711,8 +713,11 @@ private[spark] object Master { def startSystemAndActor(host: String, port: Int, webUiPort: Int, conf: SparkConf) : (ActorSystem, Int, Int) = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf) - val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), actorName) + val securityMgr = new SecurityManager(conf) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, + securityManager = securityMgr) + val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort, + securityMgr), actorName) val timeout = AkkaUtils.askTimeout(conf) val respFuture = actor.ask(RequestWebUIPort)(timeout) val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse] diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 5ab13e7aa6..a7bd01e284 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -18,8 +18,8 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.{Handler, Server} +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.Logging import org.apache.spark.deploy.master.Master @@ -46,7 +46,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { def start() { try { - val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers) + val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, master.conf) server = Some(srv) boundPort = Some(bPort) logInfo("Started Master web UI at http://%s:%d".format(host, boundPort.get)) @@ -60,12 +60,17 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { val metricsHandlers = master.masterMetricsSystem.getServletHandlers ++ master.applicationMetricsSystem.getServletHandlers - val handlers = metricsHandlers ++ Array[(String, Handler)]( - ("/static", createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR)), - ("/app/json", (request: HttpServletRequest) => applicationPage.renderJson(request)), - ("/app", (request: HttpServletRequest) => applicationPage.render(request)), - ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), - ("*", (request: HttpServletRequest) => indexPage.render(request)) + val handlers = metricsHandlers ++ Seq[ServletContextHandler]( + createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static/*"), + createServletHandler("/app/json", + createServlet((request: HttpServletRequest) => applicationPage.renderJson(request), + master.securityMgr)), + createServletHandler("/app", createServlet((request: HttpServletRequest) => applicationPage + .render(request), master.securityMgr)), + createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage + .renderJson(request), master.securityMgr)), + createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render + (request), master.securityMgr)) ) def stop() { @@ -74,5 +79,5 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { } private[spark] object MasterWebUI { - val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" + val STATIC_RESOURCE_DIR = "org/apache/spark/ui" } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index a26e47950a..be15138f62 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker import akka.actor._ -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.util.{AkkaUtils, Utils} /** @@ -29,8 +29,9 @@ object DriverWrapper { def main(args: Array[String]) { args.toList match { case workerUrl :: mainClass :: extraArgs => + val conf = new SparkConf() val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", - Utils.localHostName(), 0, false, new SparkConf()) + Utils.localHostName(), 0, false, conf, new SecurityManager(conf)) actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher") // Delegate to supplied main class diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 7b0b7861b7..afaabedffe 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -27,7 +27,7 @@ import scala.concurrent.duration._ import akka.actor._ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} @@ -48,7 +48,8 @@ private[spark] class Worker( actorSystemName: String, actorName: String, workDirPath: String = null, - val conf: SparkConf) + val conf: SparkConf, + val securityMgr: SecurityManager) extends Actor with Logging { import context.dispatcher @@ -91,7 +92,7 @@ private[spark] class Worker( var coresUsed = 0 var memoryUsed = 0 - val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf) + val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) val workerSource = new WorkerSource(this) def coresFree: Int = cores - coresUsed @@ -347,10 +348,11 @@ private[spark] object Worker { val conf = new SparkConf val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val actorName = "Worker" + val securityMgr = new SecurityManager(conf) val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, - conf = conf) + conf = conf, securityManager = securityMgr) actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterUrls, systemName, actorName, workDir, conf), name = actorName) + masterUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) (actorSystem, boundPort) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index bdf126f93a..ffc05bd306 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -19,8 +19,8 @@ package org.apache.spark.deploy.worker.ui import java.io.File import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.{Handler, Server} +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.Logging import org.apache.spark.deploy.worker.Worker @@ -33,7 +33,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} */ private[spark] class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[Int] = None) - extends Logging { + extends Logging { val timeout = AkkaUtils.askTimeout(worker.conf) val host = Utils.localHostName() val port = requestedPort.getOrElse( @@ -46,17 +46,21 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I val metricsHandlers = worker.metricsSystem.getServletHandlers - val handlers = metricsHandlers ++ Array[(String, Handler)]( - ("/static", createStaticHandler(WorkerWebUI.STATIC_RESOURCE_DIR)), - ("/log", (request: HttpServletRequest) => log(request)), - ("/logPage", (request: HttpServletRequest) => logPage(request)), - ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), - ("*", (request: HttpServletRequest) => indexPage.render(request)) + val handlers = metricsHandlers ++ Seq[ServletContextHandler]( + createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static/*"), + createServletHandler("/log", createServlet((request: HttpServletRequest) => log(request), + worker.securityMgr)), + createServletHandler("/logPage", createServlet((request: HttpServletRequest) => logPage + (request), worker.securityMgr)), + createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage + .renderJson(request), worker.securityMgr)), + createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render + (request), worker.securityMgr)) ) def start() { try { - val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers) + val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, worker.conf) server = Some(srv) boundPort = Some(bPort) logInfo("Started Worker web UI at http://%s:%d".format(host, bPort)) @@ -198,6 +202,6 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I } private[spark] object WorkerWebUI { - val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" + val STATIC_RESOURCE_BASE = "org/apache/spark/ui" val DEFAULT_PORT="8081" } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 0aae569b17..3486092a14 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import akka.actor._ import akka.remote._ -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ @@ -97,10 +97,11 @@ private[spark] object CoarseGrainedExecutorBackend { // Debug code Utils.checkHost(hostname) + val conf = new SparkConf // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0, - indestructible = true, conf = new SparkConf) + indestructible = true, conf = conf, new SecurityManager(conf)) // set it val sparkHostPort = hostname + ":" + boundPort actorSystem.actorOf( diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 989d666f15..e69f6f72d3 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -69,11 +69,6 @@ private[spark] class Executor( conf.set("spark.local.dir", getYarnLocalDirs()) } - // Create our ClassLoader and set it on this thread - private val urlClassLoader = createClassLoader() - private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) - Thread.currentThread.setContextClassLoader(replClassLoader) - if (!isLocal) { // Setup an uncaught exception handler for non-local mode. // Make any thread terminations due to uncaught exceptions kill the entire @@ -117,6 +112,12 @@ private[spark] class Executor( } } + // Create our ClassLoader and set it on this thread + // do this after SparkEnv creation so can access the SecurityManager + private val urlClassLoader = createClassLoader() + private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) + Thread.currentThread.setContextClassLoader(replClassLoader) + // Akka's message frame size. If task result is bigger than this, we use the block manager // to send the result back. private val akkaFrameSize = { @@ -338,12 +339,12 @@ private[spark] class Executor( // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager) currentFiles(name) = timestamp } for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager) currentJars(name) = timestamp // Add it to our class loader val localName = name.split("/").last diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 966c092124..c5bda2078f 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.metrics.sink.{MetricsServlet, Sink} import org.apache.spark.metrics.source.Source @@ -64,7 +64,7 @@ import org.apache.spark.metrics.source.Source * [options] is the specific property of this source or sink. */ private[spark] class MetricsSystem private (val instance: String, - conf: SparkConf) extends Logging { + conf: SparkConf, securityMgr: SecurityManager) extends Logging { val confFile = conf.get("spark.metrics.conf", null) val metricsConfig = new MetricsConfig(Option(confFile)) @@ -131,8 +131,8 @@ private[spark] class MetricsSystem private (val instance: String, val classPath = kv._2.getProperty("class") try { val sink = Class.forName(classPath) - .getConstructor(classOf[Properties], classOf[MetricRegistry]) - .newInstance(kv._2, registry) + .getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager]) + .newInstance(kv._2, registry, securityMgr) if (kv._1 == "servlet") { metricsServlet = Some(sink.asInstanceOf[MetricsServlet]) } else { @@ -160,6 +160,7 @@ private[spark] object MetricsSystem { } } - def createMetricsSystem(instance: String, conf: SparkConf): MetricsSystem = - new MetricsSystem(instance, conf) + def createMetricsSystem(instance: String, conf: SparkConf, + securityMgr: SecurityManager): MetricsSystem = + new MetricsSystem(instance, conf, securityMgr) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala index 98fa1dbd7c..4d2ffc54d8 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala @@ -22,9 +22,11 @@ import java.util.concurrent.TimeUnit import com.codahale.metrics.{ConsoleReporter, MetricRegistry} +import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class ConsoleSink(val property: Properties, val registry: MetricRegistry) extends Sink { +class ConsoleSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val CONSOLE_DEFAULT_PERIOD = 10 val CONSOLE_DEFAULT_UNIT = "SECONDS" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala index 40f64768e6..319f40815d 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -23,9 +23,11 @@ import java.util.concurrent.TimeUnit import com.codahale.metrics.{CsvReporter, MetricRegistry} +import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class CsvSink(val property: Properties, val registry: MetricRegistry) extends Sink { +class CsvSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val CSV_KEY_PERIOD = "period" val CSV_KEY_UNIT = "unit" val CSV_KEY_DIR = "directory" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala index 410ca0704b..cd37317da7 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala @@ -24,9 +24,11 @@ import com.codahale.metrics.MetricRegistry import com.codahale.metrics.ganglia.GangliaReporter import info.ganglia.gmetric4j.gmetric.GMetric +import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class GangliaSink(val property: Properties, val registry: MetricRegistry) extends Sink { +class GangliaSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val GANGLIA_KEY_PERIOD = "period" val GANGLIA_DEFAULT_PERIOD = 10 diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index e09be00142..0ffdf3846d 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -24,9 +24,11 @@ import java.util.concurrent.TimeUnit import com.codahale.metrics.MetricRegistry import com.codahale.metrics.graphite.{Graphite, GraphiteReporter} +import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink { +class GraphiteSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val GRAPHITE_DEFAULT_PERIOD = 10 val GRAPHITE_DEFAULT_UNIT = "SECONDS" val GRAPHITE_DEFAULT_PREFIX = "" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala index b5cf210af2..3b5edd5c37 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala @@ -20,8 +20,11 @@ package org.apache.spark.metrics.sink import java.util.Properties import com.codahale.metrics.{JmxReporter, MetricRegistry} +import org.apache.spark.SecurityManager + +class JmxSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { -class JmxSink(val property: Properties, val registry: MetricRegistry) extends Sink { val reporter: JmxReporter = JmxReporter.forRegistry(registry).build() override def start() { diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 3cdfe26d40..3110eccdee 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -19,16 +19,19 @@ package org.apache.spark.metrics.sink import java.util.Properties import java.util.concurrent.TimeUnit + import javax.servlet.http.HttpServletRequest import com.codahale.metrics.MetricRegistry import com.codahale.metrics.json.MetricsModule import com.fasterxml.jackson.databind.ObjectMapper -import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler +import org.apache.spark.SecurityManager import org.apache.spark.ui.JettyUtils -class MetricsServlet(val property: Properties, val registry: MetricRegistry) extends Sink { +class MetricsServlet(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val SERVLET_KEY_PATH = "path" val SERVLET_KEY_SAMPLE = "sample" @@ -42,8 +45,11 @@ class MetricsServlet(val property: Properties, val registry: MetricRegistry) ext val mapper = new ObjectMapper().registerModule( new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - def getHandlers = Array[(String, Handler)]( - (servletPath, JettyUtils.createHandler(request => getMetricsSnapshot(request), "text/json")) + def getHandlers = Array[ServletContextHandler]( + JettyUtils.createServletHandler(servletPath, + JettyUtils.createServlet( + new JettyUtils.ServletParams(request => getMetricsSnapshot(request), "text/json"), + securityMgr) ) ) def getMetricsSnapshot(request: HttpServletRequest): String = { diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala index d3c09b1606..04df2f3b0d 100644 --- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala @@ -45,9 +45,10 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: throw new Exception("Max chunk size is " + maxChunkSize) } + val security = if (isSecurityNeg) 1 else 0 if (size == 0 && !gotChunkForSendingOnce) { val newChunk = new MessageChunk( - new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) + new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null) gotChunkForSendingOnce = true return Some(newChunk) } @@ -65,7 +66,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: } buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) gotChunkForSendingOnce = true return Some(newChunk) } @@ -79,6 +80,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: throw new Exception("Attempting to get chunk from message with multiple data buffers") } val buffer = buffers(0) + val security = if (isSecurityNeg) 1 else 0 if (buffer.remaining > 0) { if (buffer.remaining < chunkSize) { throw new Exception("Not enough space in data buffer for receiving chunk") @@ -86,7 +88,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) return Some(newChunk) } None diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala index 8219a185ea..8fd9c2b87d 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -17,6 +17,11 @@ package org.apache.spark.network +import org.apache.spark._ +import org.apache.spark.SparkSaslServer + +import scala.collection.mutable.{HashMap, Queue, ArrayBuffer} + import java.net._ import java.nio._ import java.nio.channels._ @@ -27,13 +32,16 @@ import org.apache.spark._ private[spark] abstract class Connection(val channel: SocketChannel, val selector: Selector, - val socketRemoteConnectionManagerId: ConnectionManagerId) + val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId) extends Logging { - def this(channel_ : SocketChannel, selector_ : Selector) = { + var sparkSaslServer: SparkSaslServer = null + var sparkSaslClient: SparkSaslClient = null + + def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = { this(channel_, selector_, ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress])) + channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]), id_) } channel.configureBlocking(false) @@ -49,6 +57,16 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, val remoteAddress = getRemoteAddress() + /** + * Used to synchronize client requests: client's work-related requests must + * wait until SASL authentication completes. + */ + private val authenticated = new Object() + + def getAuthenticated(): Object = authenticated + + def isSaslComplete(): Boolean + def resetForceReregister(): Boolean // Read channels typically do not register for write and write does not for read @@ -69,6 +87,16 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, // Will be true for ReceivingConnection, false for SendingConnection. def changeInterestForRead(): Boolean + private def disposeSasl() { + if (sparkSaslServer != null) { + sparkSaslServer.dispose(); + } + + if (sparkSaslClient != null) { + sparkSaslClient.dispose() + } + } + // On receiving a write event, should we change the interest for this channel or not ? // Will be false for ReceivingConnection, true for SendingConnection. // Actually, for now, should not get triggered for ReceivingConnection @@ -101,6 +129,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, k.cancel() } channel.close() + disposeSasl() callOnCloseCallback() } @@ -168,8 +197,12 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector, - remoteId_ : ConnectionManagerId) - extends Connection(SocketChannel.open, selector_, remoteId_) { + remoteId_ : ConnectionManagerId, id_ : ConnectionId) + extends Connection(SocketChannel.open, selector_, remoteId_, id_) { + + def isSaslComplete(): Boolean = { + if (sparkSaslClient != null) sparkSaslClient.isComplete() else false + } private class Outbox { val messages = new Queue[Message]() @@ -226,6 +259,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, data as detailed in https://github.com/mesos/spark/pull/791 */ private var needForceReregister = false + val currentBuffers = new ArrayBuffer[ByteBuffer]() /*channel.socket.setSendBufferSize(256 * 1024)*/ @@ -316,6 +350,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // If we have 'seen' pending messages, then reset flag - since we handle that as // normal registering of event (below) if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister() + currentBuffers ++= buffers } case None => { @@ -384,8 +419,15 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // Must be created within selector loop - else deadlock -private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) - extends Connection(channel_, selector_) { +private[spark] class ReceivingConnection( + channel_ : SocketChannel, + selector_ : Selector, + id_ : ConnectionId) + extends Connection(channel_, selector_, id_) { + + def isSaslComplete(): Boolean = { + if (sparkSaslServer != null) sparkSaslServer.isComplete() else false + } class Inbox() { val messages = new HashMap[Int, BufferMessage]() @@ -396,6 +438,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S val newMessage = Message.create(header).asInstanceOf[BufferMessage] newMessage.started = true newMessage.startTime = System.currentTimeMillis + newMessage.isSecurityNeg = header.securityNeg == 1 logDebug( "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") messages += ((newMessage.id, newMessage)) @@ -441,7 +484,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S val inbox = new Inbox() val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) - var onReceiveCallback: (Connection , Message) => Unit = null + var onReceiveCallback: (Connection, Message) => Unit = null var currentChunk: MessageChunk = null channel.register(selector, SelectionKey.OP_READ) @@ -516,7 +559,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S } } } catch { - case e: Exception => { + case e: Exception => { logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala new file mode 100644 index 0000000000..ffaab677d4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { + override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId +} + +private[spark] object ConnectionId { + + def createConnectionIdFromString(connectionIdString: String): ConnectionId = { + val res = connectionIdString.split("_").map(_.trim()) + if (res.size != 3) { + throw new Exception("Error converting ConnectionId string: " + connectionIdString + + " to a ConnectionId Object") + } + new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index a7f20f8c51..a75130cba2 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -21,6 +21,9 @@ import java.net._ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ +import java.net._ +import java.util.concurrent.atomic.AtomicInteger + import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} import scala.collection.mutable.ArrayBuffer @@ -28,13 +31,15 @@ import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet import scala.collection.mutable.SynchronizedMap import scala.collection.mutable.SynchronizedQueue + import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.concurrent.duration._ import org.apache.spark._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{SystemClock, Utils} -private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Logging { +private[spark] class ConnectionManager(port: Int, conf: SparkConf, + securityManager: SecurityManager) extends Logging { class MessageStatus( val message: Message, @@ -50,6 +55,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi private val selector = SelectorProvider.provider.openSelector() + // default to 30 second timeout waiting for authentication + private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30) + private val handleMessageExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.handler.threads.min", 20), conf.getInt("spark.core.connection.handler.threads.max", 60), @@ -71,6 +79,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi new LinkedBlockingDeque[Runnable]()) private val serverChannel = ServerSocketChannel.open() + // used to track the SendingConnections waiting to do SASL negotiation + private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection] + with SynchronizedMap[ConnectionId, SendingConnection] private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] @@ -84,6 +95,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null + private val authEnabled = securityManager.isAuthenticationEnabled() + serverChannel.configureBlocking(false) serverChannel.socket.setReuseAddress(true) serverChannel.socket.setReceiveBufferSize(256 * 1024) @@ -94,6 +107,10 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) + // used in combination with the ConnectionManagerId to create unique Connection ids + // to be able to track asynchronous messages + private val idCount: AtomicInteger = new AtomicInteger(1) + private val selectorThread = new Thread("connection-manager-thread") { override def run() = ConnectionManager.this.run() } @@ -125,7 +142,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi } finally { writeRunnableStarted.synchronized { writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() + val needReregister = register || conn.resetForceReregister() if (needReregister && conn.changeInterestForWrite()) { conn.registerInterest() } @@ -372,7 +389,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi // accept them all in a tight loop. non blocking accept with no processing, should be fine while (newChannel != null) { try { - val newConnection = new ReceivingConnection(newChannel, selector) + val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) + val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId) newConnection.onReceive(receiveMessage) addListeners(newConnection) addConnection(newConnection) @@ -406,6 +424,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi logInfo("Removing SendingConnection to " + sendingConnectionManagerId) connectionsById -= sendingConnectionManagerId + connectionsAwaitingSasl -= connection.connectionId messageStatuses.synchronized { messageStatuses @@ -481,7 +500,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi val creationTime = System.currentTimeMillis def run() { logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message) + handleMessage(connectionManagerId, message, connection) logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") } } @@ -489,10 +508,133 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi /*handleMessage(connection, message)*/ } - private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) { + private def handleClientAuthentication( + waitingConn: SendingConnection, + securityMsg: SecurityMessage, + connectionId : ConnectionId) { + if (waitingConn.isSaslComplete()) { + logDebug("Client sasl completed for id: " + waitingConn.connectionId) + connectionsAwaitingSasl -= waitingConn.connectionId + waitingConn.getAuthenticated().synchronized { + waitingConn.getAuthenticated().notifyAll(); + } + return + } else { + var replyToken : Array[Byte] = null + try { + replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken); + if (waitingConn.isSaslComplete()) { + logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) + connectionsAwaitingSasl -= waitingConn.connectionId + waitingConn.getAuthenticated().synchronized { + waitingConn.getAuthenticated().notifyAll() + } + return + } + var securityMsgResp = SecurityMessage.fromResponse(replyToken, + securityMsg.getConnectionId.toString()) + var message = securityMsgResp.toBufferMessage + if (message == null) throw new Exception("Error creating security message") + sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) + } catch { + case e: Exception => { + logError("Error handling sasl client authentication", e) + waitingConn.close() + throw new Exception("Error evaluating sasl response: " + e) + } + } + } + } + + private def handleServerAuthentication( + connection: Connection, + securityMsg: SecurityMessage, + connectionId: ConnectionId) { + if (!connection.isSaslComplete()) { + logDebug("saslContext not established") + var replyToken : Array[Byte] = null + try { + connection.synchronized { + if (connection.sparkSaslServer == null) { + logDebug("Creating sasl Server") + connection.sparkSaslServer = new SparkSaslServer(securityManager) + } + } + replyToken = connection.sparkSaslServer.response(securityMsg.getToken) + if (connection.isSaslComplete()) { + logDebug("Server sasl completed: " + connection.connectionId) + } else { + logDebug("Server sasl not completed: " + connection.connectionId) + } + if (replyToken != null) { + var securityMsgResp = SecurityMessage.fromResponse(replyToken, + securityMsg.getConnectionId) + var message = securityMsgResp.toBufferMessage + if (message == null) throw new Exception("Error creating security Message") + sendSecurityMessage(connection.getRemoteConnectionManagerId(), message) + } + } catch { + case e: Exception => { + logError("Error in server auth negotiation: " + e) + // It would probably be better to send an error message telling other side auth failed + // but for now just close + connection.close() + } + } + } else { + logDebug("connection already established for this connection id: " + connection.connectionId) + } + } + + + private def handleAuthentication(conn: Connection, bufferMessage: BufferMessage): Boolean = { + if (bufferMessage.isSecurityNeg) { + logDebug("This is security neg message") + + // parse as SecurityMessage + val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage) + val connectionId = ConnectionId.createConnectionIdFromString(securityMsg.getConnectionId) + + connectionsAwaitingSasl.get(connectionId) match { + case Some(waitingConn) => { + // Client - this must be in response to us doing Send + logDebug("Client handleAuth for id: " + waitingConn.connectionId) + handleClientAuthentication(waitingConn, securityMsg, connectionId) + } + case None => { + // Server - someone sent us something and we haven't authenticated yet + logDebug("Server handleAuth for id: " + connectionId) + handleServerAuthentication(conn, securityMsg, connectionId) + } + } + return true + } else { + if (!conn.isSaslComplete()) { + // We could handle this better and tell the client we need to do authentication + // negotiation, but for now just ignore them. + logError("message sent that is not security negotiation message on connection " + + "not authenticated yet, ignoring it!!") + return true + } + } + return false + } + + private def handleMessage( + connectionManagerId: ConnectionManagerId, + message: Message, + connection: Connection) { logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") message match { case bufferMessage: BufferMessage => { + if (authEnabled) { + val res = handleAuthentication(connection, bufferMessage) + if (res == true) { + // message was security negotiation so skip the rest + logDebug("After handleAuth result was true, returning") + return + } + } if (bufferMessage.hasAckId) { val sentMessageStatus = messageStatuses.synchronized { messageStatuses.get(bufferMessage.ackId) match { @@ -541,17 +683,124 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi } } + private def checkSendAuthFirst(connManagerId: ConnectionManagerId, conn: SendingConnection) { + // see if we need to do sasl before writing + // this should only be the first negotiation as the Client!!! + if (!conn.isSaslComplete()) { + conn.synchronized { + if (conn.sparkSaslClient == null) { + conn.sparkSaslClient = new SparkSaslClient(securityManager) + var firstResponse: Array[Byte] = null + try { + firstResponse = conn.sparkSaslClient.firstToken() + var securityMsg = SecurityMessage.fromResponse(firstResponse, + conn.connectionId.toString()) + var message = securityMsg.toBufferMessage + if (message == null) throw new Exception("Error creating security message") + connectionsAwaitingSasl += ((conn.connectionId, conn)) + sendSecurityMessage(connManagerId, message) + logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId) + } catch { + case e: Exception => { + logError("Error getting first response from the SaslClient.", e) + conn.close() + throw new Exception("Error getting first response from the SaslClient") + } + } + } + } + } else { + logDebug("Sasl already established ") + } + } + + // allow us to add messages to the inbox for doing sasl negotiating + private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) { + def startNewConnection(): SendingConnection = { + val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port) + val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) + val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId, + newConnectionId) + logInfo("creating new sending connection for security! " + newConnectionId ) + registerRequests.enqueue(newConnection) + + newConnection + } + // I removed the lookupKey stuff as part of merge ... should I re-add it ? + // We did not find it useful in our test-env ... + // If we do re-add it, we should consistently use it everywhere I guess ? + message.senderAddress = id.toSocketAddress() + logTrace("Sending Security [" + message + "] to [" + connManagerId + "]") + val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection()) + + //send security message until going connection has been authenticated + connection.send(message) + + wakeupSelector() + } + private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId) + val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) + val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, + newConnectionId) + logTrace("creating new sending connection: " + newConnectionId) registerRequests.enqueue(newConnection) newConnection } val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) + if (authEnabled) { + checkSendAuthFirst(connectionManagerId, connection) + } message.senderAddress = id.toSocketAddress() + logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " + + "connectionid: " + connection.connectionId) + + if (authEnabled) { + // if we aren't authenticated yet lets block the senders until authentication completes + try { + connection.getAuthenticated().synchronized { + val clock = SystemClock + val startTime = clock.getTime() + + while (!connection.isSaslComplete()) { + logDebug("getAuthenticated wait connectionid: " + connection.connectionId) + // have timeout in case remote side never responds + connection.getAuthenticated().wait(500) + if (((clock.getTime() - startTime) >= (authTimeout * 1000)) + && (!connection.isSaslComplete())) { + // took to long to authenticate the connection, something probably went wrong + throw new Exception("Took to long for authentication to " + connectionManagerId + + ", waited " + authTimeout + "seconds, failing.") + } + } + } + } catch { + case e: Exception => logError("Exception while waiting for authentication.", e) + + // need to tell sender it failed + messageStatuses.synchronized { + val s = messageStatuses.get(message.id) + s match { + case Some(msgStatus) => { + messageStatuses -= message.id + logInfo("Notifying " + msgStatus.connectionManagerId) + msgStatus.synchronized { + msgStatus.attempted = true + msgStatus.acked = false + msgStatus.markDone() + } + } + case None => { + logError("no messageStatus for failed message id: " + message.id) + } + } + } + } + } logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") connection.send(message) @@ -603,7 +852,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi private[spark] object ConnectionManager { def main(args: Array[String]) { - val manager = new ConnectionManager(9999, new SparkConf) + val conf = new SparkConf + val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { println("Received [" + msg + "] from [" + id + "]") None diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala index 20fe676618..7caccfdbb4 100644 --- a/core/src/main/scala/org/apache/spark/network/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/Message.scala @@ -27,6 +27,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) { var started = false var startTime = -1L var finishTime = -1L + var isSecurityNeg = false def size: Int diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala index 9bcbc6141a..ead663ede7 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala @@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader( val totalSize: Int, val chunkSize: Int, val other: Int, + val securityNeg: Int, val address: InetSocketAddress) { lazy val buffer = { // No need to change this, at 'use' time, we do a reverse lookup of the hostname. @@ -40,6 +41,7 @@ private[spark] class MessageChunkHeader( putInt(totalSize). putInt(chunkSize). putInt(other). + putInt(securityNeg). putInt(ip.size). put(ip). putInt(port). @@ -48,12 +50,13 @@ private[spark] class MessageChunkHeader( } override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + - " and sizes " + totalSize + " / " + chunkSize + " bytes" + " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg + } private[spark] object MessageChunkHeader { - val HEADER_SIZE = 40 + val HEADER_SIZE = 44 def create(buffer: ByteBuffer): MessageChunkHeader = { if (buffer.remaining != HEADER_SIZE) { @@ -64,11 +67,13 @@ private[spark] object MessageChunkHeader { val totalSize = buffer.getInt() val chunkSize = buffer.getInt() val other = buffer.getInt() + val securityNeg = buffer.getInt() val ipSize = buffer.getInt() val ipBytes = new Array[Byte](ipSize) buffer.get(ipBytes) val ip = InetAddress.getByAddress(ipBytes) val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port)) + new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg, + new InetSocketAddress(ip, port)) } } diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala index 9976255c7e..3c09a713c6 100644 --- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala +++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala @@ -18,12 +18,12 @@ package org.apache.spark.network import java.nio.ByteBuffer - -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} private[spark] object ReceiverTest { def main(args: Array[String]) { - val manager = new ConnectionManager(9999, new SparkConf) + val conf = new SparkConf + val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) println("Started connection manager with id = " + manager.id) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala new file mode 100644 index 0000000000..0d9f743b36 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.StringBuilder + +import org.apache.spark._ +import org.apache.spark.network._ + +/** + * SecurityMessage is class that contains the connectionId and sasl token + * used in SASL negotiation. SecurityMessage has routines for converting + * it to and from a BufferMessage so that it can be sent by the ConnectionManager + * and easily consumed by users when received. + * The api was modeled after BlockMessage. + * + * The connectionId is the connectionId of the client side. Since + * message passing is asynchronous and its possible for the server side (receiving) + * to get multiple different types of messages on the same connection the connectionId + * is used to know which connnection the security message is intended for. + * + * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side + * is acting as a client and connecting to node_1. SASL negotiation has to occur + * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message. + * node_1 receives the message from node_0 but before it can process it and send a response, + * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0 + * and sends a security message of its own to authenticate as a client. Now node_0 gets + * the message and it needs to decide if this message is in response to it being a client + * (from the first send) or if its just node_1 trying to connect to it to send data. This + * is where the connectionId field is used. node_0 can lookup the connectionId to see if + * it is in response to it being a client or if its in response to someone sending other data. + * + * The format of a SecurityMessage as its sent is: + * - Length of the ConnectionId + * - ConnectionId + * - Length of the token + * - Token + */ +private[spark] class SecurityMessage() extends Logging { + + private var connectionId: String = null + private var token: Array[Byte] = null + + def set(byteArr: Array[Byte], newconnectionId: String) { + if (byteArr == null) { + token = new Array[Byte](0) + } else { + token = byteArr + } + connectionId = newconnectionId + } + + /** + * Read the given buffer and set the members of this class. + */ + def set(buffer: ByteBuffer) { + val idLength = buffer.getInt() + val idBuilder = new StringBuilder(idLength) + for (i <- 1 to idLength) { + idBuilder += buffer.getChar() + } + connectionId = idBuilder.toString() + + val tokenLength = buffer.getInt() + token = new Array[Byte](tokenLength) + if (tokenLength > 0) { + buffer.get(token, 0, tokenLength) + } + } + + def set(bufferMsg: BufferMessage) { + val buffer = bufferMsg.buffers.apply(0) + buffer.clear() + set(buffer) + } + + def getConnectionId: String = { + return connectionId + } + + def getToken: Array[Byte] = { + return token + } + + /** + * Create a BufferMessage that can be sent by the ConnectionManager containing + * the security information from this class. + * @return BufferMessage + */ + def toBufferMessage: BufferMessage = { + val startTime = System.currentTimeMillis + val buffers = new ArrayBuffer[ByteBuffer]() + + // 4 bytes for the length of the connectionId + // connectionId is of type char so multiple the length by 2 to get number of bytes + // 4 bytes for the length of token + // token is a byte buffer so just take the length + var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length) + buffer.putInt(connectionId.length()) + connectionId.foreach((x: Char) => buffer.putChar(x)) + buffer.putInt(token.length) + + if (token.length > 0) { + buffer.put(token) + } + buffer.flip() + buffers += buffer + + var message = Message.createBufferMessage(buffers) + logDebug("message total size is : " + message.size) + message.isSecurityNeg = true + return message + } + + override def toString: String = { + "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]" + } +} + +private[spark] object SecurityMessage { + + /** + * Convert the given BufferMessage to a SecurityMessage by parsing the contents + * of the BufferMessage and populating the SecurityMessage fields. + * @param bufferMessage is a BufferMessage that was received + * @return new SecurityMessage + */ + def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = { + val newSecurityMessage = new SecurityMessage() + newSecurityMessage.set(bufferMessage) + newSecurityMessage + } + + /** + * Create a SecurityMessage to send from a given saslResponse. + * @param response is the response to a challenge from the SaslClient or Saslserver + * @param connectionId the client connectionId we are negotiation authentication for + * @return a new SecurityMessage + */ + def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = { + val newSecurityMessage = new SecurityMessage() + newSecurityMessage.set(response, connectionId) + newSecurityMessage + } +} diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala index 646f8425d9..aac2c24a46 100644 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala @@ -18,8 +18,7 @@ package org.apache.spark.network import java.nio.ByteBuffer - -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} private[spark] object SenderTest { def main(args: Array[String]) { @@ -32,8 +31,8 @@ private[spark] object SenderTest { val targetHost = args(0) val targetPort = args(1).toInt val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort) - - val manager = new ConnectionManager(0, new SparkConf) + val conf = new SparkConf + val manager = new ConnectionManager(0, conf, new SecurityManager(conf)) println("Started connection manager with id = " + manager.id) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 977c24687c..1bf3f4db32 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props} import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} import sun.nio.ch.DirectBuffer -import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException, SecurityManager} import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer @@ -47,7 +47,8 @@ private[spark] class BlockManager( val master: BlockManagerMaster, val defaultSerializer: Serializer, maxMemory: Long, - val conf: SparkConf) + val conf: SparkConf, + securityManager: SecurityManager) extends Logging { val shuffleBlockManager = new ShuffleBlockManager(this) @@ -66,7 +67,7 @@ private[spark] class BlockManager( if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0 } - val connectionManager = new ConnectionManager(0, conf) + val connectionManager = new ConnectionManager(0, conf, securityManager) implicit val futureExecContext = connectionManager.futureExecContext val blockManagerId = BlockManagerId( @@ -122,8 +123,9 @@ private[spark] class BlockManager( * Construct a BlockManager with a memory limit set based on system properties. */ def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster, - serializer: Serializer, conf: SparkConf) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf) + serializer: Serializer, conf: SparkConf, securityManager: SecurityManager) = { + this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf, + securityManager) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 1d81d006c0..36f2a0fd02 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -24,6 +24,7 @@ import util.Random import org.apache.spark.SparkConf import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.{SecurityManager, SparkConf} /** * This class tests the BlockManager and MemoryStore for thread safety and @@ -98,7 +99,8 @@ private[spark] object ThreadingTest { val blockManagerMaster = new BlockManagerMaster( actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf))), conf) val blockManager = new BlockManager( - "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf) + "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf, + new SecurityManager(conf)) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 1b78c52ff6..7c35cd165a 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -18,7 +18,8 @@ package org.apache.spark.ui import java.net.InetSocketAddress -import javax.servlet.http.{HttpServletResponse, HttpServletRequest} +import java.net.URL +import javax.servlet.http.{HttpServlet, HttpServletResponse, HttpServletRequest} import scala.annotation.tailrec import scala.util.{Failure, Success, Try} @@ -26,11 +27,14 @@ import scala.xml.Node import org.json4s.JValue import org.json4s.jackson.JsonMethods.{pretty, render} -import org.eclipse.jetty.server.{Handler, Request, Server} -import org.eclipse.jetty.server.handler.{AbstractHandler, ContextHandler, HandlerList, ResourceHandler} + +import org.eclipse.jetty.server.{DispatcherType, Server} +import org.eclipse.jetty.server.handler.HandlerList +import org.eclipse.jetty.servlet.{DefaultServlet, FilterHolder, ServletContextHandler, ServletHolder} import org.eclipse.jetty.util.thread.QueuedThreadPool -import org.apache.spark.Logging +import org.apache.spark.{Logging, SecurityManager, SparkConf} + /** Utilities for launching a web server using Jetty's HTTP Server class */ private[spark] object JettyUtils extends Logging { @@ -39,57 +43,104 @@ private[spark] object JettyUtils extends Logging { type Responder[T] = HttpServletRequest => T - // Conversions from various types of Responder's to jetty Handlers - implicit def jsonResponderToHandler(responder: Responder[JValue]): Handler = - createHandler(responder, "text/json", (in: JValue) => pretty(render(in))) + class ServletParams[T <% AnyRef](val responder: Responder[T], + val contentType: String, + val extractFn: T => String = (in: Any) => in.toString) {} + + // Conversions from various types of Responder's to appropriate servlet parameters + implicit def jsonResponderToServlet(responder: Responder[JValue]): ServletParams[JValue] = + new ServletParams(responder, "text/json", (in: JValue) => pretty(render(in))) - implicit def htmlResponderToHandler(responder: Responder[Seq[Node]]): Handler = - createHandler(responder, "text/html", (in: Seq[Node]) => "<!DOCTYPE html>" + in.toString) + implicit def htmlResponderToServlet(responder: Responder[Seq[Node]]): ServletParams[Seq[Node]] = + new ServletParams(responder, "text/html", (in: Seq[Node]) => "<!DOCTYPE html>" + in.toString) - implicit def textResponderToHandler(responder: Responder[String]): Handler = - createHandler(responder, "text/plain") + implicit def textResponderToServlet(responder: Responder[String]): ServletParams[String] = + new ServletParams(responder, "text/plain") - def createHandler[T <% AnyRef](responder: Responder[T], contentType: String, - extractFn: T => String = (in: Any) => in.toString): Handler = { - new AbstractHandler { - def handle(target: String, - baseRequest: Request, - request: HttpServletRequest, + def createServlet[T <% AnyRef](servletParams: ServletParams[T], + securityMgr: SecurityManager): HttpServlet = { + new HttpServlet { + override def doGet(request: HttpServletRequest, response: HttpServletResponse) { - response.setContentType("%s;charset=utf-8".format(contentType)) - response.setStatus(HttpServletResponse.SC_OK) - baseRequest.setHandled(true) - val result = responder(request) - response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.getWriter().println(extractFn(result)) + if (securityMgr.checkUIViewPermissions(request.getRemoteUser())) { + response.setContentType("%s;charset=utf-8".format(servletParams.contentType)) + response.setStatus(HttpServletResponse.SC_OK) + val result = servletParams.responder(request) + response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.getWriter().println(servletParams.extractFn(result)) + } else { + response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) + response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.sendError(HttpServletResponse.SC_UNAUTHORIZED, + "User is not authorized to access this page."); + } } } } + def createServletHandler(path: String, servlet: HttpServlet): ServletContextHandler = { + val contextHandler = new ServletContextHandler() + val holder = new ServletHolder(servlet) + contextHandler.setContextPath(path) + contextHandler.addServlet(holder, "/") + contextHandler + } + /** Creates a handler that always redirects the user to a given path */ - def createRedirectHandler(newPath: String): Handler = { - new AbstractHandler { - def handle(target: String, - baseRequest: Request, - request: HttpServletRequest, + def createRedirectHandler(newPath: String, path: String): ServletContextHandler = { + val servlet = new HttpServlet { + override def doGet(request: HttpServletRequest, response: HttpServletResponse) { - response.setStatus(302) - response.setHeader("Location", baseRequest.getRootURL + newPath) - baseRequest.setHandled(true) + // make sure we don't end up with // in the middle + val newUri = new URL(new URL(request.getRequestURL.toString), newPath).toURI + response.sendRedirect(newUri.toString) } } + val contextHandler = new ServletContextHandler() + val holder = new ServletHolder(servlet) + contextHandler.setContextPath(path) + contextHandler.addServlet(holder, "/") + contextHandler } /** Creates a handler for serving files from a static directory */ - def createStaticHandler(resourceBase: String): ResourceHandler = { - val staticHandler = new ResourceHandler + def createStaticHandler(resourceBase: String, path: String): ServletContextHandler = { + val contextHandler = new ServletContextHandler() + val staticHandler = new DefaultServlet + val holder = new ServletHolder(staticHandler) Option(getClass.getClassLoader.getResource(resourceBase)) match { case Some(res) => - staticHandler.setResourceBase(res.toString) + holder.setInitParameter("resourceBase", res.toString) case None => throw new Exception("Could not find resource path for Web UI: " + resourceBase) } - staticHandler + contextHandler.addServlet(holder, path) + contextHandler + } + + private def addFilters(handlers: Seq[ServletContextHandler], conf: SparkConf) { + val filters: Array[String] = conf.get("spark.ui.filters", "").split(',').map(_.trim()) + filters.foreach { + case filter : String => + if (!filter.isEmpty) { + logInfo("Adding filter: " + filter) + val holder : FilterHolder = new FilterHolder() + holder.setClassName(filter) + // get any parameters for each filter + val paramName = "spark." + filter + ".params" + val params = conf.get(paramName, "").split(',').map(_.trim()).toSet + params.foreach { + case param : String => + if (!param.isEmpty) { + val parts = param.split("=") + if (parts.length == 2) holder.setInitParameter(parts(0), parts(1)) + } + } + val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR, + DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST) + handlers.foreach { case(handler) => handler.addFilter(holder, "/*", enumDispatcher) } + } + } } /** @@ -99,17 +150,12 @@ private[spark] object JettyUtils extends Logging { * If the desired port number is contented, continues incrementing ports until a free port is * found. Returns the chosen port and the jetty Server object. */ - def startJettyServer(hostName: String, port: Int, handlers: Seq[(String, Handler)]): (Server, Int) - = { - - val handlersToRegister = handlers.map { case(path, handler) => - val contextHandler = new ContextHandler(path) - contextHandler.setHandler(handler) - contextHandler.asInstanceOf[org.eclipse.jetty.server.Handler] - } + def startJettyServer(hostName: String, port: Int, handlers: Seq[ServletContextHandler], + conf: SparkConf): (Server, Int) = { + addFilters(handlers, conf) val handlerList = new HandlerList - handlerList.setHandlers(handlersToRegister.toArray) + handlerList.setHandlers(handlers.toArray) @tailrec def connect(currentPort: Int): (Server, Int) = { @@ -119,7 +165,9 @@ private[spark] object JettyUtils extends Logging { server.setThreadPool(pool) server.setHandler(handlerList) - Try { server.start() } match { + Try { + server.start() + } match { case s: Success[_] => (server, server.getConnectors.head.getLocalPort) case f: Failure[_] => diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index af6b65860e..ca82c3da2f 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,7 +17,10 @@ package org.apache.spark.ui -import org.eclipse.jetty.server.{Handler, Server} +import javax.servlet.http.HttpServletRequest + +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{Logging, SparkContext, SparkEnv} import org.apache.spark.ui.JettyUtils._ @@ -34,9 +37,9 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging { var boundPort: Option[Int] = None var server: Option[Server] = None - val handlers = Seq[(String, Handler)]( - ("/static", createStaticHandler(SparkUI.STATIC_RESOURCE_DIR)), - ("/", createRedirectHandler("/stages")) + val handlers = Seq[ServletContextHandler] ( + createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static/*"), + createRedirectHandler("/stages", "/") ) val storage = new BlockManagerUI(sc) val jobs = new JobProgressUI(sc) @@ -52,7 +55,7 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging { /** Bind the HTTP server which backs this web interface */ def bind() { try { - val (srv, usedPort) = JettyUtils.startJettyServer(host, port, allHandlers) + val (srv, usedPort) = JettyUtils.startJettyServer(host, port, allHandlers, sc.conf) logInfo("Started Spark Web UI at http://%s:%d".format(host, usedPort)) server = Some(srv) boundPort = Some(usedPort) @@ -83,5 +86,5 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging { private[spark] object SparkUI { val DEFAULT_PORT = "4040" - val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" + val STATIC_RESOURCE_DIR = "org/apache/spark/ui" } diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala index 9e7cdc8816..14333476c0 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConversions._ import scala.util.Properties import scala.xml.Node -import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.SparkContext import org.apache.spark.ui.JettyUtils._ @@ -32,8 +32,9 @@ import org.apache.spark.ui.UIUtils private[spark] class EnvironmentUI(sc: SparkContext) { - def getHandlers = Seq[(String, Handler)]( - ("/environment", (request: HttpServletRequest) => envDetails(request)) + def getHandlers = Seq[ServletContextHandler]( + createServletHandler("/environment", + createServlet((request: HttpServletRequest) => envDetails(request), sc.env.securityManager)) ) def envDetails(request: HttpServletRequest): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala index 1f3b7a4c23..4235cfeff9 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -22,7 +22,7 @@ import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{HashMap, HashSet} import scala.xml.Node -import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{ExceptionFailure, Logging, SparkContext} import org.apache.spark.executor.TaskMetrics @@ -43,8 +43,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { sc.addSparkListener(listener) } - def getHandlers = Seq[(String, Handler)]( - ("/executors", (request: HttpServletRequest) => render(request)) + def getHandlers = Seq[ServletContextHandler]( + createServletHandler("/executors", createServlet((request: HttpServletRequest) => render + (request), sc.env.securityManager)) ) def render(request: HttpServletRequest): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala index 557bce6b66..2d95d47e15 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala @@ -23,6 +23,7 @@ import javax.servlet.http.HttpServletRequest import scala.Seq import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.SparkContext import org.apache.spark.ui.JettyUtils._ @@ -45,9 +46,15 @@ private[spark] class JobProgressUI(val sc: SparkContext) { def formatDuration(ms: Long) = Utils.msDurationToString(ms) - def getHandlers = Seq[(String, Handler)]( - ("/stages/stage", (request: HttpServletRequest) => stagePage.render(request)), - ("/stages/pool", (request: HttpServletRequest) => poolPage.render(request)), - ("/stages", (request: HttpServletRequest) => indexPage.render(request)) + def getHandlers = Seq[ServletContextHandler]( + createServletHandler("/stages/stage", + createServlet((request: HttpServletRequest) => stagePage.render(request), + sc.env.securityManager)), + createServletHandler("/stages/pool", + createServlet((request: HttpServletRequest) => poolPage.render(request), + sc.env.securityManager)), + createServletHandler("/stages", + createServlet((request: HttpServletRequest) => indexPage.render(request), + sc.env.securityManager)) ) } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala index dc18eab74e..cb2083eb01 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala @@ -19,7 +19,7 @@ package org.apache.spark.ui.storage import javax.servlet.http.HttpServletRequest -import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{Logging, SparkContext} import org.apache.spark.ui.JettyUtils._ @@ -29,8 +29,12 @@ private[spark] class BlockManagerUI(val sc: SparkContext) extends Logging { val indexPage = new IndexPage(this) val rddPage = new RDDPage(this) - def getHandlers = Seq[(String, Handler)]( - ("/storage/rdd", (request: HttpServletRequest) => rddPage.render(request)), - ("/storage", (request: HttpServletRequest) => indexPage.render(request)) + def getHandlers = Seq[ServletContextHandler]( + createServletHandler("/storage/rdd", + createServlet((request: HttpServletRequest) => rddPage.render(request), + sc.env.securityManager)), + createServletHandler("/storage", + createServlet((request: HttpServletRequest) => indexPage.render(request), + sc.env.securityManager)) ) } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index f26ed47e58..a6c9a9aaba 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -24,12 +24,12 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, IndestructibleActorSystem} import com.typesafe.config.ConfigFactory import org.apache.log4j.{Level, Logger} -import org.apache.spark.SparkConf +import org.apache.spark.{Logging, SecurityManager, SparkConf} /** * Various utility classes for working with Akka. */ -private[spark] object AkkaUtils { +private[spark] object AkkaUtils extends Logging { /** * Creates an ActorSystem ready for remoting, with various Spark features. Returns both the @@ -42,7 +42,7 @@ private[spark] object AkkaUtils { * of a fatal exception. This is used by [[org.apache.spark.executor.Executor]]. */ def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false, - conf: SparkConf): (ActorSystem, Int) = { + conf: SparkConf, securityManager: SecurityManager): (ActorSystem, Int) = { val akkaThreads = conf.getInt("spark.akka.threads", 4) val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15) @@ -65,6 +65,15 @@ private[spark] object AkkaUtils { conf.getDouble("spark.akka.failure-detector.threshold", 300.0) val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000) + val secretKey = securityManager.getSecretKey() + val isAuthOn = securityManager.isAuthenticationEnabled() + if (isAuthOn && secretKey == null) { + throw new Exception("Secret key is null with authentication on") + } + val requireCookie = if (isAuthOn) "on" else "off" + val secureCookie = if (isAuthOn) secretKey else "" + logDebug("In createActorSystem, requireCookie is: " + requireCookie) + val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]).withFallback( ConfigFactory.parseString( s""" @@ -72,6 +81,8 @@ private[spark] object AkkaUtils { |akka.loggers = [""akka.event.slf4j.Slf4jLogger""] |akka.stdout-loglevel = "ERROR" |akka.jvm-exit-on-fatal-error = off + |akka.remote.require-cookie = "$requireCookie" + |akka.remote.secure-cookie = "$secureCookie" |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatInterval s |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPauses s |akka.remote.transport-failure-detector.threshold = $akkaFailureDetector diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 8e69f1d335..0eb2f78b73 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,7 +18,7 @@ package org.apache.spark.util import java.io._ -import java.net.{InetAddress, Inet4Address, NetworkInterface, URI, URL} +import java.net.{InetAddress, Inet4Address, NetworkInterface, URI, URL, URLConnection} import java.nio.ByteBuffer import java.util.{Locale, Random, UUID} import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor} @@ -33,10 +33,11 @@ import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} import org.apache.spark.deploy.SparkHadoopUtil + /** * Various utility methods used by Spark. */ @@ -233,13 +234,29 @@ private[spark] object Utils extends Logging { } /** + * Construct a URI container information used for authentication. + * This also sets the default authenticator to properly negotiation the + * user/password based on the URI. + * + * Note this relies on the Authenticator.setDefault being set properly to decode + * the user name and password. This is currently set in the SecurityManager. + */ + def constructURIForAuthentication(uri: URI, securityMgr: SecurityManager): URI = { + val userCred = securityMgr.getSecretKey() + if (userCred == null) throw new Exception("Secret key is null with authentication on") + val userInfo = securityMgr.getHttpUser() + ":" + userCred + new URI(uri.getScheme(), userInfo, uri.getHost(), uri.getPort(), uri.getPath(), + uri.getQuery(), uri.getFragment()) + } + + /** * Download a file requested by the executor. Supports fetching the file in a variety of ways, * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. * * Throws SparkException if the target file already exists and has different contents than * the requested file. */ - def fetchFile(url: String, targetDir: File, conf: SparkConf) { + def fetchFile(url: String, targetDir: File, conf: SparkConf, securityMgr: SecurityManager) { val filename = url.split("/").last val tempDir = getLocalDir(conf) val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) @@ -249,7 +266,19 @@ private[spark] object Utils extends Logging { uri.getScheme match { case "http" | "https" | "ftp" => logInfo("Fetching " + url + " to " + tempFile) - val in = new URL(url).openStream() + + var uc: URLConnection = null + if (securityMgr.isAuthenticationEnabled()) { + logDebug("fetchFile with security enabled") + val newuri = constructURIForAuthentication(uri, securityMgr) + uc = newuri.toURL().openConnection() + uc.setAllowUserInteraction(false) + } else { + logDebug("fetchFile not using security") + uc = new URL(url).openConnection() + } + + val in = uc.getInputStream(); val out = new FileOutputStream(tempFile) Utils.copyStream(in, out, true) if (targetFile.exists && !Files.equal(tempFile, targetFile)) { diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala new file mode 100644 index 0000000000..cd054c1f68 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +import akka.actor._ +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.AkkaUtils +import scala.concurrent.Await + +/** + * Test the AkkaUtils with various security settings. + */ +class AkkaUtilsSuite extends FunSuite with LocalSparkContext { + + test("remote fetch security bad password") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + + val securityManager = new SecurityManager(conf); + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + + val badconf = new SparkConf + badconf.set("spark.authenticate", "true") + badconf.set("spark.authenticate.secret", "bad") + val securityManagerBad = new SecurityManager(badconf); + + assert(securityManagerBad.isAuthenticationEnabled() === true) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = conf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTracker(conf) + val selection = slaveSystem.actorSelection( + s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + val timeout = AkkaUtils.lookupTimeout(conf) + intercept[akka.actor.ActorNotFound] { + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + } + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + test("remote fetch security off") { + val conf = new SparkConf + conf.set("spark.authenticate", "false") + conf.set("spark.authenticate.secret", "bad") + val securityManager = new SecurityManager(conf); + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === false) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + + val badconf = new SparkConf + badconf.set("spark.authenticate", "false") + badconf.set("spark.authenticate.secret", "good") + val securityManagerBad = new SecurityManager(badconf); + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = badconf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTracker(conf) + val selection = slaveSystem.actorSelection( + s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + val timeout = AkkaUtils.lookupTimeout(conf) + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + + assert(securityManagerBad.isAuthenticationEnabled() === false) + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + // this should succeed since security off + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + test("remote fetch security pass") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf); + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + + val goodconf = new SparkConf + goodconf.set("spark.authenticate", "true") + goodconf.set("spark.authenticate.secret", "good") + val securityManagerGood = new SecurityManager(goodconf); + + assert(securityManagerGood.isAuthenticationEnabled() === true) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = goodconf, securityManager = securityManagerGood) + val slaveTracker = new MapOutputTracker(conf) + val selection = slaveSystem.actorSelection( + s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + val timeout = AkkaUtils.lookupTimeout(conf) + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + // this should succeed since security on and passwords match + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + test("remote fetch security off client") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + + val securityManager = new SecurityManager(conf); + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + + val badconf = new SparkConf + badconf.set("spark.authenticate", "false") + badconf.set("spark.authenticate.secret", "bad") + val securityManagerBad = new SecurityManager(badconf); + + assert(securityManagerBad.isAuthenticationEnabled() === false) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = badconf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTracker(conf) + val selection = slaveSystem.actorSelection( + s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + val timeout = AkkaUtils.lookupTimeout(conf) + intercept[akka.actor.ActorNotFound] { + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + } + + actorSystem.shutdown() + slaveSystem.shutdown() + } + +} diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index e022accee6..96ba3929c1 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.FunSuite class BroadcastSuite extends FunSuite with LocalSparkContext { + override def afterEach() { super.afterEach() System.clearProperty("spark.broadcast.factory") diff --git a/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala new file mode 100644 index 0000000000..80f7ec00c7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +import java.nio._ + +import org.apache.spark.network.{ConnectionManager, Message, ConnectionManagerId} +import scala.concurrent.Await +import scala.concurrent.TimeoutException +import scala.concurrent.duration._ + + +/** + * Test the ConnectionManager with various security settings. + */ +class ConnectionManagerSuite extends FunSuite { + + test("security default off") { + val conf = new SparkConf + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var receivedMessage = false + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + receivedMessage = true + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliablySync(manager.id, bufferMessage) + + assert(receivedMessage == true) + + manager.stop() + } + + test("security on same password") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var numReceivedMessages = 0 + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedMessages += 1 + None + }) + val managerServer = new ConnectionManager(0, conf, securityManager) + var numReceivedServerMessages = 0 + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedServerMessages += 1 + None + }) + + val size = 10 * 1024 * 1024 + val count = 10 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + (0 until count).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliablySync(managerServer.id, bufferMessage) + }) + + assert(numReceivedServerMessages == 10) + assert(numReceivedMessages == 0) + + manager.stop() + managerServer.stop() + } + + test("security mismatch password") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var numReceivedMessages = 0 + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedMessages += 1 + None + }) + + val badconf = new SparkConf + badconf.set("spark.authenticate", "true") + badconf.set("spark.authenticate.secret", "bad") + val badsecurityManager = new SecurityManager(badconf) + val managerServer = new ConnectionManager(0, badconf, badsecurityManager) + var numReceivedServerMessages = 0 + + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedServerMessages += 1 + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliablySync(managerServer.id, bufferMessage) + + assert(numReceivedServerMessages == 0) + assert(numReceivedMessages == 0) + + manager.stop() + managerServer.stop() + } + + test("security mismatch auth off") { + val conf = new SparkConf + conf.set("spark.authenticate", "false") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var numReceivedMessages = 0 + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedMessages += 1 + None + }) + + val badconf = new SparkConf + badconf.set("spark.authenticate", "true") + badconf.set("spark.authenticate.secret", "good") + val badsecurityManager = new SecurityManager(badconf) + val managerServer = new ConnectionManager(0, badconf, badsecurityManager) + var numReceivedServerMessages = 0 + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedServerMessages += 1 + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + (0 until 1).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliably(managerServer.id, bufferMessage) + }).foreach(f => { + try { + val g = Await.result(f, 1 second) + assert(false) + } catch { + case e: TimeoutException => { + // we should timeout here since the client can't do the negotiation + assert(true) + } + } + }) + + assert(numReceivedServerMessages == 0) + assert(numReceivedMessages == 0) + manager.stop() + managerServer.stop() + } + + test("security auth off") { + val conf = new SparkConf + conf.set("spark.authenticate", "false") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var numReceivedMessages = 0 + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedMessages += 1 + None + }) + + val badconf = new SparkConf + badconf.set("spark.authenticate", "false") + val badsecurityManager = new SecurityManager(badconf) + val managerServer = new ConnectionManager(0, badconf, badsecurityManager) + var numReceivedServerMessages = 0 + + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedServerMessages += 1 + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + (0 until 10).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliably(managerServer.id, bufferMessage) + }).foreach(f => { + try { + val g = Await.result(f, 1 second) + if (!g.isDefined) assert(false) else assert(true) + } catch { + case e: Exception => { + assert(false) + } + } + }) + assert(numReceivedServerMessages == 10) + assert(numReceivedMessages == 0) + + manager.stop() + managerServer.stop() + } + + + +} + diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index e0e8011278..9cbdfc54a3 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.util.Utils class DriverSuite extends FunSuite with Timeouts { + test("driver should exit after finishing") { val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.props.get("spark.home")).get // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 9be67b3c95..aee9ab9091 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -30,6 +30,12 @@ class FileServerSuite extends FunSuite with LocalSparkContext { @transient var tmpFile: File = _ @transient var tmpJarUrl: String = _ + override def beforeEach() { + super.beforeEach() + resetSparkContext() + System.setProperty("spark.authenticate", "false") + } + override def beforeAll() { super.beforeAll() val tmpDir = new File(Files.createTempDir(), "test") @@ -43,6 +49,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { val jarFile = new File(tmpDir, "test.jar") val jarStream = new FileOutputStream(jarFile) val jar = new JarOutputStream(jarStream, new java.util.jar.Manifest()) + System.setProperty("spark.authenticate", "false") val jarEntry = new JarEntry(textFile.getName) jar.putNextEntry(jarEntry) @@ -77,6 +84,25 @@ class FileServerSuite extends FunSuite with LocalSparkContext { assert(result.toSet === Set((1,200), (2,300), (3,500))) } + test("Distributing files locally security On") { + val sparkConf = new SparkConf(false) + sparkConf.set("spark.authenticate", "true") + sparkConf.set("spark.authenticate.secret", "good") + sc = new SparkContext("local[4]", "test", sparkConf) + + sc.addFile(tmpFile.toString) + assert(sc.env.securityManager.isAuthenticationEnabled() === true) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val result = sc.parallelize(testData).reduceByKey { + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) + val fileVal = in.readLine().toInt + in.close() + _ * fileVal + _ * fileVal + }.collect() + assert(result.toSet === Set((1,200), (2,300), (3,500))) + } + test("Distributing files locally using URL as input") { // addFile("file:///....") sc = new SparkContext("local[4]", "test") diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 6c1e325f6f..8efa072a97 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -98,14 +98,16 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("remote fetch") { val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, + securityManager = new SecurityManager(conf)) System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf) + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, + securityManager = new SecurityManager(conf)) val slaveTracker = new MapOutputTracker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index c1e8b295df..96a5a12318 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -18,21 +18,22 @@ package org.apache.spark.metrics import org.scalatest.{BeforeAndAfter, FunSuite} - -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.master.MasterSource class MetricsSystemSuite extends FunSuite with BeforeAndAfter { var filePath: String = _ var conf: SparkConf = null + var securityMgr: SecurityManager = null before { filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile() conf = new SparkConf(false).set("spark.metrics.conf", filePath) + securityMgr = new SecurityManager(conf) } test("MetricsSystem with default config") { - val metricsSystem = MetricsSystem.createMetricsSystem("default", conf) + val metricsSystem = MetricsSystem.createMetricsSystem("default", conf, securityMgr) val sources = metricsSystem.sources val sinks = metricsSystem.sinks @@ -42,7 +43,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter { } test("MetricsSystem with sources add") { - val metricsSystem = MetricsSystem.createMetricsSystem("test", conf) + val metricsSystem = MetricsSystem.createMetricsSystem("test", conf, securityMgr) val sources = metricsSystem.sources val sinks = metricsSystem.sinks diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 9f011d9c8d..121e47c7b1 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.matchers.ShouldMatchers._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SecurityManager, SparkConf, SparkContext} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} @@ -39,6 +39,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT var actorSystem: ActorSystem = null var master: BlockManagerMaster = null var oldArch: String = null + conf.set("spark.authenticate", "false") + val securityMgr = new SecurityManager(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test conf.set("spark.kryoserializer.buffer.mb", "1") @@ -49,7 +51,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf, + securityManager = securityMgr) this.actorSystem = actorSystem conf.set("spark.driver.port", boundPort.toString) @@ -125,7 +128,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 1 manager interaction") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -155,8 +158,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 2 managers interaction") { - store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf) - store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr) + store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf, + securityMgr) val peers = master.getPeers(store.blockManagerId, 1) assert(peers.size === 1, "master did not return the other manager as a peer") @@ -171,7 +175,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing block") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -219,7 +223,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing rdd") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -253,7 +257,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration on heart beat") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) @@ -269,7 +273,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("reregistration on block update") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) @@ -288,7 +292,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration doesn't dead lock") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = List(new Array[Byte](400)) @@ -325,7 +329,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -344,7 +348,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage with serialization") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -363,7 +367,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of same RDD") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -382,7 +386,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of multiple RDDs") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr) store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) @@ -405,7 +409,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("on-disk storage") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -418,7 +422,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -433,7 +437,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with getLocalBytes") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -448,7 +452,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -463,7 +467,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization and getLocalBytes") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -478,7 +482,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -503,7 +507,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU with streams") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -527,7 +531,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels and streams") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -573,7 +577,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("overly large block") { - store = new BlockManager("<driver>", actorSystem, master, serializer, 500, conf) + store = new BlockManager("<driver>", actorSystem, master, serializer, 500, conf, securityMgr) store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.getSingle("a1") === None, "a1 was in store") store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) @@ -584,7 +588,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block compression") { try { conf.set("spark.shuffle.compress", "true") - store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100, "shuffle_0_0_0 was not compressed") @@ -592,7 +596,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.shuffle.compress", "false") - store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000, "shuffle_0_0_0 was compressed") @@ -600,7 +604,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.broadcast.compress", "true") - store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100, "broadcast_0 was not compressed") @@ -608,28 +612,28 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.broadcast.compress", "false") - store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed") store.stop() store = null conf.set("spark.rdd.compress", "true") - store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed") store.stop() store = null conf.set("spark.rdd.compress", "false") - store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed") store.stop() store = null // Check that any other block types are also kept uncompressed - store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") store.stop() @@ -643,7 +647,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block store put failure") { // Use Java serializer so we can create an unserializable error. - store = new BlockManager("<driver>", actorSystem, master, new JavaSerializer(conf), 1200, conf) + store = new BlockManager("<driver>", actorSystem, master, new JavaSerializer(conf), 1200, conf, + securityMgr) // The put should fail since a1 is not serializable. class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 20ebb1897e..30415814ad 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -24,6 +24,8 @@ import scala.util.{Failure, Success, Try} import org.eclipse.jetty.server.Server import org.scalatest.FunSuite +import org.apache.spark.SparkConf + class UISuite extends FunSuite { test("jetty port increases under contention") { val startPort = 4040 @@ -34,15 +36,17 @@ class UISuite extends FunSuite { case Failure(e) => // Either case server port is busy hence setup for test complete } - val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq()) - val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq()) + val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq(), + new SparkConf) + val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq(), + new SparkConf) // Allow some wiggle room in case ports on the machine are under contention assert(boundPort1 > startPort && boundPort1 < startPort + 10) assert(boundPort2 > boundPort1 && boundPort2 < boundPort1 + 10) } test("jetty binds to port 0 correctly") { - val (jettyServer, boundPort) = JettyUtils.startJettyServer("0.0.0.0", 0, Seq()) + val (jettyServer, boundPort) = JettyUtils.startJettyServer("0.0.0.0", 0, Seq(), new SparkConf) assert(jettyServer.getState === "STARTED") assert(boundPort != 0) Try {new ServerSocket(boundPort)} match { |