aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/SecurityManager.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/SparkConf.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/SparkSaslClient.scala147
-rw-r--r--core/src/main/scala/org/apache/spark/SparkSaslServer.scala176
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/network/nio/Connection.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala45
-rw-r--r--core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala161
-rw-r--r--core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala4
-rw-r--r--docs/security.md1
-rw-r--r--network/common/src/main/java/org/apache/spark/network/TransportContext.java15
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java11
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java32
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java64
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java2
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java19
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java1
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/TransportConf.java3
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java74
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java74
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java97
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java35
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java138
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java170
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java2
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java15
-rw-r--r--network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java11
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java172
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java89
-rw-r--r--network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java7
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala1
37 files changed, 1257 insertions, 392 deletions
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 0e0f1a7b23..dee935ffad 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -22,6 +22,7 @@ import java.net.{Authenticator, PasswordAuthentication}
import org.apache.hadoop.io.Text
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.network.sasl.SecretKeyHolder
/**
* Spark class responsible for security.
@@ -84,7 +85,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
* Authenticator installed in the SecurityManager to how it does the authentication
* and in this case gets the user name and password from the request.
*
- * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously
+ * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously
* exchange messages. For this we use the Java SASL
* (Simple Authentication and Security Layer) API and again use DIGEST-MD5
* as the authentication mechanism. This means the shared secret is not passed
@@ -98,7 +99,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
* of protection they want. If we support those, the messages will also have to
* be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's.
*
- * Since the connectionManager does asynchronous messages passing, the SASL
+ * Since the NioBlockTransferService does asynchronous messages passing, the SASL
* authentication is a bit more complex. A ConnectionManager can be both a client
* and a Server, so for a particular connection is has to determine what to do.
* A ConnectionId was added to be able to track connections and is used to
@@ -107,6 +108,10 @@ import org.apache.spark.deploy.SparkHadoopUtil
* and waits for the response from the server and does the handshake before sending
* the real message.
*
+ * The NettyBlockTransferService ensures that SASL authentication is performed
+ * synchronously prior to any other communication on a connection. This is done in
+ * SaslClientBootstrap on the client side and SaslRpcHandler on the server side.
+ *
* - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
* can be used. Yarn requires a specific AmIpFilter be installed for security to work
* properly. For non-Yarn deployments, users can write a filter to go through a
@@ -139,7 +144,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
* can take place.
*/
-private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
+private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder {
// key used to store the spark secret in the Hadoop UGI
private val sparkSecretLookupKey = "sparkCookie"
@@ -337,4 +342,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
* @return the secret key as a String if authentication is enabled, otherwise returns null
*/
def getSecretKey(): String = secretKey
+
+ override def getSaslUser(appId: String): String = {
+ val myAppId = sparkConf.getAppId
+ require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}")
+ getSaslUser()
+ }
+
+ override def getSecretKey(appId: String): String = {
+ val myAppId = sparkConf.getAppId
+ require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}")
+ getSecretKey()
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index ad0a9017af..4c6c86c7ba 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -217,6 +217,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
*/
getAll.filter { case (k, _) => isAkkaConf(k) }
+ /**
+ * Returns the Spark application id, valid in the Driver after TaskScheduler registration and
+ * from the start in the Executor.
+ */
+ def getAppId: String = get("spark.app.id")
+
/** Does the configuration contain a given parameter? */
def contains(key: String): Boolean = settings.contains(key)
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 40444c237b..3cdaa6a9cc 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -313,6 +313,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
val applicationId: String = taskScheduler.applicationId()
conf.set("spark.app.id", applicationId)
+ env.blockManager.initialize(applicationId)
+
val metricsSystem = env.metricsSystem
// The metrics system for Driver need to be set spark.app.id to app ID.
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index e2f13accdf..45e9d7f243 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -276,7 +276,7 @@ object SparkEnv extends Logging {
val blockTransferService =
conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match {
case "netty" =>
- new NettyBlockTransferService(conf)
+ new NettyBlockTransferService(conf, securityManager)
case "nio" =>
new NioBlockTransferService(conf, securityManager)
}
@@ -285,6 +285,7 @@ object SparkEnv extends Logging {
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver)
+ // NB: blockManager is not valid until initialize() is called later.
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
serializer, conf, mapOutputTracker, shuffleManager, blockTransferService)
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
deleted file mode 100644
index a954fcc0c3..0000000000
--- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
+++ /dev/null
@@ -1,147 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import javax.security.auth.callback.Callback
-import javax.security.auth.callback.CallbackHandler
-import javax.security.auth.callback.NameCallback
-import javax.security.auth.callback.PasswordCallback
-import javax.security.auth.callback.UnsupportedCallbackException
-import javax.security.sasl.RealmCallback
-import javax.security.sasl.RealmChoiceCallback
-import javax.security.sasl.Sasl
-import javax.security.sasl.SaslClient
-import javax.security.sasl.SaslException
-
-import scala.collection.JavaConversions.mapAsJavaMap
-
-import com.google.common.base.Charsets.UTF_8
-
-/**
- * Implements SASL Client logic for Spark
- */
-private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging {
-
- /**
- * Used to respond to server's counterpart, SaslServer with SASL tokens
- * represented as byte arrays.
- *
- * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
- * configurable in the future.
- */
- private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST),
- null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
- new SparkSaslClientCallbackHandler(securityMgr))
-
- /**
- * Used to initiate SASL handshake with server.
- * @return response to challenge if needed
- */
- def firstToken(): Array[Byte] = {
- synchronized {
- val saslToken: Array[Byte] =
- if (saslClient != null && saslClient.hasInitialResponse()) {
- logDebug("has initial response")
- saslClient.evaluateChallenge(new Array[Byte](0))
- } else {
- new Array[Byte](0)
- }
- saslToken
- }
- }
-
- /**
- * Determines whether the authentication exchange has completed.
- * @return true is complete, otherwise false
- */
- def isComplete(): Boolean = {
- synchronized {
- if (saslClient != null) saslClient.isComplete() else false
- }
- }
-
- /**
- * Respond to server's SASL token.
- * @param saslTokenMessage contains server's SASL token
- * @return client's response SASL token
- */
- def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = {
- synchronized {
- if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0)
- }
- }
-
- /**
- * Disposes of any system resources or security-sensitive information the
- * SaslClient might be using.
- */
- def dispose() {
- synchronized {
- if (saslClient != null) {
- try {
- saslClient.dispose()
- } catch {
- case e: SaslException => // ignored
- } finally {
- saslClient = null
- }
- }
- }
- }
-
- /**
- * Implementation of javax.security.auth.callback.CallbackHandler
- * that works with share secrets.
- */
- private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends
- CallbackHandler {
-
- private val userName: String =
- SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))
- private val secretKey = securityMgr.getSecretKey()
- private val userPassword: Array[Char] = SparkSaslServer.encodePassword(
- if (secretKey != null) secretKey.getBytes(UTF_8) else "".getBytes(UTF_8))
-
- /**
- * Implementation used to respond to SASL request from the server.
- *
- * @param callbacks objects that indicate what credential information the
- * server's SaslServer requires from the client.
- */
- override def handle(callbacks: Array[Callback]) {
- logDebug("in the sasl client callback handler")
- callbacks foreach {
- case nc: NameCallback => {
- logDebug("handle: SASL client callback: setting username: " + userName)
- nc.setName(userName)
- }
- case pc: PasswordCallback => {
- logDebug("handle: SASL client callback: setting userPassword")
- pc.setPassword(userPassword)
- }
- case rc: RealmCallback => {
- logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText())
- rc.setText(rc.getDefaultText())
- }
- case cb: RealmChoiceCallback => {}
- case cb: Callback => throw
- new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback")
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
deleted file mode 100644
index 7c2afb3646..0000000000
--- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
+++ /dev/null
@@ -1,176 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import javax.security.auth.callback.Callback
-import javax.security.auth.callback.CallbackHandler
-import javax.security.auth.callback.NameCallback
-import javax.security.auth.callback.PasswordCallback
-import javax.security.auth.callback.UnsupportedCallbackException
-import javax.security.sasl.AuthorizeCallback
-import javax.security.sasl.RealmCallback
-import javax.security.sasl.Sasl
-import javax.security.sasl.SaslException
-import javax.security.sasl.SaslServer
-import scala.collection.JavaConversions.mapAsJavaMap
-
-import com.google.common.base.Charsets.UTF_8
-import org.apache.commons.net.util.Base64
-
-/**
- * Encapsulates SASL server logic
- */
-private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging {
-
- /**
- * Actual SASL work done by this object from javax.security.sasl.
- */
- private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null,
- SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
- new SparkSaslDigestCallbackHandler(securityMgr))
-
- /**
- * Determines whether the authentication exchange has completed.
- * @return true is complete, otherwise false
- */
- def isComplete(): Boolean = {
- synchronized {
- if (saslServer != null) saslServer.isComplete() else false
- }
- }
-
- /**
- * Used to respond to server SASL tokens.
- * @param token Server's SASL token
- * @return response to send back to the server.
- */
- def response(token: Array[Byte]): Array[Byte] = {
- synchronized {
- if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0)
- }
- }
-
- /**
- * Disposes of any system resources or security-sensitive information the
- * SaslServer might be using.
- */
- def dispose() {
- synchronized {
- if (saslServer != null) {
- try {
- saslServer.dispose()
- } catch {
- case e: SaslException => // ignore
- } finally {
- saslServer = null
- }
- }
- }
- }
-
- /**
- * Implementation of javax.security.auth.callback.CallbackHandler
- * for SASL DIGEST-MD5 mechanism
- */
- private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager)
- extends CallbackHandler {
-
- private val userName: String =
- SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))
-
- override def handle(callbacks: Array[Callback]) {
- logDebug("In the sasl server callback handler")
- callbacks foreach {
- case nc: NameCallback => {
- logDebug("handle: SASL server callback: setting username")
- nc.setName(userName)
- }
- case pc: PasswordCallback => {
- logDebug("handle: SASL server callback: setting userPassword")
- val password: Array[Char] =
- SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes(UTF_8))
- pc.setPassword(password)
- }
- case rc: RealmCallback => {
- logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText())
- rc.setText(rc.getDefaultText())
- }
- case ac: AuthorizeCallback => {
- val authid = ac.getAuthenticationID()
- val authzid = ac.getAuthorizationID()
- if (authid.equals(authzid)) {
- logDebug("set auth to true")
- ac.setAuthorized(true)
- } else {
- logDebug("set auth to false")
- ac.setAuthorized(false)
- }
- if (ac.isAuthorized()) {
- logDebug("sasl server is authorized")
- ac.setAuthorizedID(authzid)
- }
- }
- case cb: Callback => throw
- new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback")
- }
- }
- }
-}
-
-private[spark] object SparkSaslServer {
-
- /**
- * This is passed as the server name when creating the sasl client/server.
- * This could be changed to be configurable in the future.
- */
- val SASL_DEFAULT_REALM = "default"
-
- /**
- * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
- * configurable in the future.
- */
- val DIGEST = "DIGEST-MD5"
-
- /**
- * The quality of protection is just "auth". This means that we are doing
- * authentication only, we are not supporting integrity or privacy protection of the
- * communication channel after authentication. This could be changed to be configurable
- * in the future.
- */
- val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true")
-
- /**
- * Encode a byte[] identifier as a Base64-encoded string.
- *
- * @param identifier identifier to encode
- * @return Base64-encoded string
- */
- def encodeIdentifier(identifier: Array[Byte]): String = {
- new String(Base64.encodeBase64(identifier), UTF_8)
- }
-
- /**
- * Encode a password as a base64-encoded char[] array.
- * @param password as a byte array.
- * @return password as a char array.
- */
- def encodePassword(password: Array[Byte]): Array[Char] = {
- new String(Base64.encodeBase64(password), UTF_8).toCharArray()
- }
-}
-
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 8b095e23f3..abc1dd0be6 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -86,6 +86,7 @@ private[spark] class Executor(
conf, executorId, slaveHostname, port, isLocal, actorSystem)
SparkEnv.set(_env)
_env.metricsSystem.registerSource(executorSource)
+ _env.blockManager.initialize(conf.getAppId)
_env
} else {
SparkEnv.get
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 1c4327cf13..0d1fc81d2a 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -17,13 +17,15 @@
package org.apache.spark.network.netty
+import scala.collection.JavaConversions._
import scala.concurrent.{Future, Promise}
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.network.client.{RpcResponseCallback, TransportClientFactory}
+import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory}
import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock}
+import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
import org.apache.spark.network.server._
import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher}
import org.apache.spark.serializer.JavaSerializer
@@ -33,18 +35,30 @@ import org.apache.spark.util.Utils
/**
* A BlockTransferService that uses Netty to fetch a set of blocks at at time.
*/
-class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
+class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager)
+ extends BlockTransferService {
+
// TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
- val serializer = new JavaSerializer(conf)
+ private val serializer = new JavaSerializer(conf)
+ private val authEnabled = securityManager.isAuthenticationEnabled()
+ private val transportConf = SparkTransportConf.fromSparkConf(conf)
private[this] var transportContext: TransportContext = _
private[this] var server: TransportServer = _
private[this] var clientFactory: TransportClientFactory = _
override def init(blockDataManager: BlockDataManager): Unit = {
- val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
- transportContext = new TransportContext(SparkTransportConf.fromSparkConf(conf), rpcHandler)
- clientFactory = transportContext.createClientFactory()
+ val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = {
+ val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
+ if (!authEnabled) {
+ (nettyRpcHandler, None)
+ } else {
+ (new SaslRpcHandler(nettyRpcHandler, securityManager),
+ Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager)))
+ }
+ }
+ transportContext = new TransportContext(transportConf, rpcHandler)
+ clientFactory = transportContext.createClientFactory(bootstrap.toList)
server = transportContext.createServer()
logInfo("Server created on " + server.getPort)
}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
index 4f6f5e2358..c2d9578be7 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
@@ -23,12 +23,13 @@ import java.nio.channels._
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.LinkedList
-import org.apache.spark._
-
import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
+import org.apache.spark._
+import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
+
private[nio]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId,
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index 8408b75bb4..f198aa8564 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -34,6 +34,7 @@ import scala.language.postfixOps
import com.google.common.base.Charsets.UTF_8
import org.apache.spark._
+import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
import org.apache.spark.util.Utils
import scala.util.Try
@@ -600,7 +601,7 @@ private[nio] class ConnectionManager(
} else {
var replyToken : Array[Byte] = null
try {
- replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken)
+ replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken)
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
@@ -634,7 +635,7 @@ private[nio] class ConnectionManager(
connection.synchronized {
if (connection.sparkSaslServer == null) {
logDebug("Creating sasl Server")
- connection.sparkSaslServer = new SparkSaslServer(securityManager)
+ connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager)
}
}
replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
@@ -778,7 +779,7 @@ private[nio] class ConnectionManager(
if (!conn.isSaslComplete()) {
conn.synchronized {
if (conn.sparkSaslClient == null) {
- conn.sparkSaslClient = new SparkSaslClient(securityManager)
+ conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager)
var firstResponse: Array[Byte] = null
try {
firstResponse = conn.sparkSaslClient.firstToken()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 5f5dd0dc1c..655d16c65c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -57,6 +57,12 @@ private[spark] class BlockResult(
inputMetrics.bytesRead = bytes
}
+/**
+ * Manager running on every node (driver and executors) which provides interfaces for putting and
+ * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap).
+ *
+ * Note that #initialize() must be called before the BlockManager is usable.
+ */
private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
@@ -69,8 +75,6 @@ private[spark] class BlockManager(
blockTransferService: BlockTransferService)
extends BlockDataManager with Logging {
- blockTransferService.init(this)
-
val diskBlockManager = new DiskBlockManager(this, conf)
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
@@ -102,22 +106,16 @@ private[spark] class BlockManager(
+ " switch to sort-based shuffle.")
}
- val blockManagerId = BlockManagerId(
- executorId, blockTransferService.hostName, blockTransferService.port)
+ var blockManagerId: BlockManagerId = _
// Address of the server that serves this executor's shuffle files. This is either an external
// service, or just our own Executor's BlockManager.
- private[spark] val shuffleServerId = if (externalShuffleServiceEnabled) {
- BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort)
- } else {
- blockManagerId
- }
+ private[spark] var shuffleServerId: BlockManagerId = _
// Client to read other executors' shuffle files. This is either an external service, or just the
// standard BlockTranserService to directly connect to other Executors.
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
- val appId = conf.get("spark.app.id", "unknown-app-id")
- new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), appId)
+ new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf))
} else {
blockTransferService
}
@@ -150,8 +148,6 @@ private[spark] class BlockManager(
private val peerFetchLock = new Object
private var lastPeerFetchTime = 0L
- initialize()
-
/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
* the initialization of the compression codec until it is first used. The reason is that a Spark
* program could be using a user-defined codec in a third party jar, which is loaded in
@@ -176,10 +172,27 @@ private[spark] class BlockManager(
}
/**
- * Initialize the BlockManager. Register to the BlockManagerMaster, and start the
- * BlockManagerWorker actor. Additionally registers with a local shuffle service if configured.
+ * Initializes the BlockManager with the given appId. This is not performed in the constructor as
+ * the appId may not be known at BlockManager instantiation time (in particular for the driver,
+ * where it is only learned after registration with the TaskScheduler).
+ *
+ * This method initializes the BlockTransferService and ShuffleClient, registers with the
+ * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle
+ * service if configured.
*/
- private def initialize(): Unit = {
+ def initialize(appId: String): Unit = {
+ blockTransferService.init(this)
+ shuffleClient.init(appId)
+
+ blockManagerId = BlockManagerId(
+ executorId, blockTransferService.hostName, blockTransferService.port)
+
+ shuffleServerId = if (externalShuffleServiceEnabled) {
+ BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort)
+ } else {
+ blockManagerId
+ }
+
master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
// Register Executors' configuration with the local shuffle service, if one should exist.
diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
new file mode 100644
index 0000000000..bed0ed9d71
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.nio._
+import java.util.concurrent.TimeUnit
+
+import scala.concurrent.duration._
+import scala.concurrent.{Await, Promise}
+import scala.util.{Failure, Success, Try}
+
+import org.apache.commons.io.IOUtils
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.shuffle.BlockFetchingListener
+import org.apache.spark.network.{BlockDataManager, BlockTransferService}
+import org.apache.spark.storage.{BlockId, ShuffleBlockId}
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.mockito.Mockito._
+import org.scalatest.mock.MockitoSugar
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMatchers}
+
+class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers {
+ test("security default off") {
+ testConnection(new SparkConf, new SparkConf) match {
+ case Success(_) => // expected
+ case Failure(t) => fail(t)
+ }
+ }
+
+ test("security on same password") {
+ val conf = new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.app.id", "app-id")
+ testConnection(conf, conf) match {
+ case Success(_) => // expected
+ case Failure(t) => fail(t)
+ }
+ }
+
+ test("security on mismatch password") {
+ val conf0 = new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.app.id", "app-id")
+ val conf1 = conf0.clone.set("spark.authenticate.secret", "bad")
+ testConnection(conf0, conf1) match {
+ case Success(_) => fail("Should have failed")
+ case Failure(t) => t.getMessage should include ("Mismatched response")
+ }
+ }
+
+ test("security mismatch auth off on server") {
+ val conf0 = new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.app.id", "app-id")
+ val conf1 = conf0.clone.set("spark.authenticate", "false")
+ testConnection(conf0, conf1) match {
+ case Success(_) => fail("Should have failed")
+ case Failure(t) => // any funny error may occur, sever will interpret SASL token as RPC
+ }
+ }
+
+ test("security mismatch auth off on client") {
+ val conf0 = new SparkConf()
+ .set("spark.authenticate", "false")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.app.id", "app-id")
+ val conf1 = conf0.clone.set("spark.authenticate", "true")
+ testConnection(conf0, conf1) match {
+ case Success(_) => fail("Should have failed")
+ case Failure(t) => t.getMessage should include ("Expected SaslMessage")
+ }
+ }
+
+ test("security mismatch app ids") {
+ val conf0 = new SparkConf()
+ .set("spark.authenticate", "true")
+ .set("spark.authenticate.secret", "good")
+ .set("spark.app.id", "app-id")
+ val conf1 = conf0.clone.set("spark.app.id", "other-id")
+ testConnection(conf0, conf1) match {
+ case Success(_) => fail("Should have failed")
+ case Failure(t) => t.getMessage should include ("SASL appId app-id did not match")
+ }
+ }
+
+ /**
+ * Creates two servers with different configurations and sees if they can talk.
+ * Returns Success() if they can transfer a block, and Failure() if the block transfer was failed
+ * properly. We will throw an out-of-band exception if something other than that goes wrong.
+ */
+ private def testConnection(conf0: SparkConf, conf1: SparkConf): Try[Unit] = {
+ val blockManager = mock[BlockDataManager]
+ val blockId = ShuffleBlockId(0, 1, 2)
+ val blockString = "Hello, world!"
+ val blockBuffer = new NioManagedBuffer(ByteBuffer.wrap(blockString.getBytes))
+ when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer)
+
+ val securityManager0 = new SecurityManager(conf0)
+ val exec0 = new NettyBlockTransferService(conf0, securityManager0)
+ exec0.init(blockManager)
+
+ val securityManager1 = new SecurityManager(conf1)
+ val exec1 = new NettyBlockTransferService(conf1, securityManager1)
+ exec1.init(blockManager)
+
+ val result = fetchBlock(exec0, exec1, "1", blockId) match {
+ case Success(buf) =>
+ IOUtils.toString(buf.createInputStream()) should equal(blockString)
+ buf.release()
+ Success()
+ case Failure(t) =>
+ Failure(t)
+ }
+ exec0.close()
+ exec1.close()
+ result
+ }
+
+ /** Synchronously fetches a single block, acting as the given executor fetching from another. */
+ private def fetchBlock(
+ self: BlockTransferService,
+ from: BlockTransferService,
+ execId: String,
+ blockId: BlockId): Try[ManagedBuffer] = {
+
+ val promise = Promise[ManagedBuffer]()
+
+ self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString),
+ new BlockFetchingListener {
+ override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
+ promise.failure(exception)
+ }
+
+ override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
+ promise.success(data.retain())
+ }
+ })
+
+ Await.ready(promise.future, FiniteDuration(1000, TimeUnit.MILLISECONDS))
+ promise.future.value.get
+ }
+}
+
diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
index b70734dfe3..716f875d30 100644
--- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
@@ -60,6 +60,7 @@ class ConnectionManagerSuite extends FunSuite {
val conf = new SparkConf
conf.set("spark.authenticate", "true")
conf.set("spark.authenticate.secret", "good")
+ conf.set("spark.app.id", "app-id")
val securityManager = new SecurityManager(conf)
val manager = new ConnectionManager(0, conf, securityManager)
var numReceivedMessages = 0
@@ -95,6 +96,7 @@ class ConnectionManagerSuite extends FunSuite {
test("security mismatch password") {
val conf = new SparkConf
conf.set("spark.authenticate", "true")
+ conf.set("spark.app.id", "app-id")
conf.set("spark.authenticate.secret", "good")
val securityManager = new SecurityManager(conf)
val manager = new ConnectionManager(0, conf, securityManager)
@@ -105,9 +107,7 @@ class ConnectionManagerSuite extends FunSuite {
None
})
- val badconf = new SparkConf
- badconf.set("spark.authenticate", "true")
- badconf.set("spark.authenticate.secret", "bad")
+ val badconf = conf.clone.set("spark.authenticate.secret", "bad")
val badsecurityManager = new SecurityManager(badconf)
val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
var numReceivedServerMessages = 0
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index c6d7105592..1461fa69db 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -63,6 +63,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
val transfer = new NioBlockTransferService(conf, securityMgr)
val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
mapOutputTracker, shuffleManager, transfer)
+ store.initialize("app-id")
allStores += store
store
}
@@ -263,6 +264,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
when(failableTransfer.port).thenReturn(1000)
val failableStore = new BlockManager("failable-store", actorSystem, master, serializer,
10000, conf, mapOutputTracker, shuffleManager, failableTransfer)
+ failableStore.initialize("app-id")
allStores += failableStore // so that this gets stopped after test
assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId))
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 715b740b85..0782876c8e 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -73,8 +73,10 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
maxMem: Long,
name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
val transfer = new NioBlockTransferService(conf, securityMgr)
- new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
+ val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
mapOutputTracker, shuffleManager, transfer)
+ manager.initialize("app-id")
+ manager
}
before {
diff --git a/docs/security.md b/docs/security.md
index ec0523184d..1e206a139f 100644
--- a/docs/security.md
+++ b/docs/security.md
@@ -7,7 +7,6 @@ Spark currently supports authentication via a shared secret. Authentication can
* For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret.
* For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications.
-* **IMPORTANT NOTE:** *The experimental Netty shuffle path (`spark.shuffle.use.netty`) is not secured, so do not use Netty for shuffles if running with authentication.*
## Web UI
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
index a271841e4e..5bc6e5a241 100644
--- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -17,12 +17,16 @@
package org.apache.spark.network;
+import java.util.List;
+
+import com.google.common.collect.Lists;
import io.netty.channel.Channel;
import io.netty.channel.socket.SocketChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.client.TransportResponseHandler;
import org.apache.spark.network.protocol.MessageDecoder;
@@ -64,8 +68,17 @@ public class TransportContext {
this.decoder = new MessageDecoder();
}
+ /**
+ * Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning
+ * a new Client. Bootstraps will be executed synchronously, and must run successfully in order
+ * to create a Client.
+ */
+ public TransportClientFactory createClientFactory(List<TransportClientBootstrap> bootstraps) {
+ return new TransportClientFactory(this, bootstraps);
+ }
+
public TransportClientFactory createClientFactory() {
- return new TransportClientFactory(this);
+ return createClientFactory(Lists.<TransportClientBootstrap>newArrayList());
}
/** Create a server which will attempt to bind to a specific port. */
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index 01c143fff4..a08cee02dd 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -19,10 +19,9 @@ package org.apache.spark.network.client;
import java.io.Closeable;
import java.util.UUID;
-import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
-import java.util.concurrent.TimeoutException;
+import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.util.concurrent.SettableFuture;
@@ -186,4 +185,12 @@ public class TransportClient implements Closeable {
// close is a local operation and should finish with milliseconds; timeout just to be safe
channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
}
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("remoteAdress", channel.remoteAddress())
+ .add("isActive", isActive())
+ .toString();
+ }
}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
new file mode 100644
index 0000000000..65e8020e34
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * A bootstrap which is executed on a TransportClient before it is returned to the user.
+ * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per-
+ * connection basis.
+ *
+ * Since connections (and TransportClients) are reused as much as possible, it is generally
+ * reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with
+ * the JVM itself.
+ */
+public interface TransportClientBootstrap {
+ /** Performs the bootstrapping operation, throwing an exception on failure. */
+ public void doBootstrap(TransportClient client) throws RuntimeException;
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 0b4a1d8286..1723fed307 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -21,10 +21,14 @@ import java.io.Closeable;
import java.lang.reflect.Field;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
+import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.collect.Lists;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.Channel;
@@ -40,6 +44,7 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;
@@ -47,22 +52,29 @@ import org.apache.spark.network.util.TransportConf;
* Factory for creating {@link TransportClient}s by using createClient.
*
* The factory maintains a connection pool to other hosts and should return the same
- * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for
- * all {@link TransportClient}s.
+ * TransportClient for the same remote host. It also shares a single worker thread pool for
+ * all TransportClients.
+ *
+ * TransportClients will be reused whenever possible. Prior to completing the creation of a new
+ * TransportClient, all given {@link TransportClientBootstrap}s will be run.
*/
public class TransportClientFactory implements Closeable {
private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
private final TransportContext context;
private final TransportConf conf;
+ private final List<TransportClientBootstrap> clientBootstraps;
private final ConcurrentHashMap<SocketAddress, TransportClient> connectionPool;
private final Class<? extends Channel> socketChannelClass;
private EventLoopGroup workerGroup;
- public TransportClientFactory(TransportContext context) {
- this.context = context;
+ public TransportClientFactory(
+ TransportContext context,
+ List<TransportClientBootstrap> clientBootstraps) {
+ this.context = Preconditions.checkNotNull(context);
this.conf = context.getConf();
+ this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
this.connectionPool = new ConcurrentHashMap<SocketAddress, TransportClient>();
IOMode ioMode = IOMode.valueOf(conf.ioMode());
@@ -72,9 +84,12 @@ public class TransportClientFactory implements Closeable {
}
/**
- * Create a new BlockFetchingClient connecting to the given remote host / port.
+ * Create a new {@link TransportClient} connecting to the given remote host / port. This will
+ * reuse TransportClients if they are still active and are for the same remote address. Prior
+ * to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s
+ * that are registered with this factory.
*
- * This blocks until a connection is successfully established.
+ * This blocks until a connection is successfully established and fully bootstrapped.
*
* Concurrency: This method is safe to call from multiple threads.
*/
@@ -104,17 +119,18 @@ public class TransportClientFactory implements Closeable {
// Use pooled buffers to reduce temporary buffer allocation
bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator());
- final AtomicReference<TransportClient> client = new AtomicReference<TransportClient>();
+ final AtomicReference<TransportClient> clientRef = new AtomicReference<TransportClient>();
bootstrap.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) {
TransportChannelHandler clientHandler = context.initializePipeline(ch);
- client.set(clientHandler.getClient());
+ clientRef.set(clientHandler.getClient());
}
});
// Connect to the remote server
+ long preConnect = System.currentTimeMillis();
ChannelFuture cf = bootstrap.connect(address);
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
throw new RuntimeException(
@@ -123,15 +139,35 @@ public class TransportClientFactory implements Closeable {
throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
}
- // Successful connection -- in the event that two threads raced to create a client, we will
+ TransportClient client = clientRef.get();
+ assert client != null : "Channel future completed successfully with null client";
+
+ // Execute any client bootstraps synchronously before marking the Client as successful.
+ long preBootstrap = System.currentTimeMillis();
+ logger.debug("Connection to {} successful, running bootstraps...", address);
+ try {
+ for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
+ clientBootstrap.doBootstrap(client);
+ }
+ } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
+ long bootstrapTime = System.currentTimeMillis() - preBootstrap;
+ logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e);
+ client.close();
+ throw Throwables.propagate(e);
+ }
+ long postBootstrap = System.currentTimeMillis();
+
+ // Successful connection & bootstrap -- in the event that two threads raced to create a client,
// use the first one that was put into the connectionPool and close the one we made here.
- assert client.get() != null : "Channel future completed successfully with null client";
- TransportClient oldClient = connectionPool.putIfAbsent(address, client.get());
+ TransportClient oldClient = connectionPool.putIfAbsent(address, client);
if (oldClient == null) {
- return client.get();
+ logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
+ address, postBootstrap - preConnect, postBootstrap - preBootstrap);
+ return client;
} else {
- logger.debug("Two clients were created concurrently, second one will be disposed.");
- client.get().close();
+ logger.debug("Two clients were created concurrently after {} ms, second will be disposed.",
+ postBootstrap - preConnect);
+ client.close();
return oldClient;
}
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
index 5a3f003726..1502b7489e 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
@@ -21,7 +21,7 @@ import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
/** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */
-public class NoOpRpcHandler implements RpcHandler {
+public class NoOpRpcHandler extends RpcHandler {
private final StreamManager streamManager;
public NoOpRpcHandler() {
diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
index 2369dc6203..2ba92a40f8 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -23,22 +23,33 @@ import org.apache.spark.network.client.TransportClient;
/**
* Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s.
*/
-public interface RpcHandler {
+public abstract class RpcHandler {
/**
* Receive a single RPC message. Any exception thrown while in this method will be sent back to
* the client in string form as a standard RPC failure.
*
+ * This method will not be called in parallel for a single TransportClient (i.e., channel).
+ *
* @param client A channel client which enables the handler to make requests back to the sender
- * of this RPC.
+ * of this RPC. This will always be the exact same object for a particular channel.
* @param message The serialized bytes of the RPC.
* @param callback Callback which should be invoked exactly once upon success or failure of the
* RPC.
*/
- void receive(TransportClient client, byte[] message, RpcResponseCallback callback);
+ public abstract void receive(
+ TransportClient client,
+ byte[] message,
+ RpcResponseCallback callback);
/**
* Returns the StreamManager which contains the state about which streams are currently being
* fetched by a TransportClient.
*/
- StreamManager getStreamManager();
+ public abstract StreamManager getStreamManager();
+
+ /**
+ * Invoked when the connection associated with the given client has been invalidated.
+ * No further requests will come from this client.
+ */
+ public void connectionTerminated(TransportClient client) { }
}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index 17fe9001b3..1580180cc1 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -86,6 +86,7 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
for (long streamId : streamIds) {
streamManager.connectionTerminated(streamId);
}
+ rpcHandler.connectionTerminated(reverseClient);
}
@Override
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
index a68f38e0e9..823790dd3c 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -55,4 +55,7 @@ public class TransportConf {
/** Send buffer size (SO_SNDBUF). */
public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); }
+
+ /** Timeout for a single round trip of SASL token exchange, in milliseconds. */
+ public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); }
}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
new file mode 100644
index 0000000000..7bc91e3753
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The
+ * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId.
+ */
+public class SaslClientBootstrap implements TransportClientBootstrap {
+ private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class);
+
+ private final TransportConf conf;
+ private final String appId;
+ private final SecretKeyHolder secretKeyHolder;
+
+ public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) {
+ this.conf = conf;
+ this.appId = appId;
+ this.secretKeyHolder = secretKeyHolder;
+ }
+
+ /**
+ * Performs SASL authentication by sending a token, and then proceeding with the SASL
+ * challenge-response tokens until we either successfully authenticate or throw an exception
+ * due to mismatch.
+ */
+ @Override
+ public void doBootstrap(TransportClient client) {
+ SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder);
+ try {
+ byte[] payload = saslClient.firstToken();
+
+ while (!saslClient.isComplete()) {
+ SaslMessage msg = new SaslMessage(appId, payload);
+ ByteBuf buf = Unpooled.buffer(msg.encodedLength());
+ msg.encode(buf);
+
+ byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeout());
+ payload = saslClient.response(response);
+ }
+ } finally {
+ try {
+ // Once authentication is complete, the server will trust all remaining communication.
+ saslClient.dispose();
+ } catch (RuntimeException e) {
+ logger.error("Error while disposing SASL client", e);
+ }
+ }
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
new file mode 100644
index 0000000000..5b77e18c26
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import com.google.common.base.Charsets;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encodable;
+
+/**
+ * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged
+ * with the given appId. This appId allows a single SaslRpcHandler to multiplex different
+ * applications which may be using different sets of credentials.
+ */
+class SaslMessage implements Encodable {
+
+ /** Serialization tag used to catch incorrect payloads. */
+ private static final byte TAG_BYTE = (byte) 0xEA;
+
+ public final String appId;
+ public final byte[] payload;
+
+ public SaslMessage(String appId, byte[] payload) {
+ this.appId = appId;
+ this.payload = payload;
+ }
+
+ @Override
+ public int encodedLength() {
+ // tag + appIdLength + appId + payloadLength + payload
+ return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeByte(TAG_BYTE);
+ byte[] idBytes = appId.getBytes(Charsets.UTF_8);
+ buf.writeInt(idBytes.length);
+ buf.writeBytes(idBytes);
+ buf.writeInt(payload.length);
+ buf.writeBytes(payload);
+ }
+
+ public static SaslMessage decode(ByteBuf buf) {
+ if (buf.readByte() != TAG_BYTE) {
+ throw new IllegalStateException("Expected SaslMessage, received something else");
+ }
+
+ int idLength = buf.readInt();
+ byte[] idBytes = new byte[idLength];
+ buf.readBytes(idBytes);
+
+ int payloadLength = buf.readInt();
+ byte[] payload = new byte[payloadLength];
+ buf.readBytes(payload);
+
+ return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload);
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
new file mode 100644
index 0000000000..3777a18e33
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.util.concurrent.ConcurrentMap;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.Maps;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+
+/**
+ * RPC Handler which performs SASL authentication before delegating to a child RPC handler.
+ * The delegate will only receive messages if the given connection has been successfully
+ * authenticated. A connection may be authenticated at most once.
+ *
+ * Note that the authentication process consists of multiple challenge-response pairs, each of
+ * which are individual RPCs.
+ */
+public class SaslRpcHandler extends RpcHandler {
+ private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
+
+ /** RpcHandler we will delegate to for authenticated connections. */
+ private final RpcHandler delegate;
+
+ /** Class which provides secret keys which are shared by server and client on a per-app basis. */
+ private final SecretKeyHolder secretKeyHolder;
+
+ /** Maps each channel to its SASL authentication state. */
+ private final ConcurrentMap<TransportClient, SparkSaslServer> channelAuthenticationMap;
+
+ public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) {
+ this.delegate = delegate;
+ this.secretKeyHolder = secretKeyHolder;
+ this.channelAuthenticationMap = Maps.newConcurrentMap();
+ }
+
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ SparkSaslServer saslServer = channelAuthenticationMap.get(client);
+ if (saslServer != null && saslServer.isComplete()) {
+ // Authentication complete, delegate to base handler.
+ delegate.receive(client, message, callback);
+ return;
+ }
+
+ SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message));
+
+ if (saslServer == null) {
+ // First message in the handshake, setup the necessary state.
+ saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder);
+ channelAuthenticationMap.put(client, saslServer);
+ }
+
+ byte[] response = saslServer.response(saslMessage.payload);
+ if (saslServer.isComplete()) {
+ logger.debug("SASL authentication successful for channel {}", client);
+ }
+ callback.onSuccess(response);
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return delegate.getStreamManager();
+ }
+
+ @Override
+ public void connectionTerminated(TransportClient client) {
+ SparkSaslServer saslServer = channelAuthenticationMap.remove(client);
+ if (saslServer != null) {
+ saslServer.dispose();
+ }
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java
new file mode 100644
index 0000000000..81d5766794
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+/**
+ * Interface for getting a secret key associated with some application.
+ */
+public interface SecretKeyHolder {
+ /**
+ * Gets an appropriate SASL User for the given appId.
+ * @throws IllegalArgumentException if the given appId is not associated with a SASL user.
+ */
+ String getSaslUser(String appId);
+
+ /**
+ * Gets an appropriate SASL secret key for the given appId.
+ * @throws IllegalArgumentException if the given appId is not associated with a SASL secret key.
+ */
+ String getSecretKey(String appId);
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
new file mode 100644
index 0000000000..72ba737b99
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.RealmChoiceCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslException;
+import java.io.IOException;
+
+import com.google.common.base.Throwables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.apache.spark.network.sasl.SparkSaslServer.*;
+
+/**
+ * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the
+ * initial state to the "authenticated" state. This client initializes the protocol via a
+ * firstToken, which is then followed by a set of challenges and responses.
+ */
+public class SparkSaslClient {
+ private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class);
+
+ private final String secretKeyId;
+ private final SecretKeyHolder secretKeyHolder;
+ private SaslClient saslClient;
+
+ public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) {
+ this.secretKeyId = secretKeyId;
+ this.secretKeyHolder = secretKeyHolder;
+ try {
+ this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM,
+ SASL_PROPS, new ClientCallbackHandler());
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /** Used to initiate SASL handshake with server. */
+ public synchronized byte[] firstToken() {
+ if (saslClient != null && saslClient.hasInitialResponse()) {
+ try {
+ return saslClient.evaluateChallenge(new byte[0]);
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ } else {
+ return new byte[0];
+ }
+ }
+
+ /** Determines whether the authentication exchange has completed. */
+ public synchronized boolean isComplete() {
+ return saslClient != null && saslClient.isComplete();
+ }
+
+ /**
+ * Respond to server's SASL token.
+ * @param token contains server's SASL token
+ * @return client's response SASL token
+ */
+ public synchronized byte[] response(byte[] token) {
+ try {
+ return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0];
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslClient might be using.
+ */
+ public synchronized void dispose() {
+ if (saslClient != null) {
+ try {
+ saslClient.dispose();
+ } catch (SaslException e) {
+ // ignore
+ } finally {
+ saslClient = null;
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler
+ * that works with share secrets.
+ */
+ private class ClientCallbackHandler implements CallbackHandler {
+ @Override
+ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+
+ for (Callback callback : callbacks) {
+ if (callback instanceof NameCallback) {
+ logger.trace("SASL client callback: setting username");
+ NameCallback nc = (NameCallback) callback;
+ nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
+ } else if (callback instanceof PasswordCallback) {
+ logger.trace("SASL client callback: setting password");
+ PasswordCallback pc = (PasswordCallback) callback;
+ pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
+ } else if (callback instanceof RealmCallback) {
+ logger.trace("SASL client callback: setting realm");
+ RealmCallback rc = (RealmCallback) callback;
+ rc.setText(rc.getDefaultText());
+ logger.info("Realm callback");
+ } else if (callback instanceof RealmChoiceCallback) {
+ // ignore (?)
+ } else {
+ throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback");
+ }
+ }
+ }
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
new file mode 100644
index 0000000000..2c0ce40c75
--- /dev/null
+++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import javax.security.auth.callback.Callback;
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.auth.callback.NameCallback;
+import javax.security.auth.callback.PasswordCallback;
+import javax.security.auth.callback.UnsupportedCallbackException;
+import javax.security.sasl.AuthorizeCallback;
+import javax.security.sasl.RealmCallback;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+import java.io.IOException;
+import java.util.Map;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.io.BaseEncoding;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the
+ * initial state to the "authenticated" state. (It is not a server in the sense of accepting
+ * connections on some socket.)
+ */
+public class SparkSaslServer {
+ private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class);
+
+ /**
+ * This is passed as the server name when creating the sasl client/server.
+ * This could be changed to be configurable in the future.
+ */
+ static final String DEFAULT_REALM = "default";
+
+ /**
+ * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
+ * configurable in the future.
+ */
+ static final String DIGEST = "DIGEST-MD5";
+
+ /**
+ * The quality of protection is just "auth". This means that we are doing
+ * authentication only, we are not supporting integrity or privacy protection of the
+ * communication channel after authentication. This could be changed to be configurable
+ * in the future.
+ */
+ static final Map<String, String> SASL_PROPS = ImmutableMap.<String, String>builder()
+ .put(Sasl.QOP, "auth")
+ .put(Sasl.SERVER_AUTH, "true")
+ .build();
+
+ /** Identifier for a certain secret key within the secretKeyHolder. */
+ private final String secretKeyId;
+ private final SecretKeyHolder secretKeyHolder;
+ private SaslServer saslServer;
+
+ public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) {
+ this.secretKeyId = secretKeyId;
+ this.secretKeyHolder = secretKeyHolder;
+ try {
+ this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS,
+ new DigestCallbackHandler());
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Determines whether the authentication exchange has completed successfully.
+ */
+ public synchronized boolean isComplete() {
+ return saslServer != null && saslServer.isComplete();
+ }
+
+ /**
+ * Used to respond to server SASL tokens.
+ * @param token Server's SASL token
+ * @return response to send back to the server.
+ */
+ public synchronized byte[] response(byte[] token) {
+ try {
+ return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0];
+ } catch (SaslException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslServer might be using.
+ */
+ public synchronized void dispose() {
+ if (saslServer != null) {
+ try {
+ saslServer.dispose();
+ } catch (SaslException e) {
+ // ignore
+ } finally {
+ saslServer = null;
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism.
+ */
+ private class DigestCallbackHandler implements CallbackHandler {
+ @Override
+ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
+ for (Callback callback : callbacks) {
+ if (callback instanceof NameCallback) {
+ logger.trace("SASL server callback: setting username");
+ NameCallback nc = (NameCallback) callback;
+ nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
+ } else if (callback instanceof PasswordCallback) {
+ logger.trace("SASL server callback: setting password");
+ PasswordCallback pc = (PasswordCallback) callback;
+ pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
+ } else if (callback instanceof RealmCallback) {
+ logger.trace("SASL server callback: setting realm");
+ RealmCallback rc = (RealmCallback) callback;
+ rc.setText(rc.getDefaultText());
+ } else if (callback instanceof AuthorizeCallback) {
+ AuthorizeCallback ac = (AuthorizeCallback) callback;
+ String authId = ac.getAuthenticationID();
+ String authzId = ac.getAuthorizationID();
+ ac.setAuthorized(authId.equals(authzId));
+ if (ac.isAuthorized()) {
+ ac.setAuthorizedID(authzId);
+ }
+ logger.debug("SASL Authorization complete, authorized set to {}", ac.isAuthorized());
+ } else {
+ throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback");
+ }
+ }
+ }
+ }
+
+ /* Encode a byte[] identifier as a Base64-encoded string. */
+ public static String encodeIdentifier(String identifier) {
+ Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled");
+ return BaseEncoding.base64().encode(identifier.getBytes(Charsets.UTF_8));
+ }
+
+ /** Encode a password as a base64-encoded char[] array. */
+ public static char[] encodePassword(String password) {
+ Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled");
+ return BaseEncoding.base64().encode(password.getBytes(Charsets.UTF_8)).toCharArray();
+ }
+}
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index a9dff31dec..cd3fea85b1 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -41,7 +41,7 @@ import org.apache.spark.network.util.JavaUtils;
* with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark-
* level shuffle block.
*/
-public class ExternalShuffleBlockHandler implements RpcHandler {
+public class ExternalShuffleBlockHandler extends RpcHandler {
private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class);
private final ExternalShuffleBlockManager blockManager;
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index 6bbabc44b9..b0b19ba67b 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -17,8 +17,6 @@
package org.apache.spark.network.shuffle;
-import java.io.Closeable;
-
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -36,15 +34,20 @@ import org.apache.spark.network.util.TransportConf;
* BlockTransferService), which has the downside of losing the shuffle data if we lose the
* executors.
*/
-public class ExternalShuffleClient implements ShuffleClient {
+public class ExternalShuffleClient extends ShuffleClient {
private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class);
private final TransportClientFactory clientFactory;
- private final String appId;
- public ExternalShuffleClient(TransportConf conf, String appId) {
+ private String appId;
+
+ public ExternalShuffleClient(TransportConf conf) {
TransportContext context = new TransportContext(conf, new NoOpRpcHandler());
this.clientFactory = context.createClientFactory();
+ }
+
+ @Override
+ public void init(String appId) {
this.appId = appId;
}
@@ -55,6 +58,7 @@ public class ExternalShuffleClient implements ShuffleClient {
String execId,
String[] blockIds,
BlockFetchingListener listener) {
+ assert appId != null : "Called before init()";
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
try {
TransportClient client = clientFactory.createClient(host, port);
@@ -82,6 +86,7 @@ public class ExternalShuffleClient implements ShuffleClient {
int port,
String execId,
ExecutorShuffleInfo executorInfo) {
+ assert appId != null : "Called before init()";
TransportClient client = clientFactory.createClient(host, port);
byte[] registerExecutorMessage =
JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo));
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
index d46a562394..f72ab40690 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
@@ -20,7 +20,14 @@ package org.apache.spark.network.shuffle;
import java.io.Closeable;
/** Provides an interface for reading shuffle files, either from an Executor or external service. */
-public interface ShuffleClient extends Closeable {
+public abstract class ShuffleClient implements Closeable {
+
+ /**
+ * Initializes the ShuffleClient, specifying this Executor's appId.
+ * Must be called before any other method on the ShuffleClient.
+ */
+ public void init(String appId) { }
+
/**
* Fetch a sequence of blocks from a remote node asynchronously,
*
@@ -28,7 +35,7 @@ public interface ShuffleClient extends Closeable {
* return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
* the data of a block is fetched, rather than waiting for all blocks to be fetched.
*/
- public void fetchBlocks(
+ public abstract void fetchBlocks(
String host,
int port,
String execId,
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
new file mode 100644
index 0000000000..8478120786
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.io.IOException;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.TestUtils;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class SaslIntegrationSuite {
+ static ExternalShuffleBlockHandler handler;
+ static TransportServer server;
+ static TransportConf conf;
+ static TransportContext context;
+
+ TransportClientFactory clientFactory;
+
+ /** Provides a secret key holder which always returns the given secret key. */
+ static class TestSecretKeyHolder implements SecretKeyHolder {
+
+ private final String secretKey;
+
+ TestSecretKeyHolder(String secretKey) {
+ this.secretKey = secretKey;
+ }
+
+ @Override
+ public String getSaslUser(String appId) {
+ return "user";
+ }
+ @Override
+ public String getSecretKey(String appId) {
+ return secretKey;
+ }
+ }
+
+
+ @BeforeClass
+ public static void beforeAll() throws IOException {
+ SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key");
+ SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder);
+ conf = new TransportConf(new SystemPropertyConfigProvider());
+ context = new TransportContext(conf, handler);
+ server = context.createServer();
+ }
+
+
+ @AfterClass
+ public static void afterAll() {
+ server.close();
+ }
+
+ @After
+ public void afterEach() {
+ if (clientFactory != null) {
+ clientFactory.close();
+ clientFactory = null;
+ }
+ }
+
+ @Test
+ public void testGoodClient() {
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key"))));
+
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ String msg = "Hello, World!";
+ byte[] resp = client.sendRpcSync(msg.getBytes(), 1000);
+ assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg
+ }
+
+ @Test
+ public void testBadClient() {
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key"))));
+
+ try {
+ // Bootstrap should fail on startup.
+ clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
+ }
+ }
+
+ @Test
+ public void testNoSaslClient() {
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList());
+
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ try {
+ client.sendRpcSync(new byte[13], 1000);
+ fail("Should have failed");
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage"));
+ }
+
+ try {
+ // Guessing the right tag byte doesn't magically get you in...
+ client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000);
+ fail("Should have failed");
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException"));
+ }
+ }
+
+ @Test
+ public void testNoSaslServer() {
+ RpcHandler handler = new TestRpcHandler();
+ TransportContext context = new TransportContext(conf, handler);
+ clientFactory = context.createClientFactory(
+ Lists.<TransportClientBootstrap>newArrayList(
+ new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key"))));
+ TransportServer server = context.createServer();
+ try {
+ clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ } catch (Exception e) {
+ assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation"));
+ } finally {
+ server.close();
+ }
+ }
+
+ /** RPC handler which simply responds with the message it received. */
+ public static class TestRpcHandler extends RpcHandler {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ callback.onSuccess(message);
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return new OneForOneStreamManager();
+ }
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
new file mode 100644
index 0000000000..67a07f38eb
--- /dev/null
+++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.sasl;
+
+import java.util.Map;
+
+import com.google.common.collect.ImmutableMap;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+/**
+ * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes.
+ */
+public class SparkSaslSuite {
+
+ /** Provides a secret key holder which returns secret key == appId */
+ private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() {
+ @Override
+ public String getSaslUser(String appId) {
+ return "user";
+ }
+
+ @Override
+ public String getSecretKey(String appId) {
+ return appId;
+ }
+ };
+
+ @Test
+ public void testMatching() {
+ SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder);
+ SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder);
+
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+
+ byte[] clientMessage = client.firstToken();
+
+ while (!client.isComplete()) {
+ clientMessage = client.response(server.response(clientMessage));
+ }
+ assertTrue(server.isComplete());
+
+ // Disposal should invalidate
+ server.dispose();
+ assertFalse(server.isComplete());
+ client.dispose();
+ assertFalse(client.isComplete());
+ }
+
+
+ @Test
+ public void testNonMatching() {
+ SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder);
+ SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder);
+
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+
+ byte[] clientMessage = client.firstToken();
+
+ try {
+ while (!client.isComplete()) {
+ clientMessage = client.response(server.response(clientMessage));
+ }
+ fail("Should not have completed");
+ } catch (Exception e) {
+ assertTrue(e.getMessage().contains("Mismatched response"));
+ assertFalse(client.isComplete());
+ assertFalse(server.isComplete());
+ }
+ }
+}
diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
index b3bcf5fd68..bc101f5384 100644
--- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
+++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java
@@ -135,7 +135,8 @@ public class ExternalShuffleIntegrationSuite {
final Semaphore requestsRemaining = new Semaphore(0);
- ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID);
+ ExternalShuffleClient client = new ExternalShuffleClient(conf);
+ client.init(APP_ID);
client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds,
new BlockFetchingListener() {
@Override
@@ -164,6 +165,7 @@ public class ExternalShuffleIntegrationSuite {
if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) {
fail("Timeout getting response from the server");
}
+ client.close();
return res;
}
@@ -265,7 +267,8 @@ public class ExternalShuffleIntegrationSuite {
}
private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) {
- ExternalShuffleClient client = new ExternalShuffleClient(conf, APP_ID);
+ ExternalShuffleClient client = new ExternalShuffleClient(conf);
+ client.init(APP_ID);
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(),
executorId, executorInfo);
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index ad1a6f01b3..0f27f55fec 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -74,6 +74,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche
blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer,
blockManagerSize, conf, mapOutputTracker, shuffleManager,
new NioBlockTransferService(conf, securityMgr))
+ blockManager.initialize("app-id")
tempDirectory = Files.createTempDir()
manualClock.setTime(0)