aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala47
1 files changed, 35 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index f198aa8564..df4b085d22 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -18,13 +18,13 @@
package org.apache.spark.network.nio
import java.io.IOException
+import java.lang.ref.WeakReference
import java.net._
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit}
-import java.util.{Timer, TimerTask}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue}
import scala.concurrent.duration._
@@ -32,6 +32,7 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.language.postfixOps
import com.google.common.base.Charsets.UTF_8
+import io.netty.util.{Timeout, TimerTask, HashedWheelTimer}
import org.apache.spark._
import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
@@ -77,7 +78,8 @@ private[nio] class ConnectionManager(
}
private val selector = SelectorProvider.provider.openSelector()
- private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)
+ private val ackTimeoutMonitor =
+ new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor"))
private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60)
@@ -139,7 +141,10 @@ private[nio] class ConnectionManager(
new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection]
with SynchronizedMap[ConnectionManagerId, SendingConnection]
- private val messageStatuses = new HashMap[Int, MessageStatus]
+ // Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this
+ // map when messages are sent and are removed when acknowledgement messages are received or when
+ // acknowledgement timeouts expire
+ private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus]
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
private val registerRequests = new SynchronizedQueue[SendingConnection]
@@ -899,22 +904,41 @@ private[nio] class ConnectionManager(
: Future[Message] = {
val promise = Promise[Message]()
- val timeoutTask = new TimerTask {
- override def run(): Unit = {
+ // It's important that the TimerTask doesn't capture a reference to `message`, which can cause
+ // memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time
+ // at which they would originally be scheduled to run. Therefore, extract the message id
+ // from outside of the TimerTask closure (see SPARK-4393 for more context).
+ val messageId = message.id
+ // Keep a weak reference to the promise so that the completed promise may be garbage-collected
+ val promiseReference = new WeakReference(promise)
+ val timeoutTask: TimerTask = new TimerTask {
+ override def run(timeout: Timeout): Unit = {
messageStatuses.synchronized {
- messageStatuses.remove(message.id).foreach ( s => {
+ messageStatuses.remove(messageId).foreach { s =>
val e = new IOException("sendMessageReliably failed because ack " +
s"was not received within $ackTimeout sec")
- if (!promise.tryFailure(e)) {
- logWarning("Ignore error because promise is completed", e)
+ val p = promiseReference.get
+ if (p != null) {
+ // Attempt to fail the promise with a Timeout exception
+ if (!p.tryFailure(e)) {
+ // If we reach here, then someone else has already signalled success or failure
+ // on this promise, so log a warning:
+ logError("Ignore error because promise is completed", e)
+ }
+ } else {
+ // The WeakReference was empty, which should never happen because
+ // sendMessageReliably's caller should have a strong reference to promise.future;
+ logError("Promise was garbage collected; this should never happen!", e)
}
- })
+ }
}
}
}
+ val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS)
+
val status = new MessageStatus(message, connectionManagerId, s => {
- timeoutTask.cancel()
+ timeoutTaskHandle.cancel()
s match {
case scala.util.Failure(e) =>
// Indicates a failure where we either never sent or never got ACK'd
@@ -943,7 +967,6 @@ private[nio] class ConnectionManager(
messageStatuses += ((message.id, status))
}
- ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000)
sendMessage(connectionManagerId, message)
promise.future
}
@@ -953,7 +976,7 @@ private[nio] class ConnectionManager(
}
def stop() {
- ackTimeoutMonitor.cancel()
+ ackTimeoutMonitor.stop()
selectorThread.interrupt()
selectorThread.join()
selector.close()