aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThomas Graves <tgraves@apache.org>2014-10-02 13:52:54 -0700
committerReynold Xin <rxin@apache.org>2014-10-02 13:52:54 -0700
commit127e97bee1e6aae7b70263bc5944b7be6f4e6fea (patch)
tree7a6f5b7ae2fc2f058a803bcb85a71354de2b1ca9
parent5db78e6b87d33ac2d48a997e69b46e9be3b63137 (diff)
downloadspark-127e97bee1e6aae7b70263bc5944b7be6f4e6fea.tar.gz
spark-127e97bee1e6aae7b70263bc5944b7be6f4e6fea.tar.bz2
spark-127e97bee1e6aae7b70263bc5944b7be6f4e6fea.zip
[SPARK-3632] ConnectionManager can run out of receive threads with authentication on
If you turn authentication on and you are using a lot of executors. There is a chance that all the of the threads in the handleMessageExecutor could be waiting to send a message because they are blocked waiting on authentication to happen. This can cause a temporary deadlock until the connection times out. To fix it, I got rid of the wait/notify and use a single outbox but only send security messages from it until authentication has completed. Author: Thomas Graves <tgraves@apache.org> Closes #2484 from tgravescs/cm_threads_auth and squashes the following commits: a0a961d [Thomas Graves] give it a type b6bc80b [Thomas Graves] Rework comments d6d4175 [Thomas Graves] update from comments 081b765 [Thomas Graves] cleanup 4d7f8f5 [Thomas Graves] Change to not use wait/notify while waiting for authentication
-rw-r--r--core/src/main/scala/org/apache/spark/SecurityManager.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/network/nio/Connection.scala65
-rw-r--r--core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala72
3 files changed, 63 insertions, 79 deletions
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 3832a780ec..0e0f1a7b23 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -103,10 +103,9 @@ import org.apache.spark.deploy.SparkHadoopUtil
* and a Server, so for a particular connection is has to determine what to do.
* A ConnectionId was added to be able to track connections and is used to
* match up incoming messages with connections waiting for authentication.
- * If its acting as a client and trying to send a message to another ConnectionManager,
- * it blocks the thread calling sendMessage until the SASL negotiation has occurred.
* The ConnectionManager tracks all the sendingConnections using the ConnectionId
- * and waits for the response from the server and does the handshake.
+ * and waits for the response from the server and does the handshake before sending
+ * the real message.
*
* - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
* can be used. Yarn requires a specific AmIpFilter be installed for security to work
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
index 18172d359c..f368209980 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
@@ -20,23 +20,27 @@ package org.apache.spark.network.nio
import java.net._
import java.nio._
import java.nio.channels._
+import java.util.LinkedList
import org.apache.spark._
-import scala.collection.mutable.{ArrayBuffer, HashMap, Queue}
+import scala.collection.mutable.{ArrayBuffer, HashMap}
private[nio]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
- val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId)
+ val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId,
+ val securityMgr: SecurityManager)
extends Logging {
var sparkSaslServer: SparkSaslServer = null
var sparkSaslClient: SparkSaslClient = null
- def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = {
+ def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId,
+ securityMgr_ : SecurityManager) = {
this(channel_, selector_,
ConnectionManagerId.fromSocketAddress(
- channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), id_)
+ channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]),
+ id_, securityMgr_)
}
channel.configureBlocking(false)
@@ -52,14 +56,6 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
val remoteAddress = getRemoteAddress()
- /**
- * Used to synchronize client requests: client's work-related requests must
- * wait until SASL authentication completes.
- */
- private val authenticated = new Object()
-
- def getAuthenticated(): Object = authenticated
-
def isSaslComplete(): Boolean
def resetForceReregister(): Boolean
@@ -192,22 +188,22 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
private[nio]
class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
- remoteId_ : ConnectionManagerId, id_ : ConnectionId)
- extends Connection(SocketChannel.open, selector_, remoteId_, id_) {
+ remoteId_ : ConnectionManagerId, id_ : ConnectionId,
+ securityMgr_ : SecurityManager)
+ extends Connection(SocketChannel.open, selector_, remoteId_, id_, securityMgr_) {
def isSaslComplete(): Boolean = {
if (sparkSaslClient != null) sparkSaslClient.isComplete() else false
}
private class Outbox {
- val messages = new Queue[Message]()
+ val messages = new LinkedList[Message]()
val defaultChunkSize = 65536
var nextMessageToBeUsed = 0
def addMessage(message: Message) {
messages.synchronized {
- /* messages += message */
- messages.enqueue(message)
+ messages.add(message)
logDebug("Added [" + message + "] to outbox for sending to " +
"[" + getRemoteConnectionManagerId() + "]")
}
@@ -218,10 +214,27 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
while (!messages.isEmpty) {
/* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
/* val message = messages(nextMessageToBeUsed) */
- val message = messages.dequeue()
+
+ val message = if (securityMgr.isAuthenticationEnabled() && !isSaslComplete()) {
+ // only allow sending of security messages until sasl is complete
+ var pos = 0
+ var securityMsg: Message = null
+ while (pos < messages.size() && securityMsg == null) {
+ if (messages.get(pos).isSecurityNeg) {
+ securityMsg = messages.remove(pos)
+ }
+ pos = pos + 1
+ }
+ // didn't find any security messages and auth isn't completed so return
+ if (securityMsg == null) return None
+ securityMsg
+ } else {
+ messages.removeFirst()
+ }
+
val chunk = message.getChunkForSending(defaultChunkSize)
if (chunk.isDefined) {
- messages.enqueue(message)
+ messages.add(message)
nextMessageToBeUsed = nextMessageToBeUsed + 1
if (!message.started) {
logDebug(
@@ -273,6 +286,15 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
changeConnectionKeyInterest(DEFAULT_INTEREST)
}
+ def registerAfterAuth(): Unit = {
+ outbox.synchronized {
+ needForceReregister = true
+ }
+ if (channel.isConnected) {
+ registerInterest()
+ }
+ }
+
def send(message: Message) {
outbox.synchronized {
outbox.addMessage(message)
@@ -415,8 +437,9 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
private[spark] class ReceivingConnection(
channel_ : SocketChannel,
selector_ : Selector,
- id_ : ConnectionId)
- extends Connection(channel_, selector_, id_) {
+ id_ : ConnectionId,
+ securityMgr_ : SecurityManager)
+ extends Connection(channel_, selector_, id_, securityMgr_) {
def isSaslComplete(): Boolean = {
if (sparkSaslServer != null) sparkSaslServer.isComplete() else false
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index 5aa7e94943..01cd27a907 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -32,7 +32,7 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.language.postfixOps
import org.apache.spark._
-import org.apache.spark.util.{SystemClock, Utils}
+import org.apache.spark.util.Utils
private[nio] class ConnectionManager(
@@ -65,8 +65,6 @@ private[nio] class ConnectionManager(
private val selector = SelectorProvider.provider.openSelector()
private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)
- // default to 30 second timeout waiting for authentication
- private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30)
private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60)
private val handleMessageExecutor = new ThreadPoolExecutor(
@@ -409,7 +407,8 @@ private[nio] class ConnectionManager(
while (newChannel != null) {
try {
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
- val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId)
+ val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId,
+ securityManager)
newConnection.onReceive(receiveMessage)
addListeners(newConnection)
addConnection(newConnection)
@@ -527,9 +526,8 @@ private[nio] class ConnectionManager(
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
- waitingConn.getAuthenticated().synchronized {
- waitingConn.getAuthenticated().notifyAll()
- }
+ waitingConn.registerAfterAuth()
+ wakeupSelector()
return
} else {
var replyToken : Array[Byte] = null
@@ -538,9 +536,8 @@ private[nio] class ConnectionManager(
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
- waitingConn.getAuthenticated().synchronized {
- waitingConn.getAuthenticated().notifyAll()
- }
+ waitingConn.registerAfterAuth()
+ wakeupSelector()
return
}
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
@@ -574,9 +571,11 @@ private[nio] class ConnectionManager(
}
replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
if (connection.isSaslComplete()) {
- logDebug("Server sasl completed: " + connection.connectionId)
+ logDebug("Server sasl completed: " + connection.connectionId +
+ " for: " + connectionId)
} else {
- logDebug("Server sasl not completed: " + connection.connectionId)
+ logDebug("Server sasl not completed: " + connection.connectionId +
+ " for: " + connectionId)
}
if (replyToken != null) {
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
@@ -723,7 +722,8 @@ private[nio] class ConnectionManager(
if (message == null) throw new Exception("Error creating security message")
connectionsAwaitingSasl += ((conn.connectionId, conn))
sendSecurityMessage(connManagerId, message)
- logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId)
+ logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId +
+ " to: " + connManagerId)
} catch {
case e: Exception => {
logError("Error getting first response from the SaslClient.", e)
@@ -744,7 +744,7 @@ private[nio] class ConnectionManager(
val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port)
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId,
- newConnectionId)
+ newConnectionId, securityManager)
logInfo("creating new sending connection for security! " + newConnectionId )
registerRequests.enqueue(newConnection)
@@ -769,61 +769,23 @@ private[nio] class ConnectionManager(
connectionManagerId.port)
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId,
- newConnectionId)
+ newConnectionId, securityManager)
logTrace("creating new sending connection: " + newConnectionId)
registerRequests.enqueue(newConnection)
newConnection
}
val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
- if (authEnabled) {
- checkSendAuthFirst(connectionManagerId, connection)
- }
+
message.senderAddress = id.toSocketAddress()
logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " +
"connectionid: " + connection.connectionId)
if (authEnabled) {
- // if we aren't authenticated yet lets block the senders until authentication completes
- try {
- connection.getAuthenticated().synchronized {
- val clock = SystemClock
- val startTime = clock.getTime()
-
- while (!connection.isSaslComplete()) {
- logDebug("getAuthenticated wait connectionid: " + connection.connectionId)
- // have timeout in case remote side never responds
- connection.getAuthenticated().wait(500)
- if (((clock.getTime() - startTime) >= (authTimeout * 1000))
- && (!connection.isSaslComplete())) {
- // took to long to authenticate the connection, something probably went wrong
- throw new Exception("Took to long for authentication to " + connectionManagerId +
- ", waited " + authTimeout + "seconds, failing.")
- }
- }
- }
- } catch {
- case e: Exception => logError("Exception while waiting for authentication.", e)
-
- // need to tell sender it failed
- messageStatuses.synchronized {
- val s = messageStatuses.get(message.id)
- s match {
- case Some(msgStatus) => {
- messageStatuses -= message.id
- logInfo("Notifying " + msgStatus.connectionManagerId)
- msgStatus.markDone(None)
- }
- case None => {
- logError("no messageStatus for failed message id: " + message.id)
- }
- }
- }
- }
+ checkSendAuthFirst(connectionManagerId, connection)
}
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
connection.send(message)
-
wakeupSelector()
}