aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorKousuke Saruta <sarutak@oss.nttdata.co.jp>2014-08-06 17:27:55 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-08-06 17:27:55 -0700
commit17caae48b3608552dd6e3ae652043831f932ce95 (patch)
treeefafda9f93d70fde4373936a6869d1d3318c0bb0 /core/src
parent4e008334ee0fb60f9fe8820afa06f7b7f0fa7a6c (diff)
downloadspark-17caae48b3608552dd6e3ae652043831f932ce95.tar.gz
spark-17caae48b3608552dd6e3ae652043831f932ce95.tar.bz2
spark-17caae48b3608552dd6e3ae652043831f932ce95.zip
[SPARK-2583] ConnectionManager error reporting
This patch modifies the ConnectionManager so that error messages are sent in reply when uncaught exceptions occur during message processing. This prevents message senders from hanging while waiting for an acknowledgment if the remote message processing failed. This is an updated version of sarutak's PR, #1490. The main change is to use Futures / Promises to signal errors. Author: Kousuke Saruta <sarutak@oss.nttdata.co.jp> Author: Josh Rosen <joshrosen@apache.org> Closes #1758 from JoshRosen/connection-manager-fixes and squashes the following commits: 68620cb [Josh Rosen] Fix test in BlockFetcherIteratorSuite: 83673de [Josh Rosen] Error ACKs should trigger IOExceptions, so catch only those exceptions in the test. b8bb4d4 [Josh Rosen] Fix manager.id vs managerServer.id typo that broke security tests. 659521f [Josh Rosen] Include previous exception when throwing new one a2f745c [Josh Rosen] Remove sendMessageReliablySync; callers can wait themselves. c01c450 [Josh Rosen] Return Try[Message] from sendMessageReliablySync. f1cd1bb [Josh Rosen] Clean up @sarutak's PR #1490 for [SPARK-2583]: ConnectionManager error reporting 7399c6b [Josh Rosen] Merge remote-tracking branch 'origin/pr/1490' into connection-manager-fixes ee91bb7 [Kousuke Saruta] Modified BufferMessage.scala to keep the spark code style 9dfd0d8 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 e7d9aa6 [Kousuke Saruta] rebase to master 326a17f [Kousuke Saruta] Add test cases to ConnectionManagerSuite.scala for SPARK-2583 2a18d6b [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 22d7ebd [Kousuke Saruta] Add test cases to BlockManagerSuite for SPARK-2583 e579302 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 281589c [Kousuke Saruta] Add a test case to BlockFetcherIteratorSuite.scala for fetching block from remote from successfully 0654128 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 ffaa83d [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 12d3de8 [Kousuke Saruta] Added BlockFetcherIteratorSuite.scala 4117b8f [Kousuke Saruta] Modified ConnectionManager to be alble to handle error during processing message 717c9c3 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 6635467 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-2583 e2b8c4a [Kousuke Saruta] Modify to propagete error using ConnectionManager
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/network/BufferMessage.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManager.scala143
-rw-r--r--core/src/main/scala/org/apache/spark/network/Message.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/network/SenderTest.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala30
-rw-r--r--core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala38
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala98
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala110
10 files changed, 362 insertions, 89 deletions
diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
index 04df2f3b0d..af35f1fc3e 100644
--- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
+++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
@@ -48,7 +48,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
val security = if (isSecurityNeg) 1 else 0
if (size == 0 && !gotChunkForSendingOnce) {
val newChunk = new MessageChunk(
- new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null)
+ new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null)
gotChunkForSendingOnce = true
return Some(newChunk)
}
@@ -66,7 +66,8 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
}
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
+ typ, id, size, newBuffer.remaining, ackId,
+ hasError, security, senderAddress), newBuffer)
gotChunkForSendingOnce = true
return Some(newChunk)
}
@@ -88,7 +89,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
+ typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer)
return Some(newChunk)
}
None
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index 4c00225280..95f96b8463 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -17,6 +17,7 @@
package org.apache.spark.network
+import java.io.IOException
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
@@ -45,16 +46,26 @@ private[spark] class ConnectionManager(
name: String = "Connection manager")
extends Logging {
+ /**
+ * Used by sendMessageReliably to track messages being sent.
+ * @param message the message that was sent
+ * @param connectionManagerId the connection manager that sent this message
+ * @param completionHandler callback that's invoked when the send has completed or failed
+ */
class MessageStatus(
val message: Message,
val connectionManagerId: ConnectionManagerId,
completionHandler: MessageStatus => Unit) {
+ /** This is non-None if message has been ack'd */
var ackMessage: Option[Message] = None
- var attempted = false
- var acked = false
- def markDone() { completionHandler(this) }
+ def markDone(ackMessage: Option[Message]) {
+ this.synchronized {
+ this.ackMessage = ackMessage
+ completionHandler(this)
+ }
+ }
}
private val selector = SelectorProvider.provider.openSelector()
@@ -442,11 +453,7 @@ private[spark] class ConnectionManager(
messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId)
.foreach(status => {
logInfo("Notifying " + status)
- status.synchronized {
- status.attempted = true
- status.acked = false
- status.markDone()
- }
+ status.markDone(None)
})
messageStatuses.retain((i, status) => {
@@ -475,11 +482,7 @@ private[spark] class ConnectionManager(
for (s <- messageStatuses.values
if s.connectionManagerId == sendingConnectionManagerId) {
logInfo("Notifying " + s)
- s.synchronized {
- s.attempted = true
- s.acked = false
- s.markDone()
- }
+ s.markDone(None)
}
messageStatuses.retain((i, status) => {
@@ -547,13 +550,13 @@ private[spark] class ConnectionManager(
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
securityMsg.getConnectionId.toString)
val message = securityMsgResp.toBufferMessage
- if (message == null) throw new Exception("Error creating security message")
+ if (message == null) throw new IOException("Error creating security message")
sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
} catch {
case e: Exception => {
logError("Error handling sasl client authentication", e)
waitingConn.close()
- throw new Exception("Error evaluating sasl response: " + e)
+ throw new IOException("Error evaluating sasl response: ", e)
}
}
}
@@ -661,34 +664,39 @@ private[spark] class ConnectionManager(
}
}
}
- sentMessageStatus.synchronized {
- sentMessageStatus.ackMessage = Some(message)
- sentMessageStatus.attempted = true
- sentMessageStatus.acked = true
- sentMessageStatus.markDone()
- }
+ sentMessageStatus.markDone(Some(message))
} else {
- val ackMessage = if (onReceiveCallback != null) {
- logDebug("Calling back")
- onReceiveCallback(bufferMessage, connectionManagerId)
- } else {
- logDebug("Not calling back as callback is null")
- None
- }
+ var ackMessage : Option[Message] = None
+ try {
+ ackMessage = if (onReceiveCallback != null) {
+ logDebug("Calling back")
+ onReceiveCallback(bufferMessage, connectionManagerId)
+ } else {
+ logDebug("Not calling back as callback is null")
+ None
+ }
- if (ackMessage.isDefined) {
- if (!ackMessage.get.isInstanceOf[BufferMessage]) {
- logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type "
- + ackMessage.get.getClass)
- } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
- logDebug("Response to " + bufferMessage + " does not have ack id set")
- ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
+ if (ackMessage.isDefined) {
+ if (!ackMessage.get.isInstanceOf[BufferMessage]) {
+ logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type "
+ + ackMessage.get.getClass)
+ } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
+ logDebug("Response to " + bufferMessage + " does not have ack id set")
+ ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
+ }
+ }
+ } catch {
+ case e: Exception => {
+ logError(s"Exception was thrown while processing message", e)
+ val m = Message.createBufferMessage(bufferMessage.id)
+ m.hasError = true
+ ackMessage = Some(m)
}
+ } finally {
+ sendMessage(connectionManagerId, ackMessage.getOrElse {
+ Message.createBufferMessage(bufferMessage.id)
+ })
}
-
- sendMessage(connectionManagerId, ackMessage.getOrElse {
- Message.createBufferMessage(bufferMessage.id)
- })
}
}
case _ => throw new Exception("Unknown type message received")
@@ -800,11 +808,7 @@ private[spark] class ConnectionManager(
case Some(msgStatus) => {
messageStatuses -= message.id
logInfo("Notifying " + msgStatus.connectionManagerId)
- msgStatus.synchronized {
- msgStatus.attempted = true
- msgStatus.acked = false
- msgStatus.markDone()
- }
+ msgStatus.markDone(None)
}
case None => {
logError("no messageStatus for failed message id: " + message.id)
@@ -823,11 +827,28 @@ private[spark] class ConnectionManager(
selector.wakeup()
}
+ /**
+ * Send a message and block until an acknowldgment is received or an error occurs.
+ * @param connectionManagerId the message's destination
+ * @param message the message being sent
+ * @return a Future that either returns the acknowledgment message or captures an exception.
+ */
def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message)
- : Future[Option[Message]] = {
- val promise = Promise[Option[Message]]
- val status = new MessageStatus(
- message, connectionManagerId, s => promise.success(s.ackMessage))
+ : Future[Message] = {
+ val promise = Promise[Message]()
+ val status = new MessageStatus(message, connectionManagerId, s => {
+ s.ackMessage match {
+ case None => // Indicates a failure where we either never sent or never got ACK'd
+ promise.failure(new IOException("sendMessageReliably failed without being ACK'd"))
+ case Some(ackMessage) =>
+ if (ackMessage.hasError) {
+ promise.failure(
+ new IOException("sendMessageReliably failed with ACK that signalled a remote error"))
+ } else {
+ promise.success(ackMessage)
+ }
+ }
+ })
messageStatuses.synchronized {
messageStatuses += ((message.id, status))
}
@@ -835,11 +856,6 @@ private[spark] class ConnectionManager(
promise.future
}
- def sendMessageReliablySync(connectionManagerId: ConnectionManagerId,
- message: Message): Option[Message] = {
- Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf)
- }
-
def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) {
onReceiveCallback = callback
}
@@ -862,6 +878,7 @@ private[spark] class ConnectionManager(
private[spark] object ConnectionManager {
+ import ExecutionContext.Implicits.global
def main(args: Array[String]) {
val conf = new SparkConf
@@ -896,7 +913,7 @@ private[spark] object ConnectionManager {
(0 until count).map(i => {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- manager.sendMessageReliablySync(manager.id, bufferMessage)
+ Await.result(manager.sendMessageReliably(manager.id, bufferMessage), Duration.Inf)
})
println("--------------------------")
println()
@@ -917,8 +934,10 @@ private[spark] object ConnectionManager {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {
- val g = Await.result(f, 1 second)
- if (!g.isDefined) println("Failed")
+ f.onFailure {
+ case e => println("Failed due to " + e)
+ }
+ Await.ready(f, 1 second)
})
val finishTime = System.currentTimeMillis
@@ -952,8 +971,10 @@ private[spark] object ConnectionManager {
val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {
- val g = Await.result(f, 1 second)
- if (!g.isDefined) println("Failed")
+ f.onFailure {
+ case e => println("Failed due to " + e)
+ }
+ Await.ready(f, 1 second)
})
val finishTime = System.currentTimeMillis
@@ -982,8 +1003,10 @@ private[spark] object ConnectionManager {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {
- val g = Await.result(f, 1 second)
- if (!g.isDefined) println("Failed")
+ f.onFailure {
+ case e => println("Failed due to " + e)
+ }
+ Await.ready(f, 1 second)
})
val finishTime = System.currentTimeMillis
Thread.sleep(1000)
diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala
index 7caccfdbb4..04ea50f629 100644
--- a/core/src/main/scala/org/apache/spark/network/Message.scala
+++ b/core/src/main/scala/org/apache/spark/network/Message.scala
@@ -28,6 +28,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
var startTime = -1L
var finishTime = -1L
var isSecurityNeg = false
+ var hasError = false
def size: Int
@@ -87,6 +88,7 @@ private[spark] object Message {
case BUFFER_MESSAGE => new BufferMessage(header.id,
ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
}
+ newMessage.hasError = header.hasError
newMessage.senderAddress = header.address
newMessage
}
diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
index ead663ede7..f3ecca5f99 100644
--- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
+++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
@@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader(
val totalSize: Int,
val chunkSize: Int,
val other: Int,
+ val hasError: Boolean,
val securityNeg: Int,
val address: InetSocketAddress) {
lazy val buffer = {
@@ -41,6 +42,7 @@ private[spark] class MessageChunkHeader(
putInt(totalSize).
putInt(chunkSize).
putInt(other).
+ put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]).
putInt(securityNeg).
putInt(ip.size).
put(ip).
@@ -56,7 +58,7 @@ private[spark] class MessageChunkHeader(
private[spark] object MessageChunkHeader {
- val HEADER_SIZE = 44
+ val HEADER_SIZE = 45
def create(buffer: ByteBuffer): MessageChunkHeader = {
if (buffer.remaining != HEADER_SIZE) {
@@ -67,13 +69,14 @@ private[spark] object MessageChunkHeader {
val totalSize = buffer.getInt()
val chunkSize = buffer.getInt()
val other = buffer.getInt()
+ val hasError = buffer.get() != 0
val securityNeg = buffer.getInt()
val ipSize = buffer.getInt()
val ipBytes = new Array[Byte](ipSize)
buffer.get(ipBytes)
val ip = InetAddress.getByAddress(ipBytes)
val port = buffer.getInt()
- new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg,
+ new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg,
new InetSocketAddress(ip, port))
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
index b8ea7c2cff..ea2ad104ec 100644
--- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
@@ -20,6 +20,10 @@ package org.apache.spark.network
import java.nio.ByteBuffer
import org.apache.spark.{SecurityManager, SparkConf}
+import scala.concurrent.Await
+import scala.concurrent.duration.Duration
+import scala.util.Try
+
private[spark] object SenderTest {
def main(args: Array[String]) {
@@ -51,7 +55,8 @@ private[spark] object SenderTest {
val dataMessage = Message.createBufferMessage(buffer.duplicate)
val startTime = System.currentTimeMillis
/* println("Started timer at " + startTime) */
- val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage)
+ val promise = manager.sendMessageReliably(targetConnectionManagerId, dataMessage)
+ val responseStr: String = Try(Await.result(promise, Duration.Inf))
.map { response =>
val buffer = response.asInstanceOf[BufferMessage].buffers(0)
new String(buffer.array, "utf-8")
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
index ccf830e118..938af6f5b9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
@@ -22,6 +22,7 @@ import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashSet
import scala.collection.mutable.Queue
+import scala.util.{Failure, Success}
import io.netty.buffer.ByteBuf
@@ -118,8 +119,8 @@ object BlockFetcherIterator {
bytesInFlight += req.size
val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
- future.onSuccess {
- case Some(message) => {
+ future.onComplete {
+ case Success(message) => {
val bufferMessage = message.asInstanceOf[BufferMessage]
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
for (blockMessage <- blockMessageArray) {
@@ -135,8 +136,8 @@ object BlockFetcherIterator {
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
}
- case None => {
- logError("Could not get block(s) from " + cmId)
+ case Failure(exception) => {
+ logError("Could not get block(s) from " + cmId, exception)
for ((blockId, size) <- req.blocks) {
results.put(new FetchResult(blockId, -1, null))
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
index c7766a3a65..bf002a42d5 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
@@ -23,6 +23,10 @@ import org.apache.spark.Logging
import org.apache.spark.network._
import org.apache.spark.util.Utils
+import scala.concurrent.Await
+import scala.concurrent.duration.Duration
+import scala.util.{Try, Failure, Success}
+
/**
* A network interface for BlockManager. Each slave should have one
* BlockManagerWorker.
@@ -44,13 +48,19 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
Some(new BlockMessageArray(responseMessages).toBufferMessage)
} catch {
- case e: Exception => logError("Exception handling buffer message", e)
- None
+ case e: Exception => {
+ logError("Exception handling buffer message", e)
+ val errorMessage = Message.createBufferMessage(msg.id)
+ errorMessage.hasError = true
+ Some(errorMessage)
+ }
}
}
case otherMessage: Any => {
logError("Unknown type message received: " + otherMessage)
- None
+ val errorMessage = Message.createBufferMessage(msg.id)
+ errorMessage.hasError = true
+ Some(errorMessage)
}
}
}
@@ -109,9 +119,9 @@ private[spark] object BlockManagerWorker extends Logging {
val connectionManager = blockManager.connectionManager
val blockMessage = BlockMessage.fromPutBlock(msg)
val blockMessageArray = new BlockMessageArray(blockMessage)
- val resultMessage = connectionManager.sendMessageReliablySync(
- toConnManagerId, blockMessageArray.toBufferMessage)
- resultMessage.isDefined
+ val resultMessage = Try(Await.result(connectionManager.sendMessageReliably(
+ toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf))
+ resultMessage.isSuccess
}
def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
@@ -119,10 +129,10 @@ private[spark] object BlockManagerWorker extends Logging {
val connectionManager = blockManager.connectionManager
val blockMessage = BlockMessage.fromGetBlock(msg)
val blockMessageArray = new BlockMessageArray(blockMessage)
- val responseMessage = connectionManager.sendMessageReliablySync(
- toConnManagerId, blockMessageArray.toBufferMessage)
+ val responseMessage = Try(Await.result(connectionManager.sendMessageReliably(
+ toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf))
responseMessage match {
- case Some(message) => {
+ case Success(message) => {
val bufferMessage = message.asInstanceOf[BufferMessage]
logDebug("Response message received " + bufferMessage)
BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => {
@@ -130,7 +140,7 @@ private[spark] object BlockManagerWorker extends Logging {
return blockMessage.getData
})
}
- case None => logDebug("No response message received")
+ case Failure(exception) => logDebug("No response message received")
}
null
}
diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
index 415ad8c432..846537df00 100644
--- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.network
+import java.io.IOException
import java.nio._
import org.apache.spark.{SecurityManager, SparkConf}
@@ -25,6 +26,7 @@ import org.scalatest.FunSuite
import scala.concurrent.{Await, TimeoutException}
import scala.concurrent.duration._
import scala.language.postfixOps
+import scala.util.Try
/**
* Test the ConnectionManager with various security settings.
@@ -46,7 +48,7 @@ class ConnectionManagerSuite extends FunSuite {
buffer.flip
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- manager.sendMessageReliablySync(manager.id, bufferMessage)
+ Await.result(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds)
assert(receivedMessage == true)
@@ -79,7 +81,7 @@ class ConnectionManagerSuite extends FunSuite {
(0 until count).map(i => {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- manager.sendMessageReliablySync(managerServer.id, bufferMessage)
+ Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds)
})
assert(numReceivedServerMessages == 10)
@@ -118,7 +120,10 @@ class ConnectionManagerSuite extends FunSuite {
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
- manager.sendMessageReliablySync(managerServer.id, bufferMessage)
+ // Expect managerServer to close connection, which we'll report as an error:
+ intercept[IOException] {
+ Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds)
+ }
assert(numReceivedServerMessages == 0)
assert(numReceivedMessages == 0)
@@ -163,6 +168,8 @@ class ConnectionManagerSuite extends FunSuite {
val g = Await.result(f, 1 second)
assert(false)
} catch {
+ case i: IOException =>
+ assert(true)
case e: TimeoutException => {
// we should timeout here since the client can't do the negotiation
assert(true)
@@ -209,7 +216,6 @@ class ConnectionManagerSuite extends FunSuite {
}).foreach(f => {
try {
val g = Await.result(f, 1 second)
- if (!g.isDefined) assert(false) else assert(true)
} catch {
case e: Exception => {
assert(false)
@@ -223,7 +229,31 @@ class ConnectionManagerSuite extends FunSuite {
managerServer.stop()
}
+ test("Ack error message") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "false")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ val managerServer = new ConnectionManager(0, conf, securityManager)
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ throw new Exception
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer)
+
+ val future = manager.sendMessageReliably(managerServer.id, bufferMessage)
+
+ intercept[IOException] {
+ Await.result(future, 1 second)
+ }
+ manager.stop()
+ managerServer.stop()
+
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
index 8dca2ebb31..1538995a6b 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
@@ -17,18 +17,22 @@
package org.apache.spark.storage
+import java.io.IOException
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.future
+import scala.concurrent.ExecutionContext.Implicits.global
+
import org.scalatest.{FunSuite, Matchers}
-import org.scalatest.PrivateMethodTester._
import org.mockito.Mockito._
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.stubbing.Answer
import org.mockito.invocation.InvocationOnMock
-import org.apache.spark._
import org.apache.spark.storage.BlockFetcherIterator._
-import org.apache.spark.network.{ConnectionManager, ConnectionManagerId,
- Message}
+import org.apache.spark.network.{ConnectionManager, Message}
class BlockFetcherIteratorSuite extends FunSuite with Matchers {
@@ -137,4 +141,90 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
assert(iterator.next._2.isDefined, "All elements should be defined but 5th element is not actually defined")
}
+ test("block fetch from remote fails using BasicBlockFetcherIterator") {
+ val blockManager = mock(classOf[BlockManager])
+ val connManager = mock(classOf[ConnectionManager])
+ when(blockManager.connectionManager).thenReturn(connManager)
+
+ val f = future {
+ throw new IOException("Send failed or we received an error ACK")
+ }
+ when(connManager.sendMessageReliably(any(),
+ any())).thenReturn(f)
+ when(blockManager.futureExecContext).thenReturn(global)
+
+ when(blockManager.blockManagerId).thenReturn(
+ BlockManagerId("test-client", "test-client", 1, 0))
+ when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024)
+
+ val blId1 = ShuffleBlockId(0,0,0)
+ val blId2 = ShuffleBlockId(0,1,0)
+ val bmId = BlockManagerId("test-server", "test-server",1 , 0)
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (bmId, Seq((blId1, 1L), (blId2, 1L)))
+ )
+
+ val iterator = new BasicBlockFetcherIterator(blockManager,
+ blocksByAddress, null)
+
+ iterator.initialize()
+ iterator.foreach{
+ case (_, r) => {
+ (!r.isDefined) should be(true)
+ }
+ }
+ }
+
+ test("block fetch from remote succeed using BasicBlockFetcherIterator") {
+ val blockManager = mock(classOf[BlockManager])
+ val connManager = mock(classOf[ConnectionManager])
+ when(blockManager.connectionManager).thenReturn(connManager)
+
+ val blId1 = ShuffleBlockId(0,0,0)
+ val blId2 = ShuffleBlockId(0,1,0)
+ val buf1 = ByteBuffer.allocate(4)
+ val buf2 = ByteBuffer.allocate(4)
+ buf1.putInt(1)
+ buf1.flip()
+ buf2.putInt(1)
+ buf2.flip()
+ val blockMessage1 = BlockMessage.fromGotBlock(GotBlock(blId1, buf1))
+ val blockMessage2 = BlockMessage.fromGotBlock(GotBlock(blId2, buf2))
+ val blockMessageArray = new BlockMessageArray(
+ Seq(blockMessage1, blockMessage2))
+
+ val bufferMessage = blockMessageArray.toBufferMessage
+ val buffer = ByteBuffer.allocate(bufferMessage.size)
+ val arrayBuffer = new ArrayBuffer[ByteBuffer]
+ bufferMessage.buffers.foreach{ b =>
+ buffer.put(b)
+ }
+ buffer.flip()
+ arrayBuffer += buffer
+
+ val f = future {
+ Message.createBufferMessage(arrayBuffer)
+ }
+ when(connManager.sendMessageReliably(any(),
+ any())).thenReturn(f)
+ when(blockManager.futureExecContext).thenReturn(global)
+
+ when(blockManager.blockManagerId).thenReturn(
+ BlockManagerId("test-client", "test-client", 1, 0))
+ when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024)
+
+ val bmId = BlockManagerId("test-server", "test-server",1 , 0)
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (bmId, Seq((blId1, 1L), (blId2, 1L)))
+ )
+
+ val iterator = new BasicBlockFetcherIterator(blockManager,
+ blocksByAddress, null)
+ iterator.initialize()
+ iterator.foreach{
+ case (_, r) => {
+ (r.isDefined) should be(true)
+ }
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 0ac0269d7c..94bb2c445d 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -25,7 +25,11 @@ import akka.actor._
import akka.pattern.ask
import akka.util.Timeout
-import org.mockito.Mockito.{mock, when}
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.Matchers.any
+import org.mockito.Mockito.{doAnswer, mock, spy, when}
+import org.mockito.stubbing.Answer
+
import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester}
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts._
@@ -33,6 +37,7 @@ import org.scalatest.Matchers
import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
import org.apache.spark.executor.DataReadMethod
+import org.apache.spark.network.{Message, ConnectionManagerId}
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
@@ -1000,6 +1005,109 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store")
}
+ test("return error message when error occurred in BlockManagerWorker#onBlockMessageReceive") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
+
+ val worker = spy(new BlockManagerWorker(store))
+ val connManagerId = mock(classOf[ConnectionManagerId])
+
+ // setup request block messages
+ val reqBlId1 = ShuffleBlockId(0,0,0)
+ val reqBlId2 = ShuffleBlockId(0,1,0)
+ val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1))
+ val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2))
+ val reqBlockMessages = new BlockMessageArray(
+ Seq(reqBlockMessage1, reqBlockMessage2))
+ val reqBufferMessage = reqBlockMessages.toBufferMessage
+
+ val answer = new Answer[Option[BlockMessage]] {
+ override def answer(invocation: InvocationOnMock)
+ :Option[BlockMessage]= {
+ throw new Exception
+ }
+ }
+
+ doAnswer(answer).when(worker).processBlockMessage(any())
+
+ // Test when exception was thrown during processing block messages
+ var ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId)
+
+ assert(ackMessage.isDefined, "When Exception was thrown in " +
+ "BlockManagerWorker#processBlockMessage, " +
+ "ackMessage should be defined")
+ assert(ackMessage.get.hasError, "When Exception was thown in " +
+ "BlockManagerWorker#processBlockMessage, " +
+ "ackMessage should have error")
+
+ val notBufferMessage = mock(classOf[Message])
+
+ // Test when not BufferMessage was received
+ ackMessage = worker.onBlockMessageReceive(notBufferMessage, connManagerId)
+ assert(ackMessage.isDefined, "When not BufferMessage was passed to " +
+ "BlockManagerWorker#onBlockMessageReceive, " +
+ "ackMessage should be defined")
+ assert(ackMessage.get.hasError, "When not BufferMessage was passed to " +
+ "BlockManagerWorker#onBlockMessageReceive, " +
+ "ackMessage should have error")
+ }
+
+ test("return ack message when no error occurred in BlocManagerWorker#onBlockMessageReceive") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
+
+ val worker = spy(new BlockManagerWorker(store))
+ val connManagerId = mock(classOf[ConnectionManagerId])
+
+ // setup request block messages
+ val reqBlId1 = ShuffleBlockId(0,0,0)
+ val reqBlId2 = ShuffleBlockId(0,1,0)
+ val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1))
+ val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2))
+ val reqBlockMessages = new BlockMessageArray(
+ Seq(reqBlockMessage1, reqBlockMessage2))
+
+ val tmpBufferMessage = reqBlockMessages.toBufferMessage
+ val buffer = ByteBuffer.allocate(tmpBufferMessage.size)
+ val arrayBuffer = new ArrayBuffer[ByteBuffer]
+ tmpBufferMessage.buffers.foreach{ b =>
+ buffer.put(b)
+ }
+ buffer.flip()
+ arrayBuffer += buffer
+ val reqBufferMessage = Message.createBufferMessage(arrayBuffer)
+
+ // setup ack block messages
+ val buf1 = ByteBuffer.allocate(4)
+ val buf2 = ByteBuffer.allocate(4)
+ buf1.putInt(1)
+ buf1.flip()
+ buf2.putInt(1)
+ buf2.flip()
+ val ackBlockMessage1 = BlockMessage.fromGotBlock(GotBlock(reqBlId1, buf1))
+ val ackBlockMessage2 = BlockMessage.fromGotBlock(GotBlock(reqBlId2, buf2))
+
+ val answer = new Answer[Option[BlockMessage]] {
+ override def answer(invocation: InvocationOnMock)
+ :Option[BlockMessage]= {
+ if (invocation.getArguments()(0).asInstanceOf[BlockMessage].eq(
+ reqBlockMessage1)) {
+ return Some(ackBlockMessage1)
+ } else {
+ return Some(ackBlockMessage2)
+ }
+ }
+ }
+
+ doAnswer(answer).when(worker).processBlockMessage(any())
+
+ val ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId)
+ assert(ackMessage.isDefined, "When BlockManagerWorker#onBlockMessageReceive " +
+ "was executed successfully, ackMessage should be defined")
+ assert(!ackMessage.get.hasError, "When BlockManagerWorker#onBlockMessageReceive " +
+ "was executed successfully, ackMessage should not have error")
+ }
+
test("reserve/release unroll memory") {
store = makeBlockManager(12000)
val memoryStore = store.memoryStore