aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/pom.xml5
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/network/BlockDataManager.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/network/BlockTransferService.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala166
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala95
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala76
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala111
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala59
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala132
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala99
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala104
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala47
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala47
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala162
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala40
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala140
-rw-r--r--core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala51
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/Serializer.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala52
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala135
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala1
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala161
-rw-r--r--core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala106
-rw-r--r--core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala64
-rw-r--r--core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala107
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala60
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala261
-rw-r--r--network/common/pom.xml94
-rw-r--r--network/common/src/main/java/org/apache/spark/network/TransportContext.java117
-rw-r--r--network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java154
-rw-r--r--network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java71
-rw-r--r--network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java76
-rw-r--r--network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java75
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java (renamed from core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala)21
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java47
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java (renamed from core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala)22
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClient.java159
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java182
-rw-r--r--network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java167
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java76
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java66
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java80
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java41
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/Message.java58
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java70
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java80
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java (renamed from core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala)10
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java (renamed from core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala)14
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java74
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java81
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java72
-rw-r--r--network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java73
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java104
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java36
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java38
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/StreamManager.java52
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java96
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java162
-rw-r--r--network/common/src/main/java/org/apache/spark/network/server/TransportServer.java121
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java52
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/IOMode.java27
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java38
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java102
-rw-r--r--network/common/src/main/java/org/apache/spark/network/util/TransportConf.java61
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java217
-rw-r--r--network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java28
-rw-r--r--network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java86
-rw-r--r--network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java175
-rw-r--r--network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java34
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java104
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TestUtils.java30
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java102
-rw-r--r--network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java115
-rw-r--r--pom.xml1
-rw-r--r--project/MimaExcludes.scala5
84 files changed, 4431 insertions, 1750 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 5cd21e18e8..8020a2daf8 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -45,6 +45,11 @@
</exclusions>
</dependency>
<dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>network</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
<groupId>net.java.dev.jets3t</groupId>
<artifactId>jets3t</artifactId>
</dependency>
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 5c076e5f1c..6a6dfda363 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -32,6 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.BlockTransferService
+import org.apache.spark.network.netty.{NettyBlockTransferService}
import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
@@ -272,7 +273,13 @@ object SparkEnv extends Logging {
val shuffleMemoryManager = new ShuffleMemoryManager(conf)
- val blockTransferService = new NioBlockTransferService(conf, securityManager)
+ val blockTransferService =
+ conf.get("spark.shuffle.blockTransferService", "nio").toLowerCase match {
+ case "netty" =>
+ new NettyBlockTransferService(conf)
+ case "nio" =>
+ new NioBlockTransferService(conf, securityManager)
+ }
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
index e0e9172427..1745d52c81 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
@@ -17,20 +17,20 @@
package org.apache.spark.network
-import org.apache.spark.storage.StorageLevel
-
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.storage.{BlockId, StorageLevel}
+private[spark]
trait BlockDataManager {
/**
- * Interface to get local block data.
- *
- * @return Some(buffer) if the block exists locally, and None if it doesn't.
+ * Interface to get local block data. Throws an exception if the block cannot be found or
+ * cannot be read successfully.
*/
- def getBlockData(blockId: String): Option[ManagedBuffer]
+ def getBlockData(blockId: BlockId): ManagedBuffer
/**
* Put the block locally, using the given storage level.
*/
- def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit
+ def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
index 34acaa563c..645793fde8 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
@@ -19,19 +19,24 @@ package org.apache.spark.network
import java.util.EventListener
+import org.apache.spark.network.buffer.ManagedBuffer
+
/**
* Listener callback interface for [[BlockTransferService.fetchBlocks]].
*/
+private[spark]
trait BlockFetchingListener extends EventListener {
/**
- * Called once per successfully fetched block.
+ * Called once per successfully fetched block. After this call returns, data will be released
+ * automatically. If the data will be passed to another thread, the receiver should retain()
+ * and release() the buffer on their own, or copy the data to a new buffer.
*/
def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit
/**
- * Called upon failures. For each failure, this is called only once (i.e. not once per block).
+ * Called at least once per block upon failures.
*/
- def onBlockFetchFailure(exception: Throwable): Unit
+ def onBlockFetchFailure(blockId: String, exception: Throwable): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
index 84d991fa68..b083f46533 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -17,13 +17,19 @@
package org.apache.spark.network
+import java.io.Closeable
+import java.nio.ByteBuffer
+
import scala.concurrent.{Await, Future}
import scala.concurrent.duration.Duration
-import org.apache.spark.storage.StorageLevel
-
+import org.apache.spark.Logging
+import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
+import org.apache.spark.storage.{BlockId, StorageLevel}
+import org.apache.spark.util.Utils
-abstract class BlockTransferService {
+private[spark]
+abstract class BlockTransferService extends Closeable with Logging {
/**
* Initialize the transfer service by giving it the BlockDataManager that can be used to fetch
@@ -34,7 +40,7 @@ abstract class BlockTransferService {
/**
* Tear down the transfer service.
*/
- def stop(): Unit
+ def close(): Unit
/**
* Port number the service is listening on, available only after [[init]] is invoked.
@@ -50,9 +56,6 @@ abstract class BlockTransferService {
* Fetch a sequence of blocks from a remote node asynchronously,
* available only after [[init]] is invoked.
*
- * Note that [[BlockFetchingListener.onBlockFetchSuccess]] is called once per block,
- * while [[BlockFetchingListener.onBlockFetchFailure]] is called once per failure (not per block).
- *
* Note that this API takes a sequence so the implementation can batch requests, and does not
* return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
* the data of a block is fetched, rather than waiting for all blocks to be fetched.
@@ -69,7 +72,7 @@ abstract class BlockTransferService {
def uploadBlock(
hostname: String,
port: Int,
- blockId: String,
+ blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit]
@@ -83,7 +86,7 @@ abstract class BlockTransferService {
val lock = new Object
@volatile var result: Either[ManagedBuffer, Throwable] = null
fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener {
- override def onBlockFetchFailure(exception: Throwable): Unit = {
+ override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
lock.synchronized {
result = Right(exception)
lock.notify()
@@ -91,7 +94,10 @@ abstract class BlockTransferService {
}
override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
lock.synchronized {
- result = Left(data)
+ val ret = ByteBuffer.allocate(data.size.toInt)
+ ret.put(data.nioByteBuffer())
+ ret.flip()
+ result = Left(new NioManagedBuffer(ret))
lock.notify()
}
}
@@ -123,7 +129,7 @@ abstract class BlockTransferService {
def uploadBlockSync(
hostname: String,
port: Int,
- blockId: String,
+ blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Unit = {
Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf)
diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
deleted file mode 100644
index 4211ba4e43..0000000000
--- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
+++ /dev/null
@@ -1,166 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.io._
-import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
-import java.nio.channels.FileChannel.MapMode
-
-import scala.util.Try
-
-import com.google.common.io.ByteStreams
-import io.netty.buffer.{ByteBufInputStream, ByteBuf}
-
-import org.apache.spark.util.{ByteBufferInputStream, Utils}
-
-
-/**
- * This interface provides an immutable view for data in the form of bytes. The implementation
- * should specify how the data is provided:
- *
- * - FileSegmentManagedBuffer: data backed by part of a file
- * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer
- * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf
- */
-sealed abstract class ManagedBuffer {
- // Note that all the methods are defined with parenthesis because their implementations can
- // have side effects (io operations).
-
- /** Number of bytes of the data. */
- def size: Long
-
- /**
- * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the
- * returned ByteBuffer should not affect the content of this buffer.
- */
- def nioByteBuffer(): ByteBuffer
-
- /**
- * Exposes this buffer's data as an InputStream. The underlying implementation does not
- * necessarily check for the length of bytes read, so the caller is responsible for making sure
- * it does not go over the limit.
- */
- def inputStream(): InputStream
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by a segment in a file
- */
-final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long)
- extends ManagedBuffer {
-
- /**
- * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889).
- * Avoid unless there's a good reason not to.
- */
- private val MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024;
-
- override def size: Long = length
-
- override def nioByteBuffer(): ByteBuffer = {
- var channel: FileChannel = null
- try {
- channel = new RandomAccessFile(file, "r").getChannel
- // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead.
- if (length < MIN_MEMORY_MAP_BYTES) {
- val buf = ByteBuffer.allocate(length.toInt)
- channel.position(offset)
- while (buf.remaining() != 0) {
- if (channel.read(buf) == -1) {
- throw new IOException("Reached EOF before filling buffer\n" +
- s"offset=$offset\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}")
- }
- }
- buf.flip()
- buf
- } else {
- channel.map(MapMode.READ_ONLY, offset, length)
- }
- } catch {
- case e: IOException =>
- Try(channel.size).toOption match {
- case Some(fileLen) =>
- throw new IOException(s"Error in reading $this (actual file length $fileLen)", e)
- case None =>
- throw new IOException(s"Error in opening $this", e)
- }
- } finally {
- if (channel != null) {
- Utils.tryLog(channel.close())
- }
- }
- }
-
- override def inputStream(): InputStream = {
- var is: FileInputStream = null
- try {
- is = new FileInputStream(file)
- ByteStreams.skipFully(is, offset)
- ByteStreams.limit(is, length)
- } catch {
- case e: IOException =>
- if (is != null) {
- Utils.tryLog(is.close())
- }
- Try(file.length).toOption match {
- case Some(fileLen) =>
- throw new IOException(s"Error in reading $this (actual file length $fileLen)", e)
- case None =>
- throw new IOException(s"Error in opening $this", e)
- }
- case e: Throwable =>
- if (is != null) {
- Utils.tryLog(is.close())
- }
- throw e
- }
- }
-
- override def toString: String = s"${getClass.getName}($file, $offset, $length)"
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]].
- */
-final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer {
-
- override def size: Long = buf.remaining()
-
- override def nioByteBuffer() = buf.duplicate()
-
- override def inputStream() = new ByteBufferInputStream(buf)
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]].
- */
-final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer {
-
- override def size: Long = buf.readableBytes()
-
- override def nioByteBuffer() = buf.nioBuffer()
-
- override def inputStream() = new ByteBufInputStream(buf)
-
- // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it.
- def release(): Unit = buf.release()
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala
new file mode 100644
index 0000000000..8c5ffd8da6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockFetcher.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.nio.ByteBuffer
+import java.util
+
+import org.apache.spark.{SparkConf, Logging}
+import org.apache.spark.network.BlockFetchingListener
+import org.apache.spark.network.netty.NettyMessages._
+import org.apache.spark.serializer.{JavaSerializer, Serializer}
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.client.{RpcResponseCallback, ChunkReceivedCallback, TransportClient}
+import org.apache.spark.storage.BlockId
+import org.apache.spark.util.Utils
+
+/**
+ * Responsible for holding the state for a request for a single set of blocks. This assumes that
+ * the chunks will be returned in the same order as requested, and that there will be exactly
+ * one chunk per block.
+ *
+ * Upon receipt of any block, the listener will be called back. Upon failure part way through,
+ * the listener will receive a failure callback for each outstanding block.
+ */
+class NettyBlockFetcher(
+ serializer: Serializer,
+ client: TransportClient,
+ blockIds: Seq[String],
+ listener: BlockFetchingListener)
+ extends Logging {
+
+ require(blockIds.nonEmpty)
+
+ private val ser = serializer.newInstance()
+
+ private var streamHandle: ShuffleStreamHandle = _
+
+ private val chunkCallback = new ChunkReceivedCallback {
+ // On receipt of a chunk, pass it upwards as a block.
+ def onSuccess(chunkIndex: Int, buffer: ManagedBuffer): Unit = Utils.logUncaughtExceptions {
+ listener.onBlockFetchSuccess(blockIds(chunkIndex), buffer)
+ }
+
+ // On receipt of a failure, fail every block from chunkIndex onwards.
+ def onFailure(chunkIndex: Int, e: Throwable): Unit = {
+ blockIds.drop(chunkIndex).foreach { blockId =>
+ listener.onBlockFetchFailure(blockId, e);
+ }
+ }
+ }
+
+ /** Begins the fetching process, calling the listener with every block fetched. */
+ def start(): Unit = {
+ // Send the RPC to open the given set of blocks. This will return a ShuffleStreamHandle.
+ client.sendRpc(ser.serialize(OpenBlocks(blockIds.map(BlockId.apply))).array(),
+ new RpcResponseCallback {
+ override def onSuccess(response: Array[Byte]): Unit = {
+ try {
+ streamHandle = ser.deserialize[ShuffleStreamHandle](ByteBuffer.wrap(response))
+ logTrace(s"Successfully opened block set: $streamHandle! Preparing to fetch chunks.")
+
+ // Immediately request all chunks -- we expect that the total size of the request is
+ // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
+ for (i <- 0 until streamHandle.numChunks) {
+ client.fetchChunk(streamHandle.streamId, i, chunkCallback)
+ }
+ } catch {
+ case e: Exception =>
+ logError("Failed while starting block fetches", e)
+ blockIds.foreach(blockId => Utils.tryLog(listener.onBlockFetchFailure(blockId, e)))
+ }
+ }
+
+ override def onFailure(e: Throwable): Unit = {
+ logError("Failed while starting block fetches", e)
+ blockIds.foreach(blockId => Utils.tryLog(listener.onBlockFetchFailure(blockId, e)))
+ }
+ })
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
new file mode 100644
index 0000000000..02c657e1d6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.Logging
+import org.apache.spark.network.BlockDataManager
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.client.{TransportClient, RpcResponseCallback}
+import org.apache.spark.network.server.{DefaultStreamManager, RpcHandler}
+import org.apache.spark.storage.{StorageLevel, BlockId}
+
+import scala.collection.JavaConversions._
+
+object NettyMessages {
+
+ /** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */
+ case class OpenBlocks(blockIds: Seq[BlockId])
+
+ /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */
+ case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel)
+
+ /** Identifier for a fixed number of chunks to read from a stream created by [[OpenBlocks]]. */
+ case class ShuffleStreamHandle(streamId: Long, numChunks: Int)
+}
+
+/**
+ * Serves requests to open blocks by simply registering one chunk per block requested.
+ */
+class NettyBlockRpcServer(
+ serializer: Serializer,
+ streamManager: DefaultStreamManager,
+ blockManager: BlockDataManager)
+ extends RpcHandler with Logging {
+
+ import NettyMessages._
+
+ override def receive(
+ client: TransportClient,
+ messageBytes: Array[Byte],
+ responseContext: RpcResponseCallback): Unit = {
+ val ser = serializer.newInstance()
+ val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes))
+ logTrace(s"Received request: $message")
+
+ message match {
+ case OpenBlocks(blockIds) =>
+ val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData)
+ val streamId = streamManager.registerStream(blocks.iterator)
+ logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
+ responseContext.onSuccess(
+ ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array())
+
+ case UploadBlock(blockId, blockData, level) =>
+ blockManager.putBlockData(blockId, new NioManagedBuffer(ByteBuffer.wrap(blockData)), level)
+ responseContext.onSuccess(new Array[Byte](0))
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
new file mode 100644
index 0000000000..38a3e94515
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import scala.concurrent.{Promise, Future}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.network._
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.client.{RpcResponseCallback, TransportClient, TransportClientFactory}
+import org.apache.spark.network.netty.NettyMessages.UploadBlock
+import org.apache.spark.network.server._
+import org.apache.spark.network.util.{ConfigProvider, TransportConf}
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.storage.{BlockId, StorageLevel}
+import org.apache.spark.util.Utils
+
+/**
+ * A BlockTransferService that uses Netty to fetch a set of blocks at at time.
+ */
+class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
+ // TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
+ val serializer = new JavaSerializer(conf)
+
+ // Create a TransportConfig using SparkConf.
+ private[this] val transportConf = new TransportConf(
+ new ConfigProvider { override def get(name: String) = conf.get(name) })
+
+ private[this] var transportContext: TransportContext = _
+ private[this] var server: TransportServer = _
+ private[this] var clientFactory: TransportClientFactory = _
+
+ override def init(blockDataManager: BlockDataManager): Unit = {
+ val streamManager = new DefaultStreamManager
+ val rpcHandler = new NettyBlockRpcServer(serializer, streamManager, blockDataManager)
+ transportContext = new TransportContext(transportConf, streamManager, rpcHandler)
+ clientFactory = transportContext.createClientFactory()
+ server = transportContext.createServer()
+ }
+
+ override def fetchBlocks(
+ hostname: String,
+ port: Int,
+ blockIds: Seq[String],
+ listener: BlockFetchingListener): Unit = {
+ try {
+ val client = clientFactory.createClient(hostname, port)
+ new NettyBlockFetcher(serializer, client, blockIds, listener).start()
+ } catch {
+ case e: Exception =>
+ logError("Exception while beginning fetchBlocks", e)
+ blockIds.foreach(listener.onBlockFetchFailure(_, e))
+ }
+ }
+
+ override def hostName: String = Utils.localHostName()
+
+ override def port: Int = server.getPort
+
+ override def uploadBlock(
+ hostname: String,
+ port: Int,
+ blockId: BlockId,
+ blockData: ManagedBuffer,
+ level: StorageLevel): Future[Unit] = {
+ val result = Promise[Unit]()
+ val client = clientFactory.createClient(hostname, port)
+
+ // Convert or copy nio buffer into array in order to serialize it.
+ val nioBuffer = blockData.nioByteBuffer()
+ val array = if (nioBuffer.hasArray) {
+ nioBuffer.array()
+ } else {
+ val data = new Array[Byte](nioBuffer.remaining())
+ nioBuffer.get(data)
+ data
+ }
+
+ val ser = serializer.newInstance()
+ client.sendRpc(ser.serialize(new UploadBlock(blockId, array, level)).array(),
+ new RpcResponseCallback {
+ override def onSuccess(response: Array[Byte]): Unit = {
+ logTrace(s"Successfully uploaded block $blockId")
+ result.success()
+ }
+ override def onFailure(e: Throwable): Unit = {
+ logError(s"Error while uploading block $blockId", e)
+ result.failure(e)
+ }
+ })
+
+ result.future
+ }
+
+ override def close(): Unit = server.close()
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala
deleted file mode 100644
index b5870152c5..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import org.apache.spark.SparkConf
-
-/**
- * A central location that tracks all the settings we exposed to users.
- */
-private[spark]
-class NettyConfig(conf: SparkConf) {
-
- /** Port the server listens on. Default to a random port. */
- private[netty] val serverPort = conf.getInt("spark.shuffle.io.port", 0)
-
- /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */
- private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase
-
- /** Connect timeout in secs. Default 60 secs. */
- private[netty] val connectTimeoutMs = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000
-
- /**
- * Percentage of the desired amount of time spent for I/O in the child event loops.
- * Only applicable in nio and epoll.
- */
- private[netty] val ioRatio = conf.getInt("spark.shuffle.io.netty.ioRatio", 80)
-
- /** Requested maximum length of the queue of incoming connections. */
- private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt)
-
- /**
- * Receive buffer size (SO_RCVBUF).
- * Note: the optimal size for receive buffer and send buffer should be
- * latency * network_bandwidth.
- * Assuming latency = 1ms, network_bandwidth = 10Gbps
- * buffer size should be ~ 1.25MB
- */
- private[netty] val receiveBuf: Option[Int] =
- conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt)
-
- /** Send buffer size (SO_SNDBUF). */
- private[netty] val sendBuf: Option[Int] =
- conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt)
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala
deleted file mode 100644
index 3ab13b96d7..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala
+++ /dev/null
@@ -1,132 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import java.util.concurrent.TimeoutException
-
-import com.google.common.base.Charsets.UTF_8
-import io.netty.bootstrap.Bootstrap
-import io.netty.buffer.PooledByteBufAllocator
-import io.netty.channel.socket.SocketChannel
-import io.netty.channel.{ChannelFutureListener, ChannelFuture, ChannelInitializer, ChannelOption}
-import io.netty.handler.codec.LengthFieldBasedFrameDecoder
-import io.netty.handler.codec.string.StringEncoder
-
-import org.apache.spark.Logging
-
-/**
- * Client for fetching data blocks from [[org.apache.spark.network.netty.server.BlockServer]].
- * Use [[BlockFetchingClientFactory]] to instantiate this client.
- *
- * The constructor blocks until a connection is successfully established.
- *
- * See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol.
- *
- * Concurrency: thread safe and can be called from multiple threads.
- */
-@throws[TimeoutException]
-private[spark]
-class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, port: Int)
- extends Logging {
-
- private val handler = new BlockFetchingClientHandler
-
- /** Netty Bootstrap for creating the TCP connection. */
- private val bootstrap: Bootstrap = {
- val b = new Bootstrap
- b.group(factory.workerGroup)
- .channel(factory.socketChannelClass)
- // Use pooled buffers to reduce temporary buffer allocation
- .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
- // Disable Nagle's Algorithm since we don't want packets to wait
- .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE)
- .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE)
- .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs)
-
- b.handler(new ChannelInitializer[SocketChannel] {
- override def initChannel(ch: SocketChannel): Unit = {
- ch.pipeline
- .addLast("encoder", new StringEncoder(UTF_8))
- // maxFrameLength = 2G, lengthFieldOffset = 0, lengthFieldLength = 4
- .addLast("framedLengthDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4))
- .addLast("handler", handler)
- }
- })
- b
- }
-
- /** Netty ChannelFuture for the connection. */
- private val cf: ChannelFuture = bootstrap.connect(hostname, port)
- if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) {
- throw new TimeoutException(
- s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)")
- }
-
- /**
- * Ask the remote server for a sequence of blocks, and execute the callback.
- *
- * Note that this is asynchronous and returns immediately. Upstream caller should throttle the
- * rate of fetching; otherwise we could run out of memory.
- *
- * @param blockIds sequence of block ids to fetch.
- * @param listener callback to fire on fetch success / failure.
- */
- def fetchBlocks(blockIds: Seq[String], listener: BlockClientListener): Unit = {
- // It's best to limit the number of "write" calls since it needs to traverse the whole pipeline.
- // It's also best to limit the number of "flush" calls since it requires system calls.
- // Let's concatenate the string and then call writeAndFlush once.
- // This is also why this implementation might be more efficient than multiple, separate
- // fetch block calls.
- var startTime: Long = 0
- logTrace {
- startTime = System.nanoTime
- s"Sending request $blockIds to $hostname:$port"
- }
-
- blockIds.foreach { blockId =>
- handler.addRequest(blockId, listener)
- }
-
- val writeFuture = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n")
- writeFuture.addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture): Unit = {
- if (future.isSuccess) {
- logTrace {
- val timeTaken = (System.nanoTime - startTime).toDouble / 1000000
- s"Sending request $blockIds to $hostname:$port took $timeTaken ms"
- }
- } else {
- // Fail all blocks.
- val errorMsg =
- s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}"
- logError(errorMsg, future.cause)
- blockIds.foreach { blockId =>
- listener.onFetchFailure(blockId, errorMsg)
- handler.removeRequest(blockId)
- }
- }
- }
- })
- }
-
- def waitForClose(): Unit = {
- cf.channel().closeFuture().sync()
- }
-
- def close(): Unit = cf.channel().close()
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala
deleted file mode 100644
index 2b28402c52..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import io.netty.channel.epoll.{EpollEventLoopGroup, EpollSocketChannel}
-import io.netty.channel.nio.NioEventLoopGroup
-import io.netty.channel.oio.OioEventLoopGroup
-import io.netty.channel.socket.nio.NioSocketChannel
-import io.netty.channel.socket.oio.OioSocketChannel
-import io.netty.channel.{EventLoopGroup, Channel}
-
-import org.apache.spark.SparkConf
-import org.apache.spark.network.netty.NettyConfig
-import org.apache.spark.util.Utils
-
-/**
- * Factory for creating [[BlockFetchingClient]] by using createClient. This factory reuses
- * the worker thread pool for Netty.
- *
- * Concurrency: createClient is safe to be called from multiple threads concurrently.
- */
-private[spark]
-class BlockFetchingClientFactory(val conf: NettyConfig) {
-
- def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf))
-
- /** A thread factory so the threads are named (for debugging). */
- val threadFactory = Utils.namedThreadFactory("spark-shuffle-client")
-
- /** The following two are instantiated by the [[init]] method, depending ioMode. */
- var socketChannelClass: Class[_ <: Channel] = _
- var workerGroup: EventLoopGroup = _
-
- init()
-
- /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */
- private def init(): Unit = {
- def initOio(): Unit = {
- socketChannelClass = classOf[OioSocketChannel]
- workerGroup = new OioEventLoopGroup(0, threadFactory)
- }
- def initNio(): Unit = {
- socketChannelClass = classOf[NioSocketChannel]
- workerGroup = new NioEventLoopGroup(0, threadFactory)
- }
- def initEpoll(): Unit = {
- socketChannelClass = classOf[EpollSocketChannel]
- workerGroup = new EpollEventLoopGroup(0, threadFactory)
- }
-
- conf.ioMode match {
- case "nio" => initNio()
- case "oio" => initOio()
- case "epoll" => initEpoll()
- case "auto" =>
- // For auto mode, first try epoll (only available on Linux), then nio.
- try {
- initEpoll()
- } catch {
- // TODO: Should we log the throwable? But that always happen on non-Linux systems.
- // Perhaps the right thing to do is to check whether the system is Linux, and then only
- // call initEpoll on Linux.
- case e: Throwable => initNio()
- }
- }
- }
-
- /**
- * Create a new BlockFetchingClient connecting to the given remote host / port.
- *
- * This blocks until a connection is successfully established.
- *
- * Concurrency: This method is safe to call from multiple threads.
- */
- def createClient(remoteHost: String, remotePort: Int): BlockFetchingClient = {
- new BlockFetchingClient(this, remoteHost, remotePort)
- }
-
- def stop(): Unit = {
- if (workerGroup != null) {
- workerGroup.shutdownGracefully()
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala
deleted file mode 100644
index d9d3f7bef0..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala
+++ /dev/null
@@ -1,104 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import com.google.common.base.Charsets.UTF_8
-import io.netty.buffer.ByteBuf
-import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
-
-import org.apache.spark.Logging
-
-
-/**
- * Handler that processes server responses. It uses the protocol documented in
- * [[org.apache.spark.network.netty.server.BlockServer]].
- *
- * Concurrency: thread safe and can be called from multiple threads.
- */
-private[client]
-class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] with Logging {
-
- /** Tracks the list of outstanding requests and their listeners on success/failure. */
- private val outstandingRequests = java.util.Collections.synchronizedMap {
- new java.util.HashMap[String, BlockClientListener]
- }
-
- def addRequest(blockId: String, listener: BlockClientListener): Unit = {
- outstandingRequests.put(blockId, listener)
- }
-
- def removeRequest(blockId: String): Unit = {
- outstandingRequests.remove(blockId)
- }
-
- override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
- val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}"
- logError(errorMsg, cause)
-
- // Fire the failure callback for all outstanding blocks
- outstandingRequests.synchronized {
- val iter = outstandingRequests.entrySet().iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- entry.getValue.onFetchFailure(entry.getKey, errorMsg)
- }
- outstandingRequests.clear()
- }
-
- ctx.close()
- }
-
- override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) {
- val totalLen = in.readInt()
- val blockIdLen = in.readInt()
- val blockIdBytes = new Array[Byte](math.abs(blockIdLen))
- in.readBytes(blockIdBytes)
- val blockId = new String(blockIdBytes, UTF_8)
- val blockSize = totalLen - math.abs(blockIdLen) - 4
-
- def server = ctx.channel.remoteAddress.toString
-
- // blockIdLen is negative when it is an error message.
- if (blockIdLen < 0) {
- val errorMessageBytes = new Array[Byte](blockSize)
- in.readBytes(errorMessageBytes)
- val errorMsg = new String(errorMessageBytes, UTF_8)
- logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server")
-
- val listener = outstandingRequests.get(blockId)
- if (listener == null) {
- // Ignore callback
- logWarning(s"Got a response for block $blockId but it is not in our outstanding requests")
- } else {
- outstandingRequests.remove(blockId)
- listener.onFetchFailure(blockId, errorMsg)
- }
- } else {
- logTrace(s"Received block $blockId ($blockSize B) from $server")
-
- val listener = outstandingRequests.get(blockId)
- if (listener == null) {
- // Ignore callback
- logWarning(s"Got a response for block $blockId but it is not in our outstanding requests")
- } else {
- outstandingRequests.remove(blockId)
- listener.onFetchSuccess(blockId, new ReferenceCountedBuffer(in))
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala
deleted file mode 100644
index 9740ee64d1..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-/**
- * A simple iterator that lazily initializes the underlying iterator.
- *
- * The use case is that sometimes we might have many iterators open at the same time, and each of
- * the iterator might initialize its own buffer (e.g. decompression buffer, deserialization buffer).
- * This could lead to too many buffers open. If this iterator is used, we lazily initialize those
- * buffers.
- */
-private[spark]
-class LazyInitIterator(createIterator: => Iterator[Any]) extends Iterator[Any] {
-
- lazy val proxy = createIterator
-
- override def hasNext: Boolean = {
- val gotNext = proxy.hasNext
- if (!gotNext) {
- close()
- }
- gotNext
- }
-
- override def next(): Any = proxy.next()
-
- def close(): Unit = Unit
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala
deleted file mode 100644
index ea1abf5ecc..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import java.io.InputStream
-import java.nio.ByteBuffer
-
-import io.netty.buffer.{ByteBuf, ByteBufInputStream}
-
-
-/**
- * A buffer abstraction based on Netty's ByteBuf so we don't expose Netty.
- * This is a Scala value class.
- *
- * The buffer's life cycle is NOT managed by the JVM, and thus requiring explicit declaration of
- * reference by the retain method and release method.
- */
-private[spark]
-class ReferenceCountedBuffer(val underlying: ByteBuf) extends AnyVal {
-
- /** Return the nio ByteBuffer view of the underlying buffer. */
- def byteBuffer(): ByteBuffer = underlying.nioBuffer
-
- /** Creates a new input stream that starts from the current position of the buffer. */
- def inputStream(): InputStream = new ByteBufInputStream(underlying)
-
- /** Increment the reference counter by one. */
- def retain(): Unit = underlying.retain()
-
- /** Decrement the reference counter by one and release the buffer if the ref count is 0. */
- def release(): Unit = underlying.release()
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala
deleted file mode 100644
index 8e4dda4ef8..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import io.netty.buffer.ByteBuf
-import io.netty.channel.ChannelHandlerContext
-import io.netty.handler.codec.MessageToByteEncoder
-
-/**
- * A simple encoder for BlockHeader. See [[BlockServer]] for the server to client protocol.
- */
-private[server]
-class BlockHeaderEncoder extends MessageToByteEncoder[BlockHeader] {
- override def encode(ctx: ChannelHandlerContext, msg: BlockHeader, out: ByteBuf): Unit = {
- // message = message length (4 bytes) + block id length (4 bytes) + block id + block data
- // message length = block id length (4 bytes) + size of block id + size of block data
- val blockIdBytes = msg.blockId.getBytes
- msg.error match {
- case Some(errorMsg) =>
- val errorBytes = errorMsg.getBytes
- out.writeInt(4 + blockIdBytes.length + errorBytes.size)
- out.writeInt(-blockIdBytes.length) // use negative block id length to represent errors
- out.writeBytes(blockIdBytes) // next is blockId itself
- out.writeBytes(errorBytes) // error message
- case None =>
- out.writeInt(4 + blockIdBytes.length + msg.blockSize)
- out.writeInt(blockIdBytes.length) // First 4 bytes is blockId length
- out.writeBytes(blockIdBytes) // next is blockId itself
- // msg of size blockSize will be written by ServerHandler
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala
deleted file mode 100644
index 9194c7ced3..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala
+++ /dev/null
@@ -1,162 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import java.net.InetSocketAddress
-
-import com.google.common.base.Charsets.UTF_8
-import io.netty.bootstrap.ServerBootstrap
-import io.netty.buffer.PooledByteBufAllocator
-import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption}
-import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel}
-import io.netty.channel.nio.NioEventLoopGroup
-import io.netty.channel.oio.OioEventLoopGroup
-import io.netty.channel.socket.SocketChannel
-import io.netty.channel.socket.nio.NioServerSocketChannel
-import io.netty.channel.socket.oio.OioServerSocketChannel
-import io.netty.handler.codec.LineBasedFrameDecoder
-import io.netty.handler.codec.string.StringDecoder
-
-import org.apache.spark.{Logging, SparkConf}
-import org.apache.spark.network.netty.NettyConfig
-import org.apache.spark.storage.BlockDataProvider
-import org.apache.spark.util.Utils
-
-
-/**
- * Server for serving Spark data blocks.
- * This should be used together with [[org.apache.spark.network.netty.client.BlockFetchingClient]].
- *
- * Protocol for requesting blocks (client to server):
- * One block id per line, e.g. to request 3 blocks: "block1\nblock2\nblock3\n"
- *
- * Protocol for sending blocks (server to client):
- * frame-length (4 bytes), block-id-length (4 bytes), block-id, block-data.
- *
- * frame-length should not include the length of itself.
- * If block-id-length is negative, then this is an error message rather than block-data. The real
- * length is the absolute value of the frame-length.
- *
- */
-private[spark]
-class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Logging {
-
- def this(sparkConf: SparkConf, dataProvider: BlockDataProvider) = {
- this(new NettyConfig(sparkConf), dataProvider)
- }
-
- def port: Int = _port
-
- def hostName: String = _hostName
-
- private var _port: Int = conf.serverPort
- private var _hostName: String = ""
- private var bootstrap: ServerBootstrap = _
- private var channelFuture: ChannelFuture = _
-
- init()
-
- /** Initialize the server. */
- private def init(): Unit = {
- bootstrap = new ServerBootstrap
- val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss")
- val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker")
-
- // Use only one thread to accept connections, and 2 * num_cores for worker.
- def initNio(): Unit = {
- val bossGroup = new NioEventLoopGroup(1, bossThreadFactory)
- val workerGroup = new NioEventLoopGroup(0, workerThreadFactory)
- workerGroup.setIoRatio(conf.ioRatio)
- bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel])
- }
- def initOio(): Unit = {
- val bossGroup = new OioEventLoopGroup(1, bossThreadFactory)
- val workerGroup = new OioEventLoopGroup(0, workerThreadFactory)
- bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel])
- }
- def initEpoll(): Unit = {
- val bossGroup = new EpollEventLoopGroup(1, bossThreadFactory)
- val workerGroup = new EpollEventLoopGroup(0, workerThreadFactory)
- workerGroup.setIoRatio(conf.ioRatio)
- bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel])
- }
-
- conf.ioMode match {
- case "nio" => initNio()
- case "oio" => initOio()
- case "epoll" => initEpoll()
- case "auto" =>
- // For auto mode, first try epoll (only available on Linux), then nio.
- try {
- initEpoll()
- } catch {
- // TODO: Should we log the throwable? But that always happen on non-Linux systems.
- // Perhaps the right thing to do is to check whether the system is Linux, and then only
- // call initEpoll on Linux.
- case e: Throwable => initNio()
- }
- }
-
- // Use pooled buffers to reduce temporary buffer allocation
- bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
- bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
-
- // Various (advanced) user-configured settings.
- conf.backLog.foreach { backLog =>
- bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog)
- }
- conf.receiveBuf.foreach { receiveBuf =>
- bootstrap.option[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf)
- }
- conf.sendBuf.foreach { sendBuf =>
- bootstrap.option[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf)
- }
-
- bootstrap.childHandler(new ChannelInitializer[SocketChannel] {
- override def initChannel(ch: SocketChannel): Unit = {
- ch.pipeline
- .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
- .addLast("stringDecoder", new StringDecoder(UTF_8))
- .addLast("blockHeaderEncoder", new BlockHeaderEncoder)
- .addLast("handler", new BlockServerHandler(dataProvider))
- }
- })
-
- channelFuture = bootstrap.bind(new InetSocketAddress(_port))
- channelFuture.sync()
-
- val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress]
- _port = addr.getPort
- _hostName = addr.getHostName
- }
-
- /** Shutdown the server. */
- def stop(): Unit = {
- if (channelFuture != null) {
- channelFuture.channel().close().awaitUninterruptibly()
- channelFuture = null
- }
- if (bootstrap != null && bootstrap.group() != null) {
- bootstrap.group().shutdownGracefully()
- }
- if (bootstrap != null && bootstrap.childGroup() != null) {
- bootstrap.childGroup().shutdownGracefully()
- }
- bootstrap = null
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala
deleted file mode 100644
index 188154d51d..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import com.google.common.base.Charsets.UTF_8
-import io.netty.channel.ChannelInitializer
-import io.netty.channel.socket.SocketChannel
-import io.netty.handler.codec.LineBasedFrameDecoder
-import io.netty.handler.codec.string.StringDecoder
-
-import org.apache.spark.storage.BlockDataProvider
-
-/** Channel initializer that sets up the pipeline for the BlockServer. */
-private[netty]
-class BlockServerChannelInitializer(dataProvider: BlockDataProvider)
- extends ChannelInitializer[SocketChannel] {
-
- override def initChannel(ch: SocketChannel): Unit = {
- ch.pipeline
- .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
- .addLast("stringDecoder", new StringDecoder(UTF_8))
- .addLast("blockHeaderEncoder", new BlockHeaderEncoder)
- .addLast("handler", new BlockServerHandler(dataProvider))
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala
deleted file mode 100644
index 40dd5e5d1a..0000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala
+++ /dev/null
@@ -1,140 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import java.io.FileInputStream
-import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
-
-import io.netty.buffer.Unpooled
-import io.netty.channel._
-
-import org.apache.spark.Logging
-import org.apache.spark.storage.{FileSegment, BlockDataProvider}
-
-
-/**
- * A handler that processes requests from clients and writes block data back.
- *
- * The messages should have been processed by a LineBasedFrameDecoder and a StringDecoder first
- * so channelRead0 is called once per line (i.e. per block id).
- */
-private[server]
-class BlockServerHandler(dataProvider: BlockDataProvider)
- extends SimpleChannelInboundHandler[String] with Logging {
-
- override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
- logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause)
- ctx.close()
- }
-
- override def channelRead0(ctx: ChannelHandlerContext, blockId: String): Unit = {
- def client = ctx.channel.remoteAddress.toString
-
- // A helper function to send error message back to the client.
- def respondWithError(error: String): Unit = {
- ctx.writeAndFlush(new BlockHeader(-1, blockId, Some(error))).addListener(
- new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture) {
- if (!future.isSuccess) {
- // TODO: Maybe log the success case as well.
- logError(s"Error sending error back to $client", future.cause)
- ctx.close()
- }
- }
- }
- )
- }
-
- def writeFileSegment(segment: FileSegment): Unit = {
- // Send error message back if the block is too large. Even though we are capable of sending
- // large (2G+) blocks, the receiving end cannot handle it so let's fail fast.
- // Once we fixed the receiving end to be able to process large blocks, this should be removed.
- // Also make sure we update BlockHeaderEncoder to support length > 2G.
-
- // See [[BlockHeaderEncoder]] for the way length is encoded.
- if (segment.length + blockId.length + 4 > Int.MaxValue) {
- respondWithError(s"Block $blockId size ($segment.length) greater than 2G")
- return
- }
-
- var fileChannel: FileChannel = null
- try {
- fileChannel = new FileInputStream(segment.file).getChannel
- } catch {
- case e: Exception =>
- logError(
- s"Error opening channel for $blockId in ${segment.file} for request from $client", e)
- respondWithError(e.getMessage)
- }
-
- // Found the block. Send it back.
- if (fileChannel != null) {
- // Write the header and block data. In the case of failures, the listener on the block data
- // write should close the connection.
- ctx.write(new BlockHeader(segment.length.toInt, blockId))
-
- val region = new DefaultFileRegion(fileChannel, segment.offset, segment.length)
- ctx.writeAndFlush(region).addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture) {
- if (future.isSuccess) {
- logTrace(s"Sent block $blockId (${segment.length} B) back to $client")
- } else {
- logError(s"Error sending block $blockId to $client; closing connection", future.cause)
- ctx.close()
- }
- }
- })
- }
- }
-
- def writeByteBuffer(buf: ByteBuffer): Unit = {
- ctx.write(new BlockHeader(buf.remaining, blockId))
- ctx.writeAndFlush(Unpooled.wrappedBuffer(buf)).addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture) {
- if (future.isSuccess) {
- logTrace(s"Sent block $blockId (${buf.remaining} B) back to $client")
- } else {
- logError(s"Error sending block $blockId to $client; closing connection", future.cause)
- ctx.close()
- }
- }
- })
- }
-
- logTrace(s"Received request from $client to fetch block $blockId")
-
- var blockData: Either[FileSegment, ByteBuffer] = null
-
- // First make sure we can find the block. If not, send error back to the user.
- try {
- blockData = dataProvider.getBlockData(blockId)
- } catch {
- case e: Exception =>
- logError(s"Error opening block $blockId for request from $client", e)
- respondWithError(e.getMessage)
- return
- }
-
- blockData match {
- case Left(segment) => writeFileSegment(segment)
- case Right(buf) => writeByteBuffer(buf)
- }
-
- } // end of channelRead0
-}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
index e3113205be..11793ea92a 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
@@ -19,12 +19,13 @@ package org.apache.spark.network.nio
import java.nio.ByteBuffer
-import scala.concurrent.Future
-
-import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf}
import org.apache.spark.network._
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
+
+import scala.concurrent.Future
/**
@@ -71,7 +72,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
/**
* Tear down the transfer service.
*/
- override def stop(): Unit = {
+ override def close(): Unit = {
if (cm != null) {
cm.stop()
}
@@ -95,27 +96,34 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
future.onSuccess { case message =>
val bufferMessage = message.asInstanceOf[BufferMessage]
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
+
// SPARK-4064: In some cases(eg. Remote block was removed) blockMessageArray may be empty.
if (blockMessageArray.isEmpty) {
- listener.onBlockFetchFailure(
- new SparkException(s"Received empty message from $cmId"))
+ blockIds.foreach { id =>
+ listener.onBlockFetchFailure(id, new SparkException(s"Received empty message from $cmId"))
+ }
} else {
- for (blockMessage <- blockMessageArray) {
+ for (blockMessage: BlockMessage <- blockMessageArray) {
val msgType = blockMessage.getType
if (msgType != BlockMessage.TYPE_GOT_BLOCK) {
- listener.onBlockFetchFailure(
- new SparkException(s"Unexpected message ${msgType} received from $cmId"))
+ if (blockMessage.getId != null) {
+ listener.onBlockFetchFailure(blockMessage.getId.toString,
+ new SparkException(s"Unexpected message $msgType received from $cmId"))
+ }
} else {
val blockId = blockMessage.getId
+ val networkSize = blockMessage.getData.limit()
listener.onBlockFetchSuccess(
- blockId.toString, new NioByteBufferManagedBuffer(blockMessage.getData))
+ blockId.toString, new NioManagedBuffer(blockMessage.getData))
}
}
}
}(cm.futureExecContext)
future.onFailure { case exception =>
- listener.onBlockFetchFailure(exception)
+ blockIds.foreach { blockId =>
+ listener.onBlockFetchFailure(blockId, exception)
+ }
}(cm.futureExecContext)
}
@@ -127,12 +135,12 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
override def uploadBlock(
hostname: String,
port: Int,
- blockId: String,
+ blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel)
: Future[Unit] = {
checkInit()
- val msg = PutBlock(BlockId(blockId), blockData.nioByteBuffer(), level)
+ val msg = PutBlock(blockId, blockData.nioByteBuffer(), level)
val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg))
val remoteCmId = new ConnectionManagerId(hostName, port)
val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage)
@@ -154,10 +162,9 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
Some(new BlockMessageArray(responseMessages).toBufferMessage)
} catch {
- case e: Exception => {
+ case e: Exception =>
logError("Exception handling buffer message", e)
Some(Message.createErrorMessage(e, msg.id))
- }
}
case otherMessage: Any =>
@@ -172,13 +179,13 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
case BlockMessage.TYPE_PUT_BLOCK =>
val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
logDebug("Received [" + msg + "]")
- putBlock(msg.id.toString, msg.data, msg.level)
+ putBlock(msg.id, msg.data, msg.level)
None
case BlockMessage.TYPE_GET_BLOCK =>
val msg = new GetBlock(blockMessage.getId)
logDebug("Received [" + msg + "]")
- val buffer = getBlock(msg.id.toString)
+ val buffer = getBlock(msg.id)
if (buffer == null) {
return None
}
@@ -188,20 +195,20 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
}
}
- private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ private def putBlock(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) {
val startTimeMs = System.currentTimeMillis()
logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes)
- blockDataManager.putBlockData(blockId, new NioByteBufferManagedBuffer(bytes), level)
+ blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level)
logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " with data size: " + bytes.limit)
}
- private def getBlock(blockId: String): ByteBuffer = {
+ private def getBlock(blockId: BlockId): ByteBuffer = {
val startTimeMs = System.currentTimeMillis()
logDebug("GetBlock " + blockId + " started from " + startTimeMs)
- val buffer = blockDataManager.getBlockData(blockId).orNull
+ val buffer = blockDataManager.getBlockData(blockId)
logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " and got buffer " + buffer)
- if (buffer == null) null else buffer.nioByteBuffer()
+ buffer.nioByteBuffer()
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index a9144cdd97..ca6e971d22 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -17,14 +17,14 @@
package org.apache.spark.serializer
-import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream}
+import java.io._
import java.nio.ByteBuffer
import scala.reflect.ClassTag
-import org.apache.spark.SparkEnv
+import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
+import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator}
/**
* :: DeveloperApi ::
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
index 439981d232..1fb5b2c454 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
@@ -24,9 +24,9 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConversions._
-import org.apache.spark.{SparkEnv, SparkConf, Logging}
+import org.apache.spark.{Logging, SparkConf, SparkEnv}
import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup
import org.apache.spark.storage._
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
index b5cd34cacd..e9805c9c13 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
@@ -23,7 +23,7 @@ import java.nio.ByteBuffer
import com.google.common.io.ByteStreams
import org.apache.spark.SparkEnv
-import org.apache.spark.network.{ManagedBuffer, FileSegmentManagedBuffer}
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.storage._
/**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
index 63863cc025..b521f0c7fc 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
@@ -18,8 +18,7 @@
package org.apache.spark.shuffle
import java.nio.ByteBuffer
-
-import org.apache.spark.network.ManagedBuffer
+import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.storage.ShuffleBlockId
private[spark]
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 4cc9792365..58510d7232 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -17,15 +17,13 @@
package org.apache.spark.storage
-import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream}
+import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
-import scala.concurrent.ExecutionContext.Implicits.global
-
-import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashMap}
-import scala.concurrent.{Await, Future}
+import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
+import scala.concurrent.{Await, Future}
import scala.util.Random
import akka.actor.{ActorSystem, Props}
@@ -35,11 +33,11 @@ import org.apache.spark._
import org.apache.spark.executor._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.util._
-
private[spark] sealed trait BlockValues
private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues
private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues
@@ -212,21 +210,20 @@ private[spark] class BlockManager(
}
/**
- * Interface to get local block data.
- *
- * @return Some(buffer) if the block exists locally, and None if it doesn't.
+ * Interface to get local block data. Throws an exception if the block cannot be found or
+ * cannot be read successfully.
*/
- override def getBlockData(blockId: String): Option[ManagedBuffer] = {
- val bid = BlockId(blockId)
- if (bid.isShuffle) {
- Some(shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId]))
+ override def getBlockData(blockId: BlockId): ManagedBuffer = {
+ if (blockId.isShuffle) {
+ shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
} else {
- val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]]
+ val blockBytesOpt = doGetLocal(blockId, asBlockResult = false)
+ .asInstanceOf[Option[ByteBuffer]]
if (blockBytesOpt.isDefined) {
val buffer = blockBytesOpt.get
- Some(new NioByteBufferManagedBuffer(buffer))
+ new NioManagedBuffer(buffer)
} else {
- None
+ throw new BlockNotFoundException(blockId.toString)
}
}
}
@@ -234,8 +231,8 @@ private[spark] class BlockManager(
/**
* Put the block locally, using the given storage level.
*/
- override def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = {
- putBytes(BlockId(blockId), data.nioByteBuffer(), level)
+ override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit = {
+ putBytes(blockId, data.nioByteBuffer(), level)
}
/**
@@ -341,17 +338,6 @@ private[spark] class BlockManager(
}
/**
- * A short-circuited method to get blocks directly from disk. This is used for getting
- * shuffle blocks. It is safe to do so without a lock on block info since disk store
- * never deletes (recent) items.
- */
- def getLocalShuffleFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
- val buf = shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
- val is = wrapForCompression(blockId, buf.inputStream())
- Some(serializer.newInstance().deserializeStream(is).asIterator)
- }
-
- /**
* Get block from local block manager.
*/
def getLocal(blockId: BlockId): Option[BlockResult] = {
@@ -869,9 +855,9 @@ private[spark] class BlockManager(
data.rewind()
logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer")
blockTransferService.uploadBlockSync(
- peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel)
- logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %d ms"
- .format((System.currentTimeMillis - onePeerStartTime)))
+ peer.host, peer.port, blockId, new NioManagedBuffer(data), tLevel)
+ logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms"
+ .format(System.currentTimeMillis - onePeerStartTime))
peersReplicatedTo += peer
peersForReplication -= peer
replicationFailed = false
@@ -1126,7 +1112,7 @@ private[spark] class BlockManager(
}
def stop(): Unit = {
- blockTransferService.stop()
+ blockTransferService.close()
diskBlockManager.stop()
actorSystem.stop(slaveActor)
blockInfo.clear()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
index 9ef453605f..81f5f2d31d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
@@ -17,5 +17,4 @@
package org.apache.spark.storage
-
class BlockNotFoundException(blockId: String) extends Exception(s"Block $blockId not found")
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 71b276b5f1..0d6f3bf003 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -19,15 +19,13 @@ package org.apache.spark.storage
import java.util.concurrent.LinkedBlockingQueue
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashSet
-import scala.collection.mutable.Queue
+import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
-import org.apache.spark.{TaskContext, Logging}
-import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService}
+import org.apache.spark.{Logging, TaskContext}
+import org.apache.spark.network.{BlockFetchingListener, BlockTransferService}
+import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.serializer.Serializer
-import org.apache.spark.util.Utils
-
+import org.apache.spark.util.{CompletionIterator, Utils}
/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -88,17 +86,51 @@ final class ShuffleBlockFetcherIterator(
*/
private[this] val results = new LinkedBlockingQueue[FetchResult]
- // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
- // the number of bytes in flight is limited to maxBytesInFlight
+ /**
+ * Current [[FetchResult]] being processed. We track this so we can release the current buffer
+ * in case of a runtime exception when processing the current buffer.
+ */
+ private[this] var currentResult: FetchResult = null
+
+ /**
+ * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+ * the number of bytes in flight is limited to maxBytesInFlight.
+ */
private[this] val fetchRequests = new Queue[FetchRequest]
- // Current bytes in flight from our requests
+ /** Current bytes in flight from our requests */
private[this] var bytesInFlight = 0L
private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
+ /**
+ * Whether the iterator is still active. If isZombie is true, the callback interface will no
+ * longer place fetched blocks into [[results]].
+ */
+ @volatile private[this] var isZombie = false
+
initialize()
+ /**
+ * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
+ */
+ private[this] def cleanup() {
+ isZombie = true
+ // Release the current buffer if necessary
+ if (currentResult != null && !currentResult.failed) {
+ currentResult.buf.release()
+ }
+
+ // Release buffers in the results queue
+ val iter = results.iterator()
+ while (iter.hasNext) {
+ val result = iter.next()
+ if (!result.failed) {
+ result.buf.release()
+ }
+ }
+ }
+
private[this] def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
@@ -110,24 +142,23 @@ final class ShuffleBlockFetcherIterator(
blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds,
new BlockFetchingListener {
- override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
- results.put(new FetchResult(BlockId(blockId), sizeMap(blockId),
- () => serializer.newInstance().deserializeStream(
- blockManager.wrapForCompression(BlockId(blockId), data.inputStream())).asIterator
- ))
- shuffleMetrics.remoteBytesRead += data.size
- shuffleMetrics.remoteBlocksFetched += 1
- logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+ override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
+ // Only add the buffer to results queue if the iterator is not zombie,
+ // i.e. cleanup() has not been called yet.
+ if (!isZombie) {
+ // Increment the ref count because we need to pass this to a different thread.
+ // This needs to be released after use.
+ buf.retain()
+ results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), buf))
+ shuffleMetrics.remoteBytesRead += buf.size
+ shuffleMetrics.remoteBlocksFetched += 1
+ }
+ logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
- override def onBlockFetchFailure(e: Throwable): Unit = {
+ override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
- // Note that there is a chance that some blocks have been fetched successfully, but we
- // still add them to the failed queue. This is fine because when the caller see a
- // FetchFailedException, it is going to fail the entire task anyway.
- for ((blockId, size) <- req.blocks) {
- results.put(new FetchResult(blockId, -1, null))
- }
+ results.put(new FetchResult(BlockId(blockId), -1, null))
}
}
)
@@ -138,7 +169,7 @@ final class ShuffleBlockFetcherIterator(
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
- logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
+ logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
@@ -185,26 +216,34 @@ final class ShuffleBlockFetcherIterator(
remoteRequests
}
+ /**
+ * Fetch the local blocks while we are fetching remote blocks. This is ok because
+ * [[ManagedBuffer]]'s memory is allocated lazily when we create the input stream, so all we
+ * track in-memory are the ManagedBuffer references themselves.
+ */
private[this] def fetchLocalBlocks() {
- // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
- // these all at once because they will just memory-map some files, so they won't consume
- // any memory that might exceed our maxBytesInFlight
- for (id <- localBlocks) {
+ val iter = localBlocks.iterator
+ while (iter.hasNext) {
+ val blockId = iter.next()
try {
+ val buf = blockManager.getBlockData(blockId)
shuffleMetrics.localBlocksFetched += 1
- results.put(new FetchResult(
- id, 0, () => blockManager.getLocalShuffleFromDisk(id, serializer).get))
- logDebug("Got local block " + id)
+ buf.retain()
+ results.put(new FetchResult(blockId, 0, buf))
} catch {
case e: Exception =>
+ // If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
- results.put(new FetchResult(id, -1, null))
+ results.put(new FetchResult(blockId, -1, null))
return
}
}
}
private[this] def initialize(): Unit = {
+ // Add a task completion callback (called in both success case and failure case) to cleanup.
+ context.addTaskCompletionListener(_ => cleanup())
+
// Split local and remote blocks.
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
@@ -229,7 +268,8 @@ final class ShuffleBlockFetcherIterator(
override def next(): (BlockId, Option[Iterator[Any]]) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
- val result = results.take()
+ currentResult = results.take()
+ val result = currentResult
val stopFetchWait = System.currentTimeMillis()
shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
if (!result.failed) {
@@ -240,7 +280,21 @@ final class ShuffleBlockFetcherIterator(
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
sendRequest(fetchRequests.dequeue())
}
- (result.blockId, if (result.failed) None else Some(result.deserialize()))
+
+ val iteratorOpt: Option[Iterator[Any]] = if (result.failed) {
+ None
+ } else {
+ val is = blockManager.wrapForCompression(result.blockId, result.buf.createInputStream())
+ val iter = serializer.newInstance().deserializeStream(is).asIterator
+ Some(CompletionIterator[Any, Iterator[Any]](iter, {
+ // Once the iterator is exhausted, release the buffer and set currentResult to null
+ // so we don't release it again in cleanup.
+ currentResult = null
+ result.buf.release()
+ }))
+ }
+
+ (result.blockId, iteratorOpt)
}
}
@@ -254,7 +308,7 @@ object ShuffleBlockFetcherIterator {
* @param blocks Sequence of tuple, where the first element is the block id,
* and the second element is the estimated size, used to calculate bytesInFlight.
*/
- class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) {
+ case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) {
val size = blocks.map(_._2).sum
}
@@ -262,10 +316,11 @@ object ShuffleBlockFetcherIterator {
* Result of a fetch from a remote block. A failure is represented as size == -1.
* @param blockId block id
* @param size estimated size of the block, used to calculate bytesInFlight.
- * Note that this is NOT the exact bytes.
- * @param deserialize closure to return the result in the form of an Iterator.
+ * Note that this is NOT the exact bytes. -1 if failure is present.
+ * @param buf [[ManagedBuffer]] for the content. null is error.
*/
- class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) {
+ case class FetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) {
def failed: Boolean = size == -1
+ if (failed) assert(buf == null) else assert(buf != null)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 1e881da511..0daab91143 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -43,7 +43,6 @@ import org.json4s._
import tachyon.client.{TachyonFile,TachyonFS}
import org.apache.spark._
-import org.apache.spark.util.SparkUncaughtExceptionHandler
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
/** CallSite represents a place in user code. It can have a short and a long form. */
diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
index d7b2d2e1e3..840d8273cb 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
@@ -24,10 +24,10 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll {
// This test suite should run all tests in ShuffleSuite with Netty shuffle mode.
override def beforeAll() {
- System.setProperty("spark.shuffle.use.netty", "true")
+ System.setProperty("spark.shuffle.blockTransferService", "netty")
}
override def afterAll() {
- System.clearProperty("spark.shuffle.use.netty")
+ System.clearProperty("spark.shuffle.blockTransferService")
}
}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala
deleted file mode 100644
index 02d0ffc86f..0000000000
--- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala
+++ /dev/null
@@ -1,161 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.io.{RandomAccessFile, File}
-import java.nio.ByteBuffer
-import java.util.{Collections, HashSet}
-import java.util.concurrent.{TimeUnit, Semaphore}
-
-import scala.collection.JavaConversions._
-
-import io.netty.buffer.{ByteBufUtil, Unpooled}
-
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
-
-import org.apache.spark.SparkConf
-import org.apache.spark.network.netty.client.{BlockClientListener, ReferenceCountedBuffer, BlockFetchingClientFactory}
-import org.apache.spark.network.netty.server.BlockServer
-import org.apache.spark.storage.{FileSegment, BlockDataProvider}
-
-
-/**
- * Test suite that makes sure the server and the client implementations share the same protocol.
- */
-class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {
-
- val bufSize = 100000
- var buf: ByteBuffer = _
- var testFile: File = _
- var server: BlockServer = _
- var clientFactory: BlockFetchingClientFactory = _
-
- val bufferBlockId = "buffer_block"
- val fileBlockId = "file_block"
-
- val fileContent = new Array[Byte](1024)
- scala.util.Random.nextBytes(fileContent)
-
- override def beforeAll() = {
- buf = ByteBuffer.allocate(bufSize)
- for (i <- 1 to bufSize) {
- buf.put(i.toByte)
- }
- buf.flip()
-
- testFile = File.createTempFile("netty-test-file", "txt")
- val fp = new RandomAccessFile(testFile, "rw")
- fp.write(fileContent)
- fp.close()
-
- server = new BlockServer(new SparkConf, new BlockDataProvider {
- override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = {
- if (blockId == bufferBlockId) {
- Right(buf)
- } else if (blockId == fileBlockId) {
- Left(new FileSegment(testFile, 10, testFile.length - 25))
- } else {
- throw new Exception("Unknown block id " + blockId)
- }
- }
- })
-
- clientFactory = new BlockFetchingClientFactory(new SparkConf)
- }
-
- override def afterAll() = {
- server.stop()
- clientFactory.stop()
- }
-
- /** A ByteBuf for buffer_block */
- lazy val byteBufferBlockReference = Unpooled.wrappedBuffer(buf)
-
- /** A ByteBuf for file_block */
- lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25)
-
- def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ReferenceCountedBuffer], Set[String]) =
- {
- val client = clientFactory.createClient(server.hostName, server.port)
- val sem = new Semaphore(0)
- val receivedBlockIds = Collections.synchronizedSet(new HashSet[String])
- val errorBlockIds = Collections.synchronizedSet(new HashSet[String])
- val receivedBuffers = Collections.synchronizedSet(new HashSet[ReferenceCountedBuffer])
-
- client.fetchBlocks(
- blockIds,
- new BlockClientListener {
- override def onFetchFailure(blockId: String, errorMsg: String): Unit = {
- errorBlockIds.add(blockId)
- sem.release()
- }
-
- override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = {
- receivedBlockIds.add(blockId)
- data.retain()
- receivedBuffers.add(data)
- sem.release()
- }
- }
- )
- if (!sem.tryAcquire(blockIds.size, 30, TimeUnit.SECONDS)) {
- fail("Timeout getting response from the server")
- }
- client.close()
- (receivedBlockIds.toSet, receivedBuffers.toSet, errorBlockIds.toSet)
- }
-
- test("fetch a ByteBuffer block") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId))
- assert(blockIds === Set(bufferBlockId))
- assert(buffers.map(_.underlying) === Set(byteBufferBlockReference))
- assert(failBlockIds.isEmpty)
- buffers.foreach(_.release())
- }
-
- test("fetch a FileSegment block via zero-copy send") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId))
- assert(blockIds === Set(fileBlockId))
- assert(buffers.map(_.underlying) === Set(fileBlockReference))
- assert(failBlockIds.isEmpty)
- buffers.foreach(_.release())
- }
-
- test("fetch a non-existent block") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block"))
- assert(blockIds.isEmpty)
- assert(buffers.isEmpty)
- assert(failBlockIds === Set("random-block"))
- }
-
- test("fetch both ByteBuffer block and FileSegment block") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId))
- assert(blockIds === Set(bufferBlockId, fileBlockId))
- assert(buffers.map(_.underlying) === Set(byteBufferBlockReference, fileBlockReference))
- assert(failBlockIds.isEmpty)
- buffers.foreach(_.release())
- }
-
- test("fetch both ByteBuffer block and a non-existent block") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block"))
- assert(blockIds === Set(bufferBlockId))
- assert(buffers.map(_.underlying) === Set(byteBufferBlockReference))
- assert(failBlockIds === Set("random-block"))
- buffers.foreach(_.release())
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala
deleted file mode 100644
index f629322ff6..0000000000
--- a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import java.nio.ByteBuffer
-
-import com.google.common.base.Charsets.UTF_8
-import io.netty.buffer.Unpooled
-import io.netty.channel.embedded.EmbeddedChannel
-
-import org.scalatest.{PrivateMethodTester, FunSuite}
-
-
-class BlockFetchingClientHandlerSuite extends FunSuite with PrivateMethodTester {
-
- test("handling block data (successful fetch)") {
- val blockId = "test_block"
- val blockData = "blahblahblahblahblah"
- val totalLength = 4 + blockId.length + blockData.length
-
- var parsedBlockId: String = ""
- var parsedBlockData: String = ""
- val handler = new BlockFetchingClientHandler
- handler.addRequest(blockId,
- new BlockClientListener {
- override def onFetchFailure(blockId: String, errorMsg: String): Unit = ???
- override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer): Unit = {
- parsedBlockId = bid
- val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining)
- refCntBuf.byteBuffer().get(bytes)
- parsedBlockData = new String(bytes, UTF_8)
- }
- }
- )
-
- val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests)
- assert(handler.invokePrivate(outstandingRequests()).size === 1)
-
- val channel = new EmbeddedChannel(handler)
- val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself
- buf.putInt(totalLength)
- buf.putInt(blockId.length)
- buf.put(blockId.getBytes)
- buf.put(blockData.getBytes)
- buf.flip()
-
- channel.writeInbound(Unpooled.wrappedBuffer(buf))
- assert(parsedBlockId === blockId)
- assert(parsedBlockData === blockData)
-
- assert(handler.invokePrivate(outstandingRequests()).size === 0)
-
- channel.close()
- }
-
- test("handling error message (failed fetch)") {
- val blockId = "test_block"
- val errorMsg = "error erro5r error err4or error3 error6 error erro1r"
- val totalLength = 4 + blockId.length + errorMsg.length
-
- var parsedBlockId: String = ""
- var parsedErrorMsg: String = ""
- val handler = new BlockFetchingClientHandler
- handler.addRequest(blockId, new BlockClientListener {
- override def onFetchFailure(bid: String, msg: String) ={
- parsedBlockId = bid
- parsedErrorMsg = msg
- }
- override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer) = ???
- })
-
- val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests)
- assert(handler.invokePrivate(outstandingRequests()).size === 1)
-
- val channel = new EmbeddedChannel(handler)
- val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself
- buf.putInt(totalLength)
- buf.putInt(-blockId.length)
- buf.put(blockId.getBytes)
- buf.put(errorMsg.getBytes)
- buf.flip()
-
- channel.writeInbound(Unpooled.wrappedBuffer(buf))
- assert(parsedBlockId === blockId)
- assert(parsedErrorMsg === errorMsg)
-
- assert(handler.invokePrivate(outstandingRequests()).size === 0)
-
- channel.close()
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala
deleted file mode 100644
index 3f8d0cf8f3..0000000000
--- a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import com.google.common.base.Charsets.UTF_8
-import io.netty.buffer.ByteBuf
-import io.netty.channel.embedded.EmbeddedChannel
-
-import org.scalatest.FunSuite
-
-class BlockHeaderEncoderSuite extends FunSuite {
-
- test("encode normal block data") {
- val blockId = "test_block"
- val channel = new EmbeddedChannel(new BlockHeaderEncoder)
- channel.writeOutbound(new BlockHeader(17, blockId, None))
- val out = channel.readOutbound().asInstanceOf[ByteBuf]
- assert(out.readInt() === 4 + blockId.length + 17)
- assert(out.readInt() === blockId.length)
-
- val blockIdBytes = new Array[Byte](blockId.length)
- out.readBytes(blockIdBytes)
- assert(new String(blockIdBytes, UTF_8) === blockId)
- assert(out.readableBytes() === 0)
-
- channel.close()
- }
-
- test("encode error message") {
- val blockId = "error_block"
- val errorMsg = "error encountered"
- val channel = new EmbeddedChannel(new BlockHeaderEncoder)
- channel.writeOutbound(new BlockHeader(17, blockId, Some(errorMsg)))
- val out = channel.readOutbound().asInstanceOf[ByteBuf]
- assert(out.readInt() === 4 + blockId.length + errorMsg.length)
- assert(out.readInt() === -blockId.length)
-
- val blockIdBytes = new Array[Byte](blockId.length)
- out.readBytes(blockIdBytes)
- assert(new String(blockIdBytes, UTF_8) === blockId)
-
- val errorMsgBytes = new Array[Byte](errorMsg.length)
- out.readBytes(errorMsgBytes)
- assert(new String(errorMsgBytes, UTF_8) === errorMsg)
- assert(out.readableBytes() === 0)
-
- channel.close()
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala
deleted file mode 100644
index 3239c710f1..0000000000
--- a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import java.io.{RandomAccessFile, File}
-import java.nio.ByteBuffer
-
-import io.netty.buffer.{Unpooled, ByteBuf}
-import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler, DefaultFileRegion}
-import io.netty.channel.embedded.EmbeddedChannel
-
-import org.scalatest.FunSuite
-
-import org.apache.spark.storage.{BlockDataProvider, FileSegment}
-
-
-class BlockServerHandlerSuite extends FunSuite {
-
- test("ByteBuffer block") {
- val expectedBlockId = "test_bytebuffer_block"
- val buf = ByteBuffer.allocate(10000)
- for (i <- 1 to 10000) {
- buf.put(i.toByte)
- }
- buf.flip()
-
- val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider {
- override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = Right(buf)
- }))
-
- channel.writeInbound(expectedBlockId)
- assert(channel.outboundMessages().size === 2)
-
- val out1 = channel.readOutbound().asInstanceOf[BlockHeader]
- val out2 = channel.readOutbound().asInstanceOf[ByteBuf]
-
- assert(out1.blockId === expectedBlockId)
- assert(out1.blockSize === buf.remaining)
- assert(out1.error === None)
-
- assert(out2.equals(Unpooled.wrappedBuffer(buf)))
-
- channel.close()
- }
-
- test("FileSegment block via zero-copy") {
- val expectedBlockId = "test_file_block"
-
- // Create random file data
- val fileContent = new Array[Byte](1024)
- scala.util.Random.nextBytes(fileContent)
- val testFile = File.createTempFile("netty-test-file", "txt")
- val fp = new RandomAccessFile(testFile, "rw")
- fp.write(fileContent)
- fp.close()
-
- val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider {
- override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = {
- Left(new FileSegment(testFile, 15, testFile.length - 25))
- }
- }))
-
- channel.writeInbound(expectedBlockId)
- assert(channel.outboundMessages().size === 2)
-
- val out1 = channel.readOutbound().asInstanceOf[BlockHeader]
- val out2 = channel.readOutbound().asInstanceOf[DefaultFileRegion]
-
- assert(out1.blockId === expectedBlockId)
- assert(out1.blockSize === testFile.length - 25)
- assert(out1.error === None)
-
- assert(out2.count === testFile.length - 25)
- assert(out2.position === 15)
- }
-
- test("pipeline exception propagation") {
- val blockServerHandler = new BlockServerHandler(new BlockDataProvider {
- override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = ???
- })
- val exceptionHandler = new SimpleChannelInboundHandler[String]() {
- override def channelRead0(ctx: ChannelHandlerContext, msg: String): Unit = {
- throw new Exception("this is an error")
- }
- }
-
- val channel = new EmbeddedChannel(exceptionHandler, blockServerHandler)
- assert(channel.isOpen)
- channel.writeInbound("a message to trigger the error")
- assert(!channel.isOpen)
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala
new file mode 100644
index 0000000000..0ade1bab18
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import java.io.{EOFException, OutputStream, InputStream}
+import java.nio.ByteBuffer
+
+import scala.reflect.ClassTag
+
+
+/**
+ * A serializer implementation that always return a single element in a deserialization stream.
+ */
+class TestSerializer extends Serializer {
+ override def newInstance() = new TestSerializerInstance
+}
+
+
+class TestSerializerInstance extends SerializerInstance {
+ override def serialize[T: ClassTag](t: T): ByteBuffer = ???
+
+ override def serializeStream(s: OutputStream): SerializationStream = ???
+
+ override def deserializeStream(s: InputStream) = new TestDeserializationStream
+
+ override def deserialize[T: ClassTag](bytes: ByteBuffer): T = ???
+
+ override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = ???
+}
+
+
+class TestDeserializationStream extends DeserializationStream {
+
+ private var count = 0
+
+ override def readObject[T: ClassTag](): T = {
+ count += 1
+ if (count == 2) {
+ throw new EOFException
+ }
+ new Object().asInstanceOf[T]
+ }
+
+ override def close(): Unit = {}
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
index ba47fe5e25..6790388f96 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.FunSuite
import org.apache.spark.{SparkEnv, SparkContext, LocalSparkContext, SparkConf}
import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.shuffle.FileShuffleBlockManager
import org.apache.spark.storage.{ShuffleBlockId, FileSegment}
@@ -36,9 +36,9 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext {
private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) {
assert(buffer.isInstanceOf[FileSegmentManagedBuffer])
val segment = buffer.asInstanceOf[FileSegmentManagedBuffer]
- assert(expected.file.getCanonicalPath === segment.file.getCanonicalPath)
- assert(expected.offset === segment.offset)
- assert(expected.length === segment.length)
+ assert(expected.file.getCanonicalPath === segment.getFile.getCanonicalPath)
+ assert(expected.offset === segment.getOffset)
+ assert(expected.length === segment.getLength)
}
test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") {
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index a8c049d749..4e502cf65e 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -17,6 +17,10 @@
package org.apache.spark.storage
+import java.util.concurrent.Semaphore
+
+import scala.concurrent.future
+import scala.concurrent.ExecutionContext.Implicits.global
import org.apache.spark.{TaskContextImpl, TaskContext}
import org.apache.spark.network.{BlockFetchingListener, BlockTransferService}
@@ -27,38 +31,64 @@ import org.mockito.stubbing.Answer
import org.scalatest.FunSuite
+import org.apache.spark.{SparkConf, TaskContext}
+import org.apache.spark.network._
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.serializer.TestSerializer
+
class ShuffleBlockFetcherIteratorSuite extends FunSuite {
+ // Some of the tests are quite tricky because we are testing the cleanup behavior
+ // in the presence of faults.
- test("handle local read failures in BlockManager") {
+ /** Creates a mock [[BlockTransferService]] that returns data from the given map. */
+ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = {
val transfer = mock(classOf[BlockTransferService])
- val blockManager = mock(classOf[BlockManager])
- doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId
-
- val blIds = Array[BlockId](
- ShuffleBlockId(0,0,0),
- ShuffleBlockId(0,1,0),
- ShuffleBlockId(0,2,0),
- ShuffleBlockId(0,3,0),
- ShuffleBlockId(0,4,0))
-
- val optItr = mock(classOf[Option[Iterator[Any]]])
- val answer = new Answer[Option[Iterator[Any]]] {
- override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] {
- throw new Exception
+ when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ override def answer(invocation: InvocationOnMock): Unit = {
+ val blocks = invocation.getArguments()(2).asInstanceOf[Seq[String]]
+ val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener]
+
+ for (blockId <- blocks) {
+ if (data.contains(BlockId(blockId))) {
+ listener.onBlockFetchSuccess(blockId, data(BlockId(blockId)))
+ } else {
+ listener.onBlockFetchFailure(blockId, new BlockNotFoundException(blockId))
+ }
+ }
}
+ })
+ transfer
+ }
+
+ private val conf = new SparkConf
+
+ test("successful 3 local reads + 2 remote reads") {
+ val blockManager = mock(classOf[BlockManager])
+ val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(localBmId).when(blockManager).blockManagerId
+
+ // Make sure blockManager.getBlockData would return the blocks
+ val localBlocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]))
+ localBlocks.foreach { case (blockId, buf) =>
+ doReturn(buf).when(blockManager).getBlockData(meq(blockId))
}
- // 3rd block is going to fail
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any())
- doAnswer(answer).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any())
+ // Make sure remote blocks would return
+ val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+ val remoteBlocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer])
+ )
+
+ val transfer = createMockTransfer(remoteBlocks)
- val bmId = BlockManagerId("test-client", "test-client", 1)
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
+ (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq),
+ (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)
)
val iterator = new ShuffleBlockFetcherIterator(
@@ -66,118 +96,145 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
transfer,
blockManager,
blocksByAddress,
- null,
+ new TestSerializer,
48 * 1024 * 1024)
- // Without exhausting the iterator, the iterator should be lazy and not call
- // getLocalShuffleFromDisk.
- verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
- // the 2nd element of the tuple returned by iterator.next should be defined when
- // fetching successfully
- assert(iterator.next()._2.isDefined,
- "1st element should be defined but is not actually defined")
- verify(blockManager, times(1)).getLocalShuffleFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
- assert(iterator.next()._2.isDefined,
- "2nd element should be defined but is not actually defined")
- verify(blockManager, times(2)).getLocalShuffleFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
- // 3rd fetch should be failed
- intercept[Exception] {
- iterator.next()
+ // 3 local blocks fetched in initialization
+ verify(blockManager, times(3)).getBlockData(any())
+
+ for (i <- 0 until 5) {
+ assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements")
+ val (blockId, subIterator) = iterator.next()
+ assert(subIterator.isDefined,
+ s"iterator should have 5 elements defined but actually has $i elements")
+
+ // Make sure we release the buffer once the iterator is exhausted.
+ val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId))
+ verify(mockBuf, times(0)).release()
+ subIterator.get.foreach(_ => Unit) // exhaust the iterator
+ verify(mockBuf, times(1)).release()
}
- verify(blockManager, times(3)).getLocalShuffleFromDisk(any(), any())
+
+ // 3 local blocks, and 2 remote blocks
+ // (but from the same block manager so one call to fetchBlocks)
+ verify(blockManager, times(3)).getBlockData(any())
+ verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any())
}
- test("handle local read successes") {
- val transfer = mock(classOf[BlockTransferService])
+ test("release current unexhausted buffer in case the task completes early") {
val blockManager = mock(classOf[BlockManager])
- doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId
-
- val blIds = Array[BlockId](
- ShuffleBlockId(0,0,0),
- ShuffleBlockId(0,1,0),
- ShuffleBlockId(0,2,0),
- ShuffleBlockId(0,3,0),
- ShuffleBlockId(0,4,0))
+ val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(localBmId).when(blockManager).blockManagerId
+
+ // Make sure remote blocks would return
+ val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+ val blocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])
+ )
- val optItr = mock(classOf[Option[Iterator[Any]]])
+ // Semaphore to coordinate event sequence in two different threads.
+ val sem = new Semaphore(0)
- // All blocks should be fetched successfully
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any())
+ val transfer = mock(classOf[BlockTransferService])
+ when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ override def answer(invocation: InvocationOnMock): Unit = {
+ val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener]
+ future {
+ // Return the first two blocks, and wait till task completion before returning the 3rd one
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0)))
+ sem.acquire()
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
+ }
+ }
+ })
- val bmId = BlockManagerId("test-client", "test-client", 1)
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
- )
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+ val taskContext = new TaskContextImpl(0, 0, 0)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContextImpl(0, 0, 0),
+ taskContext,
transfer,
blockManager,
blocksByAddress,
- null,
+ new TestSerializer,
48 * 1024 * 1024)
- // Without exhausting the iterator, the iterator should be lazy and not call getLocalShuffleFromDisk.
- verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
- assert(iterator.next()._2.isDefined,
- "All elements should be defined but 1st element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
- assert(iterator.next()._2.isDefined,
- "All elements should be defined but 2nd element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
- assert(iterator.next()._2.isDefined,
- "All elements should be defined but 3rd element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements")
- assert(iterator.next()._2.isDefined,
- "All elements should be defined but 4th element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements")
- assert(iterator.next()._2.isDefined,
- "All elements should be defined but 5th element is not actually defined")
-
- verify(blockManager, times(5)).getLocalShuffleFromDisk(any(), any())
+ // Exhaust the first block, and then it should be released.
+ iterator.next()._2.get.foreach(_ => Unit)
+ verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release()
+
+ // Get the 2nd block but do not exhaust the iterator
+ val subIter = iterator.next()._2.get
+
+ // Complete the task; then the 2nd block buffer should be exhausted
+ verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release()
+ taskContext.markTaskCompleted()
+ verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release()
+
+ // The 3rd block should not be retained because the iterator is already in zombie state
+ sem.release()
+ verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).retain()
+ verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release()
}
- test("handle remote fetch failures in BlockTransferService") {
+ test("fail all blocks if any of the remote request fails") {
+ val blockManager = mock(classOf[BlockManager])
+ val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(localBmId).when(blockManager).blockManagerId
+
+ // Make sure remote blocks would return
+ val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+ val blocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])
+ )
+
+ // Semaphore to coordinate event sequence in two different threads.
+ val sem = new Semaphore(0)
+
val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener]
- listener.onBlockFetchFailure(new Exception("blah"))
+ future {
+ // Return the first block, and then fail.
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+ listener.onBlockFetchFailure(
+ ShuffleBlockId(0, 1, 0).toString, new BlockNotFoundException("blah"))
+ listener.onBlockFetchFailure(
+ ShuffleBlockId(0, 2, 0).toString, new BlockNotFoundException("blah"))
+ sem.release()
+ }
}
})
- val blockManager = mock(classOf[BlockManager])
-
- when(blockManager.blockManagerId).thenReturn(BlockManagerId("test-client", "test-client", 1))
-
- val blId1 = ShuffleBlockId(0, 0, 0)
- val blId2 = ShuffleBlockId(0, 1, 0)
- val bmId = BlockManagerId("test-server", "test-server", 1)
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (bmId, Seq((blId1, 1L), (blId2, 1L))))
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+ val taskContext = new TaskContextImpl(0, 0, 0)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContextImpl(0, 0, 0),
+ taskContext,
transfer,
blockManager,
blocksByAddress,
- null,
+ new TestSerializer,
48 * 1024 * 1024)
- iterator.foreach { case (_, iterOption) =>
- assert(!iterOption.isDefined)
- }
+ // Continue only after the mock calls onBlockFetchFailure
+ sem.acquire()
+
+ // The first block should be defined, and the last two are not defined (due to failure)
+ assert(iterator.next()._2.isDefined === true)
+ assert(iterator.next()._2.isDefined === false)
+ assert(iterator.next()._2.isDefined === false)
}
}
diff --git a/network/common/pom.xml b/network/common/pom.xml
new file mode 100644
index 0000000000..e3b7e32870
--- /dev/null
+++ b/network/common/pom.xml
@@ -0,0 +1,94 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-parent</artifactId>
+ <version>1.2.0-SNAPSHOT</version>
+ <relativePath>../../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.spark</groupId>
+ <artifactId>network</artifactId>
+ <packaging>jar</packaging>
+ <name>Shuffle Streaming Service</name>
+ <url>http://spark.apache.org/</url>
+ <properties>
+ <sbt.project.name>network</sbt.project.name>
+ </properties>
+
+ <dependencies>
+ <!-- Core dependencies -->
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-all</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ </dependency>
+
+ <!-- Provided dependencies -->
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ <scope>provided</scope>
+ </dependency>
+
+ <!-- Test dependencies -->
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>log4j</groupId>
+ <artifactId>log4j</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+
+
+ <build>
+ <outputDirectory>target/java/classes</outputDirectory>
+ <testOutputDirectory>target/java/test-classes</testOutputDirectory>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-surefire-plugin</artifactId>
+ <version>2.17</version>
+ <configuration>
+ <skipTests>false</skipTests>
+ <includes>
+ <include>**/Test*.java</include>
+ <include>**/*Test.java</include>
+ <include>**/*Suite.java</include>
+ </includes>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+</project>
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
new file mode 100644
index 0000000000..854aa6685f
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import io.netty.channel.Channel;
+import io.netty.channel.socket.SocketChannel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.client.TransportResponseHandler;
+import org.apache.spark.network.protocol.MessageDecoder;
+import org.apache.spark.network.protocol.MessageEncoder;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportChannelHandler;
+import org.apache.spark.network.server.TransportRequestHandler;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to
+ * setup Netty Channel pipelines with a {@link org.apache.spark.network.server.TransportChannelHandler}.
+ *
+ * There are two communication protocols that the TransportClient provides, control-plane RPCs and
+ * data-plane "chunk fetching". The handling of the RPCs is performed outside of the scope of the
+ * TransportContext (i.e., by a user-provided handler), and it is responsible for setting up streams
+ * which can be streamed through the data plane in chunks using zero-copy IO.
+ *
+ * The TransportServer and TransportClientFactory both create a TransportChannelHandler for each
+ * channel. As each TransportChannelHandler contains a TransportClient, this enables server
+ * processes to send messages back to the client on an existing channel.
+ */
+public class TransportContext {
+ private final Logger logger = LoggerFactory.getLogger(TransportContext.class);
+
+ private final TransportConf conf;
+ private final StreamManager streamManager;
+ private final RpcHandler rpcHandler;
+
+ private final MessageEncoder encoder;
+ private final MessageDecoder decoder;
+
+ public TransportContext(TransportConf conf, StreamManager streamManager, RpcHandler rpcHandler) {
+ this.conf = conf;
+ this.streamManager = streamManager;
+ this.rpcHandler = rpcHandler;
+ this.encoder = new MessageEncoder();
+ this.decoder = new MessageDecoder();
+ }
+
+ public TransportClientFactory createClientFactory() {
+ return new TransportClientFactory(this);
+ }
+
+ public TransportServer createServer() {
+ return new TransportServer(this);
+ }
+
+ /**
+ * Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and
+ * has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or
+ * response messages.
+ *
+ * @return Returns the created TransportChannelHandler, which includes a TransportClient that can
+ * be used to communicate on this channel. The TransportClient is directly associated with a
+ * ChannelHandler to ensure all users of the same channel get the same TransportClient object.
+ */
+ public TransportChannelHandler initializePipeline(SocketChannel channel) {
+ try {
+ TransportChannelHandler channelHandler = createChannelHandler(channel);
+ channel.pipeline()
+ .addLast("encoder", encoder)
+ .addLast("frameDecoder", NettyUtils.createFrameDecoder())
+ .addLast("decoder", decoder)
+ // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
+ // would require more logic to guarantee if this were not part of the same event loop.
+ .addLast("handler", channelHandler);
+ return channelHandler;
+ } catch (RuntimeException e) {
+ logger.error("Error while initializing Netty pipeline", e);
+ throw e;
+ }
+ }
+
+ /**
+ * Creates the server- and client-side handler which is used to handle both RequestMessages and
+ * ResponseMessages. The channel is expected to have been successfully created, though certain
+ * properties (such as the remoteAddress()) may not be available yet.
+ */
+ private TransportChannelHandler createChannelHandler(Channel channel) {
+ TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
+ TransportClient client = new TransportClient(channel, responseHandler);
+ TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
+ streamManager, rpcHandler);
+ return new TransportChannelHandler(client, responseHandler, requestHandler);
+ }
+
+ public TransportConf getConf() { return conf; }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
new file mode 100644
index 0000000000..89ed79bc63
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.RandomAccessFile;
+import java.nio.ByteBuffer;
+import java.nio.channels.FileChannel;
+
+import com.google.common.base.Objects;
+import com.google.common.io.ByteStreams;
+import io.netty.channel.DefaultFileRegion;
+
+import org.apache.spark.network.util.JavaUtils;
+
+/**
+ * A {@link ManagedBuffer} backed by a segment in a file.
+ */
+public final class FileSegmentManagedBuffer extends ManagedBuffer {
+
+ /**
+ * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889).
+ * Avoid unless there's a good reason not to.
+ */
+ // TODO: Make this configurable
+ private static final long MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024;
+
+ private final File file;
+ private final long offset;
+ private final long length;
+
+ public FileSegmentManagedBuffer(File file, long offset, long length) {
+ this.file = file;
+ this.offset = offset;
+ this.length = length;
+ }
+
+ @Override
+ public long size() {
+ return length;
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ FileChannel channel = null;
+ try {
+ channel = new RandomAccessFile(file, "r").getChannel();
+ // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead.
+ if (length < MIN_MEMORY_MAP_BYTES) {
+ ByteBuffer buf = ByteBuffer.allocate((int) length);
+ channel.position(offset);
+ while (buf.remaining() != 0) {
+ if (channel.read(buf) == -1) {
+ throw new IOException(String.format("Reached EOF before filling buffer\n" +
+ "offset=%s\nfile=%s\nbuf.remaining=%s",
+ offset, file.getAbsoluteFile(), buf.remaining()));
+ }
+ }
+ buf.flip();
+ return buf;
+ } else {
+ return channel.map(FileChannel.MapMode.READ_ONLY, offset, length);
+ }
+ } catch (IOException e) {
+ try {
+ if (channel != null) {
+ long size = channel.size();
+ throw new IOException("Error in reading " + this + " (actual file length " + size + ")",
+ e);
+ }
+ } catch (IOException ignored) {
+ // ignore
+ }
+ throw new IOException("Error in opening " + this, e);
+ } finally {
+ JavaUtils.closeQuietly(channel);
+ }
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ FileInputStream is = null;
+ try {
+ is = new FileInputStream(file);
+ ByteStreams.skipFully(is, offset);
+ return ByteStreams.limit(is, length);
+ } catch (IOException e) {
+ try {
+ if (is != null) {
+ long size = file.length();
+ throw new IOException("Error in reading " + this + " (actual file length " + size + ")",
+ e);
+ }
+ } catch (IOException ignored) {
+ // ignore
+ } finally {
+ JavaUtils.closeQuietly(is);
+ }
+ throw new IOException("Error in opening " + this, e);
+ } catch (RuntimeException e) {
+ JavaUtils.closeQuietly(is);
+ throw e;
+ }
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ FileChannel fileChannel = new FileInputStream(file).getChannel();
+ return new DefaultFileRegion(fileChannel, offset, length);
+ }
+
+ public File getFile() { return file; }
+
+ public long getOffset() { return offset; }
+
+ public long getLength() { return length; }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("file", file)
+ .add("offset", offset)
+ .add("length", length)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
new file mode 100644
index 0000000000..a415db593a
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+/**
+ * This interface provides an immutable view for data in the form of bytes. The implementation
+ * should specify how the data is provided:
+ *
+ * - {@link FileSegmentManagedBuffer}: data backed by part of a file
+ * - {@link NioManagedBuffer}: data backed by a NIO ByteBuffer
+ * - {@link NettyManagedBuffer}: data backed by a Netty ByteBuf
+ *
+ * The concrete buffer implementation might be managed outside the JVM garbage collector.
+ * For example, in the case of {@link NettyManagedBuffer}, the buffers are reference counted.
+ * In that case, if the buffer is going to be passed around to a different thread, retain/release
+ * should be called.
+ */
+public abstract class ManagedBuffer {
+
+ /** Number of bytes of the data. */
+ public abstract long size();
+
+ /**
+ * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the
+ * returned ByteBuffer should not affect the content of this buffer.
+ */
+ // TODO: Deprecate this, usage may require expensive memory mapping or allocation.
+ public abstract ByteBuffer nioByteBuffer() throws IOException;
+
+ /**
+ * Exposes this buffer's data as an InputStream. The underlying implementation does not
+ * necessarily check for the length of bytes read, so the caller is responsible for making sure
+ * it does not go over the limit.
+ */
+ public abstract InputStream createInputStream() throws IOException;
+
+ /**
+ * Increment the reference count by one if applicable.
+ */
+ public abstract ManagedBuffer retain();
+
+ /**
+ * If applicable, decrement the reference count by one and deallocates the buffer if the
+ * reference count reaches zero.
+ */
+ public abstract ManagedBuffer release();
+
+ /**
+ * Convert the buffer into an Netty object, used to write the data out.
+ */
+ public abstract Object convertToNetty() throws IOException;
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java
new file mode 100644
index 0000000000..c806bfa45b
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufInputStream;
+
+/**
+ * A {@link ManagedBuffer} backed by a Netty {@link ByteBuf}.
+ */
+public final class NettyManagedBuffer extends ManagedBuffer {
+ private final ByteBuf buf;
+
+ public NettyManagedBuffer(ByteBuf buf) {
+ this.buf = buf;
+ }
+
+ @Override
+ public long size() {
+ return buf.readableBytes();
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ return buf.nioBuffer();
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ return new ByteBufInputStream(buf);
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ buf.retain();
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ buf.release();
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ return buf.duplicate();
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("buf", buf)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java
new file mode 100644
index 0000000000..f55b884bc4
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBufInputStream;
+import io.netty.buffer.Unpooled;
+
+/**
+ * A {@link ManagedBuffer} backed by {@link ByteBuffer}.
+ */
+public final class NioManagedBuffer extends ManagedBuffer {
+ private final ByteBuffer buf;
+
+ public NioManagedBuffer(ByteBuffer buf) {
+ this.buf = buf;
+ }
+
+ @Override
+ public long size() {
+ return buf.remaining();
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ return buf.duplicate();
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ return new ByteBufInputStream(Unpooled.wrappedBuffer(buf));
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ return Unpooled.wrappedBuffer(buf);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("buf", buf)
+ .toString();
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java
index 5b6d086630..1fbdcd6780 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala
+++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java
@@ -15,18 +15,17 @@
* limitations under the License.
*/
-package org.apache.spark.storage
-
-import java.nio.ByteBuffer
-
+package org.apache.spark.network.client;
/**
- * An interface for providing data for blocks.
- *
- * getBlockData returns either a FileSegment (for zero-copy send), or a ByteBuffer.
- *
- * Aside from unit tests, [[BlockManager]] is the main class that implements this.
+ * General exception caused by a remote exception while fetching a chunk.
*/
-private[spark] trait BlockDataProvider {
- def getBlockData(blockId: String): Either[FileSegment, ByteBuffer]
+public class ChunkFetchFailureException extends RuntimeException {
+ public ChunkFetchFailureException(String errorMsg, Throwable cause) {
+ super(errorMsg, cause);
+ }
+
+ public ChunkFetchFailureException(String errorMsg) {
+ super(errorMsg);
+ }
}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java
new file mode 100644
index 0000000000..519e6cb470
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * Callback for the result of a single chunk result. For a single stream, the callbacks are
+ * guaranteed to be called by the same thread in the same order as the requests for chunks were
+ * made.
+ *
+ * Note that if a general stream failure occurs, all outstanding chunk requests may be failed.
+ */
+public interface ChunkReceivedCallback {
+ /**
+ * Called upon receipt of a particular chunk.
+ *
+ * The given buffer will initially have a refcount of 1, but will be release()'d as soon as this
+ * call returns. You must therefore either retain() the buffer or copy its contents before
+ * returning.
+ */
+ void onSuccess(int chunkIndex, ManagedBuffer buffer);
+
+ /**
+ * Called upon failure to fetch a particular chunk. Note that this may actually be called due
+ * to failure to fetch a prior chunk in this stream.
+ *
+ * After receiving a failure, the stream may or may not be valid. The client should not assume
+ * that the server's side of the stream has been closed.
+ */
+ void onFailure(int chunkIndex, Throwable e);
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
index 162e9cc682..6ec960d795 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala
+++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
@@ -15,18 +15,16 @@
* limitations under the License.
*/
-package org.apache.spark.network.netty.server
+package org.apache.spark.network.client;
/**
- * Header describing a block. This is used only in the server pipeline.
- *
- * [[BlockServerHandler]] creates this, and [[BlockHeaderEncoder]] encodes it.
- *
- * @param blockSize length of the block content, excluding the length itself.
- * If positive, this is the header for a block (not part of the header).
- * If negative, this is the header and content for an error message.
- * @param blockId block id
- * @param error some error message from reading the block
+ * Callback for the result of a single RPC. This will be invoked once with either success or
+ * failure.
*/
-private[server]
-class BlockHeader(val blockSize: Int, val blockId: String, val error: Option[String] = None)
+public interface RpcResponseCallback {
+ /** Successful serialized result from server. */
+ void onSuccess(byte[] response);
+
+ /** Exception either propagated from server or raised on client side. */
+ void onFailure(Throwable e);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
new file mode 100644
index 0000000000..b1732fcde2
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.Closeable;
+import java.util.UUID;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.base.Preconditions;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow
+ * efficient transfer of a large amount of data, broken up into chunks with size ranging from
+ * hundreds of KB to a few MB.
+ *
+ * Note that while this client deals with the fetching of chunks from a stream (i.e., data plane),
+ * the actual setup of the streams is done outside the scope of the transport layer. The convenience
+ * method "sendRPC" is provided to enable control plane communication between the client and server
+ * to perform this setup.
+ *
+ * For example, a typical workflow might be:
+ * client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100
+ * client.fetchChunk(streamId = 100, chunkIndex = 0, callback)
+ * client.fetchChunk(streamId = 100, chunkIndex = 1, callback)
+ * ...
+ * client.sendRPC(new CloseStream(100))
+ *
+ * Construct an instance of TransportClient using {@link TransportClientFactory}. A single
+ * TransportClient may be used for multiple streams, but any given stream must be restricted to a
+ * single client, in order to avoid out-of-order responses.
+ *
+ * NB: This class is used to make requests to the server, while {@link TransportResponseHandler} is
+ * responsible for handling responses from the server.
+ *
+ * Concurrency: thread safe and can be called from multiple threads.
+ */
+public class TransportClient implements Closeable {
+ private final Logger logger = LoggerFactory.getLogger(TransportClient.class);
+
+ private final Channel channel;
+ private final TransportResponseHandler handler;
+
+ public TransportClient(Channel channel, TransportResponseHandler handler) {
+ this.channel = Preconditions.checkNotNull(channel);
+ this.handler = Preconditions.checkNotNull(handler);
+ }
+
+ public boolean isActive() {
+ return channel.isOpen() || channel.isActive();
+ }
+
+ /**
+ * Requests a single chunk from the remote side, from the pre-negotiated streamId.
+ *
+ * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though
+ * some streams may not support this.
+ *
+ * Multiple fetchChunk requests may be outstanding simultaneously, and the chunks are guaranteed
+ * to be returned in the same order that they were requested, assuming only a single
+ * TransportClient is used to fetch the chunks.
+ *
+ * @param streamId Identifier that refers to a stream in the remote StreamManager. This should
+ * be agreed upon by client and server beforehand.
+ * @param chunkIndex 0-based index of the chunk to fetch
+ * @param callback Callback invoked upon successful receipt of chunk, or upon any failure.
+ */
+ public void fetchChunk(
+ long streamId,
+ final int chunkIndex,
+ final ChunkReceivedCallback callback) {
+ final String serverAddr = NettyUtils.getRemoteAddress(channel);
+ final long startTime = System.currentTimeMillis();
+ logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr);
+
+ final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex);
+ handler.addFetchRequest(streamChunkId, callback);
+
+ channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ logger.trace("Sending request {} to {} took {} ms", streamChunkId, serverAddr,
+ timeTaken);
+ } else {
+ String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId,
+ serverAddr, future.cause());
+ logger.error(errorMsg, future.cause());
+ handler.removeFetchRequest(streamChunkId);
+ callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause()));
+ channel.close();
+ }
+ }
+ });
+ }
+
+ /**
+ * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked
+ * with the server's response or upon any failure.
+ */
+ public void sendRpc(byte[] message, final RpcResponseCallback callback) {
+ final String serverAddr = NettyUtils.getRemoteAddress(channel);
+ final long startTime = System.currentTimeMillis();
+ logger.trace("Sending RPC to {}", serverAddr);
+
+ final long requestId = UUID.randomUUID().getLeastSignificantBits();
+ handler.addRpcRequest(requestId, callback);
+
+ channel.writeAndFlush(new RpcRequest(requestId, message)).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ logger.trace("Sending request {} to {} took {} ms", requestId, serverAddr, timeTaken);
+ } else {
+ String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId,
+ serverAddr, future.cause());
+ logger.error(errorMsg, future.cause());
+ handler.removeRpcRequest(requestId);
+ callback.onFailure(new RuntimeException(errorMsg, future.cause()));
+ channel.close();
+ }
+ }
+ });
+ }
+
+ @Override
+ public void close() {
+ // close is a local operation and should finish with milliseconds; timeout just to be safe
+ channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
new file mode 100644
index 0000000000..10eb9ef7a0
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.Closeable;
+import java.lang.reflect.Field;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicReference;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.util.internal.PlatformDependent;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.server.TransportChannelHandler;
+import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Factory for creating {@link TransportClient}s by using createClient.
+ *
+ * The factory maintains a connection pool to other hosts and should return the same
+ * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for
+ * all {@link TransportClient}s.
+ */
+public class TransportClientFactory implements Closeable {
+ private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
+
+ private final TransportContext context;
+ private final TransportConf conf;
+ private final ConcurrentHashMap<SocketAddress, TransportClient> connectionPool;
+
+ private final Class<? extends Channel> socketChannelClass;
+ private final EventLoopGroup workerGroup;
+
+ public TransportClientFactory(TransportContext context) {
+ this.context = context;
+ this.conf = context.getConf();
+ this.connectionPool = new ConcurrentHashMap<SocketAddress, TransportClient>();
+
+ IOMode ioMode = IOMode.valueOf(conf.ioMode());
+ this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
+ // TODO: Make thread pool name configurable.
+ this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client");
+ }
+
+ /**
+ * Create a new BlockFetchingClient connecting to the given remote host / port.
+ *
+ * This blocks until a connection is successfully established.
+ *
+ * Concurrency: This method is safe to call from multiple threads.
+ */
+ public TransportClient createClient(String remoteHost, int remotePort) throws TimeoutException {
+ // Get connection from the connection pool first.
+ // If it is not found or not active, create a new one.
+ final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
+ TransportClient cachedClient = connectionPool.get(address);
+ if (cachedClient != null && cachedClient.isActive()) {
+ return cachedClient;
+ } else if (cachedClient != null) {
+ connectionPool.remove(address, cachedClient); // Remove inactive clients.
+ }
+
+ logger.debug("Creating new connection to " + address);
+
+ Bootstrap bootstrap = new Bootstrap();
+ bootstrap.group(workerGroup)
+ .channel(socketChannelClass)
+ // Disable Nagle's Algorithm since we don't want packets to wait
+ .option(ChannelOption.TCP_NODELAY, true)
+ .option(ChannelOption.SO_KEEPALIVE, true)
+ .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs());
+
+ // Use pooled buffers to reduce temporary buffer allocation
+ bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator());
+
+ final AtomicReference<TransportClient> client = new AtomicReference<TransportClient>();
+
+ bootstrap.handler(new ChannelInitializer<SocketChannel>() {
+ @Override
+ public void initChannel(SocketChannel ch) {
+ TransportChannelHandler clientHandler = context.initializePipeline(ch);
+ client.set(clientHandler.getClient());
+ }
+ });
+
+ // Connect to the remote server
+ ChannelFuture cf = bootstrap.connect(address);
+ if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
+ throw new TimeoutException(
+ String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
+ } else if (cf.cause() != null) {
+ throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
+ }
+
+ // Successful connection
+ assert client.get() != null : "Channel future completed successfully with null client";
+ TransportClient oldClient = connectionPool.putIfAbsent(address, client.get());
+ if (oldClient == null) {
+ return client.get();
+ } else {
+ logger.debug("Two clients were created concurrently, second one will be disposed.");
+ client.get().close();
+ return oldClient;
+ }
+ }
+
+ /** Close all connections in the connection pool, and shutdown the worker thread pool. */
+ @Override
+ public void close() {
+ for (TransportClient client : connectionPool.values()) {
+ try {
+ client.close();
+ } catch (RuntimeException e) {
+ logger.warn("Ignoring exception during close", e);
+ }
+ }
+ connectionPool.clear();
+
+ if (workerGroup != null) {
+ workerGroup.shutdownGracefully();
+ }
+ }
+
+ /**
+ * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches
+ * are disabled because the ByteBufs are allocated by the event loop thread, but released by the
+ * executor thread rather than the event loop thread. Those thread-local caches actually delay
+ * the recycling of buffers, leading to larger memory usage.
+ */
+ private PooledByteBufAllocator createPooledByteBufAllocator() {
+ return new PooledByteBufAllocator(
+ PlatformDependent.directBufferPreferred(),
+ getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"),
+ getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"),
+ getPrivateStaticField("DEFAULT_PAGE_SIZE"),
+ getPrivateStaticField("DEFAULT_MAX_ORDER"),
+ 0, // tinyCacheSize
+ 0, // smallCacheSize
+ 0 // normalCacheSize
+ );
+ }
+
+ /** Used to get defaults from Netty's private static fields. */
+ private int getPrivateStaticField(String name) {
+ try {
+ Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name);
+ f.setAccessible(true);
+ return f.getInt(null);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
new file mode 100644
index 0000000000..d8965590b3
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+import com.google.common.annotations.VisibleForTesting;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.ResponseMessage;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.server.MessageHandler;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * Handler that processes server responses, in response to requests issued from a
+ * [[TransportClient]]. It works by tracking the list of outstanding requests (and their callbacks).
+ *
+ * Concurrency: thread safe and can be called from multiple threads.
+ */
+public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
+ private final Logger logger = LoggerFactory.getLogger(TransportResponseHandler.class);
+
+ private final Channel channel;
+
+ private final Map<StreamChunkId, ChunkReceivedCallback> outstandingFetches;
+
+ private final Map<Long, RpcResponseCallback> outstandingRpcs;
+
+ public TransportResponseHandler(Channel channel) {
+ this.channel = channel;
+ this.outstandingFetches = new ConcurrentHashMap<StreamChunkId, ChunkReceivedCallback>();
+ this.outstandingRpcs = new ConcurrentHashMap<Long, RpcResponseCallback>();
+ }
+
+ public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) {
+ outstandingFetches.put(streamChunkId, callback);
+ }
+
+ public void removeFetchRequest(StreamChunkId streamChunkId) {
+ outstandingFetches.remove(streamChunkId);
+ }
+
+ public void addRpcRequest(long requestId, RpcResponseCallback callback) {
+ outstandingRpcs.put(requestId, callback);
+ }
+
+ public void removeRpcRequest(long requestId) {
+ outstandingRpcs.remove(requestId);
+ }
+
+ /**
+ * Fire the failure callback for all outstanding requests. This is called when we have an
+ * uncaught exception or pre-mature connection termination.
+ */
+ private void failOutstandingRequests(Throwable cause) {
+ for (Map.Entry<StreamChunkId, ChunkReceivedCallback> entry : outstandingFetches.entrySet()) {
+ entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
+ }
+ for (Map.Entry<Long, RpcResponseCallback> entry : outstandingRpcs.entrySet()) {
+ entry.getValue().onFailure(cause);
+ }
+
+ // It's OK if new fetches appear, as they will fail immediately.
+ outstandingFetches.clear();
+ outstandingRpcs.clear();
+ }
+
+ @Override
+ public void channelUnregistered() {
+ if (numOutstandingRequests() > 0) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ logger.error("Still have {} requests outstanding when connection from {} is closed",
+ numOutstandingRequests(), remoteAddress);
+ failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed"));
+ }
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) {
+ if (numOutstandingRequests() > 0) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ logger.error("Still have {} requests outstanding when connection from {} is closed",
+ numOutstandingRequests(), remoteAddress);
+ failOutstandingRequests(cause);
+ }
+ }
+
+ @Override
+ public void handle(ResponseMessage message) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ if (message instanceof ChunkFetchSuccess) {
+ ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
+ ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
+ if (listener == null) {
+ logger.warn("Ignoring response for block {} from {} since it is not outstanding",
+ resp.streamChunkId, remoteAddress);
+ resp.buffer.release();
+ } else {
+ outstandingFetches.remove(resp.streamChunkId);
+ listener.onSuccess(resp.streamChunkId.chunkIndex, resp.buffer);
+ resp.buffer.release();
+ }
+ } else if (message instanceof ChunkFetchFailure) {
+ ChunkFetchFailure resp = (ChunkFetchFailure) message;
+ ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
+ if (listener == null) {
+ logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding",
+ resp.streamChunkId, remoteAddress, resp.errorString);
+ } else {
+ outstandingFetches.remove(resp.streamChunkId);
+ listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException(
+ "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString));
+ }
+ } else if (message instanceof RpcResponse) {
+ RpcResponse resp = (RpcResponse) message;
+ RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
+ if (listener == null) {
+ logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
+ resp.requestId, remoteAddress, resp.response.length);
+ } else {
+ outstandingRpcs.remove(resp.requestId);
+ listener.onSuccess(resp.response);
+ }
+ } else if (message instanceof RpcFailure) {
+ RpcFailure resp = (RpcFailure) message;
+ RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
+ if (listener == null) {
+ logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding",
+ resp.requestId, remoteAddress, resp.errorString);
+ } else {
+ outstandingRpcs.remove(resp.requestId);
+ listener.onFailure(new RuntimeException(resp.errorString));
+ }
+ } else {
+ throw new IllegalStateException("Unknown response type: " + message.type());
+ }
+ }
+
+ /** Returns total number of outstanding requests (fetch requests + rpcs) */
+ @VisibleForTesting
+ public int numOutstandingRequests() {
+ return outstandingFetches.size() + outstandingRpcs.size();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
new file mode 100644
index 0000000000..152af98ced
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk.
+ */
+public final class ChunkFetchFailure implements ResponseMessage {
+ public final StreamChunkId streamChunkId;
+ public final String errorString;
+
+ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) {
+ this.streamChunkId = streamChunkId;
+ this.errorString = errorString;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchFailure; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
+ buf.writeInt(errorBytes.length);
+ buf.writeBytes(errorBytes);
+ }
+
+ public static ChunkFetchFailure decode(ByteBuf buf) {
+ StreamChunkId streamChunkId = StreamChunkId.decode(buf);
+ int numErrorStringBytes = buf.readInt();
+ byte[] errorBytes = new byte[numErrorStringBytes];
+ buf.readBytes(errorBytes);
+ return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8));
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchFailure) {
+ ChunkFetchFailure o = (ChunkFetchFailure) other;
+ return streamChunkId.equals(o.streamChunkId) && errorString.equals(o.errorString);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .add("errorString", errorString)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
new file mode 100644
index 0000000000..980947cf13
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single
+ * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
+ */
+public final class ChunkFetchRequest implements RequestMessage {
+ public final StreamChunkId streamChunkId;
+
+ public ChunkFetchRequest(StreamChunkId streamChunkId) {
+ this.streamChunkId = streamChunkId;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchRequest; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength();
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ }
+
+ public static ChunkFetchRequest decode(ByteBuf buf) {
+ return new ChunkFetchRequest(StreamChunkId.decode(buf));
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchRequest) {
+ ChunkFetchRequest o = (ChunkFetchRequest) other;
+ return streamChunkId.equals(o.streamChunkId);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
new file mode 100644
index 0000000000..ff4936470c
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Response to {@link ChunkFetchRequest} when a chunk exists and has been successfully fetched.
+ *
+ * Note that the server-side encoding of this messages does NOT include the buffer itself, as this
+ * may be written by Netty in a more efficient manner (i.e., zero-copy write).
+ * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer.
+ */
+public final class ChunkFetchSuccess implements ResponseMessage {
+ public final StreamChunkId streamChunkId;
+ public final ManagedBuffer buffer;
+
+ public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) {
+ this.streamChunkId = streamChunkId;
+ this.buffer = buffer;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchSuccess; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength();
+ }
+
+ /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ }
+
+ /** Decoding uses the given ByteBuf as our data, and will retain() it. */
+ public static ChunkFetchSuccess decode(ByteBuf buf) {
+ StreamChunkId streamChunkId = StreamChunkId.decode(buf);
+ buf.retain();
+ NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate());
+ return new ChunkFetchSuccess(streamChunkId, managedBuf);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchSuccess) {
+ ChunkFetchSuccess o = (ChunkFetchSuccess) other;
+ return streamChunkId.equals(o.streamChunkId) && buffer.equals(o.buffer);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .add("buffer", buffer)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java
new file mode 100644
index 0000000000..b4e299471b
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Interface for an object which can be encoded into a ByteBuf. Multiple Encodable objects are
+ * stored in a single, pre-allocated ByteBuf, so Encodables must also provide their length.
+ *
+ * Encodable objects should provide a static "decode(ByteBuf)" method which is invoked by
+ * {@link MessageDecoder}. During decoding, if the object uses the ByteBuf as its data (rather than
+ * just copying data from it), then you must retain() the ByteBuf.
+ *
+ * Additionally, when adding a new Encodable Message, add it to {@link Message.Type}.
+ */
+public interface Encodable {
+ /** Number of bytes of the encoded form of this object. */
+ int encodedLength();
+
+ /**
+ * Serializes this object by writing into the given ByteBuf.
+ * This method must write exactly encodedLength() bytes.
+ */
+ void encode(ByteBuf buf);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
new file mode 100644
index 0000000000..d568370125
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+/** An on-the-wire transmittable message. */
+public interface Message extends Encodable {
+ /** Used to identify this request type. */
+ Type type();
+
+ /** Preceding every serialized Message is its type, which allows us to deserialize it. */
+ public static enum Type implements Encodable {
+ ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
+ RpcRequest(3), RpcResponse(4), RpcFailure(5);
+
+ private final byte id;
+
+ private Type(int id) {
+ assert id < 128 : "Cannot have more than 128 message types";
+ this.id = (byte) id;
+ }
+
+ public byte id() { return id; }
+
+ @Override public int encodedLength() { return 1; }
+
+ @Override public void encode(ByteBuf buf) { buf.writeByte(id); }
+
+ public static Type decode(ByteBuf buf) {
+ byte id = buf.readByte();
+ switch (id) {
+ case 0: return ChunkFetchRequest;
+ case 1: return ChunkFetchSuccess;
+ case 2: return ChunkFetchFailure;
+ case 3: return RpcRequest;
+ case 4: return RpcResponse;
+ case 5: return RpcFailure;
+ default: throw new IllegalArgumentException("Unknown message type: " + id);
+ }
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
new file mode 100644
index 0000000000..81f8d7f963
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToMessageDecoder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Decoder used by the client side to encode server-to-client responses.
+ * This encoder is stateless so it is safe to be shared by multiple threads.
+ */
+@ChannelHandler.Sharable
+public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
+
+ private final Logger logger = LoggerFactory.getLogger(MessageDecoder.class);
+ @Override
+ public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
+ Message.Type msgType = Message.Type.decode(in);
+ Message decoded = decode(msgType, in);
+ assert decoded.type() == msgType;
+ logger.trace("Received message " + msgType + ": " + decoded);
+ out.add(decoded);
+ }
+
+ private Message decode(Message.Type msgType, ByteBuf in) {
+ switch (msgType) {
+ case ChunkFetchRequest:
+ return ChunkFetchRequest.decode(in);
+
+ case ChunkFetchSuccess:
+ return ChunkFetchSuccess.decode(in);
+
+ case ChunkFetchFailure:
+ return ChunkFetchFailure.decode(in);
+
+ case RpcRequest:
+ return RpcRequest.decode(in);
+
+ case RpcResponse:
+ return RpcResponse.decode(in);
+
+ case RpcFailure:
+ return RpcFailure.decode(in);
+
+ default:
+ throw new IllegalArgumentException("Unexpected message type: " + msgType);
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
new file mode 100644
index 0000000000..4cb8becc3e
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToMessageEncoder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Encoder used by the server side to encode server-to-client responses.
+ * This encoder is stateless so it is safe to be shared by multiple threads.
+ */
+@ChannelHandler.Sharable
+public final class MessageEncoder extends MessageToMessageEncoder<Message> {
+
+ private final Logger logger = LoggerFactory.getLogger(MessageEncoder.class);
+
+ /***
+ * Encodes a Message by invoking its encode() method. For non-data messages, we will add one
+ * ByteBuf to 'out' containing the total frame length, the message type, and the message itself.
+ * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer corresponding to the
+ * data to 'out', in order to enable zero-copy transfer.
+ */
+ @Override
+ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) {
+ Object body = null;
+ long bodyLength = 0;
+
+ // Only ChunkFetchSuccesses have data besides the header.
+ // The body is used in order to enable zero-copy transfer for the payload.
+ if (in instanceof ChunkFetchSuccess) {
+ ChunkFetchSuccess resp = (ChunkFetchSuccess) in;
+ try {
+ bodyLength = resp.buffer.size();
+ body = resp.buffer.convertToNetty();
+ } catch (Exception e) {
+ // Re-encode this message as BlockFetchFailure.
+ logger.error(String.format("Error opening block %s for client %s",
+ resp.streamChunkId, ctx.channel().remoteAddress()), e);
+ encode(ctx, new ChunkFetchFailure(resp.streamChunkId, e.getMessage()), out);
+ return;
+ }
+ }
+
+ Message.Type msgType = in.type();
+ // All messages have the frame length, message type, and message itself.
+ int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
+ long frameLength = headerLength + bodyLength;
+ ByteBuf header = ctx.alloc().buffer(headerLength);
+ header.writeLong(frameLength);
+ msgType.encode(header);
+ in.encode(header);
+ assert header.writableBytes() == 0;
+
+ out.add(header);
+ if (body != null && bodyLength > 0) {
+ out.add(body);
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala b/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java
index 0d7695072a..31b15bb17a 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java
@@ -15,11 +15,11 @@
* limitations under the License.
*/
-package org.apache.spark.network.netty
+package org.apache.spark.network.protocol;
-import org.apache.spark.storage.{BlockId, FileSegment}
+import org.apache.spark.network.protocol.Message;
-trait PathResolver {
- /** Get the file segment in which the given block resides. */
- def getBlockLocation(blockId: BlockId): FileSegment
+/** Messages from the client to the server. */
+public interface RequestMessage extends Message {
+ // token interface
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java
index e28219dd77..6edffd11cf 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java
@@ -15,15 +15,11 @@
* limitations under the License.
*/
-package org.apache.spark.network.netty.client
+package org.apache.spark.network.protocol;
-import java.util.EventListener
-
-
-trait BlockClientListener extends EventListener {
-
- def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit
-
- def onFetchFailure(blockId: String, errorMsg: String): Unit
+import org.apache.spark.network.protocol.Message;
+/** Messages from the server to the client. */
+public interface ResponseMessage extends Message {
+ // token interface
}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
new file mode 100644
index 0000000000..e239d4ffbd
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/** Response to {@link RpcRequest} for a failed RPC. */
+public final class RpcFailure implements ResponseMessage {
+ public final long requestId;
+ public final String errorString;
+
+ public RpcFailure(long requestId, String errorString) {
+ this.requestId = requestId;
+ this.errorString = errorString;
+ }
+
+ @Override
+ public Type type() { return Type.RpcFailure; }
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(requestId);
+ byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
+ buf.writeInt(errorBytes.length);
+ buf.writeBytes(errorBytes);
+ }
+
+ public static RpcFailure decode(ByteBuf buf) {
+ long requestId = buf.readLong();
+ int numErrorStringBytes = buf.readInt();
+ byte[] errorBytes = new byte[numErrorStringBytes];
+ buf.readBytes(errorBytes);
+ return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8));
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof RpcFailure) {
+ RpcFailure o = (RpcFailure) other;
+ return requestId == o.requestId && errorString.equals(o.errorString);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("requestId", requestId)
+ .add("errorString", errorString)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
new file mode 100644
index 0000000000..099e934ae0
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}.
+ * This will correspond to a single
+ * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
+ */
+public final class RpcRequest implements RequestMessage {
+ /** Used to link an RPC request with its response. */
+ public final long requestId;
+
+ /** Serialized message to send to remote RpcHandler. */
+ public final byte[] message;
+
+ public RpcRequest(long requestId, byte[] message) {
+ this.requestId = requestId;
+ this.message = message;
+ }
+
+ @Override
+ public Type type() { return Type.RpcRequest; }
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4 + message.length;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(requestId);
+ buf.writeInt(message.length);
+ buf.writeBytes(message);
+ }
+
+ public static RpcRequest decode(ByteBuf buf) {
+ long requestId = buf.readLong();
+ int messageLen = buf.readInt();
+ byte[] message = new byte[messageLen];
+ buf.readBytes(message);
+ return new RpcRequest(requestId, message);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof RpcRequest) {
+ RpcRequest o = (RpcRequest) other;
+ return requestId == o.requestId && Arrays.equals(message, o.message);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("requestId", requestId)
+ .add("message", message)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
new file mode 100644
index 0000000000..ed47947832
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/** Response to {@link RpcRequest} for a successful RPC. */
+public final class RpcResponse implements ResponseMessage {
+ public final long requestId;
+ public final byte[] response;
+
+ public RpcResponse(long requestId, byte[] response) {
+ this.requestId = requestId;
+ this.response = response;
+ }
+
+ @Override
+ public Type type() { return Type.RpcResponse; }
+
+ @Override
+ public int encodedLength() { return 8 + 4 + response.length; }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(requestId);
+ buf.writeInt(response.length);
+ buf.writeBytes(response);
+ }
+
+ public static RpcResponse decode(ByteBuf buf) {
+ long requestId = buf.readLong();
+ int responseLen = buf.readInt();
+ byte[] response = new byte[responseLen];
+ buf.readBytes(response);
+ return new RpcResponse(requestId, response);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof RpcResponse) {
+ RpcResponse o = (RpcResponse) other;
+ return requestId == o.requestId && Arrays.equals(response, o.response);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("requestId", requestId)
+ .add("response", response)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java
new file mode 100644
index 0000000000..d46a263884
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+* Encapsulates a request for a particular chunk of a stream.
+*/
+public final class StreamChunkId implements Encodable {
+ public final long streamId;
+ public final int chunkIndex;
+
+ public StreamChunkId(long streamId, int chunkIndex) {
+ this.streamId = streamId;
+ this.chunkIndex = chunkIndex;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4;
+ }
+
+ public void encode(ByteBuf buffer) {
+ buffer.writeLong(streamId);
+ buffer.writeInt(chunkIndex);
+ }
+
+ public static StreamChunkId decode(ByteBuf buffer) {
+ assert buffer.readableBytes() >= 8 + 4;
+ long streamId = buffer.readLong();
+ int chunkIndex = buffer.readInt();
+ return new StreamChunkId(streamId, chunkIndex);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(streamId, chunkIndex);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof StreamChunkId) {
+ StreamChunkId o = (StreamChunkId) other;
+ return streamId == o.streamId && chunkIndex == o.chunkIndex;
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamId", streamId)
+ .add("chunkIndex", chunkIndex)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java
new file mode 100644
index 0000000000..9688705569
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/DefaultStreamManager.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually
+ * fetched as chunks by the client.
+ */
+public class DefaultStreamManager extends StreamManager {
+ private final Logger logger = LoggerFactory.getLogger(DefaultStreamManager.class);
+
+ private final AtomicLong nextStreamId;
+ private final Map<Long, StreamState> streams;
+
+ /** State of a single stream. */
+ private static class StreamState {
+ final Iterator<ManagedBuffer> buffers;
+
+ // Used to keep track of the index of the buffer that the user has retrieved, just to ensure
+ // that the caller only requests each chunk one at a time, in order.
+ int curChunk = 0;
+
+ StreamState(Iterator<ManagedBuffer> buffers) {
+ this.buffers = buffers;
+ }
+ }
+
+ public DefaultStreamManager() {
+ // For debugging purposes, start with a random stream id to help identifying different streams.
+ // This does not need to be globally unique, only unique to this class.
+ nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000);
+ streams = new ConcurrentHashMap<Long, StreamState>();
+ }
+
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ StreamState state = streams.get(streamId);
+ if (chunkIndex != state.curChunk) {
+ throw new IllegalStateException(String.format(
+ "Received out-of-order chunk index %s (expected %s)", chunkIndex, state.curChunk));
+ } else if (!state.buffers.hasNext()) {
+ throw new IllegalStateException(String.format(
+ "Requested chunk index beyond end %s", chunkIndex));
+ }
+ state.curChunk += 1;
+ ManagedBuffer nextChunk = state.buffers.next();
+
+ if (!state.buffers.hasNext()) {
+ logger.trace("Removing stream id {}", streamId);
+ streams.remove(streamId);
+ }
+
+ return nextChunk;
+ }
+
+ @Override
+ public void connectionTerminated(long streamId) {
+ // Release all remaining buffers.
+ StreamState state = streams.remove(streamId);
+ if (state != null && state.buffers != null) {
+ while (state.buffers.hasNext()) {
+ state.buffers.next().release();
+ }
+ }
+ }
+
+ /**
+ * Registers a stream of ManagedBuffers which are served as individual chunks one at a time to
+ * callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a
+ * client connection is closed before the iterator is fully drained, then the remaining buffers
+ * will all be release()'d.
+ */
+ public long registerStream(Iterator<ManagedBuffer> buffers) {
+ long myStreamId = nextStreamId.getAndIncrement();
+ streams.put(myStreamId, new StreamState(buffers));
+ return myStreamId;
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
new file mode 100644
index 0000000000..b80c15106e
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import org.apache.spark.network.protocol.Message;
+
+/**
+ * Handles either request or response messages coming off of Netty. A MessageHandler instance
+ * is associated with a single Netty Channel (though it may have multiple clients on the same
+ * Channel.)
+ */
+public abstract class MessageHandler<T extends Message> {
+ /** Handles the receipt of a single message. */
+ public abstract void handle(T message);
+
+ /** Invoked when an exception was caught on the Channel. */
+ public abstract void exceptionCaught(Throwable cause);
+
+ /** Invoked when the channel this MessageHandler is on has been unregistered. */
+ public abstract void channelUnregistered();
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
new file mode 100644
index 0000000000..f54a696b8f
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+
+/**
+ * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s.
+ */
+public interface RpcHandler {
+ /**
+ * Receive a single RPC message. Any exception thrown while in this method will be sent back to
+ * the client in string form as a standard RPC failure.
+ *
+ * @param client A channel client which enables the handler to make requests back to the sender
+ * of this RPC.
+ * @param message The serialized bytes of the RPC.
+ * @param callback Callback which should be invoked exactly once upon success or failure of the
+ * RPC.
+ */
+ void receive(TransportClient client, byte[] message, RpcResponseCallback callback);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
new file mode 100644
index 0000000000..5a9a14a180
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * The StreamManager is used to fetch individual chunks from a stream. This is used in
+ * {@link TransportRequestHandler} in order to respond to fetchChunk() requests. Creation of the
+ * stream is outside the scope of the transport layer, but a given stream is guaranteed to be read
+ * by only one client connection, meaning that getChunk() for a particular stream will be called
+ * serially and that once the connection associated with the stream is closed, that stream will
+ * never be used again.
+ */
+public abstract class StreamManager {
+ /**
+ * Called in response to a fetchChunk() request. The returned buffer will be passed as-is to the
+ * client. A single stream will be associated with a single TCP connection, so this method
+ * will not be called in parallel for a particular stream.
+ *
+ * Chunks may be requested in any order, and requests may be repeated, but it is not required
+ * that implementations support this behavior.
+ *
+ * The returned ManagedBuffer will be release()'d after being written to the network.
+ *
+ * @param streamId id of a stream that has been previously registered with the StreamManager.
+ * @param chunkIndex 0-indexed chunk of the stream that's requested
+ */
+ public abstract ManagedBuffer getChunk(long streamId, int chunkIndex);
+
+ /**
+ * Indicates that the TCP connection that was tied to the given stream has been terminated. After
+ * this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned
+ * up.
+ */
+ public void connectionTerminated(long streamId) { }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
new file mode 100644
index 0000000000..e491367fa4
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.SimpleChannelInboundHandler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportResponseHandler;
+import org.apache.spark.network.protocol.Message;
+import org.apache.spark.network.protocol.RequestMessage;
+import org.apache.spark.network.protocol.ResponseMessage;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * The single Transport-level Channel handler which is used for delegating requests to the
+ * {@link TransportRequestHandler} and responses to the {@link TransportResponseHandler}.
+ *
+ * All channels created in the transport layer are bidirectional. When the Client initiates a Netty
+ * Channel with a RequestMessage (which gets handled by the Server's RequestHandler), the Server
+ * will produce a ResponseMessage (handled by the Client's ResponseHandler). However, the Server
+ * also gets a handle on the same Channel, so it may then begin to send RequestMessages to the
+ * Client.
+ * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler,
+ * for the Client's responses to the Server's requests.
+ */
+public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> {
+ private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class);
+
+ private final TransportClient client;
+ private final TransportResponseHandler responseHandler;
+ private final TransportRequestHandler requestHandler;
+
+ public TransportChannelHandler(
+ TransportClient client,
+ TransportResponseHandler responseHandler,
+ TransportRequestHandler requestHandler) {
+ this.client = client;
+ this.responseHandler = responseHandler;
+ this.requestHandler = requestHandler;
+ }
+
+ public TransportClient getClient() {
+ return client;
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
+ logger.warn("Exception in connection from " + NettyUtils.getRemoteAddress(ctx.channel()),
+ cause);
+ requestHandler.exceptionCaught(cause);
+ responseHandler.exceptionCaught(cause);
+ ctx.close();
+ }
+
+ @Override
+ public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
+ try {
+ requestHandler.channelUnregistered();
+ } catch (RuntimeException e) {
+ logger.error("Exception from request handler while unregistering channel", e);
+ }
+ try {
+ responseHandler.channelUnregistered();
+ } catch (RuntimeException e) {
+ logger.error("Exception from response handler while unregistering channel", e);
+ }
+ super.channelUnregistered(ctx);
+ }
+
+ @Override
+ public void channelRead0(ChannelHandlerContext ctx, Message request) {
+ if (request instanceof RequestMessage) {
+ requestHandler.handle((RequestMessage) request);
+ } else {
+ responseHandler.handle((ResponseMessage) request);
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
new file mode 100644
index 0000000000..352f865935
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.util.Set;
+
+import com.google.common.base.Throwables;
+import com.google.common.collect.Sets;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.RequestMessage;
+import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * A handler that processes requests from clients and writes chunk data back. Each handler is
+ * attached to a single Netty channel, and keeps track of which streams have been fetched via this
+ * channel, in order to clean them up if the channel is terminated (see #channelUnregistered).
+ *
+ * The messages should have been processed by the pipeline setup by {@link TransportServer}.
+ */
+public class TransportRequestHandler extends MessageHandler<RequestMessage> {
+ private final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class);
+
+ /** The Netty channel that this handler is associated with. */
+ private final Channel channel;
+
+ /** Client on the same channel allowing us to talk back to the requester. */
+ private final TransportClient reverseClient;
+
+ /** Returns each chunk part of a stream. */
+ private final StreamManager streamManager;
+
+ /** Handles all RPC messages. */
+ private final RpcHandler rpcHandler;
+
+ /** List of all stream ids that have been read on this handler, used for cleanup. */
+ private final Set<Long> streamIds;
+
+ public TransportRequestHandler(
+ Channel channel,
+ TransportClient reverseClient,
+ StreamManager streamManager,
+ RpcHandler rpcHandler) {
+ this.channel = channel;
+ this.reverseClient = reverseClient;
+ this.streamManager = streamManager;
+ this.rpcHandler = rpcHandler;
+ this.streamIds = Sets.newHashSet();
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) {
+ }
+
+ @Override
+ public void channelUnregistered() {
+ // Inform the StreamManager that these streams will no longer be read from.
+ for (long streamId : streamIds) {
+ streamManager.connectionTerminated(streamId);
+ }
+ }
+
+ @Override
+ public void handle(RequestMessage request) {
+ if (request instanceof ChunkFetchRequest) {
+ processFetchRequest((ChunkFetchRequest) request);
+ } else if (request instanceof RpcRequest) {
+ processRpcRequest((RpcRequest) request);
+ } else {
+ throw new IllegalArgumentException("Unknown request type: " + request);
+ }
+ }
+
+ private void processFetchRequest(final ChunkFetchRequest req) {
+ final String client = NettyUtils.getRemoteAddress(channel);
+ streamIds.add(req.streamChunkId.streamId);
+
+ logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId);
+
+ ManagedBuffer buf;
+ try {
+ buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
+ } catch (Exception e) {
+ logger.error(String.format(
+ "Error opening block %s for request from %s", req.streamChunkId, client), e);
+ respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e)));
+ return;
+ }
+
+ respond(new ChunkFetchSuccess(req.streamChunkId, buf));
+ }
+
+ private void processRpcRequest(final RpcRequest req) {
+ try {
+ rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() {
+ @Override
+ public void onSuccess(byte[] response) {
+ respond(new RpcResponse(req.requestId, response));
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
+ }
+ });
+ } catch (Exception e) {
+ logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e);
+ respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
+ }
+ }
+
+ /**
+ * Responds to a single message with some Encodable object. If a failure occurs while sending,
+ * it will be logged and the channel closed.
+ */
+ private void respond(final Encodable result) {
+ final String remoteAddress = channel.remoteAddress().toString();
+ channel.writeAndFlush(result).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ logger.trace(String.format("Sent result %s to client %s", result, remoteAddress));
+ } else {
+ logger.error(String.format("Error sending result %s to %s; closing connection",
+ result, remoteAddress), future.cause());
+ channel.close();
+ }
+ }
+ }
+ );
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
new file mode 100644
index 0000000000..243070750d
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.server;
+
+import java.io.Closeable;
+import java.net.InetSocketAddress;
+import java.util.concurrent.TimeUnit;
+
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Server for the efficient, low-level streaming service.
+ */
+public class TransportServer implements Closeable {
+ private final Logger logger = LoggerFactory.getLogger(TransportServer.class);
+
+ private final TransportContext context;
+ private final TransportConf conf;
+
+ private ServerBootstrap bootstrap;
+ private ChannelFuture channelFuture;
+ private int port = -1;
+
+ public TransportServer(TransportContext context) {
+ this.context = context;
+ this.conf = context.getConf();
+
+ init();
+ }
+
+ public int getPort() {
+ if (port == -1) {
+ throw new IllegalStateException("Server not initialized");
+ }
+ return port;
+ }
+
+ private void init() {
+
+ IOMode ioMode = IOMode.valueOf(conf.ioMode());
+ EventLoopGroup bossGroup =
+ NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server");
+ EventLoopGroup workerGroup = bossGroup;
+
+ bootstrap = new ServerBootstrap()
+ .group(bossGroup, workerGroup)
+ .channel(NettyUtils.getServerChannelClass(ioMode))
+ .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
+ .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT);
+
+ if (conf.backLog() > 0) {
+ bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog());
+ }
+
+ if (conf.receiveBuf() > 0) {
+ bootstrap.childOption(ChannelOption.SO_RCVBUF, conf.receiveBuf());
+ }
+
+ if (conf.sendBuf() > 0) {
+ bootstrap.childOption(ChannelOption.SO_SNDBUF, conf.sendBuf());
+ }
+
+ bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
+ @Override
+ protected void initChannel(SocketChannel ch) throws Exception {
+ context.initializePipeline(ch);
+ }
+ });
+
+ channelFuture = bootstrap.bind(new InetSocketAddress(conf.serverPort()));
+ channelFuture.syncUninterruptibly();
+
+ port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort();
+ logger.debug("Shuffle server started on port :" + port);
+ }
+
+ @Override
+ public void close() {
+ if (channelFuture != null) {
+ // close is a local operation and should finish with milliseconds; timeout just to be safe
+ channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS);
+ channelFuture = null;
+ }
+ if (bootstrap != null && bootstrap.group() != null) {
+ bootstrap.group().shutdownGracefully();
+ }
+ if (bootstrap != null && bootstrap.childGroup() != null) {
+ bootstrap.childGroup().shutdownGracefully();
+ }
+ bootstrap = null;
+ }
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java
new file mode 100644
index 0000000000..d944d9da1c
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.util.NoSuchElementException;
+
+/**
+ * Provides a mechanism for constructing a {@link TransportConf} using some sort of configuration.
+ */
+public abstract class ConfigProvider {
+ /** Obtains the value of the given config, throws NoSuchElementException if it doesn't exist. */
+ public abstract String get(String name);
+
+ public String get(String name, String defaultValue) {
+ try {
+ return get(name);
+ } catch (NoSuchElementException e) {
+ return defaultValue;
+ }
+ }
+
+ public int getInt(String name, int defaultValue) {
+ return Integer.parseInt(get(name, Integer.toString(defaultValue)));
+ }
+
+ public long getLong(String name, long defaultValue) {
+ return Long.parseLong(get(name, Long.toString(defaultValue)));
+ }
+
+ public double getDouble(String name, double defaultValue) {
+ return Double.parseDouble(get(name, Double.toString(defaultValue)));
+ }
+
+ public boolean getBoolean(String name, boolean defaultValue) {
+ return Boolean.parseBoolean(get(name, Boolean.toString(defaultValue)));
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java
new file mode 100644
index 0000000000..6b208d95bb
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+/**
+ * Selector for which form of low-level IO we should use.
+ * NIO is always available, while EPOLL is only available on Linux.
+ * AUTO is used to select EPOLL if it's available, or NIO otherwise.
+ */
+public enum IOMode {
+ NIO, EPOLL
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
new file mode 100644
index 0000000000..32ba3f5b07
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+import com.google.common.io.Closeables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class JavaUtils {
+ private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class);
+
+ /** Closes the given object, ignoring IOExceptions. */
+ public static void closeQuietly(Closeable closeable) {
+ try {
+ closeable.close();
+ } catch (IOException e) {
+ logger.error("IOException should not have been thrown.", e);
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
new file mode 100644
index 0000000000..b187234119
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.util.concurrent.ThreadFactory;
+
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import io.netty.channel.Channel;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.ServerChannel;
+import io.netty.channel.epoll.Epoll;
+import io.netty.channel.epoll.EpollEventLoopGroup;
+import io.netty.channel.epoll.EpollServerSocketChannel;
+import io.netty.channel.epoll.EpollSocketChannel;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.handler.codec.ByteToMessageDecoder;
+import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
+
+/**
+ * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO.
+ */
+public class NettyUtils {
+ /** Creates a Netty EventLoopGroup based on the IOMode. */
+ public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
+
+ ThreadFactory threadFactory = new ThreadFactoryBuilder()
+ .setDaemon(true)
+ .setNameFormat(threadPrefix + "-%d")
+ .build();
+
+ switch (mode) {
+ case NIO:
+ return new NioEventLoopGroup(numThreads, threadFactory);
+ case EPOLL:
+ return new EpollEventLoopGroup(numThreads, threadFactory);
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /** Returns the correct (client) SocketChannel class based on IOMode. */
+ public static Class<? extends Channel> getClientChannelClass(IOMode mode) {
+ switch (mode) {
+ case NIO:
+ return NioSocketChannel.class;
+ case EPOLL:
+ return EpollSocketChannel.class;
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /** Returns the correct ServerSocketChannel class based on IOMode. */
+ public static Class<? extends ServerChannel> getServerChannelClass(IOMode mode) {
+ switch (mode) {
+ case NIO:
+ return NioServerSocketChannel.class;
+ case EPOLL:
+ return EpollServerSocketChannel.class;
+ default:
+ throw new IllegalArgumentException("Unknown io mode: " + mode);
+ }
+ }
+
+ /**
+ * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame.
+ * This is used before all decoders.
+ */
+ public static ByteToMessageDecoder createFrameDecoder() {
+ // maxFrameLength = 2G
+ // lengthFieldOffset = 0
+ // lengthFieldLength = 8
+ // lengthAdjustment = -8, i.e. exclude the 8 byte length itself
+ // initialBytesToStrip = 8, i.e. strip out the length field itself
+ return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8);
+ }
+
+ /** Returns the remote address on the channel or "<remote address>" if none exists. */
+ public static String getRemoteAddress(Channel channel) {
+ if (channel != null && channel.remoteAddress() != null) {
+ return channel.remoteAddress().toString();
+ }
+ return "<unknown remote>";
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
new file mode 100644
index 0000000000..80f65d9803
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+/**
+ * A central location that tracks all the settings we expose to users.
+ */
+public class TransportConf {
+ private final ConfigProvider conf;
+
+ public TransportConf(ConfigProvider conf) {
+ this.conf = conf;
+ }
+
+ /** Port the server listens on. Default to a random port. */
+ public int serverPort() { return conf.getInt("spark.shuffle.io.port", 0); }
+
+ /** IO mode: nio or epoll */
+ public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); }
+
+ /** Connect timeout in secs. Default 120 secs. */
+ public int connectionTimeoutMs() {
+ return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000;
+ }
+
+ /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */
+ public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); }
+
+ /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */
+ public int serverThreads() { return conf.getInt("spark.shuffle.io.serverThreads", 0); }
+
+ /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */
+ public int clientThreads() { return conf.getInt("spark.shuffle.io.clientThreads", 0); }
+
+ /**
+ * Receive buffer size (SO_RCVBUF).
+ * Note: the optimal size for receive buffer and send buffer should be
+ * latency * network_bandwidth.
+ * Assuming latency = 1ms, network_bandwidth = 10Gbps
+ * buffer size should be ~ 1.25MB
+ */
+ public int receiveBuf() { return conf.getInt("spark.shuffle.io.receiveBuffer", -1); }
+
+ /** Send buffer size (SO_SNDBUF). */
+ public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
new file mode 100644
index 0000000000..738dca9b6a
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.io.File;
+import java.io.RandomAccessFile;
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.TransportConf;
+
+public class ChunkFetchIntegrationSuite {
+ static final long STREAM_ID = 1;
+ static final int BUFFER_CHUNK_INDEX = 0;
+ static final int FILE_CHUNK_INDEX = 1;
+
+ static TransportServer server;
+ static TransportClientFactory clientFactory;
+ static StreamManager streamManager;
+ static File testFile;
+
+ static ManagedBuffer bufferChunk;
+ static ManagedBuffer fileChunk;
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ int bufSize = 100000;
+ final ByteBuffer buf = ByteBuffer.allocate(bufSize);
+ for (int i = 0; i < bufSize; i ++) {
+ buf.put((byte) i);
+ }
+ buf.flip();
+ bufferChunk = new NioManagedBuffer(buf);
+
+ testFile = File.createTempFile("shuffle-test-file", "txt");
+ testFile.deleteOnExit();
+ RandomAccessFile fp = new RandomAccessFile(testFile, "rw");
+ byte[] fileContent = new byte[1024];
+ new Random().nextBytes(fileContent);
+ fp.write(fileContent);
+ fp.close();
+ fileChunk = new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25);
+
+ TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
+ streamManager = new StreamManager() {
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ assertEquals(STREAM_ID, streamId);
+ if (chunkIndex == BUFFER_CHUNK_INDEX) {
+ return new NioManagedBuffer(buf);
+ } else if (chunkIndex == FILE_CHUNK_INDEX) {
+ return new FileSegmentManagedBuffer(testFile, 10, testFile.length() - 25);
+ } else {
+ throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex);
+ }
+ }
+ };
+ TransportContext context = new TransportContext(conf, streamManager, new NoOpRpcHandler());
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ server.close();
+ clientFactory.close();
+ testFile.delete();
+ }
+
+ class FetchResult {
+ public Set<Integer> successChunks;
+ public Set<Integer> failedChunks;
+ public List<ManagedBuffer> buffers;
+
+ public void releaseBuffers() {
+ for (ManagedBuffer buffer : buffers) {
+ buffer.release();
+ }
+ }
+ }
+
+ private FetchResult fetchChunks(List<Integer> chunkIndices) throws Exception {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ final Semaphore sem = new Semaphore(0);
+
+ final FetchResult res = new FetchResult();
+ res.successChunks = Collections.synchronizedSet(new HashSet<Integer>());
+ res.failedChunks = Collections.synchronizedSet(new HashSet<Integer>());
+ res.buffers = Collections.synchronizedList(new LinkedList<ManagedBuffer>());
+
+ ChunkReceivedCallback callback = new ChunkReceivedCallback() {
+ @Override
+ public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
+ buffer.retain();
+ res.successChunks.add(chunkIndex);
+ res.buffers.add(buffer);
+ sem.release();
+ }
+
+ @Override
+ public void onFailure(int chunkIndex, Throwable e) {
+ res.failedChunks.add(chunkIndex);
+ sem.release();
+ }
+ };
+
+ for (int chunkIndex : chunkIndices) {
+ client.fetchChunk(STREAM_ID, chunkIndex, callback);
+ }
+ if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) {
+ fail("Timeout getting response from the server");
+ }
+ client.close();
+ return res;
+ }
+
+ @Test
+ public void fetchBufferChunk() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX));
+ assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX));
+ assertTrue(res.failedChunks.isEmpty());
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk));
+ res.releaseBuffers();
+ }
+
+ @Test
+ public void fetchFileChunk() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(FILE_CHUNK_INDEX));
+ assertEquals(res.successChunks, Sets.newHashSet(FILE_CHUNK_INDEX));
+ assertTrue(res.failedChunks.isEmpty());
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(fileChunk));
+ res.releaseBuffers();
+ }
+
+ @Test
+ public void fetchNonExistentChunk() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(12345));
+ assertTrue(res.successChunks.isEmpty());
+ assertEquals(res.failedChunks, Sets.newHashSet(12345));
+ assertTrue(res.buffers.isEmpty());
+ }
+
+ @Test
+ public void fetchBothChunks() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX));
+ assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX));
+ assertTrue(res.failedChunks.isEmpty());
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk, fileChunk));
+ res.releaseBuffers();
+ }
+
+ @Test
+ public void fetchChunkAndNonExistent() throws Exception {
+ FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, 12345));
+ assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX));
+ assertEquals(res.failedChunks, Sets.newHashSet(12345));
+ assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk));
+ res.releaseBuffers();
+ }
+
+ private void assertBufferListsEqual(List<ManagedBuffer> list0, List<ManagedBuffer> list1)
+ throws Exception {
+ assertEquals(list0.size(), list1.size());
+ for (int i = 0; i < list0.size(); i ++) {
+ assertBuffersEqual(list0.get(i), list1.get(i));
+ }
+ }
+
+ private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception {
+ ByteBuffer nio0 = buffer0.nioByteBuffer();
+ ByteBuffer nio1 = buffer1.nioByteBuffer();
+
+ int len = nio0.remaining();
+ assertEquals(nio0.remaining(), nio1.remaining());
+ for (int i = 0; i < len; i ++) {
+ assertEquals(nio0.get(), nio1.get());
+ }
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java
new file mode 100644
index 0000000000..7aa37efc58
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/NoOpRpcHandler.java
@@ -0,0 +1,28 @@
+package org.apache.spark.network;/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.RpcHandler;
+
+/** Test RpcHandler which always returns a zero-sized success. */
+public class NoOpRpcHandler implements RpcHandler {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ callback.onSuccess(new byte[0]);
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
new file mode 100644
index 0000000000..43dc0cf8c7
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import io.netty.channel.embedded.EmbeddedChannel;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+import org.apache.spark.network.protocol.Message;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.MessageDecoder;
+import org.apache.spark.network.protocol.MessageEncoder;
+import org.apache.spark.network.util.NettyUtils;
+
+public class ProtocolSuite {
+ private void testServerToClient(Message msg) {
+ EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder());
+ serverChannel.writeOutbound(msg);
+
+ EmbeddedChannel clientChannel = new EmbeddedChannel(
+ NettyUtils.createFrameDecoder(), new MessageDecoder());
+
+ while (!serverChannel.outboundMessages().isEmpty()) {
+ clientChannel.writeInbound(serverChannel.readOutbound());
+ }
+
+ assertEquals(1, clientChannel.inboundMessages().size());
+ assertEquals(msg, clientChannel.readInbound());
+ }
+
+ private void testClientToServer(Message msg) {
+ EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder());
+ clientChannel.writeOutbound(msg);
+
+ EmbeddedChannel serverChannel = new EmbeddedChannel(
+ NettyUtils.createFrameDecoder(), new MessageDecoder());
+
+ while (!clientChannel.outboundMessages().isEmpty()) {
+ serverChannel.writeInbound(clientChannel.readOutbound());
+ }
+
+ assertEquals(1, serverChannel.inboundMessages().size());
+ assertEquals(msg, serverChannel.readInbound());
+ }
+
+ @Test
+ public void requests() {
+ testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2)));
+ testClientToServer(new RpcRequest(12345, new byte[0]));
+ testClientToServer(new RpcRequest(12345, new byte[100]));
+ }
+
+ @Test
+ public void responses() {
+ testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10)));
+ testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0)));
+ testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error"));
+ testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), ""));
+ testServerToClient(new RpcResponse(12345, new byte[0]));
+ testServerToClient(new RpcResponse(12345, new byte[1000]));
+ testServerToClient(new RpcFailure(0, "this is an error"));
+ testServerToClient(new RpcFailure(0, ""));
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
new file mode 100644
index 0000000000..9f216dd2d7
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.Set;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.Sets;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.DefaultStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.TransportConf;
+
+public class RpcIntegrationSuite {
+ static TransportServer server;
+ static TransportClientFactory clientFactory;
+ static RpcHandler rpcHandler;
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
+ rpcHandler = new RpcHandler() {
+ @Override
+ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+ String msg = new String(message, Charsets.UTF_8);
+ String[] parts = msg.split("/");
+ if (parts[0].equals("hello")) {
+ callback.onSuccess(("Hello, " + parts[1] + "!").getBytes(Charsets.UTF_8));
+ } else if (parts[0].equals("return error")) {
+ callback.onFailure(new RuntimeException("Returned: " + parts[1]));
+ } else if (parts[0].equals("throw error")) {
+ throw new RuntimeException("Thrown: " + parts[1]);
+ }
+ }
+ };
+ TransportContext context = new TransportContext(conf, new DefaultStreamManager(), rpcHandler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ server.close();
+ clientFactory.close();
+ }
+
+ class RpcResult {
+ public Set<String> successMessages;
+ public Set<String> errorMessages;
+ }
+
+ private RpcResult sendRPC(String ... commands) throws Exception {
+ TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+ final Semaphore sem = new Semaphore(0);
+
+ final RpcResult res = new RpcResult();
+ res.successMessages = Collections.synchronizedSet(new HashSet<String>());
+ res.errorMessages = Collections.synchronizedSet(new HashSet<String>());
+
+ RpcResponseCallback callback = new RpcResponseCallback() {
+ @Override
+ public void onSuccess(byte[] message) {
+ res.successMessages.add(new String(message, Charsets.UTF_8));
+ sem.release();
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ res.errorMessages.add(e.getMessage());
+ sem.release();
+ }
+ };
+
+ for (String command : commands) {
+ client.sendRpc(command.getBytes(Charsets.UTF_8), callback);
+ }
+
+ if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) {
+ fail("Timeout getting response from the server");
+ }
+ client.close();
+ return res;
+ }
+
+ @Test
+ public void singleRPC() throws Exception {
+ RpcResult res = sendRPC("hello/Aaron");
+ assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!"));
+ assertTrue(res.errorMessages.isEmpty());
+ }
+
+ @Test
+ public void doubleRPC() throws Exception {
+ RpcResult res = sendRPC("hello/Aaron", "hello/Reynold");
+ assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!", "Hello, Reynold!"));
+ assertTrue(res.errorMessages.isEmpty());
+ }
+
+ @Test
+ public void returnErrorRPC() throws Exception {
+ RpcResult res = sendRPC("return error/OK");
+ assertTrue(res.successMessages.isEmpty());
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK"));
+ }
+
+ @Test
+ public void throwErrorRPC() throws Exception {
+ RpcResult res = sendRPC("throw error/uh-oh");
+ assertTrue(res.successMessages.isEmpty());
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: uh-oh"));
+ }
+
+ @Test
+ public void doubleTrouble() throws Exception {
+ RpcResult res = sendRPC("return error/OK", "throw error/uh-oh");
+ assertTrue(res.successMessages.isEmpty());
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK", "Thrown: uh-oh"));
+ }
+
+ @Test
+ public void sendSuccessAndFailure() throws Exception {
+ RpcResult res = sendRPC("hello/Bob", "throw error/the", "hello/Builder", "return error/!");
+ assertEquals(res.successMessages, Sets.newHashSet("Hello, Bob!", "Hello, Builder!"));
+ assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !"));
+ }
+
+ private void assertErrorsContain(Set<String> errors, Set<String> contains) {
+ assertEquals(contains.size(), errors.size());
+
+ Set<String> remainingErrors = Sets.newHashSet(errors);
+ for (String contain : contains) {
+ Iterator<String> it = remainingErrors.iterator();
+ boolean foundMatch = false;
+ while (it.hasNext()) {
+ if (it.next().contains(contain)) {
+ it.remove();
+ foundMatch = true;
+ break;
+ }
+ }
+ assertTrue("Could not find error containing " + contain + "; errors: " + errors, foundMatch);
+ }
+
+ assertTrue(remainingErrors.isEmpty());
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java b/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java
new file mode 100644
index 0000000000..f4e0a2426a
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/SystemPropertyConfigProvider.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.util.NoSuchElementException;
+
+import org.apache.spark.network.util.ConfigProvider;
+
+/** Uses System properties to obtain config values. */
+public class SystemPropertyConfigProvider extends ConfigProvider {
+ @Override
+ public String get(String name) {
+ String value = System.getProperty(name);
+ if (value == null) {
+ throw new NoSuchElementException(name);
+ }
+ return value;
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java
new file mode 100644
index 0000000000..38113a918f
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1).
+ *
+ * Used for testing.
+ */
+public class TestManagedBuffer extends ManagedBuffer {
+
+ private final int len;
+ private NettyManagedBuffer underlying;
+
+ public TestManagedBuffer(int len) {
+ Preconditions.checkArgument(len <= Byte.MAX_VALUE);
+ this.len = len;
+ byte[] byteArray = new byte[len];
+ for (int i = 0; i < len; i ++) {
+ byteArray[i] = (byte) i;
+ }
+ this.underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray));
+ }
+
+
+ @Override
+ public long size() {
+ return underlying.size();
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ return underlying.nioByteBuffer();
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ return underlying.createInputStream();
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ underlying.retain();
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ underlying.release();
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ return underlying.convertToNetty();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ManagedBuffer) {
+ try {
+ ByteBuffer nioBuf = ((ManagedBuffer) other).nioByteBuffer();
+ if (nioBuf.remaining() != len) {
+ return false;
+ } else {
+ for (int i = 0; i < len; i ++) {
+ if (nioBuf.get() != i) {
+ return false;
+ }
+ }
+ return true;
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ return false;
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/TestUtils.java b/network/common/src/test/java/org/apache/spark/network/TestUtils.java
new file mode 100644
index 0000000000..56a2b805f1
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/TestUtils.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.net.InetAddress;
+
+public class TestUtils {
+ public static String getLocalHost() {
+ try {
+ return InetAddress.getLocalHost().getHostAddress();
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
new file mode 100644
index 0000000000..3ef964616f
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.util.concurrent.TimeoutException;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.DefaultStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.JavaUtils;
+import org.apache.spark.network.util.TransportConf;
+
+public class TransportClientFactorySuite {
+ private TransportConf conf;
+ private TransportContext context;
+ private TransportServer server1;
+ private TransportServer server2;
+
+ @Before
+ public void setUp() {
+ conf = new TransportConf(new SystemPropertyConfigProvider());
+ StreamManager streamManager = new DefaultStreamManager();
+ RpcHandler rpcHandler = new NoOpRpcHandler();
+ context = new TransportContext(conf, streamManager, rpcHandler);
+ server1 = context.createServer();
+ server2 = context.createServer();
+ }
+
+ @After
+ public void tearDown() {
+ JavaUtils.closeQuietly(server1);
+ JavaUtils.closeQuietly(server2);
+ }
+
+ @Test
+ public void createAndReuseBlockClients() throws TimeoutException {
+ TransportClientFactory factory = context.createClientFactory();
+ TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ TransportClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
+ assertTrue(c1.isActive());
+ assertTrue(c3.isActive());
+ assertTrue(c1 == c2);
+ assertTrue(c1 != c3);
+ factory.close();
+ }
+
+ @Test
+ public void neverReturnInactiveClients() throws Exception {
+ TransportClientFactory factory = context.createClientFactory();
+ TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ c1.close();
+
+ long start = System.currentTimeMillis();
+ while (c1.isActive() && (System.currentTimeMillis() - start) < 3000) {
+ Thread.sleep(10);
+ }
+ assertFalse(c1.isActive());
+
+ TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ assertFalse(c1 == c2);
+ assertTrue(c2.isActive());
+ factory.close();
+ }
+
+ @Test
+ public void closeBlockClientsWithFactory() throws TimeoutException {
+ TransportClientFactory factory = context.createClientFactory();
+ TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+ TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
+ assertTrue(c1.isActive());
+ assertTrue(c2.isActive());
+ factory.close();
+ assertFalse(c1.isActive());
+ assertFalse(c2.isActive());
+ }
+}
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
new file mode 100644
index 0000000000..17a03ebe88
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import io.netty.channel.local.LocalChannel;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.ChunkReceivedCallback;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportResponseHandler;
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamChunkId;
+
+public class TransportResponseHandlerSuite {
+ @Test
+ public void handleSuccessfulFetch() {
+ StreamChunkId streamChunkId = new StreamChunkId(1, 0);
+
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
+ handler.addFetchRequest(streamChunkId, callback);
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123)));
+ verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any());
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+
+ @Test
+ public void handleFailedFetch() {
+ StreamChunkId streamChunkId = new StreamChunkId(1, 0);
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
+ handler.addFetchRequest(streamChunkId, callback);
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg"));
+ verify(callback, times(1)).onFailure(eq(0), (Throwable) any());
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+
+ @Test
+ public void clearAllOutstandingRequests() {
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
+ handler.addFetchRequest(new StreamChunkId(1, 0), callback);
+ handler.addFetchRequest(new StreamChunkId(1, 1), callback);
+ handler.addFetchRequest(new StreamChunkId(1, 2), callback);
+ assertEquals(3, handler.numOutstandingRequests());
+
+ handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12)));
+ handler.exceptionCaught(new Exception("duh duh duhhhh"));
+
+ // should fail both b2 and b3
+ verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any());
+ verify(callback, times(1)).onFailure(eq(1), (Throwable) any());
+ verify(callback, times(1)).onFailure(eq(2), (Throwable) any());
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+
+ @Test
+ public void handleSuccessfulRPC() {
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+ handler.addRpcRequest(12345, callback);
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new RpcResponse(54321, new byte[7])); // should be ignored
+ assertEquals(1, handler.numOutstandingRequests());
+
+ byte[] arr = new byte[10];
+ handler.handle(new RpcResponse(12345, arr));
+ verify(callback, times(1)).onSuccess(eq(arr));
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+
+ @Test
+ public void handleFailedRPC() {
+ TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+ handler.addRpcRequest(12345, callback);
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new RpcFailure(54321, "uh-oh!")); // should be ignored
+ assertEquals(1, handler.numOutstandingRequests());
+
+ handler.handle(new RpcFailure(12345, "oh no"));
+ verify(callback, times(1)).onFailure((Throwable) any());
+ assertEquals(0, handler.numOutstandingRequests());
+ }
+}
diff --git a/pom.xml b/pom.xml
index abcb97108c..e4c92470fc 100644
--- a/pom.xml
+++ b/pom.xml
@@ -91,6 +91,7 @@
<module>graphx</module>
<module>mllib</module>
<module>tools</module>
+ <module>network/common</module>
<module>streaming</module>
<module>sql/catalyst</module>
<module>sql/core</module>
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 95152b58e2..adbdc5d1da 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -51,6 +51,11 @@ object MimaExcludes {
// MapStatus should be private[spark]
ProblemFilters.exclude[IncompatibleTemplateDefProblem](
"org.apache.spark.scheduler.MapStatus"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.network.netty.PathResolver"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.network.netty.client.BlockClientListener"),
+
// TaskContext was promoted to Abstract class
ProblemFilters.exclude[AbstractClassProblem](
"org.apache.spark.TaskContext"),