aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorAaron Davidson <aaron@databricks.com>2014-11-04 16:15:38 -0800
committerReynold Xin <rxin@databricks.com>2014-11-04 16:15:38 -0800
commit5e73138a0152b78380b3f1def4b969b58e70dd11 (patch)
treed899e389f29e7c87c8f40e84872477a9b5e52277 /core/src
parentf90ad5d426cb726079c490a9bb4b1100e2b4e602 (diff)
downloadspark-5e73138a0152b78380b3f1def4b969b58e70dd11.tar.gz
spark-5e73138a0152b78380b3f1def4b969b58e70dd11.tar.bz2
spark-5e73138a0152b78380b3f1def4b969b58e70dd11.zip
[SPARK-2938] Support SASL authentication in NettyBlockTransferService
Also lays the groundwork for supporting it inside the external shuffle service. Author: Aaron Davidson <aaron@databricks.com> Closes #3087 from aarondav/sasl and squashes the following commits: 3481718 [Aaron Davidson] Delete rogue println 44f8410 [Aaron Davidson] Delete documentation - muahaha! eb9f065 [Aaron Davidson] Improve documentation and add end-to-end test at Spark-level a6b95f1 [Aaron Davidson] Address comments 785bbde [Aaron Davidson] Cleanup 79973cb [Aaron Davidson] Remove unused file 151b3c5 [Aaron Davidson] Add docs, timeout config, better failure handling f6177d7 [Aaron Davidson] Cleanup SASL state upon connection termination 7b42adb [Aaron Davidson] Add unit tests 8191bcb [Aaron Davidson] [SPARK-2938] Support SASL authentication in NettyBlockTransferService
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/SecurityManager.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/SparkConf.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/SparkSaslClient.scala147
-rw-r--r--core/src/main/scala/org/apache/spark/SparkSaslServer.scala176
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/network/nio/Connection.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala45
-rw-r--r--core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala161
-rw-r--r--core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala4
15 files changed, 257 insertions, 359 deletions
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 0e0f1a7b23..dee935ffad 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -22,6 +22,7 @@ import java.net.{Authenticator, PasswordAuthentication}
import org.apache.hadoop.io.Text
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.network.sasl.SecretKeyHolder
/**
* Spark class responsible for security.
@@ -84,7 +85,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
* Authenticator installed in the SecurityManager to how it does the authentication
* and in this case gets the user name and password from the request.
*
- * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously
+ * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously
* exchange messages. For this we use the Java SASL
* (Simple Authentication and Security Layer) API and again use DIGEST-MD5
* as the authentication mechanism. This means the shared secret is not passed
@@ -98,7 +99,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
* of protection they want. If we support those, the messages will also have to
* be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's.
*
- * Since the connectionManager does asynchronous messages passing, the SASL
+ * Since the NioBlockTransferService does asynchronous messages passing, the SASL
* authentication is a bit more complex. A ConnectionManager can be both a client
* and a Server, so for a particular connection is has to determine what to do.
* A ConnectionId was added to be able to track connections and is used to
@@ -107,6 +108,10 @@ import org.apache.spark.deploy.SparkHadoopUtil
* and waits for the response from the server and does the handshake before sending
* the real message.
*
+ * The NettyBlockTransferService ensures that SASL authentication is performed
+ * synchronously prior to any other communication on a connection. This is done in
+ * SaslClientBootstrap on the client side and SaslRpcHandler on the server side.
+ *
* - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
* can be used. Yarn requires a specific AmIpFilter be installed for security to work
* properly. For non-Yarn deployments, users can write a filter to go through a
@@ -139,7 +144,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
* can take place.
*/
-private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
+private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder {
// key used to store the spark secret in the Hadoop UGI
private val sparkSecretLookupKey = "sparkCookie"
@@ -337,4 +342,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
* @return the secret key as a String if authentication is enabled, otherwise returns null
*/
def getSecretKey(): String = secretKey
+
+ override def getSaslUser(appId: String): String = {
+ val myAppId = sparkConf.getAppId
+ require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}")
+ getSaslUser()
+ }
+
+ override def getSecretKey(appId: String): String = {
+ val myAppId = sparkConf.getAppId
+ require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}")
+ getSecretKey()
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index ad0a9017af..4c6c86c7ba 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -217,6 +217,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
*/
getAll.filter { case (k, _) => isAkkaConf(k) }
+ /**
+ * Returns the Spark application id, valid in the Driver after TaskScheduler registration and
+ * from the start in the Executor.
+ */
+ def getAppId: String = get("spark.app.id")
+
/** Does the configuration contain a given parameter? */
def contains(key: String): Boolean = settings.contains(key)
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 40444c237b..3cdaa6a9cc 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -313,6 +313,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
val applicationId: String = taskScheduler.applicationId()
conf.set("spark.app.id", applicationId)
+ env.blockManager.initialize(applicationId)
+
val metricsSystem = env.metricsSystem
// The metrics system for Driver need to be set spark.app.id to app ID.
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index e2f13accdf..45e9d7f243 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -276,7 +276,7 @@ object SparkEnv extends Logging {
val blockTransferService =
conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match {
case "netty" =>
- new NettyBlockTransferService(conf)
+ new NettyBlockTransferService(conf, securityManager)
case "nio" =>
new NioBlockTransferService(conf, securityManager)
}
@@ -285,6 +285,7 @@ object SparkEnv extends Logging {
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver)
+ // NB: blockManager is not valid until initialize() is called later.
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
serializer, conf, mapOutputTracker, shuffleManager, blockTransferService)
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
deleted file mode 100644
index a954fcc0c3..0000000000
--- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
+++ /dev/null
@@ -1,147 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import javax.security.auth.callback.Callback
-import javax.security.auth.callback.CallbackHandler
-import javax.security.auth.callback.NameCallback
-import javax.security.auth.callback.PasswordCallback
-import javax.security.auth.callback.UnsupportedCallbackException
-import javax.security.sasl.RealmCallback
-import javax.security.sasl.RealmChoiceCallback
-import javax.security.sasl.Sasl
-import javax.security.sasl.SaslClient
-import javax.security.sasl.SaslException
-
-import scala.collection.JavaConversions.mapAsJavaMap
-
-import com.google.common.base.Charsets.UTF_8
-
-/**
- * Implements SASL Client logic for Spark
- */
-private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging {
-
- /**
- * Used to respond to server's counterpart, SaslServer with SASL tokens
- * represented as byte arrays.
- *
- * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
- * configurable in the future.
- */
- private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST),
- null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
- new SparkSaslClientCallbackHandler(securityMgr))
-
- /**
- * Used to initiate SASL handshake with server.
- * @return response to challenge if needed
- */
- def firstToken(): Array[Byte] = {
- synchronized {
- val saslToken: Array[Byte] =
- if (saslClient != null && saslClient.hasInitialResponse()) {
- logDebug("has initial response")
- saslClient.evaluateChallenge(new Array[Byte](0))
- } else {
- new Array[Byte](0)
- }
- saslToken
- }
- }
-
- /**
- * Determines whether the authentication exchange has completed.
- * @return true is complete, otherwise false
- */
- def isComplete(): Boolean = {
- synchronized {
- if (saslClient != null) saslClient.isComplete() else false
- }
- }
-
- /**
- * Respond to server's SASL token.
- * @param saslTokenMessage contains server's SASL token
- * @return client's response SASL token
- */
- def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = {
- synchronized {
- if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0)
- }
- }
-
- /**
- * Disposes of any system resources or security-sensitive information the
- * SaslClient might be using.
- */
- def dispose() {
- synchronized {
- if (saslClient != null) {
- try {
- saslClient.dispose()
- } catch {
- case e: SaslException => // ignored
- } finally {
- saslClient = null
- }
- }
- }
- }
-
- /**
- * Implementation of javax.security.auth.callback.CallbackHandler
- * that works with share secrets.
- */
- private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends
- CallbackHandler {
-
- private val userName: String =
- SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))
- private val secretKey = securityMgr.getSecretKey()
- private val userPassword: Array[Char] = SparkSaslServer.encodePassword(
- if (secretKey != null) secretKey.getBytes(UTF_8) else "".getBytes(UTF_8))
-
- /**
- * Implementation used to respond to SASL request from the server.
- *
- * @param callbacks objects that indicate what credential information the
- * server's SaslServer requires from the client.
- */
- override def handle(callbacks: Array[Callback]) {
- logDebug("in the sasl client callback handler")
- callbacks foreach {
- case nc: NameCallback => {
- logDebug("handle: SASL client callback: setting username: " + userName)
- nc.setName(userName)
- }
- case pc: PasswordCallback => {
- logDebug("handle: SASL client callback: setting userPassword")
- pc.setPassword(userPassword)
- }
- case rc: RealmCallback => {
- logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText())
- rc.setText(rc.getDefaultText())
- }
- case cb: RealmChoiceCallback => {}
- case cb: Callback => throw
- new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback")
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
deleted file mode 100644
index 7c2afb3646..0000000000
--- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
+++ /dev/null
@@ -1,176 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import javax.security.auth.callback.Callback
-import javax.security.auth.callback.CallbackHandler
-import javax.security.auth.callback.NameCallback
-import javax.security.auth.callback.PasswordCallback
-import javax.security.auth.callback.UnsupportedCallbackException
-import javax.security.sasl.AuthorizeCallback
-import javax.security.sasl.RealmCallback
-import javax.security.sasl.Sasl
-import javax.security.sasl.SaslException
-import javax.security.sasl.SaslServer
-import scala.collection.JavaConversions.mapAsJavaMap
-
-import com.google.common.base.Charsets.UTF_8
-import org.apache.commons.net.util.Base64
-
-/**
- * Encapsulates SASL server logic
- */
-private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging {
-
- /**
- * Actual SASL work done by this object from javax.security.sasl.
- */
- private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null,
- SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
- new SparkSaslDigestCallbackHandler(securityMgr))
-
- /**
- * Determines whether the authentication exchange has completed.
- * @return true is complete, otherwise false
- */
- def isComplete(): Boolean = {
- synchronized {
- if (saslServer != null) saslServer.isComplete() else false
- }
- }
-
- /**
- * Used to respond to server SASL tokens.
- * @param token Server's SASL token
- * @return response to send back to the server.
- */
- def response(token: Array[Byte]): Array[Byte] = {
- synchronized {
- if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0)
- }
- }
-
- /**
- * Disposes of any system resources or security-sensitive information the
- * SaslServer might be using.
- */
- def dispose() {
- synchronized {
- if (saslServer != null) {
- try {
- saslServer.dispose()
- } catch {
- case e: SaslException => // ignore
- } finally {
- saslServer = null
- }
- }
- }
- }
-
- /**
- * Implementation of javax.security.auth.callback.CallbackHandler
- * for SASL DIGEST-MD5 mechanism
- */
- private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager)
- extends CallbackHandler {
-
- private val userName: String =
- SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))
-
- override def handle(callbacks: Array[Callback]) {
- logDebug("In the sasl server callback handler")
- callbacks foreach {
- case nc: NameCallback => {
- logDebug("handle: SASL server callback: setting username")
- nc.setName(userName)
- }
- case pc: PasswordCallback => {
- logDebug("handle: SASL server callback: setting userPassword")
- val password: Array[Char] =
- SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes(UTF_8))
- pc.setPassword(password)
- }
- case rc: RealmCallback => {
- logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText())
- rc.setText(rc.getDefaultText())
- }
- case ac: AuthorizeCallback => {
- val authid = ac.getAuthenticationID()
- val authzid = ac.getAuthorizationID()
- if (authid.equals(authzid)) {
- logDebug("set auth to true")
- ac.setAuthorized(true)
- } else {
- logDebug("set auth to false")
- ac.setAuthorized(false)
- }
- if (ac.isAuthorized()) {
- logDebug("sasl server is authorized")
- ac.setAuthorizedID(authzid)
- }
- }
- case cb: Callback => throw
- new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback")
- }
- }
- }
-}
-
-private[spark] object SparkSaslServer {
-
- /**
- * This is passed as the server name when creating the sasl client/server.
- * This could be changed to be configurable in the future.
- */
- val SASL_DEFAULT_REALM = "default"
-
- /**
- * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
- * configurable in the future.
- */
- val DIGEST = "DIGEST-MD5"
-
- /**
- * The quality of protection is just "auth". This means that we are doing
- * authentication only, we are not supporting integrity or privacy protection of the
- * communication channel after authentication. This could be changed to be configurable
- * in the future.
- */
- val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true")
-
- /**
- * Encode a byte[] identifier as a Base64-encoded string.
- *
- * @param identifier identifier to encode
- * @return Base64-encoded string
- */
- def encodeIdentifier(identifier: Array[Byte]): String = {
- new String(Base64.encodeBase64(identifier), UTF_8)
- }
-
- /**
- * Encode a password as a base64-encoded char[] array.
- * @param password as a byte array.
- * @return password as a char array.
- */
- def encodePassword(password: Array[Byte]): Array[Char] = {
- new String(Base64.encodeBase64(password), UTF_8).toCharArray()
- }
-}
-
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 8b095e23f3..abc1dd0be6 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -86,6 +86,7 @@ private[spark] class Executor(
conf, executorId, slaveHostname, port, isLocal, actorSystem)
SparkEnv.set(_env)
_env.metricsSystem.registerSource(executorSource)
+ _env.blockManager.initialize(conf.getAppId)
_env
} else {
SparkEnv.get
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 1c4327cf13..0d1fc81d2a 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -17,13 +17,15 @@
package org.apache.spark.network.netty
+import scala.collection.JavaConversions._
import scala.concurrent.{Future, Promise}
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.network.client.{RpcResponseCallback, TransportClientFactory}
+import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory}
import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock}
+import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
import org.apache.spark.network.server._
import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher}
import org.apache.spark.serializer.JavaSerializer
@@ -33,18 +35,30 @@ import org.apache.spark.util.Utils
/**
* A BlockTransferService that uses Netty to fetch a set of blocks at at time.
*/
-class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
+class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager)
+ extends BlockTransferService {
+
// TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
- val serializer = new JavaSerializer(conf)
+ private val serializer = new JavaSerializer(conf)
+ private val authEnabled = securityManager.isAuthenticationEnabled()
+ private val transportConf = SparkTransportConf.fromSparkConf(conf)
private[this] var transportContext: TransportContext = _
private[this] var server: TransportServer = _
private[this] var clientFactory: TransportClientFactory = _
override def init(blockDataManager: BlockDataManager): Unit = {
- val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
- transportContext = new TransportContext(SparkTransportConf.fromSparkConf(conf), rpcHandler)
- clientFactory = transportContext.createClientFactory()
+ val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = {
+ val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
+ if (!authEnabled) {
+ (nettyRpcHandler, None)
+ } else {
+ (new SaslRpcHandler(nettyRpcHandler, securityManager),
+ Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager)))
+ }
+ }
+ transportContext = new TransportContext(transportConf, rpcHandler)
+ clientFactory = transportContext.createClientFactory(bootstrap.toList)
server = transportContext.createServer()
logInfo("Server created on " + server.getPort)
}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
index 4f6f5e2358..c2d9578be7 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
@@ -23,12 +23,13 @@ import java.nio.channels._
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.LinkedList
-import org.apache.spark._
-
import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
+import org.apache.spark._
+import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
+
private[nio]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId,
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index 8408b75bb4..f198aa8564 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -34,6 +34,7 @@ import scala.language.postfixOps
import com.google.common.base.Charsets.UTF_8
import org.apache.spark._
+import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
import org.apache.spark.util.Utils
import scala.util.Try
@@ -600,7 +601,7 @@ private[nio] class ConnectionManager(
} else {
var replyToken : Array[Byte] = null
try {
- replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken)
+ replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken)
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
@@ -634,7 +635,7 @@ private[nio] class ConnectionManager(
connection.synchronized {
if (connection.sparkSaslServer == null) {
logDebug("Creating sasl Server")
- connection.sparkSaslServer = new SparkSaslServer(securityManager)
+ connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager)
}
}
replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
@@ -778,7 +779,7 @@ private[nio] class ConnectionManager(
if (!conn.isSaslComplete()) {
conn.synchronized {
if (conn.sparkSaslClient == null) {
- conn.sparkSaslClient = new SparkSaslClient(securityManager)
+ conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager)
var firstResponse: Array[Byte] = null
try {
firstResponse = conn.sparkSaslClient.firstToken()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 5f5dd0dc1c..655d16c65c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -57,6 +57,12 @@ private[spark] class BlockResult(
inputMetrics.bytesRead = bytes
}
+/**
+ * Manager running on every node (driver and executors) which provides interfaces for putting and
+ * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap).
+ *
+ * Note that #initialize() must be called before the BlockManager is usable.
+ */
private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
@@ -69,8 +75,6 @@ private[spark] class BlockManager(
blockTransferService: BlockTransferService)
extends BlockDataManager with Logging {
- blockTransferService.init(this)
-
val diskBlockManager = new DiskBlockManager(this, conf)
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
@@ -102,22 +106,16 @@ private[spark] class BlockManager(
+ " switch to sort-based shuffle.")
}
- val blockManagerId = BlockManagerId(
- executorId, blockTransferService.hostName, blockTransferService.port)
+ var blockManagerId: BlockManagerId = _
// Address of the server that serves this executor's shuffle files. This is either an external
// service, or just our own Executor's BlockManager.
- private[spark] val shuffleServerId = if (externalShuffleServiceEnabled) {
- BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort)
- } else {
- blockManagerId
- }
+ private[spark] var shuffleServerId: BlockManagerId = _
// Client to read other executors' shuffle files. This is either an external service, or just the
// standard BlockTranserService to directly connect to other Executors.
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
- val appId = conf.get("spark.app.id", "unknown-app-id")
- new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), appId)
+ new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf))
} else {
blockTransferService
}
@@ -150,8 +148,6 @@ private[spark] class BlockManager(
private val peerFetchLock = new Object
private var lastPeerFetchTime = 0L
- initialize()
-
/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
* the initialization of the compression codec until it is first used. The reason is that a Spark
* program could be using a user-defined codec in a third party jar, which is loaded in
@@ -176,10 +172,27 @@ private[spark] class BlockManager(
}
/**
- * Initialize the BlockManager. Register to the BlockManagerMaster, and start the
- * BlockManagerWorker actor. Additionally registers with a local shuffle service if configured.
+ * Initializes the BlockManager with the given appId. This is not performed in the constructor as
+ * the appId may not be known at BlockManager instantiation time (in particular for the driver,
+ * where it is only learned after registration with the TaskScheduler).
+ *
+ * This method initializes the BlockTransferService and ShuffleClient, registers with the
+ * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle
+ * service if configured.
*/
- private def initialize(): Unit = {
+ def initialize(appId: String): Unit = {
+ blockTransferService.init(this)
+ shuffleClient.init(appId)
+
+ blockManagerId = BlockManagerId(
+ executorId, blockTransferService.hostName, blockTransferService.port)
+
+ shuffleServerId = if (externalShuffleServiceEnabled) {
+ BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort)
+ } else {
+ blockManagerId
+ }
+
master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
// Register Executors' configuration with the local shuffle service, if one should exist.
diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
new file mode 100644
index 0000000000..bed0ed9d71
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.nio._
+import java.util.concurrent.TimeUnit
+
+import scala.concurrent.duration._
+import scala.concurrent.{Await, Promise}
+import scala.util.{Failure, Success, Try}
+
+import org.apache.commons.io.IOUtils
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.shuffle.BlockFetchingListener
+import org.apache.spark.network.{BlockDataManager, BlockTransferService}
+import org.apache.spark.storage.{BlockId, ShuffleBlockId}
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.mockito.Mockito._
+import org.scalatest.mock.MockitoSugar
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMatchers}
+
+class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers {
+ test("security default off") {
+ testConnection(new SparkConf, new SparkConf) match {
+ case Success(_) => // expected
+ case Failure(t) => fail(t)
+ }
+ }
+
+ test("security on same password") {
+ val conf = new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.app.id", "app-id")
+ testConnection(conf, conf) match {
+ case Success(_) => // expected
+ case Failure(t) => fail(t)
+ }
+ }
+
+ test("security on mismatch password") {
+ val conf0 = new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.app.id", "app-id")
+ val conf1 = conf0.clone.set("spark.authenticate.secret", "bad")
+ testConnection(conf0, conf1) match {
+ case Success(_) => fail("Should have failed")
+ case Failure(t) => t.getMessage should include ("Mismatched response")
+ }
+ }
+
+ test("security mismatch auth off on server") {
+ val conf0 = new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.app.id", "app-id")
+ val conf1 = conf0.clone.set("spark.authenticate", "false")
+ testConnection(conf0, conf1) match {
+ case Success(_) => fail("Should have failed")
+ case Failure(t) => // any funny error may occur, sever will interpret SASL token as RPC
+ }
+ }
+
+ test("security mismatch auth off on client") {
+ val conf0 = new SparkConf()
+ .set("spark.authenticate", "false")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.app.id", "app-id")
+ val conf1 = conf0.clone.set("spark.authenticate", "true")
+ testConnection(conf0, conf1) match {
+ case Success(_) => fail("Should have failed")
+ case Failure(t) => t.getMessage should include ("Expected SaslMessage")
+ }
+ }
+
+ test("security mismatch app ids") {
+ val conf0 = new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.app.id", "app-id")
+ val conf1 = conf0.clone.set("spark.app.id", "other-id")
+ testConnection(conf0, conf1) match {
+ case Success(_) => fail("Should have failed")
+ case Failure(t) => t.getMessage should include ("SASL appId app-id did not match")
+ }
+ }
+
+ /**
+ * Creates two servers with different configurations and sees if they can talk.
+ * Returns Success() if they can transfer a block, and Failure() if the block transfer was failed
+ * properly. We will throw an out-of-band exception if something other than that goes wrong.
+ */
+ private def testConnection(conf0: SparkConf, conf1: SparkConf): Try[Unit] = {
+ val blockManager = mock[BlockDataManager]
+ val blockId = ShuffleBlockId(0, 1, 2)
+ val blockString = "Hello, world!"
+ val blockBuffer = new NioManagedBuffer(ByteBuffer.wrap(blockString.getBytes))
+ when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer)
+
+ val securityManager0 = new SecurityManager(conf0)
+ val exec0 = new NettyBlockTransferService(conf0, securityManager0)
+ exec0.init(blockManager)
+
+ val securityManager1 = new SecurityManager(conf1)
+ val exec1 = new NettyBlockTransferService(conf1, securityManager1)
+ exec1.init(blockManager)
+
+ val result = fetchBlock(exec0, exec1, "1", blockId) match {
+ case Success(buf) =>
+ IOUtils.toString(buf.createInputStream()) should equal(blockString)
+ buf.release()
+ Success()
+ case Failure(t) =>
+ Failure(t)
+ }
+ exec0.close()
+ exec1.close()
+ result
+ }
+
+ /** Synchronously fetches a single block, acting as the given executor fetching from another. */
+ private def fetchBlock(
+ self: BlockTransferService,
+ from: BlockTransferService,
+ execId: String,
+ blockId: BlockId): Try[ManagedBuffer] = {
+
+ val promise = Promise[ManagedBuffer]()
+
+ self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString),
+ new BlockFetchingListener {
+ override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
+ promise.failure(exception)
+ }
+
+ override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
+ promise.success(data.retain())
+ }
+ })
+
+ Await.ready(promise.future, FiniteDuration(1000, TimeUnit.MILLISECONDS))
+ promise.future.value.get
+ }
+}
+
diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
index b70734dfe3..716f875d30 100644
--- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
@@ -60,6 +60,7 @@ class ConnectionManagerSuite extends FunSuite {
val conf = new SparkConf
conf.set("spark.authenticate", "true")
conf.set("spark.authenticate.secret", "good")
+ conf.set("spark.app.id", "app-id")
val securityManager = new SecurityManager(conf)
val manager = new ConnectionManager(0, conf, securityManager)
var numReceivedMessages = 0
@@ -95,6 +96,7 @@ class ConnectionManagerSuite extends FunSuite {
test("security mismatch password") {
val conf = new SparkConf
conf.set("spark.authenticate", "true")
+ conf.set("spark.app.id", "app-id")
conf.set("spark.authenticate.secret", "good")
val securityManager = new SecurityManager(conf)
val manager = new ConnectionManager(0, conf, securityManager)
@@ -105,9 +107,7 @@ class ConnectionManagerSuite extends FunSuite {
None
})
- val badconf = new SparkConf
- badconf.set("spark.authenticate", "true")
- badconf.set("spark.authenticate.secret", "bad")
+ val badconf = conf.clone.set("spark.authenticate.secret", "bad")
val badsecurityManager = new SecurityManager(badconf)
val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
var numReceivedServerMessages = 0
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index c6d7105592..1461fa69db 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -63,6 +63,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
val transfer = new NioBlockTransferService(conf, securityMgr)
val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
mapOutputTracker, shuffleManager, transfer)
+ store.initialize("app-id")
allStores += store
store
}
@@ -263,6 +264,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
when(failableTransfer.port).thenReturn(1000)
val failableStore = new BlockManager("failable-store", actorSystem, master, serializer,
10000, conf, mapOutputTracker, shuffleManager, failableTransfer)
+ failableStore.initialize("app-id")
allStores += failableStore // so that this gets stopped after test
assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId))
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 715b740b85..0782876c8e 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -73,8 +73,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
maxMem: Long,
name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
val transfer = new NioBlockTransferService(conf, securityMgr)
- new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
+ val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
mapOutputTracker, shuffleManager, transfer)
+ manager.initialize("app-id")
+ manager
}
before {